use crate::error::{SslError, SslResult};
use crate::handle::LcgRng;
fn build_gaussian_kernel(sigma: f64) -> SslResult<Vec<f32>> {
if !sigma.is_finite() || sigma <= 0.0 {
return Err(SslError::InvalidParameter {
name: "sigma".into(),
reason: format!("must be positive and finite, got {sigma}"),
});
}
let r = (3.0 * sigma).ceil() as usize;
let k = 2 * r + 1;
let two_sigma_sq = 2.0 * sigma * sigma;
let mut weights: Vec<f64> = (0..k)
.map(|i| {
let d = i as f64 - r as f64;
(-d * d / two_sigma_sq).exp()
})
.collect();
let sum: f64 = weights.iter().sum();
for w in &mut weights {
*w /= sum;
}
Ok(weights.iter().map(|&w| w as f32).collect())
}
fn convolve_horizontal(src: &[f32], dst: &mut [f32], height: usize, width: usize, kernel: &[f32]) {
let r = kernel.len() / 2;
for y in 0..height {
let row_off = y * width;
for x in 0..width {
let mut acc = 0.0_f32;
for (ki, &kw) in kernel.iter().enumerate() {
let src_x = (x + ki).saturating_sub(r).min(width - 1);
acc += src[row_off + src_x] * kw;
}
dst[row_off + x] = acc;
}
}
}
fn convolve_vertical(src: &[f32], dst: &mut [f32], height: usize, width: usize, kernel: &[f32]) {
let r = kernel.len() / 2;
for y in 0..height {
for x in 0..width {
let mut acc = 0.0_f32;
for (ki, &kw) in kernel.iter().enumerate() {
let src_y = (y + ki).saturating_sub(r).min(height - 1);
acc += src[src_y * width + x] * kw;
}
dst[y * width + x] = acc;
}
}
}
pub fn gaussian_blur_chw(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
sigma: f64,
) -> SslResult<Vec<f32>> {
if channels == 0 || height == 0 || width == 0 {
return Err(SslError::EmptyInput);
}
let expected = channels * height * width;
if pixels.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: pixels.len(),
});
}
let kernel = build_gaussian_kernel(sigma)?;
let plane = height * width;
let mut after_h = vec![0.0_f32; plane];
let mut output = vec![0.0_f32; expected];
for c in 0..channels {
let src_plane = &pixels[c * plane..(c + 1) * plane];
let dst_plane = &mut output[c * plane..(c + 1) * plane];
convolve_horizontal(src_plane, &mut after_h, height, width, &kernel);
convolve_vertical(&after_h, dst_plane, height, width, &kernel);
}
Ok(output)
}
pub fn random_gaussian_blur_chw(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
sigma_min: f64,
sigma_max: f64,
rng: &mut LcgRng,
) -> SslResult<Vec<f32>> {
if sigma_min <= 0.0 || !sigma_min.is_finite() {
return Err(SslError::InvalidParameter {
name: "sigma_min".into(),
reason: format!("must be positive and finite, got {sigma_min}"),
});
}
if sigma_max <= sigma_min || !sigma_max.is_finite() {
return Err(SslError::InvalidParameter {
name: "sigma_max".into(),
reason: format!("must be > sigma_min ({sigma_min}) and finite, got {sigma_max}"),
});
}
let u = rng.next_f32() as f64;
let sigma = sigma_min + u * (sigma_max - sigma_min);
gaussian_blur_chw(pixels, channels, height, width, sigma)
}
pub fn solarize(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
threshold: f32,
) -> SslResult<Vec<f32>> {
if channels == 0 || height == 0 || width == 0 {
return Err(SslError::EmptyInput);
}
let expected = channels * height * width;
if pixels.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: pixels.len(),
});
}
if !threshold.is_finite() || !(0.0..=1.0).contains(&threshold) {
return Err(SslError::InvalidParameter {
name: "threshold".into(),
reason: format!("must be in [0, 1] and finite, got {threshold}"),
});
}
let out: Vec<f32> = pixels
.iter()
.map(|&p| if p < threshold { p } else { 1.0 - p })
.collect();
Ok(out)
}
pub fn random_solarize(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
threshold: f32,
probability: f32,
rng: &mut LcgRng,
) -> SslResult<Vec<f32>> {
if !probability.is_finite() || !(0.0..=1.0).contains(&probability) {
return Err(SslError::InvalidParameter {
name: "probability".into(),
reason: format!("must be in [0, 1] and finite, got {probability}"),
});
}
if rng.next_f32() >= probability {
if channels == 0 || height == 0 || width == 0 {
return Err(SslError::EmptyInput);
}
let expected = channels * height * width;
if pixels.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: pixels.len(),
});
}
return Ok(pixels.to_vec());
}
solarize(pixels, channels, height, width, threshold)
}
pub fn add_gaussian_noise(pixels: &[f32], std_dev: f32, rng: &mut LcgRng) -> SslResult<Vec<f32>> {
if pixels.is_empty() {
return Err(SslError::EmptyInput);
}
if !std_dev.is_finite() || std_dev < 0.0 {
return Err(SslError::InvalidParameter {
name: "std_dev".into(),
reason: format!("must be >= 0 and finite, got {std_dev}"),
});
}
let n = pixels.len();
let mut out = pixels.to_vec();
let mut i = 0;
while i + 1 < n {
let (a, b) = rng.next_normal_pair();
out[i] = (pixels[i] + a * std_dev).clamp(0.0, 1.0);
out[i + 1] = (pixels[i + 1] + b * std_dev).clamp(0.0, 1.0);
i += 2;
}
if i < n {
let (a, _) = rng.next_normal_pair();
out[i] = (pixels[i] + a * std_dev).clamp(0.0, 1.0);
}
Ok(out)
}
#[derive(Debug, Clone, PartialEq)]
pub struct SimClrBlurSolarConfig {
pub blur_sigma_min: f64,
pub blur_sigma_max: f64,
pub blur_prob: f32,
pub solar_threshold: f32,
pub solar_prob: f32,
}
impl Default for SimClrBlurSolarConfig {
fn default() -> Self {
Self {
blur_sigma_min: 0.1,
blur_sigma_max: 2.0,
blur_prob: 0.5,
solar_threshold: 0.5,
solar_prob: 0.0,
}
}
}
pub fn simclr_blur_solar(
pixels: &[f32],
channels: usize,
height: usize,
width: usize,
config: &SimClrBlurSolarConfig,
rng: &mut LcgRng,
) -> SslResult<Vec<f32>> {
if channels == 0 || height == 0 || width == 0 {
return Err(SslError::EmptyInput);
}
let expected = channels * height * width;
if pixels.len() != expected {
return Err(SslError::DimensionMismatch {
expected,
got: pixels.len(),
});
}
if !config.blur_prob.is_finite() || !(0.0..=1.0).contains(&config.blur_prob) {
return Err(SslError::InvalidParameter {
name: "blur_prob".into(),
reason: format!("must be in [0, 1], got {}", config.blur_prob),
});
}
if !config.solar_prob.is_finite() || !(0.0..=1.0).contains(&config.solar_prob) {
return Err(SslError::InvalidParameter {
name: "solar_prob".into(),
reason: format!("must be in [0, 1], got {}", config.solar_prob),
});
}
let after_blur: Vec<f32> = if rng.next_f32() < config.blur_prob {
random_gaussian_blur_chw(
pixels,
channels,
height,
width,
config.blur_sigma_min,
config.blur_sigma_max,
rng,
)?
} else {
pixels.to_vec()
};
let after_solar: Vec<f32> = if rng.next_f32() < config.solar_prob {
solarize(&after_blur, channels, height, width, config.solar_threshold)?
} else {
after_blur
};
Ok(after_solar)
}
#[cfg(test)]
mod tests {
use super::*;
fn variance(v: &[f32]) -> f32 {
let n = v.len() as f32;
let mean = v.iter().sum::<f32>() / n;
v.iter().map(|&x| (x - mean) * (x - mean)).sum::<f32>() / n
}
#[test]
fn gaussian_blur_preserves_mean() {
let (c, h, w) = (1, 32, 32);
let pixels = vec![0.5_f32; c * h * w];
let blurred =
gaussian_blur_chw(&pixels, c, h, w, 1.0).expect("gaussian_blur_chw should succeed");
let orig_sum: f32 = pixels.iter().sum();
let blur_sum: f32 = blurred.iter().sum();
assert!(
(blur_sum - orig_sum).abs() < 1e-3,
"orig_sum={orig_sum}, blur_sum={blur_sum}"
);
}
#[test]
fn gaussian_blur_shape_preserved() {
let (c, h, w) = (3, 16, 24);
let pixels = vec![0.3_f32; c * h * w];
let out =
gaussian_blur_chw(&pixels, c, h, w, 0.8).expect("gaussian_blur_chw should succeed");
assert_eq!(out.len(), c * h * w);
}
#[test]
fn gaussian_blur_identity_sigma_near_zero() {
let (c, h, w) = (1, 8, 8);
let mut rng = LcgRng::new(42);
let pixels: Vec<f32> = (0..c * h * w).map(|_| rng.next_f32()).collect();
let out =
gaussian_blur_chw(&pixels, c, h, w, 1e-5).expect("gaussian_blur_chw should succeed");
for (a, b) in pixels.iter().zip(out.iter()) {
assert!((a - b).abs() < 1e-4, "a={a} b={b}");
}
}
#[test]
fn gaussian_blur_decreases_variance() {
let (c, h, w) = (1, 16, 16);
let pixels: Vec<f32> = (0..c * h * w)
.map(|i| {
if (i + i / w) % 2 == 0 {
0.0_f32
} else {
1.0_f32
}
})
.collect();
let orig_var = variance(&pixels);
let blurred =
gaussian_blur_chw(&pixels, c, h, w, 2.0).expect("gaussian_blur_chw should succeed");
let blur_var = variance(&blurred);
assert!(
blur_var < orig_var,
"expected variance reduction: orig={orig_var}, blurred={blur_var}"
);
}
#[test]
fn gaussian_blur_negative_sigma_error() {
let pixels = vec![0.5_f32; 16];
let result = gaussian_blur_chw(&pixels, 1, 4, 4, -1.0);
assert!(
matches!(result, Err(SslError::InvalidParameter { .. })),
"expected InvalidParameter, got {:?}",
result
);
}
#[test]
fn gaussian_blur_zero_sigma_error() {
let pixels = vec![0.5_f32; 16];
let result = gaussian_blur_chw(&pixels, 1, 4, 4, 0.0);
assert!(matches!(result, Err(SslError::InvalidParameter { .. })));
}
#[test]
fn gaussian_blur_multichannel_independent() {
let (c, h, w) = (3, 10, 10);
let mut pixels = vec![0.0_f32; c * h * w];
let plane = h * w;
for ch in 0..c {
for p in pixels[ch * plane..(ch + 1) * plane].iter_mut() {
*p = (ch as f32 + 1.0) * 0.25;
}
}
let out =
gaussian_blur_chw(&pixels, c, h, w, 1.5).expect("gaussian_blur_chw should succeed");
for ch in 0..c {
let expected_val = (ch as f32 + 1.0) * 0.25;
for &v in &out[ch * plane..(ch + 1) * plane] {
assert!(
(v - expected_val).abs() < 1e-4,
"ch={ch} expected={expected_val} got={v}"
);
}
}
}
#[test]
fn solarize_below_threshold_unchanged() {
let pixels = vec![0.3_f32; 16];
let out = solarize(&pixels, 1, 4, 4, 0.5).expect("solarize should succeed");
for &v in &out {
assert!((v - 0.3).abs() < 1e-6, "v={v}");
}
}
#[test]
fn solarize_above_threshold_inverted() {
let pixels = vec![0.8_f32; 16];
let out = solarize(&pixels, 1, 4, 4, 0.5).expect("solarize should succeed");
for &v in &out {
assert!((v - 0.2).abs() < 1e-6, "v={v}");
}
}
#[test]
fn solarize_at_threshold_inverted() {
let threshold = 0.5_f32;
let pixels = vec![threshold; 16];
let out = solarize(&pixels, 1, 4, 4, threshold).expect("solarize should succeed");
for &v in &out {
assert!((v - (1.0 - threshold)).abs() < 1e-6, "v={v}");
}
}
#[test]
fn solarize_preserves_shape() {
let (c, h, w) = (3, 8, 12);
let pixels = vec![0.6_f32; c * h * w];
let out = solarize(&pixels, c, h, w, 0.5).expect("solarize should succeed");
assert_eq!(out.len(), c * h * w);
}
#[test]
fn random_solarize_prob_zero() {
let pixels: Vec<f32> = (0..64).map(|i| i as f32 / 64.0).collect();
let mut rng = LcgRng::new(99);
let out = random_solarize(&pixels, 1, 8, 8, 0.5, 0.0, &mut rng)
.expect("random_solarize should succeed");
assert_eq!(out, pixels);
}
#[test]
fn random_solarize_prob_one() {
let pixels: Vec<f32> = (0..64).map(|i| i as f32 / 64.0).collect();
let mut rng = LcgRng::new(7);
let out = random_solarize(&pixels, 1, 8, 8, 0.5, 1.0, &mut rng)
.expect("random_solarize should succeed");
let expected = solarize(&pixels, 1, 8, 8, 0.5).expect("solarize should succeed");
assert_eq!(out, expected);
}
#[test]
fn add_gaussian_noise_output_in_range() {
let pixels = vec![0.5_f32; 3 * 16 * 16];
let mut rng = LcgRng::new(123);
let out =
add_gaussian_noise(&pixels, 0.2, &mut rng).expect("add_gaussian_noise should succeed");
for &v in &out {
assert!((0.0..=1.0).contains(&v), "out-of-range pixel: {v}");
}
}
#[test]
fn add_gaussian_noise_nonzero_std_changes_image() {
let pixels = vec![0.5_f32; 1024];
let mut rng = LcgRng::new(55);
let out =
add_gaussian_noise(&pixels, 0.1, &mut rng).expect("add_gaussian_noise should succeed");
let changed = pixels
.iter()
.zip(out.iter())
.filter(|(a, b)| *a != *b)
.count();
assert!(changed > 0, "no pixels changed with std_dev=0.1");
}
#[test]
fn add_gaussian_noise_zero_std_unchanged() {
let pixels = vec![0.4_f32; 64];
let mut rng = LcgRng::new(11);
let out =
add_gaussian_noise(&pixels, 0.0, &mut rng).expect("add_gaussian_noise should succeed");
for (a, b) in pixels.iter().zip(out.iter()) {
assert!((a - b).abs() < 1e-7, "a={a} b={b}");
}
}
#[test]
fn add_gaussian_noise_negative_std_error() {
let pixels = vec![0.5_f32; 16];
let mut rng = LcgRng::new(1);
let result = add_gaussian_noise(&pixels, -0.1, &mut rng);
assert!(matches!(result, Err(SslError::InvalidParameter { .. })));
}
#[test]
fn simclr_blur_solar_shape() {
let (c, h, w) = (3, 16, 16);
let pixels = vec![0.5_f32; c * h * w];
let cfg = SimClrBlurSolarConfig::default();
let mut rng = LcgRng::new(42);
let out = simclr_blur_solar(&pixels, c, h, w, &cfg, &mut rng)
.expect("simclr_blur_solar should succeed");
assert_eq!(out.len(), c * h * w);
}
#[test]
fn simclr_blur_solar_both_probs_one() {
let (c, h, w) = (1, 8, 8);
let pixels: Vec<f32> = (0..c * h * w).map(|i| (i as f32 * 0.01).min(1.0)).collect();
let cfg = SimClrBlurSolarConfig {
blur_sigma_min: 0.5,
blur_sigma_max: 1.0,
blur_prob: 1.0,
solar_threshold: 0.5,
solar_prob: 1.0,
};
let mut rng = LcgRng::new(77);
let out = simclr_blur_solar(&pixels, c, h, w, &cfg, &mut rng)
.expect("simclr_blur_solar should succeed");
assert_eq!(out.len(), pixels.len());
let identical = pixels
.iter()
.zip(out.iter())
.all(|(a, b)| (a - b).abs() < 1e-8);
assert!(!identical, "expected pipeline to modify the image");
}
#[test]
fn simclr_blur_solar_probs_zero_passthrough() {
let (c, h, w) = (1, 4, 4);
let pixels: Vec<f32> = (0..c * h * w)
.map(|i| i as f32 / (c * h * w) as f32)
.collect();
let cfg = SimClrBlurSolarConfig {
blur_sigma_min: 0.1,
blur_sigma_max: 2.0,
blur_prob: 0.0,
solar_threshold: 0.5,
solar_prob: 0.0,
};
let mut rng = LcgRng::new(3);
let out = simclr_blur_solar(&pixels, c, h, w, &cfg, &mut rng)
.expect("simclr_blur_solar should succeed");
assert_eq!(out, pixels);
}
#[test]
fn empty_image_error_blur() {
let result = gaussian_blur_chw(&[], 0, 8, 8, 1.0);
assert!(matches!(result, Err(SslError::EmptyInput)));
}
#[test]
fn empty_image_error_solarize() {
let result = solarize(&[], 0, 8, 8, 0.5);
assert!(matches!(result, Err(SslError::EmptyInput)));
}
#[test]
fn dim_mismatch_blur() {
let pixels = vec![0.5_f32; 10]; let result = gaussian_blur_chw(&pixels, 1, 4, 4, 1.0);
assert!(matches!(result, Err(SslError::DimensionMismatch { .. })));
}
#[test]
fn random_gaussian_blur_sigma_range_error() {
let pixels = vec![0.5_f32; 16];
let mut rng = LcgRng::new(1);
let result = random_gaussian_blur_chw(&pixels, 1, 4, 4, 2.0, 0.5, &mut rng);
assert!(matches!(result, Err(SslError::InvalidParameter { .. })));
}
}