use crate::advanced::rbf::{RBFInterpolator, RBFKernel};
use crate::bspline::BSpline;
use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, ToPrimitive};
use std::fmt::{Debug, Display, LowerExp};
use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum TemporalPattern {
LinearTrend,
PolynomialTrend,
SeasonalOnly,
TrendWithSeasonality,
Irregular,
StepChanges,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum SeasonalityType {
Daily,
Weekly,
Monthly,
Quarterly,
Yearly,
Custom(f64),
}
#[derive(Debug, Clone, PartialEq)]
pub enum MissingDataStrategy {
Linear,
Spline,
ForwardFill,
BackwardFill,
Mean,
Seasonal,
}
#[derive(Debug, Clone)]
pub struct TimeSeriesConfig<T> {
pub pattern: TemporalPattern,
pub seasonality: Option<SeasonalityType>,
pub missing_strategy: MissingDataStrategy,
pub temporal_smoothing: T,
pub seasonal_periods: usize,
pub estimate_uncertainty: bool,
pub outlier_threshold: Option<T>,
}
impl<T: Float + FromPrimitive> Default for TimeSeriesConfig<T> {
fn default() -> Self {
Self {
pattern: TemporalPattern::TrendWithSeasonality,
seasonality: Some(SeasonalityType::Daily),
missing_strategy: MissingDataStrategy::Spline,
temporal_smoothing: T::from_f64(0.1).expect("Operation failed"),
seasonal_periods: 3,
estimate_uncertainty: false,
outlier_threshold: Some(T::from_f64(3.0).expect("Operation failed")),
}
}
}
#[derive(Debug, Clone)]
pub struct TimeSeriesResult<T> {
pub values: Array1<T>,
pub lower_bounds: Option<Array1<T>>,
pub upper_bounds: Option<Array1<T>>,
pub prediction_intervals: Option<(Array1<T>, Array1<T>)>,
}
#[derive(Debug)]
pub struct TimeSeriesInterpolator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ std::iter::Sum
+ 'static,
{
config: TimeSeriesConfig<T>,
train_times: Array1<T>,
train_values: Array1<T>,
trend_interpolator: Option<BSpline<T>>,
seasonal_interpolator: Option<RBFInterpolator<T>>,
is_trained: bool,
outliers: Vec<usize>,
#[allow(dead_code)]
temporal_stats: TemporalStats<T>,
}
#[derive(Debug, Clone)]
pub struct TemporalStats<T> {
pub trend_slope: Option<T>,
pub seasonal_amplitude: Option<T>,
pub noise_level: T,
pub autocorrelation: Option<T>,
pub changepoints: Vec<usize>,
}
impl<T: Float> Default for TemporalStats<T> {
fn default() -> Self {
Self {
trend_slope: None,
seasonal_amplitude: None,
noise_level: T::zero(),
autocorrelation: None,
changepoints: Vec::new(),
}
}
}
impl<T> Default for TimeSeriesInterpolator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ std::iter::Sum
+ 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<T> TimeSeriesInterpolator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ std::iter::Sum
+ 'static,
{
pub fn new() -> Self {
Self {
config: TimeSeriesConfig::default(),
train_times: Array1::zeros(0),
train_values: Array1::zeros(0),
trend_interpolator: None,
seasonal_interpolator: None,
is_trained: false,
outliers: Vec::new(),
temporal_stats: TemporalStats::default(),
}
}
pub fn with_temporal_pattern(mut self, pattern: TemporalPattern) -> Self {
self.config.pattern = pattern;
self
}
pub fn with_seasonality_type(mut self, seasonality: SeasonalityType) -> Self {
self.config.seasonality = Some(seasonality);
self
}
pub fn with_missing_data_strategy(mut self, strategy: &str) -> Self {
self.config.missing_strategy = match strategy {
"linear" => MissingDataStrategy::Linear,
"spline" => MissingDataStrategy::Spline,
"forward_fill" => MissingDataStrategy::ForwardFill,
"backward_fill" => MissingDataStrategy::BackwardFill,
"mean" => MissingDataStrategy::Mean,
"seasonal" => MissingDataStrategy::Seasonal,
_ => MissingDataStrategy::Spline,
};
self
}
pub fn with_temporal_smoothing(mut self, smoothing: T) -> Self {
self.config.temporal_smoothing = smoothing;
self
}
pub fn with_uncertainty_estimation(mut self, enable: bool) -> Self {
self.config.estimate_uncertainty = enable;
self
}
pub fn fit(
&mut self,
timestamps: &ArrayView1<T>,
values: &ArrayView1<T>,
) -> InterpolateResult<()> {
if timestamps.len() != values.len() {
return Err(InterpolateError::DimensionMismatch(format!(
"timestamps and values must have the same length, got {} and {}",
timestamps.len(),
values.len()
)));
}
if timestamps.len() < 3 {
return Err(InterpolateError::InvalidValue(
"At least 3 data points required for time series interpolation".to_string(),
));
}
self.train_times = timestamps.to_owned();
self.train_values = values.to_owned();
if let Some(threshold) = self.config.outlier_threshold {
self.detect_outliers(threshold)?;
}
match self.config.pattern {
TemporalPattern::LinearTrend => {
self.fit_linear_trend()?;
}
TemporalPattern::TrendWithSeasonality => {
self.fit_trend_and_seasonal()?;
}
TemporalPattern::SeasonalOnly => {
self.fit_seasonal_only()?;
}
_ => {
self.fit_general_spline()?;
}
}
self.is_trained = true;
Ok(())
}
pub fn interpolate(
&self,
timestamps: &ArrayView1<T>,
) -> InterpolateResult<TimeSeriesResult<T>> {
if !self.is_trained {
return Err(InterpolateError::InvalidState(
"Interpolator must be fitted before interpolation".to_string(),
));
}
let mut interpolated_values = Array1::zeros(timestamps.len());
if let Some(ref trend_interp) = self.trend_interpolator {
let trend_values = trend_interp.evaluate_array(timestamps)?;
interpolated_values = interpolated_values + trend_values;
}
if let Some(ref seasonal_interp) = self.seasonal_interpolator {
let timestamps_2d = Array2::from_shape_vec((timestamps.len(), 1), timestamps.to_vec())
.map_err(|e| {
InterpolateError::ComputationError(format!(
"Failed to reshape timestamps: {}",
e
))
})?;
let seasonal_values = seasonal_interp.interpolate(×tamps_2d.view())?;
interpolated_values = interpolated_values + seasonal_values;
}
let (lower_bounds, upper_bounds) = if self.config.estimate_uncertainty {
let uncertainty = self.estimate_uncertainty(timestamps)?;
(
Some(interpolated_values.clone() - uncertainty.clone()),
Some(interpolated_values.clone() + uncertainty),
)
} else {
(None, None)
};
Ok(TimeSeriesResult {
values: interpolated_values,
lower_bounds,
upper_bounds,
prediction_intervals: None,
})
}
fn detect_outliers(&mut self, threshold: T) -> InterpolateResult<()> {
let n = self.train_values.len();
let window_size = (n / 10).clamp(3, 20);
let mut outliers = Vec::new();
for i in 0..n {
let start = i.saturating_sub(window_size / 2);
let end = (i + window_size / 2 + 1).min(n);
let window = self
.train_values
.slice(scirs2_core::ndarray::s![start..end]);
let mean = window.sum() / T::from_usize(window.len()).expect("Operation failed");
let variance = window.iter().map(|&x| (x - mean) * (x - mean)).sum::<T>()
/ T::from_usize(window.len() - 1).expect("Operation failed");
let std_dev = variance.sqrt();
if (self.train_values[i] - mean).abs() > threshold * std_dev {
outliers.push(i);
}
}
self.outliers = outliers;
Ok(())
}
fn fit_linear_trend(&mut self) -> InterpolateResult<()> {
let degree = 1;
let knots = crate::bspline::generate_knots(&self.train_times.view(), degree, "average")?;
let trend_spline = BSpline::new(
&knots.view(),
&self.train_values.view(),
degree,
crate::bspline::ExtrapolateMode::Extrapolate,
)?;
self.trend_interpolator = Some(trend_spline);
Ok(())
}
fn fit_trend_and_seasonal(&mut self) -> InterpolateResult<()> {
let degree = 3;
let knots = crate::bspline::generate_knots(&self.train_times.view(), degree, "average")?;
let trend_spline = BSpline::new(
&knots.view(),
&self.train_values.view(),
degree,
crate::bspline::ExtrapolateMode::Extrapolate,
)?;
let trend_values = trend_spline.evaluate_array(&self.train_times.view())?;
let residuals = self.train_values.clone() - trend_values;
if self.train_times.len() >= 8 {
let times_2d =
Array2::from_shape_vec((self.train_times.len(), 1), self.train_times.to_vec())
.map_err(|e| {
InterpolateError::ComputationError(format!("Failed to reshape times: {e}"))
})?;
let seasonal_rbf = RBFInterpolator::new(
×_2d.view(),
&residuals.view(),
RBFKernel::Gaussian,
T::from_f64(1.0).expect("Operation failed"),
)?;
self.seasonal_interpolator = Some(seasonal_rbf);
}
self.trend_interpolator = Some(trend_spline);
Ok(())
}
fn fit_seasonal_only(&mut self) -> InterpolateResult<()> {
if self.train_times.len() >= 8 {
let times_2d =
Array2::from_shape_vec((self.train_times.len(), 1), self.train_times.to_vec())
.map_err(|e| {
InterpolateError::ComputationError(format!("Failed to reshape times: {e}"))
})?;
let seasonal_rbf = RBFInterpolator::new(
×_2d.view(),
&self.train_values.view(),
RBFKernel::Gaussian,
T::from_f64(0.5).expect("Operation failed"),
)?;
self.seasonal_interpolator = Some(seasonal_rbf);
}
Ok(())
}
fn fit_general_spline(&mut self) -> InterpolateResult<()> {
let degree = 3;
let knots = crate::bspline::generate_knots(&self.train_times.view(), degree, "average")?;
let spline = BSpline::new(
&knots.view(),
&self.train_values.view(),
degree,
crate::bspline::ExtrapolateMode::Extrapolate,
)?;
self.trend_interpolator = Some(spline);
Ok(())
}
fn estimate_uncertainty(&self, timestamps: &ArrayView1<T>) -> InterpolateResult<Array1<T>> {
let n = timestamps.len();
let base_uncertainty = self.temporal_stats.noise_level;
Ok(Array1::from_elem(n, base_uncertainty))
}
pub fn get_outliers(&self) -> &[usize] {
&self.outliers
}
pub fn get_temporal_stats(&self) -> &TemporalStats<T> {
&self.temporal_stats
}
pub fn is_trained(&self) -> bool {
self.is_trained
}
}
#[allow(dead_code)]
pub fn make_daily_interpolator<T>() -> TimeSeriesInterpolator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ std::iter::Sum
+ 'static,
{
TimeSeriesInterpolator::new()
.with_temporal_pattern(TemporalPattern::TrendWithSeasonality)
.with_seasonality_type(SeasonalityType::Daily)
.with_missing_data_strategy("spline")
}
#[allow(dead_code)]
pub fn make_weekly_interpolator<T>() -> TimeSeriesInterpolator<T>
where
T: Float
+ FromPrimitive
+ ToPrimitive
+ Debug
+ Display
+ LowerExp
+ ScalarOperand
+ AddAssign
+ SubAssign
+ MulAssign
+ DivAssign
+ RemAssign
+ Copy
+ Send
+ Sync
+ std::iter::Sum
+ 'static,
{
TimeSeriesInterpolator::new()
.with_temporal_pattern(TemporalPattern::TrendWithSeasonality)
.with_seasonality_type(SeasonalityType::Weekly)
.with_missing_data_strategy("spline")
}
#[allow(dead_code)]
pub fn forward_fill<T>(
timestamps: &ArrayView1<T>,
values: &ArrayView1<T>,
query_timestamps: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float + PartialOrd + Copy,
{
if timestamps.len() != values.len() {
return Err(InterpolateError::DimensionMismatch(
"_timestamps and values must have same length".to_string(),
));
}
let mut result = Array1::zeros(query_timestamps.len());
for (i, &query_time) in query_timestamps.iter().enumerate() {
let mut last_value = values[0];
for (j, ×tamp) in timestamps.iter().enumerate() {
if timestamp <= query_time {
last_value = values[j];
} else {
break;
}
}
result[i] = last_value;
}
Ok(result)
}
#[allow(dead_code)]
pub fn backward_fill<T>(
timestamps: &ArrayView1<T>,
values: &ArrayView1<T>,
query_timestamps: &ArrayView1<T>,
) -> InterpolateResult<Array1<T>>
where
T: Float + PartialOrd + Copy,
{
if timestamps.len() != values.len() {
return Err(InterpolateError::DimensionMismatch(
"_timestamps and values must have same length".to_string(),
));
}
let mut result = Array1::zeros(query_timestamps.len());
for (i, &query_time) in query_timestamps.iter().enumerate() {
let mut next_value = values[values.len() - 1];
for (j, ×tamp) in timestamps.iter().enumerate().rev() {
if timestamp >= query_time {
next_value = values[j];
} else {
break;
}
}
result[i] = next_value;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array1;
#[test]
fn test_time_series_interpolator_creation() {
let interpolator = TimeSeriesInterpolator::<f64>::new();
assert!(!interpolator.is_trained());
assert_eq!(
interpolator.config.pattern,
TemporalPattern::TrendWithSeasonality
);
}
#[test]
fn test_time_series_interpolator_configuration() {
let interpolator = TimeSeriesInterpolator::<f64>::new()
.with_temporal_pattern(TemporalPattern::LinearTrend)
.with_seasonality_type(SeasonalityType::Weekly)
.with_missing_data_strategy("linear")
.with_temporal_smoothing(0.2);
assert_eq!(interpolator.config.pattern, TemporalPattern::LinearTrend);
assert_eq!(
interpolator.config.seasonality,
Some(SeasonalityType::Weekly)
);
assert_eq!(
interpolator.config.missing_strategy,
MissingDataStrategy::Linear
);
assert!((interpolator.config.temporal_smoothing - 0.2).abs() < 1e-10);
}
#[test]
fn test_time_series_fitting() {
let timestamps = Array1::linspace(0.0, 10.0, 11);
let values = timestamps.mapv(|t| t + 0.1 * (2.0 * t).sin());
let mut interpolator =
TimeSeriesInterpolator::new().with_temporal_pattern(TemporalPattern::LinearTrend);
let result = interpolator.fit(×tamps.view(), &values.view());
assert!(result.is_ok());
assert!(interpolator.is_trained());
}
#[test]
fn test_time_series_interpolation() {
let timestamps = Array1::linspace(0.0, 10.0, 11);
let values = timestamps.mapv(|t| t + 0.1 * (2.0 * t).sin());
let mut interpolator =
TimeSeriesInterpolator::new().with_temporal_pattern(TemporalPattern::LinearTrend);
interpolator
.fit(×tamps.view(), &values.view())
.expect("Operation failed");
let query_times = Array1::from_vec(vec![2.5, 5.0, 7.5]);
let result = interpolator
.interpolate(&query_times.view())
.expect("Operation failed");
assert_eq!(result.values.len(), 3);
assert!(result.values.iter().all(|&v| v.is_finite()));
}
#[test]
fn test_forward_fill() {
let timestamps = Array1::from_vec(vec![1.0, 3.0, 5.0, 7.0]);
let values = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0]);
let query_times = Array1::from_vec(vec![0.0, 2.0, 4.0, 6.0, 8.0]);
let result = forward_fill(×tamps.view(), &values.view(), &query_times.view())
.expect("Operation failed");
assert_eq!(result, Array1::from_vec(vec![10.0, 10.0, 20.0, 30.0, 40.0]));
}
#[test]
fn test_backward_fill() {
let timestamps = Array1::from_vec(vec![1.0, 3.0, 5.0, 7.0]);
let values = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0]);
let query_times = Array1::from_vec(vec![0.0, 2.0, 4.0, 6.0, 8.0]);
let result = backward_fill(×tamps.view(), &values.view(), &query_times.view())
.expect("Operation failed");
assert_eq!(result, Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 40.0]));
}
#[test]
fn test_make_daily_interpolator() {
let interpolator = make_daily_interpolator::<f64>();
assert_eq!(
interpolator.config.seasonality,
Some(SeasonalityType::Daily)
);
assert_eq!(
interpolator.config.pattern,
TemporalPattern::TrendWithSeasonality
);
}
#[test]
fn test_make_weekly_interpolator() {
let interpolator = make_weekly_interpolator::<f64>();
assert_eq!(
interpolator.config.seasonality,
Some(SeasonalityType::Weekly)
);
assert_eq!(
interpolator.config.pattern,
TemporalPattern::TrendWithSeasonality
);
}
#[test]
fn test_time_series_with_uncertainty() {
let timestamps = Array1::linspace(0.0, 10.0, 11);
let values = timestamps.mapv(|t| t + 0.1 * (2.0 * t).sin());
let mut interpolator = TimeSeriesInterpolator::new()
.with_temporal_pattern(TemporalPattern::LinearTrend)
.with_uncertainty_estimation(true);
interpolator
.fit(×tamps.view(), &values.view())
.expect("Operation failed");
let query_times = Array1::from_vec(vec![2.5, 5.0]);
let result = interpolator
.interpolate(&query_times.view())
.expect("Operation failed");
assert_eq!(result.values.len(), 2);
assert!(result.lower_bounds.is_some());
assert!(result.upper_bounds.is_some());
}
#[test]
fn test_trend_and_seasonal_fitting() {
let timestamps = Array1::linspace(0.0, 20.0, 21);
let values = timestamps.mapv(|t| t * 0.5 + 2.0 * (t * 0.5).sin() + 0.1 * t.cos());
let mut interpolator = TimeSeriesInterpolator::new()
.with_temporal_pattern(TemporalPattern::TrendWithSeasonality)
.with_seasonality_type(SeasonalityType::Custom(4.0));
let result = interpolator.fit(×tamps.view(), &values.view());
assert!(result.is_ok());
assert!(interpolator.is_trained());
assert!(interpolator.trend_interpolator.is_some());
}
}