use candle_metal_kernels::metal::{Buffer, CommandBuffer, ComputeCommandEncoder, ComputePipeline};
use objc2_metal::MTLSize;
use std::ffi::c_void;
pub(crate) fn get_2d_grid_dims(shape: &[usize], strides: &[usize]) -> MTLSize {
let mut grid_x: usize = 1;
let mut grid_y: usize = 1;
for i in 0..shape.len() {
if strides[i] == 0 {
continue;
}
if grid_x.saturating_mul(shape[i]) < u32::MAX as usize {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
}
if grid_y > u32::MAX as usize || grid_x > u32::MAX as usize {
panic!("Unable to safely factor shape.");
}
if grid_y > grid_x {
std::mem::swap(&mut grid_x, &mut grid_y);
}
MTLSize {
width: grid_x,
height: grid_y,
depth: 1,
}
}
pub(crate) fn get_2d_grid_dims_divisor(
shape: &[usize],
strides: &[usize],
mut divisor: usize,
) -> MTLSize {
let mut grid_x: usize = 1;
let mut grid_y: usize = 1;
for i in 0..shape.len() {
if strides[i] == 0 {
continue;
}
if divisor % shape[i] == 0 {
divisor /= shape[i];
continue;
}
if grid_x.saturating_mul(shape[i]) < u32::MAX as usize {
grid_x *= shape[i];
} else {
grid_y *= shape[i];
}
if divisor > 1 {
if grid_x % divisor == 0 {
grid_x /= divisor;
divisor = 1;
} else if grid_y % divisor == 0 {
grid_y /= divisor;
divisor = 1;
}
}
}
if grid_y > u32::MAX as usize || grid_x > u32::MAX as usize {
panic!("Unable to safely factor shape.");
}
if grid_y > grid_x {
std::mem::swap(&mut grid_x, &mut grid_y);
}
MTLSize {
width: grid_x,
height: grid_y,
depth: 1,
}
}
pub(crate) fn get_block_dims(dim0: usize, dim1: usize, dim2: usize, pow2: usize) -> MTLSize {
let mut pows = [0usize; 3];
let mut sum = 0usize;
loop {
let presum = sum;
if dim0 >= (1usize << (pows[0] + 1)) {
pows[0] += 1;
sum += 1;
}
if sum == pow2 {
break;
}
if dim1 >= (1usize << (pows[1] + 1)) {
pows[1] += 1;
sum += 1;
}
if sum == pow2 {
break;
}
if dim2 >= (1usize << (pows[2] + 1)) {
pows[2] += 1;
sum += 1;
}
if sum == presum || sum == pow2 {
break;
}
}
MTLSize {
width: 1usize << pows[0],
height: 1usize << pows[1],
depth: 1usize << pows[2],
}
}
pub(crate) fn linear_split(pipeline: &ComputePipeline, length: usize) -> (MTLSize, MTLSize) {
let size = length;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = size.div_ceil(width);
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
pub trait RawBytesEncoder {
fn set_bytes_raw(&self, index: usize, length: usize, bytes: *const c_void);
}
impl RawBytesEncoder for ComputeCommandEncoder {
fn set_bytes_raw(&self, index: usize, length: usize, bytes: *const c_void) {
self.set_bytes_directly(index, length, bytes);
}
}
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoder, position: usize, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_bytes(position, &data);
}
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u32);
primitive!(u64);
primitive!(f32);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,
pub offset_in_bytes: usize,
}
impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_bytes_directly(position, core::mem::size_of_val(data), data.as_ptr().cast());
}
}
impl EncoderParam for &Buffer {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1);
}
}
impl EncoderParam for &BufferOffset<'_> {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes);
}
}
impl EncoderParam for &mut Buffer {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&mut Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1);
}
}
#[macro_export]
macro_rules! set_params {
($encoder:ident, ($($param:expr),+)) => (
let mut _index = 0;
$(
$crate::metal_kernels::utils::set_param($encoder, _index, $param);
_index += 1;
)*
);
}
pub trait EncoderProvider {
type Encoder<'a>: AsRef<ComputeCommandEncoder>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_>;
}
pub struct WrappedEncoder<'a> {
inner: &'a ComputeCommandEncoder,
end_encoding_on_drop: bool,
}
impl Drop for WrappedEncoder<'_> {
fn drop(&mut self) {
if self.end_encoding_on_drop {
self.inner.end_encoding()
}
}
}
impl AsRef<ComputeCommandEncoder> for WrappedEncoder<'_> {
fn as_ref(&self) -> &ComputeCommandEncoder {
self.inner
}
}
impl EncoderProvider for &CommandBuffer {
type Encoder<'a>
= ComputeCommandEncoder
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
self.compute_command_encoder()
}
}
impl EncoderProvider for &ComputeCommandEncoder {
type Encoder<'a>
= WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self,
end_encoding_on_drop: false,
}
}
}