Skip to main content

burn_tensor/tensor/linalg/
outer.rs

1use crate::backend::Backend;
2use crate::tensor::{BasicOps, Tensor};
3use crate::{AsIndex, Numeric};
4
5/// Computes the outer product for the last columns of 2 tensors.
6///
7/// See also: [`outer_dim`].
8///
9/// # Arguments
10/// - `lhs`: the "row" tensor, with shape ``[..., i]``.
11/// - `rhs`: the "col" tensor, with shape ``[..., j]``.
12/// - `dim`: the dimension to product.
13///
14/// # Returns
15///
16/// A tensor of rank `R = D + 1`, where:
17///
18/// ``
19/// result[..., i, j] = lhs[..., i] * rhs[..., j]
20/// ``
21pub fn outer<B: Backend, const D: usize, const R: usize, K>(
22    x: Tensor<B, D, K>,
23    y: Tensor<B, D, K>,
24) -> Tensor<B, R, K>
25where
26    K: BasicOps<B> + Numeric<B>,
27{
28    outer_dim(x, y, -1)
29}
30
31/// Computes the outer product along a specific dimension, broadcasting over others.
32///
33/// For the given `dim`, computes the outer product of elements along that dimension,
34/// expanding it into two dimensions of size ``M × N`` at positions ``(dim, dim + 1)``.
35///
36/// # Arguments
37///
38/// - `lhs`: left operand, the "row" tensor, with size `M` at dimension `dim`.
39/// - `rhs`: right operand, the "col" tensor, with size `N` at dimension `dim`.
40/// - `dim`: dimension to compute the outer product along (supports negative indexing).
41///
42/// # Returns
43///
44/// A tensor of rank `R = D + 1`, where:
45///
46/// ``
47/// result[..., i, j, ...] = lhs[..., i, ...] * rhs[..., j, ...]
48/// ``
49//
50// Notes:
51// - For large batched inputs, `x_col.matmul(y_row)` *might* be more performant
52//   than broadcasted elemwise multiply; benchmarking needed to confirm.
53pub fn outer_dim<B: Backend, const D: usize, const R: usize, Dim: AsIndex, K>(
54    lhs: Tensor<B, D, K>,
55    rhs: Tensor<B, D, K>,
56    dim: Dim,
57) -> Tensor<B, R, K>
58where
59    K: BasicOps<B> + Numeric<B>,
60{
61    assert_eq!(
62        R,
63        D + 1,
64        "`outer` with D={D} expects R={} (got R={R})",
65        D + 1
66    );
67    let dim = dim.expect_dim_index(D);
68
69    // (..., i, 1, ...)
70    let x = lhs.unsqueeze_dim::<R>(dim + 1);
71
72    // (..., 1, j, ...)
73    let y = rhs.unsqueeze_dim::<R>(dim);
74
75    // (..., i, j, ...)
76    x * y
77}