burn_tensor/tensor/linalg/
outer.rs1use crate::backend::Backend;
2use crate::tensor::{BasicOps, Tensor};
3use crate::{Numeric, Shape};
4
5pub fn outer<B: Backend, const D: usize, const R: usize, K>(
20 x: Tensor<B, D, K>,
21 y: Tensor<B, D, K>,
22) -> Tensor<B, R, K>
23where
24 K: BasicOps<B> + Numeric<B>,
25{
26 if D == 1 {
27 assert!(R == 2, "`outer` with D=1 must use R=2 (got R={})", R);
28 let [m] = x.shape().dims();
29 let [n] = y.shape().dims();
30
31 let x_col = x.reshape(Shape::new([m, 1])); let y_row = y.reshape(Shape::new([1, n])); x_col * y_row } else if D == 2 {
36 assert!(R == 3, "`outer` with D=2 must use R=3 (got R={})", R);
37 let [bx, m] = x.shape().dims();
38 let [by, n] = y.shape().dims();
39 assert_eq!(bx, by, "batch dimensions must match (got {} vs {})", bx, by);
40
41 let x_col = x.reshape(Shape::new([bx, m, 1])); let y_row = y.reshape(Shape::new([by, 1, n])); x_col * y_row } else {
46 panic!("`outer` only supports rank 1 or 2 tensors (got D={})", D);
47 }
48}