Skip to main content

gloss_utils/
tensor.rs

1use burn::{
2    prelude::Backend,
3    tensor::{Float, Tensor},
4};
5
6/////
7///// Some burn utilities
8/////
9/// Normalise a 2D tensor across dim 1
10pub fn normalize_tensor<B: Backend>(tensor: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
11    let norm = tensor.clone().powf_scalar(2.0).sum_dim(1).sqrt(); // Compute the L2 norm along the last axis (dim = 1)
12    tensor.div(norm) // Divide each vector by its norm
13}
14
15/// Cross product of 2 2D Tensors
16pub fn cross_product<B: Backend>(
17    a: &Tensor<B, 2, Float>, // Tensor of shape [N, 3]
18    b: &Tensor<B, 2, Float>, // Tensor of shape [N, 3]
19) -> Tensor<B, 2, Float> {
20    // Split the input tensors along dimension 1 (the 3 components) using chunk
21    let a_chunks = a.clone().chunk(3, 1); // Split tensor `a` into 3 chunks: ax, ay, az
22    let b_chunks = b.clone().chunk(3, 1); // Split tensor `b` into 3 chunks: bx, by, bz
23
24    let ax: Tensor<B, 1> = a_chunks[0].clone().squeeze(1); // x component of a
25    let ay: Tensor<B, 1> = a_chunks[1].clone().squeeze(1); // y component of a
26    let az: Tensor<B, 1> = a_chunks[2].clone().squeeze(1); // z component of a
27
28    let bx: Tensor<B, 1> = b_chunks[0].clone().squeeze(1); // x component of b
29    let by: Tensor<B, 1> = b_chunks[1].clone().squeeze(1); // y component of b
30    let bz: Tensor<B, 1> = b_chunks[2].clone().squeeze(1); // z component of b
31
32    // Compute the components of the cross product
33    let cx = ay.clone().mul(bz.clone()).sub(az.clone().mul(by.clone())); // cx = ay * bz - az * by
34    let cy = az.mul(bx.clone()).sub(ax.clone().mul(bz)); // cy = az * bx - ax * bz
35    let cz = ax.mul(by).sub(ay.mul(bx)); // cz = ax * by - ay * bx
36
37    // Stack the result to form the resulting [N, 3] tensor
38    Tensor::stack(vec![cx, cy, cz], 1) // Concatenate along the second dimension
39}