llm_cost_ops/forecasting/
engine.rs

1// Forecasting engine - orchestrates models and generates forecasts
2
3use chrono::{DateTime, Duration, Utc};
4use rust_decimal::Decimal;
5use serde::{Deserialize, Serialize};
6
7use super::{
8    models::{ExponentialSmoothingModel, ForecastModel, LinearTrendModel, MovingAverageModel},
9    types::{
10        DataPoint, ForecastConfig, ForecastHorizon, ForecastResult as TypesForecastResult,
11        SeasonalityPattern, TimeSeriesData, TrendDirection,
12    },
13    ForecastError, ForecastResult,
14};
15
16/// Forecast request
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ForecastRequest {
19    /// Historical time series data
20    pub data: TimeSeriesData,
21
22    /// Forecast configuration
23    pub config: ForecastConfig,
24
25    /// Preferred model (if None, best model will be selected)
26    pub preferred_model: Option<ModelType>,
27}
28
29/// Available forecasting models
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
31#[serde(rename_all = "snake_case")]
32pub enum ModelType {
33    LinearTrend,
34    MovingAverage,
35    ExponentialSmoothing,
36    Auto, // Automatically select best model
37}
38
39/// Forecast engine
40pub struct ForecastEngine {
41    config: ForecastConfig,
42}
43
44impl ForecastEngine {
45    /// Create a new forecast engine
46    pub fn new(config: ForecastConfig) -> Self {
47        Self { config }
48    }
49
50    /// Create with default configuration
51    pub fn new_with_defaults() -> Self {
52        Self {
53            config: ForecastConfig::default(),
54        }
55    }
56
57    /// Generate forecast from request
58    pub fn forecast(&self, request: ForecastRequest) -> ForecastResult<TypesForecastResult> {
59        // Validate input data
60        self.validate_data(&request.data)?;
61
62        // Determine model to use
63        let model_type = request.preferred_model.unwrap_or(ModelType::Auto);
64        let model_type = if model_type == ModelType::Auto {
65            self.select_best_model(&request.data)?
66        } else {
67            model_type
68        };
69
70        // Create and train model
71        let mut model = self.create_model(model_type)?;
72        model.train(&request.data)?;
73
74        // Calculate number of periods to forecast
75        let n_periods = self.calculate_periods(&request)?;
76
77        // Generate forecast values
78        let forecast_values = model.forecast(n_periods)?;
79
80        // Generate forecast data points with timestamps
81        let last_timestamp = request
82            .data
83            .last()
84            .ok_or_else(|| ForecastError::InsufficientData("No data points".to_string()))?
85            .timestamp;
86
87        let interval_secs = request.data.interval_secs.unwrap_or(3600);
88
89        let forecast_points = self.generate_forecast_points(
90            last_timestamp,
91            interval_secs,
92            forecast_values.clone(),
93        );
94
95        // Calculate prediction intervals
96        let std_dev = request.data.std_dev().unwrap_or(0.0);
97        let z_score = self.calculate_z_score(request.config.confidence_level);
98
99        let (lower_bound, upper_bound) =
100            self.calculate_prediction_intervals(&forecast_points, std_dev, z_score);
101
102        // Detect trend
103        let trend = if request.config.include_trend {
104            model.detect_trend()
105        } else {
106            TrendDirection::Unknown
107        };
108
109        // Detect seasonality
110        let seasonality = if request.config.detect_seasonality {
111            self.detect_seasonality(&request.data)?
112        } else {
113            SeasonalityPattern {
114                detected: false,
115                period: None,
116                strength: 0.0,
117            }
118        };
119
120        // Calculate metrics if we have enough validation data
121        let metrics = self.calculate_validation_metrics(&request.data, &model)?;
122
123        Ok(TypesForecastResult {
124            forecast: forecast_points,
125            lower_bound,
126            upper_bound,
127            trend,
128            seasonality,
129            model_name: model.name().to_string(),
130            confidence_level: request.config.confidence_level,
131            metrics,
132        })
133    }
134
135    /// Validate input data
136    fn validate_data(&self, data: &TimeSeriesData) -> ForecastResult<()> {
137        if data.is_empty() {
138            return Err(ForecastError::InsufficientData(
139                "Time series data is empty".to_string(),
140            ));
141        }
142
143        if data.len() < self.config.min_data_points {
144            return Err(ForecastError::InsufficientData(format!(
145                "Insufficient data points: {} (minimum required: {})",
146                data.len(),
147                self.config.min_data_points
148            )));
149        }
150
151        Ok(())
152    }
153
154    /// Select the best model based on data characteristics
155    fn select_best_model(&self, data: &TimeSeriesData) -> ForecastResult<ModelType> {
156        // For now, use a simple heuristic:
157        // - Linear trend for data with clear trends
158        // - Moving average for stable data
159        // - Exponential smoothing as fallback
160
161        let values = data.values_f64();
162        if values.len() < 2 {
163            return Ok(ModelType::ExponentialSmoothing);
164        }
165
166        // Calculate simple trend indicator
167        let first_half_mean = values[..values.len() / 2].iter().sum::<f64>()
168            / (values.len() / 2) as f64;
169        let second_half_mean =
170            values[values.len() / 2..].iter().sum::<f64>() / (values.len() - values.len() / 2) as f64;
171
172        let trend_ratio = if first_half_mean.abs() > f64::EPSILON {
173            second_half_mean / first_half_mean
174        } else {
175            1.0
176        };
177
178        // If strong trend (>5% change), use linear trend
179        if !(0.95..=1.05).contains(&trend_ratio) {
180            Ok(ModelType::LinearTrend)
181        } else if data.len() >= 10 {
182            // Use moving average for stable data with enough points
183            Ok(ModelType::MovingAverage)
184        } else {
185            // Default to exponential smoothing
186            Ok(ModelType::ExponentialSmoothing)
187        }
188    }
189
190    /// Create a model instance
191    fn create_model(&self, model_type: ModelType) -> ForecastResult<Box<dyn ForecastModel>> {
192        match model_type {
193            ModelType::LinearTrend => Ok(Box::new(LinearTrendModel::new())),
194            ModelType::MovingAverage => {
195                let window_size = (self.config.min_data_points / 2).max(3);
196                Ok(Box::new(MovingAverageModel::new(window_size)))
197            }
198            ModelType::ExponentialSmoothing => Ok(Box::new(
199                ExponentialSmoothingModel::with_default_alpha(),
200            )),
201            ModelType::Auto => Err(ForecastError::InvalidConfig(
202                "Auto model type should have been resolved".to_string(),
203            )),
204        }
205    }
206
207    /// Calculate number of periods to forecast
208    fn calculate_periods(&self, request: &ForecastRequest) -> ForecastResult<usize> {
209        match request.config.horizon {
210            ForecastHorizon::Periods(n) => Ok(n),
211            ForecastHorizon::Days(days) => {
212                let interval_secs = request.data.interval_secs.unwrap_or(3600);
213                let periods_per_day = 86400 / interval_secs;
214                Ok((days as i64 * periods_per_day) as usize)
215            }
216            ForecastHorizon::UntilDate(target_date) => {
217                let last_timestamp = request
218                    .data
219                    .last()
220                    .ok_or_else(|| {
221                        ForecastError::InsufficientData("No data points".to_string())
222                    })?
223                    .timestamp;
224
225                let duration = target_date.signed_duration_since(last_timestamp);
226                let interval_secs = request.data.interval_secs.unwrap_or(3600);
227
228                let periods = duration.num_seconds() / interval_secs;
229                if periods <= 0 {
230                    return Err(ForecastError::InvalidConfig(
231                        "Target date must be in the future".to_string(),
232                    ));
233                }
234
235                Ok(periods as usize)
236            }
237        }
238    }
239
240    /// Generate forecast data points with timestamps
241    fn generate_forecast_points(
242        &self,
243        last_timestamp: DateTime<Utc>,
244        interval_secs: i64,
245        values: Vec<Decimal>,
246    ) -> Vec<DataPoint> {
247        values
248            .into_iter()
249            .enumerate()
250            .map(|(i, value)| {
251                DataPoint::new(
252                    last_timestamp + Duration::seconds((i as i64 + 1) * interval_secs),
253                    value,
254                )
255            })
256            .collect()
257    }
258
259    /// Calculate z-score for confidence level
260    fn calculate_z_score(&self, confidence_level: f64) -> f64 {
261        // Common z-scores for confidence levels
262        match (confidence_level * 100.0) as i32 {
263            90 => 1.645,
264            95 => 1.96,
265            99 => 2.576,
266            _ => 1.96, // Default to 95%
267        }
268    }
269
270    /// Calculate prediction intervals
271    fn calculate_prediction_intervals(
272        &self,
273        forecast: &[DataPoint],
274        std_dev: f64,
275        z_score: f64,
276    ) -> (Vec<DataPoint>, Vec<DataPoint>) {
277        let margin = std_dev * z_score;
278
279        let lower_bound: Vec<DataPoint> = forecast
280            .iter()
281            .map(|point| {
282                let lower_value = point.value
283                    - Decimal::try_from(margin).unwrap_or(Decimal::ZERO);
284                let lower_value = lower_value.max(Decimal::ZERO); // Ensure non-negative
285                DataPoint::new(point.timestamp, lower_value)
286            })
287            .collect();
288
289        let upper_bound: Vec<DataPoint> = forecast
290            .iter()
291            .map(|point| {
292                let upper_value = point.value
293                    + Decimal::try_from(margin).unwrap_or(Decimal::ZERO);
294                DataPoint::new(point.timestamp, upper_value)
295            })
296            .collect();
297
298        (lower_bound, upper_bound)
299    }
300
301    /// Detect seasonality in time series
302    fn detect_seasonality(&self, data: &TimeSeriesData) -> ForecastResult<SeasonalityPattern> {
303        // Simple autocorrelation-based seasonality detection
304        if data.len() < 14 {
305            return Ok(SeasonalityPattern {
306                detected: false,
307                period: None,
308                strength: 0.0,
309            });
310        }
311
312        let values = data.values_f64();
313        let mean = values.iter().sum::<f64>() / values.len() as f64;
314
315        // Test common periods: daily (24h), weekly (7d)
316        let test_periods = vec![24, 168]; // hours
317        let interval_hours = data.interval_secs.unwrap_or(3600) / 3600;
318
319        let mut best_period = None;
320        let mut best_correlation = 0.0;
321
322        for &period_hours in &test_periods {
323            let lag = period_hours / interval_hours;
324            if lag as usize >= values.len() {
325                continue;
326            }
327
328            let correlation = self.calculate_autocorrelation(&values, lag as usize, mean);
329            if correlation > best_correlation {
330                best_correlation = correlation;
331                best_period = Some(lag as usize);
332            }
333        }
334
335        let detected = best_correlation > 0.3; // Threshold for significance
336
337        Ok(SeasonalityPattern {
338            detected,
339            period: if detected { best_period } else { None },
340            strength: if detected { best_correlation } else { 0.0 },
341        })
342    }
343
344    /// Calculate autocorrelation at a given lag
345    fn calculate_autocorrelation(&self, values: &[f64], lag: usize, mean: f64) -> f64 {
346        if lag >= values.len() {
347            return 0.0;
348        }
349
350        let n = values.len() - lag;
351        let mut numerator = 0.0;
352        let mut denominator = 0.0;
353
354        for i in 0..n {
355            numerator += (values[i] - mean) * (values[i + lag] - mean);
356        }
357
358        for &value in values {
359            denominator += (value - mean).powi(2);
360        }
361
362        if denominator.abs() < f64::EPSILON {
363            return 0.0;
364        }
365
366        numerator / denominator
367    }
368
369    /// Calculate validation metrics using holdout set
370    fn calculate_validation_metrics(
371        &self,
372        data: &TimeSeriesData,
373        model: &Box<dyn ForecastModel>,
374    ) -> ForecastResult<Option<super::metrics::ForecastMetrics>> {
375        // Use 20% of data as holdout for validation
376        let holdout_size = (data.len() as f64 * 0.2).ceil() as usize;
377        if holdout_size < 2 || data.len() - holdout_size < self.config.min_data_points {
378            return Ok(None); // Not enough data for validation
379        }
380
381        // Split data
382        let train_size = data.len() - holdout_size;
383        let train_data = data.subset(0, train_size);
384        let holdout_data = data.subset(train_size, data.len());
385
386        // Train on training set
387        let mut validation_model = self.create_model(
388            match model.name() {
389                "Linear Trend" => ModelType::LinearTrend,
390                "Moving Average" => ModelType::MovingAverage,
391                "Exponential Smoothing" => ModelType::ExponentialSmoothing,
392                _ => ModelType::ExponentialSmoothing,
393            }
394        )?;
395
396        validation_model.train(&train_data)?;
397
398        // Forecast holdout period
399        let predictions = validation_model.forecast(holdout_size)?;
400        let actuals = holdout_data.values();
401
402        // Calculate metrics
403        match super::metrics::ForecastMetrics::new(&actuals, &predictions) {
404            Ok(metrics) => Ok(Some(metrics)),
405            Err(_) => Ok(None), // If metrics calculation fails, return None
406        }
407    }
408}
409
410impl Default for ForecastEngine {
411    fn default() -> Self {
412        Self::new_with_defaults()
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419    use chrono::Utc;
420
421    fn create_test_series(values: Vec<i32>) -> TimeSeriesData {
422        let start = Utc::now();
423        let points: Vec<DataPoint> = values
424            .into_iter()
425            .enumerate()
426            .map(|(i, v)| {
427                DataPoint::new(start + Duration::hours(i as i64), Decimal::from(v))
428            })
429            .collect();
430
431        TimeSeriesData::with_auto_interval(points)
432    }
433
434    #[test]
435    fn test_engine_creation() {
436        let engine = ForecastEngine::default();
437        assert_eq!(engine.config.confidence_level, 0.95);
438    }
439
440    #[test]
441    fn test_validate_data() {
442        let engine = ForecastEngine::default();
443
444        // Empty data
445        let empty_data = TimeSeriesData::new(vec![]);
446        assert!(engine.validate_data(&empty_data).is_err());
447
448        // Insufficient data
449        let small_data = create_test_series(vec![1, 2]);
450        assert!(engine.validate_data(&small_data).is_err());
451
452        // Valid data
453        let valid_data = create_test_series(vec![1, 2, 3, 4, 5, 6, 7, 8]);
454        assert!(engine.validate_data(&valid_data).is_ok());
455    }
456
457    #[test]
458    fn test_select_best_model() {
459        let engine = ForecastEngine::default();
460
461        // Trending data should select linear trend
462        let trending_data = create_test_series(vec![10, 20, 30, 40, 50, 60, 70, 80]);
463        let model_type = engine.select_best_model(&trending_data).unwrap();
464        assert_eq!(model_type, ModelType::LinearTrend);
465
466        // Stable data should select moving average
467        let stable_data = create_test_series(vec![50, 51, 49, 50, 52, 48, 50, 51, 49, 50]);
468        let model_type = engine.select_best_model(&stable_data).unwrap();
469        assert_eq!(model_type, ModelType::MovingAverage);
470    }
471
472    #[test]
473    fn test_calculate_periods() {
474        let engine = ForecastEngine::default();
475
476        // Test Periods horizon
477        let data = create_test_series(vec![1, 2, 3, 4, 5, 6, 7, 8]);
478        let mut config = ForecastConfig::default();
479        config.horizon = ForecastHorizon::Periods(10);
480
481        let request = ForecastRequest {
482            data: data.clone(),
483            config: config.clone(),
484            preferred_model: None,
485        };
486
487        let periods = engine.calculate_periods(&request).unwrap();
488        assert_eq!(periods, 10);
489
490        // Test Days horizon
491        config.horizon = ForecastHorizon::Days(7);
492        let request = ForecastRequest {
493            data: data.clone(),
494            config,
495            preferred_model: None,
496        };
497
498        let periods = engine.calculate_periods(&request).unwrap();
499        assert_eq!(periods, 168); // 7 days * 24 hours
500    }
501
502    #[test]
503    fn test_forecast_generation() {
504        let engine = ForecastEngine::default();
505        let data = create_test_series(vec![10, 20, 30, 40, 50, 60, 70, 80]);
506
507        let mut config = ForecastConfig::default();
508        config.horizon = ForecastHorizon::Periods(5);
509
510        let request = ForecastRequest {
511            data,
512            config,
513            preferred_model: Some(ModelType::LinearTrend),
514        };
515
516        let result = engine.forecast(request);
517        assert!(result.is_ok());
518
519        let forecast = result.unwrap();
520        assert_eq!(forecast.forecast.len(), 5);
521        assert_eq!(forecast.lower_bound.len(), 5);
522        assert_eq!(forecast.upper_bound.len(), 5);
523        assert_eq!(forecast.model_name, "Linear Trend");
524        assert_eq!(forecast.trend, TrendDirection::Increasing);
525    }
526
527    #[test]
528    fn test_z_score_calculation() {
529        let engine = ForecastEngine::default();
530
531        assert_eq!(engine.calculate_z_score(0.90), 1.645);
532        assert_eq!(engine.calculate_z_score(0.95), 1.96);
533        assert_eq!(engine.calculate_z_score(0.99), 2.576);
534    }
535
536    #[test]
537    fn test_autocorrelation() {
538        let engine = ForecastEngine::default();
539        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
540        let mean = values.iter().sum::<f64>() / values.len() as f64;
541
542        let correlation = engine.calculate_autocorrelation(&values, 1, mean);
543        assert!(correlation > 0.0); // Positive autocorrelation for trending data
544    }
545
546    #[test]
547    fn test_seasonality_detection() {
548        let engine = ForecastEngine::default();
549
550        // Not enough data
551        let small_data = create_test_series(vec![1, 2, 3, 4, 5]);
552        let seasonality = engine.detect_seasonality(&small_data).unwrap();
553        assert!(!seasonality.detected);
554
555        // Enough data but no clear seasonality
556        let data = create_test_series(vec![
557            10, 20, 15, 25, 20, 30, 25, 35, 30, 40, 35, 45, 40, 50,
558        ]);
559        let seasonality = engine.detect_seasonality(&data).unwrap();
560        // Result depends on autocorrelation, might or might not detect
561        assert!(seasonality.strength >= 0.0 && seasonality.strength <= 1.0);
562    }
563
564    #[test]
565    fn test_insufficient_data_error() {
566        let engine = ForecastEngine::default();
567        let data = create_test_series(vec![1, 2]); // Not enough data
568
569        let config = ForecastConfig::default();
570        let request = ForecastRequest {
571            data,
572            config,
573            preferred_model: None,
574        };
575
576        let result = engine.forecast(request);
577        assert!(result.is_err());
578    }
579
580    #[test]
581    fn test_different_models() {
582        let engine = ForecastEngine::default();
583        let data = create_test_series(vec![10, 20, 30, 40, 50, 60, 70, 80]);
584
585        let config = ForecastConfig {
586            horizon: ForecastHorizon::Periods(3),
587            ..Default::default()
588        };
589
590        // Test Linear Trend
591        let request = ForecastRequest {
592            data: data.clone(),
593            config: config.clone(),
594            preferred_model: Some(ModelType::LinearTrend),
595        };
596        assert!(engine.forecast(request).is_ok());
597
598        // Test Moving Average
599        let request = ForecastRequest {
600            data: data.clone(),
601            config: config.clone(),
602            preferred_model: Some(ModelType::MovingAverage),
603        };
604        assert!(engine.forecast(request).is_ok());
605
606        // Test Exponential Smoothing
607        let request = ForecastRequest {
608            data: data.clone(),
609            config: config.clone(),
610            preferred_model: Some(ModelType::ExponentialSmoothing),
611        };
612        assert!(engine.forecast(request).is_ok());
613
614        // Test Auto selection
615        let request = ForecastRequest {
616            data,
617            config,
618            preferred_model: Some(ModelType::Auto),
619        };
620        assert!(engine.forecast(request).is_ok());
621    }
622}