1use ai3_lib::tensor::Tensor;
2use flate2::write::DeflateEncoder;
3use flate2::Compression;
4use pot_o_core::TribeResult;
5use std::io::Write;
6
7pub struct MMLPathValidator {
13 pub compression_level: u32,
14}
15
16impl Default for MMLPathValidator {
17 fn default() -> Self {
18 Self {
19 compression_level: 6,
20 }
21 }
22}
23
24impl MMLPathValidator {
25 pub fn compute_mml_score(&self, input: &Tensor, output: &Tensor) -> TribeResult<f64> {
28 let input_compressed_len = self.compressed_length(&input.data.to_bytes())?;
29 let output_compressed_len = self.compressed_length(&output.data.to_bytes())?;
30
31 if input_compressed_len == 0 {
32 return Ok(1.0);
33 }
34
35 Ok(output_compressed_len as f64 / input_compressed_len as f64)
36 }
37
38 pub fn compute_entropy_mml_score(&self, input: &Tensor, output: &Tensor) -> f64 {
45 fn entropy(bytes: &[u8]) -> f64 {
46 let mut hist = [0u64; 256];
47 for &b in bytes {
48 hist[b as usize] += 1;
49 }
50 let total = bytes.len() as f64;
51 if total == 0.0 {
52 return 0.0;
53 }
54 let mut ent = 0.0f64;
55 for &count in &hist {
56 if count == 0 {
57 continue;
58 }
59 let p = count as f64 / total;
60 ent -= p * p.ln();
61 }
62 ent
63 }
64
65 let input_bytes = input.data.to_bytes();
66 let output_bytes = output.data.to_bytes();
67 let in_ent = entropy(&input_bytes);
68 let out_ent = entropy(&output_bytes);
69 if in_ent.abs() < f64::EPSILON {
70 1.0
71 } else {
72 out_ent / in_ent
73 }
74 }
75
76 pub fn validate(&self, mml_score: f64, mml_threshold: f64) -> bool {
78 mml_score <= mml_threshold
79 }
80
81 pub fn threshold_for_difficulty(base_threshold: f64, difficulty: u64) -> f64 {
84 base_threshold / (1.0 + (difficulty as f64).log2().max(0.0))
85 }
86
87 fn compressed_length(&self, data: &[u8]) -> TribeResult<usize> {
88 let mut encoder = DeflateEncoder::new(Vec::new(), Compression::new(self.compression_level));
89 encoder
90 .write_all(data)
91 .map_err(|e| pot_o_core::TribeError::TensorError(format!("Compression failed: {e}")))?;
92 let compressed = encoder
93 .finish()
94 .map_err(|e| pot_o_core::TribeError::TensorError(format!("Compression failed: {e}")))?;
95 Ok(compressed.len())
96 }
97}
98
99#[cfg(test)]
100mod tests {
101 use super::*;
102 use ai3_lib::tensor::{TensorData, TensorShape};
103
104 #[test]
105 fn test_mml_score_computation() {
106 let validator = MMLPathValidator::default();
107
108 let input = Tensor::new(
109 TensorShape::new(vec![4]),
110 TensorData::F32(vec![1.0, 2.0, 3.0, 4.0]),
111 )
112 .unwrap();
113
114 let output = Tensor::new(
116 TensorShape::new(vec![4]),
117 TensorData::F32(vec![0.0, 0.0, 0.0, 0.0]),
118 )
119 .unwrap();
120
121 let score = validator.compute_mml_score(&input, &output).unwrap();
122 assert!(score > 0.0, "Score should be positive");
123 }
124
125 #[test]
126 fn test_threshold_scaling() {
127 let t1 = MMLPathValidator::threshold_for_difficulty(0.85, 1);
128 let t4 = MMLPathValidator::threshold_for_difficulty(0.85, 4);
129 let t8 = MMLPathValidator::threshold_for_difficulty(0.85, 8);
130 assert!(t4 < t1, "Higher difficulty should give lower threshold");
131 assert!(t8 < t4);
132 }
133}