pylate_rs/
utils.rs

1use crate::error::ColbertError;
2use candle_core::Tensor;
3
4/// Normalizes a tensor using L2 normalization along the last dimension.
5pub fn normalize_l2(v: &Tensor) -> Result<Tensor, ColbertError> {
6    let norm_l2 = v.sqr()?.sum_keepdim(v.rank() - 1)?.sqrt()?;
7    v.broadcast_div(&norm_l2).map_err(ColbertError::from)
8}