1use ai3_lib::tensor::Tensor;
4use flate2::write::DeflateEncoder;
5use flate2::Compression;
6use pot_o_core::TribeResult;
7use std::io::Write;
8
9pub struct MMLPathValidator {
15 pub compression_level: u32,
16}
17
18impl Default for MMLPathValidator {
19 fn default() -> Self {
20 Self {
21 compression_level: 6,
22 }
23 }
24}
25
26impl MMLPathValidator {
27 pub fn compute_mml_score(&self, input: &Tensor, output: &Tensor) -> TribeResult<f64> {
30 let input_compressed_len = self.compressed_length(&input.data.to_bytes())?;
31 let output_compressed_len = self.compressed_length(&output.data.to_bytes())?;
32
33 if input_compressed_len == 0 {
34 return Ok(1.0);
35 }
36
37 Ok(output_compressed_len as f64 / input_compressed_len as f64)
38 }
39
40 pub fn compute_entropy_mml_score(&self, input: &Tensor, output: &Tensor) -> f64 {
47 fn entropy(bytes: &[u8]) -> f64 {
48 let mut hist = [0u64; 256];
49 for &b in bytes {
50 hist[b as usize] += 1;
51 }
52 let total = bytes.len() as f64;
53 if total == 0.0 {
54 return 0.0;
55 }
56 let mut ent = 0.0f64;
57 for &count in &hist {
58 if count == 0 {
59 continue;
60 }
61 let p = count as f64 / total;
62 ent -= p * p.ln();
63 }
64 ent
65 }
66
67 let input_bytes = input.data.to_bytes();
68 let output_bytes = output.data.to_bytes();
69 let in_ent = entropy(&input_bytes);
70 let out_ent = entropy(&output_bytes);
71 if in_ent.abs() < f64::EPSILON {
72 1.0
73 } else {
74 out_ent / in_ent
75 }
76 }
77
78 pub fn validate(&self, mml_score: f64, mml_threshold: f64) -> bool {
80 mml_score <= mml_threshold
81 }
82
83 pub fn threshold_for_difficulty(base_threshold: f64, difficulty: u64) -> f64 {
86 base_threshold / (1.0 + (difficulty as f64).log2().max(0.0))
87 }
88
89 fn compressed_length(&self, data: &[u8]) -> TribeResult<usize> {
90 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::new(self.compression_level));
91 encoder
92 .write_all(data)
93 .map_err(|e| pot_o_core::TribeError::TensorError(format!("Compression failed: {e}")))?;
94 let compressed = encoder
95 .finish()
96 .map_err(|e| pot_o_core::TribeError::TensorError(format!("Compression failed: {e}")))?;
97 Ok(compressed.len())
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use ai3_lib::tensor::{TensorData, TensorShape};
105
106 #[test]
107 fn test_mml_score_computation() {
108 let validator = MMLPathValidator::default();
109
110 let input = Tensor::new(
111 TensorShape::new(vec![4]),
112 TensorData::F32(vec![1.0, 2.0, 3.0, 4.0]),
113 )
114 .unwrap();
115
116 let output = Tensor::new(
118 TensorShape::new(vec![4]),
119 TensorData::F32(vec![0.0, 0.0, 0.0, 0.0]),
120 )
121 .unwrap();
122
123 let score = validator.compute_mml_score(&input, &output).unwrap();
124 assert!(score > 0.0, "Score should be positive");
125 }
126
127 #[test]
128 fn test_threshold_scaling() {
129 let t1 = MMLPathValidator::threshold_for_difficulty(0.85, 1);
130 let t4 = MMLPathValidator::threshold_for_difficulty(0.85, 4);
131 let t8 = MMLPathValidator::threshold_for_difficulty(0.85, 8);
132 assert!(t4 < t1, "Higher difficulty should give lower threshold");
133 assert!(t8 < t4);
134 }
135}