use crate::core::{error::BellandeError, tensor::Tensor};
pub trait Preprocessor: Send + Sync {
fn process(&self, tensor: &Tensor) -> Result<Tensor, BellandeError>;
}
pub struct Normalize {
mean: Vec<f32>,
std: Vec<f32>,
}
impl Normalize {
pub fn new(mean: Vec<f32>, std: Vec<f32>) -> Self {
assert_eq!(mean.len(), std.len());
Normalize { mean, std }
}
}
impl Preprocessor for Normalize {
fn process(&self, tensor: &Tensor) -> Result<Tensor, BellandeError> {
if tensor.shape.len() != 4 {
return Err(BellandeError::InvalidShape(format!("Preprocessor Invalid")))?;
}
let (batch_size, channels, height, width) = (
tensor.shape[0],
tensor.shape[1],
tensor.shape[2],
tensor.shape[3],
);
assert_eq!(channels, self.mean.len());
let mut normalized = tensor.data.clone();
for b in 0..batch_size {
for c in 0..channels {
for h in 0..height {
for w in 0..width {
let idx = ((b * channels + c) * height + h) * width + w;
normalized[idx] = (normalized[idx] - self.mean[c]) / self.std[c];
}
}
}
}
Ok(Tensor::new(
normalized,
tensor.shape.clone(),
tensor.requires_grad,
tensor.device.clone(),
tensor.dtype,
))
}
}