use std::{
cell::Cell,
ffi::{CStr, c_int},
};
use crate::{
array::Array,
dtype::Dtype,
error::{
EmptyInputPayload, Error, InvariantViolationPayload, OutOfRangePayload, RankMismatchPayload,
Result, check, check_vector_array_handle,
},
ffi::{VectorArrayGuard, drain_vector},
ops::{
reduction::{any_axes, sum_axes},
shape::squeeze_axes,
},
shape::dim_ptr,
stream::default_stream,
};
thread_local! {
static CPU_STREAM: Cell<Option<mlxrs_sys::mlx_stream>> = const { Cell::new(None) };
}
fn linalg_cpu_stream() -> mlxrs_sys::mlx_stream {
crate::error::ensure_handler_installed();
crate::stream::assert_streams_not_cleared();
CPU_STREAM.with(|cell| {
if let Some(s) = cell.get() {
return s;
}
let s = unsafe { mlxrs_sys::mlx_default_cpu_stream_new() };
if s.ctx.is_null() {
panic!(
"mlxrs::ops::linalg_full: mlx_default_cpu_stream_new returned NULL ctx — \
CPU stream initialization failed. Aborting."
);
}
cell.set(Some(s));
s
})
}
pub(crate) fn reject_empty_matrix(a: &Array, op: &'static str) -> Result<()> {
let shape = a.shape();
if shape.len() >= 2 && (shape[shape.len() - 1] == 0 || shape[shape.len() - 2] == 0) {
return Err(Error::EmptyInput(EmptyInputPayload::new(op)));
}
Ok(())
}
fn reject_empty_matrix_axes(a: &Array, axes: [i32; 2], op: &'static str) -> Result<()> {
let shape = a.shape();
let ndim = shape.len();
let mut resolved = [0usize; 2];
for (slot, ax) in resolved.iter_mut().zip(axes) {
let r = if ax < 0 {
ax as isize + ndim as isize
} else {
ax as isize
};
if r < 0 || (r as usize) >= ndim {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"linalg norm: matrix-norm reduction axis",
"must be in range [-ndim, ndim)",
format!("{ax}"),
)));
}
*slot = r as usize;
}
if resolved[0] == resolved[1] {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"linalg norm: matrix-norm reduction axes",
"the two reduction axes must be distinct (a matrix reduction needs two different axes)",
format!("{}", resolved[0]),
)));
}
if shape[resolved[0]] == 0 || shape[resolved[1]] == 0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(op)));
}
Ok(())
}
pub fn inv(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_linalg_inv(&mut out.0, a.0, linalg_cpu_stream()) })?;
Ok(out)
}
pub fn tri_inv(a: &Array, upper: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_linalg_tri_inv(&mut out.0, a.0, upper, linalg_cpu_stream()) })?;
Ok(out)
}
pub fn pinv(a: &Array) -> Result<Array> {
reject_empty_matrix(
a,
"pinv: input matrix has a zero-length row or column dimension",
)?;
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_linalg_pinv(&mut out.0, a.0, linalg_cpu_stream()) })?;
Ok(out)
}
pub fn cholesky_inv(a: &Array, upper: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_cholesky_inv(&mut out.0, a.0, upper, linalg_cpu_stream())
})?;
Ok(out)
}
pub fn cholesky(a: &Array, upper: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_linalg_cholesky(&mut out.0, a.0, upper, linalg_cpu_stream()) })?;
Ok(out)
}
pub fn qr(a: &Array) -> Result<(Array, Array)> {
let mut q = Array(unsafe { mlxrs_sys::mlx_array_new() });
let mut r = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_linalg_qr(&mut q.0, &mut r.0, a.0, linalg_cpu_stream()) })?;
Ok((q, r))
}
pub fn svd(a: &Array, compute_uv: bool) -> Result<Vec<Array>> {
reject_empty_matrix(
a,
"svd: input matrix has a zero-length row or column dimension",
)?;
let s = linalg_cpu_stream();
let mut vec_out = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(vec_out)?;
let _vec_guard = VectorArrayGuard(vec_out);
check(unsafe { mlxrs_sys::mlx_linalg_svd(&mut vec_out, a.0, compute_uv, s) })?;
drain_vector(vec_out)
}
pub fn lu(a: &Array) -> Result<Vec<Array>> {
let s = linalg_cpu_stream();
let mut vec_out = unsafe { mlxrs_sys::mlx_vector_array_new() };
check_vector_array_handle(vec_out)?;
let _vec_guard = VectorArrayGuard(vec_out);
check(unsafe { mlxrs_sys::mlx_linalg_lu(&mut vec_out, a.0, s) })?;
drain_vector(vec_out)
}
pub fn lu_factor(a: &Array) -> Result<(Array, Array)> {
let mut out0 = Array(unsafe { mlxrs_sys::mlx_array_new() });
let mut out1 = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_lu_factor(&mut out0.0, &mut out1.0, a.0, linalg_cpu_stream())
})?;
Ok((out0, out1))
}
pub fn solve(a: &Array, b: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_linalg_solve(&mut out.0, a.0, b.0, linalg_cpu_stream()) })?;
Ok(out)
}
pub fn solve_triangular(a: &Array, b: &Array, upper: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_solve_triangular(&mut out.0, a.0, b.0, upper, linalg_cpu_stream())
})?;
Ok(out)
}
pub fn eig(a: &Array) -> Result<(Array, Array)> {
let mut vals = Array(unsafe { mlxrs_sys::mlx_array_new() });
let mut vecs = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_linalg_eig(&mut vals.0, &mut vecs.0, a.0, linalg_cpu_stream()) })?;
Ok((vals, vecs))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Uplo {
Upper,
Lower,
}
impl Uplo {
#[inline(always)]
pub const fn as_cstr(self) -> &'static std::ffi::CStr {
match self {
Uplo::Upper => c"U",
Uplo::Lower => c"L",
}
}
}
impl Default for Uplo {
fn default() -> Self {
Uplo::Lower
}
}
pub fn eigh(a: &Array, uplo: Uplo) -> Result<(Array, Array)> {
let mut vals = Array(unsafe { mlxrs_sys::mlx_array_new() });
let mut vecs = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_eigh(
&mut vals.0,
&mut vecs.0,
a.0,
uplo.as_cstr().as_ptr(),
linalg_cpu_stream(),
)
})?;
Ok((vals, vecs))
}
pub fn eigvals(a: &Array) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe { mlxrs_sys::mlx_linalg_eigvals(&mut out.0, a.0, linalg_cpu_stream()) })?;
Ok(out)
}
pub fn eigvalsh(a: &Array, uplo: Uplo) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_eigvalsh(
&mut out.0,
a.0,
uplo.as_cstr().as_ptr(),
linalg_cpu_stream(),
)
})?;
Ok(out)
}
pub fn norm(a: &Array, ord: f64, axis: &[i32], keepdims: bool) -> Result<Array> {
if ord == 2.0 || ord == -2.0 {
let matrix_axes: Option<[i32; 2]> = match axis.len() {
2 => Some([axis[0], axis[1]]),
_ => None,
};
if let Some(axes) = matrix_axes {
reject_empty_matrix_axes(
a,
axes,
"norm: matrix has a zero-length axis for the SVD-backed spectral order \
(ord = 2 / -2)",
)?;
}
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_norm(
&mut out.0,
a.0,
ord,
dim_ptr(axis),
axis.len(),
keepdims,
default_stream(),
)
})?;
Ok(out)
}
pub fn norm_matrix(a: &Array, ord: &CStr, axis: &[i32], keepdims: bool) -> Result<Array> {
if ord.to_bytes() == b"nuc" {
let matrix_axes: Option<[i32; 2]> = match axis.len() {
2 => Some([axis[0], axis[1]]),
_ => None,
};
if let Some(axes) = matrix_axes {
reject_empty_matrix_axes(
a,
axes,
"norm_matrix: matrix has a zero-length axis for the SVD-backed nuclear \
order (ord = \"nuc\")",
)?;
}
}
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_norm_matrix(
&mut out.0,
a.0,
ord.as_ptr(),
dim_ptr(axis),
axis.len(),
keepdims,
default_stream(),
)
})?;
Ok(out)
}
pub fn norm_l2(a: &Array, axis: &[i32], keepdims: bool) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_norm_l2(
&mut out.0,
a.0,
dim_ptr(axis),
axis.len(),
keepdims,
default_stream(),
)
})?;
Ok(out)
}
pub fn cross(a: &Array, b: &Array, axis: i32) -> Result<Array> {
let mut out = Array(unsafe { mlxrs_sys::mlx_array_new() });
check(unsafe {
mlxrs_sys::mlx_linalg_cross(&mut out.0, a.0, b.0, axis as c_int, default_stream())
})?;
Ok(out)
}
fn validate_det(a: &Array, context: &'static str) -> Result<Dtype> {
let dtype = a.dtype()?;
if dtype == Dtype::Complex64 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
context,
"complex inputs are not supported",
)));
}
let shape = a.shape();
if shape.len() < 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"linalg det/slogdet: input must be rank >= 2 (a square matrix or a batch of them)",
shape.len() as u32,
shape,
)));
}
if shape[shape.len() - 1] != shape[shape.len() - 2] {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
context,
"only defined for square matrices (the trailing two dimensions must be equal)",
)));
}
Ok(match dtype {
Dtype::F16 | Dtype::F32 | Dtype::F64 | Dtype::BF16 => dtype,
_ => Dtype::F32,
})
}
fn det_raw_small(a: &Array, n: usize) -> Result<Array> {
let ndim = a.ndim();
let elem = |i: i32, j: i32| -> Result<Array> {
let mut start = vec![0i32; ndim];
let mut stop: Vec<i32> = a.shape().iter().map(|&d| d as i32).collect();
let strides = vec![1i32; ndim];
start[ndim - 2] = i;
stop[ndim - 2] = i + 1;
start[ndim - 1] = j;
stop[ndim - 1] = j + 1;
squeeze_axes(
&a.slice(&start, &stop, &strides)?,
&[(ndim - 2) as i32, (ndim - 1) as i32],
)
};
match n {
0 => {
let shape = a.shape();
let batch = shape[..shape.len() - 2].to_vec();
Array::full::<f32>(&[0i32; 0], 1.0)?
.astype(a.dtype()?)?
.broadcast_to(&batch)
}
1 => elem(0, 0),
2 => elem(0, 0)?
.multiply(&elem(1, 1)?)?
.subtract(&elem(0, 1)?.multiply(&elem(1, 0)?)?),
_ => {
let (a00, a01, a02) = (elem(0, 0)?, elem(0, 1)?, elem(0, 2)?);
let (a10, a11, a12) = (elem(1, 0)?, elem(1, 1)?, elem(1, 2)?);
let (a20, a21, a22) = (elem(2, 0)?, elem(2, 1)?, elem(2, 2)?);
let m0 = a00.multiply(&a11.multiply(&a22)?.subtract(&a12.multiply(&a21)?)?)?;
let m1 = a01.multiply(&a10.multiply(&a22)?.subtract(&a12.multiply(&a20)?)?)?;
let m2 = a02.multiply(&a10.multiply(&a21)?.subtract(&a11.multiply(&a20)?)?)?;
m0.subtract(&m1)?.add(&m2)
}
}
}
fn slogdet_impl(input: &Array, dtype: Dtype, n: usize) -> Result<(Array, Array)> {
if n <= 3 {
let raw = det_raw_small(input, n)?;
return Ok((raw.sign()?, raw.abs()?.log()?));
}
let (lu, pivots) = lu_factor(input)?;
let diag = lu.diagonal(0, -2, -1)?;
let shape = input.shape();
let k = shape[shape.len() - 1].min(shape[shape.len() - 2]);
let iota = Array::arange::<u32>(0, k as f64, 1)?;
let parity = sum_axes(&pivots.not_equal(&iota)?, &[-1], false)?.astype(Dtype::I32)?;
let zero = Array::full::<f32>(&[0i32; 0], 0.0)?.astype(dtype)?;
let num_neg = sum_axes(&diag.less(&zero)?, &[-1], false)?.astype(Dtype::I32)?;
let one = Array::full::<i32>(&[0i32; 0], 1)?;
let two = Array::full::<i32>(&[0i32; 0], 2)?;
let total = parity.add(&num_neg)?;
let sign_val = one
.subtract(&two.multiply(&total.remainder(&two)?)?)?
.astype(dtype)?;
let logabsdet = sum_axes(&diag.abs()?.log()?, &[-1], false)?;
let is_zero = any_axes(&diag.equal(&zero)?, &[-1], false)?;
let neg_inf = Array::full::<f32>(&[0i32; 0], f32::NEG_INFINITY)?.astype(dtype)?;
Ok((
is_zero.select(&zero, &sign_val)?,
is_zero.select(&neg_inf, &logabsdet)?,
))
}
pub fn det(a: &Array) -> Result<Array> {
let dtype = validate_det(a, "linalg::det")?;
let input = a.astype(dtype)?;
let shape = input.shape();
let n = shape[shape.len() - 1];
if n <= 3 {
return det_raw_small(&input, n);
}
let (sign_val, logabsdet) = slogdet_impl(&input, dtype, n)?;
sign_val.multiply(&logabsdet.exp()?)
}
pub fn slogdet(a: &Array) -> Result<(Array, Array)> {
let dtype = validate_det(a, "linalg::slogdet")?;
let input = a.astype(dtype)?;
let shape = input.shape();
let n = shape[shape.len() - 1];
slogdet_impl(&input, dtype, n)
}