speakrs 0.3.2

Fast Rust speaker diarization with pyannote-level accuracy and native CoreML/CUDA acceleration
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};

pub fn l2_normalize(vector: &ArrayView1<f32>) -> Array1<f32> {
    let norm = vector.dot(vector).sqrt();
    if norm == 0.0 {
        return Array1::zeros(vector.len());
    }
    vector / norm
}

pub fn l2_normalize_rows(embeddings: &ArrayView2<f32>) -> Array2<f32> {
    let mut normalized = embeddings.to_owned();
    for mut row in normalized.rows_mut() {
        let norm = row.dot(&row).sqrt();
        if norm > 0.0 {
            row /= norm;
        }
    }
    normalized
}

pub fn l2_normalize_rows_f64(embeddings: &ArrayView2<f64>) -> Array2<f64> {
    let mut normalized = embeddings.to_owned();
    for mut row in normalized.rows_mut() {
        let norm = row.dot(&row).sqrt();
        if norm > 0.0 {
            row /= norm;
        }
    }
    normalized
}

pub fn cosine_similarity(lhs: &ArrayView1<f32>, rhs: &ArrayView1<f32>) -> f32 {
    let lhs_norm = l2_normalize(lhs);
    let rhs_norm = l2_normalize(rhs);
    lhs_norm.dot(&rhs_norm)
}

pub fn logsumexp_f64(values: &ArrayView1<f64>) -> f64 {
    let max = values.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
    if max.is_infinite() {
        return max;
    }

    let sum_exp = values.mapv(|x| (x - max).exp()).sum();
    max + sum_exp.ln()
}

#[cfg(test)]
mod tests {
    use super::*;
    use approx::assert_abs_diff_eq;
    use ndarray::array;

    #[test]
    fn cosine_similarity_identical_vectors() {
        let v = array![1.0, 2.0, 3.0];
        let sim = cosine_similarity(&v.view(), &v.view());
        assert_abs_diff_eq!(sim, 1.0, epsilon = 1e-6);
    }

    #[test]
    fn cosine_similarity_orthogonal_vectors() {
        let a = array![1.0, 0.0];
        let b = array![0.0, 1.0];
        let sim = cosine_similarity(&a.view(), &b.view());
        assert_abs_diff_eq!(sim, 0.0, epsilon = 1e-6);
    }

    #[test]
    fn cosine_similarity_opposite_vectors() {
        let a = array![1.0, 2.0, 3.0];
        let b = array![-1.0, -2.0, -3.0];
        let sim = cosine_similarity(&a.view(), &b.view());
        assert_abs_diff_eq!(sim, -1.0, epsilon = 1e-6);
    }

    #[test]
    fn l2_normalize_has_unit_norm() {
        let v = array![3.0, 4.0];
        let normed = l2_normalize(&v.view());
        let norm = normed.dot(&normed).sqrt();
        assert_abs_diff_eq!(norm, 1.0, epsilon = 1e-6);
    }

    #[test]
    fn l2_normalize_zero_vector_stays_zero() {
        let v = array![0.0, 0.0, 0.0];
        let normed = l2_normalize(&v.view());
        assert_eq!(normed, array![0.0, 0.0, 0.0]);
    }
}