ruvector-coherence 2.0.6

Coherence measurement proxies for comparing attention mechanisms
Documentation
//! Side-by-side comparison utilities for attention masks.

use serde::{Deserialize, Serialize};

/// Result of comparing two attention masks.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComparisonResult {
    pub jaccard: f64,
    pub edge_flips: usize,
    pub baseline_edges: usize,
    pub gated_edges: usize,
    pub sparsity_ratio: f64,
}

/// Jaccard similarity: `|A & B| / |A | B|`. Returns `1.0` for two empty masks.
pub fn jaccard_similarity(mask_a: &[bool], mask_b: &[bool]) -> f64 {
    let n = mask_a.len().min(mask_b.len());
    let (mut inter, mut union) = (0usize, 0usize);
    for i in 0..n {
        if mask_a[i] || mask_b[i] {
            union += 1;
        }
        if mask_a[i] && mask_b[i] {
            inter += 1;
        }
    }
    union += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
    if union == 0 {
        1.0
    } else {
        inter as f64 / union as f64
    }
}

/// Counts positions where the two masks disagree.
pub fn edge_flip_count(mask_a: &[bool], mask_b: &[bool]) -> usize {
    let n = mask_a.len().min(mask_b.len());
    let mut flips = (0..n).filter(|&i| mask_a[i] != mask_b[i]).count();
    flips += count_true_tail(mask_a, n) + count_true_tail(mask_b, n);
    flips
}

/// Full comparison of two attention masks.
pub fn compare_attention_masks(baseline: &[bool], gated: &[bool]) -> ComparisonResult {
    let baseline_edges = baseline.iter().filter(|&&v| v).count();
    let gated_edges = gated.iter().filter(|&&v| v).count();
    let total = baseline.len().max(gated.len());
    let bl_sp = if total > 0 {
        1.0 - baseline_edges as f64 / total as f64
    } else {
        1.0
    };
    let gt_sp = if total > 0 {
        1.0 - gated_edges as f64 / total as f64
    } else {
        1.0
    };
    ComparisonResult {
        jaccard: jaccard_similarity(baseline, gated),
        edge_flips: edge_flip_count(baseline, gated),
        baseline_edges,
        gated_edges,
        sparsity_ratio: if bl_sp > f64::EPSILON {
            gt_sp / bl_sp
        } else {
            gt_sp
        },
    }
}

fn count_true_tail(mask: &[bool], from: usize) -> usize {
    if mask.len() > from {
        mask[from..].iter().filter(|&&v| v).count()
    } else {
        0
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn jaccard_cases() {
        let m = vec![true, false, true, true];
        assert!((jaccard_similarity(&m, &m) - 1.0).abs() < 1e-10);
        assert!(jaccard_similarity(&[true, false], &[false, true]).abs() < 1e-10);
        assert_eq!(jaccard_similarity(&[], &[]), 1.0);
        // partial: intersection=1, union=3
        let (a, b) = (
            vec![true, true, false, false],
            vec![true, false, true, false],
        );
        assert!((jaccard_similarity(&a, &b) - 1.0 / 3.0).abs() < 1e-10);
    }

    #[test]
    fn edge_flip_cases() {
        assert_eq!(edge_flip_count(&[true, false], &[true, false]), 0);
        assert_eq!(
            edge_flip_count(&[true, false, true], &[false, true, false]),
            3
        );
        assert_eq!(
            edge_flip_count(&[true, false], &[true, false, true, true]),
            2
        );
    }

    #[test]
    fn compare_masks() {
        let bl = vec![true, true, false, false, true];
        let gt = vec![true, false, false, true, true];
        let r = compare_attention_masks(&bl, &gt);
        assert_eq!(r.baseline_edges, 3);
        assert_eq!(r.gated_edges, 3);
        assert_eq!(r.edge_flips, 2);
        assert!((r.jaccard - 0.5).abs() < 1e-10);
    }
}