use std::{cell::Cell, ffi::c_int};
use crate::{
array::Array,
dtype::Dtype,
error::{Result, check},
shape::{IntoShape, dim_ptr, validate_dims},
stream::default_stream,
};
thread_local! {
static CPU_STREAM: Cell<Option<mlxrs_sys::mlx_stream>> = const { Cell::new(None) };
}
fn random_cpu_stream() -> mlxrs_sys::mlx_stream {
crate::error::ensure_handler_installed();
crate::stream::assert_streams_not_cleared();
CPU_STREAM.with(|cell| {
if let Some(s) = cell.get() {
return s;
}
let s = unsafe { mlxrs_sys::mlx_default_cpu_stream_new() };
if s.ctx.is_null() {
panic!(
"mlxrs::ops::random: mlx_default_cpu_stream_new returned NULL ctx — \
CPU stream initialization failed. Aborting."
);
}
cell.set(Some(s));
s
})
}
pub fn key(seed: u64) -> Result<Array> {
crate::error::ensure_handler_installed();
crate::stream::assert_streams_not_cleared();
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_random_key(&mut out.0, seed) })?;
Ok(out)
}
pub fn seed(seed: u64) -> Result<()> {
crate::error::ensure_handler_installed();
crate::stream::assert_streams_not_cleared();
check(unsafe { mlxrs_sys::mlx_random_seed(seed) })
}
pub fn split(key: &Array) -> Result<(Array, Array)> {
let mut k0 = Array(unsafe { mlxrs_sys::mlx_array_new() });
let mut k1 = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_random_split(&mut k0.0, &mut k1.0, key.0, default_stream()) })?;
Ok((k0, k1))
}
pub fn split_num(key: &Array, num: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_split_num(&mut out.0, key.0, num as c_int, default_stream())
})?;
Ok(out)
}
pub fn bernoulli(p: &Array, shape: &impl IntoShape, key: &Array) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_bernoulli(
&mut out.0,
p.0,
dim_ptr(s),
s.len(),
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn uniform(
low: &Array,
high: &Array,
shape: &impl IntoShape,
dtype: Dtype,
key: &Array,
) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_uniform(
&mut out.0,
low.0,
high.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(dtype),
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn normal(
shape: &impl IntoShape,
dtype: Dtype,
loc: f32,
scale: f32,
key: &Array,
) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_normal(
&mut out.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(dtype),
loc,
scale,
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn normal_broadcast(
shape: &impl IntoShape,
dtype: Dtype,
loc: &Array,
scale: &Array,
key: &Array,
) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_normal_broadcast(
&mut out.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(dtype),
loc.0,
scale.0,
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn randint(
low: &Array,
high: &Array,
shape: &impl IntoShape,
dtype: Dtype,
key: &Array,
) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_randint(
&mut out.0,
low.0,
high.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(dtype),
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn categorical(logits: &Array, axis: i32, key: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_categorical(&mut out.0, logits.0, axis as c_int, key.0, default_stream())
})?;
Ok(out)
}
pub fn categorical_shape(
logits: &Array,
axis: i32,
shape: &impl IntoShape,
key: &Array,
) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_categorical_shape(
&mut out.0,
logits.0,
axis as c_int,
dim_ptr(s),
s.len(),
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn categorical_num_samples(
logits: &Array,
axis: i32,
num_samples: i32,
key: &Array,
) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_categorical_num_samples(
&mut out.0,
logits.0,
axis as c_int,
num_samples as c_int,
key.0,
default_stream(),
)
})?;
Ok(out)
}
pub fn gumbel(shape: &impl IntoShape, dtype: Dtype, key: &Array) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_gumbel(
&mut out.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(dtype),
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn truncated_normal(
lower: &Array,
upper: &Array,
shape: &impl IntoShape,
dtype: Dtype,
key: &Array,
) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_truncated_normal(
&mut out.0,
lower.0,
upper.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(dtype),
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn multivariate_normal(
mean: &Array,
cov: &Array,
shape: &impl IntoShape,
dtype: Dtype,
key: &Array,
) -> Result<Array> {
crate::ops::linalg_full::reject_empty_matrix(
cov,
"multivariate_normal: covariance matrix has a zero-length row or column dimension",
)?;
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_multivariate_normal(
&mut out.0,
mean.0,
cov.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(dtype),
key.0,
random_cpu_stream(),
)
})?;
Ok(out)
})
}
pub fn laplace(
shape: &impl IntoShape,
dtype: Dtype,
loc: f32,
scale: f32,
key: &Array,
) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_laplace(
&mut out.0,
dim_ptr(s),
s.len(),
mlxrs_sys::mlx_dtype::from(dtype),
loc,
scale,
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn bits(shape: &impl IntoShape, width: i32, key: &Array) -> Result<Array> {
shape.with_shape(|s| {
validate_dims(s)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_bits(
&mut out.0,
dim_ptr(s),
s.len(),
width as c_int,
key.0,
default_stream(),
)
})?;
Ok(out)
})
}
pub fn permutation(x: &Array, axis: i32, key: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_permutation(&mut out.0, x.0, axis as c_int, key.0, default_stream())
})?;
Ok(out)
}
pub fn permutation_arange(x: i32, key: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_random_permutation_arange(&mut out.0, x as c_int, key.0, default_stream())
})?;
Ok(out)
}