burn_tensor/tensor/linalg/
matvec.rs1use crate::Numeric;
2use crate::backend::Backend;
3use crate::tensor::{BasicOps, Shape, Tensor};
4
5pub fn matvec<B: Backend, const DM: usize, const DV: usize, K>(
18 matrix: Tensor<B, DM, K>,
19 vector: Tensor<B, DV, K>,
20) -> Tensor<B, DV, K>
21where
22 K: BasicOps<B> + Numeric<B>,
23{
24 assert!(
25 DM >= 2,
26 "matvec expects the matrix to be at least rank 2 (got {DM})"
27 );
28 assert!(
29 DM == DV + 1,
30 "matvec expects the vector rank ({DV}) to be exactly one less than the matrix rank ({DM})",
31 );
32
33 let matrix_dims = matrix.shape().dims::<DM>();
34 let vector_dims = vector.shape().dims::<DV>();
35
36 let batch_rank = DM.saturating_sub(2);
38 if batch_rank > 0 {
39 let matrix_batch = Shape::from(&matrix_dims[..batch_rank]);
40 let vector_batch = Shape::from(&vector_dims[..batch_rank]);
41
42 assert!(
43 matrix_batch.broadcast(&vector_batch).is_ok(),
44 "Batch dimensions are not broadcast-compatible: matrix {:?} vs vector {:?}",
45 &matrix_dims[..batch_rank],
46 &vector_dims[..batch_rank]
47 );
48 }
49
50 let matrix_inner = matrix_dims[DM - 1];
51 let vector_inner = vector_dims[DV - 1];
52 assert!(
53 matrix_inner == vector_inner,
54 "Inner dimension mismatch: matrix has {matrix_inner} columns but vector has {vector_inner} entries",
55 );
56
57 let vector_expanded = vector.unsqueeze_dim::<DM>(DV);
58 matrix.matmul(vector_expanded).squeeze_dim::<DV>(DM - 1)
59}