use std::collections::BTreeSet;
pub const MIN_PPL_IMPROVEMENT: f64 = 0.005;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ImprovementOutcome {
Improved { delta: f64 },
Insufficient { delta: f64, threshold: f64 },
}
#[must_use]
pub fn classify_imatrix_improvement(
ppl_naive: f64,
ppl_calib: f64,
threshold: f64,
) -> ImprovementOutcome {
if !ppl_naive.is_finite() || ppl_naive <= 0.0 {
return ImprovementOutcome::Insufficient {
delta: f64::NAN,
threshold,
};
}
let delta = (ppl_naive - ppl_calib) / ppl_naive;
if delta >= threshold {
ImprovementOutcome::Improved { delta }
} else {
ImprovementOutcome::Insufficient { delta, threshold }
}
}
#[must_use]
pub fn calibration_eval_disjoint(
calib_hashes: &BTreeSet<String>,
eval_hashes: &BTreeSet<String>,
) -> bool {
calib_hashes.intersection(eval_hashes).next().is_none()
}
#[must_use]
pub fn parse_imatrix_flag(argv: &[&str]) -> Option<String> {
let mut i = 0;
while i < argv.len() {
let a = argv[i];
if a == "--imatrix" {
return argv.get(i + 1).map(|p| (*p).to_string());
}
if let Some(rest) = a.strip_prefix("--imatrix=") {
return Some(rest.to_string());
}
i += 1;
}
None
}
#[must_use]
pub fn compute_provenance_sha256(bytes: &[u8]) -> String {
use sha2::{Digest, Sha256};
let mut h = Sha256::new();
h.update(bytes);
let out = h.finalize();
let mut s = String::with_capacity(64);
for b in out {
s.push_str(&format!("{:02x}", b));
}
s
}
#[derive(Debug, Clone, PartialEq)]
pub enum ProvenanceOutcome {
Match,
Missing,
Mismatch { recorded: String, expected: String },
}
#[must_use]
pub fn validate_recorded_provenance(recorded: Option<&str>, expected: &str) -> ProvenanceOutcome {
match recorded {
None => ProvenanceOutcome::Missing,
Some(r) if r.eq_ignore_ascii_case(expected) => ProvenanceOutcome::Match,
Some(r) => ProvenanceOutcome::Mismatch {
recorded: r.to_string(),
expected: expected.to_string(),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn improvement_meets_threshold_exact() {
match classify_imatrix_improvement(100.0, 99.5, MIN_PPL_IMPROVEMENT) {
ImprovementOutcome::Improved { delta } => assert!((delta - 0.005).abs() < 1e-9),
o => panic!("expected Improved, got {:?}", o),
}
}
#[test]
fn improvement_above_threshold_classifies_improved() {
assert!(matches!(
classify_imatrix_improvement(100.0, 90.0, MIN_PPL_IMPROVEMENT),
ImprovementOutcome::Improved { .. }
));
}
#[test]
fn improvement_just_below_threshold_is_insufficient() {
let o = classify_imatrix_improvement(100.0, 99.6, MIN_PPL_IMPROVEMENT);
assert!(matches!(o, ImprovementOutcome::Insufficient { .. }));
}
#[test]
fn regression_is_insufficient() {
let o = classify_imatrix_improvement(100.0, 110.0, MIN_PPL_IMPROVEMENT);
assert!(matches!(o, ImprovementOutcome::Insufficient { .. }));
}
#[test]
fn zero_or_negative_baseline_is_insufficient_not_panic() {
assert!(matches!(
classify_imatrix_improvement(0.0, 5.0, MIN_PPL_IMPROVEMENT),
ImprovementOutcome::Insufficient { .. }
));
assert!(matches!(
classify_imatrix_improvement(-1.0, 5.0, MIN_PPL_IMPROVEMENT),
ImprovementOutcome::Insufficient { .. }
));
}
#[test]
fn improvement_is_deterministic() {
let a = classify_imatrix_improvement(42.0, 40.0, MIN_PPL_IMPROVEMENT);
let b = classify_imatrix_improvement(42.0, 40.0, MIN_PPL_IMPROVEMENT);
assert_eq!(format!("{:?}", a), format!("{:?}", b));
}
#[test]
fn disjoint_sets_are_disjoint() {
let calib: BTreeSet<String> = ["a", "b"].iter().map(|s| s.to_string()).collect();
let eval: BTreeSet<String> = ["c", "d"].iter().map(|s| s.to_string()).collect();
assert!(calibration_eval_disjoint(&calib, &eval));
}
#[test]
fn overlapping_sets_are_not_disjoint() {
let calib: BTreeSet<String> = ["a", "b", "c"].iter().map(|s| s.to_string()).collect();
let eval: BTreeSet<String> = ["c", "d"].iter().map(|s| s.to_string()).collect();
assert!(!calibration_eval_disjoint(&calib, &eval));
}
#[test]
fn empty_sets_are_trivially_disjoint() {
let empty: BTreeSet<String> = BTreeSet::new();
assert!(calibration_eval_disjoint(&empty, &empty));
}
#[test]
fn parse_imatrix_flag_finds_space_form() {
let argv = &["quantize", "model.apr", "--imatrix", "calib.jsonl"];
assert_eq!(parse_imatrix_flag(argv), Some("calib.jsonl".to_string()));
}
#[test]
fn parse_imatrix_flag_finds_equals_form() {
let argv = &["quantize", "model.apr", "--imatrix=calib.jsonl"];
assert_eq!(parse_imatrix_flag(argv), Some("calib.jsonl".to_string()));
}
#[test]
fn parse_imatrix_flag_missing_returns_none() {
let argv = &["quantize", "model.apr", "--method", "q4k"];
assert_eq!(parse_imatrix_flag(argv), None);
}
#[test]
fn parse_imatrix_flag_without_value_returns_none() {
let argv = &["quantize", "model.apr", "--imatrix"];
assert_eq!(parse_imatrix_flag(argv), None);
}
#[test]
fn parse_imatrix_flag_ignores_similar_names() {
let argv = &["quantize", "--imatrix-force", "true"];
assert_eq!(parse_imatrix_flag(argv), None);
}
#[test]
fn sha256_of_empty_is_known_constant() {
assert_eq!(
compute_provenance_sha256(&[]),
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
);
}
#[test]
fn sha256_is_deterministic() {
let a = compute_provenance_sha256(b"calibration-bytes");
let b = compute_provenance_sha256(b"calibration-bytes");
assert_eq!(a, b);
assert_eq!(a.len(), 64);
}
#[test]
fn sha256_differs_on_different_input() {
let a = compute_provenance_sha256(b"a");
let b = compute_provenance_sha256(b"b");
assert_ne!(a, b);
}
#[test]
fn provenance_match_ok() {
let expected = compute_provenance_sha256(b"calib-v1");
assert_eq!(
validate_recorded_provenance(Some(&expected), &expected),
ProvenanceOutcome::Match
);
}
#[test]
fn provenance_match_is_case_insensitive() {
let expected = compute_provenance_sha256(b"calib-v1");
let upper = expected.to_uppercase();
assert_eq!(
validate_recorded_provenance(Some(&upper), &expected),
ProvenanceOutcome::Match
);
}
#[test]
fn provenance_missing_is_failure() {
let expected = compute_provenance_sha256(b"calib-v1");
assert_eq!(
validate_recorded_provenance(None, &expected),
ProvenanceOutcome::Missing
);
}
#[test]
fn provenance_mismatch_carries_both_values() {
let expected = compute_provenance_sha256(b"calib-v1");
let wrong = compute_provenance_sha256(b"calib-v2");
match validate_recorded_provenance(Some(&wrong), &expected) {
ProvenanceOutcome::Mismatch {
recorded,
expected: exp,
} => {
assert_eq!(recorded, wrong);
assert_eq!(exp, expected);
}
o => panic!("expected Mismatch, got {:?}", o),
}
}
}