oar_ocr_core/utils/
validation.rs

1//! Reusable validation components for OCR tasks.
2//!
3//! This module provides composable validators that can be used across different tasks
4//! to validate common patterns like score ranges, dimensions, and other constraints.
5
6use crate::core::OCRError;
7
8/// A reusable validator for score ranges.
9///
10/// This validator can be configured with custom min/max bounds and field names,
11/// making it suitable for validating confidence scores, probabilities, and other
12/// numerical ranges across different tasks.
13///
14/// # Examples
15///
16/// ```rust,no_run
17/// # use oar_ocr_core::utils::validation::ScoreValidator;
18/// # use oar_ocr_core::core::OCRError;
19/// # fn example() -> Result<(), OCRError> {
20/// let validator = ScoreValidator::new_unit_range("confidence");
21/// validator.validate_scores(&[0.5, 0.8, 0.95], "Detection")?;
22/// #     Ok(())
23/// # }
24/// # example().unwrap();
25/// ```
26#[derive(Debug, Clone)]
27pub struct ScoreValidator {
28    min: f32,
29    max: f32,
30    field_name: String,
31}
32
33impl ScoreValidator {
34    /// Creates a new score validator with custom bounds.
35    ///
36    /// # Arguments
37    ///
38    /// * `min` - Minimum valid score (inclusive)
39    /// * `max` - Maximum valid score (inclusive)
40    /// * `field_name` - Name of the field being validated (for error messages)
41    pub fn new(min: f32, max: f32, field_name: impl Into<String>) -> Self {
42        Self {
43            min,
44            max,
45            field_name: field_name.into(),
46        }
47    }
48
49    /// Creates a validator for unit range scores [0.0, 1.0].
50    ///
51    /// This is the most common case for confidence scores and probabilities.
52    pub fn new_unit_range(field_name: impl Into<String>) -> Self {
53        Self::new(0.0, 1.0, field_name)
54    }
55
56    /// Validates a single score value.
57    ///
58    /// # Errors
59    ///
60    /// Returns `OCRError::InvalidInput` if the score is outside the valid range.
61    pub fn validate_score(&self, score: f32, context: &str) -> Result<(), OCRError> {
62        if !(self.min..=self.max).contains(&score) {
63            return Err(OCRError::InvalidInput {
64                message: format!(
65                    "{}: {} {} is out of valid range [{}, {}]",
66                    context, self.field_name, score, self.min, self.max
67                ),
68            });
69        }
70        Ok(())
71    }
72
73    /// Validates a collection of scores.
74    ///
75    /// # Errors
76    ///
77    /// Returns `OCRError::InvalidInput` if any score is outside the valid range.
78    /// The error message includes the index of the invalid score.
79    pub fn validate_scores(&self, scores: &[f32], context_prefix: &str) -> Result<(), OCRError> {
80        for (idx, &score) in scores.iter().enumerate() {
81            self.validate_score(score, &format!("{} {}", context_prefix, idx))?;
82        }
83        Ok(())
84    }
85
86    /// Validates a collection of scores with a custom index formatter.
87    ///
88    /// This is useful when you want to provide more context in error messages,
89    /// such as "Image 3, detection 2" instead of just an index.
90    pub fn validate_scores_with<F>(&self, scores: &[f32], format_context: F) -> Result<(), OCRError>
91    where
92        F: Fn(usize) -> String,
93    {
94        for (idx, &score) in scores.iter().enumerate() {
95            self.validate_score(score, &format_context(idx))?;
96        }
97        Ok(())
98    }
99}
100
101/// Validates that a vector's length matches an expected size.
102///
103/// # Errors
104///
105/// Returns `OCRError::InvalidInput` if lengths don't match.
106pub fn validate_length_match(
107    actual: usize,
108    expected: usize,
109    actual_name: &str,
110    expected_name: &str,
111) -> Result<(), OCRError> {
112    if actual != expected {
113        return Err(OCRError::InvalidInput {
114            message: format!(
115                "Mismatch between {} count ({}) and {} count ({})",
116                actual_name, actual, expected_name, expected
117            ),
118        });
119    }
120    Ok(())
121}
122
123/// Validates that a value doesn't exceed a maximum.
124///
125/// # Errors
126///
127/// Returns `OCRError::InvalidInput` if value exceeds maximum.
128pub fn validate_max_value<T: PartialOrd + std::fmt::Display>(
129    value: T,
130    max: T,
131    field_name: &str,
132    context: &str,
133) -> Result<(), OCRError> {
134    if value > max {
135        return Err(OCRError::InvalidInput {
136            message: format!(
137                "{}: {} {} exceeds maximum {}",
138                context, field_name, value, max
139            ),
140        });
141    }
142    Ok(())
143}
144
145/// Validates that dimensions are positive (non-zero).
146///
147/// # Errors
148///
149/// Returns `OCRError::InvalidInput` if either dimension is zero.
150pub fn validate_positive_dimensions(
151    width: u32,
152    height: u32,
153    context: &str,
154) -> Result<(), OCRError> {
155    if width == 0 || height == 0 {
156        return Err(OCRError::InvalidInput {
157            message: format!(
158                "{}: invalid dimensions width={}, height={} (must be positive)",
159                context, width, height
160            ),
161        });
162    }
163    Ok(())
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_score_validator_unit_range() {
172        let validator = ScoreValidator::new_unit_range("score");
173
174        // Valid scores should pass
175        assert!(validator.validate_score(0.0, "test").is_ok());
176        assert!(validator.validate_score(0.5, "test").is_ok());
177        assert!(validator.validate_score(1.0, "test").is_ok());
178
179        // Invalid scores should fail
180        assert!(validator.validate_score(-0.1, "test").is_err());
181        assert!(validator.validate_score(1.1, "test").is_err());
182    }
183
184    #[test]
185    fn test_score_validator_custom_range() {
186        let validator = ScoreValidator::new(0.5, 2.0, "custom");
187
188        // Valid scores should pass
189        assert!(validator.validate_score(0.5, "test").is_ok());
190        assert!(validator.validate_score(1.0, "test").is_ok());
191        assert!(validator.validate_score(2.0, "test").is_ok());
192
193        // Invalid scores should fail
194        assert!(validator.validate_score(0.4, "test").is_err());
195        assert!(validator.validate_score(2.1, "test").is_err());
196    }
197
198    #[test]
199    fn test_validate_scores() {
200        let validator = ScoreValidator::new_unit_range("score");
201
202        // All valid scores
203        assert!(validator.validate_scores(&[0.1, 0.5, 0.9], "test").is_ok());
204
205        // One invalid score
206        assert!(validator.validate_scores(&[0.1, 1.5, 0.9], "test").is_err());
207    }
208
209    #[test]
210    fn test_validate_scores_with_formatter() {
211        let validator = ScoreValidator::new_unit_range("score");
212
213        let result = validator
214            .validate_scores_with(&[0.5, 1.5], |idx| format!("Image 0, detection {}", idx));
215
216        assert!(result.is_err());
217        let err_msg = format!("{:?}", result.unwrap_err());
218        assert!(err_msg.contains("detection 1"));
219    }
220
221    #[test]
222    fn test_validate_length_match() {
223        assert!(validate_length_match(3, 3, "texts", "scores").is_ok());
224        assert!(validate_length_match(3, 5, "texts", "scores").is_err());
225    }
226
227    #[test]
228    fn test_validate_max_value() {
229        assert!(validate_max_value(50, 100, "length", "text 0").is_ok());
230        assert!(validate_max_value(150, 100, "length", "text 0").is_err());
231    }
232
233    #[test]
234    fn test_validate_positive_dimensions() {
235        assert!(validate_positive_dimensions(100, 200, "image 0").is_ok());
236        assert!(validate_positive_dimensions(0, 200, "image 0").is_err());
237        assert!(validate_positive_dimensions(100, 0, "image 0").is_err());
238    }
239}