1#![allow(deprecated)]
12
13use crate::threat_modeling::ThreatAnalyzer;
14use axum::{
15 extract::{Path, Query, State},
16 http::StatusCode,
17 response::Json,
18};
19use mockforge_foundation::threat_modeling_types::ThreatAssessment;
20use mockforge_openapi::OpenApiSpec;
21use serde::Deserialize;
22use std::sync::Arc;
23
24#[cfg(feature = "database")]
25use chrono::{DateTime, Utc};
26#[cfg(feature = "database")]
27use uuid::Uuid;
28
29#[cfg(feature = "database")]
30use crate::database::Database;
31
32#[cfg(feature = "database")]
34fn map_row_to_threat_assessment(
35 row: &sqlx::postgres::PgRow,
36) -> Result<ThreatAssessment, sqlx::Error> {
37 use mockforge_foundation::threat_modeling_types::{
38 AggregationLevel, RemediationSuggestion, ThreatCategory, ThreatFinding, ThreatLevel,
39 };
40 use sqlx::Row;
41
42 let workspace_id: Option<Uuid> = row.try_get("workspace_id")?;
44 let service_id: Option<String> = row.try_get("service_id")?;
45 let service_name: Option<String> = row.try_get("service_name")?;
46 let endpoint: Option<String> = row.try_get("endpoint")?;
47 let method: Option<String> = row.try_get("method")?;
48 let aggregation_level_str: String = row.try_get("aggregation_level")?;
49 let threat_level_str: String = row.try_get("threat_level")?;
50 let threat_score: f64 = row.try_get("threat_score")?;
51 let assessed_at: DateTime<Utc> = row.try_get("assessed_at")?;
52
53 let aggregation_level = match aggregation_level_str.as_str() {
55 "workspace" => AggregationLevel::Workspace,
56 "service" => AggregationLevel::Service,
57 "endpoint" => AggregationLevel::Endpoint,
58 _ => return Err(sqlx::Error::Decode("Invalid aggregation_level".into())),
59 };
60
61 let threat_level = match threat_level_str.as_str() {
63 "low" => ThreatLevel::Low,
64 "medium" => ThreatLevel::Medium,
65 "high" => ThreatLevel::High,
66 "critical" => ThreatLevel::Critical,
67 _ => return Err(sqlx::Error::Decode("Invalid threat_level".into())),
68 };
69
70 let threat_categories_json: serde_json::Value = row.try_get("threat_categories")?;
72 let threat_categories: Vec<ThreatCategory> =
73 serde_json::from_value(threat_categories_json).unwrap_or_default();
74
75 let findings_json: serde_json::Value = row.try_get("findings")?;
76 let findings: Vec<ThreatFinding> = serde_json::from_value(findings_json).unwrap_or_default();
77
78 let remediations_json: serde_json::Value = row.try_get("remediation_suggestions")?;
79 let remediation_suggestions: Vec<RemediationSuggestion> =
80 serde_json::from_value(remediations_json).unwrap_or_default();
81
82 Ok(ThreatAssessment {
83 workspace_id: workspace_id.map(|u| u.to_string()),
84 service_id,
85 service_name,
86 endpoint,
87 method,
88 aggregation_level,
89 threat_level,
90 threat_score,
91 threat_categories,
92 findings,
93 remediation_suggestions,
94 assessed_at,
95 })
96}
97
98#[derive(Clone)]
100pub struct ThreatModelingState {
101 pub analyzer: Arc<ThreatAnalyzer>,
103 pub webhook_configs: Vec<crate::incidents::integrations::WebhookConfig>,
105 #[cfg(feature = "database")]
107 pub database: Option<Database>,
108}
109
110#[cfg(feature = "database")]
114pub async fn get_workspace_threats(
115 State(state): State<ThreatModelingState>,
116 Path(workspace_id): Path<String>,
117) -> Result<Json<ThreatAssessment>, StatusCode> {
118 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
119 Some(pool) => pool,
120 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
121 };
122
123 let row = sqlx::query(
125 "SELECT * FROM contract_threat_assessments
126 WHERE workspace_id = $1 AND aggregation_level = 'workspace'
127 ORDER BY assessed_at DESC LIMIT 1",
128 )
129 .bind(&workspace_id)
130 .fetch_optional(pool)
131 .await
132 .map_err(|e| {
133 tracing::error!("Failed to query workspace threats: {}", e);
134 StatusCode::INTERNAL_SERVER_ERROR
135 })?;
136
137 match row {
138 Some(row) => match map_row_to_threat_assessment(&row) {
139 Ok(assessment) => Ok(Json(assessment)),
140 Err(e) => {
141 tracing::error!("Failed to map threat assessment: {}", e);
142 Err(StatusCode::INTERNAL_SERVER_ERROR)
143 }
144 },
145 None => Err(StatusCode::NOT_FOUND),
146 }
147}
148
149#[cfg(not(feature = "database"))]
153pub async fn get_workspace_threats(
154 State(_state): State<ThreatModelingState>,
155 Path(_workspace_id): Path<String>,
156) -> Result<Json<ThreatAssessment>, StatusCode> {
157 Err(StatusCode::SERVICE_UNAVAILABLE)
158}
159
160#[cfg(feature = "database")]
164pub async fn get_service_threats(
165 State(state): State<ThreatModelingState>,
166 Path(service_id): Path<String>,
167) -> Result<Json<ThreatAssessment>, StatusCode> {
168 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
169 Some(pool) => pool,
170 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
171 };
172
173 let row = sqlx::query(
174 "SELECT * FROM contract_threat_assessments
175 WHERE service_id = $1 AND aggregation_level = 'service'
176 ORDER BY assessed_at DESC LIMIT 1",
177 )
178 .bind(&service_id)
179 .fetch_optional(pool)
180 .await
181 .map_err(|e| {
182 tracing::error!("Failed to query service threats: {}", e);
183 StatusCode::INTERNAL_SERVER_ERROR
184 })?;
185
186 match row {
187 Some(row) => match map_row_to_threat_assessment(&row) {
188 Ok(assessment) => Ok(Json(assessment)),
189 Err(e) => {
190 tracing::error!("Failed to map threat assessment: {}", e);
191 Err(StatusCode::INTERNAL_SERVER_ERROR)
192 }
193 },
194 None => Err(StatusCode::NOT_FOUND),
195 }
196}
197
198#[cfg(not(feature = "database"))]
202pub async fn get_service_threats(
203 State(_state): State<ThreatModelingState>,
204 Path(_service_id): Path<String>,
205) -> Result<Json<ThreatAssessment>, StatusCode> {
206 Err(StatusCode::SERVICE_UNAVAILABLE)
207}
208
209#[cfg(feature = "database")]
213pub async fn get_endpoint_threats(
214 State(state): State<ThreatModelingState>,
215 Path(endpoint): Path<String>,
216 Query(params): Query<serde_json::Value>,
217) -> Result<Json<ThreatAssessment>, StatusCode> {
218 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
219 Some(pool) => pool,
220 None => return Err(StatusCode::SERVICE_UNAVAILABLE),
221 };
222
223 let method = params.get("method").and_then(|v| v.as_str()).unwrap_or("%");
224
225 let row = sqlx::query(
226 "SELECT * FROM contract_threat_assessments
227 WHERE endpoint = $1 AND method LIKE $2 AND aggregation_level = 'endpoint'
228 ORDER BY assessed_at DESC LIMIT 1",
229 )
230 .bind(&endpoint)
231 .bind(method)
232 .fetch_optional(pool)
233 .await
234 .map_err(|e| {
235 tracing::error!("Failed to query endpoint threats: {}", e);
236 StatusCode::INTERNAL_SERVER_ERROR
237 })?;
238
239 match row {
240 Some(row) => match map_row_to_threat_assessment(&row) {
241 Ok(assessment) => Ok(Json(assessment)),
242 Err(e) => {
243 tracing::error!("Failed to map threat assessment: {}", e);
244 Err(StatusCode::INTERNAL_SERVER_ERROR)
245 }
246 },
247 None => Err(StatusCode::NOT_FOUND),
248 }
249}
250
251#[cfg(not(feature = "database"))]
255pub async fn get_endpoint_threats(
256 State(_state): State<ThreatModelingState>,
257 Path(_endpoint): Path<String>,
258 Query(_params): Query<serde_json::Value>,
259) -> Result<Json<ThreatAssessment>, StatusCode> {
260 Err(StatusCode::SERVICE_UNAVAILABLE)
261}
262
263#[derive(Debug, Deserialize)]
265pub struct AssessThreatsRequest {
266 pub spec: String,
268 pub workspace_id: Option<String>,
270 pub service_id: Option<String>,
272 pub service_name: Option<String>,
274 pub endpoint: Option<String>,
276 pub method: Option<String>,
278}
279
280pub async fn assess_threats(
284 State(state): State<ThreatModelingState>,
285 Json(request): Json<AssessThreatsRequest>,
286) -> Result<Json<ThreatAssessment>, StatusCode> {
287 let spec = match OpenApiSpec::from_string(&request.spec, None) {
289 Ok(spec) => spec,
290 Err(_) => return Err(StatusCode::BAD_REQUEST),
291 };
292
293 match state
295 .analyzer
296 .analyze_contract(
297 &spec,
298 request.workspace_id.clone(),
299 request.service_id.clone(),
300 request.service_name.clone(),
301 request.endpoint.clone(),
302 request.method.clone(),
303 )
304 .await
305 {
306 Ok(assessment) => {
307 #[cfg(feature = "database")]
309 if let Some(pool) = state.database.as_ref().and_then(|db| db.pool()) {
310 if let Err(e) = store_threat_assessment(pool, &assessment).await {
311 tracing::warn!("Failed to store threat assessment: {}", e);
312 }
313 }
314
315 trigger_threat_assessment_webhooks(&state.webhook_configs, &assessment).await;
317
318 Ok(Json(assessment))
319 }
320 Err(_) => Err(StatusCode::INTERNAL_SERVER_ERROR),
321 }
322}
323
324#[cfg(feature = "database")]
326async fn store_threat_assessment(
327 pool: &sqlx::PgPool,
328 assessment: &ThreatAssessment,
329) -> Result<(), sqlx::Error> {
330 let id = Uuid::new_v4();
331 let workspace_uuid = assessment.workspace_id.as_ref().and_then(|id| Uuid::parse_str(id).ok());
332 let service_uuid = assessment.service_id.as_ref().and_then(|id| Uuid::parse_str(id).ok());
333
334 sqlx::query(
336 r#"
337 INSERT INTO contract_threat_assessments (
338 id, workspace_id, service_id, service_name, endpoint, method, aggregation_level,
339 threat_level, threat_score, threat_categories, findings, remediation_suggestions,
340 assessed_at, last_updated, created_at
341 ) VALUES (
342 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15
343 )
344 ON CONFLICT (workspace_id, service_id, endpoint, method, aggregation_level)
345 DO UPDATE SET
346 threat_level = EXCLUDED.threat_level,
347 threat_score = EXCLUDED.threat_score,
348 threat_categories = EXCLUDED.threat_categories,
349 findings = EXCLUDED.findings,
350 remediation_suggestions = EXCLUDED.remediation_suggestions,
351 assessed_at = EXCLUDED.assessed_at,
352 last_updated = EXCLUDED.last_updated
353 "#,
354 )
355 .bind(id)
356 .bind(workspace_uuid)
357 .bind(service_uuid)
358 .bind(assessment.service_name.as_deref())
359 .bind(assessment.endpoint.as_deref())
360 .bind(assessment.method.as_deref())
361 .bind(format!("{:?}", assessment.aggregation_level))
362 .bind(format!("{:?}", assessment.threat_level))
363 .bind(assessment.threat_score)
364 .bind(serde_json::to_value(&assessment.threat_categories).unwrap_or_default())
365 .bind(serde_json::to_value(&assessment.findings).unwrap_or_default())
366 .bind(serde_json::to_value(&assessment.remediation_suggestions).unwrap_or_default())
367 .bind(assessment.assessed_at)
368 .bind(Utc::now())
369 .bind(assessment.assessed_at)
370 .execute(pool)
371 .await?;
372
373 for finding in &assessment.findings {
375 let finding_id = Uuid::new_v4();
376 sqlx::query(
377 r#"
378 INSERT INTO threat_findings (
379 id, assessment_id, finding_type, severity, description, field_path,
380 context, remediation_suggestion, remediation_code_example, confidence,
381 ai_generated_remediation, detected_at, created_at, updated_at
382 ) VALUES (
383 $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14
384 )
385 "#,
386 )
387 .bind(finding_id)
388 .bind(id)
389 .bind(format!("{:?}", finding.finding_type))
390 .bind(format!("{:?}", finding.severity))
391 .bind(&finding.description)
392 .bind(finding.field_path.as_deref())
393 .bind(serde_json::to_value(&finding.context).unwrap_or_default())
394 .bind(None::<String>) .bind(None::<String>) .bind(finding.confidence)
397 .bind(false) .bind(Utc::now())
399 .bind(Utc::now())
400 .bind(Utc::now())
401 .execute(pool)
402 .await?;
403 }
404
405 Ok(())
406}
407
408#[cfg(feature = "database")]
412pub async fn list_findings(
413 State(state): State<ThreatModelingState>,
414 Query(_params): Query<serde_json::Value>,
415) -> Result<Json<serde_json::Value>, StatusCode> {
416 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
417 Some(pool) => pool,
418 None => {
419 return Ok(Json(serde_json::json!({
420 "findings": []
421 })));
422 }
423 };
424
425 let rows = sqlx::query(
426 "SELECT tf.*, ta.workspace_id, ta.service_id, ta.endpoint, ta.method
427 FROM threat_findings tf
428 JOIN contract_threat_assessments ta ON tf.assessment_id = ta.id
429 ORDER BY tf.detected_at DESC LIMIT 100",
430 )
431 .fetch_all(pool)
432 .await
433 .map_err(|e| {
434 tracing::error!("Failed to query threat findings: {}", e);
435 StatusCode::INTERNAL_SERVER_ERROR
436 })?;
437
438 use sqlx::Row;
440 let mut findings = Vec::new();
441 for row in rows {
442 let finding_id: Uuid = row.try_get("id").map_err(|e| {
443 tracing::error!("Failed to get finding id from row: {}", e);
444 StatusCode::INTERNAL_SERVER_ERROR
445 })?;
446 let finding_type_str: String = row.try_get("finding_type").map_err(|e| {
447 tracing::error!("Failed to get finding_type from row: {}", e);
448 StatusCode::INTERNAL_SERVER_ERROR
449 })?;
450 let severity_str: String = row.try_get("severity").map_err(|e| {
451 tracing::error!("Failed to get severity from row: {}", e);
452 StatusCode::INTERNAL_SERVER_ERROR
453 })?;
454 let description: String = row.try_get("description").map_err(|e| {
455 tracing::error!("Failed to get description from row: {}", e);
456 StatusCode::INTERNAL_SERVER_ERROR
457 })?;
458 let field_path: Option<String> = row.try_get("field_path").map_err(|e| {
459 tracing::error!("Failed to get field_path from row: {}", e);
460 StatusCode::INTERNAL_SERVER_ERROR
461 })?;
462 let context_json: serde_json::Value = row.try_get("context").map_err(|e| {
463 tracing::error!("Failed to get context from row: {}", e);
464 StatusCode::INTERNAL_SERVER_ERROR
465 })?;
466 let confidence: f64 = row.try_get("confidence").map_err(|e| {
467 tracing::error!("Failed to get confidence from row: {}", e);
468 StatusCode::INTERNAL_SERVER_ERROR
469 })?;
470
471 use mockforge_foundation::threat_modeling_types::{ThreatCategory, ThreatLevel};
472 use std::collections::HashMap;
473
474 let _finding_type = match finding_type_str.as_str() {
475 "pii_exposure" => ThreatCategory::PiiExposure,
476 "dos_risk" => ThreatCategory::DoSRisk,
477 "error_leakage" => ThreatCategory::ErrorLeakage,
478 "schema_inconsistency" => ThreatCategory::SchemaInconsistency,
479 "unbounded_arrays" => ThreatCategory::UnboundedArrays,
480 "missing_rate_limits" => ThreatCategory::MissingRateLimits,
481 "stack_trace_leakage" => ThreatCategory::StackTraceLeakage,
482 "sensitive_data_exposure" => ThreatCategory::SensitiveDataExposure,
483 "insecure_schema_design" => ThreatCategory::InsecureSchemaDesign,
484 "missing_validation" => ThreatCategory::MissingValidation,
485 "excessive_optional_fields" => ThreatCategory::ExcessiveOptionalFields,
486 _ => continue, };
488
489 let _severity = match severity_str.as_str() {
490 "low" => ThreatLevel::Low,
491 "medium" => ThreatLevel::Medium,
492 "high" => ThreatLevel::High,
493 "critical" => ThreatLevel::Critical,
494 _ => continue, };
496
497 let context: HashMap<String, serde_json::Value> =
498 serde_json::from_value(context_json).unwrap_or_default();
499
500 findings.push(serde_json::json!({
501 "id": finding_id.to_string(),
502 "finding_type": finding_type_str,
503 "severity": severity_str,
504 "description": description,
505 "field_path": field_path,
506 "context": context,
507 "confidence": confidence,
508 }));
509 }
510
511 Ok(Json(serde_json::json!({
512 "findings": findings,
513 "total": findings.len()
514 })))
515}
516
517#[cfg(not(feature = "database"))]
521pub async fn list_findings(
522 State(_state): State<ThreatModelingState>,
523 Query(_params): Query<serde_json::Value>,
524) -> Result<Json<serde_json::Value>, StatusCode> {
525 Ok(Json(serde_json::json!({
526 "findings": []
527 })))
528}
529
530#[cfg(feature = "database")]
534pub async fn get_remediations(
535 State(state): State<ThreatModelingState>,
536 Query(_params): Query<serde_json::Value>,
537) -> Result<Json<serde_json::Value>, StatusCode> {
538 let pool = match state.database.as_ref().and_then(|db| db.pool()) {
539 Some(pool) => pool,
540 None => {
541 return Ok(Json(serde_json::json!({
542 "remediations": []
543 })));
544 }
545 };
546
547 let rows = sqlx::query(
549 "SELECT remediation_suggestions FROM contract_threat_assessments
550 WHERE remediation_suggestions IS NOT NULL AND jsonb_array_length(remediation_suggestions) > 0
551 ORDER BY assessed_at DESC LIMIT 50",
552 )
553 .fetch_all(pool)
554 .await
555 .map_err(|e| {
556 tracing::error!("Failed to query remediations: {}", e);
557 StatusCode::INTERNAL_SERVER_ERROR
558 })?;
559
560 use sqlx::Row;
562 let mut remediations = Vec::new();
563 for row in rows {
564 let remediations_json: serde_json::Value =
565 row.try_get("remediation_suggestions").map_err(|e| {
566 tracing::error!("Failed to get remediation_suggestions from row: {}", e);
567 StatusCode::INTERNAL_SERVER_ERROR
568 })?;
569 if let serde_json::Value::Array(remediation_array) = remediations_json {
570 for remediation in remediation_array {
571 remediations.push(remediation);
572 }
573 }
574 }
575
576 Ok(Json(serde_json::json!({
577 "remediations": remediations,
578 "total": remediations.len()
579 })))
580}
581
582#[cfg(not(feature = "database"))]
586pub async fn get_remediations(
587 State(_state): State<ThreatModelingState>,
588 Query(_params): Query<serde_json::Value>,
589) -> Result<Json<serde_json::Value>, StatusCode> {
590 Ok(Json(serde_json::json!({
591 "remediations": []
592 })))
593}
594
595async fn trigger_threat_assessment_webhooks(
597 webhook_configs: &[crate::incidents::integrations::WebhookConfig],
598 assessment: &ThreatAssessment,
599) {
600 use crate::incidents::integrations::send_webhook;
601 use serde_json::json;
602
603 for webhook in webhook_configs {
604 if !webhook.enabled {
605 continue;
606 }
607
608 let event_type = "threat.assessment.completed";
609 if !webhook.events.is_empty() && !webhook.events.contains(&event_type.to_string()) {
610 continue;
611 }
612
613 let payload = json!({
614 "event": event_type,
615 "assessment": {
616 "workspace_id": assessment.workspace_id,
617 "service_id": assessment.service_id,
618 "service_name": assessment.service_name,
619 "endpoint": assessment.endpoint,
620 "method": assessment.method,
621 "threat_level": format!("{:?}", assessment.threat_level),
622 "threat_score": assessment.threat_score,
623 "findings_count": assessment.findings.len(),
624 "assessed_at": assessment.assessed_at,
625 }
626 });
627
628 let webhook_clone = webhook.clone();
629 tokio::spawn(async move {
630 if let Err(e) = send_webhook(&webhook_clone, &payload).await {
631 tracing::warn!("Failed to send threat assessment webhook: {}", e);
632 }
633 });
634 }
635}
636
637pub fn threat_modeling_router(state: ThreatModelingState) -> axum::Router {
639 use axum::routing::{get, post};
640 use axum::Router;
641
642 Router::new()
643 .route("/api/v1/threats/workspace/{workspace_id}", get(get_workspace_threats))
644 .route("/api/v1/threats/service/{service_id}", get(get_service_threats))
645 .route("/api/v1/threats/endpoint/{endpoint}", get(get_endpoint_threats))
646 .route("/api/v1/threats/assess", post(assess_threats))
647 .route("/api/v1/threats/findings", get(list_findings))
648 .route("/api/v1/threats/remediations", get(get_remediations))
649 .with_state(state)
650}