use metal::foreign_types::ForeignType;
use metal::{CommandBuffer, Device, NSUInteger};
use objc2::msg_send;
use objc2::runtime::AnyObject;
use crate::metal::{
buffer::MetalBuffer,
error::{metal_errors, MetalError, Result},
mps::{create_matrix_descriptor, MPSDataType, MPSOperation},
};
#[allow(dead_code)]
pub struct MPSMatMul {
matmul: *mut AnyObject,
output: MetalBuffer,
m: usize,
n: usize,
k: usize,
}
impl MPSMatMul {
#[allow(clippy::too_many_arguments)]
pub fn new(
device: &Device,
a: &MetalBuffer,
b: &MetalBuffer,
c: Option<&MetalBuffer>,
alpha: f32,
beta: f32,
transpose_a: bool,
transpose_b: bool,
) -> Result<Self> {
unsafe {
let a_shape = a.shape().dims();
let b_shape = b.shape().dims();
if a_shape.len() != 2 || b_shape.len() != 2 {
return Err(MetalError::ShapeMismatch {
expected: vec![2],
got: vec![a_shape.len(), b_shape.len()],
});
}
let (m, k_a) = if transpose_a {
(a_shape[1], a_shape[0])
} else {
(a_shape[0], a_shape[1])
};
let (k_b, n) = if transpose_b {
(b_shape[1], b_shape[0])
} else {
(b_shape[0], b_shape[1])
};
if k_a != k_b {
return Err(MetalError::ShapeMismatch {
expected: vec![k_a],
got: vec![k_b],
});
}
let k = k_a;
let output = if let Some(c_buffer) = c {
c_buffer.clone()
} else {
MetalBuffer::zeros(
&torsh_core::Shape::from(vec![m, n]),
&torsh_core::DType::F32,
&crate::metal::device::MetalDevice::new()?,
)?
};
let _a_desc = create_matrix_descriptor(
if transpose_a { k } else { m },
if transpose_a { m } else { k },
MPSDataType::Float32,
);
let _b_desc = create_matrix_descriptor(
if transpose_b { n } else { k },
if transpose_b { k } else { n },
MPSDataType::Float32,
);
let _c_desc = create_matrix_descriptor(m, n, MPSDataType::Float32);
let class = objc2::class!(MPSMatrixMultiplication);
let matmul: *mut AnyObject = msg_send![class, alloc];
let matmul: *mut AnyObject = msg_send![matmul,
initWithDevice: device.as_ptr() as *mut AnyObject,
transposeLeft: objc2::runtime::Bool::from(transpose_a),
transposeRight: objc2::runtime::Bool::from(transpose_b),
resultRows: m as NSUInteger,
resultColumns: n as NSUInteger,
interiorColumns: k as NSUInteger,
alpha: alpha as f64,
beta: beta as f64
];
Ok(Self {
matmul,
output,
m,
n,
k,
})
}
}
pub fn encode_matmul(
&self,
command_buffer: &CommandBuffer,
a: &MetalBuffer,
b: &MetalBuffer,
) -> Result<()> {
unsafe {
let class = objc2::class!(MPSMatrix);
let a_matrix: *mut AnyObject = msg_send![class, alloc];
let a_matrix: *mut AnyObject = msg_send![a_matrix,
initWithBuffer: a.buffer().as_ptr() as *mut AnyObject,
descriptor: create_matrix_descriptor(self.m, self.k, MPSDataType::Float32)
];
let b_matrix: *mut AnyObject = msg_send![class, alloc];
let b_matrix: *mut AnyObject = msg_send![b_matrix,
initWithBuffer: b.buffer().as_ptr() as *mut AnyObject,
descriptor: create_matrix_descriptor(self.k, self.n, MPSDataType::Float32)
];
let c_matrix: *mut AnyObject = msg_send![class, alloc];
let c_matrix: *mut AnyObject = msg_send![c_matrix,
initWithBuffer: self.output.buffer().as_ptr() as *mut AnyObject,
descriptor: create_matrix_descriptor(self.m, self.n, MPSDataType::Float32)
];
let _: () = msg_send![self.matmul,
encodeToCommandBuffer: command_buffer.as_ptr() as *mut AnyObject,
leftMatrix: a_matrix,
rightMatrix: b_matrix,
resultMatrix: c_matrix
];
Ok(())
}
}
pub fn output(&self) -> &MetalBuffer {
&self.output
}
}
impl MPSOperation for MPSMatMul {
fn encode(&self, _command_buffer: &CommandBuffer) -> Result<()> {
Err(metal_errors::kernel_execution_error(
"MPSMatMul::encode requires input buffers".to_string(),
None,
))
}
}
impl Drop for MPSMatMul {
fn drop(&mut self) {
unsafe {
let _: () = msg_send![self.matmul, release];
}
}
}