#[derive(Debug, Clone)]
pub struct CohenKappa {
matrix: Vec<Vec<u64>>,
n: u64,
}
impl CohenKappa {
pub fn new() -> Self {
Self {
matrix: Vec::new(),
n: 0,
}
}
fn grow(&mut self, size: usize) {
let old = self.matrix.len();
if size <= old {
return;
}
for row in &mut self.matrix {
row.resize(size, 0);
}
for _ in old..size {
self.matrix.push(vec![0; size]);
}
}
pub fn update(&mut self, true_label: usize, predicted_label: usize) {
let required = true_label.max(predicted_label) + 1;
self.grow(required);
self.matrix[true_label][predicted_label] += 1;
self.n += 1;
}
pub fn kappa(&self) -> f64 {
if self.n == 0 {
return 0.0;
}
let n = self.n as f64;
let k = self.matrix.len();
let mut correct = 0u64;
for i in 0..k {
correct += self.matrix[i][i];
}
let p_o = correct as f64 / n;
let mut p_e = 0.0;
for i in 0..k {
let row_sum: u64 = self.matrix[i].iter().sum();
let col_sum: u64 = self.matrix.iter().map(|row| row[i]).sum();
p_e += (row_sum as f64) * (col_sum as f64);
}
p_e /= n * n;
if (1.0 - p_e).abs() < f64::EPSILON {
return 0.0;
}
(p_o - p_e) / (1.0 - p_e)
}
pub fn n_samples(&self) -> u64 {
self.n
}
pub fn reset(&mut self) {
self.matrix.clear();
self.n = 0;
}
}
impl Default for CohenKappa {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct KappaM {
class_counts: Vec<u64>,
n: u64,
n_correct: u64,
}
impl KappaM {
pub fn new() -> Self {
Self {
class_counts: Vec::new(),
n: 0,
n_correct: 0,
}
}
pub fn update(&mut self, true_label: usize, predicted_label: usize) {
if true_label >= self.class_counts.len() {
self.class_counts.resize(true_label + 1, 0);
}
self.class_counts[true_label] += 1;
self.n += 1;
if true_label == predicted_label {
self.n_correct += 1;
}
}
pub fn majority_class(&self) -> Option<usize> {
if self.n == 0 {
return None;
}
self.class_counts
.iter()
.enumerate()
.max_by(|&(i_a, &count_a), &(i_b, &count_b)| count_a.cmp(&count_b).then(i_b.cmp(&i_a)))
.map(|(idx, _)| idx)
}
pub fn kappa_m(&self) -> f64 {
if self.n == 0 {
return 0.0;
}
let n = self.n as f64;
let p_model = self.n_correct as f64 / n;
let max_count = self.class_counts.iter().copied().max().unwrap_or(0);
let p_majority = max_count as f64 / n;
if (1.0 - p_majority).abs() < f64::EPSILON {
return 0.0;
}
(p_model - p_majority) / (1.0 - p_majority)
}
pub fn n_samples(&self) -> u64 {
self.n
}
pub fn reset(&mut self) {
self.class_counts.clear();
self.n = 0;
self.n_correct = 0;
}
}
impl Default for KappaM {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct KappaT {
n: u64,
n_correct: u64,
n_nochange_correct: u64,
prev_true_label: Option<usize>,
}
impl KappaT {
pub fn new() -> Self {
Self {
n: 0,
n_correct: 0,
n_nochange_correct: 0,
prev_true_label: None,
}
}
pub fn update(&mut self, true_label: usize, predicted_label: usize) {
self.n += 1;
if true_label == predicted_label {
self.n_correct += 1;
}
if let Some(prev) = self.prev_true_label {
if true_label == prev {
self.n_nochange_correct += 1;
}
}
self.prev_true_label = Some(true_label);
}
pub fn kappa_t(&self) -> f64 {
if self.n < 2 {
return 0.0;
}
let n = self.n as f64;
let p_model = self.n_correct as f64 / n;
let n_comparable = (self.n - 1) as f64;
let p_nochange = self.n_nochange_correct as f64 / n_comparable;
if (1.0 - p_nochange).abs() < f64::EPSILON {
return 0.0;
}
(p_model - p_nochange) / (1.0 - p_nochange)
}
pub fn n_samples(&self) -> u64 {
self.n
}
pub fn reset(&mut self) {
self.n = 0;
self.n_correct = 0;
self.n_nochange_correct = 0;
self.prev_true_label = None;
}
}
impl Default for KappaT {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-10;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < EPS
}
#[test]
fn cohen_empty_returns_zero() {
let k = CohenKappa::new();
assert_eq!(k.kappa(), 0.0);
assert_eq!(k.n_samples(), 0);
}
#[test]
fn cohen_perfect_classifier() {
let mut k = CohenKappa::new();
for _ in 0..50 {
k.update(0, 0);
}
for _ in 0..50 {
k.update(1, 1);
}
assert!(approx_eq(k.kappa(), 1.0));
assert_eq!(k.n_samples(), 100);
}
#[test]
fn cohen_perfect_multiclass() {
let mut k = CohenKappa::new();
for class in 0..5 {
for _ in 0..20 {
k.update(class, class);
}
}
assert!(approx_eq(k.kappa(), 1.0));
}
#[test]
fn cohen_known_confusion_matrix() {
let mut k = CohenKappa::new();
for _ in 0..20 {
k.update(0, 0);
}
for _ in 0..5 {
k.update(0, 1);
}
for _ in 0..10 {
k.update(1, 0);
}
for _ in 0..15 {
k.update(1, 1);
}
assert!(approx_eq(k.kappa(), 0.4));
}
#[test]
fn cohen_chance_classifier() {
let mut k = CohenKappa::new();
for _ in 0..25 {
k.update(0, 0);
}
for _ in 0..25 {
k.update(0, 1);
}
for _ in 0..25 {
k.update(1, 0);
}
for _ in 0..25 {
k.update(1, 1);
}
assert!(approx_eq(k.kappa(), 0.0));
}
#[test]
fn cohen_worse_than_chance() {
let mut k = CohenKappa::new();
for _ in 0..50 {
k.update(0, 1);
}
for _ in 0..50 {
k.update(1, 0);
}
assert!(approx_eq(k.kappa(), -1.0));
}
#[test]
fn cohen_single_class() {
let mut k = CohenKappa::new();
for _ in 0..100 {
k.update(0, 0);
}
assert_eq!(k.kappa(), 0.0);
}
#[test]
fn cohen_single_sample() {
let mut k = CohenKappa::new();
k.update(0, 0);
assert_eq!(k.kappa(), 0.0);
assert_eq!(k.n_samples(), 1);
}
#[test]
fn cohen_auto_grow() {
let mut k = CohenKappa::new();
k.update(5, 5);
k.update(2, 2);
k.update(0, 0);
assert_eq!(k.n_samples(), 3);
assert!(k.matrix.len() >= 6);
}
#[test]
fn cohen_reset() {
let mut k = CohenKappa::new();
k.update(0, 0);
k.update(1, 1);
k.reset();
assert_eq!(k.n_samples(), 0);
assert_eq!(k.kappa(), 0.0);
assert!(k.matrix.is_empty());
}
#[test]
fn cohen_default_is_empty() {
let k = CohenKappa::default();
assert_eq!(k.n_samples(), 0);
assert_eq!(k.kappa(), 0.0);
}
#[test]
fn kappam_empty_returns_zero() {
let k = KappaM::new();
assert_eq!(k.kappa_m(), 0.0);
assert_eq!(k.n_samples(), 0);
assert_eq!(k.majority_class(), None);
}
#[test]
fn kappam_perfect_classifier() {
let mut k = KappaM::new();
for _ in 0..60 {
k.update(0, 0);
}
for _ in 0..40 {
k.update(1, 1);
}
assert!(approx_eq(k.kappa_m(), 1.0));
}
#[test]
fn kappam_majority_class_tracking() {
let mut k = KappaM::new();
k.update(0, 0);
k.update(0, 0);
k.update(1, 1);
assert_eq!(k.majority_class(), Some(0));
k.update(1, 1);
k.update(1, 1);
assert_eq!(k.majority_class(), Some(1));
}
#[test]
fn kappam_majority_class_tie_breaks_lowest() {
let mut k = KappaM::new();
k.update(0, 0);
k.update(1, 1);
assert_eq!(k.majority_class(), Some(0));
}
#[test]
fn kappam_acts_like_majority_baseline() {
let mut k = KappaM::new();
for _ in 0..60 {
k.update(0, 0); }
for _ in 0..40 {
k.update(1, 0); }
assert!(approx_eq(k.kappa_m(), 0.0));
}
#[test]
fn kappam_worse_than_majority() {
let mut k = KappaM::new();
for _ in 0..80 {
k.update(0, 1); }
for _ in 0..20 {
k.update(1, 1); }
assert!(approx_eq(k.kappa_m(), -3.0));
}
#[test]
fn kappam_single_class_returns_zero() {
let mut k = KappaM::new();
for _ in 0..50 {
k.update(0, 0);
}
assert_eq!(k.kappa_m(), 0.0);
}
#[test]
fn kappam_single_sample() {
let mut k = KappaM::new();
k.update(0, 0);
assert_eq!(k.kappa_m(), 0.0);
assert_eq!(k.n_samples(), 1);
}
#[test]
fn kappam_reset() {
let mut k = KappaM::new();
k.update(0, 0);
k.update(1, 1);
k.reset();
assert_eq!(k.n_samples(), 0);
assert_eq!(k.kappa_m(), 0.0);
assert_eq!(k.majority_class(), None);
}
#[test]
fn kappam_default_is_empty() {
let k = KappaM::default();
assert_eq!(k.n_samples(), 0);
assert_eq!(k.kappa_m(), 0.0);
}
#[test]
fn kappat_empty_returns_zero() {
let k = KappaT::new();
assert_eq!(k.kappa_t(), 0.0);
assert_eq!(k.n_samples(), 0);
}
#[test]
fn kappat_single_sample_returns_zero() {
let mut k = KappaT::new();
k.update(0, 0);
assert_eq!(k.kappa_t(), 0.0);
assert_eq!(k.n_samples(), 1);
}
#[test]
fn kappat_perfect_on_changing_stream() {
let mut k = KappaT::new();
let labels = [0, 1, 0, 1, 0, 1];
for &l in &labels {
k.update(l, l);
}
assert!(approx_eq(k.kappa_t(), 1.0));
}
#[test]
fn kappat_no_change_baseline_is_perfect() {
let mut k = KappaT::new();
for _ in 0..5 {
k.update(0, 0);
}
assert_eq!(k.kappa_t(), 0.0);
}
#[test]
fn kappat_model_equals_nochange() {
let mut k = KappaT::new();
let trues = [0, 0, 1, 1, 0, 0];
let preds = [0, 0, 0, 1, 1, 0];
for i in 0..6 {
k.update(trues[i], preds[i]);
}
let expected = (4.0 / 6.0 - 3.0 / 5.0) / (1.0 - 3.0 / 5.0);
assert!(approx_eq(k.kappa_t(), expected));
}
#[test]
fn kappat_worse_than_nochange() {
let mut k = KappaT::new();
for _ in 0..5 {
k.update(0, 1);
}
assert_eq!(k.kappa_t(), 0.0);
}
#[test]
fn kappat_computed_example() {
let mut k = KappaT::new();
let trues = [0, 1, 1, 0, 1];
let preds = [0, 0, 1, 1, 1];
for i in 0..5 {
k.update(trues[i], preds[i]);
}
let expected = (0.6 - 0.25) / (1.0 - 0.25);
assert!(approx_eq(k.kappa_t(), expected));
}
#[test]
fn kappat_reset() {
let mut k = KappaT::new();
k.update(0, 0);
k.update(1, 1);
k.reset();
assert_eq!(k.n_samples(), 0);
assert_eq!(k.kappa_t(), 0.0);
}
#[test]
fn kappat_default_is_empty() {
let k = KappaT::default();
assert_eq!(k.n_samples(), 0);
assert_eq!(k.kappa_t(), 0.0);
}
#[test]
fn kappat_two_samples() {
let mut k = KappaT::new();
k.update(0, 0);
k.update(1, 1);
assert!(approx_eq(k.kappa_t(), 1.0));
}
}