1use crate::core::OCRError;
7
8#[inline]
10pub fn validate_finite(value: f32, param_name: &str) -> Result<(), OCRError> {
11 if !value.is_finite() {
12 return Err(OCRError::InvalidInput {
13 message: format!("Parameter '{}' must be finite, got: {}", param_name, value),
14 });
15 }
16 Ok(())
17}
18
19#[inline]
21pub fn validate_range<T: PartialOrd + std::fmt::Display>(
22 value: T,
23 min: T,
24 max: T,
25 param_name: &str,
26) -> Result<(), OCRError> {
27 if value < min || value > max {
28 return Err(OCRError::InvalidInput {
29 message: format!(
30 "Parameter '{}' must be in range [{}, {}], got: {}",
31 param_name, min, max, value
32 ),
33 });
34 }
35 Ok(())
36}
37
38#[inline]
40pub fn validate_positive<T: PartialOrd + std::fmt::Display + Default>(
41 value: T,
42 param_name: &str,
43) -> Result<(), OCRError> {
44 if value <= T::default() {
45 return Err(OCRError::InvalidInput {
46 message: format!(
47 "Parameter '{}' must be positive, got: {}",
48 param_name, value
49 ),
50 });
51 }
52 Ok(())
53}
54
55#[inline]
57pub fn validate_non_negative<T: PartialOrd + std::fmt::Display + Default>(
58 value: T,
59 param_name: &str,
60) -> Result<(), OCRError> {
61 if value < T::default() {
62 return Err(OCRError::InvalidInput {
63 message: format!(
64 "Parameter '{}' must be non-negative, got: {}",
65 param_name, value
66 ),
67 });
68 }
69 Ok(())
70}
71
72#[inline]
74pub fn validate_non_empty<T>(items: &[T], param_name: &str) -> Result<(), OCRError> {
75 if items.is_empty() {
76 return Err(OCRError::InvalidInput {
77 message: format!("Parameter '{}' cannot be empty", param_name),
78 });
79 }
80 Ok(())
81}
82
83#[inline]
85pub fn validate_same_length<T, U>(
86 items1: &[T],
87 items2: &[U],
88 name1: &str,
89 name2: &str,
90) -> Result<(), OCRError> {
91 if items1.len() != items2.len() {
92 return Err(OCRError::InvalidInput {
93 message: format!(
94 "Length mismatch: {} has {} elements, but {} has {} elements",
95 name1,
96 items1.len(),
97 name2,
98 items2.len()
99 ),
100 });
101 }
102 Ok(())
103}
104
105pub fn validate_tensor_shape(
107 shape: &[usize],
108 expected_dims: usize,
109 tensor_name: &str,
110) -> Result<(), OCRError> {
111 if shape.len() != expected_dims {
112 return Err(OCRError::InvalidInput {
113 message: format!(
114 "Tensor '{}' expected {}D shape, got {}D: {:?}",
115 tensor_name,
116 expected_dims,
117 shape.len(),
118 shape
119 ),
120 });
121 }
122 Ok(())
123}
124
125pub fn validate_batch_size(shape: &[usize], tensor_name: &str) -> Result<usize, OCRError> {
127 validate_tensor_shape(shape, 4, tensor_name)?;
128
129 let batch_size = shape[0];
130 if batch_size == 0 {
131 return Err(OCRError::InvalidInput {
132 message: format!(
133 "Tensor '{}' has zero batch size. Shape: {:?}",
134 tensor_name, shape
135 ),
136 });
137 }
138
139 Ok(batch_size)
140}
141
142pub fn validate_image_dimensions(height: u32, width: u32, context: &str) -> Result<(), OCRError> {
144 if height == 0 || width == 0 {
145 return Err(OCRError::InvalidInput {
146 message: format!(
147 "{}: image dimensions must be positive, got {}x{}",
148 context, height, width
149 ),
150 });
151 }
152
153 const MAX_DIMENSION: u32 = 32768;
155 if height > MAX_DIMENSION || width > MAX_DIMENSION {
156 return Err(OCRError::InvalidInput {
157 message: format!(
158 "{}: image dimensions exceed maximum of {}x{}, got {}x{}",
159 context, MAX_DIMENSION, MAX_DIMENSION, height, width
160 ),
161 });
162 }
163
164 Ok(())
165}
166
167#[inline]
169pub fn validate_index_bounds<T>(
170 slice: &[T],
171 index: usize,
172 slice_name: &str,
173) -> Result<(), OCRError> {
174 if index >= slice.len() {
175 return Err(OCRError::InvalidInput {
176 message: format!(
177 "Index out of bounds for '{}': index {} >= length {}",
178 slice_name,
179 index,
180 slice.len()
181 ),
182 });
183 }
184 Ok(())
185}
186
187#[inline]
189pub fn validate_division(numerator: f32, denominator: f32, context: &str) -> Result<(), OCRError> {
190 validate_finite(numerator, &format!("{} numerator", context))?;
191 validate_finite(denominator, &format!("{} denominator", context))?;
192
193 if denominator.abs() < f32::EPSILON {
194 return Err(OCRError::InvalidInput {
195 message: format!(
196 "{}: division by zero (denominator: {})",
197 context, denominator
198 ),
199 });
200 }
201
202 Ok(())
203}
204
205pub fn validate_normalization_params(
207 mean: &[f32],
208 std: &[f32],
209 num_channels: usize,
210) -> Result<(), OCRError> {
211 if mean.len() != num_channels {
213 return Err(OCRError::InvalidInput {
214 message: format!(
215 "Mean length {} does not match number of channels {}",
216 mean.len(),
217 num_channels
218 ),
219 });
220 }
221
222 if std.len() != num_channels {
223 return Err(OCRError::InvalidInput {
224 message: format!(
225 "Std length {} does not match number of channels {}",
226 std.len(),
227 num_channels
228 ),
229 });
230 }
231
232 for (i, &m) in mean.iter().enumerate() {
234 validate_finite(m, &format!("mean[{}]", i))?;
235 }
236
237 for (i, &s) in std.iter().enumerate() {
239 validate_finite(s, &format!("std[{}]", i))?;
240 validate_positive(s, &format!("std[{}]", i))?;
241 }
242
243 Ok(())
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_validate_finite() {
252 assert!(validate_finite(1.0, "test").is_ok());
253 assert!(validate_finite(0.0, "test").is_ok());
254 assert!(validate_finite(-1.0, "test").is_ok());
255 assert!(validate_finite(f32::NAN, "test").is_err());
256 assert!(validate_finite(f32::INFINITY, "test").is_err());
257 assert!(validate_finite(f32::NEG_INFINITY, "test").is_err());
258 }
259
260 #[test]
261 fn test_validate_range() {
262 assert!(validate_range(5.0, 0.0, 10.0, "test").is_ok());
263 assert!(validate_range(0.0, 0.0, 10.0, "test").is_ok());
264 assert!(validate_range(10.0, 0.0, 10.0, "test").is_ok());
265 assert!(validate_range(-1.0, 0.0, 10.0, "test").is_err());
266 assert!(validate_range(11.0, 0.0, 10.0, "test").is_err());
267 }
268
269 #[test]
270 fn test_validate_positive() {
271 assert!(validate_positive(1.0, "test").is_ok());
272 assert!(validate_positive(0.1, "test").is_ok());
273 assert!(validate_positive(0.0, "test").is_err());
274 assert!(validate_positive(-1.0, "test").is_err());
275 }
276
277 #[test]
278 fn test_validate_non_negative() {
279 assert!(validate_non_negative(1.0, "test").is_ok());
280 assert!(validate_non_negative(0.0, "test").is_ok());
281 assert!(validate_non_negative(-1.0, "test").is_err());
282 }
283
284 #[test]
285 fn test_validate_non_empty() {
286 assert!(validate_non_empty(&[1, 2, 3], "test").is_ok());
287 assert!(validate_non_empty(&[1], "test").is_ok());
288 assert!(validate_non_empty::<i32>(&[], "test").is_err());
289 }
290
291 #[test]
292 fn test_validate_same_length() {
293 assert!(validate_same_length(&[1, 2], &[3, 4], "a", "b").is_ok());
294 assert!(validate_same_length(&[1], &[2], "a", "b").is_ok());
295 assert!(validate_same_length(&[1, 2], &[3], "a", "b").is_err());
296 }
297
298 #[test]
299 fn test_validate_tensor_shape() {
300 assert!(validate_tensor_shape(&[1, 3, 224, 224], 4, "test").is_ok());
301 assert!(validate_tensor_shape(&[1, 3, 224], 3, "test").is_ok());
302 assert!(validate_tensor_shape(&[1, 3, 224], 4, "test").is_err());
303 }
304
305 #[test]
306 fn test_validate_batch_size() {
307 match validate_batch_size(&[2, 3, 224, 224], "test") {
308 Ok(batch_size) => assert_eq!(batch_size, 2),
309 Err(err) => panic!("expected validate_batch_size to succeed: {err}"),
310 }
311 match validate_batch_size(&[1, 3, 224, 224], "test") {
312 Ok(batch_size) => assert_eq!(batch_size, 1),
313 Err(err) => panic!("expected validate_batch_size to succeed: {err}"),
314 }
315 assert!(validate_batch_size(&[0, 3, 224, 224], "test").is_err());
316 assert!(validate_batch_size(&[1, 3, 224], "test").is_err());
317 }
318
319 #[test]
320 fn test_validate_image_dimensions() {
321 assert!(validate_image_dimensions(224, 224, "test").is_ok());
322 assert!(validate_image_dimensions(1, 1, "test").is_ok());
323 assert!(validate_image_dimensions(0, 224, "test").is_err());
324 assert!(validate_image_dimensions(224, 0, "test").is_err());
325 assert!(validate_image_dimensions(99999, 99999, "test").is_err());
326 }
327
328 #[test]
329 fn test_validate_division() {
330 assert!(validate_division(10.0, 2.0, "test").is_ok());
331 assert!(validate_division(0.0, 2.0, "test").is_ok());
332 assert!(validate_division(10.0, 0.0, "test").is_err());
333 assert!(validate_division(10.0, 1e-10, "test").is_err());
334 assert!(validate_division(f32::NAN, 2.0, "test").is_err());
335 assert!(validate_division(10.0, f32::INFINITY, "test").is_err());
336 }
337
338 #[test]
339 fn test_validate_normalization_params() {
340 assert!(
341 validate_normalization_params(&[0.485, 0.456, 0.406], &[0.229, 0.224, 0.225], 3)
342 .is_ok()
343 );
344
345 assert!(validate_normalization_params(&[0.485, 0.456], &[0.229, 0.224, 0.225], 3).is_err());
347
348 assert!(
350 validate_normalization_params(&[f32::NAN, 0.456, 0.406], &[0.229, 0.224, 0.225], 3)
351 .is_err()
352 );
353
354 assert!(
356 validate_normalization_params(&[0.485, 0.456, 0.406], &[0.0, 0.224, 0.225], 3).is_err()
357 );
358
359 assert!(
361 validate_normalization_params(&[0.485, 0.456, 0.406], &[-0.229, 0.224, 0.225], 3)
362 .is_err()
363 );
364 }
365}