use crate::metal::{Buffer, CommandBuffer, ComputeCommandEncoder, ComputePipeline};
use crate::MTLSize;
use std::ffi::OsStr;
use std::ops::Deref;
use std::sync::{RwLockReadGuard, RwLockWriteGuard};
pub(crate) fn linear_split(pipeline: &ComputePipeline, length: usize) -> (MTLSize, MTLSize) {
let size = length;
let width = std::cmp::min(pipeline.max_total_threads_per_threadgroup(), size);
let count = size.div_ceil(width);
let thread_group_count = MTLSize {
width: count,
height: 1,
depth: 1,
};
let thread_group_size = MTLSize {
width,
height: 1,
depth: 1,
};
(thread_group_count, thread_group_size)
}
pub fn get_block_dims(dim0: usize, dim1: usize, dim2: usize) -> MTLSize {
let mut pows0 = 0;
let mut pows1 = 0;
let mut pows2 = 0;
let mut sum = 0;
loop {
let presum = sum;
if dim0 >= (1 << (pows0 + 1)) {
pows0 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim1 >= (1 << (pows1 + 1)) {
pows1 += 1;
sum += 1;
}
if sum == 10 {
break;
}
if dim2 >= (1 << (pows2 + 1)) {
pows2 += 1;
sum += 1;
}
if sum == presum || sum == 10 {
break;
}
}
MTLSize {
width: 1 << pows0,
height: 1 << pows1,
depth: 1 << pows2,
}
}
#[inline(always)]
pub fn get_tile_size(dtype_size: usize) -> usize {
1.max(8 / dtype_size)
}
pub fn set_param<P: EncoderParam>(encoder: &ComputeCommandEncoder, position: usize, data: P) {
<P as EncoderParam>::set_param(encoder, position, data)
}
pub trait EncoderParam {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self);
}
macro_rules! primitive {
($type:ty) => {
impl EncoderParam for $type {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_bytes(position, &data);
}
}
};
}
primitive!(bool);
primitive!(usize);
primitive!(i32);
primitive!(i64);
primitive!(u8);
primitive!(u32);
primitive!(u64);
primitive!(f32);
primitive!(f64);
primitive!(half::bf16);
primitive!(half::f16);
pub struct BufferOffset<'a> {
pub buffer: &'a Buffer,
pub offset_in_bytes: usize,
}
impl<'a> BufferOffset<'a> {
pub fn zero_offset(buffer: &'a Buffer) -> Self {
Self {
buffer,
offset_in_bytes: 0,
}
}
}
impl<T> EncoderParam for &[T] {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_bytes_directly(position, core::mem::size_of_val(data), data.as_ptr().cast());
}
}
impl EncoderParam for &Buffer {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1);
}
}
impl EncoderParam for &BufferOffset<'_> {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data.buffer), data.offset_in_bytes);
}
}
impl EncoderParam for &mut Buffer {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data), 0);
}
}
impl EncoderParam for (&mut Buffer, usize) {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_buffer(position, Some(data.0), data.1);
}
}
impl EncoderParam for () {
fn set_param(_: &ComputeCommandEncoder, _: usize, _: Self) {}
}
#[macro_export]
macro_rules! set_params {
($encoder:ident, ($($param:expr),+)) => (
let mut _index = 0;
$(
$crate::utils::set_param($encoder, _index, $param);
_index += 1;
)*
);
}
pub trait EncoderProvider {
type Encoder<'a>: AsRef<ComputeCommandEncoder>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_>;
}
pub struct WrappedEncoder<'a> {
inner: &'a ComputeCommandEncoder,
end_encoding_on_drop: bool,
}
impl Drop for WrappedEncoder<'_> {
fn drop(&mut self) {
if self.end_encoding_on_drop {
self.inner.end_encoding()
}
}
}
impl AsRef<ComputeCommandEncoder> for WrappedEncoder<'_> {
fn as_ref(&self) -> &ComputeCommandEncoder {
self.inner
}
}
impl EncoderProvider for &CommandBuffer {
type Encoder<'a>
= ComputeCommandEncoder
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
self.compute_command_encoder()
}
}
impl EncoderProvider for &ComputeCommandEncoder {
type Encoder<'a>
= WrappedEncoder<'a>
where
Self: 'a;
fn encoder(&self) -> Self::Encoder<'_> {
WrappedEncoder {
inner: self,
end_encoding_on_drop: false,
}
}
}
pub enum RwLockGuard<'a, T> {
Read(RwLockReadGuard<'a, T>),
Write(RwLockWriteGuard<'a, T>),
}
impl<'a, T> Deref for RwLockGuard<'a, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
match self {
RwLockGuard::Read(g) => g.deref(),
RwLockGuard::Write(g) => g.deref(),
}
}
}
impl<'a, T> From<RwLockReadGuard<'a, T>> for RwLockGuard<'a, T> {
fn from(g: RwLockReadGuard<'a, T>) -> Self {
RwLockGuard::Read(g)
}
}
impl<'a, T> From<RwLockWriteGuard<'a, T>> for RwLockGuard<'a, T> {
fn from(g: RwLockWriteGuard<'a, T>) -> Self {
RwLockGuard::Write(g)
}
}
fn is_truthy(s: String) -> bool {
match s.as_str() {
"true" | "t" | "yes" | "y" | "1" => true,
_ => false,
}
}
pub(crate) fn get_env_bool<K: AsRef<OsStr>>(key: K, default: bool) -> bool {
std::env::var(key).map(is_truthy).unwrap_or(default)
}