Skip to main content

mockforge_intelligence/handlers/
threat_modeling.rs

1//! HTTP handlers for contract threat modeling
2//!
3//! **Internal / API-only.** No admin UI consumes these endpoints. They are
4//! intended for automated security pipelines and external threat-modeling
5//! tools. Do not build speculative UI for these routes without a
6//! stakeholder-defined use case.
7//!
8//! This module provides endpoints for security threat assessments.
9
10// ThreatAnalyzer stays in core (OpenApiSpec + LLM dep).
11#![allow(deprecated)]
12
13use crate::threat_modeling::ThreatAnalyzer;
14use axum::{
15    extract::{Path, Query, State},
16    http::StatusCode,
17    response::Json,
18};
19use mockforge_foundation::threat_modeling_types::ThreatAssessment;
20use mockforge_openapi::OpenApiSpec;
21use serde::Deserialize;
22use std::sync::Arc;
23
24#[cfg(feature = "database")]
25use chrono::{DateTime, Utc};
26#[cfg(feature = "database")]
27use uuid::Uuid;
28
29#[cfg(feature = "database")]
30use crate::database::Database;
31
32/// Helper function to map database row to ThreatAssessment
33#[cfg(feature = "database")]
34fn map_row_to_threat_assessment(
35    row: &sqlx::postgres::PgRow,
36) -> Result<ThreatAssessment, sqlx::Error> {
37    use mockforge_foundation::threat_modeling_types::{
38        AggregationLevel, RemediationSuggestion, ThreatCategory, ThreatFinding, ThreatLevel,
39    };
40    use sqlx::Row;
41
42    // Parse basic fields
43    let workspace_id: Option<Uuid> = row.try_get("workspace_id")?;
44    let service_id: Option<String> = row.try_get("service_id")?;
45    let service_name: Option<String> = row.try_get("service_name")?;
46    let endpoint: Option<String> = row.try_get("endpoint")?;
47    let method: Option<String> = row.try_get("method")?;
48    let aggregation_level_str: String = row.try_get("aggregation_level")?;
49    let threat_level_str: String = row.try_get("threat_level")?;
50    let threat_score: f64 = row.try_get("threat_score")?;
51    let assessed_at: DateTime<Utc> = row.try_get("assessed_at")?;
52
53    // Parse aggregation level
54    let aggregation_level = match aggregation_level_str.as_str() {
55        "workspace" => AggregationLevel::Workspace,
56        "service" => AggregationLevel::Service,
57        "endpoint" => AggregationLevel::Endpoint,
58        _ => return Err(sqlx::Error::Decode("Invalid aggregation_level".into())),
59    };
60
61    // Parse threat level
62    let threat_level = match threat_level_str.as_str() {
63        "low" => ThreatLevel::Low,
64        "medium" => ThreatLevel::Medium,
65        "high" => ThreatLevel::High,
66        "critical" => ThreatLevel::Critical,
67        _ => return Err(sqlx::Error::Decode("Invalid threat_level".into())),
68    };
69
70    // Parse JSONB columns
71    let threat_categories_json: serde_json::Value = row.try_get("threat_categories")?;
72    let threat_categories: Vec<ThreatCategory> =
73        serde_json::from_value(threat_categories_json).unwrap_or_default();
74
75    let findings_json: serde_json::Value = row.try_get("findings")?;
76    let findings: Vec<ThreatFinding> = serde_json::from_value(findings_json).unwrap_or_default();
77
78    let remediations_json: serde_json::Value = row.try_get("remediation_suggestions")?;
79    let remediation_suggestions: Vec<RemediationSuggestion> =
80        serde_json::from_value(remediations_json).unwrap_or_default();
81
82    Ok(ThreatAssessment {
83        workspace_id: workspace_id.map(|u| u.to_string()),
84        service_id,
85        service_name,
86        endpoint,
87        method,
88        aggregation_level,
89        threat_level,
90        threat_score,
91        threat_categories,
92        findings,
93        remediation_suggestions,
94        assessed_at,
95    })
96}
97
98/// State for threat modeling handlers
99#[derive(Clone)]
100pub struct ThreatModelingState {
101    /// Threat analyzer
102    pub analyzer: Arc<ThreatAnalyzer>,
103    /// Webhook configs for notifications (optional)
104    pub webhook_configs: Vec<crate::incidents::integrations::WebhookConfig>,
105    /// Database connection (optional)
106    #[cfg(feature = "database")]
107    pub database: Option<Database>,
108}
109
110/// Get workspace-level threat assessment
111///
112/// GET /api/v1/threats/workspace/{workspace_id}
113#[cfg(feature = "database")]
114pub async fn get_workspace_threats(
115    State(state): State<ThreatModelingState>,
116    Path(workspace_id): Path<String>,
117) -> Result<Json<ThreatAssessment>, StatusCode> {
118    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
119        Some(pool) => pool,
120        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
121    };
122
123    // Query latest assessment from database
124    let row = sqlx::query(
125        "SELECT * FROM contract_threat_assessments
126         WHERE workspace_id = $1 AND aggregation_level = 'workspace'
127         ORDER BY assessed_at DESC LIMIT 1",
128    )
129    .bind(&workspace_id)
130    .fetch_optional(pool)
131    .await
132    .map_err(|e| {
133        tracing::error!("Failed to query workspace threats: {}", e);
134        StatusCode::INTERNAL_SERVER_ERROR
135    })?;
136
137    match row {
138        Some(row) => match map_row_to_threat_assessment(&row) {
139            Ok(assessment) => Ok(Json(assessment)),
140            Err(e) => {
141                tracing::error!("Failed to map threat assessment: {}", e);
142                Err(StatusCode::INTERNAL_SERVER_ERROR)
143            }
144        },
145        None => Err(StatusCode::NOT_FOUND),
146    }
147}
148
149/// Get workspace-level threat assessment (no database)
150///
151/// GET /api/v1/threats/workspace/{workspace_id}
152#[cfg(not(feature = "database"))]
153pub async fn get_workspace_threats(
154    State(_state): State<ThreatModelingState>,
155    Path(_workspace_id): Path<String>,
156) -> Result<Json<ThreatAssessment>, StatusCode> {
157    Err(StatusCode::SERVICE_UNAVAILABLE)
158}
159
160/// Get service-level threat assessment
161///
162/// GET /api/v1/threats/service/{service_id}
163#[cfg(feature = "database")]
164pub async fn get_service_threats(
165    State(state): State<ThreatModelingState>,
166    Path(service_id): Path<String>,
167) -> Result<Json<ThreatAssessment>, StatusCode> {
168    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
169        Some(pool) => pool,
170        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
171    };
172
173    let row = sqlx::query(
174        "SELECT * FROM contract_threat_assessments
175         WHERE service_id = $1 AND aggregation_level = 'service'
176         ORDER BY assessed_at DESC LIMIT 1",
177    )
178    .bind(&service_id)
179    .fetch_optional(pool)
180    .await
181    .map_err(|e| {
182        tracing::error!("Failed to query service threats: {}", e);
183        StatusCode::INTERNAL_SERVER_ERROR
184    })?;
185
186    match row {
187        Some(row) => match map_row_to_threat_assessment(&row) {
188            Ok(assessment) => Ok(Json(assessment)),
189            Err(e) => {
190                tracing::error!("Failed to map threat assessment: {}", e);
191                Err(StatusCode::INTERNAL_SERVER_ERROR)
192            }
193        },
194        None => Err(StatusCode::NOT_FOUND),
195    }
196}
197
198/// Get service-level threat assessment (no database)
199///
200/// GET /api/v1/threats/service/{service_id}
201#[cfg(not(feature = "database"))]
202pub async fn get_service_threats(
203    State(_state): State<ThreatModelingState>,
204    Path(_service_id): Path<String>,
205) -> Result<Json<ThreatAssessment>, StatusCode> {
206    Err(StatusCode::SERVICE_UNAVAILABLE)
207}
208
209/// Get endpoint-level threat assessment
210///
211/// GET /api/v1/threats/endpoint/{endpoint}
212#[cfg(feature = "database")]
213pub async fn get_endpoint_threats(
214    State(state): State<ThreatModelingState>,
215    Path(endpoint): Path<String>,
216    Query(params): Query<serde_json::Value>,
217) -> Result<Json<ThreatAssessment>, StatusCode> {
218    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
219        Some(pool) => pool,
220        None => return Err(StatusCode::SERVICE_UNAVAILABLE),
221    };
222
223    let method = params.get("method").and_then(|v| v.as_str()).unwrap_or("%");
224
225    let row = sqlx::query(
226        "SELECT * FROM contract_threat_assessments
227         WHERE endpoint = $1 AND method LIKE $2 AND aggregation_level = 'endpoint'
228         ORDER BY assessed_at DESC LIMIT 1",
229    )
230    .bind(&endpoint)
231    .bind(method)
232    .fetch_optional(pool)
233    .await
234    .map_err(|e| {
235        tracing::error!("Failed to query endpoint threats: {}", e);
236        StatusCode::INTERNAL_SERVER_ERROR
237    })?;
238
239    match row {
240        Some(row) => match map_row_to_threat_assessment(&row) {
241            Ok(assessment) => Ok(Json(assessment)),
242            Err(e) => {
243                tracing::error!("Failed to map threat assessment: {}", e);
244                Err(StatusCode::INTERNAL_SERVER_ERROR)
245            }
246        },
247        None => Err(StatusCode::NOT_FOUND),
248    }
249}
250
251/// Get endpoint-level threat assessment (no database)
252///
253/// GET /api/v1/threats/endpoint/{endpoint}
254#[cfg(not(feature = "database"))]
255pub async fn get_endpoint_threats(
256    State(_state): State<ThreatModelingState>,
257    Path(_endpoint): Path<String>,
258    Query(_params): Query<serde_json::Value>,
259) -> Result<Json<ThreatAssessment>, StatusCode> {
260    Err(StatusCode::SERVICE_UNAVAILABLE)
261}
262
263/// Request to trigger threat assessment
264#[derive(Debug, Deserialize)]
265pub struct AssessThreatsRequest {
266    /// OpenAPI spec (YAML/JSON)
267    pub spec: String,
268    /// Workspace ID
269    pub workspace_id: Option<String>,
270    /// Service ID
271    pub service_id: Option<String>,
272    /// Service name
273    pub service_name: Option<String>,
274    /// Endpoint (optional)
275    pub endpoint: Option<String>,
276    /// Method (optional)
277    pub method: Option<String>,
278}
279
280/// Trigger threat assessment
281///
282/// POST /api/v1/threats/assess
283pub async fn assess_threats(
284    State(state): State<ThreatModelingState>,
285    Json(request): Json<AssessThreatsRequest>,
286) -> Result<Json<ThreatAssessment>, StatusCode> {
287    // Parse OpenAPI spec
288    let spec = match OpenApiSpec::from_string(&request.spec, None) {
289        Ok(spec) => spec,
290        Err(_) => return Err(StatusCode::BAD_REQUEST),
291    };
292
293    // Run threat analysis
294    match state
295        .analyzer
296        .analyze_contract(
297            &spec,
298            request.workspace_id.clone(),
299            request.service_id.clone(),
300            request.service_name.clone(),
301            request.endpoint.clone(),
302            request.method.clone(),
303        )
304        .await
305    {
306        Ok(assessment) => {
307            // Store assessment in database
308            #[cfg(feature = "database")]
309            if let Some(pool) = state.database.as_ref().and_then(|db| db.pool()) {
310                if let Err(e) = store_threat_assessment(pool, &assessment).await {
311                    tracing::warn!("Failed to store threat assessment: {}", e);
312                }
313            }
314
315            // Trigger webhook notifications
316            trigger_threat_assessment_webhooks(&state.webhook_configs, &assessment).await;
317
318            Ok(Json(assessment))
319        }
320        Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
321    }
322}
323
324/// Store threat assessment in database
325#[cfg(feature = "database")]
326async fn store_threat_assessment(
327    pool: &sqlx::PgPool,
328    assessment: &ThreatAssessment,
329) -> Result<(), sqlx::Error> {
330    let id = Uuid::new_v4();
331    let workspace_uuid = assessment.workspace_id.as_ref().and_then(|id| Uuid::parse_str(id).ok());
332    let service_uuid = assessment.service_id.as_ref().and_then(|id| Uuid::parse_str(id).ok());
333
334    // Store main assessment
335    sqlx::query(
336        r#"
337        INSERT INTO contract_threat_assessments (
338            id, workspace_id, service_id, service_name, endpoint, method, aggregation_level,
339            threat_level, threat_score, threat_categories, findings, remediation_suggestions,
340            assessed_at, last_updated, created_at
341        ) VALUES (
342            $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
343        )
344        ON CONFLICT (workspace_id, service_id, endpoint, method, aggregation_level)
345        DO UPDATE SET
346            threat_level = EXCLUDED.threat_level,
347            threat_score = EXCLUDED.threat_score,
348            threat_categories = EXCLUDED.threat_categories,
349            findings = EXCLUDED.findings,
350            remediation_suggestions = EXCLUDED.remediation_suggestions,
351            assessed_at = EXCLUDED.assessed_at,
352            last_updated = EXCLUDED.last_updated
353        "#,
354    )
355    .bind(id)
356    .bind(workspace_uuid)
357    .bind(service_uuid)
358    .bind(assessment.service_name.as_deref())
359    .bind(assessment.endpoint.as_deref())
360    .bind(assessment.method.as_deref())
361    .bind(format!("{:?}", assessment.aggregation_level))
362    .bind(format!("{:?}", assessment.threat_level))
363    .bind(assessment.threat_score)
364    .bind(serde_json::to_value(&assessment.threat_categories).unwrap_or_default())
365    .bind(serde_json::to_value(&assessment.findings).unwrap_or_default())
366    .bind(serde_json::to_value(&assessment.remediation_suggestions).unwrap_or_default())
367    .bind(assessment.assessed_at)
368    .bind(Utc::now())
369    .bind(assessment.assessed_at)
370    .execute(pool)
371    .await?;
372
373    // Store individual findings
374    for finding in &assessment.findings {
375        let finding_id = Uuid::new_v4();
376        sqlx::query(
377            r#"
378            INSERT INTO threat_findings (
379                id, assessment_id, finding_type, severity, description, field_path,
380                context, remediation_suggestion, remediation_code_example, confidence,
381                ai_generated_remediation, detected_at, created_at, updated_at
382            ) VALUES (
383                $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14
384            )
385            "#,
386        )
387        .bind(finding_id)
388        .bind(id)
389        .bind(format!("{:?}", finding.finding_type))
390        .bind(format!("{:?}", finding.severity))
391        .bind(&finding.description)
392        .bind(finding.field_path.as_deref())
393        .bind(serde_json::to_value(&finding.context).unwrap_or_default())
394        .bind(None::<String>) // remediation_suggestion from remediation_suggestions
395        .bind(None::<String>) // remediation_code_example
396        .bind(finding.confidence)
397        .bind(false) // ai_generated_remediation
398        .bind(Utc::now())
399        .bind(Utc::now())
400        .bind(Utc::now())
401        .execute(pool)
402        .await?;
403    }
404
405    Ok(())
406}
407
408/// List all threat findings
409///
410/// GET /api/v1/threats/findings
411#[cfg(feature = "database")]
412pub async fn list_findings(
413    State(state): State<ThreatModelingState>,
414    Query(_params): Query<serde_json::Value>,
415) -> Result<Json<serde_json::Value>, StatusCode> {
416    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
417        Some(pool) => pool,
418        None => {
419            return Ok(Json(serde_json::json!({
420                "findings": []
421            })));
422        }
423    };
424
425    let rows = sqlx::query(
426        "SELECT tf.*, ta.workspace_id, ta.service_id, ta.endpoint, ta.method
427         FROM threat_findings tf
428         JOIN contract_threat_assessments ta ON tf.assessment_id = ta.id
429         ORDER BY tf.detected_at DESC LIMIT 100",
430    )
431    .fetch_all(pool)
432    .await
433    .map_err(|e| {
434        tracing::error!("Failed to query threat findings: {}", e);
435        StatusCode::INTERNAL_SERVER_ERROR
436    })?;
437
438    // Map rows to findings
439    use sqlx::Row;
440    let mut findings = Vec::new();
441    for row in rows {
442        let finding_id: Uuid = row.try_get("id").map_err(|e| {
443            tracing::error!("Failed to get finding id from row: {}", e);
444            StatusCode::INTERNAL_SERVER_ERROR
445        })?;
446        let finding_type_str: String = row.try_get("finding_type").map_err(|e| {
447            tracing::error!("Failed to get finding_type from row: {}", e);
448            StatusCode::INTERNAL_SERVER_ERROR
449        })?;
450        let severity_str: String = row.try_get("severity").map_err(|e| {
451            tracing::error!("Failed to get severity from row: {}", e);
452            StatusCode::INTERNAL_SERVER_ERROR
453        })?;
454        let description: String = row.try_get("description").map_err(|e| {
455            tracing::error!("Failed to get description from row: {}", e);
456            StatusCode::INTERNAL_SERVER_ERROR
457        })?;
458        let field_path: Option<String> = row.try_get("field_path").map_err(|e| {
459            tracing::error!("Failed to get field_path from row: {}", e);
460            StatusCode::INTERNAL_SERVER_ERROR
461        })?;
462        let context_json: serde_json::Value = row.try_get("context").map_err(|e| {
463            tracing::error!("Failed to get context from row: {}", e);
464            StatusCode::INTERNAL_SERVER_ERROR
465        })?;
466        let confidence: f64 = row.try_get("confidence").map_err(|e| {
467            tracing::error!("Failed to get confidence from row: {}", e);
468            StatusCode::INTERNAL_SERVER_ERROR
469        })?;
470
471        use mockforge_foundation::threat_modeling_types::{ThreatCategory, ThreatLevel};
472        use std::collections::HashMap;
473
474        let _finding_type = match finding_type_str.as_str() {
475            "pii_exposure" => ThreatCategory::PiiExposure,
476            "dos_risk" => ThreatCategory::DoSRisk,
477            "error_leakage" => ThreatCategory::ErrorLeakage,
478            "schema_inconsistency" => ThreatCategory::SchemaInconsistency,
479            "unbounded_arrays" => ThreatCategory::UnboundedArrays,
480            "missing_rate_limits" => ThreatCategory::MissingRateLimits,
481            "stack_trace_leakage" => ThreatCategory::StackTraceLeakage,
482            "sensitive_data_exposure" => ThreatCategory::SensitiveDataExposure,
483            "insecure_schema_design" => ThreatCategory::InsecureSchemaDesign,
484            "missing_validation" => ThreatCategory::MissingValidation,
485            "excessive_optional_fields" => ThreatCategory::ExcessiveOptionalFields,
486            _ => continue, // Skip invalid finding types
487        };
488
489        let _severity = match severity_str.as_str() {
490            "low" => ThreatLevel::Low,
491            "medium" => ThreatLevel::Medium,
492            "high" => ThreatLevel::High,
493            "critical" => ThreatLevel::Critical,
494            _ => continue, // Skip invalid severity
495        };
496
497        let context: HashMap<String, serde_json::Value> =
498            serde_json::from_value(context_json).unwrap_or_default();
499
500        findings.push(serde_json::json!({
501            "id": finding_id.to_string(),
502            "finding_type": finding_type_str,
503            "severity": severity_str,
504            "description": description,
505            "field_path": field_path,
506            "context": context,
507            "confidence": confidence,
508        }));
509    }
510
511    Ok(Json(serde_json::json!({
512        "findings": findings,
513        "total": findings.len()
514    })))
515}
516
517/// List threat findings (no database)
518///
519/// GET /api/v1/threats/findings
520#[cfg(not(feature = "database"))]
521pub async fn list_findings(
522    State(_state): State<ThreatModelingState>,
523    Query(_params): Query<serde_json::Value>,
524) -> Result<Json<serde_json::Value>, StatusCode> {
525    Ok(Json(serde_json::json!({
526        "findings": []
527    })))
528}
529
530/// Get remediation suggestions
531///
532/// GET /api/v1/threats/remediations
533#[cfg(feature = "database")]
534pub async fn get_remediations(
535    State(state): State<ThreatModelingState>,
536    Query(_params): Query<serde_json::Value>,
537) -> Result<Json<serde_json::Value>, StatusCode> {
538    let pool = match state.database.as_ref().and_then(|db| db.pool()) {
539        Some(pool) => pool,
540        None => {
541            return Ok(Json(serde_json::json!({
542                "remediations": []
543            })));
544        }
545    };
546
547    // Query remediations from assessments
548    let rows = sqlx::query(
549        "SELECT remediation_suggestions FROM contract_threat_assessments
550         WHERE remediation_suggestions IS NOT NULL AND jsonb_array_length(remediation_suggestions) > 0
551         ORDER BY assessed_at DESC LIMIT 50",
552    )
553    .fetch_all(pool)
554    .await
555    .map_err(|e| {
556        tracing::error!("Failed to query remediations: {}", e);
557        StatusCode::INTERNAL_SERVER_ERROR
558    })?;
559
560    // Extract and flatten remediation suggestions
561    use sqlx::Row;
562    let mut remediations = Vec::new();
563    for row in rows {
564        let remediations_json: serde_json::Value =
565            row.try_get("remediation_suggestions").map_err(|e| {
566                tracing::error!("Failed to get remediation_suggestions from row: {}", e);
567                StatusCode::INTERNAL_SERVER_ERROR
568            })?;
569        if let serde_json::Value::Array(remediation_array) = remediations_json {
570            for remediation in remediation_array {
571                remediations.push(remediation);
572            }
573        }
574    }
575
576    Ok(Json(serde_json::json!({
577        "remediations": remediations,
578        "total": remediations.len()
579    })))
580}
581
582/// Get remediation suggestions (no database)
583///
584/// GET /api/v1/threats/remediations
585#[cfg(not(feature = "database"))]
586pub async fn get_remediations(
587    State(_state): State<ThreatModelingState>,
588    Query(_params): Query<serde_json::Value>,
589) -> Result<Json<serde_json::Value>, StatusCode> {
590    Ok(Json(serde_json::json!({
591        "remediations": []
592    })))
593}
594
595/// Trigger webhook notifications for threat assessment
596async fn trigger_threat_assessment_webhooks(
597    webhook_configs: &[crate::incidents::integrations::WebhookConfig],
598    assessment: &ThreatAssessment,
599) {
600    use crate::incidents::integrations::send_webhook;
601    use serde_json::json;
602
603    for webhook in webhook_configs {
604        if !webhook.enabled {
605            continue;
606        }
607
608        let event_type = "threat.assessment.completed";
609        if !webhook.events.is_empty() && !webhook.events.contains(&event_type.to_string()) {
610            continue;
611        }
612
613        let payload = json!({
614            "event": event_type,
615            "assessment": {
616                "workspace_id": assessment.workspace_id,
617                "service_id": assessment.service_id,
618                "service_name": assessment.service_name,
619                "endpoint": assessment.endpoint,
620                "method": assessment.method,
621                "threat_level": format!("{:?}", assessment.threat_level),
622                "threat_score": assessment.threat_score,
623                "findings_count": assessment.findings.len(),
624                "assessed_at": assessment.assessed_at,
625            }
626        });
627
628        let webhook_clone = webhook.clone();
629        tokio::spawn(async move {
630            if let Err(e) = send_webhook(&webhook_clone, &payload).await {
631                tracing::warn!("Failed to send threat assessment webhook: {}", e);
632            }
633        });
634    }
635}
636
637/// Create router for threat modeling endpoints
638pub fn threat_modeling_router(state: ThreatModelingState) -> axum::Router {
639    use axum::routing::{get, post};
640    use axum::Router;
641
642    Router::new()
643        .route("/api/v1/threats/workspace/{workspace_id}", get(get_workspace_threats))
644        .route("/api/v1/threats/service/{service_id}", get(get_service_threats))
645        .route("/api/v1/threats/endpoint/{endpoint}", get(get_endpoint_threats))
646        .route("/api/v1/threats/assess", post(assess_threats))
647        .route("/api/v1/threats/findings", get(list_findings))
648        .route("/api/v1/threats/remediations", get(get_remediations))
649        .with_state(state)
650}