1use burn::{
2 prelude::Backend,
3 tensor::{Float, Tensor},
4};
5
6pub fn normalize_tensor<B: Backend>(tensor: Tensor<B, 2, Float>) -> Tensor<B, 2, Float> {
11 let norm = tensor.clone().powf_scalar(2.0).sum_dim(1).sqrt(); tensor.div(norm) }
14
15pub fn cross_product<B: Backend>(
17 a: &Tensor<B, 2, Float>, b: &Tensor<B, 2, Float>, ) -> Tensor<B, 2, Float> {
20 let a_chunks = a.clone().chunk(3, 1); let b_chunks = b.clone().chunk(3, 1); let ax: Tensor<B, 1> = a_chunks[0].clone().squeeze(1); let ay: Tensor<B, 1> = a_chunks[1].clone().squeeze(1); let az: Tensor<B, 1> = a_chunks[2].clone().squeeze(1); let bx: Tensor<B, 1> = b_chunks[0].clone().squeeze(1); let by: Tensor<B, 1> = b_chunks[1].clone().squeeze(1); let bz: Tensor<B, 1> = b_chunks[2].clone().squeeze(1); let cx = ay.clone().mul(bz.clone()).sub(az.clone().mul(by.clone())); let cy = az.mul(bx.clone()).sub(ax.clone().mul(bz)); let cz = ax.mul(by).sub(ay.mul(bx)); Tensor::stack(vec![cx, cy, cz], 1) }