use crate::handle::LcgRng;
#[inline]
fn sample_factor(mag: f32, rng: &mut LcgRng) -> f32 {
let lo = (1.0 - mag).max(0.0);
let hi = 1.0 + mag;
lo + rng.next_f32() * (hi - lo)
}
pub fn color_jitter(
img: &[f32],
channels: usize,
h: usize,
w: usize,
brightness: f32,
contrast: f32,
saturation: f32,
rng: &mut LcgRng,
) -> Vec<f32> {
let n_pixels = channels * h * w;
let mut out: Vec<f32> = img.to_vec();
{
let fb = sample_factor(brightness, rng);
for v in &mut out {
*v *= fb;
}
}
{
let fc = sample_factor(contrast, rng);
let mean: f32 = if n_pixels == 0 {
0.0
} else {
out.iter().sum::<f32>() / n_pixels as f32
};
for v in &mut out {
*v = mean + fc * (*v - mean);
}
}
if channels == 3 {
let fs = sample_factor(saturation, rng);
let hw = h * w;
for i in 0..hw {
let r = out[i];
let g = out[hw + i];
let b = out[2 * hw + i];
let gray = 0.299 * r + 0.587 * g + 0.114 * b;
let one_minus_fs = 1.0 - fs;
out[i] = one_minus_fs * gray + fs * r;
out[hw + i] = one_minus_fs * gray + fs * g;
out[2 * hw + i] = one_minus_fs * gray + fs * b;
}
}
out
}
pub fn random_grayscale(
img: &[f32],
channels: usize,
h: usize,
w: usize,
prob: f32,
rng: &mut LcgRng,
) -> Vec<f32> {
if channels != 3 {
return img.to_vec();
}
let do_gray = if prob <= 0.0 {
false
} else if prob >= 1.0 {
true
} else {
rng.next_f32() < prob
};
if !do_gray {
return img.to_vec();
}
let hw = h * w;
let mut out = vec![0.0f32; 3 * hw];
for i in 0..hw {
let r = img[i];
let g = img[hw + i];
let b = img[2 * hw + i];
let y = 0.299 * r + 0.587 * g + 0.114 * b;
out[i] = y;
out[hw + i] = y;
out[2 * hw + i] = y;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use crate::handle::LcgRng;
fn const_rgb_image(r: f32, g: f32, b: f32, h: usize, w: usize) -> Vec<f32> {
let hw = h * w;
let mut img = vec![0.0f32; 3 * hw];
for i in 0..hw {
img[i] = r;
img[hw + i] = g;
img[2 * hw + i] = b;
}
img
}
#[test]
fn color_jitter_output_finite() {
let img = const_rgb_image(0.5, 0.4, 0.3, 8, 8);
let mut rng = LcgRng::new(42);
let out = color_jitter(&img, 3, 8, 8, 0.2, 0.2, 0.2, &mut rng);
assert!(
out.iter().all(|v| v.is_finite()),
"color_jitter produced non-finite values"
);
}
#[test]
fn color_jitter_zero_magnitude_preserves_values() {
let img = const_rgb_image(0.5, 0.5, 0.5, 4, 4);
let mut rng = LcgRng::new(1);
let out = color_jitter(&img, 3, 4, 4, 0.0, 0.0, 0.0, &mut rng);
for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
assert!((a - b).abs() < 1e-6, "pixel {i}: expected {a}, got {b}");
}
}
#[test]
fn color_jitter_output_shape_preserved() {
let img: Vec<f32> = (0..3 * 16 * 16).map(|i| i as f32 / 100.0).collect();
let mut rng = LcgRng::new(99);
let out = color_jitter(&img, 3, 16, 16, 0.4, 0.4, 0.4, &mut rng);
assert_eq!(
out.len(),
img.len(),
"color_jitter must preserve buffer length"
);
}
#[test]
fn color_jitter_single_channel_skips_saturation() {
let img = vec![0.5f32; 8 * 8];
let mut rng = LcgRng::new(7);
let out = color_jitter(&img, 1, 8, 8, 0.1, 0.1, 0.1, &mut rng);
assert_eq!(out.len(), img.len());
assert!(out.iter().all(|v| v.is_finite()));
}
#[test]
fn color_jitter_brightness_scales_uniformly() {
let img = vec![0.5f32; 3 * 4 * 4];
let mut rng = LcgRng::new(3);
let out = color_jitter(&img, 3, 4, 4, 0.5, 0.0, 0.0, &mut rng);
let first = out[0];
assert!(
out.iter().all(|&v| (v - first).abs() < 1e-6),
"brightness jitter should preserve uniformity of constant image"
);
}
#[test]
fn color_jitter_contrast_constant_image_unchanged() {
let img = vec![0.8f32; 3 * 4 * 4];
let mut rng = LcgRng::new(5);
let out = color_jitter(&img, 3, 4, 4, 0.0, 0.8, 0.0, &mut rng);
for (i, (&a, &b)) in img.iter().zip(out.iter()).enumerate() {
assert!(
(a - b).abs() < 1e-5,
"pixel {i}: constant image should be unchanged by contrast jitter"
);
}
}
#[test]
fn grayscale_outputs_equal_channels() {
let img = const_rgb_image(0.8, 0.5, 0.2, 8, 8);
let mut rng = LcgRng::new(10);
let out = random_grayscale(&img, 3, 8, 8, 1.0, &mut rng);
let hw = 8 * 8;
for i in 0..hw {
let r_out = out[i];
let g_out = out[hw + i];
let b_out = out[2 * hw + i];
assert!(
(r_out - g_out).abs() < 1e-6 && (g_out - b_out).abs() < 1e-6,
"pixel {i}: R={r_out}, G={g_out}, B={b_out} not equal after grayscale"
);
}
}
#[test]
fn grayscale_prob_zero_returns_unchanged() {
let img = const_rgb_image(0.9, 0.3, 0.1, 4, 4);
let mut rng = LcgRng::new(11);
let out = random_grayscale(&img, 3, 4, 4, 0.0, &mut rng);
assert_eq!(out, img, "prob=0 should not modify image");
}
#[test]
fn grayscale_non_rgb_returns_clone() {
let img = vec![0.5f32; 8 * 8];
let mut rng = LcgRng::new(12);
let out = random_grayscale(&img, 1, 8, 8, 1.0, &mut rng);
assert_eq!(out, img, "non-3-channel image should be returned unchanged");
}
#[test]
fn grayscale_output_shape_preserved() {
let img = const_rgb_image(0.6, 0.4, 0.2, 16, 16);
let mut rng = LcgRng::new(13);
let out = random_grayscale(&img, 3, 16, 16, 0.5, &mut rng);
assert_eq!(
out.len(),
img.len(),
"grayscale output should preserve buffer length"
);
}
#[test]
fn grayscale_luminance_correct() {
let img = vec![1.0f32, 0.0, 0.0]; let mut rng = LcgRng::new(14);
let out = random_grayscale(&img, 3, 1, 1, 1.0, &mut rng);
let expected_y = 0.299_f32;
assert!(
(out[0] - expected_y).abs() < 1e-5,
"R channel: expected {expected_y}, got {}",
out[0]
);
assert!(
(out[1] - expected_y).abs() < 1e-5,
"G channel: expected {expected_y}, got {}",
out[1]
);
assert!(
(out[2] - expected_y).abs() < 1e-5,
"B channel: expected {expected_y}, got {}",
out[2]
);
}
#[test]
fn grayscale_output_finite() {
let mut rng_gen = LcgRng::new(50);
let mut img = vec![0.0f32; 3 * 32 * 32];
rng_gen.fill_normal(&mut img);
let mut rng = LcgRng::new(51);
let out = random_grayscale(&img, 3, 32, 32, 1.0, &mut rng);
assert!(
out.iter().all(|v| v.is_finite()),
"grayscale output not finite"
);
}
}