1use crate::ops::arithmetic::ArithmeticOps;
7use crate::ops::reduction::ReductionOps;
8use crate::tensor::Tensor;
9use anyhow::{Result, anyhow};
10
11pub trait MatrixOps {
13 fn matmul(&self, other: &Tensor) -> Result<Tensor>;
15
16 fn transpose(&self) -> Result<Tensor>;
18
19 fn transpose_dims(&self, dim1: usize, dim2: usize) -> Result<Tensor>;
21
22 fn batch_matmul(&self, other: &Tensor) -> Result<Tensor>;
24}
25
26impl MatrixOps for Tensor {
27 fn matmul(&self, other: &Tensor) -> Result<Tensor> {
28 let self_shape = self.shape();
29 let other_shape = other.shape();
30
31 if self_shape.len() < 2 || other_shape.len() < 2 {
33 return Err(anyhow!(
34 "Matrix multiplication requires at least 2D tensors, got shapes {:?} and {:?}",
35 self_shape,
36 other_shape
37 ));
38 }
39
40 let _self_rows = self_shape[self_shape.len() - 2];
41 let self_cols = self_shape[self_shape.len() - 1];
42 let other_rows = other_shape[other_shape.len() - 2];
43 let _other_cols = other_shape[other_shape.len() - 1];
44
45 if self_cols != other_rows {
46 return Err(anyhow!(
47 "Incompatible dimensions for matrix multiplication: {} vs {}",
48 self_cols,
49 other_rows
50 ));
51 }
52
53 let result_candle = self.candle_tensor().matmul(other.candle_tensor())?;
54
55 Ok(Tensor::from_candle(
56 result_candle,
57 self.dtype(),
58 self.layout(),
59 ))
60 }
61
62 fn transpose(&self) -> Result<Tensor> {
63 let shape = self.shape();
64 if shape.len() < 2 {
65 return Err(anyhow!(
66 "Transpose requires at least 2D tensor, got shape {:?}",
67 shape
68 ));
69 }
70
71 let dim1 = shape.len() - 2;
72 let dim2 = shape.len() - 1;
73 self.transpose_dims(dim1, dim2)
74 }
75
76 fn transpose_dims(&self, dim1: usize, dim2: usize) -> Result<Tensor> {
77 let shape = self.shape();
78
79 if dim1 >= shape.len() || dim2 >= shape.len() {
80 return Err(anyhow!(
81 "Transpose dimensions {} and {} out of bounds for tensor with {} dimensions",
82 dim1,
83 dim2,
84 shape.len()
85 ));
86 }
87
88 let result_candle = self.candle_tensor().transpose(dim1, dim2)?;
89
90 Ok(Tensor::from_candle(
91 result_candle,
92 self.dtype(),
93 self.layout(),
94 ))
95 }
96
97 fn batch_matmul(&self, other: &Tensor) -> Result<Tensor> {
98 let self_shape = self.shape();
99 let other_shape = other.shape();
100
101 if self_shape.len() < 3 || other_shape.len() < 3 {
103 return self.matmul(other);
105 }
106
107 let self_batch = &self_shape[..self_shape.len() - 2];
109 let other_batch = &other_shape[..other_shape.len() - 2];
110
111 if self_batch != other_batch {
112 return Err(anyhow!(
113 "Incompatible batch dimensions for batch matrix multiplication: {:?} vs {:?}",
114 self_batch,
115 other_batch
116 ));
117 }
118
119 let result_candle = self.candle_tensor().matmul(other.candle_tensor())?;
120
121 Ok(Tensor::from_candle(
122 result_candle,
123 self.dtype(),
124 self.layout(),
125 ))
126 }
127}
128
129impl Tensor {
131 pub fn trace(&self) -> Result<Tensor> {
133 let shape = self.shape();
134 if shape.len() != 2 || shape[0] != shape[1] {
135 return Err(anyhow!(
136 "Trace requires a square 2D tensor, got shape {:?}",
137 shape
138 ));
139 }
140
141 let diag = self.diagonal()?;
142 diag.sum_all()
143 }
144
145 pub fn diagonal(&self) -> Result<Tensor> {
147 let shape = self.shape();
148 if shape.len() < 2 {
149 return Err(anyhow!(
150 "Diagonal requires at least 2D tensor, got shape {:?}",
151 shape
152 ));
153 }
154
155 if shape.len() == 2 {
157 let data = self.to_vec()?;
158 let rows = shape[0];
159 let cols = shape[1];
160 let min_dim = rows.min(cols);
161
162 let mut diag_data = Vec::with_capacity(min_dim);
163 for i in 0..min_dim {
164 diag_data.push(data[i * cols + i]);
165 }
166
167 return Ok(Tensor::from_data(
168 diag_data,
169 vec![min_dim],
170 self.dtype(),
171 self.layout(),
172 )?);
173 }
174
175 Err(anyhow!(
177 "Diagonal extraction for >2D tensors not yet implemented"
178 ))
179 }
180
181 pub fn eye(
183 size: usize,
184 dtype: crate::types::DataType,
185 layout: crate::types::TensorLayout,
186 ) -> Result<Tensor> {
187 use candle_core::Device;
188
189 let device = Device::Cpu;
190 let candle_tensor = candle_core::Tensor::eye(size, dtype_to_candle(&dtype)?, &device)?;
191
192 Ok(Tensor::from_candle(candle_tensor, dtype, layout))
193 }
194
195 pub fn det(&self) -> Result<Tensor> {
197 let shape = self.shape();
198 if shape.len() != 2 || shape[0] != shape[1] {
199 return Err(anyhow!(
200 "Determinant requires a square 2D tensor, got shape {:?}",
201 shape
202 ));
203 }
204
205 match shape[0] {
207 1 => {
208 let data = self.to_vec()?;
209 Ok(Tensor::from_data(
210 vec![data[0]],
211 vec![1],
212 self.dtype(),
213 self.layout(),
214 )?)
215 }
216 2 => {
217 let data = self.to_vec()?;
218 let det = data[0] * data[3] - data[1] * data[2];
219 Ok(Tensor::from_data(
220 vec![det],
221 vec![1],
222 self.dtype(),
223 self.layout(),
224 )?)
225 }
226 _ => {
227 Err(anyhow!(
230 "Determinant calculation for {}x{} matrices not yet implemented",
231 shape[0],
232 shape[1]
233 ))
234 }
235 }
236 }
237
238 pub fn inverse(&self) -> Result<Tensor> {
240 let shape = self.shape();
241 if shape.len() != 2 || shape[0] != shape[1] {
242 return Err(anyhow!(
243 "Inverse requires a square 2D tensor, got shape {:?}",
244 shape
245 ));
246 }
247
248 if shape[0] != 2 {
249 return Err(anyhow!(
250 "Matrix inverse only implemented for 2x2 matrices, got {}x{}",
251 shape[0],
252 shape[1]
253 ));
254 }
255
256 let data = self.to_vec()?;
257 let a = data[0];
258 let b = data[1];
259 let c = data[2];
260 let d = data[3];
261
262 let det = a * d - b * c;
263 if det.abs() < 1e-10 {
264 return Err(anyhow!("Matrix is singular (determinant ≈ 0)"));
265 }
266
267 let inv_det = 1.0 / det;
268 let inv_data = vec![d * inv_det, -b * inv_det, -c * inv_det, a * inv_det];
269
270 Ok(Tensor::from_data(
271 inv_data,
272 vec![2, 2],
273 self.dtype(),
274 self.layout(),
275 )?)
276 }
277
278 pub fn frobenius_norm(&self) -> Result<Tensor> {
280 let squared = self.mul(self)?;
281 let sum = squared.sum_all()?;
282 let sqrt_result = sum.sqrt()?;
283
284 let sqrt_candle = sqrt_result.candle_tensor();
286 let reshaped = if sqrt_candle.dims().is_empty() {
287 sqrt_candle.reshape(&[1])?
288 } else {
289 sqrt_candle.clone()
290 };
291
292 Ok(Tensor::from_candle(
293 reshaped,
294 sqrt_result.dtype(),
295 sqrt_result.layout(),
296 ))
297 }
298
299 pub fn diag_embed(&self) -> Result<Tensor> {
301 let shape = self.shape();
302 if shape.len() != 1 {
303 return Err(anyhow!(
304 "diag_embed requires a 1D tensor, got shape {:?}",
305 shape
306 ));
307 }
308
309 let n = shape[0];
310 let mut diag_data = vec![0.0; n * n];
311 let data = self.to_vec()?;
312
313 for i in 0..n {
314 diag_data[i * n + i] = data[i];
315 }
316
317 Ok(Tensor::from_data(
318 diag_data,
319 vec![n, n],
320 self.dtype(),
321 self.layout(),
322 )?)
323 }
324}
325
326fn dtype_to_candle(dtype: &crate::types::DataType) -> Result<candle_core::DType> {
328 use crate::types::DataType;
329 use candle_core::DType;
330
331 match dtype {
332 DataType::F32 => Ok(DType::F32),
333 DataType::F16 => Ok(DType::F16),
334 DataType::BF16 => Ok(DType::BF16),
335 DataType::F64 => Ok(DType::F64),
336 DataType::U8 => Ok(DType::U8),
337 DataType::U32 => Ok(DType::U32),
338 DataType::I8 | DataType::I32 | DataType::I64 | DataType::Bool => Ok(DType::F32),
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346 use crate::types::{DataType, TensorLayout};
347
348 #[test]
349 fn test_matrix_multiplication() -> Result<()> {
350 let a = Tensor::from_data(
351 vec![1.0, 2.0, 3.0, 4.0],
352 vec![2, 2],
353 DataType::F32,
354 TensorLayout::RowMajor,
355 )?;
356
357 let b = Tensor::from_data(
358 vec![2.0, 0.0, 1.0, 1.0],
359 vec![2, 2],
360 DataType::F32,
361 TensorLayout::RowMajor,
362 )?;
363
364 let result = a.matmul(&b)?;
365 let result_data = result.to_vec()?;
366
367 assert_eq!(result_data, vec![4.0, 2.0, 10.0, 4.0]);
369
370 Ok(())
371 }
372
373 #[test]
374 fn test_transpose() -> Result<()> {
375 let a = Tensor::from_data(
376 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
377 vec![2, 3],
378 DataType::F32,
379 TensorLayout::RowMajor,
380 )?;
381
382 let transposed = MatrixOps::transpose(&a)?;
383 let transposed_data = transposed.to_vec()?;
384 assert_eq!(transposed.shape(), vec![3, 2]);
385
386 assert_eq!(transposed_data, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
388
389 Ok(())
390 }
391
392 #[test]
393 fn test_identity_matrix() -> Result<()> {
394 let identity = Tensor::eye(3, DataType::F32, TensorLayout::RowMajor)?;
395 let identity_data = identity.to_vec()?;
396
397 assert_eq!(identity.shape(), vec![3, 3]);
398 assert_eq!(
399 identity_data,
400 vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
401 );
402
403 Ok(())
404 }
405
406 #[test]
407 fn test_diagonal() -> Result<()> {
408 let a = Tensor::from_data(
409 vec![1.0, 2.0, 3.0, 4.0],
410 vec![2, 2],
411 DataType::F32,
412 TensorLayout::RowMajor,
413 )?;
414
415 let diag = a.diagonal()?;
416 let diag_data = diag.to_vec()?;
417
418 assert_eq!(diag_data, vec![1.0, 4.0]);
419
420 Ok(())
421 }
422
423 #[test]
424 fn test_determinant_2x2() -> Result<()> {
425 let a = Tensor::from_data(
426 vec![1.0, 2.0, 3.0, 4.0],
427 vec![2, 2],
428 DataType::F32,
429 TensorLayout::RowMajor,
430 )?;
431
432 let det = a.det()?;
433 let det_data = det.to_vec()?;
434
435 assert!((det_data[0] + 2.0).abs() < 1e-6);
437
438 Ok(())
439 }
440
441 #[test]
442 fn test_matrix_inverse_2x2() -> Result<()> {
443 let a = Tensor::from_data(
444 vec![1.0, 2.0, 3.0, 4.0],
445 vec![2, 2],
446 DataType::F32,
447 TensorLayout::RowMajor,
448 )?;
449
450 let inv = a.inverse()?;
451 let inv_data = inv.to_vec()?;
452
453 assert!((inv_data[0] + 2.0).abs() < 1e-6);
455 assert!((inv_data[1] - 1.0).abs() < 1e-6);
456 assert!((inv_data[2] - 1.5).abs() < 1e-6);
457 assert!((inv_data[3] + 0.5).abs() < 1e-6);
458
459 Ok(())
460 }
461
462 #[test]
463 fn test_trace() -> Result<()> {
464 let a = Tensor::from_data(
465 vec![1.0, 2.0, 3.0, 4.0],
466 vec![2, 2],
467 DataType::F32,
468 TensorLayout::RowMajor,
469 )?;
470
471 let trace = a.trace()?;
472 let trace_data = trace.to_vec()?;
473
474 assert_eq!(trace_data[0], 5.0);
476
477 Ok(())
478 }
479
480 #[test]
481 fn test_diag_embed() -> Result<()> {
482 let a = Tensor::from_data(
483 vec![1.0, 2.0, 3.0],
484 vec![3],
485 DataType::F32,
486 TensorLayout::RowMajor,
487 )?;
488
489 let diag_matrix = a.diag_embed()?;
490 let diag_data = diag_matrix.to_vec()?;
491
492 assert_eq!(diag_matrix.shape(), vec![3, 3]);
493 assert_eq!(diag_data, vec![1.0, 0.0, 0.0, 0.0, 2.0, 0.0, 0.0, 0.0, 3.0]);
494
495 Ok(())
496 }
497
498 #[test]
499 fn test_frobenius_norm() -> Result<()> {
500 let a = Tensor::from_data(
501 vec![3.0, 4.0],
502 vec![2],
503 DataType::F32,
504 TensorLayout::RowMajor,
505 )?;
506
507 let norm = a.frobenius_norm()?;
508 let norm_data = norm.to_vec()?;
509
510 assert_eq!(norm_data[0], 5.0);
512
513 Ok(())
514 }
515
516 #[test]
517 fn test_batch_matmul() -> Result<()> {
518 let a = Tensor::from_data(
520 vec![
521 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
522 ],
523 vec![2, 3, 2],
524 DataType::F32,
525 TensorLayout::RowMajor,
526 )?;
527
528 let b = Tensor::from_data(
529 vec![1.0, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0],
530 vec![2, 2, 2],
531 DataType::F32,
532 TensorLayout::RowMajor,
533 )?;
534
535 let result = a.batch_matmul(&b)?;
536 assert_eq!(result.shape(), vec![2, 3, 2]);
537
538 Ok(())
539 }
540
541 #[test]
542 fn test_error_handling() {
543 let a = Tensor::from_data(
545 vec![1.0, 2.0],
546 vec![2],
547 DataType::F32,
548 TensorLayout::RowMajor,
549 )
550 .unwrap();
551 let b = Tensor::from_data(
552 vec![1.0, 2.0, 3.0],
553 vec![3],
554 DataType::F32,
555 TensorLayout::RowMajor,
556 )
557 .unwrap();
558 assert!(a.matmul(&b).is_err());
559
560 assert!(MatrixOps::transpose(&a).is_err());
562
563 let c = Tensor::from_data(
565 vec![1.0, 2.0, 3.0, 4.0],
566 vec![2, 2],
567 DataType::F32,
568 TensorLayout::RowMajor,
569 )
570 .unwrap();
571 assert!(c.transpose_dims(5, 6).is_err());
572
573 let d = Tensor::from_data(
575 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
576 vec![2, 3],
577 DataType::F32,
578 TensorLayout::RowMajor,
579 )
580 .unwrap();
581 assert!(d.inverse().is_err());
582
583 let singular = Tensor::from_data(
585 vec![1.0, 2.0, 2.0, 4.0],
586 vec![2, 2],
587 DataType::F32,
588 TensorLayout::RowMajor,
589 )
590 .unwrap();
591 assert!(singular.inverse().is_err());
592 }
593}