use crate::ev_augmentation::{AugmentationError, AugmentationResult, Validatable};
use crate::ev_augmentation::COL_T;
#[cfg(unix)]
use tracing::{debug, instrument};
#[cfg(not(unix))]
macro_rules! debug {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! info {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! warn {
($($args:tt)*) => {
eprintln!("[WARN] {}", format!($($args)*))
};
}
#[cfg(not(unix))]
macro_rules! trace {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! error {
($($args:tt)*) => {
eprintln!("[ERROR] {}", format!($($args)*))
};
}
#[cfg(not(unix))]
macro_rules! instrument {
($($args:tt)*) => {};
}
use polars::prelude::*;
#[derive(Debug, Clone)]
pub struct TimeSkewAugmentation {
pub coefficient: f64,
pub coefficient_range: Option<(f64, f64)>,
pub offset: f64,
pub offset_range: Option<(f64, f64)>,
pub clip_negative: bool,
pub seed: Option<u64>,
}
impl TimeSkewAugmentation {
pub fn new(coefficient: f64) -> Self {
Self {
coefficient,
coefficient_range: None,
offset: 0.0,
offset_range: None,
clip_negative: true,
seed: None,
}
}
pub fn random(min_coeff: f64, max_coeff: f64) -> Self {
Self {
coefficient: (min_coeff + max_coeff) / 2.0, coefficient_range: Some((min_coeff, max_coeff)),
offset: 0.0,
offset_range: None,
clip_negative: true,
seed: None,
}
}
pub fn with_offset(mut self, offset: f64) -> Self {
self.offset = offset;
self.offset_range = None;
self
}
pub fn with_random_offset(mut self, min_offset: f64, max_offset: f64) -> Self {
self.offset = (min_offset + max_offset) / 2.0;
self.offset_range = Some((min_offset, max_offset));
self
}
pub fn with_clipping(mut self, clip: bool) -> Self {
self.clip_negative = clip;
self
}
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
pub fn description(&self) -> String {
if self.coefficient_range.is_some() || self.offset_range.is_some() {
let coeff_str = if let Some((min, max)) = self.coefficient_range {
format!("coeff∈[{:.2},{:.2}]", min, max)
} else {
format!("coeff={:.2}", self.coefficient)
};
let offset_str = if let Some((min, max)) = self.offset_range {
format!("offset∈[{:.2},{:.2}]s", min, max)
} else if self.offset.abs() > 1e-10 {
format!("offset={:.3}s", self.offset)
} else {
String::new()
};
if offset_str.is_empty() {
coeff_str
} else {
format!("{}, {}", coeff_str, offset_str)
}
} else {
format!("coeff={:.2}, offset={:.3}s", self.coefficient, self.offset)
}
}
}
impl Validatable for TimeSkewAugmentation {
fn validate(&self) -> AugmentationResult<()> {
if self.coefficient <= 0.0 {
Err(AugmentationError::InvalidConfig(
"Time skew coefficient must be positive".to_string(),
))
} else if let Some((min, max)) = self.coefficient_range {
if min <= 0.0 {
Err(AugmentationError::InvalidConfig(
"Minimum coefficient must be positive".to_string(),
))
} else if min >= max {
Err(AugmentationError::InvalidConfig(
"Invalid coefficient range".to_string(),
))
} else if let Some((offset_min, offset_max)) = self.offset_range {
if offset_min >= offset_max {
Err(AugmentationError::InvalidConfig(
"Invalid offset range".to_string(),
))
} else {
Ok(())
}
} else {
Ok(())
}
} else if let Some((offset_min, offset_max)) = self.offset_range {
if offset_min >= offset_max {
Err(AugmentationError::InvalidConfig(
"Invalid offset range".to_string(),
))
} else {
Ok(())
}
} else {
Ok(())
}
}
}
#[cfg_attr(unix, instrument(skip(df), fields(config = ?config)))]
pub fn apply_time_skew_polars(
df: LazyFrame,
config: &TimeSkewAugmentation,
) -> PolarsResult<LazyFrame> {
debug!("Applying time skew with Polars: {:?}", config);
if config.coefficient_range.is_some() || config.offset_range.is_some() {
let collected_df = df.collect()?;
let skewed_df = collected_df;
return Ok(skewed_df.lazy());
}
let skewed_df = df.with_columns([
(col(COL_T) * lit(config.coefficient) + lit(config.offset)).alias("t_skewed"),
]);
let result = if config.clip_negative {
skewed_df
.filter(col("t_skewed").gt_eq(lit(0.0)))
.with_columns([col("t_skewed").alias(COL_T)])
.drop(["t_skewed"])
} else {
skewed_df
.with_columns([when(col("t_skewed").gt_eq(lit(0.0)))
.then(col("t_skewed"))
.otherwise(lit(0.0))
.alias(COL_T)])
.drop(["t_skewed"])
};
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_time_skew_creation() {
let skew = TimeSkewAugmentation::new(1.5);
assert_eq!(skew.coefficient, 1.5);
assert_eq!(skew.offset, 0.0);
assert!(skew.clip_negative);
let skew = TimeSkewAugmentation::random(0.8, 1.2);
assert!(skew.coefficient_range.is_some());
assert_eq!(skew.coefficient_range.unwrap(), (0.8, 1.2));
}
#[test]
fn test_time_skew_validation() {
let valid = TimeSkewAugmentation::new(2.0);
assert!(valid.validate().is_ok());
let valid = TimeSkewAugmentation::random(0.5, 1.5);
assert!(valid.validate().is_ok());
let invalid = TimeSkewAugmentation::new(0.0);
assert!(invalid.validate().is_err());
let invalid = TimeSkewAugmentation::new(-1.0);
assert!(invalid.validate().is_err());
let invalid = TimeSkewAugmentation::random(1.5, 0.5);
assert!(invalid.validate().is_err());
}
}