use crate::ev_filtering::config::Validatable;
use crate::ev_filtering::{FilterError, FilterResult};
use polars::prelude::*;
#[cfg(unix)]
use tracing::{debug, instrument, warn};
#[cfg(not(unix))]
macro_rules! debug {
($($args:tt)*) => {};
}
#[cfg(not(unix))]
macro_rules! warn {
($($args:tt)*) => {
eprintln!("[WARN] {}", format!($($args)*))
};
}
#[cfg(not(unix))]
macro_rules! instrument {
($($args:tt)*) => {};
}
use crate::ev_filtering::utils::{COL_POLARITY, COL_T, COL_X, COL_Y};
#[derive(Debug, Clone)]
pub struct RawPolarityData {
pub values: Vec<f64>,
pub encoding: PolarityEncoding,
pub confidence: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PolarityEncoding {
TrueFalse,
OneZero,
OneMinus,
Raw,
Mixed,
Unknown,
}
impl PolarityEncoding {
pub fn to_bool(&self, value: f64) -> FilterResult<bool> {
match self {
PolarityEncoding::TrueFalse => {
if value == 0.0 {
Ok(false)
} else if value == 1.0 {
Ok(true)
} else {
Err(FilterError::InvalidInput(format!(
"Invalid TrueFalse polarity value: {}",
value
)))
}
}
PolarityEncoding::OneZero => {
if value == 0.0 {
Ok(false)
} else if value == 1.0 {
Ok(true)
} else {
Err(FilterError::InvalidInput(format!(
"Invalid OneZero polarity value: {}",
value
)))
}
}
PolarityEncoding::OneMinus => {
if value == -1.0 {
Ok(false)
} else if value == 1.0 {
Ok(true)
} else {
Err(FilterError::InvalidInput(format!(
"Invalid OneMinus polarity value: {}",
value
)))
}
}
PolarityEncoding::Raw => Ok(value > 0.0), PolarityEncoding::Mixed => {
match value {
-1.0 => Ok(false),
0.0 => Ok(false),
1.0 => Ok(true),
_ => Ok(value > 0.0),
}
}
PolarityEncoding::Unknown => Ok(value > 0.0), }
}
pub fn from_bool(&self, polarity: bool) -> f64 {
match self {
PolarityEncoding::TrueFalse => {
if polarity {
1.0
} else {
0.0
}
}
PolarityEncoding::OneZero => {
if polarity {
1.0
} else {
0.0
}
}
PolarityEncoding::OneMinus => {
if polarity {
1.0
} else {
-1.0
}
}
PolarityEncoding::Raw => {
if polarity {
1.0
} else {
0.0
}
}
PolarityEncoding::Mixed => {
if polarity {
1.0
} else {
0.0
}
}
PolarityEncoding::Unknown => {
if polarity {
1.0
} else {
0.0
}
}
}
}
pub fn detect_from_raw_values(values: &[f64]) -> Self {
if values.is_empty() {
return PolarityEncoding::Unknown;
}
let unique_values: std::collections::HashSet<_> =
values.iter().map(|&v| OrderedFloat::from(v)).collect();
match unique_values.len() {
0 => PolarityEncoding::Unknown,
1 => {
let value = values[0];
if value == 0.0 || value == 1.0 {
PolarityEncoding::OneZero
} else if value == -1.0 || value == 1.0 {
PolarityEncoding::OneMinus
} else {
PolarityEncoding::Raw
}
}
2 => {
let mut vals: Vec<f64> = unique_values.iter().map(|v| (*v).into()).collect();
vals.sort_by(|a, b| a.partial_cmp(b).unwrap());
match (vals[0], vals[1]) {
(0.0, 1.0) => PolarityEncoding::OneZero,
(-1.0, 1.0) => PolarityEncoding::OneMinus,
_ => PolarityEncoding::Raw,
}
}
_ => {
let has_zero = unique_values.contains(&OrderedFloat::from(0.0));
let has_one = unique_values.contains(&OrderedFloat::from(1.0));
let has_minus_one = unique_values.contains(&OrderedFloat::from(-1.0));
if has_zero && has_one && has_minus_one {
PolarityEncoding::Mixed
} else {
PolarityEncoding::Raw
}
}
}
}
pub fn detect_from_events_polars(df: LazyFrame) -> PolarsResult<Self> {
let stats_df = df
.select([
len().alias("total_events"),
col(COL_POLARITY).sum().alias("positive_count"),
])
.with_columns([(col("positive_count").cast(DataType::Float64)
/ col("total_events").cast(DataType::Float64))
.alias("positive_ratio")])
.collect()?;
if stats_df.height() == 0 {
return Ok(PolarityEncoding::TrueFalse);
}
let row = stats_df.get_row(0)?;
let positive_ratio = row.0[2].try_extract::<f64>().unwrap_or(0.5);
if !(0.1..=0.9).contains(&positive_ratio) {
warn!(
"Unusual polarity distribution: {:.1}% positive events. Check encoding.",
positive_ratio * 100.0
);
}
Ok(PolarityEncoding::TrueFalse)
}
pub fn description(&self) -> &'static str {
match self {
PolarityEncoding::TrueFalse => "true/false",
PolarityEncoding::OneZero => "1/0",
PolarityEncoding::OneMinus => "1/-1",
PolarityEncoding::Raw => "raw",
PolarityEncoding::Mixed => "mixed encodings",
PolarityEncoding::Unknown => "unknown/auto-detect",
}
}
pub fn expected_values(&self) -> Vec<f64> {
match self {
PolarityEncoding::TrueFalse => vec![0.0, 1.0],
PolarityEncoding::OneZero => vec![0.0, 1.0],
PolarityEncoding::OneMinus => vec![-1.0, 1.0],
PolarityEncoding::Raw => vec![], PolarityEncoding::Mixed => vec![-1.0, 0.0, 1.0], PolarityEncoding::Unknown => vec![], }
}
pub fn is_valid_value(&self, value: f64) -> bool {
match self {
PolarityEncoding::TrueFalse => value == 0.0 || value == 1.0,
PolarityEncoding::OneZero => value == 0.0 || value == 1.0,
PolarityEncoding::OneMinus => value == -1.0 || value == 1.0,
PolarityEncoding::Raw => true, PolarityEncoding::Mixed => [-1.0, 0.0, 1.0].contains(&value),
PolarityEncoding::Unknown => true, }
}
pub fn confidence_score(&self, values: &[f64]) -> f64 {
if values.is_empty() {
return 0.0;
}
let expected = self.expected_values();
if expected.is_empty() {
return 0.3;
}
let valid_count = values.iter().filter(|&&v| self.is_valid_value(v)).count();
valid_count as f64 / values.len() as f64
}
pub fn convert_to(&self, target: PolarityEncoding, values: &[f64]) -> FilterResult<Vec<f64>> {
if values.is_empty() {
return Ok(Vec::new());
}
let mut converted = Vec::with_capacity(values.len());
for &value in values {
let boolean_polarity = self.to_bool(value)?;
let target_value = target.from_bool(boolean_polarity);
converted.push(target_value);
}
Ok(converted)
}
}
#[derive(Debug, Clone, Copy, PartialEq, PartialOrd)]
#[allow(clippy::derive_ord_xor_partial_ord)]
struct OrderedFloat(f64);
impl From<f64> for OrderedFloat {
fn from(val: f64) -> Self {
OrderedFloat(val)
}
}
impl From<OrderedFloat> for f64 {
fn from(val: OrderedFloat) -> Self {
val.0
}
}
impl Eq for OrderedFloat {}
#[allow(clippy::derive_ord_xor_partial_ord)]
impl Ord for OrderedFloat {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0
.partial_cmp(&other.0)
.unwrap_or(std::cmp::Ordering::Equal)
}
}
impl std::hash::Hash for OrderedFloat {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.to_bits().hash(state);
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PolaritySelection {
PositiveOnly,
NegativeOnly,
Both,
Alternating,
Balanced,
}
impl PolaritySelection {
pub fn passes(&self, polarity: bool) -> bool {
match self {
PolaritySelection::PositiveOnly => polarity,
PolaritySelection::NegativeOnly => !polarity,
PolaritySelection::Both => true,
PolaritySelection::Alternating => true, PolaritySelection::Balanced => true, }
}
pub fn description(&self) -> &'static str {
match self {
PolaritySelection::PositiveOnly => "positive only",
PolaritySelection::NegativeOnly => "negative only",
PolaritySelection::Both => "both polarities",
PolaritySelection::Alternating => "alternating polarity",
PolaritySelection::Balanced => "balanced polarity",
}
}
pub fn to_polars_expr(&self) -> Option<Expr> {
match self {
PolaritySelection::PositiveOnly => Some(col(COL_POLARITY).gt(lit(0))),
PolaritySelection::NegativeOnly => Some(col(COL_POLARITY).lt(lit(0))),
PolaritySelection::Both => None, PolaritySelection::Alternating => None, PolaritySelection::Balanced => None, }
}
}
#[derive(Debug, Clone)]
pub struct PolarityFilter {
pub selection: PolaritySelection,
pub input_encoding: PolarityEncoding,
pub validate_polarity: bool,
pub alternating_min_interval: Option<f64>,
pub balance_radius: Option<u16>,
pub balance_ratio: Option<f64>,
}
impl PolarityFilter {
pub fn positive_only() -> Self {
Self {
selection: PolaritySelection::PositiveOnly,
input_encoding: PolarityEncoding::TrueFalse,
validate_polarity: true,
alternating_min_interval: None,
balance_radius: None,
balance_ratio: None,
}
}
pub fn negative_only() -> Self {
Self {
selection: PolaritySelection::NegativeOnly,
input_encoding: PolarityEncoding::TrueFalse,
validate_polarity: true,
alternating_min_interval: None,
balance_radius: None,
balance_ratio: None,
}
}
pub fn both() -> Self {
Self {
selection: PolaritySelection::Both,
input_encoding: PolarityEncoding::TrueFalse,
validate_polarity: true,
alternating_min_interval: None,
balance_radius: None,
balance_ratio: None,
}
}
pub fn alternating(min_interval_us: f64) -> Self {
Self {
selection: PolaritySelection::Alternating,
input_encoding: PolarityEncoding::TrueFalse,
validate_polarity: true,
alternating_min_interval: Some(min_interval_us),
balance_radius: None,
balance_ratio: None,
}
}
pub fn balanced(radius: u16, ratio: f64) -> Self {
Self {
selection: PolaritySelection::Balanced,
input_encoding: PolarityEncoding::TrueFalse,
validate_polarity: true,
alternating_min_interval: None,
balance_radius: Some(radius),
balance_ratio: Some(ratio),
}
}
pub fn from_values(polarity_values: Vec<i8>) -> Self {
if polarity_values.is_empty() {
return Self::both();
}
let encoding = if polarity_values.contains(&-1) && polarity_values.contains(&1) {
PolarityEncoding::OneMinus
} else if polarity_values.contains(&0) && polarity_values.contains(&1) {
PolarityEncoding::OneZero
} else if polarity_values.len() == 1 {
match polarity_values[0] {
1 => return Self::positive_only().with_encoding(PolarityEncoding::OneZero),
0 => return Self::negative_only().with_encoding(PolarityEncoding::OneZero),
-1 => return Self::negative_only().with_encoding(PolarityEncoding::OneMinus),
_ => PolarityEncoding::Raw,
}
} else {
PolarityEncoding::Mixed
};
let has_positive = polarity_values.iter().any(|&v| v > 0);
let has_negative = polarity_values.iter().any(|&v| v <= 0);
let selection = match (has_positive, has_negative) {
(true, false) => PolaritySelection::PositiveOnly,
(false, true) => PolaritySelection::NegativeOnly,
(true, true) => PolaritySelection::Both,
(false, false) => PolaritySelection::Both, };
Self {
selection,
input_encoding: encoding,
validate_polarity: true,
alternating_min_interval: None,
balance_radius: None,
balance_ratio: None,
}
}
pub fn with_encoding(mut self, encoding: PolarityEncoding) -> Self {
self.input_encoding = encoding;
self
}
pub fn with_validation(mut self, validate: bool) -> Self {
self.validate_polarity = validate;
self
}
pub fn to_polars_expr(&self, df: &LazyFrame) -> PolarsResult<Option<Expr>> {
match self.selection {
PolaritySelection::PositiveOnly
| PolaritySelection::NegativeOnly
| PolaritySelection::Both => Ok(self.selection.to_polars_expr()),
PolaritySelection::Alternating => {
Ok(Some(self.build_alternating_expr(df)?))
}
PolaritySelection::Balanced => {
Ok(Some(self.build_balanced_expr(df)?))
}
}
}
fn build_alternating_expr(&self, _df: &LazyFrame) -> PolarsResult<Expr> {
let min_interval = self.alternating_min_interval.unwrap_or(1000.0);
let prev_polarity = col(COL_POLARITY).shift(lit(1));
let prev_time = col(COL_T).shift(lit(1));
let polarity_alternates = col(COL_POLARITY).neq(prev_polarity.clone());
let time_interval_ok = (col(COL_T) - prev_time)
* lit(1_000_000.0) .gt_eq(lit(min_interval));
let is_null = prev_polarity.is_null();
Ok(is_null.or(polarity_alternates.and(time_interval_ok)))
}
fn build_balanced_expr(&self, _df: &LazyFrame) -> PolarsResult<Expr> {
let radius = self.balance_radius.unwrap_or(5) as i64;
let required_ratio = self.balance_ratio.unwrap_or(0.3);
let tolerance = 0.2;
let spatial_bin_x = (col(COL_X) / lit(radius)).cast(DataType::Int64);
let spatial_bin_y = (col(COL_Y) / lit(radius)).cast(DataType::Int64);
let time_bin = (col(COL_T) * lit(10.0)).cast(DataType::Int64);
let local_positive_ratio = col(COL_POLARITY).cast(DataType::Float64).mean().over([
spatial_bin_x,
spatial_bin_y,
time_bin,
]);
Ok(local_positive_ratio
.clone()
.gt_eq(lit(required_ratio - tolerance))
.and(local_positive_ratio.lt_eq(lit(required_ratio + tolerance))))
}
pub fn estimate_pass_fraction_polars(&self, df: LazyFrame) -> PolarsResult<f64> {
match self.selection {
PolaritySelection::Both => Ok(1.0),
PolaritySelection::PositiveOnly | PolaritySelection::NegativeOnly => {
let stats_df = df
.select([
len().alias("total_events"),
col(COL_POLARITY).sum().alias("positive_count"),
])
.with_columns([(col("positive_count").cast(DataType::Float64)
/ col("total_events").cast(DataType::Float64))
.alias("positive_ratio")])
.collect()?;
if stats_df.height() == 0 {
return Ok(0.0);
}
let row = stats_df.get_row(0)?;
let positive_ratio = row.0[2].try_extract::<f64>().unwrap_or(0.0);
match self.selection {
PolaritySelection::PositiveOnly => Ok(positive_ratio),
PolaritySelection::NegativeOnly => Ok(1.0 - positive_ratio),
_ => unreachable!(),
}
}
PolaritySelection::Alternating => Ok(0.5), PolaritySelection::Balanced => Ok(0.8), }
}
pub fn description(&self) -> String {
let mut parts = vec![self.selection.description().to_string()];
if self.input_encoding != PolarityEncoding::TrueFalse {
parts.push(format!("encoding: {}", self.input_encoding.description()));
}
if let Some(interval) = self.alternating_min_interval {
parts.push(format!("min interval: {:.1}µs", interval));
}
if let (Some(radius), Some(ratio)) = (self.balance_radius, self.balance_ratio) {
parts.push(format!("balance: r={}, ratio={:.2}", radius, ratio));
}
parts.join(", ")
}
pub fn apply_to_dataframe(&self, df: LazyFrame) -> PolarsResult<LazyFrame> {
apply_polarity_filter(df, self)
}
pub fn apply_to_dataframe_eager(&self, df: DataFrame) -> PolarsResult<DataFrame> {
apply_polarity_filter(df.lazy(), self)?.collect()
}
}
impl Default for PolarityFilter {
fn default() -> Self {
Self::both()
}
}
impl Validatable for PolarityFilter {
fn validate(&self) -> FilterResult<()> {
match self.selection {
PolaritySelection::Alternating => {
if self.alternating_min_interval.is_none() {
return Err(FilterError::InvalidConfig(
"Alternating polarity filter requires min_interval".to_string(),
));
}
if let Some(interval) = self.alternating_min_interval {
if interval < 0.0 {
return Err(FilterError::InvalidConfig(
"Alternating min interval must be non-negative".to_string(),
));
}
}
}
PolaritySelection::Balanced => {
if self.balance_radius.is_none() || self.balance_ratio.is_none() {
return Err(FilterError::InvalidConfig(
"Balanced polarity filter requires radius and ratio".to_string(),
));
}
if let Some(ratio) = self.balance_ratio {
if !(0.0..=1.0).contains(&ratio) {
return Err(FilterError::InvalidConfig(
"Balance ratio must be between 0.0 and 1.0".to_string(),
));
}
}
}
_ => {} }
Ok(())
}
}
#[cfg_attr(unix, instrument(skip(df), fields(filter = ?filter)))]
pub fn apply_polarity_filter(df: LazyFrame, filter: &PolarityFilter) -> PolarsResult<LazyFrame> {
debug!("Applying polarity filter: {:?}", filter);
if let Err(e) = filter.validate() {
warn!("Invalid polarity filter configuration: {}", e);
return Ok(df); }
match filter.to_polars_expr(&df)? {
Some(expr) => {
debug!("Polarity filter expression: {:?}", expr);
Ok(df.filter(expr))
}
None => {
debug!("No polarity filtering needed");
Ok(df)
}
}
}
#[cfg_attr(unix, instrument(skip(df)))]
pub fn filter_by_polarity_polars(df: LazyFrame, positive: bool) -> PolarsResult<LazyFrame> {
let expr = if positive {
col(COL_POLARITY).gt(lit(0))
} else {
col(COL_POLARITY).eq(lit(0))
};
Ok(df.filter(expr))
}
pub fn filter_by_polarity_df(df: LazyFrame, positive: bool) -> PolarsResult<LazyFrame> {
let filter = if positive {
PolarityFilter::positive_only()
} else {
PolarityFilter::negative_only()
};
filter.apply_to_dataframe(df)
}
#[derive(Debug, Clone)]
pub struct PolarityStats {
pub total_events: usize,
pub positive_events: usize,
pub negative_events: usize,
pub positive_ratio: f64,
pub negative_ratio: f64,
pub polarity_balance: f64, }
impl PolarityStats {
#[cfg_attr(unix, instrument(skip(df)))]
pub fn calculate_from_dataframe(df: LazyFrame) -> PolarsResult<Self> {
let stats_df = df
.select([
len().alias("total_events"),
col(COL_POLARITY).sum().alias("positive_events"),
])
.with_columns([(col("total_events") - col("positive_events")).alias("negative_events")])
.with_columns([
(col("positive_events").cast(DataType::Float64)
/ col("total_events").cast(DataType::Float64))
.alias("positive_ratio"),
(col("negative_events").cast(DataType::Float64)
/ col("total_events").cast(DataType::Float64))
.alias("negative_ratio"),
])
.with_columns([
(lit(1.0)
- when((col("positive_ratio") - lit(0.5)).gt(lit(0.0)))
.then((col("positive_ratio") - lit(0.5)) * lit(2.0))
.otherwise((lit(0.5) - col("positive_ratio")) * lit(2.0)))
.alias("polarity_balance"),
])
.collect()?;
if stats_df.height() == 0 {
return Ok(Self::empty());
}
let row = stats_df.get_row(0)?;
Ok(Self {
total_events: row.0[0].try_extract::<u32>()? as usize,
positive_events: row.0[1].try_extract::<u32>()? as usize,
negative_events: row.0[2].try_extract::<u32>()? as usize,
positive_ratio: row.0[3].try_extract::<f64>()?,
negative_ratio: row.0[4].try_extract::<f64>()?,
polarity_balance: row.0[5].try_extract::<f64>()?,
})
}
fn empty() -> Self {
Self {
total_events: 0,
positive_events: 0,
negative_events: 0,
positive_ratio: 0.0,
negative_ratio: 0.0,
polarity_balance: 0.0,
}
}
}
impl std::fmt::Display for PolarityStats {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Polarity: +{} ({:.1}%) / -{} ({:.1}%) | Balance: {:.3}",
self.positive_events,
self.positive_ratio * 100.0,
self.negative_events,
self.negative_ratio * 100.0,
self.polarity_balance
)
}
}
#[cfg_attr(unix, instrument(skip(df)))]
pub fn apply_alternating_polarity_filter_polars(
df: LazyFrame,
min_interval_us: f64,
) -> PolarsResult<LazyFrame> {
let sorted_df = df.sort([COL_T], SortMultipleOptions::default());
let prev_polarity = col(COL_POLARITY).shift(lit(1));
let prev_time = col(COL_T).shift(lit(1));
let polarity_alternates = col(COL_POLARITY).neq(prev_polarity.clone());
let time_interval_ok = (col(COL_T) - prev_time)
* lit(1_000_000.0) .gt_eq(lit(min_interval_us));
let is_first = prev_polarity.is_null();
let passes_filter = is_first.or(polarity_alternates.and(time_interval_ok));
Ok(sorted_df.filter(passes_filter))
}
#[cfg_attr(unix, instrument(skip(df)))]
pub fn apply_balanced_polarity_filter_polars(
df: LazyFrame,
radius: u16,
required_ratio: f64,
) -> PolarsResult<LazyFrame> {
let bin_size = radius as i64;
let tolerance = 0.2;
let spatial_binned = df
.with_columns([
(col(COL_X) / lit(bin_size))
.cast(DataType::Int64)
.alias("spatial_bin_x"),
(col(COL_Y) / lit(bin_size))
.cast(DataType::Int64)
.alias("spatial_bin_y"),
(col(COL_T) * lit(10.0))
.cast(DataType::Int64)
.alias("time_bin"),
])
.with_columns([
col(COL_POLARITY)
.cast(DataType::Float64)
.mean()
.over([col("spatial_bin_x"), col("spatial_bin_y"), col("time_bin")])
.alias("local_positive_ratio"),
]);
let balance_filter = col("local_positive_ratio")
.gt_eq(lit(required_ratio - tolerance))
.and(col("local_positive_ratio").lt_eq(lit(required_ratio + tolerance)));
Ok(spatial_binned.filter(balance_filter))
}
#[cfg_attr(unix, instrument(skip(df)))]
pub fn analyze_polarity_patterns_polars(df: LazyFrame) -> PolarsResult<DataFrame> {
let sorted_df = df.sort([COL_T], SortMultipleOptions::default());
sorted_df
.select([
len().alias("total_events"),
col(COL_POLARITY).sum().alias("positive_events"),
(len() - col(COL_POLARITY).sum()).alias("negative_events"),
col(COL_POLARITY).mean().alias("positive_ratio"),
(lit(1.0) - col(COL_POLARITY).mean()).alias("negative_ratio"),
(lit(1.0)
- when((col(COL_POLARITY).mean() - lit(0.5)).gt(lit(0.0)))
.then((col(COL_POLARITY).mean() - lit(0.5)) * lit(2.0))
.otherwise((lit(0.5) - col(COL_POLARITY).mean()) * lit(2.0)))
.alias("polarity_balance"),
(col(COL_POLARITY).neq(col(COL_POLARITY).shift(lit(1))))
.sum()
.cast(DataType::Float64)
.alias("switch_count"),
col(COL_T).min().alias("t_min"),
col(COL_T).max().alias("t_max"),
(col(COL_T).max() - col(COL_T).min()).alias("duration"),
])
.with_columns([
(col("switch_count") / (len() - lit(1)).cast(DataType::Float64))
.alias("polarity_switch_rate"),
(len().cast(DataType::Float64) / col("duration")).alias("event_rate"),
])
.collect()
}
#[cfg_attr(unix, instrument(skip(df)))]
pub fn separate_polarities_polars(df: LazyFrame) -> PolarsResult<(LazyFrame, LazyFrame)> {
let positive_df = df.clone().filter(col(COL_POLARITY).gt(lit(0)));
let negative_df = df.filter(col(COL_POLARITY).eq(lit(0)));
Ok((positive_df, negative_df))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{events_to_dataframe, Event};
fn create_test_events() -> Events {
vec![
Event {
t: 1.0,
x: 100,
y: 200,
polarity: true,
}, Event {
t: 2.0,
x: 150,
y: 250,
polarity: false,
}, Event {
t: 3.0,
x: 200,
y: 300,
polarity: true,
}, Event {
t: 4.0,
x: 250,
y: 350,
polarity: false,
}, Event {
t: 5.0,
x: 300,
y: 400,
polarity: true,
}, ]
}
#[test]
fn test_polarity_filter_creation() {
let filter = PolarityFilter::positive_only();
assert_eq!(filter.selection, PolaritySelection::PositiveOnly);
let filter = PolarityFilter::negative_only();
assert_eq!(filter.selection, PolaritySelection::NegativeOnly);
let filter = PolarityFilter::both();
assert_eq!(filter.selection, PolaritySelection::Both);
let filter = PolarityFilter::alternating(1000.0);
assert_eq!(filter.selection, PolaritySelection::Alternating);
assert_eq!(filter.alternating_min_interval, Some(1000.0));
}
#[test]
fn test_polarity_filtering_polars() -> PolarsResult<()> {
let events = create_test_events();
let df = events_to_dataframe(&events)?.lazy();
let positive_filtered = filter_by_polarity_polars(df.clone(), true)?;
let pos_result = positive_filtered.collect()?;
assert_eq!(pos_result.height(), 3);
let negative_filtered = filter_by_polarity_polars(df, false)?;
let neg_result = negative_filtered.collect()?;
assert_eq!(neg_result.height(), 2);
Ok(())
}
#[test]
fn test_polarity_stats_polars() -> PolarsResult<()> {
let events = create_test_events();
let df = events_to_dataframe(&events)?.lazy();
let stats = PolarityStats::calculate_from_dataframe(df)?;
assert_eq!(stats.total_events, 5);
assert_eq!(stats.positive_events, 3);
assert_eq!(stats.negative_events, 2);
assert!((stats.positive_ratio - 0.6).abs() < 0.001);
assert!((stats.negative_ratio - 0.4).abs() < 0.001);
assert!(stats.polarity_balance > 0.5);
Ok(())
}
#[test]
fn test_polarity_filter_dataframe_native() -> PolarsResult<()> {
let events = create_test_events();
let df = events_to_dataframe(&events)?.lazy();
let positive_filter = PolarityFilter::positive_only();
let positive_filtered = positive_filter.apply_to_dataframe(df.clone())?;
let pos_result = positive_filtered.collect()?;
assert_eq!(pos_result.height(), 3);
let negative_filter = PolarityFilter::negative_only();
let negative_filtered = negative_filter.apply_to_dataframe(df)?;
let neg_result = negative_filtered.collect()?;
assert_eq!(neg_result.height(), 2);
Ok(())
}
#[test]
fn test_filter_by_polarity_dataframe() -> PolarsResult<()> {
let events = create_test_events();
let df = events_to_dataframe(&events)?.lazy();
let positive_filtered = filter_by_polarity_df(df.clone(), true)?;
let pos_result = positive_filtered.collect()?;
assert_eq!(pos_result.height(), 3);
let negative_filtered = filter_by_polarity_df(df, false)?;
let neg_result = negative_filtered.collect()?;
assert_eq!(neg_result.height(), 2);
Ok(())
}
#[test]
fn test_pattern_analysis_polars() -> PolarsResult<()> {
let events = create_test_events();
let df = events_to_dataframe(&events)?.lazy();
let analysis_df = analyze_polarity_patterns_polars(df)?;
assert_eq!(analysis_df.height(), 1);
assert!(analysis_df.column("positive_ratio").is_ok());
assert!(analysis_df.column("polarity_switch_rate").is_ok());
assert!(analysis_df.column("polarity_balance").is_ok());
Ok(())
}
#[test]
fn test_alternating_filter_polars() -> PolarsResult<()> {
let events = vec![
Event {
t: 1.0,
x: 100,
y: 200,
polarity: true,
},
Event {
t: 1.0005,
x: 100,
y: 200,
polarity: false,
}, Event {
t: 1.002,
x: 100,
y: 200,
polarity: false,
}, Event {
t: 1.005,
x: 100,
y: 200,
polarity: true,
}, ];
let df = events_to_dataframe(&events)?.lazy();
let filtered = apply_alternating_polarity_filter_polars(df, 1000.0)?; let result = filtered.collect()?;
assert!(result.height() <= events.len());
assert!(result.height() > 0);
Ok(())
}
#[test]
fn test_separate_polarities_polars() -> PolarsResult<()> {
let events = create_test_events();
let df = events_to_dataframe(&events)?.lazy();
let (pos_df, neg_df) = separate_polarities_polars(df)?;
let pos_result = pos_df.collect()?;
let neg_result = neg_df.collect()?;
assert_eq!(pos_result.height(), 3); assert_eq!(neg_result.height(), 2);
Ok(())
}
#[test]
fn test_encoding_detection() {
let zero_one_values = vec![0.0, 1.0, 0.0, 1.0];
let encoding = PolarityEncoding::detect_from_raw_values(&zero_one_values);
assert_eq!(encoding, PolarityEncoding::OneZero);
let minus_one_values = vec![-1.0, 1.0, -1.0, 1.0];
let encoding = PolarityEncoding::detect_from_raw_values(&minus_one_values);
assert_eq!(encoding, PolarityEncoding::OneMinus);
let mixed_values = vec![0.0, 1.0, -1.0, 1.0];
let encoding = PolarityEncoding::detect_from_raw_values(&mixed_values);
assert_eq!(encoding, PolarityEncoding::Mixed);
}
#[test]
fn test_legacy_compatibility() {
let events = create_test_events();
let filter = PolarityFilter::positive_only();
let filtered = filter.apply(&events).unwrap();
assert_eq!(filtered.len(), 3);
let analysis = analyze_polarity_patterns(&events).unwrap();
assert!(analysis.contains_key("positive_ratio"));
assert!(analysis.contains_key("polarity_switch_rate"));
let stats = PolarityStats::calculate(&events);
assert_eq!(stats.total_events, 5);
assert_eq!(stats.positive_events, 3);
}
#[test]
fn test_filter_validation() {
assert!(PolarityFilter::positive_only().validate().is_ok());
assert!(PolarityFilter::alternating(1000.0).validate().is_ok());
assert!(PolarityFilter::balanced(5, 0.3).validate().is_ok());
let mut invalid_alternating = PolarityFilter::alternating(1000.0);
invalid_alternating.alternating_min_interval = None;
assert!(invalid_alternating.validate().is_err());
let mut invalid_balanced = PolarityFilter::balanced(5, 0.3);
invalid_balanced.balance_ratio = Some(1.5); assert!(invalid_balanced.validate().is_err());
}
}