use crate::{Result, TreeBoostError};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, serde::Serialize, serde::Deserialize)]
pub enum NaNStrategy {
#[default]
Keep,
ForwardFill,
Constant(i32), }
impl NaNStrategy {
pub fn constant(value: f32) -> Self {
Self::Constant(value.to_bits() as i32)
}
fn get_constant(&self) -> Option<f32> {
match self {
Self::Constant(bits) => Some(f32::from_bits(*bits as u32)),
_ => None,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct LagGenerator {
lags: Vec<usize>,
nan_strategy: NaNStrategy,
feature_names: Option<Vec<String>>,
}
impl LagGenerator {
pub fn new(lags: Vec<usize>) -> Self {
Self {
lags,
nan_strategy: NaNStrategy::default(),
feature_names: None,
}
}
pub fn range(max_lag: usize) -> Self {
Self::new((1..=max_lag).collect())
}
pub fn with_nan_strategy(mut self, strategy: NaNStrategy) -> Self {
self.nan_strategy = strategy;
self
}
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
self.feature_names = Some(names);
self
}
pub fn num_lags(&self) -> usize {
self.lags.len()
}
pub fn output_names(&self) -> Vec<String> {
let base_names: Vec<String> = self
.feature_names
.clone()
.unwrap_or_else(|| vec!["feature".to_string()]);
let mut names = Vec::new();
for base in &base_names {
for lag in &self.lags {
names.push(format!("{}_lag_{}", base, lag));
}
}
names
}
pub fn transform(&self, data: &[f32], num_features: usize) -> Result<Vec<f32>> {
if data.is_empty() {
return Ok(Vec::new());
}
let num_rows = data.len() / num_features;
if num_rows * num_features != data.len() {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_new_features = num_features * self.lags.len();
let total_features = num_features + num_new_features;
let mut result = vec![f32::NAN; num_rows * total_features];
for row in 0..num_rows {
let src_start = row * num_features;
let dst_start = row * total_features;
result[dst_start..dst_start + num_features]
.copy_from_slice(&data[src_start..src_start + num_features]);
let mut lag_offset = num_features;
for &lag in &self.lags {
for feat in 0..num_features {
let dst_idx = dst_start + lag_offset + feat;
if row >= lag {
let src_row = row - lag;
let src_idx = src_row * num_features + feat;
result[dst_idx] = data[src_idx];
} else {
result[dst_idx] = match self.nan_strategy {
NaNStrategy::Keep => f32::NAN,
NaNStrategy::ForwardFill => {
data[feat]
}
NaNStrategy::Constant(_) => {
self.nan_strategy.get_constant().unwrap_or(0.0)
}
};
}
}
lag_offset += num_features;
}
}
Ok(result)
}
pub fn transform_column(&self, column: &[f32]) -> Vec<f32> {
let num_rows = column.len();
let mut result = vec![f32::NAN; num_rows * self.lags.len()];
for (lag_idx, &lag) in self.lags.iter().enumerate() {
let offset = lag_idx * num_rows;
for row in 0..num_rows {
if row >= lag {
result[offset + row] = column[row - lag];
} else {
result[offset + row] = match self.nan_strategy {
NaNStrategy::Keep => f32::NAN,
NaNStrategy::ForwardFill => column[0],
NaNStrategy::Constant(_) => self.nan_strategy.get_constant().unwrap_or(0.0),
};
}
}
}
result
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum RollingStat {
Mean,
Std,
Min,
Max,
Sum,
Count,
Median,
Var,
}
impl RollingStat {
pub fn suffix(&self) -> &'static str {
match self {
Self::Mean => "mean",
Self::Std => "std",
Self::Min => "min",
Self::Max => "max",
Self::Sum => "sum",
Self::Count => "count",
Self::Median => "median",
Self::Var => "var",
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RollingGenerator {
window: usize,
stats: Vec<RollingStat>,
min_periods: usize,
center: bool,
feature_names: Option<Vec<String>>,
}
impl RollingGenerator {
pub fn new(window: usize) -> Self {
Self {
window,
stats: vec![RollingStat::Mean],
min_periods: 1,
center: false,
feature_names: None,
}
}
pub fn with_stats(mut self, stats: Vec<RollingStat>) -> Self {
self.stats = stats;
self
}
pub fn with_min_periods(mut self, min_periods: usize) -> Self {
self.min_periods = min_periods;
self
}
pub fn centered(mut self) -> Self {
self.center = true;
self
}
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
self.feature_names = Some(names);
self
}
pub fn output_names(&self) -> Vec<String> {
let base_names: Vec<String> = self
.feature_names
.clone()
.unwrap_or_else(|| vec!["feature".to_string()]);
let mut names = Vec::new();
for base in &base_names {
for stat in &self.stats {
names.push(format!("{}_roll_{}_{}", base, self.window, stat.suffix()));
}
}
names
}
pub fn transform(&self, data: &[f32], num_features: usize) -> Result<Vec<f32>> {
if data.is_empty() {
return Ok(Vec::new());
}
let num_rows = data.len() / num_features;
if num_rows * num_features != data.len() {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_new_features = num_features * self.stats.len();
let total_features = num_features + num_new_features;
let mut result = vec![f32::NAN; num_rows * total_features];
for row in 0..num_rows {
let src_start = row * num_features;
let dst_start = row * total_features;
result[dst_start..dst_start + num_features]
.copy_from_slice(&data[src_start..src_start + num_features]);
let mut stat_offset = num_features;
for stat in &self.stats {
for feat in 0..num_features {
let dst_idx = dst_start + stat_offset + feat;
let (start_row, end_row) = if self.center {
let half = self.window / 2;
let start = row.saturating_sub(half);
let end = (row + half + 1).min(num_rows);
(start, end)
} else {
let start = row.saturating_sub(self.window - 1);
(start, row + 1)
};
let mut window_vals: Vec<f32> = Vec::with_capacity(end_row - start_row);
for r in start_row..end_row {
let val = data[r * num_features + feat];
if !val.is_nan() {
window_vals.push(val);
}
}
if window_vals.len() < self.min_periods {
result[dst_idx] = f32::NAN;
continue;
}
result[dst_idx] = self.compute_stat(*stat, &window_vals);
}
stat_offset += num_features;
}
}
Ok(result)
}
fn compute_stat(&self, stat: RollingStat, values: &[f32]) -> f32 {
if values.is_empty() {
return f32::NAN;
}
match stat {
RollingStat::Mean => {
let sum: f32 = values.iter().sum();
sum / values.len() as f32
}
RollingStat::Std => {
if values.len() < 2 {
return f32::NAN;
}
let mean = values.iter().sum::<f32>() / values.len() as f32;
let variance = values.iter().map(|x| (x - mean).powi(2)).sum::<f32>()
/ (values.len() - 1) as f32;
variance.sqrt()
}
RollingStat::Var => {
if values.len() < 2 {
return f32::NAN;
}
let mean = values.iter().sum::<f32>() / values.len() as f32;
values.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / (values.len() - 1) as f32
}
RollingStat::Min => values.iter().cloned().fold(f32::INFINITY, f32::min),
RollingStat::Max => values.iter().cloned().fold(f32::NEG_INFINITY, f32::max),
RollingStat::Sum => values.iter().sum(),
RollingStat::Count => values.len() as f32,
RollingStat::Median => {
let mut sorted = values.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let mid = sorted.len() / 2;
if sorted.len().is_multiple_of(2) {
(sorted[mid - 1] + sorted[mid]) / 2.0
} else {
sorted[mid]
}
}
}
}
pub fn transform_column(&self, column: &[f32]) -> Vec<f32> {
let num_rows = column.len();
let mut result = vec![f32::NAN; num_rows * self.stats.len()];
for (stat_idx, stat) in self.stats.iter().enumerate() {
let offset = stat_idx * num_rows;
for row in 0..num_rows {
let start_row = row.saturating_sub(self.window - 1);
let end_row = row + 1;
let mut window_vals: Vec<f32> = Vec::with_capacity(end_row - start_row);
for &val in &column[start_row..end_row] {
if !val.is_nan() {
window_vals.push(val);
}
}
if window_vals.len() < self.min_periods {
result[offset + row] = f32::NAN;
continue;
}
result[offset + row] = self.compute_stat(*stat, &window_vals);
}
}
result
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct EwmaGenerator {
alpha: f32,
feature_names: Option<Vec<String>>,
adjust: bool,
}
impl EwmaGenerator {
pub fn new(alpha: f32) -> Self {
assert!(alpha > 0.0 && alpha <= 1.0, "Alpha must be in (0, 1]");
Self {
alpha,
feature_names: None,
adjust: true,
}
}
pub fn from_span(span: usize) -> Self {
assert!(span >= 1, "Span must be >= 1");
Self::new(2.0 / (span as f32 + 1.0))
}
pub fn from_halflife(halflife: f32) -> Self {
assert!(halflife > 0.0, "Halflife must be > 0");
let alpha = 1.0 - (0.5_f32.ln() / halflife).exp();
Self::new(alpha)
}
pub fn without_adjust(mut self) -> Self {
self.adjust = false;
self
}
pub fn with_feature_names(mut self, names: Vec<String>) -> Self {
self.feature_names = Some(names);
self
}
pub fn transform(&self, data: &[f32], num_features: usize) -> Result<Vec<f32>> {
if data.is_empty() {
return Ok(Vec::new());
}
let num_rows = data.len() / num_features;
if num_rows * num_features != data.len() {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let mut result = vec![f32::NAN; data.len()];
let mut ewma = vec![0.0f32; num_features];
let mut sum_weights = vec![0.0f32; num_features];
for row in 0..num_rows {
let row_start = row * num_features;
for feat in 0..num_features {
let val = data[row_start + feat];
if val.is_nan() {
result[row_start + feat] = if row > 0 {
result[(row - 1) * num_features + feat]
} else {
f32::NAN
};
continue;
}
if row == 0 || ewma[feat] == 0.0 && sum_weights[feat] == 0.0 {
ewma[feat] = val;
sum_weights[feat] = 1.0;
} else {
ewma[feat] = self.alpha * val + (1.0 - self.alpha) * ewma[feat];
sum_weights[feat] = self.alpha + (1.0 - self.alpha) * sum_weights[feat];
}
result[row_start + feat] = if self.adjust {
ewma[feat] / sum_weights[feat]
} else {
ewma[feat]
};
}
}
Ok(result)
}
pub fn transform_column(&self, column: &[f32]) -> Vec<f32> {
let mut result = vec![f32::NAN; column.len()];
let mut ewma = 0.0f32;
let mut sum_weights = 0.0f32;
for (i, &val) in column.iter().enumerate() {
if val.is_nan() {
result[i] = if i > 0 { result[i - 1] } else { f32::NAN };
continue;
}
if i == 0 || (ewma == 0.0 && sum_weights == 0.0) {
ewma = val;
sum_weights = 1.0;
} else {
ewma = self.alpha * val + (1.0 - self.alpha) * ewma;
sum_weights = self.alpha + (1.0 - self.alpha) * sum_weights;
}
result[i] = if self.adjust {
ewma / sum_weights
} else {
ewma
};
}
result
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum SeasonalComponent {
Hour,
DayOfWeek,
DayOfMonth,
DayOfYear,
WeekOfYear,
Month,
Quarter,
Year,
IsWeekend,
}
impl SeasonalComponent {
pub fn suffix(&self) -> &'static str {
match self {
Self::Hour => "hour",
Self::DayOfWeek => "dow",
Self::DayOfMonth => "dom",
Self::DayOfYear => "doy",
Self::WeekOfYear => "woy",
Self::Month => "month",
Self::Quarter => "quarter",
Self::Year => "year",
Self::IsWeekend => "is_weekend",
}
}
pub fn max_value(&self) -> f32 {
match self {
Self::Hour => 24.0,
Self::DayOfWeek => 7.0,
Self::DayOfMonth => 31.0,
Self::DayOfYear => 366.0,
Self::WeekOfYear => 53.0,
Self::Month => 12.0,
Self::Quarter => 4.0,
Self::Year => 1.0, Self::IsWeekend => 1.0, }
}
pub fn is_cyclical(&self) -> bool {
!matches!(self, Self::Year | Self::IsWeekend)
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct SeasonalGenerator {
components: Vec<SeasonalComponent>,
cyclical: bool,
}
impl SeasonalGenerator {
pub fn new(components: Vec<SeasonalComponent>) -> Self {
Self {
components,
cyclical: false,
}
}
pub fn datetime() -> Self {
Self::new(vec![
SeasonalComponent::Hour,
SeasonalComponent::DayOfWeek,
SeasonalComponent::DayOfMonth,
SeasonalComponent::Month,
])
}
pub fn date_only() -> Self {
Self::new(vec![
SeasonalComponent::DayOfWeek,
SeasonalComponent::DayOfMonth,
SeasonalComponent::Month,
SeasonalComponent::Quarter,
])
}
pub fn with_cyclical(mut self, cyclical: bool) -> Self {
self.cyclical = cyclical;
self
}
pub fn output_names(&self, prefix: &str) -> Vec<String> {
let mut names = Vec::new();
for comp in &self.components {
if self.cyclical && comp.is_cyclical() {
names.push(format!("{}_{}_sin", prefix, comp.suffix()));
names.push(format!("{}_{}_cos", prefix, comp.suffix()));
} else {
names.push(format!("{}_{}", prefix, comp.suffix()));
}
}
names
}
pub fn num_features(&self) -> usize {
self.components
.iter()
.map(|c| {
if self.cyclical && c.is_cyclical() {
2
} else {
1
}
})
.sum()
}
pub fn transform_timestamps(&self, timestamps: &[f64]) -> Vec<f32> {
use std::f64::consts::PI;
let num_features = self.num_features();
let mut result = Vec::with_capacity(timestamps.len() * num_features);
for &ts in timestamps {
let secs = ts as i64;
let days = secs / 86400;
let time_of_day = (secs % 86400 + 86400) % 86400;
let hour = (time_of_day / 3600) as f32;
let _minute = ((time_of_day % 3600) / 60) as f32;
let day_of_week = ((days % 7 + 3 + 7) % 7) as f32;
let (year, month, day_of_month, day_of_year) = days_to_ymd(days);
let week_of_year = (day_of_year / 7 + 1).min(53) as f32;
let quarter = ((month - 1) / 3 + 1) as f32;
let is_weekend = if day_of_week >= 5.0 { 1.0 } else { 0.0 };
for comp in &self.components {
let value = match comp {
SeasonalComponent::Hour => hour,
SeasonalComponent::DayOfWeek => day_of_week,
SeasonalComponent::DayOfMonth => day_of_month as f32,
SeasonalComponent::DayOfYear => day_of_year as f32,
SeasonalComponent::WeekOfYear => week_of_year,
SeasonalComponent::Month => month as f32,
SeasonalComponent::Quarter => quarter,
SeasonalComponent::Year => year as f32,
SeasonalComponent::IsWeekend => is_weekend,
};
if self.cyclical && comp.is_cyclical() {
let max = comp.max_value() as f64;
let angle = 2.0 * PI * (value as f64) / max;
result.push(angle.sin() as f32);
result.push(angle.cos() as f32);
} else {
result.push(value);
}
}
}
result
}
}
fn days_to_ymd(days: i64) -> (i32, i32, i32, i32) {
let z = days + 719468;
let era = if z >= 0 { z } else { z - 146096 } / 146097;
let doe = (z - era * 146097) as u32; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; let y = yoe as i64 + era * 400;
let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153;
let d = doy - (153 * mp + 2) / 5 + 1;
let m = if mp < 10 { mp + 3 } else { mp - 9 };
let year = (y + (m <= 2) as i64) as i32;
let is_leap = year % 4 == 0 && (year % 100 != 0 || year % 400 == 0);
let month_days = [0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334];
let day_of_year = month_days[m as usize - 1] + d as i32 + if m > 2 && is_leap { 1 } else { 0 };
(year, m as i32, d as i32, day_of_year)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lag_generator_basic() {
let gen = LagGenerator::new(vec![1, 2]);
let data = vec![
1.0, 10.0, 2.0, 20.0, 3.0, 30.0, 4.0, 40.0, 5.0, 50.0, ];
let result = gen.transform(&data, 2).unwrap();
assert_eq!(result.len(), 30);
assert_eq!(result[0], 1.0);
assert_eq!(result[1], 10.0);
assert!(result[2].is_nan());
assert!(result[3].is_nan());
assert!(result[4].is_nan());
assert!(result[5].is_nan());
let row2_start = 2 * 6;
assert_eq!(result[row2_start], 3.0);
assert_eq!(result[row2_start + 1], 30.0);
assert_eq!(result[row2_start + 2], 2.0); assert_eq!(result[row2_start + 3], 20.0); assert_eq!(result[row2_start + 4], 1.0); assert_eq!(result[row2_start + 5], 10.0); }
#[test]
fn test_lag_generator_range() {
let gen = LagGenerator::range(3);
assert_eq!(gen.lags, vec![1, 2, 3]);
}
#[test]
fn test_lag_generator_forward_fill() {
let gen = LagGenerator::new(vec![1, 2]).with_nan_strategy(NaNStrategy::ForwardFill);
let data = vec![5.0, 6.0, 7.0, 8.0];
let result = gen.transform(&data, 1).unwrap();
assert_eq!(result[1], 5.0); assert_eq!(result[2], 5.0); }
#[test]
fn test_lag_generator_column() {
let gen = LagGenerator::new(vec![1, 2]);
let column = vec![10.0, 20.0, 30.0, 40.0, 50.0];
let result = gen.transform_column(&column);
assert_eq!(result.len(), 10);
assert!(result[0].is_nan());
assert_eq!(result[1], 10.0);
assert_eq!(result[2], 20.0);
assert_eq!(result[3], 30.0);
assert_eq!(result[4], 40.0);
assert!(result[5].is_nan());
assert!(result[6].is_nan());
assert_eq!(result[7], 10.0);
assert_eq!(result[8], 20.0);
assert_eq!(result[9], 30.0);
}
#[test]
fn test_rolling_mean() {
let gen = RollingGenerator::new(3).with_stats(vec![RollingStat::Mean]);
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = gen.transform(&data, 1).unwrap();
assert_eq!(result.len(), 10);
assert_eq!(result[0], 1.0); assert_eq!(result[1], 1.0);
assert_eq!(result[4], 3.0); assert_eq!(result[5], 2.0);
assert_eq!(result[8], 5.0); assert_eq!(result[9], 4.0); }
#[test]
fn test_rolling_multiple_stats() {
let gen = RollingGenerator::new(3).with_stats(vec![
RollingStat::Min,
RollingStat::Max,
RollingStat::Sum,
]);
let data = vec![1.0, 5.0, 3.0, 7.0, 2.0];
let result = gen.transform(&data, 1).unwrap();
assert_eq!(result.len(), 20);
let row4_start = 4 * 4;
assert_eq!(result[row4_start], 2.0); assert_eq!(result[row4_start + 1], 2.0); assert_eq!(result[row4_start + 2], 7.0); assert_eq!(result[row4_start + 3], 12.0); }
#[test]
fn test_rolling_min_periods() {
let gen = RollingGenerator::new(3)
.with_stats(vec![RollingStat::Mean])
.with_min_periods(3);
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let result = gen.transform(&data, 1).unwrap();
assert!(result[1].is_nan()); assert!(result[3].is_nan()); assert!(!result[5].is_nan()); }
#[test]
fn test_rolling_std() {
let gen = RollingGenerator::new(3).with_stats(vec![RollingStat::Std]);
let data = vec![1.0, 2.0, 3.0];
let result = gen.transform(&data, 1).unwrap();
assert!((result[5] - 1.0).abs() < 0.001);
}
#[test]
fn test_rolling_median() {
let gen = RollingGenerator::new(3).with_stats(vec![RollingStat::Median]);
let data = vec![1.0, 5.0, 3.0, 7.0, 2.0];
let result = gen.transform(&data, 1).unwrap();
assert_eq!(result[5], 3.0);
assert_eq!(result[9], 3.0);
}
#[test]
fn test_ewma_basic() {
let gen = EwmaGenerator::new(0.5);
let data = vec![1.0, 2.0, 3.0, 4.0];
let result = gen.transform(&data, 1).unwrap();
assert_eq!(result.len(), 4);
assert_eq!(result[0], 1.0);
assert!(result[1] > 1.0 && result[1] < 2.0);
}
#[test]
fn test_ewma_from_span() {
let gen = EwmaGenerator::from_span(3);
assert!((gen.alpha - 0.5).abs() < 0.001);
}
#[test]
fn test_ewma_handles_nan() {
let gen = EwmaGenerator::new(0.5);
let data = vec![1.0, f32::NAN, 3.0, 4.0];
let result = gen.transform(&data, 1).unwrap();
assert_eq!(result[1], result[0]);
}
#[test]
fn test_seasonal_basic() {
let gen = SeasonalGenerator::new(vec![
SeasonalComponent::Hour,
SeasonalComponent::DayOfWeek,
SeasonalComponent::Month,
]);
let ts = 1705320000.0;
let result = gen.transform_timestamps(&[ts]);
assert_eq!(result.len(), 3);
assert_eq!(result[0], 12.0);
assert_eq!(result[1], 0.0);
assert_eq!(result[2], 1.0);
}
#[test]
fn test_seasonal_cyclical() {
let gen = SeasonalGenerator::new(vec![SeasonalComponent::Hour]).with_cyclical(true);
let ts_midnight = 1705276800.0; let result_midnight = gen.transform_timestamps(&[ts_midnight]);
let ts_noon = 1705320000.0; let result_noon = gen.transform_timestamps(&[ts_noon]);
assert!(result_midnight[0].abs() < 0.001); assert!((result_midnight[1] - 1.0).abs() < 0.001);
assert!(result_noon[0].abs() < 0.001); assert!((result_noon[1] + 1.0).abs() < 0.001); }
#[test]
fn test_seasonal_weekend() {
let gen = SeasonalGenerator::new(vec![SeasonalComponent::IsWeekend]);
let monday = 1705320000.0;
let monday_result = gen.transform_timestamps(&[monday]);
assert_eq!(monday_result[0], 0.0);
let saturday = 1705147200.0;
let saturday_result = gen.transform_timestamps(&[saturday]);
assert_eq!(saturday_result[0], 1.0);
}
#[test]
fn test_seasonal_output_names() {
let gen =
SeasonalGenerator::new(vec![SeasonalComponent::Hour, SeasonalComponent::DayOfWeek])
.with_cyclical(true);
let names = gen.output_names("timestamp");
assert_eq!(names.len(), 4); assert_eq!(names[0], "timestamp_hour_sin");
assert_eq!(names[1], "timestamp_hour_cos");
assert_eq!(names[2], "timestamp_dow_sin");
assert_eq!(names[3], "timestamp_dow_cos");
}
#[test]
fn test_days_to_ymd() {
let (year, month, day, _) = days_to_ymd(19737);
assert_eq!(year, 2024);
assert_eq!(month, 1);
assert_eq!(day, 15);
let (year, month, day, doy) = days_to_ymd(0);
assert_eq!(year, 1970);
assert_eq!(month, 1);
assert_eq!(day, 1);
assert_eq!(doy, 1);
}
#[test]
fn test_lag_generator_serialization() {
let gen = LagGenerator::new(vec![1, 7, 14]).with_nan_strategy(NaNStrategy::ForwardFill);
let json = serde_json::to_string(&gen).unwrap();
let loaded: LagGenerator = serde_json::from_str(&json).unwrap();
assert_eq!(loaded.lags, vec![1, 7, 14]);
assert_eq!(loaded.nan_strategy, NaNStrategy::ForwardFill);
}
#[test]
fn test_rolling_generator_serialization() {
let gen = RollingGenerator::new(7)
.with_stats(vec![RollingStat::Mean, RollingStat::Std])
.with_min_periods(3);
let json = serde_json::to_string(&gen).unwrap();
let loaded: RollingGenerator = serde_json::from_str(&json).unwrap();
assert_eq!(loaded.window, 7);
assert_eq!(loaded.stats.len(), 2);
assert_eq!(loaded.min_periods, 3);
}
#[test]
fn test_ewma_generator_serialization() {
let gen = EwmaGenerator::new(0.3).without_adjust();
let json = serde_json::to_string(&gen).unwrap();
let loaded: EwmaGenerator = serde_json::from_str(&json).unwrap();
assert!((loaded.alpha - 0.3).abs() < 1e-6);
assert!(!loaded.adjust);
}
#[test]
fn test_seasonal_generator_serialization() {
let gen = SeasonalGenerator::new(vec![
SeasonalComponent::Hour,
SeasonalComponent::DayOfWeek,
SeasonalComponent::Month,
])
.with_cyclical(true);
let json = serde_json::to_string(&gen).unwrap();
let loaded: SeasonalGenerator = serde_json::from_str(&json).unwrap();
assert_eq!(loaded.components.len(), 3);
assert!(loaded.cyclical);
}
}