Skip to main content

oar_ocr_core/processors/
normalization.rs

1//! Image normalization utilities for OCR processing.
2//!
3//! This module provides functionality to normalize images for OCR processing,
4//! including standard normalization with mean and standard deviation, as well as
5//! specialized normalization for OCR recognition tasks.
6
7use crate::core::OCRError;
8use crate::processors::types::{ColorOrder, TensorLayout};
9use image::DynamicImage;
10use rayon::prelude::*;
11
12/// Normalizes images for OCR processing.
13///
14/// This struct encapsulates the parameters needed to normalize images,
15/// including scaling factors, mean values, standard deviations, and channel ordering.
16/// It provides methods to apply normalization to single images or batches of images.
17#[derive(Debug)]
18pub struct NormalizeImage {
19    /// Scaling factors for each channel (alpha = scale / std)
20    pub alpha: Vec<f32>,
21    /// Offset values for each channel (beta = -mean / std)
22    pub beta: Vec<f32>,
23    /// Tensor data layout (CHW or HWC)
24    pub order: TensorLayout,
25    /// Color channel order (RGB or BGR)
26    pub color_order: ColorOrder,
27}
28
29impl NormalizeImage {
30    /// Creates a new NormalizeImage instance with the specified parameters.
31    ///
32    /// # Arguments
33    ///
34    /// * `scale` - Optional scaling factor (defaults to 1.0/255.0)
35    /// * `mean` - Optional mean values for each channel (defaults to [0.485, 0.456, 0.406])
36    /// * `std` - Optional standard deviation values for each channel (defaults to [0.229, 0.224, 0.225])
37    /// * `order` - Optional tensor data layout (defaults to CHW)
38    /// * `color_order` - Optional color channel order (defaults to BGR)
39    ///
40    /// # Returns
41    ///
42    /// A Result containing the new NormalizeImage instance or an OCRError if validation fails.
43    ///
44    /// # Errors
45    ///
46    /// Returns an error if:
47    /// * Scale is less than or equal to 0
48    /// * Mean or std vectors don't have exactly 3 elements
49    /// * Any standard deviation value is less than or equal to 0
50    pub fn new(
51        scale: Option<f32>,
52        mean: Option<Vec<f32>>,
53        std: Option<Vec<f32>>,
54        order: Option<TensorLayout>,
55        color_order: Option<ColorOrder>,
56    ) -> Result<Self, OCRError> {
57        Self::with_color_order(scale, mean, std, order, color_order)
58    }
59
60    /// Creates a new NormalizeImage instance with the specified parameters including color order.
61    ///
62    /// # Arguments
63    ///
64    /// * `scale` - Optional scaling factor (defaults to 1.0/255.0)
65    /// * `mean` - Optional mean values for each channel (defaults to [0.485, 0.456, 0.406])
66    /// * `std` - Optional standard deviation values for each channel (defaults to [0.229, 0.224, 0.225])
67    /// * `order` - Optional tensor data layout (defaults to CHW)
68    /// * `color_order` - Optional color channel order (defaults to RGB)
69    ///
70    /// # Mean/Std Semantics
71    ///
72    /// `mean` and `std` must be provided in the **output channel order** specified by `color_order`.
73    /// For example, if `color_order` is BGR, pass mean/std as `[B_mean, G_mean, R_mean]`.
74    ///
75    /// **Note:** This function does not validate that mean/std values match the specified
76    /// `color_order`. Ensuring consistency is the caller's responsibility. If you have stats
77    /// expressed in RGB order but need BGR output, prefer using
78    /// [`NormalizeImage::with_color_order_from_rgb_stats`] or
79    /// [`NormalizeImage::imagenet_bgr_from_rgb_stats`] which handle the reordering automatically.
80    ///
81    /// # Returns
82    ///
83    /// A Result containing the new NormalizeImage instance or an OCRError if validation fails.
84    pub fn with_color_order(
85        scale: Option<f32>,
86        mean: Option<Vec<f32>>,
87        std: Option<Vec<f32>>,
88        order: Option<TensorLayout>,
89        color_order: Option<ColorOrder>,
90    ) -> Result<Self, OCRError> {
91        let scale = scale.unwrap_or(1.0 / 255.0);
92        let mean = mean.unwrap_or_else(|| vec![0.485, 0.456, 0.406]);
93        let std = std.unwrap_or_else(|| vec![0.229, 0.224, 0.225]);
94        let order = order.unwrap_or(TensorLayout::CHW);
95        let color_order = color_order.unwrap_or_default();
96
97        if scale <= 0.0 {
98            return Err(OCRError::ConfigError {
99                message: "Scale must be greater than 0".to_string(),
100            });
101        }
102
103        if mean.len() != 3 {
104            return Err(OCRError::ConfigError {
105                message: "Mean must have exactly 3 elements (3-channel normalization)".to_string(),
106            });
107        }
108
109        if std.len() != 3 {
110            return Err(OCRError::ConfigError {
111                message: "Std must have exactly 3 elements (3-channel normalization)".to_string(),
112            });
113        }
114
115        for (i, &s) in std.iter().enumerate() {
116            if s <= 0.0 {
117                return Err(OCRError::ConfigError {
118                    message: format!(
119                        "Standard deviation at index {i} must be greater than 0, got {s}"
120                    ),
121                });
122            }
123        }
124
125        let alpha: Vec<f32> = std.iter().map(|s| scale / s).collect();
126        let beta: Vec<f32> = mean.iter().zip(&std).map(|(m, s)| -m / s).collect();
127
128        Ok(Self {
129            alpha,
130            beta,
131            order,
132            color_order,
133        })
134    }
135
136    /// Validates the configuration of the NormalizeImage instance.
137    ///
138    /// # Returns
139    ///
140    /// A Result indicating success or an OCRError if validation fails.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if:
145    /// * Alpha or beta vectors don't have exactly 3 elements
146    /// * Any alpha or beta value is not finite
147    pub fn validate_config(&self) -> Result<(), OCRError> {
148        if self.alpha.len() != 3 || self.beta.len() != 3 {
149            return Err(OCRError::ConfigError {
150                message: "Alpha and beta must have exactly 3 elements (3-channel normalization)"
151                    .to_string(),
152            });
153        }
154
155        for (i, &alpha) in self.alpha.iter().enumerate() {
156            if !alpha.is_finite() {
157                return Err(OCRError::ConfigError {
158                    message: format!("Alpha value at index {i} is not finite: {alpha}"),
159                });
160            }
161        }
162
163        for (i, &beta) in self.beta.iter().enumerate() {
164            if !beta.is_finite() {
165                return Err(OCRError::ConfigError {
166                    message: format!("Beta value at index {i} is not finite: {beta}"),
167                });
168            }
169        }
170
171        Ok(())
172    }
173
174    /// Creates a NormalizeImage instance with parameters suitable for OCR recognition.
175    ///
176    /// This creates a normalization configuration with:
177    /// * Scale: 2.0/255.0
178    /// * Mean: [1.0, 1.0, 1.0]
179    /// * Std: [1.0, 1.0, 1.0]
180    /// * Order: CHW
181    ///
182    /// # Returns
183    ///
184    /// A Result containing the new NormalizeImage instance or an OCRError.
185    pub fn for_ocr_recognition() -> Result<Self, OCRError> {
186        Self::new(
187            Some(2.0 / 255.0),
188            Some(vec![1.0, 1.0, 1.0]),
189            Some(vec![1.0, 1.0, 1.0]),
190            Some(TensorLayout::CHW),
191            Some(ColorOrder::BGR),
192        )
193    }
194
195    /// Creates an ImageNet-style RGB normalizer (mean/std in RGB order).
196    pub fn imagenet_rgb() -> Result<Self, OCRError> {
197        Self::with_color_order(
198            None,
199            Some(vec![0.485, 0.456, 0.406]),
200            Some(vec![0.229, 0.224, 0.225]),
201            Some(TensorLayout::CHW),
202            Some(ColorOrder::RGB),
203        )
204    }
205
206    /// Creates an ImageNet-style BGR normalizer from RGB stats.
207    ///
208    /// This is useful for PaddlePaddle-exported models that expect BGR input,
209    /// while configuration commonly provides ImageNet mean/std in RGB order.
210    pub fn imagenet_bgr_from_rgb_stats() -> Result<Self, OCRError> {
211        Self::with_color_order(
212            None,
213            Some(vec![0.406, 0.456, 0.485]),
214            Some(vec![0.225, 0.224, 0.229]),
215            Some(TensorLayout::CHW),
216            Some(ColorOrder::BGR),
217        )
218    }
219
220    /// Builds a normalizer for a given output `color_order` using RGB mean/std stats.
221    ///
222    /// Invariant: `mean`/`std` passed to `with_color_order` are interpreted in the output channel
223    /// order (`ColorOrder`). This helper makes the conversion explicit at call sites.
224    pub fn with_color_order_from_rgb_stats(
225        scale: Option<f32>,
226        mean_rgb: Vec<f32>,
227        std_rgb: Vec<f32>,
228        order: Option<TensorLayout>,
229        output_color_order: ColorOrder,
230    ) -> Result<Self, OCRError> {
231        if mean_rgb.len() != 3 || std_rgb.len() != 3 {
232            return Err(OCRError::ConfigError {
233                message: format!(
234                    "mean/std must have exactly 3 elements (got mean={}, std={})",
235                    mean_rgb.len(),
236                    std_rgb.len()
237                ),
238            });
239        }
240
241        let (mean, std) = match output_color_order {
242            ColorOrder::RGB => (mean_rgb, std_rgb),
243            ColorOrder::BGR => (
244                vec![mean_rgb[2], mean_rgb[1], mean_rgb[0]],
245                vec![std_rgb[2], std_rgb[1], std_rgb[0]],
246            ),
247        };
248
249        Self::with_color_order(
250            scale,
251            Some(mean),
252            Some(std),
253            order,
254            Some(output_color_order),
255        )
256    }
257
258    /// Applies normalization to a vector of images.
259    ///
260    /// # Arguments
261    ///
262    /// * `imgs` - A vector of DynamicImage instances to normalize
263    ///
264    /// # Returns
265    ///
266    /// A vector of normalized images represented as vectors of f32 values
267    pub fn apply(&self, imgs: Vec<DynamicImage>) -> Vec<Vec<f32>> {
268        imgs.into_iter().map(|img| self.normalize(img)).collect()
269    }
270
271    /// Validates inputs for batch processing operations.
272    ///
273    /// # Arguments
274    ///
275    /// * `imgs_len` - Number of images in the batch
276    /// * `shapes` - Shapes of the images as (channels, height, width) tuples
277    /// * `batch_tensor` - The batch tensor to validate against
278    ///
279    /// # Returns
280    ///
281    /// A Result containing a tuple of (batch_size, channels, height, max_width) or an OCRError.
282    fn validate_batch_inputs(
283        &self,
284        imgs_len: usize,
285        shapes: &[(usize, usize, usize)],
286        batch_tensor: &[f32],
287    ) -> Result<(usize, usize, usize, usize), OCRError> {
288        if imgs_len != shapes.len() {
289            return Err(OCRError::InvalidInput {
290                message: format!(
291                    "Images and shapes length mismatch: {} images vs {} shapes",
292                    imgs_len,
293                    shapes.len()
294                ),
295            });
296        }
297
298        let batch_size = imgs_len;
299        if batch_size == 0 {
300            return Ok((0, 0, 0, 0));
301        }
302
303        let max_width = shapes.iter().map(|(_, _, w)| *w).max().unwrap_or(0);
304        let channels = shapes.first().map(|(c, _, _)| *c).unwrap_or(0);
305        let height = shapes.first().map(|(_, h, _)| *h).unwrap_or(0);
306        let img_size = channels * height * max_width;
307
308        if batch_tensor.len() < batch_size * img_size {
309            return Err(OCRError::BufferTooSmall {
310                expected: batch_size * img_size,
311                actual: batch_tensor.len(),
312            });
313        }
314
315        Ok((batch_size, channels, height, max_width))
316    }
317
318    /// Applies normalization to a batch of images and stores the result in a pre-allocated tensor.
319    ///
320    /// # Arguments
321    ///
322    /// * `imgs` - A vector of DynamicImage instances to normalize
323    /// * `batch_tensor` - A mutable slice where the normalized batch will be stored
324    /// * `shapes` - Shapes of the images as (channels, height, width) tuples
325    ///
326    /// # Returns
327    ///
328    /// A Result indicating success or an OCRError if validation fails.
329    pub fn apply_to_batch(
330        &self,
331        imgs: Vec<DynamicImage>,
332        batch_tensor: &mut [f32],
333        shapes: &[(usize, usize, usize)],
334    ) -> Result<(), OCRError> {
335        let (batch_size, channels, height, max_width) =
336            self.validate_batch_inputs(imgs.len(), shapes, batch_tensor)?;
337
338        if batch_size == 0 {
339            return Ok(());
340        }
341
342        let img_size = channels * height * max_width;
343
344        for (batch_idx, (img, &(_c, h, w))) in imgs.into_iter().zip(shapes.iter()).enumerate() {
345            let normalized_img = self.normalize(img);
346
347            let batch_offset = batch_idx * img_size;
348
349            for ch in 0.._c {
350                for y in 0..h {
351                    for x in 0..w {
352                        let src_idx = ch * h * w + y * w + x;
353                        let dst_idx = batch_offset + ch * height * max_width + y * max_width + x;
354                        if src_idx < normalized_img.len() && dst_idx < batch_tensor.len() {
355                            batch_tensor[dst_idx] = normalized_img[src_idx];
356                        }
357                    }
358                }
359            }
360        }
361
362        Ok(())
363    }
364
365    /// Applies normalization to a batch of images and stores the result in a pre-allocated tensor,
366    /// processing images in a streaming fashion.
367    ///
368    /// # Arguments
369    ///
370    /// * `imgs` - A vector of DynamicImage instances to normalize
371    /// * `batch_tensor` - A mutable slice where the normalized batch will be stored
372    /// * `shapes` - Shapes of the images as (channels, height, width) tuples
373    ///
374    /// # Returns
375    ///
376    /// A Result indicating success or an OCRError if validation fails.
377    pub fn normalize_streaming_to_batch(
378        &self,
379        imgs: Vec<DynamicImage>,
380        batch_tensor: &mut [f32],
381        shapes: &[(usize, usize, usize)],
382    ) -> Result<(), OCRError> {
383        let (batch_size, channels, height, max_width) =
384            self.validate_batch_inputs(imgs.len(), shapes, batch_tensor)?;
385
386        if batch_size == 0 {
387            return Ok(());
388        }
389
390        let img_size = channels * height * max_width;
391        batch_tensor.fill(0.0);
392
393        // Pre-compute channel mapping for BGR support
394        let src_channels: [usize; 3] = match self.color_order {
395            ColorOrder::RGB => [0, 1, 2],
396            ColorOrder::BGR => [2, 1, 0],
397        };
398
399        for (batch_idx, (img, &(_c, h, w))) in imgs.into_iter().zip(shapes.iter()).enumerate() {
400            let rgb_img = img.to_rgb8();
401            let (width, height_img) = rgb_img.dimensions();
402            let batch_offset = batch_idx * img_size;
403
404            match self.order {
405                TensorLayout::CHW => {
406                    for (c, &src_c) in src_channels.iter().enumerate().take(channels.min(3)) {
407                        for y in 0..h.min(height_img as usize) {
408                            for x in 0..w.min(width as usize) {
409                                let pixel = rgb_img.get_pixel(x as u32, y as u32);
410                                let channel_value = pixel[src_c] as f32;
411                                let dst_idx =
412                                    batch_offset + c * height * max_width + y * max_width + x;
413                                if dst_idx < batch_tensor.len() {
414                                    batch_tensor[dst_idx] =
415                                        channel_value * self.alpha[c] + self.beta[c];
416                                }
417                            }
418                        }
419                    }
420                }
421                TensorLayout::HWC => {
422                    for y in 0..h.min(height_img as usize) {
423                        for x in 0..w.min(width as usize) {
424                            let pixel = rgb_img.get_pixel(x as u32, y as u32);
425                            for (c, &src_c) in src_channels.iter().enumerate().take(channels.min(3))
426                            {
427                                let channel_value = pixel[src_c] as f32;
428                                let dst_idx =
429                                    batch_offset + y * max_width * channels + x * channels + c;
430                                if dst_idx < batch_tensor.len() {
431                                    batch_tensor[dst_idx] =
432                                        channel_value * self.alpha[c] + self.beta[c];
433                                }
434                            }
435                        }
436                    }
437                }
438            }
439        }
440
441        Ok(())
442    }
443
444    /// Normalizes a single image.
445    ///
446    /// # Arguments
447    ///
448    /// * `img` - The DynamicImage to normalize
449    ///
450    /// # Returns
451    ///
452    /// A vector of normalized pixel values as f32
453    fn normalize(&self, img: DynamicImage) -> Vec<f32> {
454        let rgb_img = img.to_rgb8();
455        let (width, height) = rgb_img.dimensions();
456        let channels = 3;
457
458        // Map channel index based on color order
459        // RGB: c=0->R, c=1->G, c=2->B (same as pixel layout)
460        // BGR: c=0->B, c=1->G, c=2->R (swap R and B)
461        let map_channel = |c: u32| -> usize {
462            match self.color_order {
463                ColorOrder::RGB => c as usize,
464                ColorOrder::BGR => match c {
465                    0 => 2, // B -> pixel[2]
466                    1 => 1, // G -> pixel[1]
467                    2 => 0, // R -> pixel[0]
468                    _ => c as usize,
469                },
470            }
471        };
472
473        match self.order {
474            TensorLayout::CHW => {
475                let mut result = vec![0.0f32; (channels * height * width) as usize];
476
477                for c in 0..channels {
478                    let src_c = map_channel(c);
479                    for y in 0..height {
480                        for x in 0..width {
481                            let pixel = rgb_img.get_pixel(x, y);
482                            let channel_value = pixel[src_c] as f32;
483                            let dst_idx = (c * height * width + y * width + x) as usize;
484
485                            result[dst_idx] =
486                                channel_value * self.alpha[c as usize] + self.beta[c as usize];
487                        }
488                    }
489                }
490                result
491            }
492            TensorLayout::HWC => {
493                let mut result = vec![0.0f32; (height * width * channels) as usize];
494
495                for y in 0..height {
496                    for x in 0..width {
497                        let pixel = rgb_img.get_pixel(x, y);
498                        for c in 0..channels {
499                            let src_c = map_channel(c);
500                            let channel_value = pixel[src_c] as f32;
501                            let dst_idx = (y * width * channels + x * channels + c) as usize;
502
503                            result[dst_idx] =
504                                channel_value * self.alpha[c as usize] + self.beta[c as usize];
505                        }
506                    }
507                }
508                result
509            }
510        }
511    }
512
513    /// Normalizes a single image and returns it as a 4D tensor.
514    ///
515    /// # Arguments
516    ///
517    /// * `img` - The DynamicImage to normalize
518    ///
519    /// # Returns
520    ///
521    /// A Result containing the normalized image as a 4D tensor or an OCRError.
522    pub fn normalize_to(&self, img: DynamicImage) -> Result<crate::core::Tensor4D, OCRError> {
523        let rgb_img = img.to_rgb8();
524        let (width, height) = rgb_img.dimensions();
525        let channels = 3;
526
527        // Map channel index based on color order
528        let map_channel = |c: u32| -> usize {
529            match self.color_order {
530                ColorOrder::RGB => c as usize,
531                ColorOrder::BGR => match c {
532                    0 => 2,
533                    1 => 1,
534                    2 => 0,
535                    _ => c as usize,
536                },
537            }
538        };
539
540        match self.order {
541            TensorLayout::CHW => {
542                let mut result = vec![0.0f32; (channels * height * width) as usize];
543
544                for c in 0..channels {
545                    let src_c = map_channel(c);
546                    for y in 0..height {
547                        for x in 0..width {
548                            let pixel = rgb_img.get_pixel(x, y);
549                            let channel_value = pixel[src_c] as f32;
550                            let dst_idx = (c * height * width + y * width + x) as usize;
551
552                            result[dst_idx] =
553                                channel_value * self.alpha[c as usize] + self.beta[c as usize];
554                        }
555                    }
556                }
557
558                ndarray::Array4::from_shape_vec(
559                    (1, channels as usize, height as usize, width as usize),
560                    result,
561                )
562                .map_err(|e| {
563                    OCRError::tensor_operation_error(
564                        "normalization_tensor_creation_chw",
565                        &[1, channels as usize, height as usize, width as usize],
566                        &[(channels * height * width) as usize],
567                        &format!("Failed to create CHW normalization tensor for {}x{} image with {} channels",
568                            width, height, channels),
569                        e,
570                    )
571                })
572            }
573            TensorLayout::HWC => {
574                let mut result = vec![0.0f32; (height * width * channels) as usize];
575
576                for y in 0..height {
577                    for x in 0..width {
578                        let pixel = rgb_img.get_pixel(x, y);
579                        for c in 0..channels {
580                            let src_c = map_channel(c);
581                            let channel_value = pixel[src_c] as f32;
582                            let dst_idx = (y * width * channels + x * channels + c) as usize;
583
584                            result[dst_idx] =
585                                channel_value * self.alpha[c as usize] + self.beta[c as usize];
586                        }
587                    }
588                }
589
590                ndarray::Array4::from_shape_vec(
591                    (1, height as usize, width as usize, channels as usize),
592                    result,
593                )
594                .map_err(|e| {
595                    OCRError::tensor_operation_error(
596                        "normalization_tensor_creation_hwc",
597                        &[1, height as usize, width as usize, channels as usize],
598                        &[(height * width * channels) as usize],
599                        &format!("Failed to create HWC normalization tensor for {}x{} image with {} channels",
600                            width, height, channels),
601                        e,
602                    )
603                })
604            }
605        }
606    }
607
608    /// Normalizes a batch of images and returns them as a 4D tensor.
609    ///
610    /// # Arguments
611    ///
612    /// * `imgs` - A vector of DynamicImage instances to normalize
613    ///
614    /// # Returns
615    ///
616    /// A Result containing the normalized batch as a 4D tensor or an OCRError.
617    ///
618    /// # Errors
619    ///
620    /// Returns an error if:
621    /// * Images in the batch don't all have the same dimensions
622    pub fn normalize_batch_to(
623        &self,
624        imgs: Vec<DynamicImage>,
625    ) -> Result<crate::core::Tensor4D, OCRError> {
626        if imgs.is_empty() {
627            return Ok(ndarray::Array4::zeros((0, 0, 0, 0)));
628        }
629
630        let batch_size = imgs.len();
631
632        let rgb_imgs: Vec<_> = imgs.into_iter().map(|img| img.to_rgb8()).collect();
633        let dimensions: Vec<_> = rgb_imgs.iter().map(|img| img.dimensions()).collect();
634
635        let (first_width, first_height) = dimensions.first().copied().unwrap_or((0, 0));
636        for (i, &(width, height)) in dimensions.iter().enumerate() {
637            if width != first_width || height != first_height {
638                return Err(OCRError::InvalidInput {
639                    message: format!(
640                        "All images in batch must have the same dimensions. Image 0: {first_width}x{first_height}, Image {i}: {width}x{height}"
641                    ),
642                });
643            }
644        }
645
646        let (width, height) = (first_width, first_height);
647        let channels = 3u32;
648
649        // Pre-compute channel mapping for BGR support
650        // src_channels[c] gives the source pixel index for output channel c
651        let src_channels: [usize; 3] = match self.color_order {
652            ColorOrder::RGB => [0, 1, 2],
653            ColorOrder::BGR => [2, 1, 0], // B from pixel[2], G from pixel[1], R from pixel[0]
654        };
655
656        // Clone alpha/beta for parallel closure
657        let alpha = self.alpha.clone();
658        let beta = self.beta.clone();
659
660        match self.order {
661            TensorLayout::CHW => {
662                let mut result = vec![0.0f32; batch_size * (channels * height * width) as usize];
663
664                let img_size = (channels * height * width) as usize;
665                if batch_size == 1 {
666                    // Avoid rayon overhead for single-image batches
667                    let rgb_img = &rgb_imgs[0];
668                    let batch_slice = &mut result[0..img_size];
669                    for c in 0..channels {
670                        let src_c = src_channels[c as usize];
671                        for y in 0..height {
672                            for x in 0..width {
673                                let pixel = rgb_img.get_pixel(x, y);
674                                let channel_value = pixel[src_c] as f32;
675                                let dst_idx = (c * height * width + y * width + x) as usize;
676                                batch_slice[dst_idx] =
677                                    channel_value * alpha[c as usize] + beta[c as usize];
678                            }
679                        }
680                    }
681                } else {
682                    result.par_chunks_mut(img_size).enumerate().for_each(
683                        |(batch_idx, batch_slice)| {
684                            let rgb_img = &rgb_imgs[batch_idx];
685                            for c in 0..channels {
686                                let src_c = src_channels[c as usize];
687                                for y in 0..height {
688                                    for x in 0..width {
689                                        let pixel = rgb_img.get_pixel(x, y);
690                                        let channel_value = pixel[src_c] as f32;
691                                        let dst_idx = (c * height * width + y * width + x) as usize;
692                                        batch_slice[dst_idx] =
693                                            channel_value * alpha[c as usize] + beta[c as usize];
694                                    }
695                                }
696                            }
697                        },
698                    );
699                }
700
701                ndarray::Array4::from_shape_vec(
702                    (
703                        batch_size,
704                        channels as usize,
705                        height as usize,
706                        width as usize,
707                    ),
708                    result,
709                )
710                .map_err(|e| {
711                    OCRError::tensor_operation(
712                        "Failed to create batch normalization tensor in CHW format",
713                        e,
714                    )
715                })
716            }
717            TensorLayout::HWC => {
718                let mut result = vec![0.0f32; batch_size * (height * width * channels) as usize];
719
720                let img_size = (height * width * channels) as usize;
721                if batch_size == 1 {
722                    // Avoid rayon overhead for single-image batches
723                    let rgb_img = &rgb_imgs[0];
724                    let batch_slice = &mut result[0..img_size];
725                    for y in 0..height {
726                        for x in 0..width {
727                            let pixel = rgb_img.get_pixel(x, y);
728                            for c in 0..channels {
729                                let src_c = src_channels[c as usize];
730                                let channel_value = pixel[src_c] as f32;
731                                let dst_idx = (y * width * channels + x * channels + c) as usize;
732                                batch_slice[dst_idx] =
733                                    channel_value * alpha[c as usize] + beta[c as usize];
734                            }
735                        }
736                    }
737                } else {
738                    result.par_chunks_mut(img_size).enumerate().for_each(
739                        |(batch_idx, batch_slice)| {
740                            let rgb_img = &rgb_imgs[batch_idx];
741                            for y in 0..height {
742                                for x in 0..width {
743                                    let pixel = rgb_img.get_pixel(x, y);
744                                    for c in 0..channels {
745                                        let src_c = src_channels[c as usize];
746                                        let channel_value = pixel[src_c] as f32;
747                                        let dst_idx =
748                                            (y * width * channels + x * channels + c) as usize;
749                                        batch_slice[dst_idx] =
750                                            channel_value * alpha[c as usize] + beta[c as usize];
751                                    }
752                                }
753                            }
754                        },
755                    );
756                }
757
758                ndarray::Array4::from_shape_vec(
759                    (
760                        batch_size,
761                        height as usize,
762                        width as usize,
763                        channels as usize,
764                    ),
765                    result,
766                )
767                .map_err(|e| {
768                    OCRError::tensor_operation(
769                        "Failed to create batch normalization tensor in HWC format",
770                        e,
771                    )
772                })
773            }
774        }
775    }
776}
777
778#[cfg(test)]
779mod tests {
780    use super::*;
781    use image::{Rgb, RgbImage};
782
783    #[test]
784    fn test_normalize_image_color_order_rgb_vs_bgr_chw() -> Result<(), OCRError> {
785        let mut img = RgbImage::new(1, 1);
786        img.put_pixel(0, 0, Rgb([10, 20, 30])); // R, G, B
787
788        let rgb = NormalizeImage::with_color_order(
789            Some(1.0),
790            Some(vec![0.0, 0.0, 0.0]),
791            Some(vec![1.0, 1.0, 1.0]),
792            Some(TensorLayout::CHW),
793            Some(ColorOrder::RGB),
794        )?;
795        let bgr = NormalizeImage::with_color_order(
796            Some(1.0),
797            Some(vec![0.0, 0.0, 0.0]),
798            Some(vec![1.0, 1.0, 1.0]),
799            Some(TensorLayout::CHW),
800            Some(ColorOrder::BGR),
801        )?;
802
803        let rgb_out = rgb.apply(vec![DynamicImage::ImageRgb8(img.clone())]);
804        let bgr_out = bgr.apply(vec![DynamicImage::ImageRgb8(img)]);
805
806        assert_eq!(rgb_out.len(), 1);
807        assert_eq!(bgr_out.len(), 1);
808        assert_eq!(rgb_out[0], vec![10.0, 20.0, 30.0]);
809        assert_eq!(bgr_out[0], vec![30.0, 20.0, 10.0]);
810        Ok(())
811    }
812
813    #[test]
814    fn test_normalize_image_mean_std_applied_in_output_channel_order() -> Result<(), OCRError> {
815        let mut img = RgbImage::new(1, 1);
816        img.put_pixel(0, 0, Rgb([11, 22, 33])); // R, G, B
817
818        let rgb = NormalizeImage::with_color_order(
819            Some(1.0),
820            Some(vec![1.0, 2.0, 3.0]), // RGB means
821            Some(vec![2.0, 4.0, 5.0]), // RGB stds
822            Some(TensorLayout::CHW),
823            Some(ColorOrder::RGB),
824        )?;
825        let bgr = NormalizeImage::with_color_order(
826            Some(1.0),
827            Some(vec![3.0, 2.0, 1.0]), // BGR means
828            Some(vec![5.0, 4.0, 2.0]), // BGR stds
829            Some(TensorLayout::CHW),
830            Some(ColorOrder::BGR),
831        )?;
832
833        let rgb_out = rgb.apply(vec![DynamicImage::ImageRgb8(img.clone())]);
834        let bgr_out = bgr.apply(vec![DynamicImage::ImageRgb8(img)]);
835
836        assert_eq!(rgb_out[0], vec![5.0, 5.0, 6.0]); // (R-1)/2, (G-2)/4, (B-3)/5
837        assert_eq!(bgr_out[0], vec![6.0, 5.0, 5.0]); // (B-3)/5, (G-2)/4, (R-1)/2
838        Ok(())
839    }
840}