1use axum::{
6 extract::{Path, Query, State},
7 http::StatusCode,
8 response::Json,
9};
10use chrono::{DateTime, Utc};
11use mockforge_core::contract_drift::forecasting::types::SeasonalPattern;
12use mockforge_core::contract_drift::forecasting::{ChangeForecast, Forecaster};
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15use uuid::Uuid;
16
17use crate::database::Database;
18
19#[cfg(feature = "database")]
21fn map_row_to_change_forecast(row: &sqlx::postgres::PgRow) -> Result<ChangeForecast, sqlx::Error> {
22 use sqlx::Row;
23
24 let service_id: Option<String> = row.try_get("service_id")?;
25 let service_name: Option<String> = row.try_get("service_name")?;
26 let endpoint: String = row.try_get("endpoint")?;
27 let method: String = row.try_get("method")?;
28 let forecast_window_days: i32 = row.try_get("forecast_window_days")?;
29 let predicted_change_probability: f64 = row.try_get("predicted_change_probability")?;
30 let predicted_break_probability: f64 = row.try_get("predicted_break_probability")?;
31 let next_expected_change_date: Option<DateTime<Utc>> =
32 row.try_get("next_expected_change_date")?;
33 let next_expected_break_date: Option<DateTime<Utc>> =
34 row.try_get("next_expected_break_date")?;
35 let volatility_score: f64 = row.try_get("volatility_score")?;
36 let confidence: f64 = row.try_get("confidence")?;
37 let seasonal_patterns_json: serde_json::Value =
38 row.try_get("seasonal_patterns").unwrap_or_default();
39 let predicted_at: DateTime<Utc> = row.try_get("predicted_at")?;
40 let expires_at: DateTime<Utc> = row.try_get("expires_at")?;
41
42 let seasonal_patterns: Vec<SeasonalPattern> = if seasonal_patterns_json.is_array() {
44 serde_json::from_value(seasonal_patterns_json).unwrap_or_default()
45 } else {
46 Vec::new()
47 };
48
49 Ok(ChangeForecast {
50 service_id,
51 service_name,
52 endpoint,
53 method,
54 forecast_window_days: forecast_window_days as u32,
55 predicted_change_probability,
56 predicted_break_probability,
57 next_expected_change_date,
58 next_expected_break_date,
59 volatility_score,
60 confidence,
61 seasonal_patterns,
62 predicted_at,
63 expires_at,
64 })
65}
66
67#[derive(Clone)]
69pub struct ForecastingState {
70 pub forecaster: Arc<Forecaster>,
72 pub database: Option<Database>,
74}
75
76#[derive(Debug, Deserialize)]
78pub struct ListForecastsQuery {
79 pub workspace_id: Option<String>,
81 pub service_id: Option<String>,
83 pub endpoint: Option<String>,
85 pub method: Option<String>,
87 pub window_days: Option<u32>,
89}
90
91#[derive(Debug, Serialize)]
93pub struct ForecastListResponse {
94 pub forecasts: Vec<ChangeForecast>,
96 pub total: usize,
98}
99
100#[cfg(feature = "database")]
104pub async fn list_forecasts(
105 State(state): State<ForecastingState>,
106 Query(params): Query<ListForecastsQuery>,
107) -> Result<Json<ForecastListResponse>, StatusCode> {
108 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
109 Some(pool) => pool,
110 None => {
111 return Ok(Json(ForecastListResponse {
112 forecasts: Vec::new(),
113 total: 0,
114 }));
115 }
116 };
117
118 let mut query = String::from(
120 "SELECT id, workspace_id, service_id, service_name, endpoint, method,
121 forecast_window_days, predicted_change_probability, predicted_break_probability,
122 next_expected_change_date, next_expected_break_date, volatility_score, confidence,
123 seasonal_patterns, predicted_at, expires_at
124 FROM api_change_forecasts WHERE expires_at > NOW()",
125 );
126
127 let mut bind_index = 1;
128
129 if let Some(ws_id) = ¶ms.workspace_id {
130 query.push_str(&format!(" AND workspace_id = ${}", bind_index));
131 bind_index += 1;
132 }
133
134 if let Some(svc_id) = ¶ms.service_id {
135 query.push_str(&format!(" AND service_id = ${}", bind_index));
136 bind_index += 1;
137 }
138
139 if let Some(ep) = ¶ms.endpoint {
140 query.push_str(&format!(" AND endpoint = ${}", bind_index));
141 bind_index += 1;
142 }
143
144 if let Some(m) = ¶ms.method {
145 query.push_str(&format!(" AND method = ${}", bind_index));
146 bind_index += 1;
147 }
148
149 if let Some(window) = params.window_days {
150 query.push_str(&format!(" AND forecast_window_days = ${}", bind_index));
151 bind_index += 1;
152 }
153
154 query.push_str(" ORDER BY predicted_at DESC LIMIT 100");
155
156 let mut query_builder = sqlx::query(&query);
158
159 if let Some(ws_id) = ¶ms.workspace_id {
160 let uuid = Uuid::parse_str(ws_id).ok();
161 query_builder = query_builder.bind(uuid);
162 }
163
164 if let Some(svc_id) = ¶ms.service_id {
165 query_builder = query_builder.bind(svc_id);
166 }
167
168 if let Some(ep) = ¶ms.endpoint {
169 query_builder = query_builder.bind(ep);
170 }
171
172 if let Some(m) = ¶ms.method {
173 query_builder = query_builder.bind(m);
174 }
175
176 if let Some(window) = params.window_days {
177 query_builder = query_builder.bind(window as i32);
178 }
179
180 let rows = query_builder.fetch_all(pool).await.map_err(|e| {
182 tracing::error!("Failed to query forecasts: {}", e);
183 StatusCode::INTERNAL_SERVER_ERROR
184 })?;
185
186 let mut forecasts = Vec::new();
188 for row in rows {
189 match map_row_to_change_forecast(&row) {
190 Ok(forecast) => forecasts.push(forecast),
191 Err(e) => {
192 tracing::warn!("Failed to map forecast row: {}", e);
193 continue;
194 }
195 }
196 }
197
198 let total = forecasts.len();
199 Ok(Json(ForecastListResponse { forecasts, total }))
200}
201
202#[cfg(not(feature = "database"))]
206pub async fn list_forecasts(
207 State(_state): State<ForecastingState>,
208 Query(_params): Query<ListForecastsQuery>,
209) -> Result<Json<ForecastListResponse>, StatusCode> {
210 Ok(Json(ForecastListResponse {
211 forecasts: Vec::new(),
212 total: 0,
213 }))
214}
215
216#[cfg(feature = "database")]
220pub async fn get_service_forecasts(
221 State(state): State<ForecastingState>,
222 Path(service_id): Path<String>,
223 Query(_params): Query<ListForecastsQuery>,
224) -> Result<Json<ForecastListResponse>, StatusCode> {
225 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
226 Some(pool) => pool,
227 None => {
228 return Ok(Json(ForecastListResponse {
229 forecasts: Vec::new(),
230 total: 0,
231 }));
232 }
233 };
234
235 let rows = sqlx::query(
237 "SELECT * FROM api_change_forecasts
238 WHERE service_id = $1 AND expires_at > NOW()
239 ORDER BY predicted_at DESC LIMIT 50",
240 )
241 .bind(&service_id)
242 .fetch_all(pool)
243 .await
244 .map_err(|e| {
245 tracing::error!("Failed to query service forecasts: {}", e);
246 StatusCode::INTERNAL_SERVER_ERROR
247 })?;
248
249 Ok(Json(ForecastListResponse {
251 forecasts: Vec::new(),
252 total: rows.len(),
253 }))
254}
255
256#[cfg(not(feature = "database"))]
260pub async fn get_service_forecasts(
261 State(_state): State<ForecastingState>,
262 Path(_service_id): Path<String>,
263 Query(_params): Query<ListForecastsQuery>,
264) -> Result<Json<ForecastListResponse>, StatusCode> {
265 Ok(Json(ForecastListResponse {
266 forecasts: Vec::new(),
267 total: 0,
268 }))
269}
270
271#[cfg(feature = "database")]
275pub async fn get_endpoint_forecasts(
276 State(state): State<ForecastingState>,
277 Path(endpoint): Path<String>,
278 Query(params): Query<ListForecastsQuery>,
279) -> Result<Json<ForecastListResponse>, StatusCode> {
280 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
281 Some(pool) => pool,
282 None => {
283 return Ok(Json(ForecastListResponse {
284 forecasts: Vec::new(),
285 total: 0,
286 }));
287 }
288 };
289
290 let method = params.method.as_deref().unwrap_or("%");
291
292 let rows = sqlx::query(
293 "SELECT * FROM api_change_forecasts
294 WHERE endpoint = $1 AND method LIKE $2 AND expires_at > NOW()
295 ORDER BY predicted_at DESC LIMIT 50",
296 )
297 .bind(&endpoint)
298 .bind(method)
299 .fetch_all(pool)
300 .await
301 .map_err(|e| {
302 tracing::error!("Failed to query endpoint forecasts: {}", e);
303 StatusCode::INTERNAL_SERVER_ERROR
304 })?;
305
306 Ok(Json(ForecastListResponse {
307 forecasts: Vec::new(),
308 total: rows.len(),
309 }))
310}
311
312#[cfg(not(feature = "database"))]
316pub async fn get_endpoint_forecasts(
317 State(_state): State<ForecastingState>,
318 Path(_endpoint): Path<String>,
319 Query(_params): Query<ListForecastsQuery>,
320) -> Result<Json<ForecastListResponse>, StatusCode> {
321 Ok(Json(ForecastListResponse {
322 forecasts: Vec::new(),
323 total: 0,
324 }))
325}
326
327#[derive(Debug, Deserialize)]
329pub struct RefreshForecastsRequest {
330 pub workspace_id: Option<String>,
332 pub service_id: Option<String>,
334 pub endpoint: Option<String>,
336 pub method: Option<String>,
338}
339
340#[cfg(feature = "database")]
344pub async fn refresh_forecasts(
345 State(state): State<ForecastingState>,
346 Json(request): Json<RefreshForecastsRequest>,
347) -> Result<Json<serde_json::Value>, StatusCode> {
348 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
349 Some(pool) => pool,
350 None => {
351 return Ok(Json(serde_json::json!({
352 "success": false,
353 "error": "Database not available"
354 })));
355 }
356 };
357
358 let mut incident_query = String::from(
360 "SELECT id, workspace_id, endpoint, method, incident_type, severity, status,
361 detected_at, details, created_at, updated_at
362 FROM drift_incidents WHERE 1=1",
363 );
364
365 if let Some(ws_id) = &request.workspace_id {
366 incident_query.push_str(" AND workspace_id = $1");
367 }
368
369 let rows = sqlx::query(&incident_query).fetch_all(pool).await.map_err(|e| {
371 tracing::error!("Failed to query drift incidents: {}", e);
372 StatusCode::INTERNAL_SERVER_ERROR
373 })?;
374
375 use mockforge_core::incidents::types::{IncidentSeverity, IncidentStatus, IncidentType};
377 use sqlx::Row;
378 let mut incidents = Vec::new();
379 for row in rows {
380 let id: uuid::Uuid = row.try_get("id").map_err(|e| {
381 tracing::error!("Failed to get id from row: {}", e);
382 StatusCode::INTERNAL_SERVER_ERROR
383 })?;
384 let workspace_id: Option<uuid::Uuid> = row.try_get("workspace_id").ok();
385 let endpoint: String = match row.try_get("endpoint") {
386 Ok(e) => e,
387 Err(_) => continue,
388 };
389 let method: String = match row.try_get("method") {
390 Ok(m) => m,
391 Err(_) => continue,
392 };
393 let incident_type_str: String = match row.try_get("incident_type") {
394 Ok(s) => s,
395 Err(_) => continue,
396 };
397 let severity_str: String = match row.try_get("severity") {
398 Ok(s) => s,
399 Err(_) => continue,
400 };
401 let status_str: String = match row.try_get("status") {
402 Ok(s) => s,
403 Err(_) => continue,
404 };
405 let detected_at: DateTime<Utc> = match row.try_get("detected_at") {
406 Ok(dt) => dt,
407 Err(_) => continue,
408 };
409 let details_json: serde_json::Value = row.try_get("details").unwrap_or_default();
410 let created_at: DateTime<Utc> = match row.try_get("created_at") {
411 Ok(dt) => dt,
412 Err(_) => continue,
413 };
414 let updated_at: DateTime<Utc> = match row.try_get("updated_at") {
415 Ok(dt) => dt,
416 Err(_) => continue,
417 };
418
419 let incident_type = match incident_type_str.as_str() {
420 "breaking_change" => IncidentType::BreakingChange,
421 "threshold_exceeded" => IncidentType::ThresholdExceeded,
422 _ => continue, };
424
425 let severity = match severity_str.as_str() {
426 "low" => IncidentSeverity::Low,
427 "medium" => IncidentSeverity::Medium,
428 "high" => IncidentSeverity::High,
429 "critical" => IncidentSeverity::Critical,
430 _ => continue, };
432
433 let status = match status_str.as_str() {
434 "open" => IncidentStatus::Open,
435 "acknowledged" => IncidentStatus::Acknowledged,
436 "resolved" => IncidentStatus::Resolved,
437 "closed" => IncidentStatus::Closed,
438 _ => continue, };
440
441 incidents.push(DriftIncident {
442 id: id.to_string(),
443 budget_id: None,
444 workspace_id: workspace_id.map(|u| u.to_string()),
445 endpoint,
446 method,
447 incident_type,
448 severity,
449 status,
450 detected_at: detected_at.timestamp(),
451 resolved_at: None,
452 details: details_json,
453 external_ticket_id: None,
454 external_ticket_url: None,
455 created_at: created_at.timestamp(),
456 updated_at: updated_at.timestamp(),
457 sync_cycle_id: None,
458 contract_diff_id: None,
459 before_sample: None,
460 after_sample: None,
461 fitness_test_results: Vec::new(),
462 affected_consumers: None,
463 protocol: None,
464 });
465 }
466
467 use mockforge_core::incidents::types::DriftIncident;
469 use std::collections::HashMap;
470 let mut forecasts_generated = 0;
471 let mut endpoint_groups: HashMap<(String, String), Vec<DriftIncident>> = HashMap::new();
472
473 for incident in incidents {
474 endpoint_groups
475 .entry((incident.endpoint.clone(), incident.method.clone()))
476 .or_insert_with(Vec::new)
477 .push(incident);
478 }
479
480 for ((endpoint, method), group_incidents) in endpoint_groups {
481 if let Some(_forecast) = state.forecaster.generate_forecast(
482 &group_incidents,
483 request.workspace_id.clone(),
484 None, None, endpoint,
487 method,
488 30, ) {
490 forecasts_generated += 1;
491 }
492 }
493
494 Ok(Json(serde_json::json!({
495 "success": true,
496 "message": "Forecasts refreshed",
497 "forecasts_generated": forecasts_generated
498 })))
499}
500
501#[cfg(not(feature = "database"))]
505pub async fn refresh_forecasts(
506 State(_state): State<ForecastingState>,
507 Json(_request): Json<RefreshForecastsRequest>,
508) -> Result<Json<serde_json::Value>, StatusCode> {
509 Ok(Json(serde_json::json!({
510 "success": false,
511 "error": "Database not available"
512 })))
513}
514
515#[cfg(feature = "database")]
517pub async fn store_forecast(
518 pool: &sqlx::PgPool,
519 forecast: &ChangeForecast,
520 workspace_id: Option<&str>,
521) -> Result<(), sqlx::Error> {
522 let id = Uuid::new_v4();
523 let workspace_uuid = workspace_id.and_then(|id| Uuid::parse_str(id).ok());
524
525 sqlx::query(
526 r#"
527 INSERT INTO api_change_forecasts (
528 id, workspace_id, service_id, service_name, endpoint, method,
529 forecast_window_days, predicted_change_probability, predicted_break_probability,
530 next_expected_change_date, next_expected_break_date, volatility_score, confidence,
531 seasonal_patterns, predicted_at, expires_at
532 ) VALUES (
533 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16
534 )
535 ON CONFLICT (workspace_id, service_id, endpoint, method, forecast_window_days)
536 DO UPDATE SET
537 predicted_change_probability = EXCLUDED.predicted_change_probability,
538 predicted_break_probability = EXCLUDED.predicted_break_probability,
539 next_expected_change_date = EXCLUDED.next_expected_change_date,
540 next_expected_break_date = EXCLUDED.next_expected_break_date,
541 volatility_score = EXCLUDED.volatility_score,
542 confidence = EXCLUDED.confidence,
543 seasonal_patterns = EXCLUDED.seasonal_patterns,
544 predicted_at = EXCLUDED.predicted_at,
545 expires_at = EXCLUDED.expires_at,
546 updated_at = NOW()
547 "#,
548 )
549 .bind(id)
550 .bind(workspace_uuid)
551 .bind(forecast.service_id.as_deref())
552 .bind(forecast.service_name.as_deref())
553 .bind(&forecast.endpoint)
554 .bind(&forecast.method)
555 .bind(forecast.forecast_window_days as i32)
556 .bind(forecast.predicted_change_probability)
557 .bind(forecast.predicted_break_probability)
558 .bind(forecast.next_expected_change_date)
559 .bind(forecast.next_expected_break_date)
560 .bind(forecast.volatility_score)
561 .bind(forecast.confidence)
562 .bind(serde_json::to_value(&forecast.seasonal_patterns).unwrap_or_default())
563 .bind(forecast.predicted_at)
564 .bind(forecast.expires_at)
565 .execute(pool)
566 .await?;
567
568 Ok(())
569}
570
571pub fn forecasting_router(state: ForecastingState) -> axum::Router {
573 use axum::routing::{get, post};
574 use axum::Router;
575
576 Router::new()
577 .route("/api/v1/forecasts", get(list_forecasts))
578 .route("/api/v1/forecasts/service/{service_id}", get(get_service_forecasts))
579 .route("/api/v1/forecasts/endpoint/{endpoint}", get(get_endpoint_forecasts))
580 .route("/api/v1/forecasts/refresh", post(refresh_forecasts))
581 .with_state(state)
582}