Skip to main content

mockforge_http/handlers/
forecasting.rs

1//! HTTP handlers for API change forecasting
2//!
3//! This module provides endpoints for querying and managing forecasts.
4
5use axum::{
6    extract::{Path, Query, State},
7    http::StatusCode,
8    response::Json,
9};
10use mockforge_core::contract_drift::forecasting::{ChangeForecast, Forecaster};
11use serde::{Deserialize, Serialize};
12use std::sync::Arc;
13
14#[cfg(feature = "database")]
15use chrono::{DateTime, Utc};
16#[cfg(feature = "database")]
17use mockforge_core::contract_drift::forecasting::SeasonalPattern;
18#[cfg(feature = "database")]
19use uuid::Uuid;
20
21use crate::database::Database;
22
23/// Helper function to map database row to ChangeForecast
24#[cfg(feature = "database")]
25fn map_row_to_change_forecast(row: &sqlx::postgres::PgRow) -> Result<ChangeForecast, sqlx::Error> {
26    use sqlx::Row;
27
28    let service_id: Option<String> = row.try_get("service_id")?;
29    let service_name: Option<String> = row.try_get("service_name")?;
30    let endpoint: String = row.try_get("endpoint")?;
31    let method: String = row.try_get("method")?;
32    let forecast_window_days: i32 = row.try_get("forecast_window_days")?;
33    let predicted_change_probability: f64 = row.try_get("predicted_change_probability")?;
34    let predicted_break_probability: f64 = row.try_get("predicted_break_probability")?;
35    let next_expected_change_date: Option<DateTime<Utc>> =
36        row.try_get("next_expected_change_date")?;
37    let next_expected_break_date: Option<DateTime<Utc>> =
38        row.try_get("next_expected_break_date")?;
39    let volatility_score: f64 = row.try_get("volatility_score")?;
40    let confidence: f64 = row.try_get("confidence")?;
41    let seasonal_patterns_json: serde_json::Value =
42        row.try_get("seasonal_patterns").unwrap_or_default();
43    let predicted_at: DateTime<Utc> = row.try_get("predicted_at")?;
44    let expires_at: DateTime<Utc> = row.try_get("expires_at")?;
45
46    // Parse seasonal patterns from JSONB
47    let seasonal_patterns: Vec<SeasonalPattern> = if seasonal_patterns_json.is_array() {
48        serde_json::from_value(seasonal_patterns_json).unwrap_or_default()
49    } else {
50        Vec::new()
51    };
52
53    Ok(ChangeForecast {
54        service_id,
55        service_name,
56        endpoint,
57        method,
58        forecast_window_days: forecast_window_days as u32,
59        predicted_change_probability,
60        predicted_break_probability,
61        next_expected_change_date,
62        next_expected_break_date,
63        volatility_score,
64        confidence,
65        seasonal_patterns,
66        predicted_at,
67        expires_at,
68    })
69}
70
71/// State for forecasting handlers
72#[derive(Clone)]
73pub struct ForecastingState {
74    /// Forecaster engine
75    pub forecaster: Arc<Forecaster>,
76    /// Database connection (optional)
77    pub database: Option<Database>,
78}
79
80/// Query parameters for listing forecasts
81#[derive(Debug, Deserialize)]
82pub struct ListForecastsQuery {
83    /// Workspace ID filter
84    pub workspace_id: Option<String>,
85    /// Service ID filter
86    pub service_id: Option<String>,
87    /// Endpoint filter
88    pub endpoint: Option<String>,
89    /// Method filter
90    pub method: Option<String>,
91    /// Forecast window (30, 90, or 180 days)
92    pub window_days: Option<u32>,
93}
94
95/// Response for forecast list
96#[derive(Debug, Serialize)]
97pub struct ForecastListResponse {
98    /// Forecasts
99    pub forecasts: Vec<ChangeForecast>,
100    /// Total count
101    pub total: usize,
102}
103
104/// Get forecasts
105///
106/// GET /api/v1/forecasts
107#[cfg(feature = "database")]
108pub async fn list_forecasts(
109    State(state): State<ForecastingState>,
110    Query(params): Query<ListForecastsQuery>,
111) -> Result<Json<ForecastListResponse>, StatusCode> {
112    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
113        Some(pool) => pool,
114        None => {
115            return Ok(Json(ForecastListResponse {
116                forecasts: Vec::new(),
117                total: 0,
118            }));
119        }
120    };
121
122    // Build query with filters
123    let mut query = String::from(
124        "SELECT id, workspace_id, service_id, service_name, endpoint, method,
125         forecast_window_days, predicted_change_probability, predicted_break_probability,
126         next_expected_change_date, next_expected_break_date, volatility_score, confidence,
127         seasonal_patterns, predicted_at, expires_at
128         FROM api_change_forecasts WHERE expires_at > NOW()",
129    );
130
131    let mut bind_index = 1;
132
133    if params.workspace_id.is_some() {
134        query.push_str(&format!(" AND workspace_id = ${}", bind_index));
135        bind_index += 1;
136    }
137
138    if params.service_id.is_some() {
139        query.push_str(&format!(" AND service_id = ${}", bind_index));
140        bind_index += 1;
141    }
142
143    if params.endpoint.is_some() {
144        query.push_str(&format!(" AND endpoint = ${}", bind_index));
145        bind_index += 1;
146    }
147
148    if params.method.is_some() {
149        query.push_str(&format!(" AND method = ${}", bind_index));
150        bind_index += 1;
151    }
152
153    if let Some(window) = params.window_days {
154        query.push_str(&format!(" AND forecast_window_days = ${}", bind_index));
155        bind_index += 1;
156    }
157
158    query.push_str(" ORDER BY predicted_at DESC LIMIT 100");
159
160    // Build query with proper bindings using sqlx
161    let mut query_builder = sqlx::query(&query);
162
163    if let Some(ws_id) = &params.workspace_id {
164        let uuid = Uuid::parse_str(ws_id).ok();
165        query_builder = query_builder.bind(uuid);
166    }
167
168    if let Some(svc_id) = &params.service_id {
169        query_builder = query_builder.bind(svc_id);
170    }
171
172    if let Some(ep) = &params.endpoint {
173        query_builder = query_builder.bind(ep);
174    }
175
176    if let Some(m) = &params.method {
177        query_builder = query_builder.bind(m);
178    }
179
180    if let Some(window) = params.window_days {
181        query_builder = query_builder.bind(window as i32);
182    }
183
184    // Execute query
185    let rows = query_builder.fetch_all(pool).await.map_err(|e| {
186        tracing::error!("Failed to query forecasts: {}", e);
187        StatusCode::INTERNAL_SERVER_ERROR
188    })?;
189
190    // Map rows to ChangeForecast
191    let mut forecasts = Vec::new();
192    for row in rows {
193        match map_row_to_change_forecast(&row) {
194            Ok(forecast) => forecasts.push(forecast),
195            Err(e) => {
196                tracing::warn!("Failed to map forecast row: {}", e);
197                continue;
198            }
199        }
200    }
201
202    let total = forecasts.len();
203    Ok(Json(ForecastListResponse { forecasts, total }))
204}
205
206/// List forecasts (no database)
207///
208/// GET /api/v1/forecasts
209#[cfg(not(feature = "database"))]
210pub async fn list_forecasts(
211    State(_state): State<ForecastingState>,
212    Query(_params): Query<ListForecastsQuery>,
213) -> Result<Json<ForecastListResponse>, StatusCode> {
214    Ok(Json(ForecastListResponse {
215        forecasts: Vec::new(),
216        total: 0,
217    }))
218}
219
220/// Get service-level forecasts
221///
222/// GET /api/v1/forecasts/service/{service_id}
223#[cfg(feature = "database")]
224pub async fn get_service_forecasts(
225    State(state): State<ForecastingState>,
226    Path(service_id): Path<String>,
227    Query(_params): Query<ListForecastsQuery>,
228) -> Result<Json<ForecastListResponse>, StatusCode> {
229    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
230        Some(pool) => pool,
231        None => {
232            return Ok(Json(ForecastListResponse {
233                forecasts: Vec::new(),
234                total: 0,
235            }));
236        }
237    };
238
239    // Query forecasts for this service
240    let rows = sqlx::query(
241        "SELECT * FROM api_change_forecasts
242         WHERE service_id = $1 AND expires_at > NOW()
243         ORDER BY predicted_at DESC LIMIT 50",
244    )
245    .bind(&service_id)
246    .fetch_all(pool)
247    .await
248    .map_err(|e| {
249        tracing::error!("Failed to query service forecasts: {}", e);
250        StatusCode::INTERNAL_SERVER_ERROR
251    })?;
252
253    // Map rows to forecasts
254    let mut forecasts = Vec::new();
255    for row in rows {
256        match map_row_to_change_forecast(&row) {
257            Ok(forecast) => forecasts.push(forecast),
258            Err(e) => {
259                tracing::warn!("Failed to map service forecast row: {}", e);
260                continue;
261            }
262        }
263    }
264
265    let total = forecasts.len();
266    Ok(Json(ForecastListResponse { forecasts, total }))
267}
268
269/// Get service-level forecasts (no database)
270///
271/// GET /api/v1/forecasts/service/{service_id}
272#[cfg(not(feature = "database"))]
273pub async fn get_service_forecasts(
274    State(_state): State<ForecastingState>,
275    Path(_service_id): Path<String>,
276    Query(_params): Query<ListForecastsQuery>,
277) -> Result<Json<ForecastListResponse>, StatusCode> {
278    Ok(Json(ForecastListResponse {
279        forecasts: Vec::new(),
280        total: 0,
281    }))
282}
283
284/// Get endpoint-level forecasts
285///
286/// GET /api/v1/forecasts/endpoint/{endpoint}
287#[cfg(feature = "database")]
288pub async fn get_endpoint_forecasts(
289    State(state): State<ForecastingState>,
290    Path(endpoint): Path<String>,
291    Query(params): Query<ListForecastsQuery>,
292) -> Result<Json<ForecastListResponse>, StatusCode> {
293    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
294        Some(pool) => pool,
295        None => {
296            return Ok(Json(ForecastListResponse {
297                forecasts: Vec::new(),
298                total: 0,
299            }));
300        }
301    };
302
303    let method = params.method.as_deref().unwrap_or("%");
304
305    let rows = sqlx::query(
306        "SELECT * FROM api_change_forecasts
307         WHERE endpoint = $1 AND method LIKE $2 AND expires_at > NOW()
308         ORDER BY predicted_at DESC LIMIT 50",
309    )
310    .bind(&endpoint)
311    .bind(method)
312    .fetch_all(pool)
313    .await
314    .map_err(|e| {
315        tracing::error!("Failed to query endpoint forecasts: {}", e);
316        StatusCode::INTERNAL_SERVER_ERROR
317    })?;
318
319    // Map rows to forecasts
320    let mut forecasts = Vec::new();
321    for row in rows {
322        match map_row_to_change_forecast(&row) {
323            Ok(forecast) => forecasts.push(forecast),
324            Err(e) => {
325                tracing::warn!("Failed to map endpoint forecast row: {}", e);
326                continue;
327            }
328        }
329    }
330
331    let total = forecasts.len();
332    Ok(Json(ForecastListResponse { forecasts, total }))
333}
334
335/// Get endpoint-level forecasts (no database)
336///
337/// GET /api/v1/forecasts/endpoint/{endpoint}
338#[cfg(not(feature = "database"))]
339pub async fn get_endpoint_forecasts(
340    State(_state): State<ForecastingState>,
341    Path(_endpoint): Path<String>,
342    Query(_params): Query<ListForecastsQuery>,
343) -> Result<Json<ForecastListResponse>, StatusCode> {
344    Ok(Json(ForecastListResponse {
345        forecasts: Vec::new(),
346        total: 0,
347    }))
348}
349
350/// Request to refresh forecasts
351#[derive(Debug, Deserialize)]
352pub struct RefreshForecastsRequest {
353    /// Workspace ID
354    pub workspace_id: Option<String>,
355    /// Service ID
356    pub service_id: Option<String>,
357    /// Endpoint (optional)
358    pub endpoint: Option<String>,
359    /// Method (optional)
360    pub method: Option<String>,
361}
362
363/// Refresh forecasts
364///
365/// POST /api/v1/forecasts/refresh
366#[cfg(feature = "database")]
367pub async fn refresh_forecasts(
368    State(state): State<ForecastingState>,
369    Json(request): Json<RefreshForecastsRequest>,
370) -> Result<Json<serde_json::Value>, StatusCode> {
371    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
372        Some(pool) => pool,
373        None => {
374            return Ok(Json(serde_json::json!({
375                "success": false,
376                "error": "Database not available"
377            })));
378        }
379    };
380
381    // Query historical incidents for forecasting
382    let mut incident_query = String::from(
383        "SELECT id, workspace_id, endpoint, method, incident_type, severity, status,
384         detected_at, details, created_at, updated_at
385         FROM drift_incidents WHERE 1=1",
386    );
387
388    if let Some(ws_id) = &request.workspace_id {
389        incident_query.push_str(" AND workspace_id = $1");
390    }
391
392    // Execute query to get incidents
393    let rows = sqlx::query(&incident_query).fetch_all(pool).await.map_err(|e| {
394        tracing::error!("Failed to query drift incidents: {}", e);
395        StatusCode::INTERNAL_SERVER_ERROR
396    })?;
397
398    // Map rows to DriftIncident and generate forecasts
399    use mockforge_core::incidents::types::{IncidentSeverity, IncidentStatus, IncidentType};
400    use sqlx::Row;
401    let mut incidents = Vec::new();
402    for row in rows {
403        let id: uuid::Uuid = row.try_get("id").map_err(|e| {
404            tracing::error!("Failed to get id from row: {}", e);
405            StatusCode::INTERNAL_SERVER_ERROR
406        })?;
407        let workspace_id: Option<uuid::Uuid> = row.try_get("workspace_id").ok();
408        let endpoint: String = match row.try_get("endpoint") {
409            Ok(e) => e,
410            Err(_) => continue,
411        };
412        let method: String = match row.try_get("method") {
413            Ok(m) => m,
414            Err(_) => continue,
415        };
416        let incident_type_str: String = match row.try_get("incident_type") {
417            Ok(s) => s,
418            Err(_) => continue,
419        };
420        let severity_str: String = match row.try_get("severity") {
421            Ok(s) => s,
422            Err(_) => continue,
423        };
424        let status_str: String = match row.try_get("status") {
425            Ok(s) => s,
426            Err(_) => continue,
427        };
428        let detected_at: DateTime<Utc> = match row.try_get("detected_at") {
429            Ok(dt) => dt,
430            Err(_) => continue,
431        };
432        let details_json: serde_json::Value = row.try_get("details").unwrap_or_default();
433        let created_at: DateTime<Utc> = match row.try_get("created_at") {
434            Ok(dt) => dt,
435            Err(_) => continue,
436        };
437        let updated_at: DateTime<Utc> = match row.try_get("updated_at") {
438            Ok(dt) => dt,
439            Err(_) => continue,
440        };
441
442        let incident_type = match incident_type_str.as_str() {
443            "breaking_change" => IncidentType::BreakingChange,
444            "threshold_exceeded" => IncidentType::ThresholdExceeded,
445            _ => continue, // Skip invalid types
446        };
447
448        let severity = match severity_str.as_str() {
449            "low" => IncidentSeverity::Low,
450            "medium" => IncidentSeverity::Medium,
451            "high" => IncidentSeverity::High,
452            "critical" => IncidentSeverity::Critical,
453            _ => continue, // Skip invalid severity
454        };
455
456        let status = match status_str.as_str() {
457            "open" => IncidentStatus::Open,
458            "acknowledged" => IncidentStatus::Acknowledged,
459            "resolved" => IncidentStatus::Resolved,
460            "closed" => IncidentStatus::Closed,
461            _ => continue, // Skip invalid status
462        };
463
464        incidents.push(DriftIncident {
465            id: id.to_string(),
466            budget_id: None,
467            workspace_id: workspace_id.map(|u| u.to_string()),
468            endpoint,
469            method,
470            incident_type,
471            severity,
472            status,
473            detected_at: detected_at.timestamp(),
474            resolved_at: None,
475            details: details_json,
476            external_ticket_id: None,
477            external_ticket_url: None,
478            created_at: created_at.timestamp(),
479            updated_at: updated_at.timestamp(),
480            sync_cycle_id: None,
481            contract_diff_id: None,
482            before_sample: None,
483            after_sample: None,
484            fitness_test_results: Vec::new(),
485            affected_consumers: None,
486            protocol: None,
487        });
488    }
489
490    // Generate forecasts from incidents by grouping by endpoint/method
491    use mockforge_core::incidents::types::DriftIncident;
492    use std::collections::HashMap;
493    let mut forecasts_generated = 0;
494    let mut endpoint_groups: HashMap<(String, String), Vec<DriftIncident>> = HashMap::new();
495
496    for incident in incidents {
497        endpoint_groups
498            .entry((incident.endpoint.clone(), incident.method.clone()))
499            .or_insert_with(Vec::new)
500            .push(incident);
501    }
502
503    for ((endpoint, method), group_incidents) in endpoint_groups {
504        if let Some(forecast) = state.forecaster.generate_forecast(
505            &group_incidents,
506            request.workspace_id.clone(),
507            None, // service_id
508            None, // service_name
509            endpoint,
510            method,
511            30, // forecast_window_days
512        ) {
513            // Persist forecast to database
514            if let Err(e) = store_forecast(pool, &forecast, request.workspace_id.as_deref()).await {
515                tracing::warn!("Failed to store forecast: {}", e);
516            }
517            forecasts_generated += 1;
518        }
519    }
520
521    Ok(Json(serde_json::json!({
522        "success": true,
523        "message": "Forecasts refreshed",
524        "forecasts_generated": forecasts_generated
525    })))
526}
527
528/// Refresh forecasts (no database)
529///
530/// POST /api/v1/forecasts/refresh
531#[cfg(not(feature = "database"))]
532pub async fn refresh_forecasts(
533    State(_state): State<ForecastingState>,
534    Json(_request): Json<RefreshForecastsRequest>,
535) -> Result<Json<serde_json::Value>, StatusCode> {
536    Ok(Json(serde_json::json!({
537        "success": false,
538        "error": "Database not available"
539    })))
540}
541
542/// Store a forecast in the database
543#[cfg(feature = "database")]
544pub async fn store_forecast(
545    pool: &sqlx::PgPool,
546    forecast: &ChangeForecast,
547    workspace_id: Option<&str>,
548) -> Result<(), sqlx::Error> {
549    let id = Uuid::new_v4();
550    let workspace_uuid = workspace_id.and_then(|id| Uuid::parse_str(id).ok());
551
552    sqlx::query(
553        r#"
554        INSERT INTO api_change_forecasts (
555            id, workspace_id, service_id, service_name, endpoint, method,
556            forecast_window_days, predicted_change_probability, predicted_break_probability,
557            next_expected_change_date, next_expected_break_date, volatility_score, confidence,
558            seasonal_patterns, predicted_at, expires_at
559        ) VALUES (
560            $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16
561        )
562        ON CONFLICT (workspace_id, service_id, endpoint, method, forecast_window_days)
563        DO UPDATE SET
564            predicted_change_probability = EXCLUDED.predicted_change_probability,
565            predicted_break_probability = EXCLUDED.predicted_break_probability,
566            next_expected_change_date = EXCLUDED.next_expected_change_date,
567            next_expected_break_date = EXCLUDED.next_expected_break_date,
568            volatility_score = EXCLUDED.volatility_score,
569            confidence = EXCLUDED.confidence,
570            seasonal_patterns = EXCLUDED.seasonal_patterns,
571            predicted_at = EXCLUDED.predicted_at,
572            expires_at = EXCLUDED.expires_at,
573            updated_at = NOW()
574        "#,
575    )
576    .bind(id)
577    .bind(workspace_uuid)
578    .bind(forecast.service_id.as_deref())
579    .bind(forecast.service_name.as_deref())
580    .bind(&forecast.endpoint)
581    .bind(&forecast.method)
582    .bind(forecast.forecast_window_days as i32)
583    .bind(forecast.predicted_change_probability)
584    .bind(forecast.predicted_break_probability)
585    .bind(forecast.next_expected_change_date)
586    .bind(forecast.next_expected_break_date)
587    .bind(forecast.volatility_score)
588    .bind(forecast.confidence)
589    .bind(serde_json::to_value(&forecast.seasonal_patterns).unwrap_or_default())
590    .bind(forecast.predicted_at)
591    .bind(forecast.expires_at)
592    .execute(pool)
593    .await?;
594
595    Ok(())
596}
597
598/// Create router for forecasting endpoints
599pub fn forecasting_router(state: ForecastingState) -> axum::Router {
600    use axum::routing::{get, post};
601    use axum::Router;
602
603    Router::new()
604        .route("/api/v1/forecasts", get(list_forecasts))
605        .route("/api/v1/forecasts/service/{service_id}", get(get_service_forecasts))
606        .route("/api/v1/forecasts/endpoint/{endpoint}", get(get_endpoint_forecasts))
607        .route("/api/v1/forecasts/refresh", post(refresh_forecasts))
608        .with_state(state)
609}