Skip to main content

diffsol/matrix/
mod.rs

1use std::fmt::Debug;
2use std::ops::{Add, AddAssign, Mul, MulAssign, Sub, SubAssign};
3
4use crate::error::DiffsolError;
5use crate::scalar::Scale;
6use crate::vector::VectorHost;
7use crate::{Context, IndexType, Scalar, Vector, VectorIndex};
8
9use extract_block::combine;
10use num_traits::{One, Zero};
11use sparsity::{Dense, MatrixSparsity, MatrixSparsityRef};
12
13#[cfg(feature = "cuda")]
14pub mod cuda;
15
16#[cfg(feature = "nalgebra")]
17pub mod dense_nalgebra_serial;
18
19#[cfg(feature = "faer")]
20pub mod dense_faer_serial;
21
22#[cfg(feature = "faer")]
23pub mod sparse_faer;
24
25pub mod default_solver;
26pub mod extract_block;
27pub mod sparsity;
28
29#[macro_use]
30mod utils;
31
32/// Common interface for matrix types, providing access to scalar type, context, and dimensions.
33pub trait MatrixCommon: Sized + Debug {
34    type V: Vector<T = Self::T, C = Self::C, Index: VectorIndex<C = Self::C>>;
35    type T: Scalar;
36    type C: Context;
37    type Inner;
38
39    /// Get the number of rows in this matrix.
40    fn nrows(&self) -> IndexType;
41    /// Get the number of columns in this matrix.
42    fn ncols(&self) -> IndexType;
43    /// Get a reference to the inner representation of the matrix.
44    fn inner(&self) -> &Self::Inner;
45}
46
47impl<M> MatrixCommon for &M
48where
49    M: MatrixCommon,
50{
51    type T = M::T;
52    type V = M::V;
53    type C = M::C;
54    type Inner = M::Inner;
55
56    fn nrows(&self) -> IndexType {
57        M::nrows(*self)
58    }
59    fn ncols(&self) -> IndexType {
60        M::ncols(*self)
61    }
62    fn inner(&self) -> &Self::Inner {
63        M::inner(*self)
64    }
65}
66
67impl<M> MatrixCommon for &mut M
68where
69    M: MatrixCommon,
70{
71    type T = M::T;
72    type V = M::V;
73    type C = M::C;
74    type Inner = M::Inner;
75
76    fn ncols(&self) -> IndexType {
77        M::ncols(*self)
78    }
79    fn nrows(&self) -> IndexType {
80        M::nrows(*self)
81    }
82    fn inner(&self) -> &Self::Inner {
83        M::inner(*self)
84    }
85}
86
87/// Operations on matrices by value (addition and subtraction).
88///
89/// This trait defines matrix addition and subtraction when both operands are owned or references.
90pub trait MatrixOpsByValue<Rhs = Self, Output = Self>:
91    MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
92{
93}
94
95impl<M, Rhs, Output> MatrixOpsByValue<Rhs, Output> for M where
96    M: MatrixCommon + Add<Rhs, Output = Output> + Sub<Rhs, Output = Output>
97{
98}
99
100/// In-place operations on matrices (addition and subtraction).
101///
102/// This trait defines in-place matrix addition and subtraction (self += rhs, self -= rhs).
103pub trait MatrixMutOpsByValue<Rhs = Self>: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
104
105impl<M, Rhs> MatrixMutOpsByValue<Rhs> for M where M: MatrixCommon + AddAssign<Rhs> + SubAssign<Rhs> {}
106
107/// A trait allowing for references to implement matrix operations
108pub trait MatrixRef<M: MatrixCommon>: Mul<Scale<M::T>, Output = M> {}
109impl<RefT, M: MatrixCommon> MatrixRef<M> for RefT where RefT: Mul<Scale<M::T>, Output = M> {}
110
111/// A mutable view of a dense matrix, supporting in-place operations and modifications.
112///
113/// This trait represents a temporary mutable reference to a matrix's data, allowing in-place
114/// arithmetic operations (+=, -=, *=) and matrix-matrix multiplication. Mutable views can be
115/// created via the `columns_mut()` or `column_mut()` methods on a `DenseMatrix`.
116pub trait MatrixViewMut<'a>:
117    for<'b> MatrixMutOpsByValue<&'b Self>
118    + for<'b> MatrixMutOpsByValue<&'b Self::View>
119    + MulAssign<Scale<Self::T>>
120{
121    type Owned;
122    type View;
123    /// Convert this mutable view into an owned matrix, cloning the data if necessary.
124    fn into_owned(self) -> Self::Owned;
125    /// Perform matrix-matrix multiplication with owned matrices: self = alpha * a * b + beta * self
126    fn gemm_oo(&mut self, alpha: Self::T, a: &Self::Owned, b: &Self::Owned, beta: Self::T);
127    /// Perform matrix-matrix multiplication with a view and owned matrix: self = alpha * a * b + beta * self
128    fn gemm_vo(&mut self, alpha: Self::T, a: &Self::View, b: &Self::Owned, beta: Self::T);
129}
130
131/// A borrowed immutable view of a dense matrix, supporting read-only arithmetic operations.
132///
133/// This trait represents a temporary immutable reference to a matrix's data, allowing read-only
134/// operations like addition, subtraction, scalar multiplication, and matrix-vector multiplication.
135/// Matrix views can be created via the `columns()` methods on a `DenseMatrix`.
136pub trait MatrixView<'a>:
137    for<'b> MatrixOpsByValue<&'b Self::Owned, Self::Owned> + Mul<Scale<Self::T>, Output = Self::Owned>
138{
139    type Owned;
140
141    /// Convert this view into an owned matrix, cloning the data if necessary.
142    fn into_owned(self) -> Self::Owned;
143
144    /// Perform a matrix-vector multiplication with a vector view: y = alpha * self * x + beta * y
145    fn gemv_v(
146        &self,
147        alpha: Self::T,
148        x: &<Self::V as Vector>::View<'_>,
149        beta: Self::T,
150        y: &mut Self::V,
151    );
152
153    /// Perform a matrix-vector multiplication with an owned vector: y = alpha * self * x + beta * y
154    fn gemv_o(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
155}
156
157/// A base matrix trait supporting both sparse and dense matrices.
158///
159/// This trait provides a complete interface for matrix operations including:
160/// - Matrix creation and memory management
161/// - Matrix-vector and matrix-matrix multiplication
162/// - Element access and modification
163/// - Sparsity information and handling
164/// - Matrix decomposition and combination operations
165/// - Triplet-based construction for sparse matrices
166///
167/// Implementing matrices can be dense or sparse, and may be hosted on CPU or GPU.
168/// Users typically do not need to implement this trait; use provided implementations.
169pub trait Matrix:
170    MatrixCommon + Mul<Scale<Self::T>, Output = Self> + Clone + Send + 'static
171{
172    type Sparsity: MatrixSparsity<Self>;
173    type SparsityRef<'a>: MatrixSparsityRef<'a, Self>
174    where
175        Self: 'a;
176
177    /// Return sparsity information, or `None` if the matrix is dense.
178    fn sparsity(&self) -> Option<Self::SparsityRef<'_>>;
179
180    /// Get the context associated with this matrix (for device placement, memory management, etc.).
181    fn context(&self) -> &Self::C;
182
183    /// Get a mutable reference to the inner representation of the matrix.
184    fn inner_mut(&mut self) -> &mut Self::Inner;
185
186    /// Returns true if this matrix is stored in a sparse format
187    fn is_sparse() -> bool {
188        Self::zeros(1, 1, Default::default()).sparsity().is_some()
189    }
190
191    /// Partition the diagonal indices into two groups: those with zero diagonal elements and those with non-zero diagonal elements.
192    ///
193    /// This is useful for identifying algebraic constraints, which typically have zero diagonal elements in the mass matrix.
194    /// Returns a tuple of (zero_diagonal_indices, non_zero_diagonal_indices).
195    fn partition_indices_by_zero_diagonal(
196        &self,
197    ) -> (<Self::V as Vector>::Index, <Self::V as Vector>::Index);
198
199    /// Perform a matrix-vector multiplication: y = alpha * self * x + beta * y
200    fn gemv(&self, alpha: Self::T, x: &Self::V, beta: Self::T, y: &mut Self::V);
201
202    /// Copy the contents of `other` into this matrix.
203    fn copy_from(&mut self, other: &Self);
204
205    /// Create a new matrix of shape `nrows` x `ncols` filled with zeros.
206    fn zeros(nrows: IndexType, ncols: IndexType, ctx: Self::C) -> Self;
207
208    /// Create a new matrix from a sparsity pattern. Non-zero elements are not initialized.
209    fn new_from_sparsity(
210        nrows: IndexType,
211        ncols: IndexType,
212        sparsity: Option<Self::Sparsity>,
213        ctx: Self::C,
214    ) -> Self;
215
216    /// Create a new diagonal matrix from a vector holding the diagonal elements.
217    fn from_diagonal(v: &Self::V) -> Self;
218
219    /// Set the values of column `j` to be equal to the values in `v`.
220    ///
221    /// For sparse matrices, only the existing non-zero elements are updated.
222    fn set_column(&mut self, j: IndexType, v: &Self::V);
223
224    /// Add a column of this matrix to a vector: v += self[:, j]
225    fn add_column_to_vector(&self, j: IndexType, v: &mut Self::V);
226
227    /// Assign the values in the `data` vector to this matrix at the indices in `dst_indices`
228    /// from the indices in `src_indices`.
229    ///
230    /// For dense matrices, the index is the data index in column-major order.
231    /// For sparse matrices, the index is the index into the data array.
232    fn set_data_with_indices(
233        &mut self,
234        dst_indices: &<Self::V as Vector>::Index,
235        src_indices: &<Self::V as Vector>::Index,
236        data: &Self::V,
237    );
238
239    /// Gather values from another matrix at specified indices into this matrix.
240    ///
241    /// For sparse matrices: the index `idx_i` in `indices` is an index into the data array for `other`,
242    /// and is copied to the index `idx_i` in the data array for this matrix.
243    /// For dense matrices: the index is the data index in column-major order.
244    fn gather(&mut self, other: &Self, indices: &<Self::V as Vector>::Index);
245
246    /// Split this matrix into four submatrices based on algebraic constraint indices.
247    ///
248    /// Partitions the matrix into blocks:
249    /// ```text
250    /// M = [UL, UR]
251    ///     [LL, LR]
252    /// ```
253    /// where:
254    /// - UL contains rows and columns NOT in `algebraic_indices`
255    /// - UR contains rows NOT in `algebraic_indices` and columns in `algebraic_indices`
256    /// - LL contains rows in `algebraic_indices` and columns NOT in `algebraic_indices`
257    /// - LR contains rows and columns in `algebraic_indices`
258    ///
259    /// Returns an array of tuples, where each tuple contains a submatrix and the indices that were used to create it.
260    /// These indices can be used with `gather()` to update the submatrix.
261    fn split(
262        &self,
263        algebraic_indices: &<Self::V as Vector>::Index,
264    ) -> [(Self, <Self::V as Vector>::Index); 4] {
265        match self.sparsity() {
266            Some(sp) => sp.split(algebraic_indices).map(|(sp, src_indices)| {
267                let mut m = Self::new_from_sparsity(
268                    sp.nrows(),
269                    sp.ncols(),
270                    Some(sp),
271                    self.context().clone(),
272                );
273                m.gather(self, &src_indices);
274                (m, src_indices)
275            }),
276            None => Dense::<Self>::new(self.nrows(), self.ncols())
277                .split(algebraic_indices)
278                .map(|(sp, src_indices)| {
279                    let mut m = Self::new_from_sparsity(
280                        sp.nrows(),
281                        sp.ncols(),
282                        None,
283                        self.context().clone(),
284                    );
285                    m.gather(self, &src_indices);
286                    (m, src_indices)
287                }),
288        }
289    }
290
291    /// Combine four submatrices back into a single matrix based on algebraic constraint indices.
292    ///
293    /// Inverse operation of `split()`. Takes submatrices `ul`, `ur`, `ll`, `lr` and combines them
294    /// back into the original matrix structure.
295    fn combine(
296        ul: &Self,
297        ur: &Self,
298        ll: &Self,
299        lr: &Self,
300        algebraic_indices: &<Self::V as Vector>::Index,
301    ) -> Self {
302        combine(ul, ur, ll, lr, algebraic_indices)
303    }
304
305    /// Perform the assignment: self = x + beta * y where x and y are matrices and beta is a scalar.
306    ///
307    /// Note: Panics if the sparsity patterns of self, x, and y do not match.
308    /// The sparsity of self must be the union of the sparsity of x and y.
309    fn scale_add_and_assign(&mut self, x: &Self, beta: Self::T, y: &Self);
310
311    /// Iterate over structural positions and values of the matrix.
312    ///
313    /// Returns a tuple:
314    /// - First iterator: `(row, col)` pairs for each non-zero element (length `nnz`)
315    /// - Second iterator: values (length `nnz * nbatch`), laid out batch-contiguously:
316    ///   `[batch0_val0..batch0_valN, batch1_val0..batch1_valN, ...]`
317    fn triplet_iter(
318        &self,
319    ) -> (
320        impl Iterator<Item = (IndexType, IndexType)> + '_,
321        impl Iterator<Item = Self::T> + '_,
322    );
323
324    /// Create a new matrix from structural indices and values.
325    ///
326    /// - `indices`: `(row, col)` pairs for each non-zero element (length `nnz`)
327    /// - `values`: values laid out batch-contiguously (length `nnz * ctx.nbatch()`):
328    ///   `[batch0_val0..batch0_valN, batch1_val0..batch1_valN, ...]`
329    fn try_from_triplets(
330        nrows: IndexType,
331        ncols: IndexType,
332        indices: Vec<(IndexType, IndexType)>,
333        values: Vec<Self::T>,
334        ctx: Self::C,
335    ) -> Result<Self, DiffsolError>;
336}
337
338/// A host matrix is a matrix type whose vector type is hosted on the CPU.
339///
340/// This trait extends `Matrix` to ensure the associated vector type implements `VectorHost`,
341/// enabling direct CPU-side access to data. GPU matrices typically do not implement this trait.
342pub trait MatrixHost: Matrix<V: VectorHost> {}
343
344impl<T: Matrix<V: VectorHost>> MatrixHost for T {}
345
346/// A dense column-major matrix with efficient column access operations.
347///
348/// This trait represents matrices stored in column-major order, where accessing matrix columns
349/// is efficient. It supports:
350/// - Matrix views and mutable views
351/// - Matrix-matrix multiplication (GEMM)
352/// - Column operations (axpy, access, modification)
353/// - Element access and modification
354/// - Matrix resizing
355///
356/// The column-major layout makes operations on individual or ranges of columns very efficient.
357pub trait DenseMatrix:
358    Matrix
359    + for<'b> MatrixOpsByValue<&'b Self, Self>
360    + for<'b> MatrixMutOpsByValue<&'b Self>
361    + for<'a, 'b> MatrixOpsByValue<&'b Self::View<'a>, Self>
362    + for<'a, 'b> MatrixMutOpsByValue<&'b Self::View<'a>>
363{
364    /// A view of the dense matrix type
365    type View<'a>: MatrixView<'a, Owned = Self, T = Self::T, V = Self::V>
366    where
367        Self: 'a;
368
369    /// A mutable view of the dense matrix type
370    type ViewMut<'a>: MatrixViewMut<
371        'a,
372        Owned = Self,
373        T = Self::T,
374        V = Self::V,
375        View = Self::View<'a>,
376    >
377    where
378        Self: 'a;
379
380    /// Perform a matrix-matrix multiplication: self = alpha * a * b + beta * self
381    fn gemm(&mut self, alpha: Self::T, a: &Self, b: &Self, beta: Self::T);
382
383    /// Perform a column AXPY operation: column i = alpha * column j + column i
384    ///
385    /// This is equivalent to: self[:, i] += alpha * self[:, j]
386    fn column_axpy(&mut self, alpha: Self::T, j: IndexType, i: IndexType);
387
388    /// Get an immutable view of columns from `start` (inclusive) to `end` (exclusive).
389    fn columns(&self, start: IndexType, end: IndexType) -> Self::View<'_>;
390
391    /// Get an immutable vector view of column `i`.
392    fn column(&self, i: IndexType) -> <Self::V as Vector>::View<'_>;
393
394    /// Get a mutable view of columns from `start` (inclusive) to `end` (exclusive).
395    fn columns_mut(&mut self, start: IndexType, end: IndexType) -> Self::ViewMut<'_>;
396
397    /// Get a mutable vector view of column `i`.
398    fn column_mut(&mut self, i: IndexType) -> <Self::V as Vector>::ViewMut<'_>;
399
400    /// Set the value at the given row and column indices.
401    fn set_index(&mut self, i: IndexType, j: IndexType, value: Self::T);
402
403    /// Get the value at the given row and column indices.
404    fn get_index(&self, i: IndexType, j: IndexType) -> Self::T;
405
406    /// Perform matrix-matrix multiplication using GEMM, allocating a new matrix for the result.
407    fn mat_mul(&self, b: &Self) -> Self {
408        let nrows = self.nrows();
409        let ncols = b.ncols();
410        let mut ret = Self::zeros(nrows, ncols, self.context().clone());
411        ret.gemm(Self::T::one(), self, b, Self::T::zero());
412        ret
413    }
414
415    /// Resize the number of columns in the matrix, preserving existing data.
416    ///
417    /// New elements (if added) are uninitialized. If the number of columns decreases, trailing columns are discarded.
418    fn resize_cols(&mut self, ncols: IndexType);
419
420    /// Create a new matrix from a vector of values in column-major order.
421    ///
422    /// The values are assumed to be stored in column-major order (first column, then second column, etc.).
423    fn from_vec(nrows: IndexType, ncols: IndexType, data: Vec<Self::T>, ctx: Self::C) -> Self;
424}
425
426#[cfg(test)]
427pub(crate) mod tests {
428    use super::{DenseMatrix, Matrix, MatrixCommon, MatrixView, MatrixViewMut};
429    use crate::scalar::Scale;
430    use crate::{scalar::IndexType, Context, Vector, VectorIndex};
431    use num_traits::{FromPrimitive, One, Zero};
432
433    fn f<M: Matrix>(x: f64) -> M::T {
434        M::T::from_f64(x).unwrap()
435    }
436
437    fn triplet_values<M: Matrix>(m: &M) -> Vec<M::T> {
438        let (_, vals) = m.triplet_iter();
439        vals.collect()
440    }
441
442    fn triplet_indices<M: Matrix>(m: &M) -> Vec<(IndexType, IndexType)> {
443        let (idx, _) = m.triplet_iter();
444        idx.collect()
445    }
446
447    pub fn test_partition_indices_by_zero_diagonal<M: Matrix>() {
448        let indices = vec![(0, 0), (1, 1), (3, 3)];
449        let values = vec![M::T::one(), M::T::from_f64(2.0).unwrap(), M::T::one()];
450        let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
451        let (zero_diagonal_indices, non_zero_diagonal_indices) =
452            m.partition_indices_by_zero_diagonal();
453        assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
454        assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
455
456        let indices = vec![(0, 0), (1, 1), (2, 2), (3, 3)];
457        let values = vec![
458            M::T::one(),
459            M::T::from_f64(2.0).unwrap(),
460            M::T::zero(),
461            M::T::one(),
462        ];
463        let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
464        let (zero_diagonal_indices, non_zero_diagonal_indices) =
465            m.partition_indices_by_zero_diagonal();
466        assert_eq!(zero_diagonal_indices.clone_as_vec(), vec![2]);
467        assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 3]);
468
469        let indices = vec![(0, 0), (1, 1), (2, 2), (3, 3)];
470        let values = vec![
471            M::T::one(),
472            M::T::from_f64(2.0).unwrap(),
473            M::T::from_f64(3.0).unwrap(),
474            M::T::one(),
475        ];
476        let m = M::try_from_triplets(4, 4, indices, values, Default::default()).unwrap();
477        let (zero_diagonal_indices, non_zero_diagonal_indices) =
478            m.partition_indices_by_zero_diagonal();
479        assert_eq!(
480            zero_diagonal_indices.clone_as_vec(),
481            Vec::<IndexType>::new()
482        );
483        assert_eq!(non_zero_diagonal_indices.clone_as_vec(), vec![0, 1, 2, 3]);
484    }
485
486    // --- Matrix-generic tests (work with both dense and sparse) ---
487
488    pub fn test_zeros<M: Matrix>() {
489        let a = M::zeros(2, 3, Default::default());
490        assert_eq!(a.nrows(), 2);
491        assert_eq!(a.ncols(), 3);
492        let vals = triplet_values(&a);
493        assert!(vals.is_empty() || vals.iter().all(|v| v.is_zero()));
494    }
495
496    pub fn test_from_diagonal<M: Matrix>() {
497        let v = M::V::from_vec(
498            vec![f::<M>(2.0), f::<M>(3.0), f::<M>(5.0)],
499            Default::default(),
500        );
501        let a = M::from_diagonal(&v);
502        assert_eq!(a.nrows(), 3);
503        assert_eq!(a.ncols(), 3);
504        let idx = triplet_indices(&a);
505        let vals = triplet_values(&a);
506        // diagonal matrix triplet_iter returns only the diagonal nnz entries
507        for &(i, j) in &idx {
508            let pos = idx.iter().position(|&x| x == (i, j)).unwrap();
509            if i == j {
510                assert!(
511                    vals[pos] != M::T::zero(),
512                    "diagonal entry should be non-zero"
513                );
514            } else {
515                assert!(vals[pos].is_zero(), "off-diagonal entry should be zero");
516            }
517        }
518    }
519
520    pub fn test_from_diagonal_dense<M: DenseMatrix>() {
521        let v = M::V::from_vec(
522            vec![f::<M>(2.0), f::<M>(3.0), f::<M>(5.0)],
523            Default::default(),
524        );
525        let a = M::from_diagonal(&v);
526        assert_eq!(a.nrows(), 3);
527        assert_eq!(a.ncols(), 3);
528        assert_eq!(a.get_index(0, 0), f::<M>(2.0));
529        assert_eq!(a.get_index(1, 1), f::<M>(3.0));
530        assert_eq!(a.get_index(2, 2), f::<M>(5.0));
531        assert_eq!(a.get_index(0, 1), f::<M>(0.0));
532        assert_eq!(a.get_index(1, 0), f::<M>(0.0));
533    }
534
535    pub fn test_gemv<M: Matrix>() {
536        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
537        let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
538        let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
539        let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(2.0)], Default::default());
540        let mut y = M::V::zeros(2, Default::default());
541        a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
542        assert_eq!(y.clone_as_vec(), vec![f::<M>(5.0), f::<M>(11.0)]);
543    }
544
545    pub fn test_set_column<M: Matrix>() {
546        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
547        let values = vec![f::<M>(0.0), f::<M>(0.0), f::<M>(0.0), f::<M>(0.0)];
548        let mut a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
549        let v = M::V::from_vec(vec![f::<M>(7.0), f::<M>(8.0)], Default::default());
550        a.set_column(1, &v);
551        let idx = triplet_indices(&a);
552        let vals = triplet_values(&a);
553        assert_eq!(idx, vec![(0, 0), (1, 0), (0, 1), (1, 1)]);
554        assert_eq!(
555            vals,
556            vec![f::<M>(0.0), f::<M>(0.0), f::<M>(7.0), f::<M>(8.0)]
557        );
558    }
559
560    pub fn test_copy_from<M: Matrix>() {
561        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
562        let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
563        let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
564        let mut b = M::zeros(2, 2, Default::default());
565        b.copy_from(&a);
566        let vals = triplet_values(&b);
567        assert_eq!(
568            vals,
569            vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)]
570        );
571    }
572
573    pub fn test_scale_add_and_assign<M: Matrix>() {
574        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
575        let x_vals = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
576        let y_vals = vec![f::<M>(10.0), f::<M>(20.0), f::<M>(30.0), f::<M>(40.0)];
577        let x = M::try_from_triplets(2, 2, indices.clone(), x_vals, Default::default()).unwrap();
578        let y = M::try_from_triplets(2, 2, indices, y_vals, Default::default()).unwrap();
579        let mut result = M::zeros(2, 2, Default::default());
580        result.copy_from(&x);
581        result.scale_add_and_assign(&x, f::<M>(2.0), &y);
582        let vals = triplet_values(&result);
583        assert_eq!(
584            vals,
585            vec![f::<M>(21.0), f::<M>(42.0), f::<M>(63.0), f::<M>(84.0)]
586        );
587    }
588
589    // --- DenseMatrix-specific tests ---
590
591    pub fn test_column_axpy<M: DenseMatrix>() {
592        let mut a = M::zeros(2, 2, Default::default());
593        a.set_index(0, 0, M::T::one());
594        a.set_index(0, 1, M::T::from_f64(2.0).unwrap());
595        a.set_index(1, 0, M::T::from_f64(3.0).unwrap());
596        a.set_index(1, 1, M::T::from_f64(4.0).unwrap());
597
598        a.column_axpy(M::T::from_f64(2.0).unwrap(), 0, 1);
599        assert_eq!(a.get_index(0, 0), M::T::one());
600        assert_eq!(a.get_index(0, 1), M::T::from_f64(4.0).unwrap());
601        assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
602        assert_eq!(a.get_index(1, 1), M::T::from_f64(10.0).unwrap());
603    }
604
605    pub fn test_resize_cols<M: DenseMatrix>() {
606        let mut a = M::zeros(2, 2, Default::default());
607        a.set_index(0, 0, M::T::one());
608        a.set_index(0, 1, M::T::from_f64(2.0).unwrap());
609        a.set_index(1, 0, M::T::from_f64(3.0).unwrap());
610        a.set_index(1, 1, M::T::from_f64(4.0).unwrap());
611
612        a.resize_cols(3);
613        assert_eq!(a.ncols(), 3);
614        assert_eq!(a.nrows(), 2);
615        assert_eq!(a.get_index(0, 0), M::T::one());
616        assert_eq!(a.get_index(0, 1), M::T::from_f64(2.0).unwrap());
617        assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
618        assert_eq!(a.get_index(1, 1), M::T::from_f64(4.0).unwrap());
619
620        a.set_index(0, 2, M::T::from_f64(5.0).unwrap());
621        a.set_index(1, 2, M::T::from_f64(6.0).unwrap());
622        assert_eq!(a.get_index(0, 2), M::T::from_f64(5.0).unwrap());
623        assert_eq!(a.get_index(1, 2), M::T::from_f64(6.0).unwrap());
624
625        a.resize_cols(2);
626        assert_eq!(a.ncols(), 2);
627        assert_eq!(a.nrows(), 2);
628        assert_eq!(a.get_index(0, 0), M::T::one());
629        assert_eq!(a.get_index(0, 1), M::T::from_f64(2.0).unwrap());
630        assert_eq!(a.get_index(1, 0), M::T::from_f64(3.0).unwrap());
631        assert_eq!(a.get_index(1, 1), M::T::from_f64(4.0).unwrap());
632    }
633
634    pub fn test_from_vec<M: DenseMatrix>() {
635        let a = M::from_vec(
636            2,
637            2,
638            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
639            Default::default(),
640        );
641        assert_eq!(a.nrows(), 2);
642        assert_eq!(a.ncols(), 2);
643        assert_eq!(a.get_index(0, 0), f::<M>(1.0));
644        assert_eq!(a.get_index(1, 0), f::<M>(3.0));
645        assert_eq!(a.get_index(0, 1), f::<M>(2.0));
646        assert_eq!(a.get_index(1, 1), f::<M>(4.0));
647    }
648
649    pub fn test_gemm<M: DenseMatrix>() {
650        let a = M::from_vec(
651            2,
652            2,
653            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
654            Default::default(),
655        );
656        let b = M::from_vec(
657            2,
658            2,
659            vec![f::<M>(2.0), f::<M>(1.0), f::<M>(0.0), f::<M>(3.0)],
660            Default::default(),
661        );
662        let mut c = M::zeros(2, 2, Default::default());
663        c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
664        assert_eq!(c.get_index(0, 0), f::<M>(4.0));
665        assert_eq!(c.get_index(1, 0), f::<M>(10.0));
666        assert_eq!(c.get_index(0, 1), f::<M>(6.0));
667        assert_eq!(c.get_index(1, 1), f::<M>(12.0));
668    }
669
670    pub fn test_mat_mul<M: DenseMatrix>() {
671        let a = M::from_vec(
672            2,
673            2,
674            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
675            Default::default(),
676        );
677        let b = M::from_vec(
678            2,
679            2,
680            vec![f::<M>(2.0), f::<M>(1.0), f::<M>(0.0), f::<M>(3.0)],
681            Default::default(),
682        );
683        let c = a.mat_mul(&b);
684        assert_eq!(c.get_index(0, 0), f::<M>(4.0));
685        assert_eq!(c.get_index(1, 0), f::<M>(10.0));
686        assert_eq!(c.get_index(0, 1), f::<M>(6.0));
687        assert_eq!(c.get_index(1, 1), f::<M>(12.0));
688    }
689
690    pub fn test_columns_view<M: DenseMatrix>() {
691        let a = M::from_vec(
692            2,
693            3,
694            vec![
695                f::<M>(1.0),
696                f::<M>(4.0),
697                f::<M>(2.0),
698                f::<M>(5.0),
699                f::<M>(3.0),
700                f::<M>(6.0),
701            ],
702            Default::default(),
703        );
704        let view = a.columns(0, 2);
705        assert_eq!(view.ncols(), 2);
706        assert_eq!(view.nrows(), 2);
707        let owned = view.into_owned();
708        assert_eq!(owned.get_index(0, 0), f::<M>(1.0));
709        assert_eq!(owned.get_index(1, 0), f::<M>(4.0));
710        assert_eq!(owned.get_index(0, 1), f::<M>(2.0));
711        assert_eq!(owned.get_index(1, 1), f::<M>(5.0));
712    }
713
714    pub fn test_column_view<M: DenseMatrix>() {
715        let a = M::from_vec(
716            2,
717            2,
718            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
719            Default::default(),
720        );
721        let col = a.column(1);
722        use crate::VectorView;
723        assert_eq!(col.get_index(0), f::<M>(2.0));
724        assert_eq!(col.get_index(1), f::<M>(4.0));
725    }
726
727    // --- Batched Matrix-generic tests ---
728
729    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
730    pub fn test_batched_zeros_m<M: Matrix>(ctx: M::C) {
731        assert_eq!(ctx.nbatch(), 2);
732        let a = M::zeros(2, 3, ctx);
733        assert_eq!(a.nrows(), 2);
734        assert_eq!(a.ncols(), 3);
735        let vals = triplet_values(&a);
736        assert!(vals.is_empty() || vals.iter().all(|v| v.is_zero()));
737    }
738
739    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
740    pub fn test_batched_gemv_m<M: Matrix>(ctx: M::C) {
741        assert_eq!(ctx.nbatch(), 2);
742        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
743        let values = vec![
744            f::<M>(1.0),
745            f::<M>(3.0),
746            f::<M>(2.0),
747            f::<M>(4.0), // batch 0
748            f::<M>(5.0),
749            f::<M>(7.0),
750            f::<M>(6.0),
751            f::<M>(8.0), // batch 1
752        ];
753        let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
754        let x = M::V::from_vec(
755            vec![f::<M>(1.0), f::<M>(2.0), f::<M>(1.0), f::<M>(1.0)],
756            ctx.clone(),
757        );
758        let mut y = M::V::zeros(2, ctx);
759        a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
760        assert_eq!(
761            y.clone_as_vec(),
762            vec![f::<M>(5.0), f::<M>(11.0), f::<M>(11.0), f::<M>(15.0)]
763        );
764    }
765
766    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
767    pub fn test_batched_gemv_broadcast_x_m<M: Matrix>(ctx: M::C) {
768        assert_eq!(ctx.nbatch(), 2);
769        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
770        let values = vec![
771            f::<M>(1.0),
772            f::<M>(3.0),
773            f::<M>(2.0),
774            f::<M>(4.0),
775            f::<M>(5.0),
776            f::<M>(7.0),
777            f::<M>(6.0),
778            f::<M>(8.0),
779        ];
780        let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
781        let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(2.0)], Default::default());
782        let mut y = M::V::zeros(2, ctx);
783        a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
784        assert_eq!(
785            y.clone_as_vec(),
786            vec![f::<M>(5.0), f::<M>(11.0), f::<M>(17.0), f::<M>(23.0)]
787        );
788    }
789
790    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
791    pub fn test_batched_gemv_broadcast_mat_m<M: Matrix>(ctx: M::C) {
792        assert_eq!(ctx.nbatch(), 2);
793        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
794        let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
795        let a =
796            M::try_from_triplets(2, 2, indices, values, ctx.clone_with_nbatch(1).unwrap()).unwrap();
797        let x = M::V::from_vec(
798            vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)],
799            ctx.clone(),
800        );
801        let mut y = M::V::zeros(2, ctx);
802        a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
803        assert_eq!(
804            y.clone_as_vec(),
805            vec![f::<M>(5.0), f::<M>(11.0), f::<M>(11.0), f::<M>(25.0)]
806        );
807    }
808
809    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
810    pub fn test_batched_from_diagonal_m<M: Matrix>(ctx: M::C) {
811        assert_eq!(ctx.nbatch(), 2);
812        let v = M::V::from_vec(
813            vec![f::<M>(2.0), f::<M>(3.0), f::<M>(4.0), f::<M>(5.0)],
814            ctx,
815        );
816        let a = M::from_diagonal(&v);
817        assert_eq!(a.nrows(), 2);
818        assert_eq!(a.ncols(), 2);
819        let idx = triplet_indices(&a);
820        let vals = triplet_values(&a);
821        for &(i, j) in &idx {
822            let pos = idx.iter().position(|&x| x == (i, j)).unwrap();
823            if i == j {
824                assert!(
825                    vals[pos] != M::T::zero(),
826                    "diagonal entry should be non-zero"
827                );
828            } else {
829                assert!(vals[pos].is_zero(), "off-diagonal entry should be zero");
830            }
831        }
832    }
833
834    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
835    pub fn test_batched_copy_from_m<M: Matrix>(ctx: M::C) {
836        assert_eq!(ctx.nbatch(), 2);
837        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
838        let values = vec![
839            f::<M>(1.0),
840            f::<M>(2.0),
841            f::<M>(3.0),
842            f::<M>(4.0),
843            f::<M>(5.0),
844            f::<M>(6.0),
845            f::<M>(7.0),
846            f::<M>(8.0),
847        ];
848        let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
849        let mut b = M::zeros(2, 2, ctx);
850        b.copy_from(&a);
851        let vals = triplet_values(&b);
852        assert_eq!(
853            vals,
854            vec![
855                f::<M>(1.0),
856                f::<M>(2.0),
857                f::<M>(3.0),
858                f::<M>(4.0),
859                f::<M>(5.0),
860                f::<M>(6.0),
861                f::<M>(7.0),
862                f::<M>(8.0),
863            ]
864        );
865    }
866
867    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
868    pub fn test_batched_set_column_m<M: Matrix>(ctx: M::C) {
869        assert_eq!(ctx.nbatch(), 2);
870        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
871        let values = vec![
872            f::<M>(0.0),
873            f::<M>(0.0),
874            f::<M>(0.0),
875            f::<M>(0.0),
876            f::<M>(0.0),
877            f::<M>(0.0),
878            f::<M>(0.0),
879            f::<M>(0.0),
880        ];
881        let mut a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
882        let v = M::V::from_vec(
883            vec![f::<M>(5.0), f::<M>(6.0), f::<M>(7.0), f::<M>(8.0)],
884            ctx,
885        );
886        a.set_column(0, &v);
887        let vals = triplet_values(&a);
888        assert_eq!(
889            vals,
890            vec![
891                f::<M>(5.0),
892                f::<M>(6.0),
893                f::<M>(0.0),
894                f::<M>(0.0),
895                f::<M>(7.0),
896                f::<M>(8.0),
897                f::<M>(0.0),
898                f::<M>(0.0),
899            ]
900        );
901    }
902
903    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
904    pub fn test_batched_scale_add_and_assign_m<M: Matrix>(ctx: M::C) {
905        assert_eq!(ctx.nbatch(), 2);
906        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
907        let x_vals = vec![
908            f::<M>(1.0),
909            f::<M>(2.0),
910            f::<M>(3.0),
911            f::<M>(4.0),
912            f::<M>(5.0),
913            f::<M>(6.0),
914            f::<M>(7.0),
915            f::<M>(8.0),
916        ];
917        let y_vals = vec![
918            f::<M>(10.0),
919            f::<M>(20.0),
920            f::<M>(30.0),
921            f::<M>(40.0),
922            f::<M>(50.0),
923            f::<M>(60.0),
924            f::<M>(70.0),
925            f::<M>(80.0),
926        ];
927        let x = M::try_from_triplets(2, 2, indices.clone(), x_vals, ctx.clone()).unwrap();
928        let y = M::try_from_triplets(2, 2, indices, y_vals, ctx.clone()).unwrap();
929        let mut result = M::zeros(2, 2, ctx);
930        result.copy_from(&x);
931        result.scale_add_and_assign(&x, f::<M>(2.0), &y);
932        let vals = triplet_values(&result);
933        assert_eq!(
934            vals,
935            vec![
936                f::<M>(21.0),
937                f::<M>(42.0),
938                f::<M>(63.0),
939                f::<M>(84.0),
940                f::<M>(105.0),
941                f::<M>(126.0),
942                f::<M>(147.0),
943                f::<M>(168.0),
944            ]
945        );
946    }
947
948    // --- Batched DenseMatrix-specific tests ---
949
950    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
951    pub fn test_batched_from_vec<M: DenseMatrix>(ctx: M::C) {
952        assert_eq!(ctx.nbatch(), 2);
953        // 2x2 matrix, nbatch=2: physical is 2x4
954        // batch0: col0=[1,3], col1=[2,4]; batch1: col0=[5,7], col1=[6,8]
955        let a = M::from_vec(
956            2,
957            2,
958            vec![
959                f::<M>(1.0),
960                f::<M>(3.0),
961                f::<M>(2.0),
962                f::<M>(4.0),
963                f::<M>(5.0),
964                f::<M>(7.0),
965                f::<M>(6.0),
966                f::<M>(8.0),
967            ],
968            ctx,
969        );
970        assert_eq!(a.nrows(), 2);
971        assert_eq!(a.ncols(), 2);
972        assert_eq!(a.get_index(0, 0), f::<M>(1.0));
973        assert_eq!(a.get_index(1, 0), f::<M>(3.0));
974        assert_eq!(a.get_index(0, 1), f::<M>(2.0));
975        assert_eq!(a.get_index(1, 1), f::<M>(4.0));
976    }
977
978    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
979    pub fn test_batched_gemm<M: DenseMatrix>(ctx: M::C) {
980        assert_eq!(ctx.nbatch(), 2);
981        // batch0: A=[[1,0],[0,1]](identity), batch1: A=[[2,0],[0,2]]
982        let a = M::from_vec(
983            2,
984            2,
985            vec![
986                f::<M>(1.0),
987                f::<M>(0.0),
988                f::<M>(0.0),
989                f::<M>(1.0),
990                f::<M>(2.0),
991                f::<M>(0.0),
992                f::<M>(0.0),
993                f::<M>(2.0),
994            ],
995            ctx.clone(),
996        );
997        // batch0: B=[[3,4],[5,6]], batch1: B=[[1,1],[1,1]]
998        let b = M::from_vec(
999            2,
1000            2,
1001            vec![
1002                f::<M>(3.0),
1003                f::<M>(5.0),
1004                f::<M>(4.0),
1005                f::<M>(6.0),
1006                f::<M>(1.0),
1007                f::<M>(1.0),
1008                f::<M>(1.0),
1009                f::<M>(1.0),
1010            ],
1011            ctx.clone(),
1012        );
1013        let mut c = M::zeros(2, 2, ctx);
1014        c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1015        // batch0: I*B=B=[[3,4],[5,6]], batch1: 2I*B=[[2,2],[2,2]]
1016        assert_eq!(c.get_index(0, 0), f::<M>(3.0));
1017        assert_eq!(c.get_index(1, 0), f::<M>(5.0));
1018        assert_eq!(c.get_index(0, 1), f::<M>(4.0));
1019        assert_eq!(c.get_index(1, 1), f::<M>(6.0));
1020    }
1021
1022    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1023    pub fn test_batched_columns<M: DenseMatrix>(ctx: M::C) {
1024        assert_eq!(ctx.nbatch(), 2);
1025        // 2x3 matrix, nbatch=2
1026        // batch0: [[1,3,5],[2,4,6]], batch1: [[7,9,11],[8,10,12]]
1027        let a = M::from_vec(
1028            2,
1029            3,
1030            vec![
1031                f::<M>(1.0),
1032                f::<M>(2.0),
1033                f::<M>(3.0),
1034                f::<M>(4.0),
1035                f::<M>(5.0),
1036                f::<M>(6.0),
1037                f::<M>(7.0),
1038                f::<M>(8.0),
1039                f::<M>(9.0),
1040                f::<M>(10.0),
1041                f::<M>(11.0),
1042                f::<M>(12.0),
1043            ],
1044            ctx.clone(),
1045        );
1046        let view = a.columns(0, 2);
1047        assert_eq!(view.ncols(), 2);
1048        assert_eq!(view.nrows(), 2);
1049        let owned = view.into_owned();
1050        assert_eq!(owned.nrows(), 2);
1051        assert_eq!(owned.ncols(), 2);
1052        // Verify via gemv_o: multiply columns(0,2) by [1,1] for each batch
1053        let view2 = a.columns(0, 2);
1054        let x = M::V::from_vec(
1055            vec![f::<M>(1.0), f::<M>(1.0), f::<M>(1.0), f::<M>(1.0)],
1056            ctx.clone(),
1057        );
1058        let mut y = M::V::zeros(2, ctx);
1059        view2.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1060        // batch0: [1,2]*1 + [3,4]*1 = [4,6], batch1: [7,8]*1 + [9,10]*1 = [16,18]
1061        assert_eq!(
1062            y.clone_as_vec(),
1063            vec![f::<M>(4.0), f::<M>(6.0), f::<M>(16.0), f::<M>(18.0)]
1064        );
1065    }
1066
1067    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1068    pub fn test_batched_gemv_o_on_columns<M: DenseMatrix>(ctx: M::C) {
1069        assert_eq!(ctx.nbatch(), 2);
1070        // 2x3 diff matrix, nbatch=2
1071        // batch0: [[1,2,3],[4,5,6]], batch1: [[7,8,9],[10,11,12]]
1072        let diff = M::from_vec(
1073            2,
1074            3,
1075            vec![
1076                f::<M>(1.0),
1077                f::<M>(4.0),
1078                f::<M>(2.0),
1079                f::<M>(5.0),
1080                f::<M>(3.0),
1081                f::<M>(6.0),
1082                f::<M>(7.0),
1083                f::<M>(10.0),
1084                f::<M>(8.0),
1085                f::<M>(11.0),
1086                f::<M>(9.0),
1087                f::<M>(12.0),
1088            ],
1089            ctx.clone(),
1090        );
1091        // take columns 0..2 from each batch
1092        let view = diff.columns(0, 2);
1093        // x has nbatch=2, length=2 (matches ncols of view)
1094        // batch0: x=[1,1], batch1: x=[2,2]
1095        let x = M::V::from_vec(
1096            vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
1097            ctx.clone(),
1098        );
1099        let mut y = M::V::zeros(2, ctx);
1100        view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1101        // batch0: [[1,2],[4,5]] * [1,1] = [1+2, 4+5] = [3, 9]
1102        // batch1: [[7,8],[10,11]] * [2,2] = [14+16, 20+22] = [30, 42]
1103        assert_eq!(
1104            y.clone_as_vec(),
1105            vec![f::<M>(3.0), f::<M>(9.0), f::<M>(30.0), f::<M>(42.0)]
1106        );
1107    }
1108
1109    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1110    pub fn test_batched_gemv_v_broadcast_mat<M: DenseMatrix>(ctx3: M::C) {
1111        assert_eq!(ctx3.nbatch(), 2);
1112        // matrix view with nbatch=1 broadcasts to x/y with nbatch=2
1113        let ctx1 = M::C::default();
1114        // 2x3 matrix, nbatch=1: [[1,2,3],[4,5,6]]
1115        let diff = M::from_vec(
1116            2,
1117            3,
1118            vec![
1119                f::<M>(1.0),
1120                f::<M>(4.0),
1121                f::<M>(2.0),
1122                f::<M>(5.0),
1123                f::<M>(3.0),
1124                f::<M>(6.0),
1125            ],
1126            ctx1,
1127        );
1128        let view = diff.columns(0, 2);
1129        // x with nbatch=2, length=2
1130        // batch0: [1,1], batch1: [2,2]
1131        let x = M::V::from_vec(
1132            vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
1133            ctx3.clone(),
1134        );
1135        let mut y = M::V::zeros(2, ctx3);
1136        view.gemv_v(f::<M>(1.0), &x.as_view(), f::<M>(0.0), &mut y);
1137        // batch0: [[1,2],[4,5]] * [1,1] = [3, 9]
1138        // batch1: [[1,2],[4,5]] * [2,2] = [6, 18]
1139        assert_eq!(
1140            y.clone_as_vec(),
1141            vec![f::<M>(3.0), f::<M>(9.0), f::<M>(6.0), f::<M>(18.0)]
1142        );
1143    }
1144
1145    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1146    pub fn test_batched_gemv_o_broadcast_mat<M: DenseMatrix>(ctx3: M::C) {
1147        assert_eq!(ctx3.nbatch(), 2);
1148        // matrix view with nbatch=1 broadcasts to x/y with nbatch=2
1149        let ctx1 = M::C::default();
1150        // 2x3 matrix, nbatch=1: [[1,2,3],[4,5,6]]
1151        let diff = M::from_vec(
1152            2,
1153            3,
1154            vec![
1155                f::<M>(1.0),
1156                f::<M>(4.0),
1157                f::<M>(2.0),
1158                f::<M>(5.0),
1159                f::<M>(3.0),
1160                f::<M>(6.0),
1161            ],
1162            ctx1,
1163        );
1164        let view = diff.columns(0, 2);
1165        // x with nbatch=2, length=2
1166        // batch0: [1,1], batch1: [2,2]
1167        let x = M::V::from_vec(
1168            vec![f::<M>(1.0), f::<M>(1.0), f::<M>(2.0), f::<M>(2.0)],
1169            ctx3.clone(),
1170        );
1171        let mut y = M::V::zeros(2, ctx3);
1172        view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1173        // batch0: [[1,2],[4,5]] * [1,1] = [3, 9]
1174        // batch1: [[1,2],[4,5]] * [2,2] = [6, 18]
1175        assert_eq!(
1176            y.clone_as_vec(),
1177            vec![f::<M>(3.0), f::<M>(9.0), f::<M>(6.0), f::<M>(18.0)]
1178        );
1179    }
1180
1181    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1182    pub fn test_batched_gemm_vo_on_columns<M: DenseMatrix>(ctx: M::C) {
1183        assert_eq!(ctx.nbatch(), 2);
1184        // 2x3 diff matrix, nbatch=2
1185        let diff = M::from_vec(
1186            2,
1187            3,
1188            vec![
1189                f::<M>(1.0),
1190                f::<M>(4.0),
1191                f::<M>(2.0),
1192                f::<M>(5.0),
1193                f::<M>(3.0),
1194                f::<M>(6.0),
1195                f::<M>(7.0),
1196                f::<M>(10.0),
1197                f::<M>(8.0),
1198                f::<M>(11.0),
1199                f::<M>(9.0),
1200                f::<M>(12.0),
1201            ],
1202            ctx.clone(),
1203        );
1204        // R is 2x2 (nbatch=2): batch0=identity, batch1=2*identity
1205        let r = M::from_vec(
1206            2,
1207            2,
1208            vec![
1209                f::<M>(1.0),
1210                f::<M>(0.0),
1211                f::<M>(0.0),
1212                f::<M>(1.0),
1213                f::<M>(2.0),
1214                f::<M>(0.0),
1215                f::<M>(0.0),
1216                f::<M>(2.0),
1217            ],
1218            ctx.clone(),
1219        );
1220        let mut result = M::zeros(2, 3, ctx);
1221        {
1222            let d_view = diff.columns(0, 2);
1223            let mut r_view = result.columns_mut(0, 2);
1224            r_view.gemm_vo(f::<M>(1.0), &d_view, &r, f::<M>(0.0));
1225        }
1226        // batch0: [[1,2],[4,5]] * I = [[1,2],[4,5]]
1227        // batch1: [[7,8],[10,11]] * 2I = [[14,16],[20,22]]
1228        assert_eq!(result.get_index(0, 0), f::<M>(1.0));
1229        assert_eq!(result.get_index(1, 0), f::<M>(4.0));
1230        assert_eq!(result.get_index(0, 1), f::<M>(2.0));
1231        assert_eq!(result.get_index(1, 1), f::<M>(5.0));
1232    }
1233
1234    // --- Broadcasting tests ---
1235
1236    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1237    pub fn test_batched_gemm_broadcast_b<M: DenseMatrix>(ctx: M::C) {
1238        assert_eq!(ctx.nbatch(), 2);
1239        // batch0: A=[[1,0],[0,1]], batch1: A=[[2,0],[0,3]]
1240        let a = M::from_vec(
1241            2,
1242            2,
1243            vec![
1244                f::<M>(1.0),
1245                f::<M>(0.0),
1246                f::<M>(0.0),
1247                f::<M>(1.0),
1248                f::<M>(2.0),
1249                f::<M>(0.0),
1250                f::<M>(0.0),
1251                f::<M>(3.0),
1252            ],
1253            ctx.clone(),
1254        );
1255        // B with nbatch=1: [[1,2],[3,4]]
1256        let b = M::from_vec(
1257            2,
1258            2,
1259            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1260            Default::default(),
1261        );
1262        let mut c = M::zeros(2, 2, ctx);
1263        c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1264        // batch0: I*B=[[1,2],[3,4]], batch1: diag(2,3)*B=[[2,4],[9,12]]
1265        assert_eq!(c.get_index(0, 0), f::<M>(1.0));
1266        assert_eq!(c.get_index(1, 0), f::<M>(3.0));
1267        assert_eq!(c.get_index(0, 1), f::<M>(2.0));
1268        assert_eq!(c.get_index(1, 1), f::<M>(4.0));
1269    }
1270
1271    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1272    pub fn test_batched_gemm_broadcast_a<M: DenseMatrix>(ctx: M::C) {
1273        assert_eq!(ctx.nbatch(), 2);
1274        // A with nbatch=1: [[1,0],[0,2]]
1275        let a = M::from_vec(
1276            2,
1277            2,
1278            vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(2.0)],
1279            Default::default(),
1280        );
1281        // batch0: B=[[3,4],[5,6]], batch1: B=[[1,1],[1,1]]
1282        let b = M::from_vec(
1283            2,
1284            2,
1285            vec![
1286                f::<M>(3.0),
1287                f::<M>(5.0),
1288                f::<M>(4.0),
1289                f::<M>(6.0),
1290                f::<M>(1.0),
1291                f::<M>(1.0),
1292                f::<M>(1.0),
1293                f::<M>(1.0),
1294            ],
1295            ctx.clone(),
1296        );
1297        let mut c = M::zeros(2, 2, ctx);
1298        c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1299        // batch0: [[1,0],[0,2]]*[[3,4],[5,6]]=[[3,4],[10,12]]
1300        assert_eq!(c.get_index(0, 0), f::<M>(3.0));
1301        assert_eq!(c.get_index(1, 0), f::<M>(10.0));
1302        assert_eq!(c.get_index(0, 1), f::<M>(4.0));
1303        assert_eq!(c.get_index(1, 1), f::<M>(12.0));
1304    }
1305
1306    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1307    pub fn test_batched_gemv_o_broadcast_x<M: DenseMatrix>(ctx: M::C) {
1308        assert_eq!(ctx.nbatch(), 2);
1309        // 2x3 diff matrix, nbatch=2
1310        let diff = M::from_vec(
1311            2,
1312            3,
1313            vec![
1314                f::<M>(1.0),
1315                f::<M>(4.0),
1316                f::<M>(2.0),
1317                f::<M>(5.0),
1318                f::<M>(3.0),
1319                f::<M>(6.0),
1320                f::<M>(7.0),
1321                f::<M>(10.0),
1322                f::<M>(8.0),
1323                f::<M>(11.0),
1324                f::<M>(9.0),
1325                f::<M>(12.0),
1326            ],
1327            ctx.clone(),
1328        );
1329        let view = diff.columns(0, 2);
1330        // x with nbatch=1, length=2 (broadcast)
1331        let x = M::V::from_vec(vec![f::<M>(1.0), f::<M>(1.0)], Default::default());
1332        let mut y = M::V::zeros(2, ctx);
1333        view.gemv_o(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1334        // batch0: [[1,2],[4,5]] * [1,1] = [3, 9]
1335        // batch1: [[7,8],[10,11]] * [1,1] = [15, 21]
1336        assert_eq!(
1337            y.clone_as_vec(),
1338            vec![f::<M>(3.0), f::<M>(9.0), f::<M>(15.0), f::<M>(21.0)]
1339        );
1340    }
1341
1342    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1343    pub fn test_batched_gemm_vo_broadcast_b<M: DenseMatrix>(ctx: M::C) {
1344        assert_eq!(ctx.nbatch(), 2);
1345        // 2x3 diff matrix, nbatch=2
1346        let diff = M::from_vec(
1347            2,
1348            3,
1349            vec![
1350                f::<M>(1.0),
1351                f::<M>(4.0),
1352                f::<M>(2.0),
1353                f::<M>(5.0),
1354                f::<M>(3.0),
1355                f::<M>(6.0),
1356                f::<M>(7.0),
1357                f::<M>(10.0),
1358                f::<M>(8.0),
1359                f::<M>(11.0),
1360                f::<M>(9.0),
1361                f::<M>(12.0),
1362            ],
1363            ctx.clone(),
1364        );
1365        // R with nbatch=1: 2x2 identity
1366        let r = M::from_vec(
1367            2,
1368            2,
1369            vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(1.0)],
1370            Default::default(),
1371        );
1372        let mut result = M::zeros(2, 3, ctx);
1373        {
1374            let d_view = diff.columns(0, 2);
1375            let mut r_view = result.columns_mut(0, 2);
1376            r_view.gemm_vo(f::<M>(1.0), &d_view, &r, f::<M>(0.0));
1377        }
1378        // Both batches: sub-matrix * I = sub-matrix (unchanged)
1379        // batch0: [[1,2],[4,5]], batch1: [[7,8],[10,11]]
1380        assert_eq!(result.get_index(0, 0), f::<M>(1.0));
1381        assert_eq!(result.get_index(1, 0), f::<M>(4.0));
1382        assert_eq!(result.get_index(0, 1), f::<M>(2.0));
1383        assert_eq!(result.get_index(1, 1), f::<M>(5.0));
1384    }
1385
1386    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1387    pub fn test_batched_gemm_vo_broadcast_a<M: DenseMatrix>(ctx: M::C) {
1388        assert_eq!(ctx.nbatch(), 2);
1389        // diff with nbatch=1: 2x3 matrix [[1,2,3],[4,5,6]]
1390        let diff = M::from_vec(
1391            2,
1392            3,
1393            vec![
1394                f::<M>(1.0),
1395                f::<M>(4.0),
1396                f::<M>(2.0),
1397                f::<M>(5.0),
1398                f::<M>(3.0),
1399                f::<M>(6.0),
1400            ],
1401            Default::default(),
1402        );
1403        // b with nbatch=2: batch0=[[1,0],[0,1]], batch1=[[2,0],[0,3]]
1404        let b = M::from_vec(
1405            2,
1406            2,
1407            vec![
1408                f::<M>(1.0),
1409                f::<M>(0.0),
1410                f::<M>(0.0),
1411                f::<M>(1.0),
1412                f::<M>(2.0),
1413                f::<M>(0.0),
1414                f::<M>(0.0),
1415                f::<M>(3.0),
1416            ],
1417            ctx.clone(),
1418        );
1419        let mut result = M::zeros(2, 3, ctx);
1420        {
1421            let d_view = diff.columns(0, 2);
1422            let mut r_view = result.columns_mut(0, 2);
1423            r_view.gemm_vo(f::<M>(1.0), &d_view, &b, f::<M>(0.0));
1424        }
1425        // batch0: [[1,2],[4,5]]*I=[[1,2],[4,5]], batch1: [[1,2],[4,5]]*[[2,0],[0,3]]=[[2,6],[8,15]]
1426        assert_eq!(result.get_index(0, 0), f::<M>(1.0));
1427        assert_eq!(result.get_index(1, 0), f::<M>(4.0));
1428        assert_eq!(result.get_index(0, 1), f::<M>(2.0));
1429        assert_eq!(result.get_index(1, 1), f::<M>(5.0));
1430    }
1431
1432    // --- Incompatible batch tests ---
1433
1434    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1435    pub fn test_batched_gemm_incompatible_a<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
1436        assert_eq!(ctx2.nbatch(), 2);
1437        assert_eq!(ctx3.nbatch(), 3);
1438        let a = M::zeros(2, 2, ctx3);
1439        let b = M::zeros(2, 2, ctx2.clone());
1440        let mut c = M::zeros(2, 2, ctx2);
1441        c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1442    }
1443
1444    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1445    pub fn test_batched_gemv_incompatible<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
1446        assert_eq!(ctx2.nbatch(), 2);
1447        assert_eq!(ctx3.nbatch(), 3);
1448        let a = M::zeros(2, 2, ctx2.clone());
1449        let x = M::V::zeros(2, ctx3);
1450        let mut y = M::V::zeros(2, ctx2);
1451        a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1452    }
1453
1454    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1455    pub fn test_batched_gemm_incompatible<M: DenseMatrix>(ctx2: M::C, ctx3: M::C) {
1456        assert_eq!(ctx2.nbatch(), 2);
1457        assert_eq!(ctx3.nbatch(), 3);
1458        let a = M::zeros(2, 2, ctx2.clone());
1459        let b = M::zeros(2, 2, ctx3);
1460        let mut c = M::zeros(2, 2, ctx2);
1461        c.gemm(f::<M>(1.0), &a, &b, f::<M>(0.0));
1462    }
1463
1464    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1465    pub fn test_batched_resize_cols<M: DenseMatrix>(ctx: M::C) {
1466        assert_eq!(ctx.nbatch(), 2);
1467        // 2x2, nbatch=2: batch0=[[1,2],[3,4]], batch1=[[5,6],[7,8]]
1468        let mut a = M::from_vec(
1469            2,
1470            2,
1471            vec![
1472                f::<M>(1.0),
1473                f::<M>(3.0),
1474                f::<M>(2.0),
1475                f::<M>(4.0),
1476                f::<M>(5.0),
1477                f::<M>(7.0),
1478                f::<M>(6.0),
1479                f::<M>(8.0),
1480            ],
1481            ctx.clone(),
1482        );
1483        // grow to 3 columns
1484        a.resize_cols(3);
1485        assert_eq!(a.ncols(), 3);
1486        assert_eq!(a.nrows(), 2);
1487        // existing data preserved per batch
1488        assert_eq!(a.get_index(0, 0), f::<M>(1.0));
1489        assert_eq!(a.get_index(1, 0), f::<M>(3.0));
1490        assert_eq!(a.get_index(0, 1), f::<M>(2.0));
1491        assert_eq!(a.get_index(1, 1), f::<M>(4.0));
1492        // new column is zero
1493        assert_eq!(a.get_index(0, 2), f::<M>(0.0));
1494        assert_eq!(a.get_index(1, 2), f::<M>(0.0));
1495        // verify via gemv that batch 1 data is intact
1496        let x = M::V::from_vec(
1497            vec![
1498                f::<M>(1.0),
1499                f::<M>(0.0),
1500                f::<M>(0.0),
1501                f::<M>(1.0),
1502                f::<M>(0.0),
1503                f::<M>(0.0),
1504            ],
1505            ctx.clone(),
1506        );
1507        let mut y = M::V::zeros(2, ctx.clone());
1508        a.gemv(f::<M>(1.0), &x, f::<M>(0.0), &mut y);
1509        // batch0: col0=[1,3], x=[1,0,0] → [1,3]
1510        // batch1: col0=[5,7], x=[1,0,0] → [5,7]
1511        assert_eq!(
1512            y.clone_as_vec(),
1513            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(5.0), f::<M>(7.0)]
1514        );
1515
1516        // shrink to 1 column
1517        a.resize_cols(1);
1518        assert_eq!(a.ncols(), 1);
1519        assert_eq!(a.get_index(0, 0), f::<M>(1.0));
1520        assert_eq!(a.get_index(1, 0), f::<M>(3.0));
1521        // verify batch1 col0 via gemv
1522        let x2 = M::V::from_vec(vec![f::<M>(1.0), f::<M>(1.0)], ctx.clone());
1523        let mut y2 = M::V::zeros(2, ctx);
1524        a.gemv(f::<M>(1.0), &x2, f::<M>(0.0), &mut y2);
1525        // batch0: [[1],[3]] * [1] = [1,3], batch1: [[5],[7]] * [1] = [5,7]
1526        assert_eq!(
1527            y2.clone_as_vec(),
1528            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(5.0), f::<M>(7.0)]
1529        );
1530    }
1531
1532    // --- New unbatched Matrix-generic tests ---
1533
1534    pub fn test_mul_scalar<M: Matrix>() {
1535        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1536        let values = vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)];
1537        let a = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
1538        let result = a * Scale(f::<M>(2.0));
1539        let (_, vals) = result.triplet_iter();
1540        let vals: Vec<_> = vals.collect();
1541        assert_eq!(
1542            vals,
1543            vec![f::<M>(2.0), f::<M>(6.0), f::<M>(4.0), f::<M>(8.0)]
1544        );
1545    }
1546
1547    pub fn test_add_column_to_vector<M: Matrix>() {
1548        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1549        let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0)];
1550        let mat = M::try_from_triplets(2, 2, indices, values, Default::default()).unwrap();
1551        let mut v = M::V::zeros(2, Default::default());
1552        mat.add_column_to_vector(1, &mut v);
1553        assert_eq!(v.clone_as_vec(), vec![f::<M>(3.0), f::<M>(4.0)]);
1554    }
1555
1556    // --- New unbatched DenseMatrix-specific tests ---
1557
1558    pub fn test_add<M: DenseMatrix>() {
1559        let a = M::from_vec(
1560            2,
1561            2,
1562            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1563            Default::default(),
1564        );
1565        let b = M::from_vec(
1566            2,
1567            2,
1568            vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
1569            Default::default(),
1570        );
1571        let result = a + &b;
1572        assert_eq!(result.get_index(0, 0), f::<M>(6.0));
1573        assert_eq!(result.get_index(1, 1), f::<M>(12.0));
1574    }
1575
1576    pub fn test_sub<M: DenseMatrix>() {
1577        let a = M::from_vec(
1578            2,
1579            2,
1580            vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
1581            Default::default(),
1582        );
1583        let b = M::from_vec(
1584            2,
1585            2,
1586            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1587            Default::default(),
1588        );
1589        let result = a - &b;
1590        assert_eq!(result.get_index(0, 0), f::<M>(4.0));
1591        assert_eq!(result.get_index(1, 1), f::<M>(4.0));
1592    }
1593
1594    pub fn test_add_assign<M: DenseMatrix>() {
1595        let mut a = M::from_vec(
1596            2,
1597            2,
1598            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1599            Default::default(),
1600        );
1601        let b = M::from_vec(
1602            2,
1603            2,
1604            vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
1605            Default::default(),
1606        );
1607        a += &b;
1608        assert_eq!(a.get_index(0, 0), f::<M>(6.0));
1609        assert_eq!(a.get_index(1, 1), f::<M>(12.0));
1610    }
1611
1612    pub fn test_sub_assign<M: DenseMatrix>() {
1613        let mut a = M::from_vec(
1614            2,
1615            2,
1616            vec![f::<M>(5.0), f::<M>(7.0), f::<M>(6.0), f::<M>(8.0)],
1617            Default::default(),
1618        );
1619        let b = M::from_vec(
1620            2,
1621            2,
1622            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1623            Default::default(),
1624        );
1625        a -= &b;
1626        assert_eq!(a.get_index(0, 0), f::<M>(4.0));
1627        assert_eq!(a.get_index(1, 1), f::<M>(4.0));
1628    }
1629
1630    pub fn test_gather<M: DenseMatrix>() {
1631        let mat1 = M::from_vec(
1632            3,
1633            3,
1634            vec![
1635                f::<M>(1.0),
1636                f::<M>(2.0),
1637                f::<M>(3.0),
1638                f::<M>(4.0),
1639                f::<M>(5.0),
1640                f::<M>(6.0),
1641                f::<M>(7.0),
1642                f::<M>(8.0),
1643                f::<M>(9.0),
1644            ],
1645            Default::default(),
1646        );
1647        let mut mat2 = M::zeros(2, 2, Default::default());
1648        let indices = <M::V as Vector>::Index::from_vec(vec![0, 1, 3, 4], Default::default());
1649        mat2.gather(&mat1, &indices);
1650        assert_eq!(mat2.get_index(0, 0), f::<M>(1.0));
1651        assert_eq!(mat2.get_index(1, 0), f::<M>(2.0));
1652        assert_eq!(mat2.get_index(0, 1), f::<M>(4.0));
1653        assert_eq!(mat2.get_index(1, 1), f::<M>(5.0));
1654    }
1655
1656    pub fn test_set_data_with_indices<M: DenseMatrix>() {
1657        let mut mat = M::zeros(2, 2, Default::default());
1658        let dst_indices = <M::V as Vector>::Index::from_vec(vec![0, 3], Default::default());
1659        let src_indices = <M::V as Vector>::Index::from_vec(vec![0, 1], Default::default());
1660        let data = M::V::from_vec(vec![f::<M>(5.0), f::<M>(6.0)], Default::default());
1661        mat.set_data_with_indices(&dst_indices, &src_indices, &data);
1662        assert_eq!(mat.get_index(0, 0), f::<M>(5.0));
1663        assert_eq!(mat.get_index(1, 1), f::<M>(6.0));
1664    }
1665
1666    pub fn test_mul_assign_scalar<M: DenseMatrix>() {
1667        let mut mat = M::from_vec(
1668            2,
1669            2,
1670            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
1671            Default::default(),
1672        );
1673        {
1674            let mut view = mat.columns_mut(0, 2);
1675            view *= Scale(f::<M>(2.0));
1676        }
1677        assert_eq!(mat.get_index(0, 0), f::<M>(2.0));
1678        assert_eq!(mat.get_index(1, 1), f::<M>(8.0));
1679    }
1680
1681    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1682    pub fn test_batched_combine<M: DenseMatrix>(ctx: M::C) {
1683        assert_eq!(ctx.nbatch(), 2);
1684        #[rustfmt::skip]
1685        let data: Vec<M::T> = vec![
1686            // batch 0: 4x4 column-major (cols 0-3)
1687            f::<M>(1.0), f::<M>(2.0), f::<M>(3.0), f::<M>(4.0),
1688            f::<M>(5.0), f::<M>(6.0), f::<M>(7.0), f::<M>(8.0),
1689            f::<M>(9.0), f::<M>(10.0), f::<M>(11.0), f::<M>(12.0),
1690            f::<M>(13.0), f::<M>(14.0), f::<M>(15.0), f::<M>(16.0),
1691            // batch 1: 4x4 column-major (cols 0-3)
1692            f::<M>(101.0), f::<M>(102.0), f::<M>(103.0), f::<M>(104.0),
1693            f::<M>(105.0), f::<M>(106.0), f::<M>(107.0), f::<M>(108.0),
1694            f::<M>(109.0), f::<M>(110.0), f::<M>(111.0), f::<M>(112.0),
1695            f::<M>(113.0), f::<M>(114.0), f::<M>(115.0), f::<M>(116.0),
1696        ];
1697        let m = M::from_vec(4, 4, data, ctx.clone());
1698
1699        let alg_indices = <M::V as Vector>::Index::from_vec(vec![1, 3], Default::default());
1700        let [(ul, _), (ur, _), (ll, _), (lr, _)] = m.split(&alg_indices);
1701
1702        let recombined = M::combine(&ul, &ur, &ll, &lr, &alg_indices);
1703
1704        let (_orig_idx, orig_vals) = m.triplet_iter();
1705        let (_recom_idx, recom_vals) = recombined.triplet_iter();
1706        let orig_vals: Vec<_> = orig_vals.collect();
1707        let recom_vals: Vec<_> = recom_vals.collect();
1708        assert_eq!(orig_vals, recom_vals);
1709    }
1710
1711    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1712    pub fn test_batched_add_column_to_vector_m<M: Matrix>(ctx: M::C) {
1713        assert_eq!(ctx.nbatch(), 2);
1714        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1715        let values = vec![
1716            f::<M>(1.0),
1717            f::<M>(2.0),
1718            f::<M>(3.0),
1719            f::<M>(4.0),
1720            f::<M>(5.0),
1721            f::<M>(6.0),
1722            f::<M>(7.0),
1723            f::<M>(8.0),
1724        ];
1725        let mat = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
1726        let mut v = M::V::zeros(2, ctx);
1727        mat.add_column_to_vector(1, &mut v);
1728        assert_eq!(
1729            v.clone_as_vec(),
1730            vec![f::<M>(3.0), f::<M>(4.0), f::<M>(7.0), f::<M>(8.0)]
1731        );
1732    }
1733
1734    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1735    pub fn test_batched_set_data_with_indices_m<M: Matrix>(ctx: M::C) {
1736        assert_eq!(ctx.nbatch(), 2);
1737        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1738        let zero_values = vec![
1739            f::<M>(0.0),
1740            f::<M>(0.0),
1741            f::<M>(0.0),
1742            f::<M>(0.0),
1743            f::<M>(0.0),
1744            f::<M>(0.0),
1745            f::<M>(0.0),
1746            f::<M>(0.0),
1747        ];
1748        let mut mat = M::try_from_triplets(2, 2, indices, zero_values, ctx.clone()).unwrap();
1749        let dst_indices = <M::V as Vector>::Index::from_vec(vec![0, 3], Default::default());
1750        let src_indices = <M::V as Vector>::Index::from_vec(vec![0, 1], Default::default());
1751        let data = M::V::from_vec(
1752            vec![f::<M>(5.0), f::<M>(6.0), f::<M>(50.0), f::<M>(60.0)],
1753            ctx,
1754        );
1755        mat.set_data_with_indices(&dst_indices, &src_indices, &data);
1756        let (_, vals) = mat.triplet_iter();
1757        let vals: Vec<_> = vals.collect();
1758        assert_eq!(
1759            vals,
1760            vec![
1761                f::<M>(5.0),
1762                f::<M>(0.0),
1763                f::<M>(0.0),
1764                f::<M>(6.0),
1765                f::<M>(50.0),
1766                f::<M>(0.0),
1767                f::<M>(0.0),
1768                f::<M>(60.0),
1769            ]
1770        );
1771    }
1772
1773    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1774    pub fn test_batched_gather_m<M: Matrix>(ctx: M::C) {
1775        assert_eq!(ctx.nbatch(), 2);
1776        let indices: Vec<(IndexType, IndexType)> =
1777            (0..3).flat_map(|j| (0..3).map(move |i| (i, j))).collect();
1778        let values = vec![
1779            f::<M>(1.0),
1780            f::<M>(2.0),
1781            f::<M>(3.0),
1782            f::<M>(4.0),
1783            f::<M>(5.0),
1784            f::<M>(6.0),
1785            f::<M>(7.0),
1786            f::<M>(8.0),
1787            f::<M>(9.0),
1788            f::<M>(10.0),
1789            f::<M>(20.0),
1790            f::<M>(30.0),
1791            f::<M>(40.0),
1792            f::<M>(50.0),
1793            f::<M>(60.0),
1794            f::<M>(70.0),
1795            f::<M>(80.0),
1796            f::<M>(90.0),
1797        ];
1798        let mat1 = M::try_from_triplets(3, 3, indices, values, ctx.clone()).unwrap();
1799        let dest_indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1800        let zero_values = vec![
1801            f::<M>(0.0),
1802            f::<M>(0.0),
1803            f::<M>(0.0),
1804            f::<M>(0.0),
1805            f::<M>(0.0),
1806            f::<M>(0.0),
1807            f::<M>(0.0),
1808            f::<M>(0.0),
1809        ];
1810        let mut mat2 = M::try_from_triplets(2, 2, dest_indices, zero_values, ctx).unwrap();
1811        let gather_indices =
1812            <M::V as Vector>::Index::from_vec(vec![0, 1, 3, 4], Default::default());
1813        mat2.gather(&mat1, &gather_indices);
1814        let (_, vals) = mat2.triplet_iter();
1815        let vals: Vec<_> = vals.collect();
1816        assert_eq!(
1817            vals,
1818            vec![
1819                f::<M>(1.0),
1820                f::<M>(2.0),
1821                f::<M>(4.0),
1822                f::<M>(5.0),
1823                f::<M>(10.0),
1824                f::<M>(20.0),
1825                f::<M>(40.0),
1826                f::<M>(50.0),
1827            ]
1828        );
1829    }
1830
1831    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1832    pub fn test_batched_mul_scalar_m<M: Matrix>(ctx: M::C) {
1833        assert_eq!(ctx.nbatch(), 2);
1834        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
1835        let values = vec![
1836            f::<M>(1.0),
1837            f::<M>(3.0),
1838            f::<M>(2.0),
1839            f::<M>(4.0),
1840            f::<M>(5.0),
1841            f::<M>(7.0),
1842            f::<M>(6.0),
1843            f::<M>(8.0),
1844        ];
1845        let a = M::try_from_triplets(2, 2, indices, values, ctx.clone()).unwrap();
1846        let result = a * Scale(f::<M>(2.0));
1847        let (_, vals) = result.triplet_iter();
1848        let vals: Vec<_> = vals.collect();
1849        assert_eq!(
1850            vals,
1851            vec![
1852                f::<M>(2.0),
1853                f::<M>(6.0),
1854                f::<M>(4.0),
1855                f::<M>(8.0),
1856                f::<M>(10.0),
1857                f::<M>(14.0),
1858                f::<M>(12.0),
1859                f::<M>(16.0),
1860            ]
1861        );
1862    }
1863
1864    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1865    pub fn test_batched_partition_indices<M: Matrix>(ctx: M::C) {
1866        assert_eq!(ctx.nbatch(), 2);
1867        let zero_val = M::T::zero();
1868        let one_val = f::<M>(1.0);
1869        let two_val = f::<M>(2.0);
1870        let indices = vec![(0, 0), (1, 1), (2, 2)];
1871        let values = vec![one_val, zero_val, one_val, two_val, zero_val, two_val];
1872        let m = M::try_from_triplets(3, 3, indices, values, ctx).unwrap();
1873        let (zero_idx, nonzero_idx) = m.partition_indices_by_zero_diagonal();
1874        assert_eq!(zero_idx.clone_as_vec(), vec![1]);
1875        assert_eq!(nonzero_idx.clone_as_vec(), vec![0, 2]);
1876    }
1877
1878    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1879    pub fn test_batched_column_axpy<M: DenseMatrix>(ctx: M::C) {
1880        assert_eq!(ctx.nbatch(), 2);
1881        let mut a = M::from_vec(
1882            2,
1883            2,
1884            vec![
1885                f::<M>(1.0),
1886                f::<M>(3.0),
1887                f::<M>(2.0),
1888                f::<M>(4.0),
1889                f::<M>(5.0),
1890                f::<M>(7.0),
1891                f::<M>(6.0),
1892                f::<M>(8.0),
1893            ],
1894            ctx,
1895        );
1896        a.column_axpy(f::<M>(2.0), 0, 1);
1897        assert_eq!(a.get_index(0, 0), f::<M>(1.0));
1898        assert_eq!(a.get_index(0, 1), f::<M>(4.0));
1899        assert_eq!(a.get_index(1, 0), f::<M>(3.0));
1900        assert_eq!(a.get_index(1, 1), f::<M>(10.0));
1901    }
1902
1903    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1904    pub fn test_batched_mat_mul<M: DenseMatrix>(ctx: M::C) {
1905        assert_eq!(ctx.nbatch(), 2);
1906        let a = M::from_vec(
1907            2,
1908            2,
1909            vec![
1910                f::<M>(1.0),
1911                f::<M>(3.0),
1912                f::<M>(2.0),
1913                f::<M>(4.0),
1914                f::<M>(2.0),
1915                f::<M>(1.0),
1916                f::<M>(0.0),
1917                f::<M>(3.0),
1918            ],
1919            ctx.clone(),
1920        );
1921        let b = M::from_vec(
1922            2,
1923            2,
1924            vec![
1925                f::<M>(2.0),
1926                f::<M>(1.0),
1927                f::<M>(0.0),
1928                f::<M>(3.0),
1929                f::<M>(1.0),
1930                f::<M>(0.0),
1931                f::<M>(2.0),
1932                f::<M>(1.0),
1933            ],
1934            ctx.clone(),
1935        );
1936        let c = a.mat_mul(&b);
1937        assert_eq!(c.get_index(0, 0), f::<M>(4.0));
1938        assert_eq!(c.get_index(1, 0), f::<M>(10.0));
1939        assert_eq!(c.get_index(0, 1), f::<M>(6.0));
1940        assert_eq!(c.get_index(1, 1), f::<M>(12.0));
1941    }
1942
1943    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1944    pub fn test_batched_from_diagonal_dense<M: DenseMatrix>(ctx: M::C) {
1945        assert_eq!(ctx.nbatch(), 2);
1946        let v = M::V::from_vec(
1947            vec![f::<M>(2.0), f::<M>(3.0), f::<M>(4.0), f::<M>(5.0)],
1948            ctx,
1949        );
1950        let a = M::from_diagonal(&v);
1951        assert_eq!(a.nrows(), 2);
1952        assert_eq!(a.ncols(), 2);
1953        assert_eq!(a.get_index(0, 0), f::<M>(2.0));
1954        assert_eq!(a.get_index(1, 1), f::<M>(3.0));
1955        assert_eq!(a.get_index(0, 1), f::<M>(0.0));
1956        assert_eq!(a.get_index(1, 0), f::<M>(0.0));
1957    }
1958
1959    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1960    fn make_strided_matrix<M: DenseMatrix>(nbatch: usize) -> M {
1961        let ctx = M::C::default().clone_with_nbatch(nbatch).unwrap();
1962        let nrows = 3;
1963        let ncols = 4;
1964        let mut data = Vec::with_capacity(nrows * ncols * nbatch);
1965        for b in 0..nbatch {
1966            for col in 0..ncols {
1967                for row in 0..nrows {
1968                    data.push(f::<M>(row as f64 + col as f64 * 10.0 + b as f64 * 100.0));
1969                }
1970            }
1971        }
1972        M::from_vec(nrows, ncols, data, ctx)
1973    }
1974
1975    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1976    pub fn test_strided_matrix_view_into_owned<M: DenseMatrix>(ctx: M::C) {
1977        let matrix = make_strided_matrix::<M>(ctx.nbatch());
1978        let view = matrix.columns(0, 2);
1979        let owned = view.into_owned();
1980        assert_eq!(owned.nrows(), 3);
1981        assert_eq!(owned.ncols(), 2);
1982        // column 0, batch 0: [0,1,2]
1983        assert_eq!(owned.get_index(0, 0), f::<M>(0.0));
1984        assert_eq!(owned.get_index(1, 0), f::<M>(1.0));
1985        assert_eq!(owned.get_index(2, 0), f::<M>(2.0));
1986        // column 1, batch 0: [10,11,12]
1987        assert_eq!(owned.get_index(0, 1), f::<M>(10.0));
1988        assert_eq!(owned.get_index(1, 1), f::<M>(11.0));
1989        assert_eq!(owned.get_index(2, 1), f::<M>(12.0));
1990    }
1991
1992    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
1993    pub fn test_strided_matrix_view_add_owned<M: DenseMatrix>(ctx: M::C) {
1994        let matrix = make_strided_matrix::<M>(ctx.nbatch());
1995        let view = matrix.columns(0, 2);
1996        // owned 3x2 with nbatch=1 (broadcast) — column-major: col0=[1,2,3], col1=[4,5,6]
1997        let rhs = M::from_vec(
1998            3,
1999            2,
2000            vec![
2001                f::<M>(1.0),
2002                f::<M>(2.0),
2003                f::<M>(3.0),
2004                f::<M>(4.0),
2005                f::<M>(5.0),
2006                f::<M>(6.0),
2007            ],
2008            M::C::default(),
2009        );
2010        let result = view + &rhs;
2011        // batch0: [0,1,2,10,11,12] + [1,2,3,4,5,6] = [1,3,5,14,16,18]
2012        assert_eq!(result.get_index(0, 0), f::<M>(1.0));
2013        assert_eq!(result.get_index(0, 1), f::<M>(14.0));
2014    }
2015
2016    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2017    pub fn test_strided_matrix_view_sub_owned<M: DenseMatrix>(ctx: M::C) {
2018        let matrix = make_strided_matrix::<M>(ctx.nbatch());
2019        let view = matrix.columns(0, 2);
2020        let rhs = M::from_vec(
2021            3,
2022            2,
2023            vec![
2024                f::<M>(0.0),
2025                f::<M>(1.0),
2026                f::<M>(2.0),
2027                f::<M>(10.0),
2028                f::<M>(11.0),
2029                f::<M>(12.0),
2030            ],
2031            M::C::default(),
2032        );
2033        let result = view - &rhs;
2034        // batch0: [0,1,2,10,11,12] - [0,1,2,10,11,12] = all zeros
2035        assert_eq!(result.get_index(0, 0), f::<M>(0.0));
2036        assert_eq!(result.get_index(0, 1), f::<M>(0.0));
2037    }
2038
2039    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2040    pub fn test_strided_matrix_view_mul_scalar<M: DenseMatrix>(ctx: M::C) {
2041        let matrix = make_strided_matrix::<M>(ctx.nbatch());
2042        let view = matrix.columns(0, 2);
2043        let result = view * Scale(f::<M>(2.0));
2044        assert_eq!(result.get_index(0, 0), f::<M>(0.0));
2045        assert_eq!(result.get_index(1, 0), f::<M>(2.0));
2046        assert_eq!(result.get_index(0, 1), f::<M>(20.0));
2047    }
2048
2049    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2050    pub fn test_strided_matrix_view_mut_add_assign_view<M: DenseMatrix>(ctx: M::C) {
2051        let mut a = make_strided_matrix::<M>(ctx.nbatch());
2052        let b = make_strided_matrix::<M>(ctx.nbatch());
2053        {
2054            let mut a_view = a.columns_mut(0, 2);
2055            let b_view = b.columns(2, 4);
2056            a_view += &b_view;
2057        }
2058        // a columns 0-1 now = original a[0..2] + b[2..4]
2059        // batch0 a[0..2]: [[0,10],[1,11],[2,12]]
2060        // batch0 b[2..4]: [[20,30],[21,31],[22,32]]
2061        // sum: [[20,40],[22,42],[24,44]]
2062        assert_eq!(a.get_index(0, 0), f::<M>(20.0));
2063        assert_eq!(a.get_index(1, 0), f::<M>(22.0));
2064        assert_eq!(a.get_index(0, 1), f::<M>(40.0));
2065    }
2066
2067    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2068    pub fn test_strided_matrix_view_mut_sub_assign_view<M: DenseMatrix>(ctx: M::C) {
2069        let mut a = make_strided_matrix::<M>(ctx.nbatch());
2070        let b = make_strided_matrix::<M>(ctx.nbatch());
2071        {
2072            let mut a_view = a.columns_mut(0, 2);
2073            let b_view = b.columns(0, 2);
2074            a_view -= &b_view;
2075        }
2076        // same columns subtracted = all zero
2077        assert_eq!(a.get_index(0, 0), f::<M>(0.0));
2078        assert_eq!(a.get_index(1, 0), f::<M>(0.0));
2079    }
2080
2081    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2082    pub fn test_strided_matrix_view_mut_mul_assign_scalar<M: DenseMatrix>(ctx: M::C) {
2083        let mut a = make_strided_matrix::<M>(ctx.nbatch());
2084        {
2085            let mut a_view = a.columns_mut(0, 2);
2086            a_view *= Scale(f::<M>(2.0));
2087        }
2088        assert_eq!(a.get_index(0, 0), f::<M>(0.0));
2089        assert_eq!(a.get_index(1, 0), f::<M>(2.0));
2090        assert_eq!(a.get_index(0, 1), f::<M>(20.0));
2091    }
2092
2093    // --- View-mut tests (into_owned, gemm_oo, += / -= between two mutable views) ---
2094
2095    pub fn test_view_mut_into_owned<M: DenseMatrix>() {
2096        let mut a = M::from_vec(
2097            2,
2098            3,
2099            vec![
2100                f::<M>(1.0),
2101                f::<M>(2.0),
2102                f::<M>(3.0),
2103                f::<M>(4.0),
2104                f::<M>(5.0),
2105                f::<M>(6.0),
2106            ],
2107            Default::default(),
2108        );
2109        let owned = a.columns_mut(0, 2).into_owned();
2110        assert_eq!(owned.nrows(), 2);
2111        assert_eq!(owned.ncols(), 2);
2112        assert_eq!(owned.get_index(0, 0), f::<M>(1.0));
2113        assert_eq!(owned.get_index(1, 0), f::<M>(2.0));
2114        assert_eq!(owned.get_index(0, 1), f::<M>(3.0));
2115        assert_eq!(owned.get_index(1, 1), f::<M>(4.0));
2116    }
2117
2118    pub fn test_view_mut_add_assign_view_mut<M: DenseMatrix>() {
2119        let mut a = M::from_vec(
2120            2,
2121            2,
2122            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
2123            Default::default(),
2124        );
2125        let mut b = M::from_vec(
2126            2,
2127            2,
2128            vec![f::<M>(10.0), f::<M>(30.0), f::<M>(20.0), f::<M>(40.0)],
2129            Default::default(),
2130        );
2131        {
2132            let mut a_view = a.columns_mut(0, 2);
2133            let b_view = b.columns_mut(0, 2);
2134            a_view += &b_view;
2135        }
2136        assert_eq!(a.get_index(0, 0), f::<M>(11.0));
2137        assert_eq!(a.get_index(1, 0), f::<M>(33.0));
2138        assert_eq!(a.get_index(0, 1), f::<M>(22.0));
2139        assert_eq!(a.get_index(1, 1), f::<M>(44.0));
2140    }
2141
2142    pub fn test_view_mut_sub_assign_view_mut<M: DenseMatrix>() {
2143        let mut a = M::from_vec(
2144            2,
2145            2,
2146            vec![f::<M>(10.0), f::<M>(30.0), f::<M>(20.0), f::<M>(40.0)],
2147            Default::default(),
2148        );
2149        let mut b = M::from_vec(
2150            2,
2151            2,
2152            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
2153            Default::default(),
2154        );
2155        {
2156            let mut a_view = a.columns_mut(0, 2);
2157            let b_view = b.columns_mut(0, 2);
2158            a_view -= &b_view;
2159        }
2160        assert_eq!(a.get_index(0, 0), f::<M>(9.0));
2161        assert_eq!(a.get_index(1, 0), f::<M>(27.0));
2162        assert_eq!(a.get_index(0, 1), f::<M>(18.0));
2163        assert_eq!(a.get_index(1, 1), f::<M>(36.0));
2164    }
2165
2166    pub fn test_gemm_oo_on_columns<M: DenseMatrix>() {
2167        // a = [[1,2],[3,4]] (col-major [1,3,2,4])
2168        let a = M::from_vec(
2169            2,
2170            2,
2171            vec![f::<M>(1.0), f::<M>(3.0), f::<M>(2.0), f::<M>(4.0)],
2172            Default::default(),
2173        );
2174        // b = identity
2175        let b = M::from_vec(
2176            2,
2177            2,
2178            vec![f::<M>(1.0), f::<M>(0.0), f::<M>(0.0), f::<M>(1.0)],
2179            Default::default(),
2180        );
2181        let mut result = M::zeros(2, 3, Default::default());
2182        {
2183            let mut r_view = result.columns_mut(0, 2);
2184            r_view.gemm_oo(f::<M>(1.0), &a, &b, f::<M>(0.0));
2185        }
2186        // result columns 0-1 = a * I = a; column 2 untouched (zero)
2187        assert_eq!(result.get_index(0, 0), f::<M>(1.0));
2188        assert_eq!(result.get_index(1, 0), f::<M>(3.0));
2189        assert_eq!(result.get_index(0, 1), f::<M>(2.0));
2190        assert_eq!(result.get_index(1, 1), f::<M>(4.0));
2191        assert_eq!(result.get_index(0, 2), f::<M>(0.0));
2192        assert_eq!(result.get_index(1, 2), f::<M>(0.0));
2193    }
2194
2195    pub fn test_try_from_triplets_wrong_length<M: Matrix>() {
2196        let indices = vec![(0, 0), (1, 0), (0, 1), (1, 1)];
2197        // one value too few: triggers the length assertion inside try_from_triplets
2198        let values = vec![f::<M>(1.0), f::<M>(2.0), f::<M>(3.0)];
2199        let _ = M::try_from_triplets(2, 2, indices, values, Default::default());
2200    }
2201
2202    // --- Batched view-mut tests ---
2203
2204    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2205    pub fn test_strided_matrix_view_mut_into_owned<M: DenseMatrix>(ctx: M::C) {
2206        let mut matrix = make_strided_matrix::<M>(ctx.nbatch());
2207        let owned = matrix.columns_mut(0, 2).into_owned();
2208        assert_eq!(owned.nrows(), 3);
2209        assert_eq!(owned.ncols(), 2);
2210        // batch 0 col0=[0,1,2], col1=[10,11,12]
2211        assert_eq!(owned.get_index(0, 0), f::<M>(0.0));
2212        assert_eq!(owned.get_index(1, 0), f::<M>(1.0));
2213        assert_eq!(owned.get_index(2, 0), f::<M>(2.0));
2214        assert_eq!(owned.get_index(0, 1), f::<M>(10.0));
2215        assert_eq!(owned.get_index(1, 1), f::<M>(11.0));
2216        assert_eq!(owned.get_index(2, 1), f::<M>(12.0));
2217        // verify both batches via triplet_iter
2218        let (_, vals) = owned.triplet_iter();
2219        let vals: Vec<_> = vals.collect();
2220        assert_eq!(
2221            vals,
2222            vec![
2223                f::<M>(0.0),
2224                f::<M>(1.0),
2225                f::<M>(2.0),
2226                f::<M>(10.0),
2227                f::<M>(11.0),
2228                f::<M>(12.0),
2229                f::<M>(100.0),
2230                f::<M>(101.0),
2231                f::<M>(102.0),
2232                f::<M>(110.0),
2233                f::<M>(111.0),
2234                f::<M>(112.0),
2235            ]
2236        );
2237    }
2238
2239    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2240    pub fn test_batched_view_mut_add_assign_view_mut<M: DenseMatrix>(ctx: M::C) {
2241        assert_eq!(ctx.nbatch(), 2);
2242        // a 2x2 nbatch=2: batch0 [[1,2],[3,4]], batch1 [[5,6],[7,8]]
2243        let mut a = M::from_vec(
2244            2,
2245            2,
2246            vec![
2247                f::<M>(1.0),
2248                f::<M>(3.0),
2249                f::<M>(2.0),
2250                f::<M>(4.0),
2251                f::<M>(5.0),
2252                f::<M>(7.0),
2253                f::<M>(6.0),
2254                f::<M>(8.0),
2255            ],
2256            ctx.clone(),
2257        );
2258        let mut b = M::from_vec(
2259            2,
2260            2,
2261            vec![
2262                f::<M>(10.0),
2263                f::<M>(30.0),
2264                f::<M>(20.0),
2265                f::<M>(40.0),
2266                f::<M>(50.0),
2267                f::<M>(70.0),
2268                f::<M>(60.0),
2269                f::<M>(80.0),
2270            ],
2271            ctx,
2272        );
2273        {
2274            let mut a_view = a.columns_mut(0, 2);
2275            let b_view = b.columns_mut(0, 2);
2276            a_view += &b_view;
2277        }
2278        let (_, vals) = a.triplet_iter();
2279        let vals: Vec<_> = vals.collect();
2280        assert_eq!(
2281            vals,
2282            vec![
2283                f::<M>(11.0),
2284                f::<M>(33.0),
2285                f::<M>(22.0),
2286                f::<M>(44.0),
2287                f::<M>(55.0),
2288                f::<M>(77.0),
2289                f::<M>(66.0),
2290                f::<M>(88.0),
2291            ]
2292        );
2293    }
2294
2295    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2296    pub fn test_batched_view_mut_sub_assign_view_mut<M: DenseMatrix>(ctx: M::C) {
2297        assert_eq!(ctx.nbatch(), 2);
2298        let mut a = M::from_vec(
2299            2,
2300            2,
2301            vec![
2302                f::<M>(10.0),
2303                f::<M>(30.0),
2304                f::<M>(20.0),
2305                f::<M>(40.0),
2306                f::<M>(50.0),
2307                f::<M>(70.0),
2308                f::<M>(60.0),
2309                f::<M>(80.0),
2310            ],
2311            ctx.clone(),
2312        );
2313        let mut b = M::from_vec(
2314            2,
2315            2,
2316            vec![
2317                f::<M>(1.0),
2318                f::<M>(3.0),
2319                f::<M>(2.0),
2320                f::<M>(4.0),
2321                f::<M>(5.0),
2322                f::<M>(7.0),
2323                f::<M>(6.0),
2324                f::<M>(8.0),
2325            ],
2326            ctx,
2327        );
2328        {
2329            let mut a_view = a.columns_mut(0, 2);
2330            let b_view = b.columns_mut(0, 2);
2331            a_view -= &b_view;
2332        }
2333        let (_, vals) = a.triplet_iter();
2334        let vals: Vec<_> = vals.collect();
2335        assert_eq!(
2336            vals,
2337            vec![
2338                f::<M>(9.0),
2339                f::<M>(27.0),
2340                f::<M>(18.0),
2341                f::<M>(36.0),
2342                f::<M>(45.0),
2343                f::<M>(63.0),
2344                f::<M>(54.0),
2345                f::<M>(72.0),
2346            ]
2347        );
2348    }
2349
2350    #[cfg_attr(not(feature = "cuda"), allow(dead_code))]
2351    pub fn test_batched_gemm_oo_on_columns<M: DenseMatrix>(ctx: M::C) {
2352        assert_eq!(ctx.nbatch(), 2);
2353        // a 2x2 nbatch=2: batch0 [[1,2],[3,4]], batch1 [[5,6],[7,8]]
2354        let a = M::from_vec(
2355            2,
2356            2,
2357            vec![
2358                f::<M>(1.0),
2359                f::<M>(3.0),
2360                f::<M>(2.0),
2361                f::<M>(4.0),
2362                f::<M>(5.0),
2363                f::<M>(7.0),
2364                f::<M>(6.0),
2365                f::<M>(8.0),
2366            ],
2367            ctx.clone(),
2368        );
2369        // b 2x2 nbatch=2: batch0 identity, batch1 2*identity
2370        let b = M::from_vec(
2371            2,
2372            2,
2373            vec![
2374                f::<M>(1.0),
2375                f::<M>(0.0),
2376                f::<M>(0.0),
2377                f::<M>(1.0),
2378                f::<M>(2.0),
2379                f::<M>(0.0),
2380                f::<M>(0.0),
2381                f::<M>(2.0),
2382            ],
2383            ctx.clone(),
2384        );
2385        let mut result = M::zeros(2, 3, ctx);
2386        {
2387            let mut r_view = result.columns_mut(0, 2);
2388            r_view.gemm_oo(f::<M>(1.0), &a, &b, f::<M>(0.0));
2389        }
2390        // batch0: a*I = [[1,2],[3,4]]; batch1: a*2I = [[10,12],[14,16]]; col2 = 0
2391        assert_eq!(result.get_index(0, 0), f::<M>(1.0));
2392        assert_eq!(result.get_index(1, 0), f::<M>(3.0));
2393        assert_eq!(result.get_index(0, 1), f::<M>(2.0));
2394        assert_eq!(result.get_index(1, 1), f::<M>(4.0));
2395        let (_, vals) = result.triplet_iter();
2396        let vals: Vec<_> = vals.collect();
2397        assert_eq!(
2398            vals,
2399            vec![
2400                f::<M>(1.0),
2401                f::<M>(3.0),
2402                f::<M>(2.0),
2403                f::<M>(4.0),
2404                f::<M>(0.0),
2405                f::<M>(0.0),
2406                f::<M>(10.0),
2407                f::<M>(14.0),
2408                f::<M>(12.0),
2409                f::<M>(16.0),
2410                f::<M>(0.0),
2411                f::<M>(0.0),
2412            ]
2413        );
2414    }
2415}
2416
2417#[cfg(test)]
2418macro_rules! generate_matrix_tests_nonbatched {
2419    ($suffix:ident, $M:ty) => {
2420        paste::paste! {
2421            #[test]
2422            fn [<test_zeros_ $suffix>]() {
2423                $crate::matrix::tests::test_zeros::<$M>();
2424            }
2425            #[test]
2426            fn [<test_from_diagonal_ $suffix>]() {
2427                $crate::matrix::tests::test_from_diagonal::<$M>();
2428            }
2429            #[test]
2430            fn [<test_gemv_ $suffix>]() {
2431                $crate::matrix::tests::test_gemv::<$M>();
2432            }
2433            #[test]
2434            fn [<test_set_column_ $suffix>]() {
2435                $crate::matrix::tests::test_set_column::<$M>();
2436            }
2437            #[test]
2438            fn [<test_copy_from_ $suffix>]() {
2439                $crate::matrix::tests::test_copy_from::<$M>();
2440            }
2441            #[test]
2442            fn [<test_scale_add_and_assign_ $suffix>]() {
2443                $crate::matrix::tests::test_scale_add_and_assign::<$M>();
2444            }
2445            #[test]
2446            fn [<test_partition_indices_ $suffix>]() {
2447                $crate::matrix::tests::test_partition_indices_by_zero_diagonal::<$M>();
2448            }
2449            #[test]
2450            fn [<test_mul_scalar_ $suffix>]() {
2451                $crate::matrix::tests::test_mul_scalar::<$M>();
2452            }
2453            #[test]
2454            fn [<test_add_column_to_vector_ $suffix>]() {
2455                $crate::matrix::tests::test_add_column_to_vector::<$M>();
2456            }
2457            #[test]
2458            #[should_panic]
2459            fn [<test_try_from_triplets_wrong_length_ $suffix>]() {
2460                $crate::matrix::tests::test_try_from_triplets_wrong_length::<$M>();
2461            }
2462        }
2463    };
2464}
2465
2466#[cfg(test)]
2467#[cfg_attr(not(feature = "cuda"), allow(unused_macros))]
2468macro_rules! generate_matrix_tests_batched {
2469    ($suffix:ident, $M:ty, $ctx1:expr, $ctx2:expr) => {
2470        paste::paste! {
2471            #[test]
2472            fn [<test_batched_add_column_to_vector_ $suffix>]() {
2473                $crate::matrix::tests::test_batched_add_column_to_vector_m::<$M>($ctx2);
2474            }
2475            #[test]
2476            fn [<test_batched_set_data_with_indices_ $suffix>]() {
2477                $crate::matrix::tests::test_batched_set_data_with_indices_m::<$M>($ctx2);
2478            }
2479            #[test]
2480            fn [<test_batched_gather_ $suffix>]() {
2481                $crate::matrix::tests::test_batched_gather_m::<$M>($ctx2);
2482            }
2483            #[test]
2484            fn [<test_batched_mul_scalar_ $suffix>]() {
2485                $crate::matrix::tests::test_batched_mul_scalar_m::<$M>($ctx2);
2486            }
2487            #[test]
2488            fn [<test_batched_partition_indices_ $suffix>]() {
2489                $crate::matrix::tests::test_batched_partition_indices::<$M>($ctx2);
2490            }
2491            #[test]
2492            fn [<test_batched_zeros_ $suffix>]() {
2493                $crate::matrix::tests::test_batched_zeros_m::<$M>($ctx2);
2494            }
2495            #[test]
2496            fn [<test_batched_gemv_ $suffix>]() {
2497                $crate::matrix::tests::test_batched_gemv_m::<$M>($ctx2);
2498            }
2499            #[test]
2500            fn [<test_batched_gemv_broadcast_x_ $suffix>]() {
2501                $crate::matrix::tests::test_batched_gemv_broadcast_x_m::<$M>($ctx2);
2502            }
2503            #[test]
2504            fn [<test_batched_gemv_broadcast_mat_ $suffix>]() {
2505                $crate::matrix::tests::test_batched_gemv_broadcast_mat_m::<$M>($ctx2);
2506            }
2507            #[test]
2508            fn [<test_batched_from_diagonal_ $suffix>]() {
2509                $crate::matrix::tests::test_batched_from_diagonal_m::<$M>($ctx2);
2510            }
2511            #[test]
2512            fn [<test_batched_copy_from_ $suffix>]() {
2513                $crate::matrix::tests::test_batched_copy_from_m::<$M>($ctx2);
2514            }
2515            #[test]
2516            fn [<test_batched_set_column_ $suffix>]() {
2517                $crate::matrix::tests::test_batched_set_column_m::<$M>($ctx2);
2518            }
2519            #[test]
2520            fn [<test_batched_scale_add_ $suffix>]() {
2521                $crate::matrix::tests::test_batched_scale_add_and_assign_m::<$M>($ctx2);
2522            }
2523        }
2524    };
2525}
2526
2527#[cfg(test)]
2528macro_rules! generate_dense_matrix_tests_nonbatched {
2529    ($suffix:ident, $M:ty) => {
2530        paste::paste! {
2531            #[test]
2532            fn [<test_from_vec_ $suffix>]() {
2533                $crate::matrix::tests::test_from_vec::<$M>();
2534            }
2535            #[test]
2536            fn [<test_from_diagonal_dense_ $suffix>]() {
2537                $crate::matrix::tests::test_from_diagonal_dense::<$M>();
2538            }
2539            #[test]
2540            fn [<test_gemm_ $suffix>]() {
2541                $crate::matrix::tests::test_gemm::<$M>();
2542            }
2543            #[test]
2544            fn [<test_mat_mul_ $suffix>]() {
2545                $crate::matrix::tests::test_mat_mul::<$M>();
2546            }
2547            #[test]
2548            fn [<test_columns_view_ $suffix>]() {
2549                $crate::matrix::tests::test_columns_view::<$M>();
2550            }
2551            #[test]
2552            fn [<test_column_view_ $suffix>]() {
2553                $crate::matrix::tests::test_column_view::<$M>();
2554            }
2555            #[test]
2556            fn [<test_column_axpy_ $suffix>]() {
2557                $crate::matrix::tests::test_column_axpy::<$M>();
2558            }
2559            #[test]
2560            fn [<test_resize_cols_ $suffix>]() {
2561                $crate::matrix::tests::test_resize_cols::<$M>();
2562            }
2563            #[test]
2564            fn [<test_add_ $suffix>]() {
2565                $crate::matrix::tests::test_add::<$M>();
2566            }
2567            #[test]
2568            fn [<test_sub_ $suffix>]() {
2569                $crate::matrix::tests::test_sub::<$M>();
2570            }
2571            #[test]
2572            fn [<test_add_assign_ $suffix>]() {
2573                $crate::matrix::tests::test_add_assign::<$M>();
2574            }
2575            #[test]
2576            fn [<test_sub_assign_ $suffix>]() {
2577                $crate::matrix::tests::test_sub_assign::<$M>();
2578            }
2579            #[test]
2580            fn [<test_gather_ $suffix>]() {
2581                $crate::matrix::tests::test_gather::<$M>();
2582            }
2583            #[test]
2584            fn [<test_set_data_with_indices_ $suffix>]() {
2585                $crate::matrix::tests::test_set_data_with_indices::<$M>();
2586            }
2587            #[test]
2588            fn [<test_mul_assign_scalar_ $suffix>]() {
2589                $crate::matrix::tests::test_mul_assign_scalar::<$M>();
2590            }
2591            #[test]
2592            fn [<test_view_mut_into_owned_ $suffix>]() {
2593                $crate::matrix::tests::test_view_mut_into_owned::<$M>();
2594            }
2595            #[test]
2596            fn [<test_view_mut_add_assign_view_mut_ $suffix>]() {
2597                $crate::matrix::tests::test_view_mut_add_assign_view_mut::<$M>();
2598            }
2599            #[test]
2600            fn [<test_view_mut_sub_assign_view_mut_ $suffix>]() {
2601                $crate::matrix::tests::test_view_mut_sub_assign_view_mut::<$M>();
2602            }
2603            #[test]
2604            fn [<test_gemm_oo_on_columns_ $suffix>]() {
2605                $crate::matrix::tests::test_gemm_oo_on_columns::<$M>();
2606            }
2607        }
2608    };
2609}
2610
2611#[cfg(test)]
2612#[cfg_attr(not(feature = "cuda"), allow(unused_macros))]
2613macro_rules! generate_dense_matrix_tests_batched {
2614    ($suffix:ident, $M:ty, $ctx1:expr, $ctx2:expr) => {
2615        paste::paste! {
2616            #[test]
2617            fn [<test_batched_column_axpy_ $suffix>]() {
2618                $crate::matrix::tests::test_batched_column_axpy::<$M>($ctx2);
2619            }
2620            #[test]
2621            fn [<test_batched_mat_mul_ $suffix>]() {
2622                $crate::matrix::tests::test_batched_mat_mul::<$M>($ctx2);
2623            }
2624            #[test]
2625            fn [<test_batched_from_diagonal_dense_ $suffix>]() {
2626                $crate::matrix::tests::test_batched_from_diagonal_dense::<$M>($ctx2);
2627            }
2628            #[test]
2629            fn [<test_batched_from_vec_ $suffix>]() {
2630                $crate::matrix::tests::test_batched_from_vec::<$M>($ctx2);
2631            }
2632            #[test]
2633            fn [<test_batched_gemm_ $suffix>]() {
2634                $crate::matrix::tests::test_batched_gemm::<$M>($ctx2);
2635            }
2636            #[test]
2637            fn [<test_batched_columns_ $suffix>]() {
2638                $crate::matrix::tests::test_batched_columns::<$M>($ctx2);
2639            }
2640            #[test]
2641            fn [<test_batched_gemv_o_on_columns_ $suffix>]() {
2642                $crate::matrix::tests::test_batched_gemv_o_on_columns::<$M>($ctx2);
2643            }
2644            #[test]
2645            fn [<test_batched_gemm_vo_on_columns_ $suffix>]() {
2646                $crate::matrix::tests::test_batched_gemm_vo_on_columns::<$M>($ctx2);
2647            }
2648            #[test]
2649            fn [<test_batched_gemm_broadcast_b_ $suffix>]() {
2650                $crate::matrix::tests::test_batched_gemm_broadcast_b::<$M>($ctx2);
2651            }
2652            #[test]
2653            fn [<test_batched_gemv_o_broadcast_x_ $suffix>]() {
2654                $crate::matrix::tests::test_batched_gemv_o_broadcast_x::<$M>($ctx2);
2655            }
2656            #[test]
2657            fn [<test_batched_gemv_v_broadcast_mat_ $suffix>]() {
2658                $crate::matrix::tests::test_batched_gemv_v_broadcast_mat::<$M>($ctx2);
2659            }
2660            #[test]
2661            fn [<test_batched_gemv_o_broadcast_mat_ $suffix>]() {
2662                $crate::matrix::tests::test_batched_gemv_o_broadcast_mat::<$M>($ctx2);
2663            }
2664            #[test]
2665            fn [<test_batched_gemm_vo_broadcast_b_ $suffix>]() {
2666                $crate::matrix::tests::test_batched_gemm_vo_broadcast_b::<$M>($ctx2);
2667            }
2668            #[test]
2669            fn [<test_batched_gemm_broadcast_a_ $suffix>]() {
2670                $crate::matrix::tests::test_batched_gemm_broadcast_a::<$M>($ctx2);
2671            }
2672            #[test]
2673            fn [<test_batched_gemm_vo_broadcast_a_ $suffix>]() {
2674                $crate::matrix::tests::test_batched_gemm_vo_broadcast_a::<$M>($ctx2);
2675            }
2676            #[test]
2677            fn [<test_batched_resize_cols_ $suffix>]() {
2678                $crate::matrix::tests::test_batched_resize_cols::<$M>($ctx2);
2679            }
2680            #[test]
2681            fn [<test_batched_combine_ $suffix>]() {
2682                $crate::matrix::tests::test_batched_combine::<$M>($ctx2);
2683            }
2684            #[test]
2685            #[should_panic(expected = "incompatible nbatch")]
2686            fn [<test_batched_gemv_incompatible_ $suffix>]() {
2687                $crate::matrix::tests::test_batched_gemv_incompatible::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
2688            }
2689            #[test]
2690            #[should_panic(expected = "incompatible nbatch")]
2691            fn [<test_batched_gemm_incompatible_ $suffix>]() {
2692                $crate::matrix::tests::test_batched_gemm_incompatible::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
2693            }
2694            #[test]
2695            #[should_panic(expected = "incompatible nbatch")]
2696            fn [<test_batched_gemm_incompatible_a_ $suffix>]() {
2697                $crate::matrix::tests::test_batched_gemm_incompatible_a::<$M>($ctx2, $ctx1.clone_with_nbatch(3).unwrap());
2698            }
2699            #[test]
2700            fn [<test_strided_matrix_view_into_owned_ $suffix>]() {
2701                $crate::matrix::tests::test_strided_matrix_view_into_owned::<$M>($ctx2);
2702            }
2703            #[test]
2704            fn [<test_strided_matrix_view_add_owned_ $suffix>]() {
2705                $crate::matrix::tests::test_strided_matrix_view_add_owned::<$M>($ctx2);
2706            }
2707            #[test]
2708            fn [<test_strided_matrix_view_sub_owned_ $suffix>]() {
2709                $crate::matrix::tests::test_strided_matrix_view_sub_owned::<$M>($ctx2);
2710            }
2711            #[test]
2712            fn [<test_strided_matrix_view_mul_scalar_ $suffix>]() {
2713                $crate::matrix::tests::test_strided_matrix_view_mul_scalar::<$M>($ctx2);
2714            }
2715            #[test]
2716            fn [<test_strided_matrix_view_mut_add_assign_view_ $suffix>]() {
2717                $crate::matrix::tests::test_strided_matrix_view_mut_add_assign_view::<$M>($ctx2);
2718            }
2719            #[test]
2720            fn [<test_strided_matrix_view_mut_sub_assign_view_ $suffix>]() {
2721                $crate::matrix::tests::test_strided_matrix_view_mut_sub_assign_view::<$M>($ctx2);
2722            }
2723            #[test]
2724            fn [<test_strided_matrix_view_mut_mul_assign_scalar_ $suffix>]() {
2725                $crate::matrix::tests::test_strided_matrix_view_mut_mul_assign_scalar::<$M>($ctx2);
2726            }
2727            #[test]
2728            fn [<test_strided_matrix_view_mut_into_owned_ $suffix>]() {
2729                $crate::matrix::tests::test_strided_matrix_view_mut_into_owned::<$M>($ctx2);
2730            }
2731            #[test]
2732            fn [<test_batched_view_mut_add_assign_view_mut_ $suffix>]() {
2733                $crate::matrix::tests::test_batched_view_mut_add_assign_view_mut::<$M>($ctx2);
2734            }
2735            #[test]
2736            fn [<test_batched_view_mut_sub_assign_view_mut_ $suffix>]() {
2737                $crate::matrix::tests::test_batched_view_mut_sub_assign_view_mut::<$M>($ctx2);
2738            }
2739            #[test]
2740            fn [<test_batched_gemm_oo_on_columns_ $suffix>]() {
2741                $crate::matrix::tests::test_batched_gemm_oo_on_columns::<$M>($ctx2);
2742            }
2743        }
2744    };
2745}
2746
2747#[cfg(test)]
2748#[cfg_attr(not(feature = "cuda"), allow(unused_imports))]
2749pub(crate) use generate_dense_matrix_tests_batched;
2750#[cfg(test)]
2751pub(crate) use generate_dense_matrix_tests_nonbatched;
2752#[cfg(test)]
2753#[cfg_attr(not(feature = "cuda"), allow(unused_imports))]
2754pub(crate) use generate_matrix_tests_batched;
2755#[cfg(test)]
2756pub(crate) use generate_matrix_tests_nonbatched;