Skip to main content

oar_ocr_core/utils/
validation.rs

1//! Reusable validation components for OCR tasks.
2//!
3//! This module provides composable validators that can be used across different tasks
4//! to validate common patterns like score ranges, dimensions, and other constraints.
5
6use crate::core::OCRError;
7
8/// A reusable validator for score ranges.
9///
10/// This validator can be configured with custom min/max bounds and field names,
11/// making it suitable for validating confidence scores, probabilities, and other
12/// numerical ranges across different tasks.
13///
14/// # Examples
15///
16/// ```rust,no_run
17/// use oar_ocr_core::utils::validation::ScoreValidator;
18/// use oar_ocr_core::core::OCRError;
19/// # fn main() -> Result<(), OCRError> {
20/// let validator = ScoreValidator::new_unit_range("confidence");
21/// validator.validate_scores(&[0.5, 0.8, 0.95], "Detection")?;
22/// # Ok(())
23/// # }
24/// ```
25#[derive(Debug, Clone)]
26pub struct ScoreValidator {
27    min: f32,
28    max: f32,
29    field_name: String,
30}
31
32impl ScoreValidator {
33    /// Creates a new score validator with custom bounds.
34    ///
35    /// # Arguments
36    ///
37    /// * `min` - Minimum valid score (inclusive)
38    /// * `max` - Maximum valid score (inclusive)
39    /// * `field_name` - Name of the field being validated (for error messages)
40    pub fn new(min: f32, max: f32, field_name: impl Into<String>) -> Self {
41        Self {
42            min,
43            max,
44            field_name: field_name.into(),
45        }
46    }
47
48    /// Creates a validator for unit range scores [0.0, 1.0].
49    ///
50    /// This is the most common case for confidence scores and probabilities.
51    pub fn new_unit_range(field_name: impl Into<String>) -> Self {
52        Self::new(0.0, 1.0, field_name)
53    }
54
55    /// Validates a single score value.
56    ///
57    /// # Errors
58    ///
59    /// Returns `OCRError::InvalidInput` if the score is outside the valid range.
60    pub fn validate_score(&self, score: f32, context: &str) -> Result<(), OCRError> {
61        if !(self.min..=self.max).contains(&score) {
62            return Err(OCRError::InvalidInput {
63                message: format!(
64                    "{}: {} {} is out of valid range [{}, {}]",
65                    context, self.field_name, score, self.min, self.max
66                ),
67            });
68        }
69        Ok(())
70    }
71
72    /// Validates a collection of scores.
73    ///
74    /// # Errors
75    ///
76    /// Returns `OCRError::InvalidInput` if any score is outside the valid range.
77    /// The error message includes the index of the invalid score.
78    pub fn validate_scores(&self, scores: &[f32], context_prefix: &str) -> Result<(), OCRError> {
79        for (idx, &score) in scores.iter().enumerate() {
80            self.validate_score(score, &format!("{} {}", context_prefix, idx))?;
81        }
82        Ok(())
83    }
84
85    /// Validates a collection of scores with a custom index formatter.
86    ///
87    /// This is useful when you want to provide more context in error messages,
88    /// such as "Image 3, detection 2" instead of just an index.
89    pub fn validate_scores_with<F>(&self, scores: &[f32], format_context: F) -> Result<(), OCRError>
90    where
91        F: Fn(usize) -> String,
92    {
93        for (idx, &score) in scores.iter().enumerate() {
94            self.validate_score(score, &format_context(idx))?;
95        }
96        Ok(())
97    }
98}
99
100/// Validates that a vector's length matches an expected size.
101///
102/// # Errors
103///
104/// Returns `OCRError::InvalidInput` if lengths don't match.
105pub fn validate_length_match(
106    actual: usize,
107    expected: usize,
108    actual_name: &str,
109    expected_name: &str,
110) -> Result<(), OCRError> {
111    if actual != expected {
112        return Err(OCRError::InvalidInput {
113            message: format!(
114                "Mismatch between {} count ({}) and {} count ({})",
115                actual_name, actual, expected_name, expected
116            ),
117        });
118    }
119    Ok(())
120}
121
122/// Validates that a value doesn't exceed a maximum.
123///
124/// # Errors
125///
126/// Returns `OCRError::InvalidInput` if value exceeds maximum.
127pub fn validate_max_value<T: PartialOrd + std::fmt::Display>(
128    value: T,
129    max: T,
130    field_name: &str,
131    context: &str,
132) -> Result<(), OCRError> {
133    if value > max {
134        return Err(OCRError::InvalidInput {
135            message: format!(
136                "{}: {} {} exceeds maximum {}",
137                context, field_name, value, max
138            ),
139        });
140    }
141    Ok(())
142}
143
144/// Validates that dimensions are positive (non-zero).
145///
146/// # Errors
147///
148/// Returns `OCRError::InvalidInput` if either dimension is zero.
149pub fn validate_positive_dimensions(
150    width: u32,
151    height: u32,
152    context: &str,
153) -> Result<(), OCRError> {
154    if width == 0 || height == 0 {
155        return Err(OCRError::InvalidInput {
156            message: format!(
157                "{}: invalid dimensions width={}, height={} (must be positive)",
158                context, width, height
159            ),
160        });
161    }
162    Ok(())
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn test_score_validator_unit_range() {
171        let validator = ScoreValidator::new_unit_range("score");
172
173        // Valid scores should pass
174        assert!(validator.validate_score(0.0, "test").is_ok());
175        assert!(validator.validate_score(0.5, "test").is_ok());
176        assert!(validator.validate_score(1.0, "test").is_ok());
177
178        // Invalid scores should fail
179        assert!(validator.validate_score(-0.1, "test").is_err());
180        assert!(validator.validate_score(1.1, "test").is_err());
181    }
182
183    #[test]
184    fn test_score_validator_custom_range() {
185        let validator = ScoreValidator::new(0.5, 2.0, "custom");
186
187        // Valid scores should pass
188        assert!(validator.validate_score(0.5, "test").is_ok());
189        assert!(validator.validate_score(1.0, "test").is_ok());
190        assert!(validator.validate_score(2.0, "test").is_ok());
191
192        // Invalid scores should fail
193        assert!(validator.validate_score(0.4, "test").is_err());
194        assert!(validator.validate_score(2.1, "test").is_err());
195    }
196
197    #[test]
198    fn test_validate_scores() {
199        let validator = ScoreValidator::new_unit_range("score");
200
201        // All valid scores
202        assert!(validator.validate_scores(&[0.1, 0.5, 0.9], "test").is_ok());
203
204        // One invalid score
205        assert!(validator.validate_scores(&[0.1, 1.5, 0.9], "test").is_err());
206    }
207
208    #[test]
209    fn test_validate_scores_with_formatter() {
210        let validator = ScoreValidator::new_unit_range("score");
211
212        let result = validator
213            .validate_scores_with(&[0.5, 1.5], |idx| format!("Image 0, detection {}", idx));
214
215        assert!(result.is_err());
216        let err_msg = format!("{:?}", result.unwrap_err());
217        assert!(err_msg.contains("detection 1"));
218    }
219
220    #[test]
221    fn test_validate_length_match() {
222        assert!(validate_length_match(3, 3, "texts", "scores").is_ok());
223        assert!(validate_length_match(3, 5, "texts", "scores").is_err());
224    }
225
226    #[test]
227    fn test_validate_max_value() {
228        assert!(validate_max_value(50, 100, "length", "text 0").is_ok());
229        assert!(validate_max_value(150, 100, "length", "text 0").is_err());
230    }
231
232    #[test]
233    fn test_validate_positive_dimensions() {
234        assert!(validate_positive_dimensions(100, 200, "image 0").is_ok());
235        assert!(validate_positive_dimensions(0, 200, "image 0").is_err());
236        assert!(validate_positive_dimensions(100, 0, "image 0").is_err());
237    }
238}