use crate::metal::{
buffer::MetalBuffer,
error::Result,
kernels::{kernel_names, KernelManager},
ops::execute_and_wait,
};
pub fn neg(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_NEG_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn exp(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_EXP_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn log(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_LOG_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn sqrt(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_SQRT_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn tanh(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_TANH_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn relu(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_RELU_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn abs(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_ABS_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn sin(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_SIN_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn cos(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_COS_F32, input.shape().numel())
})?;
Ok(output)
}
pub fn sigmoid(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(
encoder,
kernel_names::UNARY_SIGMOID_F32,
input.shape().numel(),
)
})?;
Ok(output)
}
pub fn gelu(input: &MetalBuffer) -> Result<MetalBuffer> {
let device = input.device();
let output = MetalBuffer::zeros(input.shape(), &input.dtype(), device)?;
let kernel_manager = KernelManager::new(device.device_ref())?;
execute_and_wait(device, |encoder| {
encoder.set_buffer(0, Some(input.buffer()), 0);
encoder.set_buffer(1, Some(output.buffer()), 0);
kernel_manager.dispatch_1d(encoder, kernel_names::UNARY_GELU_F32, input.shape().numel())
})?;
Ok(output)
}