use super::trends::TrendsData;
#[derive(Debug, Clone)]
pub struct ForecastData {
pub next_30_days_tokens: u64,
pub next_30_days_cost: f64,
pub monthly_cost_estimate: f64,
pub confidence: f64,
pub trend_direction: TrendDirection,
pub unavailable_reason: Option<String>,
}
#[derive(Debug, Clone)]
pub enum TrendDirection {
Up(f64),
Down(f64),
Stable,
}
impl ForecastData {
pub fn unavailable(reason: &str) -> Self {
Self {
next_30_days_tokens: 0,
next_30_days_cost: 0.0,
monthly_cost_estimate: 0.0,
confidence: 0.0,
trend_direction: TrendDirection::Stable,
unavailable_reason: Some(reason.to_string()),
}
}
}
pub fn forecast_usage(trends: &TrendsData) -> ForecastData {
if trends.dates.len() < 7 {
return ForecastData::unavailable("Insufficient data (<7 days)");
}
let points: Vec<_> = trends
.daily_tokens
.iter()
.enumerate()
.map(|(i, &tokens)| (i as f64, tokens as f64))
.collect();
let (slope, intercept, r_squared) = linear_regression(&points);
let confidence = r_squared.clamp(0.0, 1.0);
let last_x = points.len() as f64;
let next_30_x = last_x + 30.0;
let next_30_days_tokens = (slope * next_30_x + intercept).max(0.0) as u64;
let total_cost: f64 = trends.daily_cost.iter().sum();
let total_tokens: u64 = trends.daily_tokens.iter().sum();
let cost_per_token = if total_tokens > 0 {
total_cost / total_tokens as f64
} else {
0.01 / 1000.0 };
let next_30_days_cost = next_30_days_tokens as f64 * cost_per_token;
let days_in_period = trends.dates.len() as f64;
let monthly_cost_estimate = (total_cost / days_in_period) * 30.0;
let trend_direction = if slope.abs() < 0.01 * intercept.abs() {
TrendDirection::Stable
} else if slope > 0.0 {
let increase_pct = (slope * 30.0 / intercept.abs() * 100.0).abs();
TrendDirection::Up(increase_pct)
} else {
let decrease_pct = (slope * 30.0 / intercept.abs() * 100.0).abs();
TrendDirection::Down(decrease_pct)
};
ForecastData {
next_30_days_tokens,
next_30_days_cost,
monthly_cost_estimate,
confidence,
trend_direction,
unavailable_reason: None,
}
}
fn linear_regression(points: &[(f64, f64)]) -> (f64, f64, f64) {
let n = points.len() as f64;
let sum_x: f64 = points.iter().map(|p| p.0).sum();
let sum_y: f64 = points.iter().map(|p| p.1).sum();
let sum_xx: f64 = points.iter().map(|p| p.0 * p.0).sum();
let sum_xy: f64 = points.iter().map(|p| p.0 * p.1).sum();
let slope = (n * sum_xy - sum_x * sum_y) / (n * sum_xx - sum_x * sum_x);
let intercept = (sum_y - slope * sum_x) / n;
let mean_y = sum_y / n;
let ss_tot: f64 = points.iter().map(|p| (p.1 - mean_y).powi(2)).sum();
let ss_res: f64 = points
.iter()
.map(|p| {
let predicted = slope * p.0 + intercept;
(p.1 - predicted).powi(2)
})
.sum();
let r_squared = if ss_tot > 0.0 {
1.0 - (ss_res / ss_tot)
} else {
0.0
};
(slope, intercept, r_squared)
}