1use std::io;
7use std::path::Path;
8
9#[derive(Debug, Clone)]
11pub struct GroundTruth {
12 mean: f32,
14 std: f32,
16 min: f32,
18 max: f32,
20 data: Option<Vec<f32>>,
22 shape: Vec<usize>,
24}
25
26impl GroundTruth {
27 #[must_use]
33 pub fn from_stats(mean: f32, std: f32) -> Self {
34 Self {
35 mean,
36 std,
37 min: f32::NEG_INFINITY,
38 max: f32::INFINITY,
39 data: None,
40 shape: vec![],
41 }
42 }
43
44 pub fn from_slice(data: &[f32]) -> Self {
48 if data.is_empty() {
49 return Self {
50 mean: 0.0,
51 std: 0.0,
52 min: 0.0,
53 max: 0.0,
54 data: None,
55 shape: vec![0],
56 };
57 }
58
59 let n = data.len() as f32;
60 let mean = data.iter().sum::<f32>() / n;
61 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / n;
62 let std = variance.sqrt();
63 let min = data.iter().copied().fold(f32::INFINITY, f32::min);
64 let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
65
66 Self {
67 mean,
68 std,
69 min,
70 max,
71 data: Some(data.to_vec()),
72 shape: vec![data.len()],
73 }
74 }
75
76 #[must_use]
78 pub fn from_slice_with_shape(data: &[f32], shape: Vec<usize>) -> Self {
79 let mut gt = Self::from_slice(data);
80 gt.shape = shape;
81 gt
82 }
83
84 pub fn from_bin_file<P: AsRef<Path>>(path: P) -> io::Result<Self> {
88 let bytes = std::fs::read(path)?;
89 let data: Vec<f32> = bytes
90 .chunks_exact(4)
91 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
92 .collect();
93 Ok(Self::from_slice(&data))
94 }
95
96 pub fn from_json_file<P: AsRef<Path>>(path: P) -> io::Result<Self> {
100 let content = std::fs::read_to_string(path)?;
101 let mean = Self::extract_json_f32(&content, "mean")?;
103 let std = Self::extract_json_f32(&content, "std")?;
104 let min = Self::extract_json_f32(&content, "min").unwrap_or(f32::NEG_INFINITY);
105 let max = Self::extract_json_f32(&content, "max").unwrap_or(f32::INFINITY);
106
107 Ok(Self {
108 mean,
109 std,
110 min,
111 max,
112 data: None,
113 shape: vec![],
114 })
115 }
116
117 fn extract_json_f32(content: &str, key: &str) -> io::Result<f32> {
119 let pattern = format!("\"{key}\":");
120 let start = content.find(&pattern).ok_or_else(|| {
121 io::Error::new(
122 io::ErrorKind::InvalidData,
123 format!("key '{key}' not found in JSON"),
124 )
125 })?;
126 let after_key = &content[start + pattern.len()..];
127 let value_start = after_key.trim_start();
128 let value_end = value_start
129 .find([',', '}', '\n'])
130 .unwrap_or(value_start.len());
131 let value_str = value_start[..value_end].trim();
132 value_str.parse::<f32>().map_err(|_| {
133 io::Error::new(
134 io::ErrorKind::InvalidData,
135 format!("could not parse '{value_str}' as f32"),
136 )
137 })
138 }
139
140 #[must_use]
142 pub fn mean(&self) -> f32 {
143 self.mean
144 }
145
146 #[must_use]
148 pub fn std(&self) -> f32 {
149 self.std
150 }
151
152 #[must_use]
154 pub fn min(&self) -> f32 {
155 self.min
156 }
157
158 #[must_use]
160 pub fn max(&self) -> f32 {
161 self.max
162 }
163
164 #[must_use]
166 pub fn data(&self) -> Option<&[f32]> {
167 self.data.as_deref()
168 }
169
170 #[must_use]
172 pub fn shape(&self) -> &[usize] {
173 &self.shape
174 }
175
176 #[must_use]
178 pub fn has_data(&self) -> bool {
179 self.data.is_some()
180 }
181}
182
183impl Default for GroundTruth {
184 fn default() -> Self {
185 Self::from_stats(0.0, 1.0)
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 #[test]
194 fn test_from_stats() {
195 let gt = GroundTruth::from_stats(-0.215, 0.448);
196 assert!((gt.mean() - (-0.215)).abs() < 1e-6);
197 assert!((gt.std() - 0.448).abs() < 1e-6);
198 }
199
200 #[test]
201 fn test_from_slice_mean() {
202 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
203 let gt = GroundTruth::from_slice(&data);
204 assert!((gt.mean() - 3.0).abs() < 1e-6);
205 }
206
207 #[test]
208 fn test_from_slice_std() {
209 let data = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
211 let gt = GroundTruth::from_slice(&data);
212 assert!((gt.std() - 2.0).abs() < 1e-6);
214 }
215
216 #[test]
217 fn test_from_slice_min_max() {
218 let data = vec![-5.0, 0.0, 10.0, 3.0];
219 let gt = GroundTruth::from_slice(&data);
220 assert!((gt.min() - (-5.0)).abs() < 1e-6);
221 assert!((gt.max() - 10.0).abs() < 1e-6);
222 }
223
224 #[test]
225 fn test_empty_slice() {
226 let data: Vec<f32> = vec![];
227 let gt = GroundTruth::from_slice(&data);
228 assert_eq!(gt.mean(), 0.0);
229 assert_eq!(gt.std(), 0.0);
230 }
231
232 #[test]
233 fn test_from_slice_with_shape() {
234 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
235 let gt = GroundTruth::from_slice_with_shape(&data, vec![2, 3]);
236 assert_eq!(gt.shape(), &[2, 3]);
237 assert!(gt.has_data());
238 assert_eq!(
239 gt.data()
240 .expect("data should be present for slice-constructed GroundTruth")
241 .len(),
242 6
243 );
244 }
245
246 #[test]
247 fn test_has_data() {
248 let gt_with_data = GroundTruth::from_slice(&[1.0, 2.0, 3.0]);
249 let gt_stats_only = GroundTruth::from_stats(1.0, 0.5);
250 assert!(gt_with_data.has_data());
251 assert!(!gt_stats_only.has_data());
252 }
253
254 #[test]
255 fn test_data_accessor() {
256 let data = vec![1.0, 2.0, 3.0];
257 let gt = GroundTruth::from_slice(&data);
258 assert!(gt.data().is_some());
259 assert_eq!(
260 gt.data()
261 .expect("data should be present for slice-constructed GroundTruth"),
262 &[1.0, 2.0, 3.0]
263 );
264 }
265
266 #[test]
267 fn test_shape_accessor() {
268 let gt = GroundTruth::from_slice(&[1.0, 2.0, 3.0, 4.0]);
269 assert_eq!(gt.shape(), &[4]);
270 }
271
272 #[test]
273 fn test_default() {
274 let gt = GroundTruth::default();
275 assert!((gt.mean() - 0.0).abs() < 1e-6);
276 assert!((gt.std() - 1.0).abs() < 1e-6);
277 }
278
279 #[test]
280 fn test_from_stats_min_max_defaults() {
281 let gt = GroundTruth::from_stats(0.0, 1.0);
282 assert!(gt.min().is_infinite() && gt.min().is_sign_negative());
283 assert!(gt.max().is_infinite() && gt.max().is_sign_positive());
284 }
285
286 #[test]
287 fn test_from_bin_file() {
288 use std::io::Write;
289 let dir = tempfile::tempdir().expect("tempdir creation should succeed");
290 let path = dir.path().join("test.bin");
291 let data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0];
292 let bytes: Vec<u8> = data.iter().flat_map(|f| f.to_le_bytes()).collect();
293 std::fs::File::create(&path)
294 .expect("test file creation should succeed")
295 .write_all(&bytes)
296 .expect("test file write should succeed");
297
298 let gt = GroundTruth::from_bin_file(&path)
299 .expect("from_bin_file should parse valid binary data");
300 assert!((gt.mean() - 3.0).abs() < 1e-6);
301 assert!(gt.has_data());
302 }
303
304 #[test]
305 fn test_from_bin_file_not_found() {
306 let result = GroundTruth::from_bin_file("/nonexistent/path.bin");
307 assert!(result.is_err());
308 }
309
310 #[test]
311 fn test_from_json_file() {
312 use std::io::Write;
313 let dir = tempfile::tempdir().expect("tempdir creation should succeed");
314 let path = dir.path().join("test.json");
315 let json = r#"{"mean": 0.5, "std": 1.2, "min": -0.1, "max": 2.0}"#;
316 std::fs::File::create(&path)
317 .expect("test file creation should succeed")
318 .write_all(json.as_bytes())
319 .expect("test file write should succeed");
320
321 let gt =
322 GroundTruth::from_json_file(&path).expect("from_json_file should parse valid JSON");
323 assert!((gt.mean() - 0.5).abs() < 1e-6);
324 assert!((gt.std() - 1.2).abs() < 1e-6);
325 assert!((gt.min() - (-0.1)).abs() < 1e-6);
326 assert!((gt.max() - 2.0).abs() < 1e-6);
327 }
328
329 #[test]
330 fn test_from_json_file_partial() {
331 use std::io::Write;
332 let dir = tempfile::tempdir().expect("tempdir creation should succeed");
333 let path = dir.path().join("partial.json");
334 let json = r#"{"mean": 0.5, "std": 1.2}"#;
335 std::fs::File::create(&path)
336 .expect("test file creation should succeed")
337 .write_all(json.as_bytes())
338 .expect("test file write should succeed");
339
340 let gt = GroundTruth::from_json_file(&path)
341 .expect("from_json_file should parse partial JSON with defaults");
342 assert!((gt.mean() - 0.5).abs() < 1e-6);
343 assert!(gt.min().is_infinite()); assert!(gt.max().is_infinite()); }
346
347 #[test]
348 fn test_from_json_file_not_found() {
349 let result = GroundTruth::from_json_file("/nonexistent/path.json");
350 assert!(result.is_err());
351 }
352
353 #[test]
354 fn test_from_json_file_missing_key() {
355 use std::io::Write;
356 let dir = tempfile::tempdir().expect("tempdir creation should succeed");
357 let path = dir.path().join("missing.json");
358 let json = r#"{"std": 1.2}"#; std::fs::File::create(&path)
360 .expect("test file creation should succeed")
361 .write_all(json.as_bytes())
362 .expect("test file write should succeed");
363
364 let result = GroundTruth::from_json_file(&path);
365 assert!(result.is_err());
366 }
367
368 #[test]
369 fn test_from_json_file_invalid_value() {
370 use std::io::Write;
371 let dir = tempfile::tempdir().expect("tempdir creation should succeed");
372 let path = dir.path().join("invalid.json");
373 let json = r#"{"mean": "not_a_number", "std": 1.2}"#;
374 std::fs::File::create(&path)
375 .expect("test file creation should succeed")
376 .write_all(json.as_bytes())
377 .expect("test file write should succeed");
378
379 let result = GroundTruth::from_json_file(&path);
380 assert!(result.is_err());
381 }
382}