use crate::ev_augmentation::{AugmentationError, AugmentationResult, Validatable};
use rand::{Rng, SeedableRng};
#[cfg(unix)]
use tracing::{debug, instrument};
#[cfg(not(unix))]
macro_rules! debug {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! info {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! warn {
($($args:tt)*) => {
eprintln!("[WARN] {}", format!($($args)*))
};
}
#[cfg(not(unix))]
macro_rules! trace {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! error {
($($args:tt)*) => {
eprintln!("[ERROR] {}", format!($($args)*))
};
}
#[cfg(not(unix))]
macro_rules! instrument {
($($args:tt)*) => {};
}
use crate::ev_augmentation::{COL_POLARITY, COL_T};
use polars::prelude::*;
#[derive(Debug, Clone)]
pub struct TimeReversalAugmentation {
pub probability: f64,
pub flip_polarity: bool,
pub seed: Option<u64>,
}
impl TimeReversalAugmentation {
pub fn new(probability: f64) -> Self {
Self {
probability,
flip_polarity: true, seed: None,
}
}
pub fn with_polarity_flip(mut self, flip: bool) -> Self {
self.flip_polarity = flip;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn description(&self) -> String {
format!(
"time_reversal(prob={:.2}, flip_polarity={})",
self.probability, self.flip_polarity
)
}
}
impl Validatable for TimeReversalAugmentation {
fn validate(&self) -> AugmentationResult<()> {
if !(0.0..=1.0).contains(&self.probability) {
Err(AugmentationError::InvalidProbability(self.probability))
} else {
Ok(())
}
}
}
#[cfg_attr(unix, instrument(skip(df), level = "debug"))]
pub fn apply_time_reversal_polars(
df: LazyFrame,
config: &TimeReversalAugmentation,
) -> AugmentationResult<LazyFrame> {
config.validate()?;
let mut rng = if let Some(seed) = config.seed {
rand::rngs::StdRng::seed_from_u64(seed)
} else {
rand::rngs::StdRng::from_entropy()
};
let apply_reversal = rng.gen::<f64>() < config.probability;
if !apply_reversal {
debug!("Time reversal not applied (probability check failed) - Polars");
return Ok(df);
}
debug!(
"Applying time reversal (Polars) with polarity flip: {}",
config.flip_polarity
);
let mut result = df.with_columns([
(col(COL_T).max().over([lit(1)]) - (col(COL_T) - col(COL_T).min().over([lit(1)])))
.alias(COL_T),
]);
if config.flip_polarity {
result = result.with_columns([
(lit(1) - col(COL_POLARITY)).alias(COL_POLARITY),
]);
}
result = result.sort([COL_T], Default::default());
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_events() -> Events {
vec![
Event {
t: 1.0,
x: 100,
y: 200,
polarity: true,
},
Event {
t: 2.0,
x: 150,
y: 250,
polarity: false,
},
Event {
t: 3.0,
x: 200,
y: 300,
polarity: true,
},
Event {
t: 4.0,
x: 250,
y: 350,
polarity: false,
},
]
}
#[test]
fn test_time_reversal_creation() {
let reversal = TimeReversalAugmentation::new(0.5);
assert_eq!(reversal.probability, 0.5);
assert!(reversal.flip_polarity);
}
#[test]
fn test_validation() {
let valid_config = TimeReversalAugmentation::new(0.5);
assert!(valid_config.validate().is_ok());
let invalid_prob = TimeReversalAugmentation::new(1.5);
assert!(invalid_prob.validate().is_err());
let invalid_negative = TimeReversalAugmentation::new(-0.1);
assert!(invalid_negative.validate().is_err());
}
#[test]
fn test_no_reversal() {
let events = create_test_events();
let reversal = TimeReversalAugmentation::new(0.0); let result = reversal.apply(&events).unwrap();
assert_eq!(result.len(), events.len());
for (original, result) in events.iter().zip(result.iter()) {
assert_eq!(original.t, result.t);
assert_eq!(original.x, result.x);
assert_eq!(original.y, result.y);
assert_eq!(original.polarity, result.polarity);
}
}
#[test]
fn test_deterministic_reversal() {
let events = create_test_events();
let reversal1 = TimeReversalAugmentation::new(1.0).with_seed(42); let reversal2 = TimeReversalAugmentation::new(1.0).with_seed(42);
let result1 = reversal1.apply(&events).unwrap();
let result2 = reversal2.apply(&events).unwrap();
assert_eq!(result1.len(), result2.len());
for (r1, r2) in result1.iter().zip(result2.iter()) {
assert_eq!(r1.t, r2.t);
assert_eq!(r1.x, r2.x);
assert_eq!(r1.y, r2.y);
assert_eq!(r1.polarity, r2.polarity);
}
}
#[test]
fn test_time_reversal_properties() {
let events = create_test_events();
let reversal = TimeReversalAugmentation::new(1.0).with_seed(42); let result = reversal.apply(&events).unwrap();
assert_eq!(result.len(), events.len());
assert!(result[0].t < result[1].t); assert!(result[1].t < result[2].t);
assert!(result[2].t < result[3].t);
for (original, reversed) in events.iter().zip(result.iter()) {
assert_eq!(original.polarity, !reversed.polarity);
assert_eq!(original.x, reversed.x); assert_eq!(original.y, reversed.y);
}
}
#[test]
fn test_time_reversal_without_polarity_flip() {
let events = create_test_events();
let reversal = TimeReversalAugmentation::new(1.0)
.with_polarity_flip(false)
.with_seed(42);
let result = reversal.apply(&events).unwrap();
assert_eq!(result.len(), events.len());
for (original, reversed) in events.iter().zip(result.iter()) {
assert_eq!(original.polarity, reversed.polarity);
}
}
}