1use chrono::{DateTime, Duration, Utc};
4use rust_decimal::Decimal;
5use serde::{Deserialize, Serialize};
6
7use super::{
8 models::{ExponentialSmoothingModel, ForecastModel, LinearTrendModel, MovingAverageModel},
9 types::{
10 DataPoint, ForecastConfig, ForecastHorizon, ForecastResult as TypesForecastResult,
11 SeasonalityPattern, TimeSeriesData, TrendDirection,
12 },
13 ForecastError, ForecastResult,
14};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ForecastRequest {
19 pub data: TimeSeriesData,
21
22 pub config: ForecastConfig,
24
25 pub preferred_model: Option<ModelType>,
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum ModelType {
33 LinearTrend,
34 MovingAverage,
35 ExponentialSmoothing,
36 Auto, }
38
39pub struct ForecastEngine {
41 config: ForecastConfig,
42}
43
44impl ForecastEngine {
45 pub fn new(config: ForecastConfig) -> Self {
47 Self { config }
48 }
49
50 pub fn new_with_defaults() -> Self {
52 Self {
53 config: ForecastConfig::default(),
54 }
55 }
56
57 pub fn forecast(&self, request: ForecastRequest) -> ForecastResult<TypesForecastResult> {
59 self.validate_data(&request.data)?;
61
62 let model_type = request.preferred_model.unwrap_or(ModelType::Auto);
64 let model_type = if model_type == ModelType::Auto {
65 self.select_best_model(&request.data)?
66 } else {
67 model_type
68 };
69
70 let mut model = self.create_model(model_type)?;
72 model.train(&request.data)?;
73
74 let n_periods = self.calculate_periods(&request)?;
76
77 let forecast_values = model.forecast(n_periods)?;
79
80 let last_timestamp = request
82 .data
83 .last()
84 .ok_or_else(|| ForecastError::InsufficientData("No data points".to_string()))?
85 .timestamp;
86
87 let interval_secs = request.data.interval_secs.unwrap_or(3600);
88
89 let forecast_points = self.generate_forecast_points(
90 last_timestamp,
91 interval_secs,
92 forecast_values.clone(),
93 );
94
95 let std_dev = request.data.std_dev().unwrap_or(0.0);
97 let z_score = self.calculate_z_score(request.config.confidence_level);
98
99 let (lower_bound, upper_bound) =
100 self.calculate_prediction_intervals(&forecast_points, std_dev, z_score);
101
102 let trend = if request.config.include_trend {
104 model.detect_trend()
105 } else {
106 TrendDirection::Unknown
107 };
108
109 let seasonality = if request.config.detect_seasonality {
111 self.detect_seasonality(&request.data)?
112 } else {
113 SeasonalityPattern {
114 detected: false,
115 period: None,
116 strength: 0.0,
117 }
118 };
119
120 let metrics = self.calculate_validation_metrics(&request.data, &model)?;
122
123 Ok(TypesForecastResult {
124 forecast: forecast_points,
125 lower_bound,
126 upper_bound,
127 trend,
128 seasonality,
129 model_name: model.name().to_string(),
130 confidence_level: request.config.confidence_level,
131 metrics,
132 })
133 }
134
135 fn validate_data(&self, data: &TimeSeriesData) -> ForecastResult<()> {
137 if data.is_empty() {
138 return Err(ForecastError::InsufficientData(
139 "Time series data is empty".to_string(),
140 ));
141 }
142
143 if data.len() < self.config.min_data_points {
144 return Err(ForecastError::InsufficientData(format!(
145 "Insufficient data points: {} (minimum required: {})",
146 data.len(),
147 self.config.min_data_points
148 )));
149 }
150
151 Ok(())
152 }
153
154 fn select_best_model(&self, data: &TimeSeriesData) -> ForecastResult<ModelType> {
156 let values = data.values_f64();
162 if values.len() < 2 {
163 return Ok(ModelType::ExponentialSmoothing);
164 }
165
166 let first_half_mean = values[..values.len() / 2].iter().sum::<f64>()
168 / (values.len() / 2) as f64;
169 let second_half_mean =
170 values[values.len() / 2..].iter().sum::<f64>() / (values.len() - values.len() / 2) as f64;
171
172 let trend_ratio = if first_half_mean.abs() > f64::EPSILON {
173 second_half_mean / first_half_mean
174 } else {
175 1.0
176 };
177
178 if !(0.95..=1.05).contains(&trend_ratio) {
180 Ok(ModelType::LinearTrend)
181 } else if data.len() >= 10 {
182 Ok(ModelType::MovingAverage)
184 } else {
185 Ok(ModelType::ExponentialSmoothing)
187 }
188 }
189
190 fn create_model(&self, model_type: ModelType) -> ForecastResult<Box<dyn ForecastModel>> {
192 match model_type {
193 ModelType::LinearTrend => Ok(Box::new(LinearTrendModel::new())),
194 ModelType::MovingAverage => {
195 let window_size = (self.config.min_data_points / 2).max(3);
196 Ok(Box::new(MovingAverageModel::new(window_size)))
197 }
198 ModelType::ExponentialSmoothing => Ok(Box::new(
199 ExponentialSmoothingModel::with_default_alpha(),
200 )),
201 ModelType::Auto => Err(ForecastError::InvalidConfig(
202 "Auto model type should have been resolved".to_string(),
203 )),
204 }
205 }
206
207 fn calculate_periods(&self, request: &ForecastRequest) -> ForecastResult<usize> {
209 match request.config.horizon {
210 ForecastHorizon::Periods(n) => Ok(n),
211 ForecastHorizon::Days(days) => {
212 let interval_secs = request.data.interval_secs.unwrap_or(3600);
213 let periods_per_day = 86400 / interval_secs;
214 Ok((days as i64 * periods_per_day) as usize)
215 }
216 ForecastHorizon::UntilDate(target_date) => {
217 let last_timestamp = request
218 .data
219 .last()
220 .ok_or_else(|| {
221 ForecastError::InsufficientData("No data points".to_string())
222 })?
223 .timestamp;
224
225 let duration = target_date.signed_duration_since(last_timestamp);
226 let interval_secs = request.data.interval_secs.unwrap_or(3600);
227
228 let periods = duration.num_seconds() / interval_secs;
229 if periods <= 0 {
230 return Err(ForecastError::InvalidConfig(
231 "Target date must be in the future".to_string(),
232 ));
233 }
234
235 Ok(periods as usize)
236 }
237 }
238 }
239
240 fn generate_forecast_points(
242 &self,
243 last_timestamp: DateTime<Utc>,
244 interval_secs: i64,
245 values: Vec<Decimal>,
246 ) -> Vec<DataPoint> {
247 values
248 .into_iter()
249 .enumerate()
250 .map(|(i, value)| {
251 DataPoint::new(
252 last_timestamp + Duration::seconds((i as i64 + 1) * interval_secs),
253 value,
254 )
255 })
256 .collect()
257 }
258
259 fn calculate_z_score(&self, confidence_level: f64) -> f64 {
261 match (confidence_level * 100.0) as i32 {
263 90 => 1.645,
264 95 => 1.96,
265 99 => 2.576,
266 _ => 1.96, }
268 }
269
270 fn calculate_prediction_intervals(
272 &self,
273 forecast: &[DataPoint],
274 std_dev: f64,
275 z_score: f64,
276 ) -> (Vec<DataPoint>, Vec<DataPoint>) {
277 let margin = std_dev * z_score;
278
279 let lower_bound: Vec<DataPoint> = forecast
280 .iter()
281 .map(|point| {
282 let lower_value = point.value
283 - Decimal::try_from(margin).unwrap_or(Decimal::ZERO);
284 let lower_value = lower_value.max(Decimal::ZERO); DataPoint::new(point.timestamp, lower_value)
286 })
287 .collect();
288
289 let upper_bound: Vec<DataPoint> = forecast
290 .iter()
291 .map(|point| {
292 let upper_value = point.value
293 + Decimal::try_from(margin).unwrap_or(Decimal::ZERO);
294 DataPoint::new(point.timestamp, upper_value)
295 })
296 .collect();
297
298 (lower_bound, upper_bound)
299 }
300
301 fn detect_seasonality(&self, data: &TimeSeriesData) -> ForecastResult<SeasonalityPattern> {
303 if data.len() < 14 {
305 return Ok(SeasonalityPattern {
306 detected: false,
307 period: None,
308 strength: 0.0,
309 });
310 }
311
312 let values = data.values_f64();
313 let mean = values.iter().sum::<f64>() / values.len() as f64;
314
315 let test_periods = vec![24, 168]; let interval_hours = data.interval_secs.unwrap_or(3600) / 3600;
318
319 let mut best_period = None;
320 let mut best_correlation = 0.0;
321
322 for &period_hours in &test_periods {
323 let lag = period_hours / interval_hours;
324 if lag as usize >= values.len() {
325 continue;
326 }
327
328 let correlation = self.calculate_autocorrelation(&values, lag as usize, mean);
329 if correlation > best_correlation {
330 best_correlation = correlation;
331 best_period = Some(lag as usize);
332 }
333 }
334
335 let detected = best_correlation > 0.3; Ok(SeasonalityPattern {
338 detected,
339 period: if detected { best_period } else { None },
340 strength: if detected { best_correlation } else { 0.0 },
341 })
342 }
343
344 fn calculate_autocorrelation(&self, values: &[f64], lag: usize, mean: f64) -> f64 {
346 if lag >= values.len() {
347 return 0.0;
348 }
349
350 let n = values.len() - lag;
351 let mut numerator = 0.0;
352 let mut denominator = 0.0;
353
354 for i in 0..n {
355 numerator += (values[i] - mean) * (values[i + lag] - mean);
356 }
357
358 for &value in values {
359 denominator += (value - mean).powi(2);
360 }
361
362 if denominator.abs() < f64::EPSILON {
363 return 0.0;
364 }
365
366 numerator / denominator
367 }
368
369 fn calculate_validation_metrics(
371 &self,
372 data: &TimeSeriesData,
373 model: &Box<dyn ForecastModel>,
374 ) -> ForecastResult<Option<super::metrics::ForecastMetrics>> {
375 let holdout_size = (data.len() as f64 * 0.2).ceil() as usize;
377 if holdout_size < 2 || data.len() - holdout_size < self.config.min_data_points {
378 return Ok(None); }
380
381 let train_size = data.len() - holdout_size;
383 let train_data = data.subset(0, train_size);
384 let holdout_data = data.subset(train_size, data.len());
385
386 let mut validation_model = self.create_model(
388 match model.name() {
389 "Linear Trend" => ModelType::LinearTrend,
390 "Moving Average" => ModelType::MovingAverage,
391 "Exponential Smoothing" => ModelType::ExponentialSmoothing,
392 _ => ModelType::ExponentialSmoothing,
393 }
394 )?;
395
396 validation_model.train(&train_data)?;
397
398 let predictions = validation_model.forecast(holdout_size)?;
400 let actuals = holdout_data.values();
401
402 match super::metrics::ForecastMetrics::new(&actuals, &predictions) {
404 Ok(metrics) => Ok(Some(metrics)),
405 Err(_) => Ok(None), }
407 }
408}
409
410impl Default for ForecastEngine {
411 fn default() -> Self {
412 Self::new_with_defaults()
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419 use chrono::Utc;
420
421 fn create_test_series(values: Vec<i32>) -> TimeSeriesData {
422 let start = Utc::now();
423 let points: Vec<DataPoint> = values
424 .into_iter()
425 .enumerate()
426 .map(|(i, v)| {
427 DataPoint::new(start + Duration::hours(i as i64), Decimal::from(v))
428 })
429 .collect();
430
431 TimeSeriesData::with_auto_interval(points)
432 }
433
434 #[test]
435 fn test_engine_creation() {
436 let engine = ForecastEngine::default();
437 assert_eq!(engine.config.confidence_level, 0.95);
438 }
439
440 #[test]
441 fn test_validate_data() {
442 let engine = ForecastEngine::default();
443
444 let empty_data = TimeSeriesData::new(vec![]);
446 assert!(engine.validate_data(&empty_data).is_err());
447
448 let small_data = create_test_series(vec![1, 2]);
450 assert!(engine.validate_data(&small_data).is_err());
451
452 let valid_data = create_test_series(vec![1, 2, 3, 4, 5, 6, 7, 8]);
454 assert!(engine.validate_data(&valid_data).is_ok());
455 }
456
457 #[test]
458 fn test_select_best_model() {
459 let engine = ForecastEngine::default();
460
461 let trending_data = create_test_series(vec![10, 20, 30, 40, 50, 60, 70, 80]);
463 let model_type = engine.select_best_model(&trending_data).unwrap();
464 assert_eq!(model_type, ModelType::LinearTrend);
465
466 let stable_data = create_test_series(vec![50, 51, 49, 50, 52, 48, 50, 51, 49, 50]);
468 let model_type = engine.select_best_model(&stable_data).unwrap();
469 assert_eq!(model_type, ModelType::MovingAverage);
470 }
471
472 #[test]
473 fn test_calculate_periods() {
474 let engine = ForecastEngine::default();
475
476 let data = create_test_series(vec![1, 2, 3, 4, 5, 6, 7, 8]);
478 let mut config = ForecastConfig::default();
479 config.horizon = ForecastHorizon::Periods(10);
480
481 let request = ForecastRequest {
482 data: data.clone(),
483 config: config.clone(),
484 preferred_model: None,
485 };
486
487 let periods = engine.calculate_periods(&request).unwrap();
488 assert_eq!(periods, 10);
489
490 config.horizon = ForecastHorizon::Days(7);
492 let request = ForecastRequest {
493 data: data.clone(),
494 config,
495 preferred_model: None,
496 };
497
498 let periods = engine.calculate_periods(&request).unwrap();
499 assert_eq!(periods, 168); }
501
502 #[test]
503 fn test_forecast_generation() {
504 let engine = ForecastEngine::default();
505 let data = create_test_series(vec![10, 20, 30, 40, 50, 60, 70, 80]);
506
507 let mut config = ForecastConfig::default();
508 config.horizon = ForecastHorizon::Periods(5);
509
510 let request = ForecastRequest {
511 data,
512 config,
513 preferred_model: Some(ModelType::LinearTrend),
514 };
515
516 let result = engine.forecast(request);
517 assert!(result.is_ok());
518
519 let forecast = result.unwrap();
520 assert_eq!(forecast.forecast.len(), 5);
521 assert_eq!(forecast.lower_bound.len(), 5);
522 assert_eq!(forecast.upper_bound.len(), 5);
523 assert_eq!(forecast.model_name, "Linear Trend");
524 assert_eq!(forecast.trend, TrendDirection::Increasing);
525 }
526
527 #[test]
528 fn test_z_score_calculation() {
529 let engine = ForecastEngine::default();
530
531 assert_eq!(engine.calculate_z_score(0.90), 1.645);
532 assert_eq!(engine.calculate_z_score(0.95), 1.96);
533 assert_eq!(engine.calculate_z_score(0.99), 2.576);
534 }
535
536 #[test]
537 fn test_autocorrelation() {
538 let engine = ForecastEngine::default();
539 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
540 let mean = values.iter().sum::<f64>() / values.len() as f64;
541
542 let correlation = engine.calculate_autocorrelation(&values, 1, mean);
543 assert!(correlation > 0.0); }
545
546 #[test]
547 fn test_seasonality_detection() {
548 let engine = ForecastEngine::default();
549
550 let small_data = create_test_series(vec![1, 2, 3, 4, 5]);
552 let seasonality = engine.detect_seasonality(&small_data).unwrap();
553 assert!(!seasonality.detected);
554
555 let data = create_test_series(vec![
557 10, 20, 15, 25, 20, 30, 25, 35, 30, 40, 35, 45, 40, 50,
558 ]);
559 let seasonality = engine.detect_seasonality(&data).unwrap();
560 assert!(seasonality.strength >= 0.0 && seasonality.strength <= 1.0);
562 }
563
564 #[test]
565 fn test_insufficient_data_error() {
566 let engine = ForecastEngine::default();
567 let data = create_test_series(vec![1, 2]); let config = ForecastConfig::default();
570 let request = ForecastRequest {
571 data,
572 config,
573 preferred_model: None,
574 };
575
576 let result = engine.forecast(request);
577 assert!(result.is_err());
578 }
579
580 #[test]
581 fn test_different_models() {
582 let engine = ForecastEngine::default();
583 let data = create_test_series(vec![10, 20, 30, 40, 50, 60, 70, 80]);
584
585 let config = ForecastConfig {
586 horizon: ForecastHorizon::Periods(3),
587 ..Default::default()
588 };
589
590 let request = ForecastRequest {
592 data: data.clone(),
593 config: config.clone(),
594 preferred_model: Some(ModelType::LinearTrend),
595 };
596 assert!(engine.forecast(request).is_ok());
597
598 let request = ForecastRequest {
600 data: data.clone(),
601 config: config.clone(),
602 preferred_model: Some(ModelType::MovingAverage),
603 };
604 assert!(engine.forecast(request).is_ok());
605
606 let request = ForecastRequest {
608 data: data.clone(),
609 config: config.clone(),
610 preferred_model: Some(ModelType::ExponentialSmoothing),
611 };
612 assert!(engine.forecast(request).is_ok());
613
614 let request = ForecastRequest {
616 data,
617 config,
618 preferred_model: Some(ModelType::Auto),
619 };
620 assert!(engine.forecast(request).is_ok());
621 }
622}