use std::ffi::CString;
use crate::{
array::Array,
dtype::Dtype,
error::{Error, InteriorNulPayload, LengthMismatchPayload, Result, check},
ffi::{VectorArrayGuard, drain_vector, opt_array},
stream::default_stream,
};
#[inline(always)]
fn opt_int(v: i32) -> mlxrs_sys::mlx_optional_int {
mlxrs_sys::mlx_optional_int {
value: v,
has_value: true,
}
}
#[inline(always)]
fn mode_cstring(mode: &str) -> Result<CString> {
CString::new(mode).map_err(|_| {
let _ = mode;
Error::InteriorNul(InteriorNulPayload::new(
"mlxrs::ops::quantized::mode_cstring",
"mode",
))
})
}
pub fn quantize(
w: &Array,
group_size: i32,
bits: i32,
mode: &str,
global_scale: Option<&Array>,
) -> Result<(Array, Array, Option<Array>)> {
let mode_c = mode_cstring(mode)?;
let (gs, _gs_guard) = opt_array(global_scale);
let s = default_stream();
let mut vec_out = unsafe { mlxrs_sys::mlx_vector_array_new() };
crate::error::check_vector_array_handle(vec_out)?;
let _vec_guard = VectorArrayGuard(vec_out);
check(unsafe {
mlxrs_sys::mlx_quantize(
&mut vec_out,
w.0,
opt_int(group_size),
opt_int(bits),
mode_c.as_ptr(),
gs,
s,
)
})?;
let mut parts = drain_vector(vec_out)?;
let (w_q, scales, biases) = match parts.len() {
2 => {
let scales = parts.pop().expect("len checked == 2");
let w_q = parts.pop().expect("len checked == 2");
(w_q, scales, None)
}
3 => {
let biases = parts.pop().expect("len checked == 3");
let scales = parts.pop().expect("len checked == 3");
let w_q = parts.pop().expect("len checked == 3");
(w_q, scales, Some(biases))
}
n => return Err(unexpected_arity(n)),
};
Ok((w_q, scales, biases))
}
#[allow(clippy::too_many_arguments)]
pub fn dequantize(
w: &Array,
scales: &Array,
biases: Option<&Array>,
group_size: i32,
bits: i32,
mode: &str,
global_scale: Option<&Array>,
dtype: Option<Dtype>,
) -> Result<Array> {
let mode_c = mode_cstring(mode)?;
let (biases_h, _biases_guard) = opt_array(biases);
let (gs, _gs_guard) = opt_array(global_scale);
let dtype_opt = mlxrs_sys::mlx_optional_dtype {
value: dtype
.map(Into::into)
.unwrap_or(mlxrs_sys::mlx_dtype__MLX_FLOAT32),
has_value: dtype.is_some(),
};
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_dequantize(
&mut out.0,
w.0,
scales.0,
biases_h,
opt_int(group_size),
opt_int(bits),
mode_c.as_ptr(),
gs,
dtype_opt,
default_stream(),
)
})?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn quantized_matmul(
x: &Array,
w: &Array,
scales: &Array,
biases: Option<&Array>,
transpose: bool,
group_size: i32,
bits: i32,
mode: &str,
) -> Result<Array> {
let mode_c = mode_cstring(mode)?;
let (biases_h, _biases_guard) = opt_array(biases);
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_quantized_matmul(
&mut out.0,
x.0,
w.0,
scales.0,
biases_h,
transpose,
opt_int(group_size),
opt_int(bits),
mode_c.as_ptr(),
default_stream(),
)
})?;
Ok(out)
}
#[allow(clippy::too_many_arguments)]
pub fn gather_qmm(
x: &Array,
w: &Array,
scales: &Array,
biases: Option<&Array>,
lhs_indices: Option<&Array>,
rhs_indices: Option<&Array>,
transpose: bool,
group_size: i32,
bits: i32,
mode: &str,
sorted_indices: bool,
) -> Result<Array> {
let mode_c = mode_cstring(mode)?;
let (biases_h, _biases_guard) = opt_array(biases);
let (lhs_h, _lhs_guard) = opt_array(lhs_indices);
let (rhs_h, _rhs_guard) = opt_array(rhs_indices);
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_gather_qmm(
&mut out.0,
x.0,
w.0,
scales.0,
biases_h,
lhs_h,
rhs_h,
transpose,
opt_int(group_size),
opt_int(bits),
mode_c.as_ptr(),
sorted_indices,
default_stream(),
)
})?;
Ok(out)
}
fn unexpected_arity(n: usize) -> Error {
Error::LengthMismatch(LengthMismatchPayload::new(
"ops::quantized::quantize: mlx_quantize output arity (must be 2 for bias-less float modes or 3 for affine)",
3,
n,
))
}