use crate::{
XResult,
gpu::metal::{
METAL_DEVICE, METAL_QUEUE, RANDOM_METALLIB, get_pipeline, load_library, thread_config,
},
};
use metal::MTLResourceOptions;
use rand::RngExt;
use std::sync::LazyLock;
static LIBRARY: LazyLock<XResult<metal::Library>> = LazyLock::new(|| load_library(RANDOM_METALLIB));
static STANDARD_STABLE_PIPELINE: LazyLock<XResult<metal::ComputePipelineState>> =
LazyLock::new(|| {
let library = LIBRARY.as_ref()?;
get_pipeline(library, "standard_stable_rand")
});
static UNIFORM_PIPELINE: LazyLock<XResult<metal::ComputePipelineState>> = LazyLock::new(|| {
let library = LIBRARY.as_ref()?;
get_pipeline(library, "randuniform")
});
static NORMAL_PIPELINE: LazyLock<XResult<metal::ComputePipelineState>> = LazyLock::new(|| {
let library = LIBRARY.as_ref()?;
get_pipeline(library, "randnormal")
});
static EXP_PIPELINE: LazyLock<XResult<metal::ComputePipelineState>> = LazyLock::new(|| {
let library = LIBRARY.as_ref()?;
get_pipeline(library, "randexp")
});
pub fn standard_stable_rands(alpha: f32, beta: f32, len: usize) -> XResult<Vec<f32>> {
let device = METAL_DEVICE.as_ref()?;
let queue = METAL_QUEUE.as_ref()?;
let pipeline = STANDARD_STABLE_PIPELINE.as_ref()?;
let (inv_alpha, one_minus_alpha_div_alpha, b, s) = if (alpha - 1.0).abs() < 1e-3 {
(0.0f32, 0.0f32, 0.0f32, 0.0f32)
} else {
let inv_alpha = 1.0 / alpha;
let one_minus_alpha_div_alpha = (1.0 - alpha) * inv_alpha;
let tmp = beta * (alpha * std::f32::consts::FRAC_PI_2).tan();
let b = tmp.atan() * inv_alpha;
let s = (1.0 + tmp * tmp).powf(0.5 * inv_alpha);
(inv_alpha, one_minus_alpha_div_alpha, b, s)
};
let out_buffer = device.new_buffer(
(len * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let seed: u64 = rand::rng().random();
let len_u32 = len as u32;
let (thread_groups, threads_per_group) = thread_config(len);
let command_buffer = queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(pipeline);
encoder.set_buffer(0, Some(&out_buffer), 0);
encoder.set_bytes(
1,
std::mem::size_of::<f32>() as u64,
&alpha as *const f32 as *const _,
);
encoder.set_bytes(
2,
std::mem::size_of::<f32>() as u64,
&beta as *const f32 as *const _,
);
encoder.set_bytes(
3,
std::mem::size_of::<f32>() as u64,
&inv_alpha as *const f32 as *const _,
);
encoder.set_bytes(
4,
std::mem::size_of::<f32>() as u64,
&one_minus_alpha_div_alpha as *const f32 as *const _,
);
encoder.set_bytes(
5,
std::mem::size_of::<f32>() as u64,
&b as *const f32 as *const _,
);
encoder.set_bytes(
6,
std::mem::size_of::<f32>() as u64,
&s as *const f32 as *const _,
);
encoder.set_bytes(
7,
std::mem::size_of::<u32>() as u64,
&len_u32 as *const u32 as *const _,
);
encoder.set_bytes(
8,
std::mem::size_of::<u64>() as u64,
&seed as *const u64 as *const _,
);
encoder.dispatch_thread_groups(thread_groups, threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let result = unsafe {
let ptr = out_buffer.contents() as *const f32;
std::slice::from_raw_parts(ptr, len).to_vec()
};
Ok(result)
}
pub fn metalrands(n: usize) -> XResult<Vec<f32>> {
let device = METAL_DEVICE.as_ref()?;
let queue = METAL_QUEUE.as_ref()?;
let pipeline = UNIFORM_PIPELINE.as_ref()?;
let out_buffer = device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let seed: u64 = rand::rng().random();
let len_u32 = n as u32;
let (thread_groups, threads_per_group) = thread_config(n);
let command_buffer = queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(pipeline);
encoder.set_buffer(0, Some(&out_buffer), 0);
encoder.set_bytes(
1,
std::mem::size_of::<u32>() as u64,
&len_u32 as *const u32 as *const _,
);
encoder.set_bytes(
2,
std::mem::size_of::<u64>() as u64,
&seed as *const u64 as *const _,
);
encoder.dispatch_thread_groups(thread_groups, threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let result = unsafe {
let ptr = out_buffer.contents() as *const f32;
std::slice::from_raw_parts(ptr, n).to_vec()
};
Ok(result)
}
pub fn metalrandn(n: usize, mu: f32, sigma: f32) -> XResult<Vec<f32>> {
let device = METAL_DEVICE.as_ref()?;
let queue = METAL_QUEUE.as_ref()?;
let pipeline = NORMAL_PIPELINE.as_ref()?;
let out_buffer = device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let seed: u64 = rand::rng().random();
let len_u32 = n as u32;
let (thread_groups, threads_per_group) = thread_config(n);
let command_buffer = queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(pipeline);
encoder.set_buffer(0, Some(&out_buffer), 0);
encoder.set_bytes(
1,
std::mem::size_of::<u32>() as u64,
&len_u32 as *const u32 as *const _,
);
encoder.set_bytes(
2,
std::mem::size_of::<f32>() as u64,
&mu as *const f32 as *const _,
);
encoder.set_bytes(
3,
std::mem::size_of::<f32>() as u64,
&sigma as *const f32 as *const _,
);
encoder.set_bytes(
4,
std::mem::size_of::<u64>() as u64,
&seed as *const u64 as *const _,
);
encoder.dispatch_thread_groups(thread_groups, threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let result = unsafe {
let ptr = out_buffer.contents() as *const f32;
std::slice::from_raw_parts(ptr, n).to_vec()
};
Ok(result)
}
pub fn metalrandexp(n: usize) -> XResult<Vec<f32>> {
let device = METAL_DEVICE.as_ref()?;
let queue = METAL_QUEUE.as_ref()?;
let pipeline = EXP_PIPELINE.as_ref()?;
let out_buffer = device.new_buffer(
(n * std::mem::size_of::<f32>()) as u64,
MTLResourceOptions::StorageModeShared,
);
let seed: u64 = rand::rng().random();
let len_u32 = n as u32;
let (thread_groups, threads_per_group) = thread_config(n);
let command_buffer = queue.new_command_buffer();
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(pipeline);
encoder.set_buffer(0, Some(&out_buffer), 0);
encoder.set_bytes(
1,
std::mem::size_of::<u32>() as u64,
&len_u32 as *const u32 as *const _,
);
encoder.set_bytes(
2,
std::mem::size_of::<u64>() as u64,
&seed as *const u64 as *const _,
);
encoder.dispatch_thread_groups(thread_groups, threads_per_group);
encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
let result = unsafe {
let ptr = out_buffer.contents() as *const f32;
std::slice::from_raw_parts(ptr, n).to_vec()
};
Ok(result)
}