1use axum::{
6 extract::{Path, Query, State},
7 http::StatusCode,
8 response::Json,
9};
10use chrono::{DateTime, Utc};
11use mockforge_core::contract_drift::forecasting::{ChangeForecast, Forecaster};
12use mockforge_core::contract_drift::forecasting::types::SeasonalPattern;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15use uuid::Uuid;
16
17use crate::database::Database;
18
19#[cfg(feature = "database")]
21fn map_row_to_change_forecast(row: &sqlx::postgres::PgRow) -> Result<ChangeForecast, sqlx::Error> {
22 use sqlx::Row;
23
24 let service_id: Option<String> = row.try_get("service_id")?;
25 let service_name: Option<String> = row.try_get("service_name")?;
26 let endpoint: String = row.try_get("endpoint")?;
27 let method: String = row.try_get("method")?;
28 let forecast_window_days: i32 = row.try_get("forecast_window_days")?;
29 let predicted_change_probability: f64 = row.try_get("predicted_change_probability")?;
30 let predicted_break_probability: f64 = row.try_get("predicted_break_probability")?;
31 let next_expected_change_date: Option<DateTime<Utc>> =
32 row.try_get("next_expected_change_date")?;
33 let next_expected_break_date: Option<DateTime<Utc>> =
34 row.try_get("next_expected_break_date")?;
35 let volatility_score: f64 = row.try_get("volatility_score")?;
36 let confidence: f64 = row.try_get("confidence")?;
37 let seasonal_patterns_json: serde_json::Value =
38 row.try_get("seasonal_patterns").unwrap_or_default();
39 let predicted_at: DateTime<Utc> = row.try_get("predicted_at")?;
40 let expires_at: DateTime<Utc> = row.try_get("expires_at")?;
41
42 let seasonal_patterns: Vec<SeasonalPattern> = if seasonal_patterns_json.is_array() {
44 serde_json::from_value(seasonal_patterns_json).unwrap_or_default()
45 } else {
46 Vec::new()
47 };
48
49 Ok(ChangeForecast {
50 service_id,
51 service_name,
52 endpoint,
53 method,
54 forecast_window_days: forecast_window_days as u32,
55 predicted_change_probability,
56 predicted_break_probability,
57 next_expected_change_date,
58 next_expected_break_date,
59 volatility_score,
60 confidence,
61 seasonal_patterns,
62 predicted_at,
63 expires_at,
64 })
65}
66
67#[derive(Clone)]
69pub struct ForecastingState {
70 pub forecaster: Arc<Forecaster>,
72 pub database: Option<Database>,
74}
75
76#[derive(Debug, Deserialize)]
78pub struct ListForecastsQuery {
79 pub workspace_id: Option<String>,
81 pub service_id: Option<String>,
83 pub endpoint: Option<String>,
85 pub method: Option<String>,
87 pub window_days: Option<u32>,
89}
90
91#[derive(Debug, Serialize)]
93pub struct ForecastListResponse {
94 pub forecasts: Vec<ChangeForecast>,
96 pub total: usize,
98}
99
100#[cfg(feature = "database")]
104pub async fn list_forecasts(
105 State(state): State<ForecastingState>,
106 Query(params): Query<ListForecastsQuery>,
107) -> Result<Json<ForecastListResponse>, StatusCode> {
108 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
109 Some(pool) => pool,
110 None => {
111 return Ok(Json(ForecastListResponse {
112 forecasts: Vec::new(),
113 total: 0,
114 }));
115 }
116 };
117
118 let mut query = String::from(
120 "SELECT id, workspace_id, service_id, service_name, endpoint, method,
121 forecast_window_days, predicted_change_probability, predicted_break_probability,
122 next_expected_change_date, next_expected_break_date, volatility_score, confidence,
123 seasonal_patterns, predicted_at, expires_at
124 FROM api_change_forecasts WHERE expires_at > NOW()",
125 );
126
127 let mut bind_index = 1;
128
129 if let Some(ws_id) = ¶ms.workspace_id {
130 query.push_str(&format!(" AND workspace_id = ${}", bind_index));
131 bind_index += 1;
132 }
133
134 if let Some(svc_id) = ¶ms.service_id {
135 query.push_str(&format!(" AND service_id = ${}", bind_index));
136 bind_index += 1;
137 }
138
139 if let Some(ep) = ¶ms.endpoint {
140 query.push_str(&format!(" AND endpoint = ${}", bind_index));
141 bind_index += 1;
142 }
143
144 if let Some(m) = ¶ms.method {
145 query.push_str(&format!(" AND method = ${}", bind_index));
146 bind_index += 1;
147 }
148
149 if let Some(window) = params.window_days {
150 query.push_str(&format!(" AND forecast_window_days = ${}", bind_index));
151 bind_index += 1;
152 }
153
154 query.push_str(" ORDER BY predicted_at DESC LIMIT 100");
155
156 let mut query_builder = sqlx::query(&query);
158
159 if let Some(ws_id) = ¶ms.workspace_id {
160 let uuid = Uuid::parse_str(ws_id).ok();
161 query_builder = query_builder.bind(uuid);
162 }
163
164 if let Some(svc_id) = ¶ms.service_id {
165 query_builder = query_builder.bind(svc_id);
166 }
167
168 if let Some(ep) = ¶ms.endpoint {
169 query_builder = query_builder.bind(ep);
170 }
171
172 if let Some(m) = ¶ms.method {
173 query_builder = query_builder.bind(m);
174 }
175
176 if let Some(window) = params.window_days {
177 query_builder = query_builder.bind(window as i32);
178 }
179
180 let rows = query_builder.fetch_all(pool).await.map_err(|e| {
182 tracing::error!("Failed to query forecasts: {}", e);
183 StatusCode::INTERNAL_SERVER_ERROR
184 })?;
185
186 let mut forecasts = Vec::new();
188 for row in rows {
189 match map_row_to_change_forecast(&row) {
190 Ok(forecast) => forecasts.push(forecast),
191 Err(e) => {
192 tracing::warn!("Failed to map forecast row: {}", e);
193 continue;
194 }
195 }
196 }
197
198 let total = forecasts.len();
199 Ok(Json(ForecastListResponse {
200 forecasts,
201 total,
202 }))
203}
204
205#[cfg(not(feature = "database"))]
209pub async fn list_forecasts(
210 State(_state): State<ForecastingState>,
211 Query(_params): Query<ListForecastsQuery>,
212) -> Result<Json<ForecastListResponse>, StatusCode> {
213 Ok(Json(ForecastListResponse {
214 forecasts: Vec::new(),
215 total: 0,
216 }))
217}
218
219#[cfg(feature = "database")]
223pub async fn get_service_forecasts(
224 State(state): State<ForecastingState>,
225 Path(service_id): Path<String>,
226 Query(_params): Query<ListForecastsQuery>,
227) -> Result<Json<ForecastListResponse>, StatusCode> {
228 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
229 Some(pool) => pool,
230 None => {
231 return Ok(Json(ForecastListResponse {
232 forecasts: Vec::new(),
233 total: 0,
234 }));
235 }
236 };
237
238 let rows = sqlx::query(
240 "SELECT * FROM api_change_forecasts
241 WHERE service_id = $1 AND expires_at > NOW()
242 ORDER BY predicted_at DESC LIMIT 50",
243 )
244 .bind(&service_id)
245 .fetch_all(pool)
246 .await
247 .map_err(|e| {
248 tracing::error!("Failed to query service forecasts: {}", e);
249 StatusCode::INTERNAL_SERVER_ERROR
250 })?;
251
252 Ok(Json(ForecastListResponse {
254 forecasts: Vec::new(),
255 total: rows.len(),
256 }))
257}
258
259#[cfg(not(feature = "database"))]
263pub async fn get_service_forecasts(
264 State(_state): State<ForecastingState>,
265 Path(_service_id): Path<String>,
266 Query(_params): Query<ListForecastsQuery>,
267) -> Result<Json<ForecastListResponse>, StatusCode> {
268 Ok(Json(ForecastListResponse {
269 forecasts: Vec::new(),
270 total: 0,
271 }))
272}
273
274#[cfg(feature = "database")]
278pub async fn get_endpoint_forecasts(
279 State(state): State<ForecastingState>,
280 Path(endpoint): Path<String>,
281 Query(params): Query<ListForecastsQuery>,
282) -> Result<Json<ForecastListResponse>, StatusCode> {
283 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
284 Some(pool) => pool,
285 None => {
286 return Ok(Json(ForecastListResponse {
287 forecasts: Vec::new(),
288 total: 0,
289 }));
290 }
291 };
292
293 let method = params.method.as_deref().unwrap_or("%");
294
295 let rows = sqlx::query(
296 "SELECT * FROM api_change_forecasts
297 WHERE endpoint = $1 AND method LIKE $2 AND expires_at > NOW()
298 ORDER BY predicted_at DESC LIMIT 50",
299 )
300 .bind(&endpoint)
301 .bind(method)
302 .fetch_all(pool)
303 .await
304 .map_err(|e| {
305 tracing::error!("Failed to query endpoint forecasts: {}", e);
306 StatusCode::INTERNAL_SERVER_ERROR
307 })?;
308
309 Ok(Json(ForecastListResponse {
310 forecasts: Vec::new(),
311 total: rows.len(),
312 }))
313}
314
315#[cfg(not(feature = "database"))]
319pub async fn get_endpoint_forecasts(
320 State(_state): State<ForecastingState>,
321 Path(_endpoint): Path<String>,
322 Query(_params): Query<ListForecastsQuery>,
323) -> Result<Json<ForecastListResponse>, StatusCode> {
324 Ok(Json(ForecastListResponse {
325 forecasts: Vec::new(),
326 total: 0,
327 }))
328}
329
330#[derive(Debug, Deserialize)]
332pub struct RefreshForecastsRequest {
333 pub workspace_id: Option<String>,
335 pub service_id: Option<String>,
337 pub endpoint: Option<String>,
339 pub method: Option<String>,
341}
342
343#[cfg(feature = "database")]
347pub async fn refresh_forecasts(
348 State(state): State<ForecastingState>,
349 Json(request): Json<RefreshForecastsRequest>,
350) -> Result<Json<serde_json::Value>, StatusCode> {
351 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
352 Some(pool) => pool,
353 None => {
354 return Ok(Json(serde_json::json!({
355 "success": false,
356 "error": "Database not available"
357 })));
358 }
359 };
360
361 let mut incident_query = String::from(
363 "SELECT id, workspace_id, endpoint, method, incident_type, severity, status,
364 detected_at, details, created_at, updated_at
365 FROM drift_incidents WHERE 1=1",
366 );
367
368 if let Some(ws_id) = &request.workspace_id {
369 incident_query.push_str(" AND workspace_id = $1");
370 }
371
372 let rows = sqlx::query(&incident_query).fetch_all(pool).await.map_err(|e| {
374 tracing::error!("Failed to query drift incidents: {}", e);
375 StatusCode::INTERNAL_SERVER_ERROR
376 })?;
377
378 use mockforge_core::incidents::types::{IncidentSeverity, IncidentStatus, IncidentType};
380 use sqlx::Row;
381 let mut incidents = Vec::new();
382 for row in rows {
383 let id: uuid::Uuid = row.try_get("id").map_err(|e| {
384 tracing::error!("Failed to get id from row: {}", e);
385 StatusCode::INTERNAL_SERVER_ERROR
386 })?;
387 let workspace_id: Option<uuid::Uuid> = row.try_get("workspace_id").ok();
388 let endpoint: String = match row.try_get("endpoint") {
389 Ok(e) => e,
390 Err(_) => continue,
391 };
392 let method: String = match row.try_get("method") {
393 Ok(m) => m,
394 Err(_) => continue,
395 };
396 let incident_type_str: String = match row.try_get("incident_type") {
397 Ok(s) => s,
398 Err(_) => continue,
399 };
400 let severity_str: String = match row.try_get("severity") {
401 Ok(s) => s,
402 Err(_) => continue,
403 };
404 let status_str: String = match row.try_get("status") {
405 Ok(s) => s,
406 Err(_) => continue,
407 };
408 let detected_at: DateTime<Utc> = match row.try_get("detected_at") {
409 Ok(dt) => dt,
410 Err(_) => continue,
411 };
412 let details_json: serde_json::Value = row.try_get("details").unwrap_or_default();
413 let created_at: DateTime<Utc> = match row.try_get("created_at") {
414 Ok(dt) => dt,
415 Err(_) => continue,
416 };
417 let updated_at: DateTime<Utc> = match row.try_get("updated_at") {
418 Ok(dt) => dt,
419 Err(_) => continue,
420 };
421
422 let incident_type = match incident_type_str.as_str() {
423 "breaking_change" => IncidentType::BreakingChange,
424 "threshold_exceeded" => IncidentType::ThresholdExceeded,
425 _ => continue, };
427
428 let severity = match severity_str.as_str() {
429 "low" => IncidentSeverity::Low,
430 "medium" => IncidentSeverity::Medium,
431 "high" => IncidentSeverity::High,
432 "critical" => IncidentSeverity::Critical,
433 _ => continue, };
435
436 let status = match status_str.as_str() {
437 "open" => IncidentStatus::Open,
438 "acknowledged" => IncidentStatus::Acknowledged,
439 "resolved" => IncidentStatus::Resolved,
440 "closed" => IncidentStatus::Closed,
441 _ => continue, };
443
444 incidents.push(DriftIncident {
445 id: id.to_string(),
446 budget_id: None,
447 workspace_id: workspace_id.map(|u| u.to_string()),
448 endpoint,
449 method,
450 incident_type,
451 severity,
452 status,
453 detected_at: detected_at.timestamp(),
454 resolved_at: None,
455 details: details_json,
456 external_ticket_id: None,
457 external_ticket_url: None,
458 created_at: created_at.timestamp(),
459 updated_at: updated_at.timestamp(),
460 sync_cycle_id: None,
461 contract_diff_id: None,
462 before_sample: None,
463 after_sample: None,
464 fitness_test_results: Vec::new(),
465 affected_consumers: None,
466 protocol: None,
467 });
468 }
469
470 use mockforge_core::incidents::types::DriftIncident;
472 use std::collections::HashMap;
473 let mut forecasts_generated = 0;
474 let mut endpoint_groups: HashMap<(String, String), Vec<DriftIncident>> = HashMap::new();
475
476 for incident in incidents {
477 endpoint_groups
478 .entry((incident.endpoint.clone(), incident.method.clone()))
479 .or_insert_with(Vec::new)
480 .push(incident);
481 }
482
483 for ((endpoint, method), group_incidents) in endpoint_groups {
484 if let Some(_forecast) = state.forecaster.generate_forecast(
485 &group_incidents,
486 request.workspace_id.clone(),
487 None, None, endpoint,
490 method,
491 30, ) {
493 forecasts_generated += 1;
494 }
495 }
496
497 Ok(Json(serde_json::json!({
498 "success": true,
499 "message": "Forecasts refreshed",
500 "forecasts_generated": forecasts_generated
501 })))
502}
503
504#[cfg(not(feature = "database"))]
508pub async fn refresh_forecasts(
509 State(_state): State<ForecastingState>,
510 Json(_request): Json<RefreshForecastsRequest>,
511) -> Result<Json<serde_json::Value>, StatusCode> {
512 Ok(Json(serde_json::json!({
513 "success": false,
514 "error": "Database not available"
515 })))
516}
517
518#[cfg(feature = "database")]
520pub async fn store_forecast(
521 pool: &sqlx::PgPool,
522 forecast: &ChangeForecast,
523 workspace_id: Option<&str>,
524) -> Result<(), sqlx::Error> {
525 let id = Uuid::new_v4();
526 let workspace_uuid = workspace_id.and_then(|id| Uuid::parse_str(id).ok());
527
528 sqlx::query(
529 r#"
530 INSERT INTO api_change_forecasts (
531 id, workspace_id, service_id, service_name, endpoint, method,
532 forecast_window_days, predicted_change_probability, predicted_break_probability,
533 next_expected_change_date, next_expected_break_date, volatility_score, confidence,
534 seasonal_patterns, predicted_at, expires_at
535 ) VALUES (
536 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16
537 )
538 ON CONFLICT (workspace_id, service_id, endpoint, method, forecast_window_days)
539 DO UPDATE SET
540 predicted_change_probability = EXCLUDED.predicted_change_probability,
541 predicted_break_probability = EXCLUDED.predicted_break_probability,
542 next_expected_change_date = EXCLUDED.next_expected_change_date,
543 next_expected_break_date = EXCLUDED.next_expected_break_date,
544 volatility_score = EXCLUDED.volatility_score,
545 confidence = EXCLUDED.confidence,
546 seasonal_patterns = EXCLUDED.seasonal_patterns,
547 predicted_at = EXCLUDED.predicted_at,
548 expires_at = EXCLUDED.expires_at,
549 updated_at = NOW()
550 "#,
551 )
552 .bind(id)
553 .bind(workspace_uuid)
554 .bind(forecast.service_id.as_deref())
555 .bind(forecast.service_name.as_deref())
556 .bind(&forecast.endpoint)
557 .bind(&forecast.method)
558 .bind(forecast.forecast_window_days as i32)
559 .bind(forecast.predicted_change_probability)
560 .bind(forecast.predicted_break_probability)
561 .bind(forecast.next_expected_change_date)
562 .bind(forecast.next_expected_break_date)
563 .bind(forecast.volatility_score)
564 .bind(forecast.confidence)
565 .bind(serde_json::to_value(&forecast.seasonal_patterns).unwrap_or_default())
566 .bind(forecast.predicted_at)
567 .bind(forecast.expires_at)
568 .execute(pool)
569 .await?;
570
571 Ok(())
572}
573
574pub fn forecasting_router(state: ForecastingState) -> axum::Router {
576 use axum::routing::{get, post};
577 use axum::Router;
578
579 Router::new()
580 .route("/api/v1/forecasts", get(list_forecasts))
581 .route("/api/v1/forecasts/service/{service_id}", get(get_service_forecasts))
582 .route("/api/v1/forecasts/endpoint/{endpoint}", get(get_endpoint_forecasts))
583 .route("/api/v1/forecasts/refresh", post(refresh_forecasts))
584 .with_state(state)
585}