1use crate::Error;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum RiskCategory {
16 Technical,
18 Operational,
20 Compliance,
22 Business,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum Likelihood {
30 Rare = 1,
32 Unlikely = 2,
34 Possible = 3,
36 Likely = 4,
38 AlmostCertain = 5,
40}
41
42impl Likelihood {
43 pub fn value(&self) -> u8 {
45 *self as u8
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
51#[serde(rename_all = "lowercase")]
52pub enum Impact {
53 Negligible = 1,
55 Low = 2,
57 Medium = 3,
59 High = 4,
61 Critical = 5,
63}
64
65impl Impact {
66 pub fn value(&self) -> u8 {
68 *self as u8
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
74#[serde(rename_all = "lowercase")]
75pub enum RiskLevel {
76 Low,
78 Medium,
80 High,
82 Critical,
84}
85
86impl RiskLevel {
87 pub fn from_score(score: u8) -> Self {
89 match score {
90 1..=5 => RiskLevel::Low,
91 6..=11 => RiskLevel::Medium,
92 12..=19 => RiskLevel::High,
93 20..=25 => RiskLevel::Critical,
94 _ => RiskLevel::Low,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102pub enum TreatmentOption {
103 Avoid,
105 Mitigate,
107 Transfer,
109 Accept,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum TreatmentStatus {
117 NotStarted,
119 InProgress,
121 Completed,
123 OnHold,
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
129#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
130#[serde(rename_all = "lowercase")]
131pub enum RiskReviewFrequency {
132 Monthly,
134 Quarterly,
136 Annually,
138 AdHoc,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct RiskReview {
145 pub reviewed_at: DateTime<Utc>,
147 pub reviewed_by: Uuid,
149 pub notes: Option<String>,
151 pub status: TreatmentStatus,
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct Risk {
158 pub risk_id: String,
160 pub title: String,
162 pub description: String,
164 pub category: RiskCategory,
166 pub subcategory: Option<String>,
168 pub likelihood: Likelihood,
170 pub impact: Impact,
172 pub risk_score: u8,
174 pub risk_level: RiskLevel,
176 pub threat: Option<String>,
178 pub vulnerability: Option<String>,
180 pub asset: Option<String>,
182 pub existing_controls: Vec<String>,
184 pub treatment_option: TreatmentOption,
186 pub treatment_plan: Vec<String>,
188 pub treatment_owner: Option<String>,
190 pub treatment_deadline: Option<DateTime<Utc>>,
192 pub treatment_status: TreatmentStatus,
194 pub residual_likelihood: Option<Likelihood>,
196 pub residual_impact: Option<Impact>,
198 pub residual_risk_score: Option<u8>,
200 pub residual_risk_level: Option<RiskLevel>,
202 pub last_reviewed: Option<DateTime<Utc>>,
204 pub reviewed_by: Option<Uuid>,
206 pub review_history: Vec<RiskReview>,
208 pub next_review: Option<DateTime<Utc>>,
210 pub review_frequency: RiskReviewFrequency,
212 pub compliance_requirements: Vec<String>,
214 pub created_at: DateTime<Utc>,
216 pub updated_at: DateTime<Utc>,
218 pub created_by: Uuid,
220}
221
222impl Risk {
223 pub fn new(
225 risk_id: String,
226 title: String,
227 description: String,
228 category: RiskCategory,
229 likelihood: Likelihood,
230 impact: Impact,
231 created_by: Uuid,
232 ) -> Self {
233 let risk_score = likelihood.value() * impact.value();
234 let risk_level = RiskLevel::from_score(risk_score);
235
236 Self {
237 risk_id,
238 title,
239 description,
240 category,
241 subcategory: None,
242 likelihood,
243 impact,
244 risk_score,
245 risk_level,
246 threat: None,
247 vulnerability: None,
248 asset: None,
249 existing_controls: Vec::new(),
250 treatment_option: TreatmentOption::Accept,
251 treatment_plan: Vec::new(),
252 treatment_owner: None,
253 treatment_deadline: None,
254 treatment_status: TreatmentStatus::NotStarted,
255 residual_likelihood: None,
256 residual_impact: None,
257 residual_risk_score: None,
258 residual_risk_level: None,
259 last_reviewed: None,
260 reviewed_by: None,
261 review_history: Vec::new(),
262 next_review: None,
263 review_frequency: RiskReviewFrequency::Quarterly,
264 compliance_requirements: Vec::new(),
265 created_at: Utc::now(),
266 updated_at: Utc::now(),
267 created_by,
268 }
269 }
270
271 pub fn recalculate(&mut self) {
273 self.risk_score = self.likelihood.value() * self.impact.value();
274 self.risk_level = RiskLevel::from_score(self.risk_score);
275
276 if let (Some(res_likelihood), Some(res_impact)) =
277 (self.residual_likelihood, self.residual_impact)
278 {
279 self.residual_risk_score = Some(res_likelihood.value() * res_impact.value());
280 self.residual_risk_level = self.residual_risk_score.map(RiskLevel::from_score);
281 }
282 }
283
284 pub fn calculate_next_review(&mut self) {
286 let now = Utc::now();
287 let next = match self.review_frequency {
288 RiskReviewFrequency::Monthly => now + chrono::Duration::days(30),
289 RiskReviewFrequency::Quarterly => now + chrono::Duration::days(90),
290 RiskReviewFrequency::Annually => now + chrono::Duration::days(365),
291 RiskReviewFrequency::AdHoc => now + chrono::Duration::days(90), };
293 self.next_review = Some(next);
294 }
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299pub struct RiskSummary {
300 pub total_risks: u32,
302 pub critical: u32,
304 pub high: u32,
306 pub medium: u32,
308 pub low: u32,
310 pub by_category: HashMap<RiskCategory, u32>,
312 pub by_treatment_status: HashMap<TreatmentStatus, u32>,
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
318#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
319pub struct RiskAssessmentConfig {
320 pub enabled: bool,
322 pub default_review_frequency: RiskReviewFrequency,
324 pub risk_tolerance: RiskTolerance,
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
330#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
331pub struct RiskTolerance {
332 pub max_acceptable_score: u8,
334 pub require_treatment_above: u8,
336}
337
338impl Default for RiskAssessmentConfig {
339 fn default() -> Self {
340 Self {
341 enabled: true,
342 default_review_frequency: RiskReviewFrequency::Quarterly,
343 risk_tolerance: RiskTolerance {
344 max_acceptable_score: 5, require_treatment_above: 11, },
347 }
348 }
349}
350
351pub struct RiskAssessmentEngine {
353 config: RiskAssessmentConfig,
354 risks: std::sync::Arc<tokio::sync::RwLock<HashMap<String, Risk>>>,
356 risk_id_counter: std::sync::Arc<tokio::sync::RwLock<u64>>,
358 persistence_path: Option<std::path::PathBuf>,
360}
361
362impl RiskAssessmentEngine {
363 pub fn new(config: RiskAssessmentConfig) -> Self {
365 Self {
366 config,
367 risks: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
368 risk_id_counter: std::sync::Arc::new(tokio::sync::RwLock::new(0)),
369 persistence_path: None,
370 }
371 }
372
373 pub async fn with_persistence<P: AsRef<std::path::Path>>(
375 config: RiskAssessmentConfig,
376 persistence_path: P,
377 ) -> Result<Self, Error> {
378 let path = persistence_path.as_ref().to_path_buf();
379 let mut engine = Self {
380 config,
381 risks: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
382 risk_id_counter: std::sync::Arc::new(tokio::sync::RwLock::new(0)),
383 persistence_path: Some(path.clone()),
384 };
385
386 engine.load_risks().await?;
388
389 Ok(engine)
390 }
391
392 async fn load_risks(&mut self) -> Result<(), Error> {
394 let path = match &self.persistence_path {
395 Some(p) => p,
396 None => return Ok(()), };
398
399 if !path.exists() {
400 return Ok(()); }
402
403 let content = tokio::fs::read_to_string(path)
404 .await
405 .map_err(|e| Error::Generic(format!("Failed to read risk register: {}", e)))?;
406
407 let risks: HashMap<String, Risk> = serde_json::from_str(&content)
408 .map_err(|e| Error::Generic(format!("Failed to parse risk register: {}", e)))?;
409
410 let max_id = risks
412 .keys()
413 .filter_map(|id| id.strip_prefix("RISK-").and_then(|num| num.parse::<u64>().ok()))
414 .max()
415 .unwrap_or(0);
416
417 let mut risk_map = self.risks.write().await;
418 *risk_map = risks;
419 drop(risk_map);
420
421 let mut counter = self.risk_id_counter.write().await;
422 *counter = max_id;
423 drop(counter);
424
425 Ok(())
426 }
427
428 async fn save_risks(&self) -> Result<(), Error> {
430 let path = match &self.persistence_path {
431 Some(p) => p,
432 None => return Ok(()), };
434
435 if let Some(parent) = path.parent() {
437 tokio::fs::create_dir_all(parent)
438 .await
439 .map_err(|e| Error::Generic(format!("Failed to create directory: {}", e)))?;
440 }
441
442 let risks = self.risks.read().await;
443 let content = serde_json::to_string_pretty(&*risks)
444 .map_err(|e| Error::Generic(format!("Failed to serialize risk register: {}", e)))?;
445
446 tokio::fs::write(path, content)
447 .await
448 .map_err(|e| Error::Generic(format!("Failed to write risk register: {}", e)))?;
449
450 Ok(())
451 }
452
453 async fn generate_risk_id(&self) -> String {
455 let mut counter = self.risk_id_counter.write().await;
456 *counter += 1;
457 format!("RISK-{:03}", *counter)
458 }
459
460 pub async fn create_risk(
462 &self,
463 title: String,
464 description: String,
465 category: RiskCategory,
466 likelihood: Likelihood,
467 impact: Impact,
468 created_by: Uuid,
469 ) -> Result<Risk, Error> {
470 let risk_id = self.generate_risk_id().await;
471 let mut risk = Risk::new(
472 risk_id.clone(),
473 title,
474 description,
475 category,
476 likelihood,
477 impact,
478 created_by,
479 );
480 risk.review_frequency = self.config.default_review_frequency;
481 risk.calculate_next_review();
482
483 let mut risks = self.risks.write().await;
484 risks.insert(risk_id, risk.clone());
485 drop(risks);
486
487 self.save_risks().await?;
489
490 Ok(risk)
491 }
492
493 pub async fn get_risk(&self, risk_id: &str) -> Result<Option<Risk>, Error> {
495 let risks = self.risks.read().await;
496 Ok(risks.get(risk_id).cloned())
497 }
498
499 pub async fn get_all_risks(&self) -> Result<Vec<Risk>, Error> {
501 let risks = self.risks.read().await;
502 Ok(risks.values().cloned().collect())
503 }
504
505 pub async fn get_risks_by_level(&self, level: RiskLevel) -> Result<Vec<Risk>, Error> {
507 let risks = self.risks.read().await;
508 Ok(risks.values().filter(|r| r.risk_level == level).cloned().collect())
509 }
510
511 pub async fn get_risks_by_category(&self, category: RiskCategory) -> Result<Vec<Risk>, Error> {
513 let risks = self.risks.read().await;
514 Ok(risks.values().filter(|r| r.category == category).cloned().collect())
515 }
516
517 pub async fn get_risks_by_treatment_status(
519 &self,
520 status: TreatmentStatus,
521 ) -> Result<Vec<Risk>, Error> {
522 let risks = self.risks.read().await;
523 Ok(risks.values().filter(|r| r.treatment_status == status).cloned().collect())
524 }
525
526 pub async fn update_risk(&self, risk_id: &str, mut risk: Risk) -> Result<(), Error> {
528 risk.recalculate();
529 risk.updated_at = Utc::now();
530
531 let mut risks = self.risks.write().await;
532 if risks.contains_key(risk_id) {
533 risks.insert(risk_id.to_string(), risk);
534 drop(risks);
535 self.save_risks().await?;
537 Ok(())
538 } else {
539 Err(Error::Generic("Risk not found".to_string()))
540 }
541 }
542
543 pub async fn update_risk_assessment(
545 &self,
546 risk_id: &str,
547 likelihood: Option<Likelihood>,
548 impact: Option<Impact>,
549 ) -> Result<(), Error> {
550 let mut risks = self.risks.write().await;
551 if let Some(risk) = risks.get_mut(risk_id) {
552 if let Some(l) = likelihood {
553 risk.likelihood = l;
554 }
555 if let Some(i) = impact {
556 risk.impact = i;
557 }
558 risk.recalculate();
559 risk.updated_at = Utc::now();
560 drop(risks);
561 self.save_risks().await?;
563 Ok(())
564 } else {
565 Err(Error::Generic("Risk not found".to_string()))
566 }
567 }
568
569 pub async fn update_treatment_plan(
571 &self,
572 risk_id: &str,
573 treatment_option: TreatmentOption,
574 treatment_plan: Vec<String>,
575 treatment_owner: Option<String>,
576 treatment_deadline: Option<DateTime<Utc>>,
577 ) -> Result<(), Error> {
578 let mut risks = self.risks.write().await;
579 if let Some(risk) = risks.get_mut(risk_id) {
580 risk.treatment_option = treatment_option;
581 risk.treatment_plan = treatment_plan;
582 risk.treatment_owner = treatment_owner;
583 risk.treatment_deadline = treatment_deadline;
584 risk.updated_at = Utc::now();
585 drop(risks);
586 self.save_risks().await?;
588 Ok(())
589 } else {
590 Err(Error::Generic("Risk not found".to_string()))
591 }
592 }
593
594 pub async fn update_treatment_status(
596 &self,
597 risk_id: &str,
598 status: TreatmentStatus,
599 ) -> Result<(), Error> {
600 let mut risks = self.risks.write().await;
601 if let Some(risk) = risks.get_mut(risk_id) {
602 risk.treatment_status = status;
603 risk.updated_at = Utc::now();
604 drop(risks);
605 self.save_risks().await?;
607 Ok(())
608 } else {
609 Err(Error::Generic("Risk not found".to_string()))
610 }
611 }
612
613 pub async fn set_residual_risk(
615 &self,
616 risk_id: &str,
617 residual_likelihood: Likelihood,
618 residual_impact: Impact,
619 ) -> Result<(), Error> {
620 let mut risks = self.risks.write().await;
621 if let Some(risk) = risks.get_mut(risk_id) {
622 risk.residual_likelihood = Some(residual_likelihood);
623 risk.residual_impact = Some(residual_impact);
624 risk.recalculate();
625 risk.updated_at = Utc::now();
626 drop(risks);
627 self.save_risks().await?;
629 Ok(())
630 } else {
631 Err(Error::Generic("Risk not found".to_string()))
632 }
633 }
634
635 pub async fn review_risk(&self, risk_id: &str, reviewed_by: Uuid) -> Result<(), Error> {
637 self.review_risk_with_notes(risk_id, reviewed_by, None).await
638 }
639
640 pub async fn review_risk_with_notes(
642 &self,
643 risk_id: &str,
644 reviewed_by: Uuid,
645 notes: Option<String>,
646 ) -> Result<(), Error> {
647 let mut risks = self.risks.write().await;
648 if let Some(risk) = risks.get_mut(risk_id) {
649 let now = Utc::now();
650
651 let review = RiskReview {
653 reviewed_at: now,
654 reviewed_by,
655 notes,
656 status: risk.treatment_status,
657 };
658
659 risk.last_reviewed = Some(now);
661 risk.reviewed_by = Some(reviewed_by);
662 risk.review_history.push(review);
663 risk.calculate_next_review();
664 risk.updated_at = now;
665
666 drop(risks);
667 self.save_risks().await?;
669 Ok(())
670 } else {
671 Err(Error::Generic("Risk not found".to_string()))
672 }
673 }
674
675 pub async fn get_risk_summary(&self) -> Result<RiskSummary, Error> {
677 let risks = self.risks.read().await;
678
679 let mut summary = RiskSummary {
680 total_risks: risks.len() as u32,
681 critical: 0,
682 high: 0,
683 medium: 0,
684 low: 0,
685 by_category: HashMap::new(),
686 by_treatment_status: HashMap::new(),
687 };
688
689 for risk in risks.values() {
690 match risk.risk_level {
691 RiskLevel::Critical => summary.critical += 1,
692 RiskLevel::High => summary.high += 1,
693 RiskLevel::Medium => summary.medium += 1,
694 RiskLevel::Low => summary.low += 1,
695 }
696
697 *summary.by_category.entry(risk.category).or_insert(0) += 1;
698 let count = summary.by_treatment_status.entry(risk.treatment_status).or_insert(0);
699 *count += 1;
700 }
701
702 Ok(summary)
703 }
704
705 pub async fn get_risks_due_for_review(&self) -> Result<Vec<Risk>, Error> {
707 let risks = self.risks.read().await;
708 let now = Utc::now();
709
710 Ok(risks
711 .values()
712 .filter(|r| r.next_review.map(|next| next <= now).unwrap_or(false))
713 .cloned()
714 .collect())
715 }
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 #[tokio::test]
723 async fn test_risk_creation() {
724 let config = RiskAssessmentConfig::default();
725 let engine = RiskAssessmentEngine::new(config);
726
727 let risk = engine
728 .create_risk(
729 "Test Risk".to_string(),
730 "Test description".to_string(),
731 RiskCategory::Technical,
732 Likelihood::Possible,
733 Impact::High,
734 Uuid::new_v4(),
735 )
736 .await
737 .unwrap();
738
739 assert_eq!(risk.risk_score, 12); assert_eq!(risk.risk_level, RiskLevel::High);
741 }
742
743 #[test]
744 fn test_risk_level_calculation() {
745 assert_eq!(RiskLevel::from_score(3), RiskLevel::Low);
746 assert_eq!(RiskLevel::from_score(9), RiskLevel::Medium);
747 assert_eq!(RiskLevel::from_score(15), RiskLevel::High);
748 assert_eq!(RiskLevel::from_score(22), RiskLevel::Critical);
749 }
750}