llm_cost_ops/forecasting/
models.rs

1// Forecasting models and algorithms
2
3use chrono::Duration;
4use rust_decimal::Decimal;
5
6use super::{
7    types::{DataPoint, TimeSeriesData, TrendDirection},
8    ForecastError, ForecastResult,
9};
10
11/// Trait for forecasting models
12pub trait ForecastModel: Send + Sync {
13    /// Get the name of the model
14    fn name(&self) -> &str;
15
16    /// Train the model on historical data
17    fn train(&mut self, data: &TimeSeriesData) -> ForecastResult<()>;
18
19    /// Generate forecast for n periods ahead
20    fn forecast(&self, n_periods: usize) -> ForecastResult<Vec<Decimal>>;
21
22    /// Detect trend direction
23    fn detect_trend(&self) -> TrendDirection;
24}
25
26/// Linear Trend Model
27pub struct LinearTrendModel {
28    slope: f64,
29    intercept: f64,
30    last_value: Option<Decimal>,
31    interval_secs: i64,
32    trained: bool,
33}
34
35impl LinearTrendModel {
36    /// Create a new linear trend model
37    pub fn new() -> Self {
38        Self {
39            slope: 0.0,
40            intercept: 0.0,
41            last_value: None,
42            interval_secs: 3600, // Default 1 hour
43            trained: false,
44        }
45    }
46
47    /// Calculate linear regression parameters
48    fn calculate_regression(values: &[f64]) -> (f64, f64) {
49        let n = values.len() as f64;
50        let x: Vec<f64> = (0..values.len()).map(|i| i as f64).collect();
51
52        let sum_x: f64 = x.iter().sum();
53        let sum_y: f64 = values.iter().sum();
54        let sum_xy: f64 = x.iter().zip(values.iter()).map(|(a, b)| a * b).sum();
55        let sum_x2: f64 = x.iter().map(|a| a * a).sum();
56
57        let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
58        let intercept = (sum_y - slope * sum_x) / n;
59
60        (slope, intercept)
61    }
62}
63
64impl Default for LinearTrendModel {
65    fn default() -> Self {
66        Self::new()
67    }
68}
69
70impl ForecastModel for LinearTrendModel {
71    fn name(&self) -> &str {
72        "Linear Trend"
73    }
74
75    fn train(&mut self, data: &TimeSeriesData) -> ForecastResult<()> {
76        if data.len() < 2 {
77            return Err(ForecastError::InsufficientData(
78                "Linear trend requires at least 2 data points".to_string(),
79            ));
80        }
81
82        let values = data.values_f64();
83        let (slope, intercept) = Self::calculate_regression(&values);
84
85        self.slope = slope;
86        self.intercept = intercept;
87        self.last_value = data.last().map(|p| p.value);
88        self.interval_secs = data.interval_secs.unwrap_or(3600);
89        self.trained = true;
90
91        Ok(())
92    }
93
94    fn forecast(&self, n_periods: usize) -> ForecastResult<Vec<Decimal>> {
95        if !self.trained {
96            return Err(ForecastError::ModelError(
97                "Model must be trained before forecasting".to_string(),
98            ));
99        }
100
101        let mut forecasts = Vec::with_capacity(n_periods);
102        let last_idx = self.intercept.abs() + self.slope.abs();
103
104        for i in 1..=n_periods {
105            let forecast_value = self.slope * (last_idx + i as f64) + self.intercept;
106            let forecast_value = forecast_value.max(0.0); // Ensure non-negative
107
108            forecasts.push(
109                Decimal::try_from(forecast_value)
110                    .unwrap_or(Decimal::ZERO),
111            );
112        }
113
114        Ok(forecasts)
115    }
116
117    fn detect_trend(&self) -> TrendDirection {
118        if !self.trained {
119            return TrendDirection::Unknown;
120        }
121
122        if self.slope > 0.01 {
123            TrendDirection::Increasing
124        } else if self.slope < -0.01 {
125            TrendDirection::Decreasing
126        } else {
127            TrendDirection::Stable
128        }
129    }
130}
131
132/// Moving Average Model
133pub struct MovingAverageModel {
134    window_size: usize,
135    values: Vec<Decimal>,
136    interval_secs: i64,
137}
138
139impl MovingAverageModel {
140    /// Create a new moving average model
141    pub fn new(window_size: usize) -> Self {
142        Self {
143            window_size,
144            values: Vec::new(),
145            interval_secs: 3600,
146        }
147    }
148}
149
150impl ForecastModel for MovingAverageModel {
151    fn name(&self) -> &str {
152        "Moving Average"
153    }
154
155    fn train(&mut self, data: &TimeSeriesData) -> ForecastResult<()> {
156        if data.len() < self.window_size {
157            return Err(ForecastError::InsufficientData(format!(
158                "Moving average requires at least {} data points",
159                self.window_size
160            )));
161        }
162
163        self.values = data.values();
164        self.interval_secs = data.interval_secs.unwrap_or(3600);
165
166        Ok(())
167    }
168
169    fn forecast(&self, n_periods: usize) -> ForecastResult<Vec<Decimal>> {
170        if self.values.is_empty() {
171            return Err(ForecastError::ModelError(
172                "Model must be trained before forecasting".to_string(),
173            ));
174        }
175
176        let mut forecasts = Vec::with_capacity(n_periods);
177        let mut extended_values = self.values.clone();
178
179        for _ in 0..n_periods {
180            // Calculate moving average of last window_size values
181            let start_idx = extended_values.len().saturating_sub(self.window_size);
182            let window = &extended_values[start_idx..];
183
184            let sum: Decimal = window.iter().sum();
185            let avg = sum / Decimal::from(window.len());
186
187            forecasts.push(avg);
188            extended_values.push(avg);
189        }
190
191        Ok(forecasts)
192    }
193
194    fn detect_trend(&self) -> TrendDirection {
195        if self.values.len() < 2 {
196            return TrendDirection::Unknown;
197        }
198
199        let mid_point = self.values.len() / 2;
200        let first_half: Decimal = self.values[..mid_point].iter().sum::<Decimal>()
201            / Decimal::from(mid_point);
202        let second_half: Decimal = self.values[mid_point..].iter().sum::<Decimal>()
203            / Decimal::from(self.values.len() - mid_point);
204
205        if second_half > first_half * Decimal::new(101, 2) {
206            // 1% increase
207            TrendDirection::Increasing
208        } else if second_half < first_half * Decimal::new(99, 2) {
209            // 1% decrease
210            TrendDirection::Decreasing
211        } else {
212            TrendDirection::Stable
213        }
214    }
215}
216
217/// Exponential Smoothing Model
218pub struct ExponentialSmoothingModel {
219    alpha: f64, // Smoothing factor (0 < alpha < 1)
220    last_smoothed: Option<f64>,
221    interval_secs: i64,
222    trained: bool,
223}
224
225impl ExponentialSmoothingModel {
226    /// Create a new exponential smoothing model
227    pub fn new(alpha: f64) -> ForecastResult<Self> {
228        if !(0.0..=1.0).contains(&alpha) {
229            return Err(ForecastError::InvalidConfig(
230                "Alpha must be between 0 and 1".to_string(),
231            ));
232        }
233
234        Ok(Self {
235            alpha,
236            last_smoothed: None,
237            interval_secs: 3600,
238            trained: false,
239        })
240    }
241
242    /// Create with default alpha (0.3)
243    pub fn with_default_alpha() -> Self {
244        Self {
245            alpha: 0.3,
246            last_smoothed: None,
247            interval_secs: 3600,
248            trained: false,
249        }
250    }
251}
252
253impl ForecastModel for ExponentialSmoothingModel {
254    fn name(&self) -> &str {
255        "Exponential Smoothing"
256    }
257
258    fn train(&mut self, data: &TimeSeriesData) -> ForecastResult<()> {
259        if data.is_empty() {
260            return Err(ForecastError::InsufficientData(
261                "Exponential smoothing requires at least 1 data point".to_string(),
262            ));
263        }
264
265        let values = data.values_f64();
266        let mut smoothed = values[0];
267
268        for &value in &values[1..] {
269            smoothed = self.alpha * value + (1.0 - self.alpha) * smoothed;
270        }
271
272        self.last_smoothed = Some(smoothed);
273        self.interval_secs = data.interval_secs.unwrap_or(3600);
274        self.trained = true;
275
276        Ok(())
277    }
278
279    fn forecast(&self, n_periods: usize) -> ForecastResult<Vec<Decimal>> {
280        if !self.trained {
281            return Err(ForecastError::ModelError(
282                "Model must be trained before forecasting".to_string(),
283            ));
284        }
285
286        let forecast_value = self.last_smoothed.unwrap_or(0.0).max(0.0);
287        let decimal_value = Decimal::try_from(forecast_value)
288            .unwrap_or(Decimal::ZERO);
289
290        // Exponential smoothing produces constant forecast
291        Ok(vec![decimal_value; n_periods])
292    }
293
294    fn detect_trend(&self) -> TrendDirection {
295        // Exponential smoothing doesn't directly detect trends
296        TrendDirection::Stable
297    }
298}
299
300/// Generate forecast data points with timestamps
301pub fn generate_forecast_points(
302    last_timestamp: chrono::DateTime<chrono::Utc>,
303    interval_secs: i64,
304    values: Vec<Decimal>,
305) -> Vec<DataPoint> {
306    values
307        .into_iter()
308        .enumerate()
309        .map(|(i, value)| DataPoint::new(
310            last_timestamp + Duration::seconds((i as i64 + 1) * interval_secs),
311            value,
312        ))
313        .collect()
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use chrono::Utc;
320
321    fn create_test_series(values: Vec<i32>) -> TimeSeriesData {
322        let start = Utc::now();
323        let points: Vec<DataPoint> = values
324            .into_iter()
325            .enumerate()
326            .map(|(i, v)| {
327                DataPoint::new(start + Duration::hours(i as i64), Decimal::from(v))
328            })
329            .collect();
330
331        TimeSeriesData::with_auto_interval(points)
332    }
333
334    #[test]
335    fn test_linear_trend_increasing() {
336        let data = create_test_series(vec![10, 20, 30, 40, 50]);
337        let mut model = LinearTrendModel::new();
338
339        assert!(model.train(&data).is_ok());
340        assert_eq!(model.detect_trend(), TrendDirection::Increasing);
341
342        let forecast = model.forecast(3).unwrap();
343        assert_eq!(forecast.len(), 3);
344        // Should predict continuation of upward trend
345        assert!(forecast[0] > Decimal::from(50));
346    }
347
348    #[test]
349    fn test_linear_trend_decreasing() {
350        let data = create_test_series(vec![50, 40, 30, 20, 10]);
351        let mut model = LinearTrendModel::new();
352
353        assert!(model.train(&data).is_ok());
354        assert_eq!(model.detect_trend(), TrendDirection::Decreasing);
355    }
356
357    #[test]
358    fn test_moving_average() {
359        let data = create_test_series(vec![10, 20, 15, 25, 20]);
360        let mut model = MovingAverageModel::new(3);
361
362        assert!(model.train(&data).is_ok());
363
364        let forecast = model.forecast(2).unwrap();
365        assert_eq!(forecast.len(), 2);
366    }
367
368    #[test]
369    fn test_exponential_smoothing() {
370        let data = create_test_series(vec![10, 12, 11, 13, 12]);
371        let mut model = ExponentialSmoothingModel::with_default_alpha();
372
373        assert!(model.train(&data).is_ok());
374
375        let forecast = model.forecast(3).unwrap();
376        assert_eq!(forecast.len(), 3);
377        // All forecasts should be the same (constant forecast)
378        assert_eq!(forecast[0], forecast[1]);
379        assert_eq!(forecast[1], forecast[2]);
380    }
381
382    #[test]
383    fn test_insufficient_data() {
384        let data = create_test_series(vec![10]);
385        let mut model = LinearTrendModel::new();
386
387        assert!(model.train(&data).is_err());
388    }
389
390    #[test]
391    fn test_untrained_forecast() {
392        let model = LinearTrendModel::new();
393        assert!(model.forecast(5).is_err());
394    }
395
396    #[test]
397    fn test_invalid_alpha() {
398        assert!(ExponentialSmoothingModel::new(1.5).is_err());
399        assert!(ExponentialSmoothingModel::new(-0.1).is_err());
400        assert!(ExponentialSmoothingModel::new(0.5).is_ok());
401    }
402}