pylate_rs/utils.rs
1use crate::error::ColbertError;
2use candle_core::Tensor;
3
4/// Normalizes a tensor using L2 normalization along the last dimension.
5pub fn normalize_l2(v: &Tensor) -> Result<Tensor, ColbertError> {
6 let norm_l2 = v.sqr()?.sum_keepdim(v.rank() - 1)?.sqrt()?;
7 v.broadcast_div(&norm_l2).map_err(ColbertError::from)
8}