oar_ocr/processors/
normalization.rs

1//! Image normalization utilities for OCR processing.
2//!
3//! This module provides functionality to normalize images for OCR processing,
4//! including standard normalization with mean and standard deviation, as well as
5//! specialized normalization for OCR recognition tasks.
6
7use crate::core::OCRError;
8use crate::processors::types::ChannelOrder;
9use image::DynamicImage;
10use rayon::prelude::*;
11
12/// Normalizes images for OCR processing.
13///
14/// This struct encapsulates the parameters needed to normalize images,
15/// including scaling factors, mean values, standard deviations, and channel ordering.
16/// It provides methods to apply normalization to single images or batches of images.
17#[derive(Debug)]
18pub struct NormalizeImage {
19    /// Scaling factors for each channel (alpha = scale / std)
20    pub alpha: Vec<f32>,
21    /// Offset values for each channel (beta = -mean / std)
22    pub beta: Vec<f32>,
23    /// Channel ordering (CHW or HWC)
24    pub order: ChannelOrder,
25}
26
27impl NormalizeImage {
28    /// Creates a new NormalizeImage instance with the specified parameters.
29    ///
30    /// # Arguments
31    ///
32    /// * `scale` - Optional scaling factor (defaults to 1.0/255.0)
33    /// * `mean` - Optional mean values for each channel (defaults to [0.485, 0.456, 0.406])
34    /// * `std` - Optional standard deviation values for each channel (defaults to [0.229, 0.224, 0.225])
35    /// * `order` - Optional channel ordering (defaults to CHW)
36    ///
37    /// # Returns
38    ///
39    /// A Result containing the new NormalizeImage instance or an OCRError if validation fails.
40    ///
41    /// # Errors
42    ///
43    /// Returns an error if:
44    /// * Scale is less than or equal to 0
45    /// * Mean or std vectors don't have exactly 3 elements
46    /// * Any standard deviation value is less than or equal to 0
47    pub fn new(
48        scale: Option<f32>,
49        mean: Option<Vec<f32>>,
50        std: Option<Vec<f32>>,
51        order: Option<ChannelOrder>,
52    ) -> Result<Self, OCRError> {
53        let scale = scale.unwrap_or(1.0 / 255.0);
54        let mean = mean.unwrap_or_else(|| vec![0.485, 0.456, 0.406]);
55        let std = std.unwrap_or_else(|| vec![0.229, 0.224, 0.225]);
56        let order = order.unwrap_or(ChannelOrder::CHW);
57
58        if scale <= 0.0 {
59            return Err(OCRError::ConfigError {
60                message: "Scale must be greater than 0".to_string(),
61            });
62        }
63
64        if mean.len() != 3 {
65            return Err(OCRError::ConfigError {
66                message: "Mean must have exactly 3 elements for RGB".to_string(),
67            });
68        }
69
70        if std.len() != 3 {
71            return Err(OCRError::ConfigError {
72                message: "Std must have exactly 3 elements for RGB".to_string(),
73            });
74        }
75
76        for (i, &s) in std.iter().enumerate() {
77            if s <= 0.0 {
78                return Err(OCRError::ConfigError {
79                    message: format!(
80                        "Standard deviation at index {i} must be greater than 0, got {s}"
81                    ),
82                });
83            }
84        }
85
86        let alpha: Vec<f32> = std.iter().map(|s| scale / s).collect();
87        let beta: Vec<f32> = mean.iter().zip(&std).map(|(m, s)| -m / s).collect();
88
89        Ok(Self { alpha, beta, order })
90    }
91
92    /// Validates the configuration of the NormalizeImage instance.
93    ///
94    /// # Returns
95    ///
96    /// A Result indicating success or an OCRError if validation fails.
97    ///
98    /// # Errors
99    ///
100    /// Returns an error if:
101    /// * Alpha or beta vectors don't have exactly 3 elements
102    /// * Any alpha or beta value is not finite
103    pub fn validate_config(&self) -> Result<(), OCRError> {
104        if self.alpha.len() != 3 || self.beta.len() != 3 {
105            return Err(OCRError::ConfigError {
106                message: "Alpha and beta must have exactly 3 elements for RGB".to_string(),
107            });
108        }
109
110        for (i, &alpha) in self.alpha.iter().enumerate() {
111            if !alpha.is_finite() {
112                return Err(OCRError::ConfigError {
113                    message: format!("Alpha value at index {i} is not finite: {alpha}"),
114                });
115            }
116        }
117
118        for (i, &beta) in self.beta.iter().enumerate() {
119            if !beta.is_finite() {
120                return Err(OCRError::ConfigError {
121                    message: format!("Beta value at index {i} is not finite: {beta}"),
122                });
123            }
124        }
125
126        Ok(())
127    }
128
129    /// Creates a NormalizeImage instance with parameters suitable for OCR recognition.
130    ///
131    /// This creates a normalization configuration with:
132    /// * Scale: 2.0/255.0
133    /// * Mean: [1.0, 1.0, 1.0]
134    /// * Std: [1.0, 1.0, 1.0]
135    /// * Order: CHW
136    ///
137    /// # Returns
138    ///
139    /// A Result containing the new NormalizeImage instance or an OCRError.
140    pub fn for_ocr_recognition() -> Result<Self, OCRError> {
141        Self::new(
142            Some(2.0 / 255.0),
143            Some(vec![1.0, 1.0, 1.0]),
144            Some(vec![1.0, 1.0, 1.0]),
145            Some(ChannelOrder::CHW),
146        )
147    }
148
149    /// Applies normalization to a vector of images.
150    ///
151    /// # Arguments
152    ///
153    /// * `imgs` - A vector of DynamicImage instances to normalize
154    ///
155    /// # Returns
156    ///
157    /// A vector of normalized images represented as vectors of f32 values
158    pub fn apply(&self, imgs: Vec<DynamicImage>) -> Vec<Vec<f32>> {
159        imgs.into_iter().map(|img| self.normalize(img)).collect()
160    }
161
162    /// Validates inputs for batch processing operations.
163    ///
164    /// # Arguments
165    ///
166    /// * `imgs_len` - Number of images in the batch
167    /// * `shapes` - Shapes of the images as (channels, height, width) tuples
168    /// * `batch_tensor` - The batch tensor to validate against
169    ///
170    /// # Returns
171    ///
172    /// A Result containing a tuple of (batch_size, channels, height, max_width) or an OCRError.
173    fn validate_batch_inputs(
174        &self,
175        imgs_len: usize,
176        shapes: &[(usize, usize, usize)],
177        batch_tensor: &[f32],
178    ) -> Result<(usize, usize, usize, usize), OCRError> {
179        if imgs_len != shapes.len() {
180            return Err(OCRError::InvalidInput {
181                message: format!(
182                    "Images and shapes length mismatch: {} images vs {} shapes",
183                    imgs_len,
184                    shapes.len()
185                ),
186            });
187        }
188
189        let batch_size = imgs_len;
190        if batch_size == 0 {
191            return Ok((0, 0, 0, 0));
192        }
193
194        let max_width = shapes.iter().map(|(_, _, w)| *w).max().unwrap_or(0);
195        let channels = shapes.first().map(|(c, _, _)| *c).unwrap_or(0);
196        let height = shapes.first().map(|(_, h, _)| *h).unwrap_or(0);
197        let img_size = channels * height * max_width;
198
199        if batch_tensor.len() < batch_size * img_size {
200            return Err(OCRError::BufferTooSmall {
201                expected: batch_size * img_size,
202                actual: batch_tensor.len(),
203            });
204        }
205
206        Ok((batch_size, channels, height, max_width))
207    }
208
209    /// Applies normalization to a batch of images and stores the result in a pre-allocated tensor.
210    ///
211    /// # Arguments
212    ///
213    /// * `imgs` - A vector of DynamicImage instances to normalize
214    /// * `batch_tensor` - A mutable slice where the normalized batch will be stored
215    /// * `shapes` - Shapes of the images as (channels, height, width) tuples
216    ///
217    /// # Returns
218    ///
219    /// A Result indicating success or an OCRError if validation fails.
220    pub fn apply_to_batch(
221        &self,
222        imgs: Vec<DynamicImage>,
223        batch_tensor: &mut [f32],
224        shapes: &[(usize, usize, usize)],
225    ) -> Result<(), OCRError> {
226        let (batch_size, channels, height, max_width) =
227            self.validate_batch_inputs(imgs.len(), shapes, batch_tensor)?;
228
229        if batch_size == 0 {
230            return Ok(());
231        }
232
233        let img_size = channels * height * max_width;
234
235        for (batch_idx, (img, &(_c, h, w))) in imgs.into_iter().zip(shapes.iter()).enumerate() {
236            let normalized_img = self.normalize(img);
237
238            let batch_offset = batch_idx * img_size;
239
240            for ch in 0.._c {
241                for y in 0..h {
242                    for x in 0..w {
243                        let src_idx = ch * h * w + y * w + x;
244                        let dst_idx = batch_offset + ch * height * max_width + y * max_width + x;
245                        if src_idx < normalized_img.len() && dst_idx < batch_tensor.len() {
246                            batch_tensor[dst_idx] = normalized_img[src_idx];
247                        }
248                    }
249                }
250            }
251        }
252
253        Ok(())
254    }
255
256    /// Applies normalization to a batch of images and stores the result in a pre-allocated tensor,
257    /// processing images in a streaming fashion.
258    ///
259    /// # Arguments
260    ///
261    /// * `imgs` - A vector of DynamicImage instances to normalize
262    /// * `batch_tensor` - A mutable slice where the normalized batch will be stored
263    /// * `shapes` - Shapes of the images as (channels, height, width) tuples
264    ///
265    /// # Returns
266    ///
267    /// A Result indicating success or an OCRError if validation fails.
268    pub fn normalize_streaming_to_batch(
269        &self,
270        imgs: Vec<DynamicImage>,
271        batch_tensor: &mut [f32],
272        shapes: &[(usize, usize, usize)],
273    ) -> Result<(), OCRError> {
274        let (batch_size, channels, height, max_width) =
275            self.validate_batch_inputs(imgs.len(), shapes, batch_tensor)?;
276
277        if batch_size == 0 {
278            return Ok(());
279        }
280
281        let img_size = channels * height * max_width;
282        batch_tensor.fill(0.0);
283
284        for (batch_idx, (img, &(_c, h, w))) in imgs.into_iter().zip(shapes.iter()).enumerate() {
285            let rgb_img = img.to_rgb8();
286            let (width, height_img) = rgb_img.dimensions();
287            let batch_offset = batch_idx * img_size;
288
289            match self.order {
290                ChannelOrder::CHW => {
291                    for c in 0..channels.min(3) {
292                        for y in 0..h.min(height_img as usize) {
293                            for x in 0..w.min(width as usize) {
294                                let pixel = rgb_img.get_pixel(x as u32, y as u32);
295                                let channel_value = pixel[c] as f32;
296                                let dst_idx =
297                                    batch_offset + c * height * max_width + y * max_width + x;
298                                if dst_idx < batch_tensor.len() {
299                                    batch_tensor[dst_idx] =
300                                        channel_value * self.alpha[c] + self.beta[c];
301                                }
302                            }
303                        }
304                    }
305                }
306                ChannelOrder::HWC => {
307                    for y in 0..h.min(height_img as usize) {
308                        for x in 0..w.min(width as usize) {
309                            let pixel = rgb_img.get_pixel(x as u32, y as u32);
310                            for c in 0..channels.min(3) {
311                                let channel_value = pixel[c] as f32;
312                                let dst_idx =
313                                    batch_offset + y * max_width * channels + x * channels + c;
314                                if dst_idx < batch_tensor.len() {
315                                    batch_tensor[dst_idx] =
316                                        channel_value * self.alpha[c] + self.beta[c];
317                                }
318                            }
319                        }
320                    }
321                }
322            }
323        }
324
325        Ok(())
326    }
327
328    /// Normalizes a single image.
329    ///
330    /// # Arguments
331    ///
332    /// * `img` - The DynamicImage to normalize
333    ///
334    /// # Returns
335    ///
336    /// A vector of normalized pixel values as f32
337    fn normalize(&self, img: DynamicImage) -> Vec<f32> {
338        let rgb_img = img.to_rgb8();
339        let (width, height) = rgb_img.dimensions();
340        let channels = 3;
341
342        match self.order {
343            ChannelOrder::CHW => {
344                let mut result = vec![0.0f32; (channels * height * width) as usize];
345
346                for c in 0..channels {
347                    for y in 0..height {
348                        for x in 0..width {
349                            let pixel = rgb_img.get_pixel(x, y);
350                            let channel_value = pixel[c as usize] as f32;
351                            let dst_idx = (c * height * width + y * width + x) as usize;
352
353                            result[dst_idx] =
354                                channel_value * self.alpha[c as usize] + self.beta[c as usize];
355                        }
356                    }
357                }
358                result
359            }
360            ChannelOrder::HWC => {
361                let mut result = vec![0.0f32; (height * width * channels) as usize];
362
363                for y in 0..height {
364                    for x in 0..width {
365                        let pixel = rgb_img.get_pixel(x, y);
366                        for c in 0..channels {
367                            let channel_value = pixel[c as usize] as f32;
368                            let dst_idx = (y * width * channels + x * channels + c) as usize;
369
370                            result[dst_idx] =
371                                channel_value * self.alpha[c as usize] + self.beta[c as usize];
372                        }
373                    }
374                }
375                result
376            }
377        }
378    }
379
380    /// Normalizes a single image and returns it as a 4D tensor.
381    ///
382    /// # Arguments
383    ///
384    /// * `img` - The DynamicImage to normalize
385    ///
386    /// # Returns
387    ///
388    /// A Result containing the normalized image as a 4D tensor or an OCRError.
389    pub fn normalize_to(&self, img: DynamicImage) -> Result<crate::core::Tensor4D, OCRError> {
390        let rgb_img = img.to_rgb8();
391        let (width, height) = rgb_img.dimensions();
392        let channels = 3;
393
394        match self.order {
395            ChannelOrder::CHW => {
396                let mut result = vec![0.0f32; (channels * height * width) as usize];
397
398                for c in 0..channels {
399                    for y in 0..height {
400                        for x in 0..width {
401                            let pixel = rgb_img.get_pixel(x, y);
402                            let channel_value = pixel[c as usize] as f32;
403                            let dst_idx = (c * height * width + y * width + x) as usize;
404
405                            result[dst_idx] =
406                                channel_value * self.alpha[c as usize] + self.beta[c as usize];
407                        }
408                    }
409                }
410
411                ndarray::Array4::from_shape_vec(
412                    (1, channels as usize, height as usize, width as usize),
413                    result,
414                )
415                .map_err(|e| {
416                    OCRError::tensor_operation_error(
417                        "normalization_tensor_creation_chw",
418                        &[1, channels as usize, height as usize, width as usize],
419                        &[(channels * height * width) as usize],
420                        &format!("Failed to create CHW normalization tensor for {}x{} image with {} channels",
421                            width, height, channels),
422                        e,
423                    )
424                })
425            }
426            ChannelOrder::HWC => {
427                let mut result = vec![0.0f32; (height * width * channels) as usize];
428
429                for y in 0..height {
430                    for x in 0..width {
431                        let pixel = rgb_img.get_pixel(x, y);
432                        for c in 0..channels {
433                            let channel_value = pixel[c as usize] as f32;
434                            let dst_idx = (y * width * channels + x * channels + c) as usize;
435
436                            result[dst_idx] =
437                                channel_value * self.alpha[c as usize] + self.beta[c as usize];
438                        }
439                    }
440                }
441
442                ndarray::Array4::from_shape_vec(
443                    (1, height as usize, width as usize, channels as usize),
444                    result,
445                )
446                .map_err(|e| {
447                    OCRError::tensor_operation_error(
448                        "normalization_tensor_creation_hwc",
449                        &[1, height as usize, width as usize, channels as usize],
450                        &[(height * width * channels) as usize],
451                        &format!("Failed to create HWC normalization tensor for {}x{} image with {} channels",
452                            width, height, channels),
453                        e,
454                    )
455                })
456            }
457        }
458    }
459
460    /// Normalizes a batch of images and returns them as a 4D tensor.
461    ///
462    /// # Arguments
463    ///
464    /// * `imgs` - A vector of DynamicImage instances to normalize
465    ///
466    /// # Returns
467    ///
468    /// A Result containing the normalized batch as a 4D tensor or an OCRError.
469    ///
470    /// # Errors
471    ///
472    /// Returns an error if:
473    /// * Images in the batch don't all have the same dimensions
474    pub fn normalize_batch_to(
475        &self,
476        imgs: Vec<DynamicImage>,
477    ) -> Result<crate::core::Tensor4D, OCRError> {
478        if imgs.is_empty() {
479            return Ok(ndarray::Array4::zeros((0, 0, 0, 0)));
480        }
481
482        let batch_size = imgs.len();
483
484        let rgb_imgs: Vec<_> = imgs.into_iter().map(|img| img.to_rgb8()).collect();
485        let dimensions: Vec<_> = rgb_imgs.iter().map(|img| img.dimensions()).collect();
486
487        let (first_width, first_height) = dimensions.first().copied().unwrap_or((0, 0));
488        for (i, &(width, height)) in dimensions.iter().enumerate() {
489            if width != first_width || height != first_height {
490                return Err(OCRError::InvalidInput {
491                    message: format!(
492                        "All images in batch must have the same dimensions. Image 0: {first_width}x{first_height}, Image {i}: {width}x{height}"
493                    ),
494                });
495            }
496        }
497
498        let (width, height) = (first_width, first_height);
499        let channels = 3;
500
501        match self.order {
502            ChannelOrder::CHW => {
503                let mut result = vec![0.0f32; batch_size * (channels * height * width) as usize];
504
505                let img_size = (channels * height * width) as usize;
506                if batch_size == 1 {
507                    // Avoid rayon overhead for single-image batches
508                    let rgb_img = &rgb_imgs[0];
509                    let batch_slice = &mut result[0..img_size];
510                    for c in 0..channels {
511                        for y in 0..height {
512                            for x in 0..width {
513                                let pixel = rgb_img.get_pixel(x, y);
514                                let channel_value = pixel[c as usize] as f32;
515                                let dst_idx = (c * height * width + y * width + x) as usize;
516                                batch_slice[dst_idx] =
517                                    channel_value * self.alpha[c as usize] + self.beta[c as usize];
518                            }
519                        }
520                    }
521                } else {
522                    result.par_chunks_mut(img_size).enumerate().for_each(
523                        |(batch_idx, batch_slice)| {
524                            let rgb_img = &rgb_imgs[batch_idx];
525                            for c in 0..channels {
526                                for y in 0..height {
527                                    for x in 0..width {
528                                        let pixel = rgb_img.get_pixel(x, y);
529                                        let channel_value = pixel[c as usize] as f32;
530                                        let dst_idx = (c * height * width + y * width + x) as usize;
531                                        batch_slice[dst_idx] = channel_value
532                                            * self.alpha[c as usize]
533                                            + self.beta[c as usize];
534                                    }
535                                }
536                            }
537                        },
538                    );
539                }
540
541                ndarray::Array4::from_shape_vec(
542                    (
543                        batch_size,
544                        channels as usize,
545                        height as usize,
546                        width as usize,
547                    ),
548                    result,
549                )
550                .map_err(|e| {
551                    OCRError::tensor_operation(
552                        "Failed to create batch normalization tensor in CHW format",
553                        e,
554                    )
555                })
556            }
557            ChannelOrder::HWC => {
558                let mut result = vec![0.0f32; batch_size * (height * width * channels) as usize];
559
560                let img_size = (height * width * channels) as usize;
561                if batch_size == 1 {
562                    // Avoid rayon overhead for single-image batches
563                    let rgb_img = &rgb_imgs[0];
564                    let batch_slice = &mut result[0..img_size];
565                    for y in 0..height {
566                        for x in 0..width {
567                            let pixel = rgb_img.get_pixel(x, y);
568                            for c in 0..channels {
569                                let channel_value = pixel[c as usize] as f32;
570                                let dst_idx = (y * width * channels + x * channels + c) as usize;
571                                batch_slice[dst_idx] =
572                                    channel_value * self.alpha[c as usize] + self.beta[c as usize];
573                            }
574                        }
575                    }
576                } else {
577                    result.par_chunks_mut(img_size).enumerate().for_each(
578                        |(batch_idx, batch_slice)| {
579                            let rgb_img = &rgb_imgs[batch_idx];
580                            for y in 0..height {
581                                for x in 0..width {
582                                    let pixel = rgb_img.get_pixel(x, y);
583                                    for c in 0..channels {
584                                        let channel_value = pixel[c as usize] as f32;
585                                        let dst_idx =
586                                            (y * width * channels + x * channels + c) as usize;
587                                        batch_slice[dst_idx] = channel_value
588                                            * self.alpha[c as usize]
589                                            + self.beta[c as usize];
590                                    }
591                                }
592                            }
593                        },
594                    );
595                }
596
597                ndarray::Array4::from_shape_vec(
598                    (
599                        batch_size,
600                        height as usize,
601                        width as usize,
602                        channels as usize,
603                    ),
604                    result,
605                )
606                .map_err(|e| {
607                    OCRError::tensor_operation(
608                        "Failed to create batch normalization tensor in HWC format",
609                        e,
610                    )
611                })
612            }
613        }
614    }
615}