use crate::ev_augmentation::{AugmentationError, AugmentationResult, Validatable};
pub const COL_X: &str = "x";
pub const COL_Y: &str = "y";
pub const COL_T: &str = "t";
pub const COL_POLARITY: &str = "polarity";
#[cfg(unix)]
use tracing::{debug, instrument};
#[cfg(not(unix))]
macro_rules! debug {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! info {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! warn {
($($args:tt)*) => {
eprintln!("[WARN] {}", format!($($args)*))
};
}
#[cfg(not(unix))]
macro_rules! trace {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! error {
($($args:tt)*) => {
eprintln!("[ERROR] {}", format!($($args)*))
};
}
#[cfg(not(unix))]
macro_rules! instrument {
($($args:tt)*) => {};
}
use polars::prelude::*;
#[derive(Debug, Clone)]
pub struct TimeJitterAugmentation {
pub std_us: f64,
pub clip_negative: bool,
pub sort_timestamps: bool,
pub seed: Option<u64>,
}
impl TimeJitterAugmentation {
pub fn new(std_us: f64) -> Self {
Self {
std_us,
clip_negative: false,
sort_timestamps: false,
seed: None,
}
}
pub fn with_clipping(mut self, clip: bool) -> Self {
self.clip_negative = clip;
self
}
pub fn with_sorting(mut self, sort: bool) -> Self {
self.sort_timestamps = sort;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn description(&self) -> String {
format!("std={:.1}µs", self.std_us)
}
pub fn apply_to_dataframe(&self, df: LazyFrame) -> PolarsResult<LazyFrame> {
apply_time_jitter(df, self)
}
pub fn apply_to_dataframe_eager(&self, df: DataFrame) -> PolarsResult<DataFrame> {
apply_time_jitter(df.lazy(), self)?.collect()
}
}
impl Validatable for TimeJitterAugmentation {
fn validate(&self) -> AugmentationResult<()> {
if self.std_us < 0.0 {
return Err(AugmentationError::InvalidConfig(
"Time jitter standard deviation must be non-negative".to_string(),
));
}
Ok(())
}
}
#[cfg_attr(unix, instrument(skip(df), fields(config = ?config)))]
pub fn apply_time_jitter(
df: LazyFrame,
config: &TimeJitterAugmentation,
) -> PolarsResult<LazyFrame> {
debug!("Applying time jitter with Polars: {:?}", config);
let std_seconds = config.std_us / 1_000_000.0;
if std_seconds <= 0.0 {
debug!("No time jittering needed (std_seconds <= 0)");
return Ok(df);
}
let collected_df = df.collect()?;
let jittered_df = collected_df;
Ok(jittered_df.lazy())
}
pub fn apply_time_jitter_polars(
df: LazyFrame,
config: &TimeJitterAugmentation,
) -> PolarsResult<LazyFrame> {
apply_time_jitter(df, config)
}
pub fn apply_time_jitter_df(df: LazyFrame, std_us: f64) -> PolarsResult<LazyFrame> {
let config = TimeJitterAugmentation::new(std_us);
apply_time_jitter(df, &config)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_time_jitter_creation() {
let jitter = TimeJitterAugmentation::new(1000.0);
assert_eq!(jitter.std_us, 1000.0);
assert!(!jitter.clip_negative);
assert!(!jitter.sort_timestamps);
}
#[test]
fn test_time_jitter_validation() {
let valid = TimeJitterAugmentation::new(100.0);
assert!(valid.validate().is_ok());
let invalid = TimeJitterAugmentation::new(-100.0);
assert!(invalid.validate().is_err());
}
}