mockforge_http/handlers/
threat_modeling.rs

1//! HTTP handlers for contract threat modeling
2//!
3//! This module provides endpoints for security threat assessments.
4
5use axum::{
6    extract::{Path, Query, State},
7    http::StatusCode,
8    response::Json,
9};
10use chrono::{DateTime, Utc};
11use mockforge_core::contract_drift::threat_modeling::{ThreatAnalyzer, ThreatAssessment};
12use mockforge_core::openapi::OpenApiSpec;
13use serde::Deserialize;
14use std::sync::Arc;
15use uuid::Uuid;
16
17use crate::database::Database;
18
19/// Helper function to map database row to ThreatAssessment
20#[cfg(feature = "database")]
21fn map_row_to_threat_assessment(
22    row: &sqlx::postgres::PgRow,
23) -> Result<ThreatAssessment, sqlx::Error> {
24    use mockforge_core::contract_drift::threat_modeling::{
25        AggregationLevel, RemediationSuggestion, ThreatCategory, ThreatFinding, ThreatLevel,
26    };
27
28    // Parse basic fields
29    let workspace_id: Option<uuid::Uuid> = row.try_get("workspace_id")?;
30    let service_id: Option<String> = row.try_get("service_id")?;
31    let service_name: Option<String> = row.try_get("service_name")?;
32    let endpoint: Option<String> = row.try_get("endpoint")?;
33    let method: Option<String> = row.try_get("method")?;
34    let aggregation_level_str: String = row.try_get("aggregation_level")?;
35    let threat_level_str: String = row.try_get("threat_level")?;
36    let threat_score: f64 = row.try_get("threat_score")?;
37    let assessed_at: DateTime<Utc> = row.try_get("assessed_at")?;
38
39    // Parse aggregation level
40    let aggregation_level = match aggregation_level_str.as_str() {
41        "workspace" => AggregationLevel::Workspace,
42        "service" => AggregationLevel::Service,
43        "endpoint" => AggregationLevel::Endpoint,
44        _ => return Err(sqlx::Error::Decode("Invalid aggregation_level".into())),
45    };
46
47    // Parse threat level
48    let threat_level = match threat_level_str.as_str() {
49        "low" => ThreatLevel::Low,
50        "medium" => ThreatLevel::Medium,
51        "high" => ThreatLevel::High,
52        "critical" => ThreatLevel::Critical,
53        _ => return Err(sqlx::Error::Decode("Invalid threat_level".into())),
54    };
55
56    // Parse JSONB columns
57    let threat_categories_json: serde_json::Value = row.try_get("threat_categories")?;
58    let threat_categories: Vec<ThreatCategory> =
59        serde_json::from_value(threat_categories_json).unwrap_or_default();
60
61    let findings_json: serde_json::Value = row.try_get("findings")?;
62    let findings: Vec<ThreatFinding> = serde_json::from_value(findings_json).unwrap_or_default();
63
64    let remediations_json: serde_json::Value = row.try_get("remediation_suggestions")?;
65    let remediation_suggestions: Vec<RemediationSuggestion> =
66        serde_json::from_value(remediations_json).unwrap_or_default();
67
68    Ok(ThreatAssessment {
69        workspace_id: workspace_id.map(|u| u.to_string()),
70        service_id,
71        service_name,
72        endpoint,
73        method,
74        aggregation_level,
75        threat_level,
76        threat_score,
77        threat_categories,
78        findings,
79        remediation_suggestions,
80        assessed_at,
81    })
82}
83
84/// State for threat modeling handlers
85#[derive(Clone)]
86pub struct ThreatModelingState {
87    /// Threat analyzer
88    pub analyzer: Arc<ThreatAnalyzer>,
89    /// Webhook configs for notifications (optional)
90    pub webhook_configs: Vec<mockforge_core::incidents::integrations::WebhookConfig>,
91    /// Database connection (optional)
92    pub database: Option<Database>,
93}
94
95/// Get workspace-level threat assessment
96///
97/// GET /api/v1/threats/workspace/{workspace_id}
98#[cfg(feature = "database")]
99pub async fn get_workspace_threats(
100    State(state): State<ThreatModelingState>,
101    Path(workspace_id): Path<String>,
102) -> Result<Json<ThreatAssessment>, StatusCode> {
103    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
104        Some(pool) => pool,
105        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
106    };
107
108    // Query latest assessment from database
109    let row = sqlx::query(
110        "SELECT * FROM contract_threat_assessments
111         WHERE workspace_id = $1 AND aggregation_level = 'workspace'
112         ORDER BY assessed_at DESC LIMIT 1",
113    )
114    .bind(&workspace_id)
115    .fetch_optional(pool)
116    .await
117    .map_err(|e| {
118        tracing::error!("Failed to query workspace threats: {}", e);
119        StatusCode::INTERNAL_SERVER_ERROR
120    })?;
121
122    match row {
123        Some(row) => match map_row_to_threat_assessment(&row) {
124            Ok(assessment) => Ok(Json(assessment)),
125            Err(e) => {
126                tracing::error!("Failed to map threat assessment: {}", e);
127                Err(StatusCode::INTERNAL_SERVER_ERROR)
128            }
129        },
130        None => Err(StatusCode::NOT_FOUND),
131    }
132}
133
134/// Get workspace-level threat assessment (no database)
135///
136/// GET /api/v1/threats/workspace/{workspace_id}
137#[cfg(not(feature = "database"))]
138pub async fn get_workspace_threats(
139    State(_state): State<ThreatModelingState>,
140    Path(_workspace_id): Path<String>,
141) -> Result<Json<ThreatAssessment>, StatusCode> {
142    Err(StatusCode::SERVICE_UNAVAILABLE)
143}
144
145/// Get service-level threat assessment
146///
147/// GET /api/v1/threats/service/{service_id}
148#[cfg(feature = "database")]
149pub async fn get_service_threats(
150    State(state): State<ThreatModelingState>,
151    Path(service_id): Path<String>,
152) -> Result<Json<ThreatAssessment>, StatusCode> {
153    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
154        Some(pool) => pool,
155        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
156    };
157
158    let row = sqlx::query(
159        "SELECT * FROM contract_threat_assessments
160         WHERE service_id = $1 AND aggregation_level = 'service'
161         ORDER BY assessed_at DESC LIMIT 1",
162    )
163    .bind(&service_id)
164    .fetch_optional(pool)
165    .await
166    .map_err(|e| {
167        tracing::error!("Failed to query service threats: {}", e);
168        StatusCode::INTERNAL_SERVER_ERROR
169    })?;
170
171    match row {
172        Some(row) => match map_row_to_threat_assessment(&row) {
173            Ok(assessment) => Ok(Json(assessment)),
174            Err(e) => {
175                tracing::error!("Failed to map threat assessment: {}", e);
176                Err(StatusCode::INTERNAL_SERVER_ERROR)
177            }
178        },
179        None => Err(StatusCode::NOT_FOUND),
180    }
181}
182
183/// Get service-level threat assessment (no database)
184///
185/// GET /api/v1/threats/service/{service_id}
186#[cfg(not(feature = "database"))]
187pub async fn get_service_threats(
188    State(_state): State<ThreatModelingState>,
189    Path(_service_id): Path<String>,
190) -> Result<Json<ThreatAssessment>, StatusCode> {
191    Err(StatusCode::SERVICE_UNAVAILABLE)
192}
193
194/// Get endpoint-level threat assessment
195///
196/// GET /api/v1/threats/endpoint/{endpoint}
197#[cfg(feature = "database")]
198pub async fn get_endpoint_threats(
199    State(state): State<ThreatModelingState>,
200    Path(endpoint): Path<String>,
201    Query(params): Query<serde_json::Value>,
202) -> Result<Json<ThreatAssessment>, StatusCode> {
203    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
204        Some(pool) => pool,
205        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
206    };
207
208    let method = params.get("method").and_then(|v| v.as_str()).unwrap_or("%");
209
210    let row = sqlx::query(
211        "SELECT * FROM contract_threat_assessments
212         WHERE endpoint = $1 AND method LIKE $2 AND aggregation_level = 'endpoint'
213         ORDER BY assessed_at DESC LIMIT 1",
214    )
215    .bind(&endpoint)
216    .bind(method)
217    .fetch_optional(pool)
218    .await
219    .map_err(|e| {
220        tracing::error!("Failed to query endpoint threats: {}", e);
221        StatusCode::INTERNAL_SERVER_ERROR
222    })?;
223
224    match row {
225        Some(row) => match map_row_to_threat_assessment(&row) {
226            Ok(assessment) => Ok(Json(assessment)),
227            Err(e) => {
228                tracing::error!("Failed to map threat assessment: {}", e);
229                Err(StatusCode::INTERNAL_SERVER_ERROR)
230            }
231        },
232        None => Err(StatusCode::NOT_FOUND),
233    }
234}
235
236/// Get endpoint-level threat assessment (no database)
237///
238/// GET /api/v1/threats/endpoint/{endpoint}
239#[cfg(not(feature = "database"))]
240pub async fn get_endpoint_threats(
241    State(_state): State<ThreatModelingState>,
242    Path(_endpoint): Path<String>,
243    Query(_params): Query<serde_json::Value>,
244) -> Result<Json<ThreatAssessment>, StatusCode> {
245    Err(StatusCode::SERVICE_UNAVAILABLE)
246}
247
248/// Request to trigger threat assessment
249#[derive(Debug, Deserialize)]
250pub struct AssessThreatsRequest {
251    /// OpenAPI spec (YAML/JSON)
252    pub spec: String,
253    /// Workspace ID
254    pub workspace_id: Option<String>,
255    /// Service ID
256    pub service_id: Option<String>,
257    /// Service name
258    pub service_name: Option<String>,
259    /// Endpoint (optional)
260    pub endpoint: Option<String>,
261    /// Method (optional)
262    pub method: Option<String>,
263}
264
265/// Trigger threat assessment
266///
267/// POST /api/v1/threats/assess
268pub async fn assess_threats(
269    State(state): State<ThreatModelingState>,
270    Json(request): Json<AssessThreatsRequest>,
271) -> Result<Json<ThreatAssessment>, StatusCode> {
272    // Parse OpenAPI spec
273    let spec = match OpenApiSpec::from_string(&request.spec, None) {
274        Ok(spec) => spec,
275        Err(_) => return Err(StatusCode::BAD_REQUEST),
276    };
277
278    // Run threat analysis
279    match state
280        .analyzer
281        .analyze_contract(
282            &spec,
283            request.workspace_id.clone(),
284            request.service_id.clone(),
285            request.service_name.clone(),
286            request.endpoint.clone(),
287            request.method.clone(),
288        )
289        .await
290    {
291        Ok(assessment) => {
292            // Store assessment in database
293            #[cfg(feature = "database")]
294            if let Some(pool) = state.database.as_ref().and_then(|db| db.pool()) {
295                if let Err(e) = store_threat_assessment(pool, &assessment).await {
296                    tracing::warn!("Failed to store threat assessment: {}", e);
297                }
298            }
299
300            // Trigger webhook notifications
301            trigger_threat_assessment_webhooks(&state.webhook_configs, &assessment).await;
302
303            Ok(Json(assessment))
304        }
305        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
306    }
307}
308
309/// Store threat assessment in database
310#[cfg(feature = "database")]
311async fn store_threat_assessment(
312    pool: &sqlx::PgPool,
313    assessment: &ThreatAssessment,
314) -> Result<(), sqlx::Error> {
315    let id = Uuid::new_v4();
316    let workspace_uuid = assessment.workspace_id.as_ref().and_then(|id| Uuid::parse_str(id).ok());
317    let service_uuid = assessment.service_id.as_ref().and_then(|id| Uuid::parse_str(id).ok());
318
319    // Store main assessment
320    sqlx::query(
321        r#"
322        INSERT INTO contract_threat_assessments (
323            id, workspace_id, service_id, service_name, endpoint, method, aggregation_level,
324            threat_level, threat_score, threat_categories, findings, remediation_suggestions,
325            assessed_at, last_updated, created_at
326        ) VALUES (
327            $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
328        )
329        ON CONFLICT (workspace_id, service_id, endpoint, method, aggregation_level)
330        DO UPDATE SET
331            threat_level = EXCLUDED.threat_level,
332            threat_score = EXCLUDED.threat_score,
333            threat_categories = EXCLUDED.threat_categories,
334            findings = EXCLUDED.findings,
335            remediation_suggestions = EXCLUDED.remediation_suggestions,
336            assessed_at = EXCLUDED.assessed_at,
337            last_updated = EXCLUDED.last_updated
338        "#,
339    )
340    .bind(id)
341    .bind(workspace_uuid)
342    .bind(service_uuid)
343    .bind(assessment.service_name.as_deref())
344    .bind(assessment.endpoint.as_deref())
345    .bind(assessment.method.as_deref())
346    .bind(format!("{:?}", assessment.aggregation_level))
347    .bind(format!("{:?}", assessment.threat_level))
348    .bind(assessment.threat_score)
349    .bind(serde_json::to_value(&assessment.threat_categories).unwrap_or_default())
350    .bind(serde_json::to_value(&assessment.findings).unwrap_or_default())
351    .bind(serde_json::to_value(&assessment.remediation_suggestions).unwrap_or_default())
352    .bind(assessment.assessed_at)
353    .bind(Utc::now())
354    .bind(assessment.assessed_at)
355    .execute(pool)
356    .await?;
357
358    // Store individual findings
359    for finding in &assessment.findings {
360        let finding_id = Uuid::new_v4();
361        sqlx::query(
362            r#"
363            INSERT INTO threat_findings (
364                id, assessment_id, finding_type, severity, description, field_path,
365                context, remediation_suggestion, remediation_code_example, confidence,
366                ai_generated_remediation, detected_at, created_at, updated_at
367            ) VALUES (
368                $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14
369            )
370            "#,
371        )
372        .bind(finding_id)
373        .bind(id)
374        .bind(format!("{:?}", finding.finding_type))
375        .bind(format!("{:?}", finding.severity))
376        .bind(&finding.description)
377        .bind(finding.field_path.as_deref())
378        .bind(serde_json::to_value(&finding.context).unwrap_or_default())
379        .bind(None::<String>) // remediation_suggestion from remediation_suggestions
380        .bind(None::<String>) // remediation_code_example
381        .bind(finding.confidence)
382        .bind(false) // ai_generated_remediation
383        .bind(Utc::now())
384        .bind(Utc::now())
385        .bind(Utc::now())
386        .execute(pool)
387        .await?;
388    }
389
390    Ok(())
391}
392
393/// List all threat findings
394///
395/// GET /api/v1/threats/findings
396#[cfg(feature = "database")]
397pub async fn list_findings(
398    State(state): State<ThreatModelingState>,
399    Query(_params): Query<serde_json::Value>,
400) -> Result<Json<serde_json::Value>, StatusCode> {
401    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
402        Some(pool) => pool,
403        None => {
404            return Ok(Json(serde_json::json!({
405                "findings": []
406            })));
407        }
408    };
409
410    let rows = sqlx::query(
411        "SELECT tf.*, ta.workspace_id, ta.service_id, ta.endpoint, ta.method
412         FROM threat_findings tf
413         JOIN contract_threat_assessments ta ON tf.assessment_id = ta.id
414         ORDER BY tf.detected_at DESC LIMIT 100",
415    )
416    .fetch_all(pool)
417    .await
418    .map_err(|e| {
419        tracing::error!("Failed to query threat findings: {}", e);
420        StatusCode::INTERNAL_SERVER_ERROR
421    })?;
422
423    // Map rows to findings
424    use sqlx::Row;
425    let mut findings = Vec::new();
426    for row in rows {
427        let finding_id: uuid::Uuid = row.try_get("id").map_err(|e| {
428            tracing::error!("Failed to get finding id from row: {}", e);
429            StatusCode::INTERNAL_SERVER_ERROR
430        })?;
431        let finding_type_str: String = row.try_get("finding_type").map_err(|e| {
432            tracing::error!("Failed to get finding_type from row: {}", e);
433            StatusCode::INTERNAL_SERVER_ERROR
434        })?;
435        let severity_str: String = row.try_get("severity").map_err(|e| {
436            tracing::error!("Failed to get severity from row: {}", e);
437            StatusCode::INTERNAL_SERVER_ERROR
438        })?;
439        let description: String = row.try_get("description").map_err(|e| {
440            tracing::error!("Failed to get description from row: {}", e);
441            StatusCode::INTERNAL_SERVER_ERROR
442        })?;
443        let field_path: Option<String> = row.try_get("field_path").map_err(|e| {
444            tracing::error!("Failed to get field_path from row: {}", e);
445            StatusCode::INTERNAL_SERVER_ERROR
446        })?;
447        let context_json: serde_json::Value = row.try_get("context").map_err(|e| {
448            tracing::error!("Failed to get context from row: {}", e);
449            StatusCode::INTERNAL_SERVER_ERROR
450        })?;
451        let confidence: f64 = row.try_get("confidence").map_err(|e| {
452            tracing::error!("Failed to get confidence from row: {}", e);
453            StatusCode::INTERNAL_SERVER_ERROR
454        })?;
455
456        use mockforge_core::contract_drift::threat_modeling::{ThreatCategory, ThreatLevel};
457        use std::collections::HashMap;
458
459        let finding_type = match finding_type_str.as_str() {
460            "pii_exposure" => ThreatCategory::PiiExposure,
461            "dos_risk" => ThreatCategory::DoSRisk,
462            "error_leakage" => ThreatCategory::ErrorLeakage,
463            "schema_inconsistency" => ThreatCategory::SchemaInconsistency,
464            "unbounded_arrays" => ThreatCategory::UnboundedArrays,
465            "missing_rate_limits" => ThreatCategory::MissingRateLimits,
466            "stack_trace_leakage" => ThreatCategory::StackTraceLeakage,
467            "sensitive_data_exposure" => ThreatCategory::SensitiveDataExposure,
468            "insecure_schema_design" => ThreatCategory::InsecureSchemaDesign,
469            "missing_validation" => ThreatCategory::MissingValidation,
470            "excessive_optional_fields" => ThreatCategory::ExcessiveOptionalFields,
471            _ => continue, // Skip invalid finding types
472        };
473
474        let severity = match severity_str.as_str() {
475            "low" => ThreatLevel::Low,
476            "medium" => ThreatLevel::Medium,
477            "high" => ThreatLevel::High,
478            "critical" => ThreatLevel::Critical,
479            _ => continue, // Skip invalid severity
480        };
481
482        let context: HashMap<String, serde_json::Value> =
483            serde_json::from_value(context_json).unwrap_or_default();
484
485        findings.push(serde_json::json!({
486            "id": finding_id.to_string(),
487            "finding_type": finding_type_str,
488            "severity": severity_str,
489            "description": description,
490            "field_path": field_path,
491            "context": context,
492            "confidence": confidence,
493        }));
494    }
495
496    Ok(Json(serde_json::json!({
497        "findings": findings,
498        "total": findings.len()
499    })))
500}
501
502/// List threat findings (no database)
503///
504/// GET /api/v1/threats/findings
505#[cfg(not(feature = "database"))]
506pub async fn list_findings(
507    State(_state): State<ThreatModelingState>,
508    Query(_params): Query<serde_json::Value>,
509) -> Result<Json<serde_json::Value>, StatusCode> {
510    Ok(Json(serde_json::json!({
511        "findings": []
512    })))
513}
514
515/// Get remediation suggestions
516///
517/// GET /api/v1/threats/remediations
518#[cfg(feature = "database")]
519pub async fn get_remediations(
520    State(state): State<ThreatModelingState>,
521    Query(_params): Query<serde_json::Value>,
522) -> Result<Json<serde_json::Value>, StatusCode> {
523    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
524        Some(pool) => pool,
525        None => {
526            return Ok(Json(serde_json::json!({
527                "remediations": []
528            })));
529        }
530    };
531
532    // Query remediations from assessments
533    let rows = sqlx::query(
534        "SELECT remediation_suggestions FROM contract_threat_assessments
535         WHERE remediation_suggestions IS NOT NULL AND jsonb_array_length(remediation_suggestions) > 0
536         ORDER BY assessed_at DESC LIMIT 50",
537    )
538    .fetch_all(pool)
539    .await
540    .map_err(|e| {
541        tracing::error!("Failed to query remediations: {}", e);
542        StatusCode::INTERNAL_SERVER_ERROR
543    })?;
544
545    // Extract and flatten remediation suggestions
546    use sqlx::Row;
547    let mut remediations = Vec::new();
548    for row in rows {
549        let remediations_json: serde_json::Value = row.try_get("remediation_suggestions").map_err(|e| {
550            tracing::error!("Failed to get remediation_suggestions from row: {}", e);
551            StatusCode::INTERNAL_SERVER_ERROR
552        })?;
553        if let serde_json::Value::Array(remediation_array) = remediations_json {
554            for remediation in remediation_array {
555                remediations.push(remediation);
556            }
557        }
558    }
559
560    Ok(Json(serde_json::json!({
561        "remediations": remediations,
562        "total": remediations.len()
563    })))
564}
565
566/// Get remediation suggestions (no database)
567///
568/// GET /api/v1/threats/remediations
569#[cfg(not(feature = "database"))]
570pub async fn get_remediations(
571    State(_state): State<ThreatModelingState>,
572    Query(_params): Query<serde_json::Value>,
573) -> Result<Json<serde_json::Value>, StatusCode> {
574    Ok(Json(serde_json::json!({
575        "remediations": []
576    })))
577}
578
579/// Trigger webhook notifications for threat assessment
580async fn trigger_threat_assessment_webhooks(
581    webhook_configs: &[mockforge_core::incidents::integrations::WebhookConfig],
582    assessment: &ThreatAssessment,
583) {
584    use mockforge_core::incidents::integrations::send_webhook;
585    use serde_json::json;
586
587    for webhook in webhook_configs {
588        if !webhook.enabled {
589            continue;
590        }
591
592        let event_type = "threat.assessment.completed";
593        if !webhook.events.is_empty() && !webhook.events.contains(&event_type.to_string()) {
594            continue;
595        }
596
597        let payload = json!({
598            "event": event_type,
599            "assessment": {
600                "workspace_id": assessment.workspace_id,
601                "service_id": assessment.service_id,
602                "service_name": assessment.service_name,
603                "endpoint": assessment.endpoint,
604                "method": assessment.method,
605                "threat_level": format!("{:?}", assessment.threat_level),
606                "threat_score": assessment.threat_score,
607                "findings_count": assessment.findings.len(),
608                "assessed_at": assessment.assessed_at,
609            }
610        });
611
612        let webhook_clone = webhook.clone();
613        tokio::spawn(async move {
614            if let Err(e) = send_webhook(&webhook_clone, &payload).await {
615                tracing::warn!("Failed to send threat assessment webhook: {}", e);
616            }
617        });
618    }
619}
620
621/// Create router for threat modeling endpoints
622pub fn threat_modeling_router(state: ThreatModelingState) -> axum::Router {
623    use axum::routing::{get, post};
624    use axum::Router;
625
626    Router::new()
627        .route("/api/v1/threats/workspace/{workspace_id}", get(get_workspace_threats))
628        .route("/api/v1/threats/service/{service_id}", get(get_service_threats))
629        .route("/api/v1/threats/endpoint/{endpoint}", get(get_endpoint_threats))
630        .route("/api/v1/threats/assess", post(assess_threats))
631        .route("/api/v1/threats/findings", get(list_findings))
632        .route("/api/v1/threats/remediations", get(get_remediations))
633        .with_state(state)
634}