1use axum::{
6 extract::{Path, Query, State},
7 http::StatusCode,
8 response::Json,
9};
10use chrono::{DateTime, Utc};
11use mockforge_core::contract_drift::threat_modeling::{ThreatAnalyzer, ThreatAssessment};
12use mockforge_core::openapi::OpenApiSpec;
13use serde::Deserialize;
14use std::sync::Arc;
15use uuid::Uuid;
16
17use crate::database::Database;
18
19#[cfg(feature = "database")]
21fn map_row_to_threat_assessment(
22 row: &sqlx::postgres::PgRow,
23) -> Result<ThreatAssessment, sqlx::Error> {
24 use mockforge_core::contract_drift::threat_modeling::{
25 AggregationLevel, RemediationSuggestion, ThreatCategory, ThreatFinding, ThreatLevel,
26 };
27
28 let workspace_id: Option<uuid::Uuid> = row.try_get("workspace_id")?;
30 let service_id: Option<String> = row.try_get("service_id")?;
31 let service_name: Option<String> = row.try_get("service_name")?;
32 let endpoint: Option<String> = row.try_get("endpoint")?;
33 let method: Option<String> = row.try_get("method")?;
34 let aggregation_level_str: String = row.try_get("aggregation_level")?;
35 let threat_level_str: String = row.try_get("threat_level")?;
36 let threat_score: f64 = row.try_get("threat_score")?;
37 let assessed_at: DateTime<Utc> = row.try_get("assessed_at")?;
38
39 let aggregation_level = match aggregation_level_str.as_str() {
41 "workspace" => AggregationLevel::Workspace,
42 "service" => AggregationLevel::Service,
43 "endpoint" => AggregationLevel::Endpoint,
44 _ => return Err(sqlx::Error::Decode("Invalid aggregation_level".into())),
45 };
46
47 let threat_level = match threat_level_str.as_str() {
49 "low" => ThreatLevel::Low,
50 "medium" => ThreatLevel::Medium,
51 "high" => ThreatLevel::High,
52 "critical" => ThreatLevel::Critical,
53 _ => return Err(sqlx::Error::Decode("Invalid threat_level".into())),
54 };
55
56 let threat_categories_json: serde_json::Value = row.try_get("threat_categories")?;
58 let threat_categories: Vec<ThreatCategory> =
59 serde_json::from_value(threat_categories_json).unwrap_or_default();
60
61 let findings_json: serde_json::Value = row.try_get("findings")?;
62 let findings: Vec<ThreatFinding> = serde_json::from_value(findings_json).unwrap_or_default();
63
64 let remediations_json: serde_json::Value = row.try_get("remediation_suggestions")?;
65 let remediation_suggestions: Vec<RemediationSuggestion> =
66 serde_json::from_value(remediations_json).unwrap_or_default();
67
68 Ok(ThreatAssessment {
69 workspace_id: workspace_id.map(|u| u.to_string()),
70 service_id,
71 service_name,
72 endpoint,
73 method,
74 aggregation_level,
75 threat_level,
76 threat_score,
77 threat_categories,
78 findings,
79 remediation_suggestions,
80 assessed_at,
81 })
82}
83
84#[derive(Clone)]
86pub struct ThreatModelingState {
87 pub analyzer: Arc<ThreatAnalyzer>,
89 pub webhook_configs: Vec<mockforge_core::incidents::integrations::WebhookConfig>,
91 pub database: Option<Database>,
93}
94
95#[cfg(feature = "database")]
99pub async fn get_workspace_threats(
100 State(state): State<ThreatModelingState>,
101 Path(workspace_id): Path<String>,
102) -> Result<Json<ThreatAssessment>, StatusCode> {
103 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
104 Some(pool) => pool,
105 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
106 };
107
108 let row = sqlx::query(
110 "SELECT * FROM contract_threat_assessments
111 WHERE workspace_id = $1 AND aggregation_level = 'workspace'
112 ORDER BY assessed_at DESC LIMIT 1",
113 )
114 .bind(&workspace_id)
115 .fetch_optional(pool)
116 .await
117 .map_err(|e| {
118 tracing::error!("Failed to query workspace threats: {}", e);
119 StatusCode::INTERNAL_SERVER_ERROR
120 })?;
121
122 match row {
123 Some(row) => match map_row_to_threat_assessment(&row) {
124 Ok(assessment) => Ok(Json(assessment)),
125 Err(e) => {
126 tracing::error!("Failed to map threat assessment: {}", e);
127 Err(StatusCode::INTERNAL_SERVER_ERROR)
128 }
129 },
130 None => Err(StatusCode::NOT_FOUND),
131 }
132}
133
134#[cfg(not(feature = "database"))]
138pub async fn get_workspace_threats(
139 State(_state): State<ThreatModelingState>,
140 Path(_workspace_id): Path<String>,
141) -> Result<Json<ThreatAssessment>, StatusCode> {
142 Err(StatusCode::SERVICE_UNAVAILABLE)
143}
144
145#[cfg(feature = "database")]
149pub async fn get_service_threats(
150 State(state): State<ThreatModelingState>,
151 Path(service_id): Path<String>,
152) -> Result<Json<ThreatAssessment>, StatusCode> {
153 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
154 Some(pool) => pool,
155 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
156 };
157
158 let row = sqlx::query(
159 "SELECT * FROM contract_threat_assessments
160 WHERE service_id = $1 AND aggregation_level = 'service'
161 ORDER BY assessed_at DESC LIMIT 1",
162 )
163 .bind(&service_id)
164 .fetch_optional(pool)
165 .await
166 .map_err(|e| {
167 tracing::error!("Failed to query service threats: {}", e);
168 StatusCode::INTERNAL_SERVER_ERROR
169 })?;
170
171 match row {
172 Some(row) => match map_row_to_threat_assessment(&row) {
173 Ok(assessment) => Ok(Json(assessment)),
174 Err(e) => {
175 tracing::error!("Failed to map threat assessment: {}", e);
176 Err(StatusCode::INTERNAL_SERVER_ERROR)
177 }
178 },
179 None => Err(StatusCode::NOT_FOUND),
180 }
181}
182
183#[cfg(not(feature = "database"))]
187pub async fn get_service_threats(
188 State(_state): State<ThreatModelingState>,
189 Path(_service_id): Path<String>,
190) -> Result<Json<ThreatAssessment>, StatusCode> {
191 Err(StatusCode::SERVICE_UNAVAILABLE)
192}
193
194#[cfg(feature = "database")]
198pub async fn get_endpoint_threats(
199 State(state): State<ThreatModelingState>,
200 Path(endpoint): Path<String>,
201 Query(params): Query<serde_json::Value>,
202) -> Result<Json<ThreatAssessment>, StatusCode> {
203 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
204 Some(pool) => pool,
205 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
206 };
207
208 let method = params.get("method").and_then(|v| v.as_str()).unwrap_or("%");
209
210 let row = sqlx::query(
211 "SELECT * FROM contract_threat_assessments
212 WHERE endpoint = $1 AND method LIKE $2 AND aggregation_level = 'endpoint'
213 ORDER BY assessed_at DESC LIMIT 1",
214 )
215 .bind(&endpoint)
216 .bind(method)
217 .fetch_optional(pool)
218 .await
219 .map_err(|e| {
220 tracing::error!("Failed to query endpoint threats: {}", e);
221 StatusCode::INTERNAL_SERVER_ERROR
222 })?;
223
224 match row {
225 Some(row) => match map_row_to_threat_assessment(&row) {
226 Ok(assessment) => Ok(Json(assessment)),
227 Err(e) => {
228 tracing::error!("Failed to map threat assessment: {}", e);
229 Err(StatusCode::INTERNAL_SERVER_ERROR)
230 }
231 },
232 None => Err(StatusCode::NOT_FOUND),
233 }
234}
235
236#[cfg(not(feature = "database"))]
240pub async fn get_endpoint_threats(
241 State(_state): State<ThreatModelingState>,
242 Path(_endpoint): Path<String>,
243 Query(_params): Query<serde_json::Value>,
244) -> Result<Json<ThreatAssessment>, StatusCode> {
245 Err(StatusCode::SERVICE_UNAVAILABLE)
246}
247
248#[derive(Debug, Deserialize)]
250pub struct AssessThreatsRequest {
251 pub spec: String,
253 pub workspace_id: Option<String>,
255 pub service_id: Option<String>,
257 pub service_name: Option<String>,
259 pub endpoint: Option<String>,
261 pub method: Option<String>,
263}
264
265pub async fn assess_threats(
269 State(state): State<ThreatModelingState>,
270 Json(request): Json<AssessThreatsRequest>,
271) -> Result<Json<ThreatAssessment>, StatusCode> {
272 let spec = match OpenApiSpec::from_string(&request.spec, None) {
274 Ok(spec) => spec,
275 Err(_) => return Err(StatusCode::BAD_REQUEST),
276 };
277
278 match state
280 .analyzer
281 .analyze_contract(
282 &spec,
283 request.workspace_id.clone(),
284 request.service_id.clone(),
285 request.service_name.clone(),
286 request.endpoint.clone(),
287 request.method.clone(),
288 )
289 .await
290 {
291 Ok(assessment) => {
292 #[cfg(feature = "database")]
294 if let Some(pool) = state.database.as_ref().and_then(|db| db.pool()) {
295 if let Err(e) = store_threat_assessment(pool, &assessment).await {
296 tracing::warn!("Failed to store threat assessment: {}", e);
297 }
298 }
299
300 trigger_threat_assessment_webhooks(&state.webhook_configs, &assessment).await;
302
303 Ok(Json(assessment))
304 }
305 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
306 }
307}
308
309#[cfg(feature = "database")]
311async fn store_threat_assessment(
312 pool: &sqlx::PgPool,
313 assessment: &ThreatAssessment,
314) -> Result<(), sqlx::Error> {
315 let id = Uuid::new_v4();
316 let workspace_uuid = assessment.workspace_id.as_ref().and_then(|id| Uuid::parse_str(id).ok());
317 let service_uuid = assessment.service_id.as_ref().and_then(|id| Uuid::parse_str(id).ok());
318
319 sqlx::query(
321 r#"
322 INSERT INTO contract_threat_assessments (
323 id, workspace_id, service_id, service_name, endpoint, method, aggregation_level,
324 threat_level, threat_score, threat_categories, findings, remediation_suggestions,
325 assessed_at, last_updated, created_at
326 ) VALUES (
327 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
328 )
329 ON CONFLICT (workspace_id, service_id, endpoint, method, aggregation_level)
330 DO UPDATE SET
331 threat_level = EXCLUDED.threat_level,
332 threat_score = EXCLUDED.threat_score,
333 threat_categories = EXCLUDED.threat_categories,
334 findings = EXCLUDED.findings,
335 remediation_suggestions = EXCLUDED.remediation_suggestions,
336 assessed_at = EXCLUDED.assessed_at,
337 last_updated = EXCLUDED.last_updated
338 "#,
339 )
340 .bind(id)
341 .bind(workspace_uuid)
342 .bind(service_uuid)
343 .bind(assessment.service_name.as_deref())
344 .bind(assessment.endpoint.as_deref())
345 .bind(assessment.method.as_deref())
346 .bind(format!("{:?}", assessment.aggregation_level))
347 .bind(format!("{:?}", assessment.threat_level))
348 .bind(assessment.threat_score)
349 .bind(serde_json::to_value(&assessment.threat_categories).unwrap_or_default())
350 .bind(serde_json::to_value(&assessment.findings).unwrap_or_default())
351 .bind(serde_json::to_value(&assessment.remediation_suggestions).unwrap_or_default())
352 .bind(assessment.assessed_at)
353 .bind(Utc::now())
354 .bind(assessment.assessed_at)
355 .execute(pool)
356 .await?;
357
358 for finding in &assessment.findings {
360 let finding_id = Uuid::new_v4();
361 sqlx::query(
362 r#"
363 INSERT INTO threat_findings (
364 id, assessment_id, finding_type, severity, description, field_path,
365 context, remediation_suggestion, remediation_code_example, confidence,
366 ai_generated_remediation, detected_at, created_at, updated_at
367 ) VALUES (
368 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14
369 )
370 "#,
371 )
372 .bind(finding_id)
373 .bind(id)
374 .bind(format!("{:?}", finding.finding_type))
375 .bind(format!("{:?}", finding.severity))
376 .bind(&finding.description)
377 .bind(finding.field_path.as_deref())
378 .bind(serde_json::to_value(&finding.context).unwrap_or_default())
379 .bind(None::<String>) .bind(None::<String>) .bind(finding.confidence)
382 .bind(false) .bind(Utc::now())
384 .bind(Utc::now())
385 .bind(Utc::now())
386 .execute(pool)
387 .await?;
388 }
389
390 Ok(())
391}
392
393#[cfg(feature = "database")]
397pub async fn list_findings(
398 State(state): State<ThreatModelingState>,
399 Query(_params): Query<serde_json::Value>,
400) -> Result<Json<serde_json::Value>, StatusCode> {
401 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
402 Some(pool) => pool,
403 None => {
404 return Ok(Json(serde_json::json!({
405 "findings": []
406 })));
407 }
408 };
409
410 let rows = sqlx::query(
411 "SELECT tf.*, ta.workspace_id, ta.service_id, ta.endpoint, ta.method
412 FROM threat_findings tf
413 JOIN contract_threat_assessments ta ON tf.assessment_id = ta.id
414 ORDER BY tf.detected_at DESC LIMIT 100",
415 )
416 .fetch_all(pool)
417 .await
418 .map_err(|e| {
419 tracing::error!("Failed to query threat findings: {}", e);
420 StatusCode::INTERNAL_SERVER_ERROR
421 })?;
422
423 use sqlx::Row;
425 let mut findings = Vec::new();
426 for row in rows {
427 let finding_id: uuid::Uuid = row.try_get("id").map_err(|e| {
428 tracing::error!("Failed to get finding id from row: {}", e);
429 StatusCode::INTERNAL_SERVER_ERROR
430 })?;
431 let finding_type_str: String = row.try_get("finding_type").map_err(|e| {
432 tracing::error!("Failed to get finding_type from row: {}", e);
433 StatusCode::INTERNAL_SERVER_ERROR
434 })?;
435 let severity_str: String = row.try_get("severity").map_err(|e| {
436 tracing::error!("Failed to get severity from row: {}", e);
437 StatusCode::INTERNAL_SERVER_ERROR
438 })?;
439 let description: String = row.try_get("description").map_err(|e| {
440 tracing::error!("Failed to get description from row: {}", e);
441 StatusCode::INTERNAL_SERVER_ERROR
442 })?;
443 let field_path: Option<String> = row.try_get("field_path").map_err(|e| {
444 tracing::error!("Failed to get field_path from row: {}", e);
445 StatusCode::INTERNAL_SERVER_ERROR
446 })?;
447 let context_json: serde_json::Value = row.try_get("context").map_err(|e| {
448 tracing::error!("Failed to get context from row: {}", e);
449 StatusCode::INTERNAL_SERVER_ERROR
450 })?;
451 let confidence: f64 = row.try_get("confidence").map_err(|e| {
452 tracing::error!("Failed to get confidence from row: {}", e);
453 StatusCode::INTERNAL_SERVER_ERROR
454 })?;
455
456 use mockforge_core::contract_drift::threat_modeling::{ThreatCategory, ThreatLevel};
457 use std::collections::HashMap;
458
459 let finding_type = match finding_type_str.as_str() {
460 "pii_exposure" => ThreatCategory::PiiExposure,
461 "dos_risk" => ThreatCategory::DoSRisk,
462 "error_leakage" => ThreatCategory::ErrorLeakage,
463 "schema_inconsistency" => ThreatCategory::SchemaInconsistency,
464 "unbounded_arrays" => ThreatCategory::UnboundedArrays,
465 "missing_rate_limits" => ThreatCategory::MissingRateLimits,
466 "stack_trace_leakage" => ThreatCategory::StackTraceLeakage,
467 "sensitive_data_exposure" => ThreatCategory::SensitiveDataExposure,
468 "insecure_schema_design" => ThreatCategory::InsecureSchemaDesign,
469 "missing_validation" => ThreatCategory::MissingValidation,
470 "excessive_optional_fields" => ThreatCategory::ExcessiveOptionalFields,
471 _ => continue, };
473
474 let severity = match severity_str.as_str() {
475 "low" => ThreatLevel::Low,
476 "medium" => ThreatLevel::Medium,
477 "high" => ThreatLevel::High,
478 "critical" => ThreatLevel::Critical,
479 _ => continue, };
481
482 let context: HashMap<String, serde_json::Value> =
483 serde_json::from_value(context_json).unwrap_or_default();
484
485 findings.push(serde_json::json!({
486 "id": finding_id.to_string(),
487 "finding_type": finding_type_str,
488 "severity": severity_str,
489 "description": description,
490 "field_path": field_path,
491 "context": context,
492 "confidence": confidence,
493 }));
494 }
495
496 Ok(Json(serde_json::json!({
497 "findings": findings,
498 "total": findings.len()
499 })))
500}
501
502#[cfg(not(feature = "database"))]
506pub async fn list_findings(
507 State(_state): State<ThreatModelingState>,
508 Query(_params): Query<serde_json::Value>,
509) -> Result<Json<serde_json::Value>, StatusCode> {
510 Ok(Json(serde_json::json!({
511 "findings": []
512 })))
513}
514
515#[cfg(feature = "database")]
519pub async fn get_remediations(
520 State(state): State<ThreatModelingState>,
521 Query(_params): Query<serde_json::Value>,
522) -> Result<Json<serde_json::Value>, StatusCode> {
523 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
524 Some(pool) => pool,
525 None => {
526 return Ok(Json(serde_json::json!({
527 "remediations": []
528 })));
529 }
530 };
531
532 let rows = sqlx::query(
534 "SELECT remediation_suggestions FROM contract_threat_assessments
535 WHERE remediation_suggestions IS NOT NULL AND jsonb_array_length(remediation_suggestions) > 0
536 ORDER BY assessed_at DESC LIMIT 50",
537 )
538 .fetch_all(pool)
539 .await
540 .map_err(|e| {
541 tracing::error!("Failed to query remediations: {}", e);
542 StatusCode::INTERNAL_SERVER_ERROR
543 })?;
544
545 use sqlx::Row;
547 let mut remediations = Vec::new();
548 for row in rows {
549 let remediations_json: serde_json::Value = row.try_get("remediation_suggestions").map_err(|e| {
550 tracing::error!("Failed to get remediation_suggestions from row: {}", e);
551 StatusCode::INTERNAL_SERVER_ERROR
552 })?;
553 if let serde_json::Value::Array(remediation_array) = remediations_json {
554 for remediation in remediation_array {
555 remediations.push(remediation);
556 }
557 }
558 }
559
560 Ok(Json(serde_json::json!({
561 "remediations": remediations,
562 "total": remediations.len()
563 })))
564}
565
566#[cfg(not(feature = "database"))]
570pub async fn get_remediations(
571 State(_state): State<ThreatModelingState>,
572 Query(_params): Query<serde_json::Value>,
573) -> Result<Json<serde_json::Value>, StatusCode> {
574 Ok(Json(serde_json::json!({
575 "remediations": []
576 })))
577}
578
579async fn trigger_threat_assessment_webhooks(
581 webhook_configs: &[mockforge_core::incidents::integrations::WebhookConfig],
582 assessment: &ThreatAssessment,
583) {
584 use mockforge_core::incidents::integrations::send_webhook;
585 use serde_json::json;
586
587 for webhook in webhook_configs {
588 if !webhook.enabled {
589 continue;
590 }
591
592 let event_type = "threat.assessment.completed";
593 if !webhook.events.is_empty() && !webhook.events.contains(&event_type.to_string()) {
594 continue;
595 }
596
597 let payload = json!({
598 "event": event_type,
599 "assessment": {
600 "workspace_id": assessment.workspace_id,
601 "service_id": assessment.service_id,
602 "service_name": assessment.service_name,
603 "endpoint": assessment.endpoint,
604 "method": assessment.method,
605 "threat_level": format!("{:?}", assessment.threat_level),
606 "threat_score": assessment.threat_score,
607 "findings_count": assessment.findings.len(),
608 "assessed_at": assessment.assessed_at,
609 }
610 });
611
612 let webhook_clone = webhook.clone();
613 tokio::spawn(async move {
614 if let Err(e) = send_webhook(&webhook_clone, &payload).await {
615 tracing::warn!("Failed to send threat assessment webhook: {}", e);
616 }
617 });
618 }
619}
620
621pub fn threat_modeling_router(state: ThreatModelingState) -> axum::Router {
623 use axum::routing::{get, post};
624 use axum::Router;
625
626 Router::new()
627 .route("/api/v1/threats/workspace/{workspace_id}", get(get_workspace_threats))
628 .route("/api/v1/threats/service/{service_id}", get(get_service_threats))
629 .route("/api/v1/threats/endpoint/{endpoint}", get(get_endpoint_threats))
630 .route("/api/v1/threats/assess", post(assess_threats))
631 .route("/api/v1/threats/findings", get(list_findings))
632 .route("/api/v1/threats/remediations", get(get_remediations))
633 .with_state(state)
634}