#![cfg(target_os = "macos")]
mod common;
use std::collections::BTreeMap;
use common::{Dt, gpu_lock, max_abs_diff, pack_bytes, unpack_bytes};
use metaltile_core::{dtype::DType, ir::Kernel};
use metaltile_runtime::Context;
use metaltile_std::mlx::binary::{
mt_atan2,
mt_div,
mt_logaddexp,
mt_max_elem,
mt_min_elem,
mt_mul,
mt_pow,
mt_sub,
};
fn run_binary(kernel_ir: fn(DType) -> Kernel, a: &[f32], b: &[f32], dt: Dt, n: usize) -> Vec<f32> {
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("a".into(), pack_bytes(a, dt));
buffers.insert("b".into(), pack_bytes(b, dt));
buffers.insert("out".into(), pack_bytes(&vec![0.0f32; n], dt));
let ctx = Context::new().expect("Context::new on macOS");
let kernel = kernel_ir(dt.to_dtype());
let tpg = 256usize;
let groups = n.div_ceil(tpg);
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [groups, 1, 1], [tpg, 1, 1])
.expect("binary dispatch");
let mut out = unpack_bytes(result.outputs.get("out").expect("out"), dt);
out.truncate(n);
out
}
#[test]
fn binary_mul_matches_cpu_f32() {
let _g = gpu_lock();
let n = 1024usize;
let a: Vec<f32> = (0..n).map(|i| (i % 17) as f32 * 0.05 - 0.4).collect();
let b: Vec<f32> = (0..n).map(|i| (i % 13) as f32 * 0.04 - 0.25).collect();
let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x * y).collect();
let actual = run_binary(mt_mul::kernel_ir_for, &a, &b, Dt::F32, n);
assert!(max_abs_diff(&actual, &expected) < 1e-5, "mul f32 mismatch");
}
#[test]
fn binary_sub_matches_cpu_f32() {
let _g = gpu_lock();
let n = 512usize;
let a: Vec<f32> = (0..n).map(|i| (i % 19) as f32 * 0.07 - 0.6).collect();
let b: Vec<f32> = (0..n).map(|i| (i % 11) as f32 * 0.05 - 0.3).collect();
let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x - y).collect();
let actual = run_binary(mt_sub::kernel_ir_for, &a, &b, Dt::F32, n);
assert!(max_abs_diff(&actual, &expected) < 1e-5, "sub f32 mismatch");
}
#[test]
fn binary_div_matches_cpu_f32() {
let _g = gpu_lock();
let n = 512usize;
let a: Vec<f32> = (0..n).map(|i| (i % 17) as f32 * 0.06 - 0.4).collect();
let b: Vec<f32> = (0..n).map(|i| (i % 13) as f32 * 0.08 + 0.2).collect();
let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x / y).collect();
let actual = run_binary(mt_div::kernel_ir_for, &a, &b, Dt::F32, n);
assert!(max_abs_diff(&actual, &expected) < 1e-4, "div f32 mismatch");
}
#[test]
fn binary_max_elem_matches_cpu_f32() {
let _g = gpu_lock();
let n = 512usize;
let a: Vec<f32> = (0..n).map(|i| (i % 17) as f32 * 0.05 - 0.4).collect();
let b: Vec<f32> = (0..n).map(|i| (i % 13) as f32 * 0.06 - 0.35).collect();
let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x.max(*y)).collect();
let actual = run_binary(mt_max_elem::kernel_ir_for, &a, &b, Dt::F32, n);
assert!(max_abs_diff(&actual, &expected) < 1e-5, "max_elem f32 mismatch");
}
#[test]
fn binary_min_elem_matches_cpu_f32() {
let _g = gpu_lock();
let n = 512usize;
let a: Vec<f32> = (0..n).map(|i| (i % 17) as f32 * 0.05 - 0.4).collect();
let b: Vec<f32> = (0..n).map(|i| (i % 13) as f32 * 0.06 - 0.35).collect();
let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x.min(*y)).collect();
let actual = run_binary(mt_min_elem::kernel_ir_for, &a, &b, Dt::F32, n);
assert!(max_abs_diff(&actual, &expected) < 1e-5, "min_elem f32 mismatch");
}
#[test]
fn binary_pow_matches_cpu_f32() {
let _g = gpu_lock();
let n = 256usize;
let a: Vec<f32> = (0..n).map(|i| (i % 9) as f32 * 0.1 + 0.2).collect();
let b: Vec<f32> = (0..n).map(|i| (i % 5) as f32 * 0.4 + 0.2).collect();
let expected: Vec<f32> = a.iter().zip(b.iter()).map(|(x, y)| x.powf(*y)).collect();
let actual = run_binary(mt_pow::kernel_ir_for, &a, &b, Dt::F32, n);
assert!(max_abs_diff(&actual, &expected) < 1e-3, "pow f32 mismatch");
}
#[test]
fn binary_atan2_matches_cpu_f32() {
let _g = gpu_lock();
let n = 512usize;
let y_vals: Vec<f32> = (0..n).map(|i| (i % 17) as f32 * 0.1 - 0.8).collect();
let x_vals: Vec<f32> = (0..n).map(|i| (i % 11) as f32 * 0.1 - 0.5).collect();
let expected: Vec<f32> = y_vals.iter().zip(x_vals.iter()).map(|(y, x)| y.atan2(*x)).collect();
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("y".into(), pack_bytes(&y_vals, Dt::F32));
buffers.insert("x".into(), pack_bytes(&x_vals, Dt::F32));
buffers.insert("out".into(), pack_bytes(&vec![0.0f32; n], Dt::F32));
let ctx = Context::new().expect("Context::new");
let kernel = mt_atan2::kernel_ir_for(DType::F32);
let tpg = 256;
let groups = n.div_ceil(tpg);
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [groups, 1, 1], [tpg, 1, 1])
.expect("atan2 dispatch");
let actual: Vec<f32> = unpack_bytes(result.outputs.get("out").expect("out"), Dt::F32)
.into_iter()
.take(n)
.collect();
assert!(max_abs_diff(&actual, &expected) < 1e-4, "atan2 f32 mismatch");
}
#[test]
fn binary_logaddexp_matches_cpu_f32() {
let _g = gpu_lock();
let n = 512usize;
let a: Vec<f32> = (0..n).map(|i| (i % 11) as f32 * 0.3 - 1.5).collect();
let b: Vec<f32> = (0..n).map(|i| (i % 7) as f32 * 0.4 - 1.0).collect();
let expected: Vec<f32> =
a.iter().zip(b.iter()).map(|(x, y)| (x.exp() + y.exp()).ln()).collect();
let actual = run_binary(mt_logaddexp::kernel_ir_for, &a, &b, Dt::F32, n);
assert!(max_abs_diff(&actual, &expected) < 1e-3, "logaddexp f32 mismatch");
}