use crate::utils::{BufferOffset, EncoderProvider};
use crate::{get_tile_size, linear_split};
use crate::{set_params, Buffer, ComputeCommandEncoder, Device, Kernels, MetalKernelError, Source};
use objc2_metal::MTLResourceUsage;
#[allow(clippy::too_many_arguments)]
pub fn call_affine(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
dtype_size: usize,
size: usize,
input: BufferOffset,
output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, add, &input, output));
let tile_size = get_tile_size(dtype_size);
let tiles = size.div_ceil(tile_size);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
encoder.use_resource(input.buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_affine_strided(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: BufferOffset,
input_stride: &[usize],
output: &Buffer,
mul: f32,
add: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
size,
shape.len(),
shape,
input_stride,
mul,
add,
&input,
output
)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input.buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_powf(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
dtype_size: usize,
size: usize,
input: BufferOffset,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output));
let tile_size = get_tile_size(dtype_size);
let tiles = size.div_ceil(tile_size);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
encoder.use_resource(input.buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_powf_strided(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: BufferOffset,
input_stride: &[usize],
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(size, shape.len(), shape, input_stride, mul, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input.buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_elu(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
dtype_size: usize,
size: usize,
input: BufferOffset,
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (size, mul, &input, output));
let tile_size = get_tile_size(dtype_size);
let tiles = size.div_ceil(tile_size);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
encoder.use_resource(input.buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_elu_strided(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
name: &'static str,
shape: &[usize],
input: BufferOffset,
input_stride: &[usize],
output: &Buffer,
mul: f32,
) -> Result<(), MetalKernelError> {
let pipeline = kernels.load_pipeline(device, Source::Affine, name)?;
let size: usize = shape.iter().product();
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(size, shape.len(), shape, input_stride, mul, &input, output)
);
let (thread_group_count, thread_group_size) = linear_split(&pipeline, size);
encoder.use_resource(input.buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
Ok(())
}