mdarray_linalg/
utils.rs

1//! Utility functions for matrix printing, shape retrieval, identity
2//! generation, Kronecker product, trace, transpose operations, ...
3//!
4//! These functions were necessary for implementing this crate.  They are
5//! exposed because they can be generally useful, but this is not meant to be
6//! a complete collection of linear algebra utilities at this time.
7
8use mdarray::{DSlice, DTensor, Layout, Shape, Slice, tensor};
9use num_complex::ComplexFloat;
10use num_traits::{One, Zero};
11
12/// Displays a numeric `mdarray` in a human-readable format (NumPy-style)
13pub fn pretty_print<T: ComplexFloat + std::fmt::Display>(mat: &DTensor<T, 2>)
14where
15    <T as num_complex::ComplexFloat>::Real: std::fmt::Display,
16{
17    let shape = mat.shape();
18    for i in 0..shape.0 {
19        for j in 0..shape.1 {
20            let v = mat[[i, j]];
21            print!("{:>10.4} {:+.4}i  ", v.re(), v.im(),);
22        }
23        println!();
24    }
25    println!();
26}
27
28/// Safely casts a value to `i32`
29pub fn into_i32<T>(x: T) -> i32
30where
31    T: TryInto<i32>,
32    <T as TryInto<i32>>::Error: std::fmt::Debug,
33{
34    x.try_into().expect("dimension must fit into i32")
35}
36
37/// Returns the dimensions of an arbitrary number of matrices (e.g.,
38/// `A, B, C → (ma, na), (mb, nb), (mc, nc))`
39#[macro_export]
40macro_rules! get_dims {
41    ( $( $matrix:expr ),+ ) => {
42        (
43            $(
44                {
45                    let shape = $matrix.shape();
46                    (into_i32(shape.0), into_i32(shape.1))
47                }
48            ),*
49        )
50    };
51}
52
53/// Make sure that matrix shapes are compatible with `C = A * B`, and
54/// return the dimensions `(m, n, k)` safely cast to `i32`, where `C` is `(m
55/// x n)`, and `k` is the common dimension of `A` and `B`
56pub fn dims3(
57    a_shape: &(usize, usize),
58    b_shape: &(usize, usize),
59    c_shape: &(usize, usize),
60) -> (i32, i32, i32) {
61    let (m, k) = *a_shape;
62    let (k2, n) = *b_shape;
63    let (m2, n2) = *c_shape;
64
65    assert!(m == m2, "a and c must agree in number of rows");
66    assert!(n == n2, "b and c must agree in number of columns");
67    assert!(
68        k == k2,
69        "a's number of columns must be equal to b's number of rows"
70    );
71
72    (into_i32(m), into_i32(n), into_i32(k))
73}
74
75/// Make sure that matrix shapes are compatible with `A * B`, and return
76/// the dimensions `(m, n)` safely cast to `i32`
77pub fn dims2(a_shape: &(usize, usize), b_shape: &(usize, usize)) -> (i32, i32) {
78    let (m, k) = *a_shape;
79    let (k2, n) = *b_shape;
80
81    assert!(
82        k == k2,
83        "a's number of columns must be equal to b's number of rows"
84    );
85
86    (into_i32(m), into_i32(n))
87}
88
89/// Handles different stride layouts by selecting the correct memory
90/// order and stride for contiguous arrays
91#[macro_export]
92macro_rules! trans_stride {
93    ($x:expr, $same_order:expr, $other_order:expr) => {{
94        if $x.stride(1) == 1 {
95            ($same_order, into_i32($x.stride(0)))
96        } else {
97            {
98                assert!($x.stride(0) == 1, stringify!($x must be contiguous in one dimension));
99                ($other_order, into_i32($x.stride(1)))
100            }
101        }
102    }};
103}
104
105/// Transposes a matrix in-place. Dimensions stay the same, only the memory ordering changes.
106/// - For square matrices: swaps elements across the main diagonal.
107/// - For rectangular matrices: reshuffles data in a temporary buffer so that the
108///   same `(rows, cols)` slice now represents the transposed layout.
109pub fn transpose_in_place<T, L>(c: &mut DSlice<T, 2, L>)
110where
111    T: ComplexFloat + Default,
112    L: Layout,
113{
114    let (m, n) = *c.shape();
115
116    if n == m {
117        for i in 0..m {
118            for j in (i + 1)..n {
119                c.swap(i * n + j, j * n + i);
120            }
121        }
122    } else {
123        let mut result = tensor![[T::default(); m]; n];
124        for j in 0..n {
125            for i in 0..m {
126                result[j * m + i] = c[i * n + j];
127            }
128        }
129        for j in 0..n {
130            for i in 0..m {
131                c[j * m + i] = result[j * m + i];
132            }
133        }
134    }
135}
136
137/// Convert pivot indices to permutation matrix
138pub fn ipiv_to_perm_mat<T: ComplexFloat>(ipiv: &[i32], m: usize) -> DTensor<T, 2> {
139    let mut p = tensor![[T::zero(); m]; m];
140
141    for i in 0..m {
142        p[[i, i]] = T::one();
143    }
144
145    // Apply row swaps according to LAPACK's ipiv convention
146    for i in 0..ipiv.len() {
147        let pivot_row = (ipiv[i] - 1) as usize; // LAPACK uses 1-based indexing
148        if pivot_row != i {
149            for j in 0..m {
150                let temp = p[[i, j]];
151                p[[i, j]] = p[[pivot_row, j]];
152                p[[pivot_row, j]] = temp;
153            }
154        }
155    }
156
157    p
158}
159
160/// Given an input matrix of shape `(m × n)`, this function creates and returns
161/// a new matrix of shape `(n × m)`, where each element at position `(i, j)` in the
162/// original is moved to position `(j, i)` in the result.
163pub fn to_col_major<T, L>(c: &DSlice<T, 2, L>) -> DTensor<T, 2>
164where
165    T: ComplexFloat + Default + Clone,
166    L: Layout,
167{
168    let (m, n) = *c.shape();
169    let mut result = DTensor::<T, 2>::zeros([n, m]);
170
171    for i in 0..m {
172        for j in 0..n {
173            result[[j, i]] = c[[i, j]];
174        }
175    }
176
177    result
178}
179
180/// Computes the trace of a square matrix (sum of diagonal elements).
181/// # Examples
182/// ```
183/// use mdarray::tensor;
184/// use mdarray_linalg::trace;
185///
186/// let a = tensor![[1., 2., 3.],
187///                 [4., 5., 6.],
188///                 [7., 8., 9.]];
189///
190/// let tr = trace(&a);
191/// assert_eq!(tr, 15.0);
192/// ```
193pub fn trace<T, L>(a: &DSlice<T, 2, L>) -> T
194where
195    T: ComplexFloat + std::ops::Add<Output = T> + Copy,
196    L: Layout,
197{
198    let (m, n) = *a.shape();
199    assert_eq!(m, n, "trace is only defined for square matrices");
200
201    let mut tr = T::zero();
202    for i in 0..n {
203        tr = tr + a[[i, i]];
204    }
205    tr
206}
207
208/// Creates an identity matrix of size `n x n`.
209/// # Examples
210/// ```
211/// use mdarray::tensor;
212/// use mdarray_linalg::identity;
213///
214/// let i3 = identity::<f64>(3);
215/// assert_eq!(i3, tensor![[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]]);
216/// ```
217pub fn identity<T: Zero + One>(n: usize) -> DTensor<T, 2> {
218    DTensor::<T, 2>::from_fn([n, n], |i| if i[0] == i[1] { T::one() } else { T::zero() })
219}
220
221/// Creates a diagonal matrix of size `n x n` with ones on a specified diagonal.
222///
223/// The diagonal can be shifted using `k`:
224/// - `k = 0` → main diagonal (default, standard identity)
225/// - `k > 0` → k-th diagonal above the main one
226/// - `k < 0` → k-th diagonal below the main one
227/// # Examples
228/// ```
229/// use mdarray::tensor;
230/// use mdarray_linalg::identity_k;
231///
232/// let i3 = identity_k::<f64>(3, 1);
233/// assert_eq!(i3, tensor![[0.,1.,0.],[0.,0.,1.],[0.,0.,0.]]);
234/// ```
235pub fn identity_k<T: Zero + One>(n: usize, k: isize) -> DTensor<T, 2> {
236    DTensor::<T, 2>::from_fn([n, n], |i| {
237        if (i[1] as isize - i[0] as isize) == k {
238            T::one()
239        } else {
240            T::zero()
241        }
242    })
243}
244
245/// Computes the Kronecker product of two 2D tensors.
246///
247/// The Kronecker product of matrices `A (m×n)` and `B (p×q)` is defined as the
248/// block matrix of size `(m*p) × (n*q)` where each element `a[i, j]` of `A`
249/// multiplies the entire matrix `B`.
250///
251/// # Examples
252/// ```
253/// use mdarray::tensor;
254/// use mdarray_linalg::kron;
255///
256/// let a = tensor![[1., 2.],
257///                 [3., 4.]];
258///
259/// let b = tensor![[0., 5.],
260///                 [6., 7.]];
261///
262/// let k = kron(&a, &b);
263///
264/// assert_eq!(k, tensor![
265///     [ 0.,  5.,  0., 10.],
266///     [ 6.,  7., 12., 14.],
267///     [ 0., 15.,  0., 20.],
268///     [18., 21., 24., 28.]
269/// ]);
270/// ```
271pub fn kron<T, La, Lb>(a: &DSlice<T, 2, La>, b: &DSlice<T, 2, Lb>) -> DTensor<T, 2>
272where
273    T: ComplexFloat + std::ops::Mul<Output = T> + Copy,
274    La: Layout,
275    Lb: Layout,
276{
277    let (ma, na) = *a.shape();
278    let (mb, nb) = *b.shape();
279
280    let out_shape = [ma * mb, na * nb];
281
282    DTensor::<T, 2>::from_fn(out_shape, |idx| {
283        let i = idx[0];
284        let j = idx[1];
285
286        let ai = i / mb;
287        let bi = i % mb;
288        let aj = j / nb;
289        let bj = j % nb;
290
291        a[[ai, aj]] * b[[bi, bj]]
292    })
293}
294
295/// Converts a flat index to multidimensional coordinates.
296///
297/// # Examples
298///
299/// ```
300/// use mdarray::DTensor;
301/// use mdarray_linalg::unravel_index;
302///
303/// let x = DTensor::<usize, 2>::from_fn([2,3], |i| i[0] + i[1]);
304///
305/// assert_eq!(unravel_index(&x, 0), vec![0, 0]);
306/// assert_eq!(unravel_index(&x, 4), vec![1, 1]);
307/// assert_eq!(unravel_index(&x, 5), vec![1, 2]);
308/// ```
309///
310/// # Panics
311///
312/// Panics if `flat` is out of bounds (>= `x.len()`).
313pub fn unravel_index<T, S: Shape, L: Layout>(x: &Slice<T, S, L>, mut flat: usize) -> Vec<usize> {
314    let rank = x.rank();
315
316    assert!(
317        flat < x.len(),
318        "flat index out of bounds: {} >= {}",
319        flat,
320        x.len()
321    );
322
323    let mut coords = vec![0usize; rank];
324
325    for i in (0..rank).rev() {
326        let dim = x.shape().dim(i);
327        coords[i] = flat % dim;
328        flat /= dim;
329    }
330
331    coords
332}