burn_tensor/tensor/linalg/outer.rs
1use crate::backend::Backend;
2use crate::tensor::{BasicOps, Tensor};
3use crate::{AsIndex, Numeric};
4
5/// Computes the outer product for the last columns of 2 tensors.
6///
7/// See also: [`outer_dim`].
8///
9/// # Arguments
10/// - `lhs`: the "row" tensor, with shape ``[..., i]``.
11/// - `rhs`: the "col" tensor, with shape ``[..., j]``.
12/// - `dim`: the dimension to product.
13///
14/// # Returns
15///
16/// A tensor of rank `R = D + 1`, where:
17///
18/// ``
19/// result[..., i, j] = lhs[..., i] * rhs[..., j]
20/// ``
21pub fn outer<B: Backend, const D: usize, const R: usize, K>(
22 x: Tensor<B, D, K>,
23 y: Tensor<B, D, K>,
24) -> Tensor<B, R, K>
25where
26 K: BasicOps<B> + Numeric<B>,
27{
28 outer_dim(x, y, -1)
29}
30
31/// Computes the outer product along a specific dimension, broadcasting over others.
32///
33/// For the given `dim`, computes the outer product of elements along that dimension,
34/// expanding it into two dimensions of size ``M × N`` at positions ``(dim, dim + 1)``.
35///
36/// # Arguments
37///
38/// - `lhs`: left operand, the "row" tensor, with size `M` at dimension `dim`.
39/// - `rhs`: right operand, the "col" tensor, with size `N` at dimension `dim`.
40/// - `dim`: dimension to compute the outer product along (supports negative indexing).
41///
42/// # Returns
43///
44/// A tensor of rank `R = D + 1`, where:
45///
46/// ``
47/// result[..., i, j, ...] = lhs[..., i, ...] * rhs[..., j, ...]
48/// ``
49//
50// Notes:
51// - For large batched inputs, `x_col.matmul(y_row)` *might* be more performant
52// than broadcasted elemwise multiply; benchmarking needed to confirm.
53pub fn outer_dim<B: Backend, const D: usize, const R: usize, Dim: AsIndex, K>(
54 lhs: Tensor<B, D, K>,
55 rhs: Tensor<B, D, K>,
56 dim: Dim,
57) -> Tensor<B, R, K>
58where
59 K: BasicOps<B> + Numeric<B>,
60{
61 assert_eq!(
62 R,
63 D + 1,
64 "`outer` with D={D} expects R={} (got R={R})",
65 D + 1
66 );
67 let dim = dim.expect_dim_index(D);
68
69 // (..., i, 1, ...)
70 let x = lhs.unsqueeze_dim::<R>(dim + 1);
71
72 // (..., 1, j, ...)
73 let y = rhs.unsqueeze_dim::<R>(dim);
74
75 // (..., i, j, ...)
76 x * y
77}