burn_tensor/tensor/linalg/
outer.rs

1use crate::backend::Backend;
2use crate::tensor::{BasicOps, Tensor};
3use crate::{Numeric, Shape};
4
5/// Computes the outer product (and batched outer product) for rank-1 or rank-2 tensor.
6///
7/// Supported ranks:
8/// - D = 1, R = 2: vectors (m,) × (n,) → (m, n)
9/// - D = 2, R = 3: batched (b, m) × (b, n) → (b, m, n)
10///
11/// Panics:
12/// - if D > 2
13/// - if (D, R) is not (1,2) or (2,3)
14/// - if D = 2 and batch dimensions differ
15//
16// Notes:
17// - For large batched inputs, `x_col.matmul(y_row)` *might* be more performant
18//   than broadcasted elemwise multiply; benchmarking needed to confirm.
19pub fn outer<B: Backend, const D: usize, const R: usize, K>(
20    x: Tensor<B, D, K>,
21    y: Tensor<B, D, K>,
22) -> Tensor<B, R, K>
23where
24    K: BasicOps<B> + Numeric<B>,
25{
26    if D == 1 {
27        assert!(R == 2, "`outer` with D=1 must use R=2 (got R={})", R);
28        let [m] = x.shape().dims();
29        let [n] = y.shape().dims();
30
31        let x_col = x.reshape(Shape::new([m, 1])); // (m, 1)
32        let y_row = y.reshape(Shape::new([1, n])); // (1, n)
33
34        x_col * y_row // (m, n)
35    } else if D == 2 {
36        assert!(R == 3, "`outer` with D=2 must use R=3 (got R={})", R);
37        let [bx, m] = x.shape().dims();
38        let [by, n] = y.shape().dims();
39        assert_eq!(bx, by, "batch dimensions must match (got {} vs {})", bx, by);
40
41        let x_col = x.reshape(Shape::new([bx, m, 1])); // (b, m, 1)
42        let y_row = y.reshape(Shape::new([by, 1, n])); // (b, 1, n)
43
44        x_col * y_row // (b, m, n)
45    } else {
46        panic!("`outer` only supports rank 1 or 2 tensors (got D={})", D);
47    }
48}