use crate::ev_augmentation::{AugmentationError, AugmentationResult, Validatable};
#[cfg(unix)]
use tracing::{debug, instrument};
#[cfg(not(unix))]
macro_rules! debug {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! info {
($($args:tt)*) => {};
}
use polars::prelude::*;
pub const COL_X: &str = "x";
pub const COL_Y: &str = "y";
pub const COL_T: &str = "t";
pub const COL_POLARITY: &str = "polarity";
#[derive(Debug, Clone)]
pub struct SpatialJitterAugmentation {
pub var_x: f64,
pub var_y: f64,
pub sigma_xy: f64,
pub clip_outliers: bool,
pub sensor_size: Option<(u16, u16)>,
pub seed: Option<u64>,
}
impl SpatialJitterAugmentation {
pub fn new(var_x: f64, var_y: f64) -> Self {
Self {
var_x,
var_y,
sigma_xy: 0.0,
clip_outliers: false,
sensor_size: None,
seed: None,
}
}
pub fn with_correlation(mut self, sigma_xy: f64) -> Self {
self.sigma_xy = sigma_xy;
self
}
pub fn with_clipping(mut self, sensor_width: u16, sensor_height: u16) -> Self {
self.clip_outliers = true;
self.sensor_size = Some((sensor_width, sensor_height));
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
fn is_valid_covariance(&self) -> bool {
let det = self.var_x * self.var_y - self.sigma_xy * self.sigma_xy;
let trace = self.var_x + self.var_y;
det >= 0.0 && trace >= 0.0
}
pub fn apply_to_dataframe(&self, df: LazyFrame) -> PolarsResult<LazyFrame> {
apply_spatial_jitter(df, self)
}
pub fn apply_to_dataframe_eager(&self, df: DataFrame) -> PolarsResult<DataFrame> {
apply_spatial_jitter(df.lazy(), self)?.collect()
}
pub fn description(&self) -> String {
if self.sigma_xy.abs() < 1e-10 {
format!("σx²={:.2}, σy²={:.2}", self.var_x, self.var_y)
} else {
format!(
"σx²={:.2}, σy²={:.2}, σxy={:.2}",
self.var_x, self.var_y, self.sigma_xy
)
}
}
}
impl Validatable for SpatialJitterAugmentation {
fn validate(&self) -> AugmentationResult<()> {
if self.var_x < 0.0 {
return Err(AugmentationError::InvalidConfig(
"X variance must be non-negative".to_string(),
));
}
if self.var_y < 0.0 {
return Err(AugmentationError::InvalidConfig(
"Y variance must be non-negative".to_string(),
));
}
if !self.is_valid_covariance() {
return Err(AugmentationError::InvalidConfig(
"Covariance matrix must be positive semi-definite".to_string(),
));
}
if self.clip_outliers && self.sensor_size.is_none() {
return Err(AugmentationError::InvalidConfig(
"Sensor size must be specified when clipping is enabled".to_string(),
));
}
Ok(())
}
}
#[cfg_attr(unix, instrument(skip(df), fields(config = ?config)))]
pub fn apply_spatial_jitter(
df: LazyFrame,
config: &SpatialJitterAugmentation,
) -> PolarsResult<LazyFrame> {
debug!("Applying spatial jitter with Polars: {:?}", config);
let _collected_df = df.collect()?;
Err(PolarsError::ComputeError(
"Spatial jitter temporarily disabled - Events type removed".into(),
))
}
pub fn apply_spatial_jitter_polars(
df: LazyFrame,
config: &SpatialJitterAugmentation,
) -> PolarsResult<LazyFrame> {
apply_spatial_jitter(df, config)
}
pub fn apply_spatial_jitter_df(df: LazyFrame, std_x: f64, std_y: f64) -> PolarsResult<LazyFrame> {
let config = SpatialJitterAugmentation::new(std_x * std_x, std_y * std_y);
apply_spatial_jitter(df, &config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spatial_jitter_creation() {
let jitter = SpatialJitterAugmentation::new(1.0, 2.0);
assert_eq!(jitter.var_x, 1.0);
assert_eq!(jitter.var_y, 2.0);
assert_eq!(jitter.sigma_xy, 0.0);
assert!(!jitter.clip_outliers);
}
#[test]
fn test_spatial_jitter_validation() {
let valid = SpatialJitterAugmentation::new(1.0, 1.0);
assert!(valid.validate().is_ok());
let invalid = SpatialJitterAugmentation::new(-1.0, 1.0);
assert!(invalid.validate().is_err());
let invalid = SpatialJitterAugmentation::new(1.0, 1.0)
.with_clipping(640, 480)
.with_clipping(0, 0); let mut invalid_modified = invalid;
invalid_modified.sensor_size = None;
assert!(invalid_modified.validate().is_err());
}
}