Skip to main content

aprender/verify/
ground_truth.rs

1//! Ground Truth data structures for pipeline verification.
2//!
3//! Ground truth represents the expected output at a pipeline stage,
4//! extracted from a reference implementation (whisper.cpp, `HuggingFace`, etc.)
5
6use std::io;
7use std::path::Path;
8
9/// Ground truth statistics for a tensor at a pipeline stage.
10#[derive(Debug, Clone)]
11pub struct GroundTruth {
12    /// Mean value
13    mean: f32,
14    /// Standard deviation
15    std: f32,
16    /// Minimum value
17    min: f32,
18    /// Maximum value
19    max: f32,
20    /// Optional raw data for detailed comparison
21    data: Option<Vec<f32>>,
22    /// Shape information
23    shape: Vec<usize>,
24}
25
26impl GroundTruth {
27    /// Create ground truth from precomputed statistics.
28    ///
29    /// # Arguments
30    /// * `mean` - Expected mean value
31    /// * `std` - Expected standard deviation
32    #[must_use]
33    pub fn from_stats(mean: f32, std: f32) -> Self {
34        Self {
35            mean,
36            std,
37            min: f32::NEG_INFINITY,
38            max: f32::INFINITY,
39            data: None,
40            shape: vec![],
41        }
42    }
43
44    /// Create ground truth from a data slice.
45    ///
46    /// Computes mean, std, min, max from the data.
47    pub fn from_slice(data: &[f32]) -> Self {
48        if data.is_empty() {
49            return Self {
50                mean: 0.0,
51                std: 0.0,
52                min: 0.0,
53                max: 0.0,
54                data: None,
55                shape: vec![0],
56            };
57        }
58
59        let n = data.len() as f32;
60        let mean = data.iter().sum::<f32>() / n;
61        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
62        let std = variance.sqrt();
63        let min = data.iter().copied().fold(f32::INFINITY, f32::min);
64        let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
65
66        Self {
67            mean,
68            std,
69            min,
70            max,
71            data: Some(data.to_vec()),
72            shape: vec![data.len()],
73        }
74    }
75
76    /// Create ground truth with shape information.
77    #[must_use]
78    pub fn from_slice_with_shape(data: &[f32], shape: Vec<usize>) -> Self {
79        let mut gt = Self::from_slice(data);
80        gt.shape = shape;
81        gt
82    }
83
84    /// Load ground truth from a binary file.
85    ///
86    /// Format: raw f32 values in little-endian byte order.
87    pub fn from_bin_file<P: AsRef<Path>>(path: P) -> io::Result<Self> {
88        let bytes = std::fs::read(path)?;
89        let data: Vec<f32> = bytes
90            .chunks_exact(4)
91            .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
92            .collect();
93        Ok(Self::from_slice(&data))
94    }
95
96    /// Load ground truth statistics from a JSON file.
97    ///
98    /// Expected format: `{"mean": f32, "std": f32, "min": f32, "max": f32}`
99    pub fn from_json_file<P: AsRef<Path>>(path: P) -> io::Result<Self> {
100        let content = std::fs::read_to_string(path)?;
101        // Simple JSON parsing without serde dependency
102        let mean = Self::extract_json_f32(&content, "mean")?;
103        let std = Self::extract_json_f32(&content, "std")?;
104        let min = Self::extract_json_f32(&content, "min").unwrap_or(f32::NEG_INFINITY);
105        let max = Self::extract_json_f32(&content, "max").unwrap_or(f32::INFINITY);
106
107        Ok(Self {
108            mean,
109            std,
110            min,
111            max,
112            data: None,
113            shape: vec![],
114        })
115    }
116
117    /// Helper to extract a f32 value from JSON content.
118    fn extract_json_f32(content: &str, key: &str) -> io::Result<f32> {
119        let pattern = format!("\"{key}\":");
120        let start = content.find(&pattern).ok_or_else(|| {
121            io::Error::new(
122                io::ErrorKind::InvalidData,
123                format!("key '{key}' not found in JSON"),
124            )
125        })?;
126        let after_key = &content[start + pattern.len()..];
127        let value_start = after_key.trim_start();
128        let value_end = value_start
129            .find([',', '}', '\n'])
130            .unwrap_or(value_start.len());
131        let value_str = value_start[..value_end].trim();
132        value_str.parse::<f32>().map_err(|_| {
133            io::Error::new(
134                io::ErrorKind::InvalidData,
135                format!("could not parse '{value_str}' as f32"),
136            )
137        })
138    }
139
140    /// Get the mean value.
141    #[must_use]
142    pub fn mean(&self) -> f32 {
143        self.mean
144    }
145
146    /// Get the standard deviation.
147    #[must_use]
148    pub fn std(&self) -> f32 {
149        self.std
150    }
151
152    /// Get the minimum value.
153    #[must_use]
154    pub fn min(&self) -> f32 {
155        self.min
156    }
157
158    /// Get the maximum value.
159    #[must_use]
160    pub fn max(&self) -> f32 {
161        self.max
162    }
163
164    /// Get the raw data if available.
165    #[must_use]
166    pub fn data(&self) -> Option<&[f32]> {
167        self.data.as_deref()
168    }
169
170    /// Get the shape.
171    #[must_use]
172    pub fn shape(&self) -> &[usize] {
173        &self.shape
174    }
175
176    /// Check if raw data is available for detailed comparison.
177    #[must_use]
178    pub fn has_data(&self) -> bool {
179        self.data.is_some()
180    }
181}
182
183impl Default for GroundTruth {
184    fn default() -> Self {
185        Self::from_stats(0.0, 1.0)
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    #[test]
194    fn test_from_stats() {
195        let gt = GroundTruth::from_stats(-0.215, 0.448);
196        assert!((gt.mean() - (-0.215)).abs() < 1e-6);
197        assert!((gt.std() - 0.448).abs() < 1e-6);
198    }
199
200    #[test]
201    fn test_from_slice_mean() {
202        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
203        let gt = GroundTruth::from_slice(&data);
204        assert!((gt.mean() - 3.0).abs() < 1e-6);
205    }
206
207    #[test]
208    fn test_from_slice_std() {
209        // Data with known population std
210        let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
211        let gt = GroundTruth::from_slice(&data);
212        // Mean = 5.0, Variance = 4.0, Std = 2.0
213        assert!((gt.std() - 2.0).abs() < 1e-6);
214    }
215
216    #[test]
217    fn test_from_slice_min_max() {
218        let data = vec![-5.0, 0.0, 10.0, 3.0];
219        let gt = GroundTruth::from_slice(&data);
220        assert!((gt.min() - (-5.0)).abs() < 1e-6);
221        assert!((gt.max() - 10.0).abs() < 1e-6);
222    }
223
224    #[test]
225    fn test_empty_slice() {
226        let data: Vec<f32> = vec![];
227        let gt = GroundTruth::from_slice(&data);
228        assert_eq!(gt.mean(), 0.0);
229        assert_eq!(gt.std(), 0.0);
230    }
231
232    #[test]
233    fn test_from_slice_with_shape() {
234        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
235        let gt = GroundTruth::from_slice_with_shape(&data, vec![2, 3]);
236        assert_eq!(gt.shape(), &[2, 3]);
237        assert!(gt.has_data());
238        assert_eq!(
239            gt.data()
240                .expect("data should be present for slice-constructed GroundTruth")
241                .len(),
242            6
243        );
244    }
245
246    #[test]
247    fn test_has_data() {
248        let gt_with_data = GroundTruth::from_slice(&[1.0, 2.0, 3.0]);
249        let gt_stats_only = GroundTruth::from_stats(1.0, 0.5);
250        assert!(gt_with_data.has_data());
251        assert!(!gt_stats_only.has_data());
252    }
253
254    #[test]
255    fn test_data_accessor() {
256        let data = vec![1.0, 2.0, 3.0];
257        let gt = GroundTruth::from_slice(&data);
258        assert!(gt.data().is_some());
259        assert_eq!(
260            gt.data()
261                .expect("data should be present for slice-constructed GroundTruth"),
262            &[1.0, 2.0, 3.0]
263        );
264    }
265
266    #[test]
267    fn test_shape_accessor() {
268        let gt = GroundTruth::from_slice(&[1.0, 2.0, 3.0, 4.0]);
269        assert_eq!(gt.shape(), &[4]);
270    }
271
272    #[test]
273    fn test_default() {
274        let gt = GroundTruth::default();
275        assert!((gt.mean() - 0.0).abs() < 1e-6);
276        assert!((gt.std() - 1.0).abs() < 1e-6);
277    }
278
279    #[test]
280    fn test_from_stats_min_max_defaults() {
281        let gt = GroundTruth::from_stats(0.0, 1.0);
282        assert!(gt.min().is_infinite() && gt.min().is_sign_negative());
283        assert!(gt.max().is_infinite() && gt.max().is_sign_positive());
284    }
285
286    #[test]
287    fn test_from_bin_file() {
288        use std::io::Write;
289        let dir = tempfile::tempdir().expect("tempdir creation should succeed");
290        let path = dir.path().join("test.bin");
291        let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
292        let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
293        std::fs::File::create(&path)
294            .expect("test file creation should succeed")
295            .write_all(&bytes)
296            .expect("test file write should succeed");
297
298        let gt = GroundTruth::from_bin_file(&path)
299            .expect("from_bin_file should parse valid binary data");
300        assert!((gt.mean() - 3.0).abs() < 1e-6);
301        assert!(gt.has_data());
302    }
303
304    #[test]
305    fn test_from_bin_file_not_found() {
306        let result = GroundTruth::from_bin_file("/nonexistent/path.bin");
307        assert!(result.is_err());
308    }
309
310    #[test]
311    fn test_from_json_file() {
312        use std::io::Write;
313        let dir = tempfile::tempdir().expect("tempdir creation should succeed");
314        let path = dir.path().join("test.json");
315        let json = r#"{"mean": 0.5, "std": 1.2, "min": -0.1, "max": 2.0}"#;
316        std::fs::File::create(&path)
317            .expect("test file creation should succeed")
318            .write_all(json.as_bytes())
319            .expect("test file write should succeed");
320
321        let gt =
322            GroundTruth::from_json_file(&path).expect("from_json_file should parse valid JSON");
323        assert!((gt.mean() - 0.5).abs() < 1e-6);
324        assert!((gt.std() - 1.2).abs() < 1e-6);
325        assert!((gt.min() - (-0.1)).abs() < 1e-6);
326        assert!((gt.max() - 2.0).abs() < 1e-6);
327    }
328
329    #[test]
330    fn test_from_json_file_partial() {
331        use std::io::Write;
332        let dir = tempfile::tempdir().expect("tempdir creation should succeed");
333        let path = dir.path().join("partial.json");
334        let json = r#"{"mean": 0.5, "std": 1.2}"#;
335        std::fs::File::create(&path)
336            .expect("test file creation should succeed")
337            .write_all(json.as_bytes())
338            .expect("test file write should succeed");
339
340        let gt = GroundTruth::from_json_file(&path)
341            .expect("from_json_file should parse partial JSON with defaults");
342        assert!((gt.mean() - 0.5).abs() < 1e-6);
343        assert!(gt.min().is_infinite()); // Default
344        assert!(gt.max().is_infinite()); // Default
345    }
346
347    #[test]
348    fn test_from_json_file_not_found() {
349        let result = GroundTruth::from_json_file("/nonexistent/path.json");
350        assert!(result.is_err());
351    }
352
353    #[test]
354    fn test_from_json_file_missing_key() {
355        use std::io::Write;
356        let dir = tempfile::tempdir().expect("tempdir creation should succeed");
357        let path = dir.path().join("missing.json");
358        let json = r#"{"std": 1.2}"#; // Missing "mean"
359        std::fs::File::create(&path)
360            .expect("test file creation should succeed")
361            .write_all(json.as_bytes())
362            .expect("test file write should succeed");
363
364        let result = GroundTruth::from_json_file(&path);
365        assert!(result.is_err());
366    }
367
368    #[test]
369    fn test_from_json_file_invalid_value() {
370        use std::io::Write;
371        let dir = tempfile::tempdir().expect("tempdir creation should succeed");
372        let path = dir.path().join("invalid.json");
373        let json = r#"{"mean": "not_a_number", "std": 1.2}"#;
374        std::fs::File::create(&path)
375            .expect("test file creation should succeed")
376            .write_all(json.as_bytes())
377            .expect("test file write should succeed");
378
379        let result = GroundTruth::from_json_file(&path);
380        assert!(result.is_err());
381    }
382}