mdarray_linalg/utils.rs
1//! Utility functions for matrix printing, shape retrieval, identity
2//! generation, Kronecker product, trace, transpose operations, ...
3//!
4//! These functions were necessary for implementing this crate. They are
5//! exposed because they can be generally useful, but this is not meant to be
6//! a complete collection of linear algebra utilities at this time.
7
8use mdarray::{DSlice, DTensor, Layout, Shape, Slice, tensor};
9use num_complex::ComplexFloat;
10use num_traits::{One, Zero};
11
12/// Displays a numeric `mdarray` in a human-readable format (NumPy-style)
13pub fn pretty_print<T: ComplexFloat + std::fmt::Display>(mat: &DTensor<T, 2>)
14where
15 <T as num_complex::ComplexFloat>::Real: std::fmt::Display,
16{
17 let shape = mat.shape();
18 for i in 0..shape.0 {
19 for j in 0..shape.1 {
20 let v = mat[[i, j]];
21 print!("{:>10.4} {:+.4}i ", v.re(), v.im(),);
22 }
23 println!();
24 }
25 println!();
26}
27
28/// Safely casts a value to `i32`
29pub fn into_i32<T>(x: T) -> i32
30where
31 T: TryInto<i32>,
32 <T as TryInto<i32>>::Error: std::fmt::Debug,
33{
34 x.try_into().expect("dimension must fit into i32")
35}
36
37/// Returns the dimensions of an arbitrary number of matrices (e.g.,
38/// `A, B, C → (ma, na), (mb, nb), (mc, nc))`
39#[macro_export]
40macro_rules! get_dims {
41 ( $( $matrix:expr ),+ ) => {
42 (
43 $(
44 {
45 let shape = $matrix.shape();
46 (into_i32(shape.0), into_i32(shape.1))
47 }
48 ),*
49 )
50 };
51}
52
53/// Make sure that matrix shapes are compatible with `C = A * B`, and
54/// return the dimensions `(m, n, k)` safely cast to `i32`, where `C` is `(m
55/// x n)`, and `k` is the common dimension of `A` and `B`
56pub fn dims3(
57 a_shape: &(usize, usize),
58 b_shape: &(usize, usize),
59 c_shape: &(usize, usize),
60) -> (i32, i32, i32) {
61 let (m, k) = *a_shape;
62 let (k2, n) = *b_shape;
63 let (m2, n2) = *c_shape;
64
65 assert!(m == m2, "a and c must agree in number of rows");
66 assert!(n == n2, "b and c must agree in number of columns");
67 assert!(
68 k == k2,
69 "a's number of columns must be equal to b's number of rows"
70 );
71
72 (into_i32(m), into_i32(n), into_i32(k))
73}
74
75/// Make sure that matrix shapes are compatible with `A * B`, and return
76/// the dimensions `(m, n)` safely cast to `i32`
77pub fn dims2(a_shape: &(usize, usize), b_shape: &(usize, usize)) -> (i32, i32) {
78 let (m, k) = *a_shape;
79 let (k2, n) = *b_shape;
80
81 assert!(
82 k == k2,
83 "a's number of columns must be equal to b's number of rows"
84 );
85
86 (into_i32(m), into_i32(n))
87}
88
89/// Handles different stride layouts by selecting the correct memory
90/// order and stride for contiguous arrays
91#[macro_export]
92macro_rules! trans_stride {
93 ($x:expr, $same_order:expr, $other_order:expr) => {{
94 if $x.stride(1) == 1 {
95 ($same_order, into_i32($x.stride(0)))
96 } else {
97 {
98 assert!($x.stride(0) == 1, stringify!($x must be contiguous in one dimension));
99 ($other_order, into_i32($x.stride(1)))
100 }
101 }
102 }};
103}
104
105/// Transposes a matrix in-place. Dimensions stay the same, only the memory ordering changes.
106/// - For square matrices: swaps elements across the main diagonal.
107/// - For rectangular matrices: reshuffles data in a temporary buffer so that the
108/// same `(rows, cols)` slice now represents the transposed layout.
109pub fn transpose_in_place<T, L>(c: &mut DSlice<T, 2, L>)
110where
111 T: ComplexFloat + Default,
112 L: Layout,
113{
114 let (m, n) = *c.shape();
115
116 if n == m {
117 for i in 0..m {
118 for j in (i + 1)..n {
119 c.swap(i * n + j, j * n + i);
120 }
121 }
122 } else {
123 let mut result = tensor![[T::default(); m]; n];
124 for j in 0..n {
125 for i in 0..m {
126 result[j * m + i] = c[i * n + j];
127 }
128 }
129 for j in 0..n {
130 for i in 0..m {
131 c[j * m + i] = result[j * m + i];
132 }
133 }
134 }
135}
136
137/// Convert pivot indices to permutation matrix
138pub fn ipiv_to_perm_mat<T: ComplexFloat>(ipiv: &[i32], m: usize) -> DTensor<T, 2> {
139 let mut p = tensor![[T::zero(); m]; m];
140
141 for i in 0..m {
142 p[[i, i]] = T::one();
143 }
144
145 // Apply row swaps according to LAPACK's ipiv convention
146 for i in 0..ipiv.len() {
147 let pivot_row = (ipiv[i] - 1) as usize; // LAPACK uses 1-based indexing
148 if pivot_row != i {
149 for j in 0..m {
150 let temp = p[[i, j]];
151 p[[i, j]] = p[[pivot_row, j]];
152 p[[pivot_row, j]] = temp;
153 }
154 }
155 }
156
157 p
158}
159
160/// Given an input matrix of shape `(m × n)`, this function creates and returns
161/// a new matrix of shape `(n × m)`, where each element at position `(i, j)` in the
162/// original is moved to position `(j, i)` in the result.
163pub fn to_col_major<T, L>(c: &DSlice<T, 2, L>) -> DTensor<T, 2>
164where
165 T: ComplexFloat + Default + Clone,
166 L: Layout,
167{
168 let (m, n) = *c.shape();
169 let mut result = DTensor::<T, 2>::zeros([n, m]);
170
171 for i in 0..m {
172 for j in 0..n {
173 result[[j, i]] = c[[i, j]];
174 }
175 }
176
177 result
178}
179
180/// Computes the trace of a square matrix (sum of diagonal elements).
181/// # Examples
182/// ```
183/// use mdarray::tensor;
184/// use mdarray_linalg::trace;
185///
186/// let a = tensor![[1., 2., 3.],
187/// [4., 5., 6.],
188/// [7., 8., 9.]];
189///
190/// let tr = trace(&a);
191/// assert_eq!(tr, 15.0);
192/// ```
193pub fn trace<T, L>(a: &DSlice<T, 2, L>) -> T
194where
195 T: ComplexFloat + std::ops::Add<Output = T> + Copy,
196 L: Layout,
197{
198 let (m, n) = *a.shape();
199 assert_eq!(m, n, "trace is only defined for square matrices");
200
201 let mut tr = T::zero();
202 for i in 0..n {
203 tr = tr + a[[i, i]];
204 }
205 tr
206}
207
208/// Creates an identity matrix of size `n x n`.
209/// # Examples
210/// ```
211/// use mdarray::tensor;
212/// use mdarray_linalg::identity;
213///
214/// let i3 = identity::<f64>(3);
215/// assert_eq!(i3, tensor![[1.,0.,0.],[0.,1.,0.],[0.,0.,1.]]);
216/// ```
217pub fn identity<T: Zero + One>(n: usize) -> DTensor<T, 2> {
218 DTensor::<T, 2>::from_fn([n, n], |i| if i[0] == i[1] { T::one() } else { T::zero() })
219}
220
221/// Creates a diagonal matrix of size `n x n` with ones on a specified diagonal.
222///
223/// The diagonal can be shifted using `k`:
224/// - `k = 0` → main diagonal (default, standard identity)
225/// - `k > 0` → k-th diagonal above the main one
226/// - `k < 0` → k-th diagonal below the main one
227/// # Examples
228/// ```
229/// use mdarray::tensor;
230/// use mdarray_linalg::identity_k;
231///
232/// let i3 = identity_k::<f64>(3, 1);
233/// assert_eq!(i3, tensor![[0.,1.,0.],[0.,0.,1.],[0.,0.,0.]]);
234/// ```
235pub fn identity_k<T: Zero + One>(n: usize, k: isize) -> DTensor<T, 2> {
236 DTensor::<T, 2>::from_fn([n, n], |i| {
237 if (i[1] as isize - i[0] as isize) == k {
238 T::one()
239 } else {
240 T::zero()
241 }
242 })
243}
244
245/// Computes the Kronecker product of two 2D tensors.
246///
247/// The Kronecker product of matrices `A (m×n)` and `B (p×q)` is defined as the
248/// block matrix of size `(m*p) × (n*q)` where each element `a[i, j]` of `A`
249/// multiplies the entire matrix `B`.
250///
251/// # Examples
252/// ```
253/// use mdarray::tensor;
254/// use mdarray_linalg::kron;
255///
256/// let a = tensor![[1., 2.],
257/// [3., 4.]];
258///
259/// let b = tensor![[0., 5.],
260/// [6., 7.]];
261///
262/// let k = kron(&a, &b);
263///
264/// assert_eq!(k, tensor![
265/// [ 0., 5., 0., 10.],
266/// [ 6., 7., 12., 14.],
267/// [ 0., 15., 0., 20.],
268/// [18., 21., 24., 28.]
269/// ]);
270/// ```
271pub fn kron<T, La, Lb>(a: &DSlice<T, 2, La>, b: &DSlice<T, 2, Lb>) -> DTensor<T, 2>
272where
273 T: ComplexFloat + std::ops::Mul<Output = T> + Copy,
274 La: Layout,
275 Lb: Layout,
276{
277 let (ma, na) = *a.shape();
278 let (mb, nb) = *b.shape();
279
280 let out_shape = [ma * mb, na * nb];
281
282 DTensor::<T, 2>::from_fn(out_shape, |idx| {
283 let i = idx[0];
284 let j = idx[1];
285
286 let ai = i / mb;
287 let bi = i % mb;
288 let aj = j / nb;
289 let bj = j % nb;
290
291 a[[ai, aj]] * b[[bi, bj]]
292 })
293}
294
295/// Converts a flat index to multidimensional coordinates.
296///
297/// # Examples
298///
299/// ```
300/// use mdarray::DTensor;
301/// use mdarray_linalg::unravel_index;
302///
303/// let x = DTensor::<usize, 2>::from_fn([2,3], |i| i[0] + i[1]);
304///
305/// assert_eq!(unravel_index(&x, 0), vec![0, 0]);
306/// assert_eq!(unravel_index(&x, 4), vec![1, 1]);
307/// assert_eq!(unravel_index(&x, 5), vec![1, 2]);
308/// ```
309///
310/// # Panics
311///
312/// Panics if `flat` is out of bounds (>= `x.len()`).
313pub fn unravel_index<T, S: Shape, L: Layout>(x: &Slice<T, S, L>, mut flat: usize) -> Vec<usize> {
314 let rank = x.rank();
315
316 assert!(
317 flat < x.len(),
318 "flat index out of bounds: {} >= {}",
319 flat,
320 x.len()
321 );
322
323 let mut coords = vec![0usize; rank];
324
325 for i in (0..rank).rev() {
326 let dim = x.shape().dim(i);
327 coords[i] = flat % dim;
328 flat /= dim;
329 }
330
331 coords
332}