Skip to main content

burn_tensor/tensor/linalg/
matvec.rs

1use crate::Numeric;
2use crate::backend::Backend;
3use crate::tensor::{BasicOps, Shape, Tensor};
4
5/// Performs matrix-vector multiplication with optional batch dimensions.
6///
7/// The `matrix` tensor is expected to have rank `DM` with the last two dimensions representing
8/// the matrix rows and columns. The `vector` tensor should have rank `DV = DM - 1`, sharing
9/// broadcast-compatible batch dimensions and matching the last dimension of the matrix.
10///
11/// # Panics
12///
13/// * If the matrix rank is lower than 2.
14/// * If the vector rank isn't one less than the matrix rank.
15/// * If batch dimensions differ between the operands.
16/// * If the inner dimensions are incompatible for multiplication.
17pub fn matvec<B: Backend, const DM: usize, const DV: usize, K>(
18    matrix: Tensor<B, DM, K>,
19    vector: Tensor<B, DV, K>,
20) -> Tensor<B, DV, K>
21where
22    K: BasicOps<B> + Numeric<B>,
23{
24    assert!(
25        DM >= 2,
26        "matvec expects the matrix to be at least rank 2 (got {DM})"
27    );
28    assert!(
29        DM == DV + 1,
30        "matvec expects the vector rank ({DV}) to be exactly one less than the matrix rank ({DM})",
31    );
32
33    let matrix_dims = matrix.shape().dims::<DM>();
34    let vector_dims = vector.shape().dims::<DV>();
35
36    // Validate batch dimensions (all leading dimensions prior to the matrix axes).
37    let batch_rank = DM.saturating_sub(2);
38    if batch_rank > 0 {
39        let matrix_batch = Shape::from(&matrix_dims[..batch_rank]);
40        let vector_batch = Shape::from(&vector_dims[..batch_rank]);
41
42        assert!(
43            matrix_batch.broadcast(&vector_batch).is_ok(),
44            "Batch dimensions are not broadcast-compatible: matrix {:?} vs vector {:?}",
45            &matrix_dims[..batch_rank],
46            &vector_dims[..batch_rank]
47        );
48    }
49
50    let matrix_inner = matrix_dims[DM - 1];
51    let vector_inner = vector_dims[DV - 1];
52    assert!(
53        matrix_inner == vector_inner,
54        "Inner dimension mismatch: matrix has {matrix_inner} columns but vector has {vector_inner} entries",
55    );
56
57    let vector_expanded = vector.unsqueeze_dim::<DM>(DV);
58    matrix.matmul(vector_expanded).squeeze_dim::<DV>(DM - 1)
59}