1use crate::core::OCRError;
8use crate::processors::types::ChannelOrder;
9use image::DynamicImage;
10use rayon::prelude::*;
11
12#[derive(Debug)]
18pub struct NormalizeImage {
19 pub alpha: Vec<f32>,
21 pub beta: Vec<f32>,
23 pub order: ChannelOrder,
25}
26
27impl NormalizeImage {
28 pub fn new(
48 scale: Option<f32>,
49 mean: Option<Vec<f32>>,
50 std: Option<Vec<f32>>,
51 order: Option<ChannelOrder>,
52 ) -> Result<Self, OCRError> {
53 let scale = scale.unwrap_or(1.0 / 255.0);
54 let mean = mean.unwrap_or_else(|| vec![0.485, 0.456, 0.406]);
55 let std = std.unwrap_or_else(|| vec![0.229, 0.224, 0.225]);
56 let order = order.unwrap_or(ChannelOrder::CHW);
57
58 if scale <= 0.0 {
59 return Err(OCRError::ConfigError {
60 message: "Scale must be greater than 0".to_string(),
61 });
62 }
63
64 if mean.len() != 3 {
65 return Err(OCRError::ConfigError {
66 message: "Mean must have exactly 3 elements for RGB".to_string(),
67 });
68 }
69
70 if std.len() != 3 {
71 return Err(OCRError::ConfigError {
72 message: "Std must have exactly 3 elements for RGB".to_string(),
73 });
74 }
75
76 for (i, &s) in std.iter().enumerate() {
77 if s <= 0.0 {
78 return Err(OCRError::ConfigError {
79 message: format!(
80 "Standard deviation at index {i} must be greater than 0, got {s}"
81 ),
82 });
83 }
84 }
85
86 let alpha: Vec<f32> = std.iter().map(|s| scale / s).collect();
87 let beta: Vec<f32> = mean.iter().zip(&std).map(|(m, s)| -m / s).collect();
88
89 Ok(Self { alpha, beta, order })
90 }
91
92 pub fn validate_config(&self) -> Result<(), OCRError> {
104 if self.alpha.len() != 3 || self.beta.len() != 3 {
105 return Err(OCRError::ConfigError {
106 message: "Alpha and beta must have exactly 3 elements for RGB".to_string(),
107 });
108 }
109
110 for (i, &alpha) in self.alpha.iter().enumerate() {
111 if !alpha.is_finite() {
112 return Err(OCRError::ConfigError {
113 message: format!("Alpha value at index {i} is not finite: {alpha}"),
114 });
115 }
116 }
117
118 for (i, &beta) in self.beta.iter().enumerate() {
119 if !beta.is_finite() {
120 return Err(OCRError::ConfigError {
121 message: format!("Beta value at index {i} is not finite: {beta}"),
122 });
123 }
124 }
125
126 Ok(())
127 }
128
129 pub fn for_ocr_recognition() -> Result<Self, OCRError> {
141 Self::new(
142 Some(2.0 / 255.0),
143 Some(vec![1.0, 1.0, 1.0]),
144 Some(vec![1.0, 1.0, 1.0]),
145 Some(ChannelOrder::CHW),
146 )
147 }
148
149 pub fn apply(&self, imgs: Vec<DynamicImage>) -> Vec<Vec<f32>> {
159 imgs.into_iter().map(|img| self.normalize(img)).collect()
160 }
161
162 fn validate_batch_inputs(
174 &self,
175 imgs_len: usize,
176 shapes: &[(usize, usize, usize)],
177 batch_tensor: &[f32],
178 ) -> Result<(usize, usize, usize, usize), OCRError> {
179 if imgs_len != shapes.len() {
180 return Err(OCRError::InvalidInput {
181 message: format!(
182 "Images and shapes length mismatch: {} images vs {} shapes",
183 imgs_len,
184 shapes.len()
185 ),
186 });
187 }
188
189 let batch_size = imgs_len;
190 if batch_size == 0 {
191 return Ok((0, 0, 0, 0));
192 }
193
194 let max_width = shapes.iter().map(|(_, _, w)| *w).max().unwrap_or(0);
195 let channels = shapes.first().map(|(c, _, _)| *c).unwrap_or(0);
196 let height = shapes.first().map(|(_, h, _)| *h).unwrap_or(0);
197 let img_size = channels * height * max_width;
198
199 if batch_tensor.len() < batch_size * img_size {
200 return Err(OCRError::BufferTooSmall {
201 expected: batch_size * img_size,
202 actual: batch_tensor.len(),
203 });
204 }
205
206 Ok((batch_size, channels, height, max_width))
207 }
208
209 pub fn apply_to_batch(
221 &self,
222 imgs: Vec<DynamicImage>,
223 batch_tensor: &mut [f32],
224 shapes: &[(usize, usize, usize)],
225 ) -> Result<(), OCRError> {
226 let (batch_size, channels, height, max_width) =
227 self.validate_batch_inputs(imgs.len(), shapes, batch_tensor)?;
228
229 if batch_size == 0 {
230 return Ok(());
231 }
232
233 let img_size = channels * height * max_width;
234
235 for (batch_idx, (img, &(_c, h, w))) in imgs.into_iter().zip(shapes.iter()).enumerate() {
236 let normalized_img = self.normalize(img);
237
238 let batch_offset = batch_idx * img_size;
239
240 for ch in 0.._c {
241 for y in 0..h {
242 for x in 0..w {
243 let src_idx = ch * h * w + y * w + x;
244 let dst_idx = batch_offset + ch * height * max_width + y * max_width + x;
245 if src_idx < normalized_img.len() && dst_idx < batch_tensor.len() {
246 batch_tensor[dst_idx] = normalized_img[src_idx];
247 }
248 }
249 }
250 }
251 }
252
253 Ok(())
254 }
255
256 pub fn normalize_streaming_to_batch(
269 &self,
270 imgs: Vec<DynamicImage>,
271 batch_tensor: &mut [f32],
272 shapes: &[(usize, usize, usize)],
273 ) -> Result<(), OCRError> {
274 let (batch_size, channels, height, max_width) =
275 self.validate_batch_inputs(imgs.len(), shapes, batch_tensor)?;
276
277 if batch_size == 0 {
278 return Ok(());
279 }
280
281 let img_size = channels * height * max_width;
282 batch_tensor.fill(0.0);
283
284 for (batch_idx, (img, &(_c, h, w))) in imgs.into_iter().zip(shapes.iter()).enumerate() {
285 let rgb_img = img.to_rgb8();
286 let (width, height_img) = rgb_img.dimensions();
287 let batch_offset = batch_idx * img_size;
288
289 match self.order {
290 ChannelOrder::CHW => {
291 for c in 0..channels.min(3) {
292 for y in 0..h.min(height_img as usize) {
293 for x in 0..w.min(width as usize) {
294 let pixel = rgb_img.get_pixel(x as u32, y as u32);
295 let channel_value = pixel[c] as f32;
296 let dst_idx =
297 batch_offset + c * height * max_width + y * max_width + x;
298 if dst_idx < batch_tensor.len() {
299 batch_tensor[dst_idx] =
300 channel_value * self.alpha[c] + self.beta[c];
301 }
302 }
303 }
304 }
305 }
306 ChannelOrder::HWC => {
307 for y in 0..h.min(height_img as usize) {
308 for x in 0..w.min(width as usize) {
309 let pixel = rgb_img.get_pixel(x as u32, y as u32);
310 for c in 0..channels.min(3) {
311 let channel_value = pixel[c] as f32;
312 let dst_idx =
313 batch_offset + y * max_width * channels + x * channels + c;
314 if dst_idx < batch_tensor.len() {
315 batch_tensor[dst_idx] =
316 channel_value * self.alpha[c] + self.beta[c];
317 }
318 }
319 }
320 }
321 }
322 }
323 }
324
325 Ok(())
326 }
327
328 fn normalize(&self, img: DynamicImage) -> Vec<f32> {
338 let rgb_img = img.to_rgb8();
339 let (width, height) = rgb_img.dimensions();
340 let channels = 3;
341
342 match self.order {
343 ChannelOrder::CHW => {
344 let mut result = vec![0.0f32; (channels * height * width) as usize];
345
346 for c in 0..channels {
347 for y in 0..height {
348 for x in 0..width {
349 let pixel = rgb_img.get_pixel(x, y);
350 let channel_value = pixel[c as usize] as f32;
351 let dst_idx = (c * height * width + y * width + x) as usize;
352
353 result[dst_idx] =
354 channel_value * self.alpha[c as usize] + self.beta[c as usize];
355 }
356 }
357 }
358 result
359 }
360 ChannelOrder::HWC => {
361 let mut result = vec![0.0f32; (height * width * channels) as usize];
362
363 for y in 0..height {
364 for x in 0..width {
365 let pixel = rgb_img.get_pixel(x, y);
366 for c in 0..channels {
367 let channel_value = pixel[c as usize] as f32;
368 let dst_idx = (y * width * channels + x * channels + c) as usize;
369
370 result[dst_idx] =
371 channel_value * self.alpha[c as usize] + self.beta[c as usize];
372 }
373 }
374 }
375 result
376 }
377 }
378 }
379
380 pub fn normalize_to(&self, img: DynamicImage) -> Result<crate::core::Tensor4D, OCRError> {
390 let rgb_img = img.to_rgb8();
391 let (width, height) = rgb_img.dimensions();
392 let channels = 3;
393
394 match self.order {
395 ChannelOrder::CHW => {
396 let mut result = vec![0.0f32; (channels * height * width) as usize];
397
398 for c in 0..channels {
399 for y in 0..height {
400 for x in 0..width {
401 let pixel = rgb_img.get_pixel(x, y);
402 let channel_value = pixel[c as usize] as f32;
403 let dst_idx = (c * height * width + y * width + x) as usize;
404
405 result[dst_idx] =
406 channel_value * self.alpha[c as usize] + self.beta[c as usize];
407 }
408 }
409 }
410
411 ndarray::Array4::from_shape_vec(
412 (1, channels as usize, height as usize, width as usize),
413 result,
414 )
415 .map_err(|e| {
416 OCRError::tensor_operation_error(
417 "normalization_tensor_creation_chw",
418 &[1, channels as usize, height as usize, width as usize],
419 &[(channels * height * width) as usize],
420 &format!("Failed to create CHW normalization tensor for {}x{} image with {} channels",
421 width, height, channels),
422 e,
423 )
424 })
425 }
426 ChannelOrder::HWC => {
427 let mut result = vec![0.0f32; (height * width * channels) as usize];
428
429 for y in 0..height {
430 for x in 0..width {
431 let pixel = rgb_img.get_pixel(x, y);
432 for c in 0..channels {
433 let channel_value = pixel[c as usize] as f32;
434 let dst_idx = (y * width * channels + x * channels + c) as usize;
435
436 result[dst_idx] =
437 channel_value * self.alpha[c as usize] + self.beta[c as usize];
438 }
439 }
440 }
441
442 ndarray::Array4::from_shape_vec(
443 (1, height as usize, width as usize, channels as usize),
444 result,
445 )
446 .map_err(|e| {
447 OCRError::tensor_operation_error(
448 "normalization_tensor_creation_hwc",
449 &[1, height as usize, width as usize, channels as usize],
450 &[(height * width * channels) as usize],
451 &format!("Failed to create HWC normalization tensor for {}x{} image with {} channels",
452 width, height, channels),
453 e,
454 )
455 })
456 }
457 }
458 }
459
460 pub fn normalize_batch_to(
475 &self,
476 imgs: Vec<DynamicImage>,
477 ) -> Result<crate::core::Tensor4D, OCRError> {
478 if imgs.is_empty() {
479 return Ok(ndarray::Array4::zeros((0, 0, 0, 0)));
480 }
481
482 let batch_size = imgs.len();
483
484 let rgb_imgs: Vec<_> = imgs.into_iter().map(|img| img.to_rgb8()).collect();
485 let dimensions: Vec<_> = rgb_imgs.iter().map(|img| img.dimensions()).collect();
486
487 let (first_width, first_height) = dimensions.first().copied().unwrap_or((0, 0));
488 for (i, &(width, height)) in dimensions.iter().enumerate() {
489 if width != first_width || height != first_height {
490 return Err(OCRError::InvalidInput {
491 message: format!(
492 "All images in batch must have the same dimensions. Image 0: {first_width}x{first_height}, Image {i}: {width}x{height}"
493 ),
494 });
495 }
496 }
497
498 let (width, height) = (first_width, first_height);
499 let channels = 3;
500
501 match self.order {
502 ChannelOrder::CHW => {
503 let mut result = vec![0.0f32; batch_size * (channels * height * width) as usize];
504
505 let img_size = (channels * height * width) as usize;
506 if batch_size == 1 {
507 let rgb_img = &rgb_imgs[0];
509 let batch_slice = &mut result[0..img_size];
510 for c in 0..channels {
511 for y in 0..height {
512 for x in 0..width {
513 let pixel = rgb_img.get_pixel(x, y);
514 let channel_value = pixel[c as usize] as f32;
515 let dst_idx = (c * height * width + y * width + x) as usize;
516 batch_slice[dst_idx] =
517 channel_value * self.alpha[c as usize] + self.beta[c as usize];
518 }
519 }
520 }
521 } else {
522 result.par_chunks_mut(img_size).enumerate().for_each(
523 |(batch_idx, batch_slice)| {
524 let rgb_img = &rgb_imgs[batch_idx];
525 for c in 0..channels {
526 for y in 0..height {
527 for x in 0..width {
528 let pixel = rgb_img.get_pixel(x, y);
529 let channel_value = pixel[c as usize] as f32;
530 let dst_idx = (c * height * width + y * width + x) as usize;
531 batch_slice[dst_idx] = channel_value
532 * self.alpha[c as usize]
533 + self.beta[c as usize];
534 }
535 }
536 }
537 },
538 );
539 }
540
541 ndarray::Array4::from_shape_vec(
542 (
543 batch_size,
544 channels as usize,
545 height as usize,
546 width as usize,
547 ),
548 result,
549 )
550 .map_err(|e| {
551 OCRError::tensor_operation(
552 "Failed to create batch normalization tensor in CHW format",
553 e,
554 )
555 })
556 }
557 ChannelOrder::HWC => {
558 let mut result = vec![0.0f32; batch_size * (height * width * channels) as usize];
559
560 let img_size = (height * width * channels) as usize;
561 if batch_size == 1 {
562 let rgb_img = &rgb_imgs[0];
564 let batch_slice = &mut result[0..img_size];
565 for y in 0..height {
566 for x in 0..width {
567 let pixel = rgb_img.get_pixel(x, y);
568 for c in 0..channels {
569 let channel_value = pixel[c as usize] as f32;
570 let dst_idx = (y * width * channels + x * channels + c) as usize;
571 batch_slice[dst_idx] =
572 channel_value * self.alpha[c as usize] + self.beta[c as usize];
573 }
574 }
575 }
576 } else {
577 result.par_chunks_mut(img_size).enumerate().for_each(
578 |(batch_idx, batch_slice)| {
579 let rgb_img = &rgb_imgs[batch_idx];
580 for y in 0..height {
581 for x in 0..width {
582 let pixel = rgb_img.get_pixel(x, y);
583 for c in 0..channels {
584 let channel_value = pixel[c as usize] as f32;
585 let dst_idx =
586 (y * width * channels + x * channels + c) as usize;
587 batch_slice[dst_idx] = channel_value
588 * self.alpha[c as usize]
589 + self.beta[c as usize];
590 }
591 }
592 }
593 },
594 );
595 }
596
597 ndarray::Array4::from_shape_vec(
598 (
599 batch_size,
600 height as usize,
601 width as usize,
602 channels as usize,
603 ),
604 result,
605 )
606 .map_err(|e| {
607 OCRError::tensor_operation(
608 "Failed to create batch normalization tensor in HWC format",
609 e,
610 )
611 })
612 }
613 }
614 }
615}