llm_cost_ops/forecasting/
models.rs1use chrono::Duration;
4use rust_decimal::Decimal;
5
6use super::{
7 types::{DataPoint, TimeSeriesData, TrendDirection},
8 ForecastError, ForecastResult,
9};
10
11pub trait ForecastModel: Send + Sync {
13 fn name(&self) -> &str;
15
16 fn train(&mut self, data: &TimeSeriesData) -> ForecastResult<()>;
18
19 fn forecast(&self, n_periods: usize) -> ForecastResult<Vec<Decimal>>;
21
22 fn detect_trend(&self) -> TrendDirection;
24}
25
26pub struct LinearTrendModel {
28 slope: f64,
29 intercept: f64,
30 last_value: Option<Decimal>,
31 interval_secs: i64,
32 trained: bool,
33}
34
35impl LinearTrendModel {
36 pub fn new() -> Self {
38 Self {
39 slope: 0.0,
40 intercept: 0.0,
41 last_value: None,
42 interval_secs: 3600, trained: false,
44 }
45 }
46
47 fn calculate_regression(values: &[f64]) -> (f64, f64) {
49 let n = values.len() as f64;
50 let x: Vec<f64> = (0..values.len()).map(|i| i as f64).collect();
51
52 let sum_x: f64 = x.iter().sum();
53 let sum_y: f64 = values.iter().sum();
54 let sum_xy: f64 = x.iter().zip(values.iter()).map(|(a, b)| a * b).sum();
55 let sum_x2: f64 = x.iter().map(|a| a * a).sum();
56
57 let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_x2 - sum_x * sum_x);
58 let intercept = (sum_y - slope * sum_x) / n;
59
60 (slope, intercept)
61 }
62}
63
64impl Default for LinearTrendModel {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl ForecastModel for LinearTrendModel {
71 fn name(&self) -> &str {
72 "Linear Trend"
73 }
74
75 fn train(&mut self, data: &TimeSeriesData) -> ForecastResult<()> {
76 if data.len() < 2 {
77 return Err(ForecastError::InsufficientData(
78 "Linear trend requires at least 2 data points".to_string(),
79 ));
80 }
81
82 let values = data.values_f64();
83 let (slope, intercept) = Self::calculate_regression(&values);
84
85 self.slope = slope;
86 self.intercept = intercept;
87 self.last_value = data.last().map(|p| p.value);
88 self.interval_secs = data.interval_secs.unwrap_or(3600);
89 self.trained = true;
90
91 Ok(())
92 }
93
94 fn forecast(&self, n_periods: usize) -> ForecastResult<Vec<Decimal>> {
95 if !self.trained {
96 return Err(ForecastError::ModelError(
97 "Model must be trained before forecasting".to_string(),
98 ));
99 }
100
101 let mut forecasts = Vec::with_capacity(n_periods);
102 let last_idx = self.intercept.abs() + self.slope.abs();
103
104 for i in 1..=n_periods {
105 let forecast_value = self.slope * (last_idx + i as f64) + self.intercept;
106 let forecast_value = forecast_value.max(0.0); forecasts.push(
109 Decimal::try_from(forecast_value)
110 .unwrap_or(Decimal::ZERO),
111 );
112 }
113
114 Ok(forecasts)
115 }
116
117 fn detect_trend(&self) -> TrendDirection {
118 if !self.trained {
119 return TrendDirection::Unknown;
120 }
121
122 if self.slope > 0.01 {
123 TrendDirection::Increasing
124 } else if self.slope < -0.01 {
125 TrendDirection::Decreasing
126 } else {
127 TrendDirection::Stable
128 }
129 }
130}
131
132pub struct MovingAverageModel {
134 window_size: usize,
135 values: Vec<Decimal>,
136 interval_secs: i64,
137}
138
139impl MovingAverageModel {
140 pub fn new(window_size: usize) -> Self {
142 Self {
143 window_size,
144 values: Vec::new(),
145 interval_secs: 3600,
146 }
147 }
148}
149
150impl ForecastModel for MovingAverageModel {
151 fn name(&self) -> &str {
152 "Moving Average"
153 }
154
155 fn train(&mut self, data: &TimeSeriesData) -> ForecastResult<()> {
156 if data.len() < self.window_size {
157 return Err(ForecastError::InsufficientData(format!(
158 "Moving average requires at least {} data points",
159 self.window_size
160 )));
161 }
162
163 self.values = data.values();
164 self.interval_secs = data.interval_secs.unwrap_or(3600);
165
166 Ok(())
167 }
168
169 fn forecast(&self, n_periods: usize) -> ForecastResult<Vec<Decimal>> {
170 if self.values.is_empty() {
171 return Err(ForecastError::ModelError(
172 "Model must be trained before forecasting".to_string(),
173 ));
174 }
175
176 let mut forecasts = Vec::with_capacity(n_periods);
177 let mut extended_values = self.values.clone();
178
179 for _ in 0..n_periods {
180 let start_idx = extended_values.len().saturating_sub(self.window_size);
182 let window = &extended_values[start_idx..];
183
184 let sum: Decimal = window.iter().sum();
185 let avg = sum / Decimal::from(window.len());
186
187 forecasts.push(avg);
188 extended_values.push(avg);
189 }
190
191 Ok(forecasts)
192 }
193
194 fn detect_trend(&self) -> TrendDirection {
195 if self.values.len() < 2 {
196 return TrendDirection::Unknown;
197 }
198
199 let mid_point = self.values.len() / 2;
200 let first_half: Decimal = self.values[..mid_point].iter().sum::<Decimal>()
201 / Decimal::from(mid_point);
202 let second_half: Decimal = self.values[mid_point..].iter().sum::<Decimal>()
203 / Decimal::from(self.values.len() - mid_point);
204
205 if second_half > first_half * Decimal::new(101, 2) {
206 TrendDirection::Increasing
208 } else if second_half < first_half * Decimal::new(99, 2) {
209 TrendDirection::Decreasing
211 } else {
212 TrendDirection::Stable
213 }
214 }
215}
216
217pub struct ExponentialSmoothingModel {
219 alpha: f64, last_smoothed: Option<f64>,
221 interval_secs: i64,
222 trained: bool,
223}
224
225impl ExponentialSmoothingModel {
226 pub fn new(alpha: f64) -> ForecastResult<Self> {
228 if !(0.0..=1.0).contains(&alpha) {
229 return Err(ForecastError::InvalidConfig(
230 "Alpha must be between 0 and 1".to_string(),
231 ));
232 }
233
234 Ok(Self {
235 alpha,
236 last_smoothed: None,
237 interval_secs: 3600,
238 trained: false,
239 })
240 }
241
242 pub fn with_default_alpha() -> Self {
244 Self {
245 alpha: 0.3,
246 last_smoothed: None,
247 interval_secs: 3600,
248 trained: false,
249 }
250 }
251}
252
253impl ForecastModel for ExponentialSmoothingModel {
254 fn name(&self) -> &str {
255 "Exponential Smoothing"
256 }
257
258 fn train(&mut self, data: &TimeSeriesData) -> ForecastResult<()> {
259 if data.is_empty() {
260 return Err(ForecastError::InsufficientData(
261 "Exponential smoothing requires at least 1 data point".to_string(),
262 ));
263 }
264
265 let values = data.values_f64();
266 let mut smoothed = values[0];
267
268 for &value in &values[1..] {
269 smoothed = self.alpha * value + (1.0 - self.alpha) * smoothed;
270 }
271
272 self.last_smoothed = Some(smoothed);
273 self.interval_secs = data.interval_secs.unwrap_or(3600);
274 self.trained = true;
275
276 Ok(())
277 }
278
279 fn forecast(&self, n_periods: usize) -> ForecastResult<Vec<Decimal>> {
280 if !self.trained {
281 return Err(ForecastError::ModelError(
282 "Model must be trained before forecasting".to_string(),
283 ));
284 }
285
286 let forecast_value = self.last_smoothed.unwrap_or(0.0).max(0.0);
287 let decimal_value = Decimal::try_from(forecast_value)
288 .unwrap_or(Decimal::ZERO);
289
290 Ok(vec![decimal_value; n_periods])
292 }
293
294 fn detect_trend(&self) -> TrendDirection {
295 TrendDirection::Stable
297 }
298}
299
300pub fn generate_forecast_points(
302 last_timestamp: chrono::DateTime<chrono::Utc>,
303 interval_secs: i64,
304 values: Vec<Decimal>,
305) -> Vec<DataPoint> {
306 values
307 .into_iter()
308 .enumerate()
309 .map(|(i, value)| DataPoint::new(
310 last_timestamp + Duration::seconds((i as i64 + 1) * interval_secs),
311 value,
312 ))
313 .collect()
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use chrono::Utc;
320
321 fn create_test_series(values: Vec<i32>) -> TimeSeriesData {
322 let start = Utc::now();
323 let points: Vec<DataPoint> = values
324 .into_iter()
325 .enumerate()
326 .map(|(i, v)| {
327 DataPoint::new(start + Duration::hours(i as i64), Decimal::from(v))
328 })
329 .collect();
330
331 TimeSeriesData::with_auto_interval(points)
332 }
333
334 #[test]
335 fn test_linear_trend_increasing() {
336 let data = create_test_series(vec![10, 20, 30, 40, 50]);
337 let mut model = LinearTrendModel::new();
338
339 assert!(model.train(&data).is_ok());
340 assert_eq!(model.detect_trend(), TrendDirection::Increasing);
341
342 let forecast = model.forecast(3).unwrap();
343 assert_eq!(forecast.len(), 3);
344 assert!(forecast[0] > Decimal::from(50));
346 }
347
348 #[test]
349 fn test_linear_trend_decreasing() {
350 let data = create_test_series(vec![50, 40, 30, 20, 10]);
351 let mut model = LinearTrendModel::new();
352
353 assert!(model.train(&data).is_ok());
354 assert_eq!(model.detect_trend(), TrendDirection::Decreasing);
355 }
356
357 #[test]
358 fn test_moving_average() {
359 let data = create_test_series(vec![10, 20, 15, 25, 20]);
360 let mut model = MovingAverageModel::new(3);
361
362 assert!(model.train(&data).is_ok());
363
364 let forecast = model.forecast(2).unwrap();
365 assert_eq!(forecast.len(), 2);
366 }
367
368 #[test]
369 fn test_exponential_smoothing() {
370 let data = create_test_series(vec![10, 12, 11, 13, 12]);
371 let mut model = ExponentialSmoothingModel::with_default_alpha();
372
373 assert!(model.train(&data).is_ok());
374
375 let forecast = model.forecast(3).unwrap();
376 assert_eq!(forecast.len(), 3);
377 assert_eq!(forecast[0], forecast[1]);
379 assert_eq!(forecast[1], forecast[2]);
380 }
381
382 #[test]
383 fn test_insufficient_data() {
384 let data = create_test_series(vec![10]);
385 let mut model = LinearTrendModel::new();
386
387 assert!(model.train(&data).is_err());
388 }
389
390 #[test]
391 fn test_untrained_forecast() {
392 let model = LinearTrendModel::new();
393 assert!(model.forecast(5).is_err());
394 }
395
396 #[test]
397 fn test_invalid_alpha() {
398 assert!(ExponentialSmoothingModel::new(1.5).is_err());
399 assert!(ExponentialSmoothingModel::new(-0.1).is_err());
400 assert!(ExponentialSmoothingModel::new(0.5).is_ok());
401 }
402}