use torsh_core::Shape;
use crate::metal::{
buffer::MetalBuffer,
error::{MetalError, Result},
kernels::{kernel_names, KernelManager},
ops::execute_and_wait,
};
pub fn matmul(a: &MetalBuffer, b: &MetalBuffer) -> Result<MetalBuffer> {
let a_shape = a.shape().dims();
let b_shape = b.shape().dims();
if a_shape.len() < 2 || b_shape.len() < 2 {
return Err(MetalError::ShapeMismatch {
expected: vec![2, 2],
got: vec![a_shape.len(), b_shape.len()],
});
}
let m = a_shape[a_shape.len() - 2];
let k1 = a_shape[a_shape.len() - 1];
let k2 = b_shape[b_shape.len() - 2];
let n = b_shape[b_shape.len() - 1];
if k1 != k2 {
return Err(MetalError::ShapeMismatch {
expected: vec![k1],
got: vec![k2],
});
}
let _device = a.device();
matmul_kernel(a, b, m, n, k1)
}
fn matmul_kernel(
a: &MetalBuffer,
b: &MetalBuffer,
m: usize,
n: usize,
k: usize,
) -> Result<MetalBuffer> {
let device = a.device();
let output_shape = {
let mut shape = a.shape().dims().to_vec();
let len = shape.len();
shape[len - 2] = m;
shape[len - 1] = n;
Shape::from(shape)
};
let output = MetalBuffer::zeros(&output_shape, &a.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
let dims = [m as u32, n as u32, k as u32];
let dims_buffer = device.device().new_buffer_with_data(
dims.as_ptr() as *const _,
(dims.len() * std::mem::size_of::<u32>()) as u64,
device.resource_options(),
);
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(a.buffer()), 0);
encoder.set_buffer(1, Some(b.buffer()), 0);
encoder.set_buffer(2, Some(output.buffer()), 0);
encoder.set_buffer(3, Some(&dims_buffer), 0);
kernel_manager.dispatch_2d(encoder, kernel_names::MATMUL_F32, n, m)
})?;
Ok(output)
}
pub fn bmm(a: &MetalBuffer, b: &MetalBuffer) -> Result<MetalBuffer> {
let a_shape = a.shape().dims();
let b_shape = b.shape().dims();
if a_shape.len() != 3 || b_shape.len() != 3 {
return Err(MetalError::ShapeMismatch {
expected: vec![3, 3],
got: vec![a_shape.len(), b_shape.len()],
});
}
if a_shape[0] != b_shape[0] {
return Err(MetalError::ShapeMismatch {
expected: vec![a_shape[0]],
got: vec![b_shape[0]],
});
}
let batch_size = a_shape[0];
let _m = a_shape[1];
let k1 = a_shape[2];
let k2 = b_shape[1];
let _n = b_shape[2];
if k1 != k2 {
return Err(MetalError::ShapeMismatch {
expected: vec![k1],
got: vec![k2],
});
}
let _device = a.device();
let mut outputs = Vec::new();
for i in 0..batch_size {
let a_batch = extract_batch(a, i)?;
let b_batch = extract_batch(b, i)?;
let output = matmul(&a_batch, &b_batch)?;
outputs.push(output);
}
concatenate_batches(&outputs)
}
fn extract_batch(tensor: &MetalBuffer, batch_idx: usize) -> Result<MetalBuffer> {
let shape = tensor.shape().dims();
if shape.len() != 3 || batch_idx >= shape[0] {
return Err(MetalError::BackendError("Invalid batch index".to_string()));
}
let batch_shape = Shape::from(vec![shape[1], shape[2]]);
let batch_size = shape[1] * shape[2];
let _offset = batch_idx * batch_size;
tensor.view(&batch_shape)
}
fn concatenate_batches(batches: &[MetalBuffer]) -> Result<MetalBuffer> {
if batches.is_empty() {
return Err(MetalError::BackendError(
"No batches to concatenate".to_string(),
));
}
let first_shape = batches[0].shape().dims();
let batch_size = batches.len();
let mut output_shape = vec![batch_size];
output_shape.extend_from_slice(first_shape);
let device = batches[0].device();
let output = MetalBuffer::zeros(&Shape::from(output_shape), &batches[0].dtype(), device)?;
Ok(output)
}