Skip to main content

ronn_core/ops/
matrix.rs

1//! Matrix operations for tensors.
2//!
3//! This module provides matrix operations including matrix multiplication,
4//! transpose, and other linear algebra operations using the Candle backend.
5
6use crate::ops::arithmetic::ArithmeticOps;
7use crate::ops::reduction::ReductionOps;
8use crate::tensor::Tensor;
9use anyhow::{Result, anyhow};
10
11/// Trait for matrix operations on tensors.
12pub trait MatrixOps {
13    /// Matrix multiplication.
14    fn matmul(&self, other: &Tensor) -> Result<Tensor>;
15
16    /// Transpose the tensor (swap last two dimensions).
17    fn transpose(&self) -> Result<Tensor>;
18
19    /// Transpose with specific dimension indices.
20    fn transpose_dims(&self, dim1: usize, dim2: usize) -> Result<Tensor>;
21
22    /// Batch matrix multiplication.
23    fn batch_matmul(&self, other: &Tensor) -> Result<Tensor>;
24}
25
26impl MatrixOps for Tensor {
27    fn matmul(&self, other: &Tensor) -> Result<Tensor> {
28        let self_shape = self.shape();
29        let other_shape = other.shape();
30
31        // Check dimension compatibility for matrix multiplication
32        if self_shape.len() < 2 || other_shape.len() < 2 {
33            return Err(anyhow!(
34                "Matrix multiplication requires at least 2D tensors, got shapes {:?} and {:?}",
35                self_shape,
36                other_shape
37            ));
38        }
39
40        let _self_rows = self_shape[self_shape.len() - 2];
41        let self_cols = self_shape[self_shape.len() - 1];
42        let other_rows = other_shape[other_shape.len() - 2];
43        let _other_cols = other_shape[other_shape.len() - 1];
44
45        if self_cols != other_rows {
46            return Err(anyhow!(
47                "Incompatible dimensions for matrix multiplication: {} vs {}",
48                self_cols,
49                other_rows
50            ));
51        }
52
53        let result_candle = self.candle_tensor().matmul(other.candle_tensor())?;
54
55        Ok(Tensor::from_candle(
56            result_candle,
57            self.dtype(),
58            self.layout(),
59        ))
60    }
61
62    fn transpose(&self) -> Result<Tensor> {
63        let shape = self.shape();
64        if shape.len() < 2 {
65            return Err(anyhow!(
66                "Transpose requires at least 2D tensor, got shape {:?}",
67                shape
68            ));
69        }
70
71        let dim1 = shape.len() - 2;
72        let dim2 = shape.len() - 1;
73        self.transpose_dims(dim1, dim2)
74    }
75
76    fn transpose_dims(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
77        let shape = self.shape();
78
79        if dim1 >= shape.len() || dim2 >= shape.len() {
80            return Err(anyhow!(
81                "Transpose dimensions {} and {} out of bounds for tensor with {} dimensions",
82                dim1,
83                dim2,
84                shape.len()
85            ));
86        }
87
88        let result_candle = self.candle_tensor().transpose(dim1, dim2)?;
89
90        Ok(Tensor::from_candle(
91            result_candle,
92            self.dtype(),
93            self.layout(),
94        ))
95    }
96
97    fn batch_matmul(&self, other: &Tensor) -> Result<Tensor> {
98        let self_shape = self.shape();
99        let other_shape = other.shape();
100
101        // For batch matrix multiplication, we need at least 3D tensors
102        if self_shape.len() < 3 || other_shape.len() < 3 {
103            // If one tensor is 2D, we can broadcast it
104            return self.matmul(other);
105        }
106
107        // Check batch dimensions are compatible
108        let self_batch = &self_shape[..self_shape.len() - 2];
109        let other_batch = &other_shape[..other_shape.len() - 2];
110
111        if self_batch != other_batch {
112            return Err(anyhow!(
113                "Incompatible batch dimensions for batch matrix multiplication: {:?} vs {:?}",
114                self_batch,
115                other_batch
116            ));
117        }
118
119        let result_candle = self.candle_tensor().matmul(other.candle_tensor())?;
120
121        Ok(Tensor::from_candle(
122            result_candle,
123            self.dtype(),
124            self.layout(),
125        ))
126    }
127}
128
129/// Additional matrix operations as methods on Tensor.
130impl Tensor {
131    /// Compute the trace (sum of diagonal elements) of a 2D tensor.
132    pub fn trace(&self) -> Result<Tensor> {
133        let shape = self.shape();
134        if shape.len() != 2 || shape[0] != shape[1] {
135            return Err(anyhow!(
136                "Trace requires a square 2D tensor, got shape {:?}",
137                shape
138            ));
139        }
140
141        let diag = self.diagonal()?;
142        diag.sum_all()
143    }
144
145    /// Get the diagonal elements of a 2D tensor.
146    pub fn diagonal(&self) -> Result<Tensor> {
147        let shape = self.shape();
148        if shape.len() < 2 {
149            return Err(anyhow!(
150                "Diagonal requires at least 2D tensor, got shape {:?}",
151                shape
152            ));
153        }
154
155        // For 2D matrices, manually extract diagonal elements
156        if shape.len() == 2 {
157            let data = self.to_vec()?;
158            let rows = shape[0];
159            let cols = shape[1];
160            let min_dim = rows.min(cols);
161
162            let mut diag_data = Vec::with_capacity(min_dim);
163            for i in 0..min_dim {
164                diag_data.push(data[i * cols + i]);
165            }
166
167            return Ok(Tensor::from_data(
168                diag_data,
169                vec![min_dim],
170                self.dtype(),
171                self.layout(),
172            )?);
173        }
174
175        // For higher dimensions, this is more complex - placeholder implementation
176        Err(anyhow!(
177            "Diagonal extraction for >2D tensors not yet implemented"
178        ))
179    }
180
181    /// Create an identity matrix of given size.
182    pub fn eye(
183        size: usize,
184        dtype: crate::types::DataType,
185        layout: crate::types::TensorLayout,
186    ) -> Result<Tensor> {
187        use candle_core::Device;
188
189        let device = Device::Cpu;
190        let candle_tensor = candle_core::Tensor::eye(size, dtype_to_candle(&dtype)?, &device)?;
191
192        Ok(Tensor::from_candle(candle_tensor, dtype, layout))
193    }
194
195    /// Compute the determinant of a 2D square tensor.
196    pub fn det(&self) -> Result<Tensor> {
197        let shape = self.shape();
198        if shape.len() != 2 || shape[0] != shape[1] {
199            return Err(anyhow!(
200                "Determinant requires a square 2D tensor, got shape {:?}",
201                shape
202            ));
203        }
204
205        // For small matrices, we can compute determinant directly
206        match shape[0] {
207            1 => {
208                let data = self.to_vec()?;
209                Ok(Tensor::from_data(
210                    vec![data[0]],
211                    vec![1],
212                    self.dtype(),
213                    self.layout(),
214                )?)
215            }
216            2 => {
217                let data = self.to_vec()?;
218                let det = data[0] * data[3] - data[1] * data[2];
219                Ok(Tensor::from_data(
220                    vec![det],
221                    vec![1],
222                    self.dtype(),
223                    self.layout(),
224                )?)
225            }
226            _ => {
227                // For larger matrices, we'd need more complex algorithms like LU decomposition
228                // This is a placeholder implementation
229                Err(anyhow!(
230                    "Determinant calculation for {}x{} matrices not yet implemented",
231                    shape[0],
232                    shape[1]
233                ))
234            }
235        }
236    }
237
238    /// Compute matrix inverse (for 2x2 matrices only in this implementation).
239    pub fn inverse(&self) -> Result<Tensor> {
240        let shape = self.shape();
241        if shape.len() != 2 || shape[0] != shape[1] {
242            return Err(anyhow!(
243                "Inverse requires a square 2D tensor, got shape {:?}",
244                shape
245            ));
246        }
247
248        if shape[0] != 2 {
249            return Err(anyhow!(
250                "Matrix inverse only implemented for 2x2 matrices, got {}x{}",
251                shape[0],
252                shape[1]
253            ));
254        }
255
256        let data = self.to_vec()?;
257        let a = data[0];
258        let b = data[1];
259        let c = data[2];
260        let d = data[3];
261
262        let det = a * d - b * c;
263        if det.abs() < 1e-10 {
264            return Err(anyhow!("Matrix is singular (determinant ≈ 0)"));
265        }
266
267        let inv_det = 1.0 / det;
268        let inv_data = vec![d * inv_det, -b * inv_det, -c * inv_det, a * inv_det];
269
270        Ok(Tensor::from_data(
271            inv_data,
272            vec![2, 2],
273            self.dtype(),
274            self.layout(),
275        )?)
276    }
277
278    /// Compute the Frobenius norm of the tensor.
279    pub fn frobenius_norm(&self) -> Result<Tensor> {
280        let squared = self.mul(self)?;
281        let sum = squared.sum_all()?;
282        let sqrt_result = sum.sqrt()?;
283
284        // Ensure result is at least 1D
285        let sqrt_candle = sqrt_result.candle_tensor();
286        let reshaped = if sqrt_candle.dims().is_empty() {
287            sqrt_candle.reshape(&[1])?
288        } else {
289            sqrt_candle.clone()
290        };
291
292        Ok(Tensor::from_candle(
293            reshaped,
294            sqrt_result.dtype(),
295            sqrt_result.layout(),
296        ))
297    }
298
299    /// Create a tensor with ones on the diagonal and zeros elsewhere.
300    pub fn diag_embed(&self) -> Result<Tensor> {
301        let shape = self.shape();
302        if shape.len() != 1 {
303            return Err(anyhow!(
304                "diag_embed requires a 1D tensor, got shape {:?}",
305                shape
306            ));
307        }
308
309        let n = shape[0];
310        let mut diag_data = vec![0.0; n * n];
311        let data = self.to_vec()?;
312
313        for i in 0..n {
314            diag_data[i * n + i] = data[i];
315        }
316
317        Ok(Tensor::from_data(
318            diag_data,
319            vec![n, n],
320            self.dtype(),
321            self.layout(),
322        )?)
323    }
324}
325
326/// Convert RONN DataType to Candle DType (helper function).
327fn dtype_to_candle(dtype: &crate::types::DataType) -> Result<candle_core::DType> {
328    use crate::types::DataType;
329    use candle_core::DType;
330
331    match dtype {
332        DataType::F32 => Ok(DType::F32),
333        DataType::F16 => Ok(DType::F16),
334        DataType::BF16 => Ok(DType::BF16),
335        DataType::F64 => Ok(DType::F64),
336        DataType::U8 => Ok(DType::U8),
337        DataType::U32 => Ok(DType::U32),
338        // For unsupported types, use F32
339        DataType::I8 | DataType::I32 | DataType::I64 | DataType::Bool => Ok(DType::F32),
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use crate::types::{DataType, TensorLayout};
347
348    #[test]
349    fn test_matrix_multiplication() -> Result<()> {
350        let a = Tensor::from_data(
351            vec![1.0, 2.0, 3.0, 4.0],
352            vec![2, 2],
353            DataType::F32,
354            TensorLayout::RowMajor,
355        )?;
356
357        let b = Tensor::from_data(
358            vec![2.0, 0.0, 1.0, 1.0],
359            vec![2, 2],
360            DataType::F32,
361            TensorLayout::RowMajor,
362        )?;
363
364        let result = a.matmul(&b)?;
365        let result_data = result.to_vec()?;
366
367        // Expected: [[1*2+2*1, 1*0+2*1], [3*2+4*1, 3*0+4*1]] = [[4, 2], [10, 4]]
368        assert_eq!(result_data, vec![4.0, 2.0, 10.0, 4.0]);
369
370        Ok(())
371    }
372
373    #[test]
374    fn test_transpose() -> Result<()> {
375        let a = Tensor::from_data(
376            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
377            vec![2, 3],
378            DataType::F32,
379            TensorLayout::RowMajor,
380        )?;
381
382        let transposed = MatrixOps::transpose(&a)?;
383        let transposed_data = transposed.to_vec()?;
384        assert_eq!(transposed.shape(), vec![3, 2]);
385
386        // Expected: [[1, 4], [2, 5], [3, 6]] (flattened: [1, 4, 2, 5, 3, 6])
387        assert_eq!(transposed_data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
388
389        Ok(())
390    }
391
392    #[test]
393    fn test_identity_matrix() -> Result<()> {
394        let identity = Tensor::eye(3, DataType::F32, TensorLayout::RowMajor)?;
395        let identity_data = identity.to_vec()?;
396
397        assert_eq!(identity.shape(), vec![3, 3]);
398        assert_eq!(
399            identity_data,
400            vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
401        );
402
403        Ok(())
404    }
405
406    #[test]
407    fn test_diagonal() -> Result<()> {
408        let a = Tensor::from_data(
409            vec![1.0, 2.0, 3.0, 4.0],
410            vec![2, 2],
411            DataType::F32,
412            TensorLayout::RowMajor,
413        )?;
414
415        let diag = a.diagonal()?;
416        let diag_data = diag.to_vec()?;
417
418        assert_eq!(diag_data, vec![1.0, 4.0]);
419
420        Ok(())
421    }
422
423    #[test]
424    fn test_determinant_2x2() -> Result<()> {
425        let a = Tensor::from_data(
426            vec![1.0, 2.0, 3.0, 4.0],
427            vec![2, 2],
428            DataType::F32,
429            TensorLayout::RowMajor,
430        )?;
431
432        let det = a.det()?;
433        let det_data = det.to_vec()?;
434
435        // det([[1, 2], [3, 4]]) = 1*4 - 2*3 = -2
436        assert!((det_data[0] + 2.0).abs() < 1e-6);
437
438        Ok(())
439    }
440
441    #[test]
442    fn test_matrix_inverse_2x2() -> Result<()> {
443        let a = Tensor::from_data(
444            vec![1.0, 2.0, 3.0, 4.0],
445            vec![2, 2],
446            DataType::F32,
447            TensorLayout::RowMajor,
448        )?;
449
450        let inv = a.inverse()?;
451        let inv_data = inv.to_vec()?;
452
453        // inv([[1, 2], [3, 4]]) = (1/(-2)) * [[4, -2], [-3, 1]] = [[-2, 1], [1.5, -0.5]]
454        assert!((inv_data[0] + 2.0).abs() < 1e-6);
455        assert!((inv_data[1] - 1.0).abs() < 1e-6);
456        assert!((inv_data[2] - 1.5).abs() < 1e-6);
457        assert!((inv_data[3] + 0.5).abs() < 1e-6);
458
459        Ok(())
460    }
461
462    #[test]
463    fn test_trace() -> Result<()> {
464        let a = Tensor::from_data(
465            vec![1.0, 2.0, 3.0, 4.0],
466            vec![2, 2],
467            DataType::F32,
468            TensorLayout::RowMajor,
469        )?;
470
471        let trace = a.trace()?;
472        let trace_data = trace.to_vec()?;
473
474        // trace([[1, 2], [3, 4]]) = 1 + 4 = 5
475        assert_eq!(trace_data[0], 5.0);
476
477        Ok(())
478    }
479
480    #[test]
481    fn test_diag_embed() -> Result<()> {
482        let a = Tensor::from_data(
483            vec![1.0, 2.0, 3.0],
484            vec![3],
485            DataType::F32,
486            TensorLayout::RowMajor,
487        )?;
488
489        let diag_matrix = a.diag_embed()?;
490        let diag_data = diag_matrix.to_vec()?;
491
492        assert_eq!(diag_matrix.shape(), vec![3, 3]);
493        assert_eq!(diag_data, vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
494
495        Ok(())
496    }
497
498    #[test]
499    fn test_frobenius_norm() -> Result<()> {
500        let a = Tensor::from_data(
501            vec![3.0, 4.0],
502            vec![2],
503            DataType::F32,
504            TensorLayout::RowMajor,
505        )?;
506
507        let norm = a.frobenius_norm()?;
508        let norm_data = norm.to_vec()?;
509
510        // ||[3, 4]||_F = sqrt(3^2 + 4^2) = sqrt(25) = 5
511        assert_eq!(norm_data[0], 5.0);
512
513        Ok(())
514    }
515
516    #[test]
517    fn test_batch_matmul() -> Result<()> {
518        // Create 2 batch matrices of size 2x3x2
519        let a = Tensor::from_data(
520            vec![
521                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
522            ],
523            vec![2, 3, 2],
524            DataType::F32,
525            TensorLayout::RowMajor,
526        )?;
527
528        let b = Tensor::from_data(
529            vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0],
530            vec![2, 2, 2],
531            DataType::F32,
532            TensorLayout::RowMajor,
533        )?;
534
535        let result = a.batch_matmul(&b)?;
536        assert_eq!(result.shape(), vec![2, 3, 2]);
537
538        Ok(())
539    }
540
541    #[test]
542    fn test_error_handling() {
543        // Test incompatible dimensions for matmul
544        let a = Tensor::from_data(
545            vec![1.0, 2.0],
546            vec![2],
547            DataType::F32,
548            TensorLayout::RowMajor,
549        )
550        .unwrap();
551        let b = Tensor::from_data(
552            vec![1.0, 2.0, 3.0],
553            vec![3],
554            DataType::F32,
555            TensorLayout::RowMajor,
556        )
557        .unwrap();
558        assert!(a.matmul(&b).is_err());
559
560        // Test transpose on 1D tensor
561        assert!(MatrixOps::transpose(&a).is_err());
562
563        // Test invalid transpose dimensions
564        let c = Tensor::from_data(
565            vec![1.0, 2.0, 3.0, 4.0],
566            vec![2, 2],
567            DataType::F32,
568            TensorLayout::RowMajor,
569        )
570        .unwrap();
571        assert!(c.transpose_dims(5, 6).is_err());
572
573        // Test inverse on non-square matrix
574        let d = Tensor::from_data(
575            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
576            vec![2, 3],
577            DataType::F32,
578            TensorLayout::RowMajor,
579        )
580        .unwrap();
581        assert!(d.inverse().is_err());
582
583        // Test singular matrix inverse
584        let singular = Tensor::from_data(
585            vec![1.0, 2.0, 2.0, 4.0],
586            vec![2, 2],
587            DataType::F32,
588            TensorLayout::RowMajor,
589        )
590        .unwrap();
591        assert!(singular.inverse().is_err());
592    }
593}