Skip to main content

oar_ocr_core/core/
validation.rs

1//! Input Validation Utilities
2//!
3//! This module provides comprehensive validation utilities to prevent runtime panics
4//! and ensure data integrity across the OCR pipeline.
5
6use crate::core::OCRError;
7
8/// Validates that a float value is finite (not NaN or infinite).
9#[inline]
10pub fn validate_finite(value: f32, param_name: &str) -> Result<(), OCRError> {
11    if !value.is_finite() {
12        return Err(OCRError::InvalidInput {
13            message: format!("Parameter '{}' must be finite, got: {}", param_name, value),
14        });
15    }
16    Ok(())
17}
18
19/// Validates that a value is within a specified range (inclusive).
20#[inline]
21pub fn validate_range<T: PartialOrd + std::fmt::Display>(
22    value: T,
23    min: T,
24    max: T,
25    param_name: &str,
26) -> Result<(), OCRError> {
27    if value < min || value > max {
28        return Err(OCRError::InvalidInput {
29            message: format!(
30                "Parameter '{}' must be in range [{}, {}], got: {}",
31                param_name, min, max, value
32            ),
33        });
34    }
35    Ok(())
36}
37
38/// Validates that a value is positive (> 0).
39#[inline]
40pub fn validate_positive<T: PartialOrd + std::fmt::Display + Default>(
41    value: T,
42    param_name: &str,
43) -> Result<(), OCRError> {
44    if value <= T::default() {
45        return Err(OCRError::InvalidInput {
46            message: format!(
47                "Parameter '{}' must be positive, got: {}",
48                param_name, value
49            ),
50        });
51    }
52    Ok(())
53}
54
55/// Validates that a value is non-negative (>= 0).
56#[inline]
57pub fn validate_non_negative<T: PartialOrd + std::fmt::Display + Default>(
58    value: T,
59    param_name: &str,
60) -> Result<(), OCRError> {
61    if value < T::default() {
62        return Err(OCRError::InvalidInput {
63            message: format!(
64                "Parameter '{}' must be non-negative, got: {}",
65                param_name, value
66            ),
67        });
68    }
69    Ok(())
70}
71
72/// Validates that a collection is not empty.
73#[inline]
74pub fn validate_non_empty<T>(items: &[T], param_name: &str) -> Result<(), OCRError> {
75    if items.is_empty() {
76        return Err(OCRError::InvalidInput {
77            message: format!("Parameter '{}' cannot be empty", param_name),
78        });
79    }
80    Ok(())
81}
82
83/// Validates that two collections have the same length.
84#[inline]
85pub fn validate_same_length<T, U>(
86    items1: &[T],
87    items2: &[U],
88    name1: &str,
89    name2: &str,
90) -> Result<(), OCRError> {
91    if items1.len() != items2.len() {
92        return Err(OCRError::InvalidInput {
93            message: format!(
94                "Length mismatch: {} has {} elements, but {} has {} elements",
95                name1,
96                items1.len(),
97                name2,
98                items2.len()
99            ),
100        });
101    }
102    Ok(())
103}
104
105/// Validates tensor shape dimensions.
106pub fn validate_tensor_shape(
107    shape: &[usize],
108    expected_dims: usize,
109    tensor_name: &str,
110) -> Result<(), OCRError> {
111    if shape.len() != expected_dims {
112        return Err(OCRError::InvalidInput {
113            message: format!(
114                "Tensor '{}' expected {}D shape, got {}D: {:?}",
115                tensor_name,
116                expected_dims,
117                shape.len(),
118                shape
119            ),
120        });
121    }
122    Ok(())
123}
124
125/// Validates that tensor batch size is positive.
126pub fn validate_batch_size(shape: &[usize], tensor_name: &str) -> Result<usize, OCRError> {
127    validate_tensor_shape(shape, 4, tensor_name)?;
128
129    let batch_size = shape[0];
130    if batch_size == 0 {
131        return Err(OCRError::InvalidInput {
132            message: format!(
133                "Tensor '{}' has zero batch size. Shape: {:?}",
134                tensor_name, shape
135            ),
136        });
137    }
138
139    Ok(batch_size)
140}
141
142/// Validates image dimensions.
143pub fn validate_image_dimensions(height: u32, width: u32, context: &str) -> Result<(), OCRError> {
144    if height == 0 || width == 0 {
145        return Err(OCRError::InvalidInput {
146            message: format!(
147                "{}: image dimensions must be positive, got {}x{}",
148                context, height, width
149            ),
150        });
151    }
152
153    // Reasonable upper bounds to prevent memory issues
154    const MAX_DIMENSION: u32 = 32768;
155    if height > MAX_DIMENSION || width > MAX_DIMENSION {
156        return Err(OCRError::InvalidInput {
157            message: format!(
158                "{}: image dimensions exceed maximum of {}x{}, got {}x{}",
159                context, MAX_DIMENSION, MAX_DIMENSION, height, width
160            ),
161        });
162    }
163
164    Ok(())
165}
166
167/// Validates that array index is within bounds.
168#[inline]
169pub fn validate_index_bounds<T>(
170    slice: &[T],
171    index: usize,
172    slice_name: &str,
173) -> Result<(), OCRError> {
174    if index >= slice.len() {
175        return Err(OCRError::InvalidInput {
176            message: format!(
177                "Index out of bounds for '{}': index {} >= length {}",
178                slice_name,
179                index,
180                slice.len()
181            ),
182        });
183    }
184    Ok(())
185}
186
187/// Validates division operands to prevent division by zero.
188#[inline]
189pub fn validate_division(numerator: f32, denominator: f32, context: &str) -> Result<(), OCRError> {
190    validate_finite(numerator, &format!("{} numerator", context))?;
191    validate_finite(denominator, &format!("{} denominator", context))?;
192
193    if denominator.abs() < f32::EPSILON {
194        return Err(OCRError::InvalidInput {
195            message: format!(
196                "{}: division by zero (denominator: {})",
197                context, denominator
198            ),
199        });
200    }
201
202    Ok(())
203}
204
205/// Validates normalization parameters (mean and std).
206pub fn validate_normalization_params(
207    mean: &[f32],
208    std: &[f32],
209    num_channels: usize,
210) -> Result<(), OCRError> {
211    // Check lengths match expected channels
212    if mean.len() != num_channels {
213        return Err(OCRError::InvalidInput {
214            message: format!(
215                "Mean length {} does not match number of channels {}",
216                mean.len(),
217                num_channels
218            ),
219        });
220    }
221
222    if std.len() != num_channels {
223        return Err(OCRError::InvalidInput {
224            message: format!(
225                "Std length {} does not match number of channels {}",
226                std.len(),
227                num_channels
228            ),
229        });
230    }
231
232    // Validate all values are finite
233    for (i, &m) in mean.iter().enumerate() {
234        validate_finite(m, &format!("mean[{}]", i))?;
235    }
236
237    // Validate std values are positive and finite
238    for (i, &s) in std.iter().enumerate() {
239        validate_finite(s, &format!("std[{}]", i))?;
240        validate_positive(s, &format!("std[{}]", i))?;
241    }
242
243    Ok(())
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_validate_finite() {
252        assert!(validate_finite(1.0, "test").is_ok());
253        assert!(validate_finite(0.0, "test").is_ok());
254        assert!(validate_finite(-1.0, "test").is_ok());
255        assert!(validate_finite(f32::NAN, "test").is_err());
256        assert!(validate_finite(f32::INFINITY, "test").is_err());
257        assert!(validate_finite(f32::NEG_INFINITY, "test").is_err());
258    }
259
260    #[test]
261    fn test_validate_range() {
262        assert!(validate_range(5.0, 0.0, 10.0, "test").is_ok());
263        assert!(validate_range(0.0, 0.0, 10.0, "test").is_ok());
264        assert!(validate_range(10.0, 0.0, 10.0, "test").is_ok());
265        assert!(validate_range(-1.0, 0.0, 10.0, "test").is_err());
266        assert!(validate_range(11.0, 0.0, 10.0, "test").is_err());
267    }
268
269    #[test]
270    fn test_validate_positive() {
271        assert!(validate_positive(1.0, "test").is_ok());
272        assert!(validate_positive(0.1, "test").is_ok());
273        assert!(validate_positive(0.0, "test").is_err());
274        assert!(validate_positive(-1.0, "test").is_err());
275    }
276
277    #[test]
278    fn test_validate_non_negative() {
279        assert!(validate_non_negative(1.0, "test").is_ok());
280        assert!(validate_non_negative(0.0, "test").is_ok());
281        assert!(validate_non_negative(-1.0, "test").is_err());
282    }
283
284    #[test]
285    fn test_validate_non_empty() {
286        assert!(validate_non_empty(&[1, 2, 3], "test").is_ok());
287        assert!(validate_non_empty(&[1], "test").is_ok());
288        assert!(validate_non_empty::<i32>(&[], "test").is_err());
289    }
290
291    #[test]
292    fn test_validate_same_length() {
293        assert!(validate_same_length(&[1, 2], &[3, 4], "a", "b").is_ok());
294        assert!(validate_same_length(&[1], &[2], "a", "b").is_ok());
295        assert!(validate_same_length(&[1, 2], &[3], "a", "b").is_err());
296    }
297
298    #[test]
299    fn test_validate_tensor_shape() {
300        assert!(validate_tensor_shape(&[1, 3, 224, 224], 4, "test").is_ok());
301        assert!(validate_tensor_shape(&[1, 3, 224], 3, "test").is_ok());
302        assert!(validate_tensor_shape(&[1, 3, 224], 4, "test").is_err());
303    }
304
305    #[test]
306    fn test_validate_batch_size() {
307        match validate_batch_size(&[2, 3, 224, 224], "test") {
308            Ok(batch_size) => assert_eq!(batch_size, 2),
309            Err(err) => panic!("expected validate_batch_size to succeed: {err}"),
310        }
311        match validate_batch_size(&[1, 3, 224, 224], "test") {
312            Ok(batch_size) => assert_eq!(batch_size, 1),
313            Err(err) => panic!("expected validate_batch_size to succeed: {err}"),
314        }
315        assert!(validate_batch_size(&[0, 3, 224, 224], "test").is_err());
316        assert!(validate_batch_size(&[1, 3, 224], "test").is_err());
317    }
318
319    #[test]
320    fn test_validate_image_dimensions() {
321        assert!(validate_image_dimensions(224, 224, "test").is_ok());
322        assert!(validate_image_dimensions(1, 1, "test").is_ok());
323        assert!(validate_image_dimensions(0, 224, "test").is_err());
324        assert!(validate_image_dimensions(224, 0, "test").is_err());
325        assert!(validate_image_dimensions(99999, 99999, "test").is_err());
326    }
327
328    #[test]
329    fn test_validate_division() {
330        assert!(validate_division(10.0, 2.0, "test").is_ok());
331        assert!(validate_division(0.0, 2.0, "test").is_ok());
332        assert!(validate_division(10.0, 0.0, "test").is_err());
333        assert!(validate_division(10.0, 1e-10, "test").is_err());
334        assert!(validate_division(f32::NAN, 2.0, "test").is_err());
335        assert!(validate_division(10.0, f32::INFINITY, "test").is_err());
336    }
337
338    #[test]
339    fn test_validate_normalization_params() {
340        assert!(
341            validate_normalization_params(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3)
342                .is_ok()
343        );
344
345        // Wrong length
346        assert!(validate_normalization_params(&[0.485, 0.456], &[0.229, 0.224, 0.225], 3).is_err());
347
348        // NaN in mean
349        assert!(
350            validate_normalization_params(&[f32::NAN, 0.456, 0.406], &[0.229, 0.224, 0.225], 3)
351                .is_err()
352        );
353
354        // Zero in std
355        assert!(
356            validate_normalization_params(&[0.485, 0.456, 0.406], &[0.0, 0.224, 0.225], 3).is_err()
357        );
358
359        // Negative in std
360        assert!(
361            validate_normalization_params(&[0.485, 0.456, 0.406], &[-0.229, 0.224, 0.225], 3)
362                .is_err()
363        );
364    }
365}