1use crate::Error;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum RiskCategory {
16 Technical,
18 Operational,
20 Compliance,
22 Business,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum Likelihood {
30 Rare = 1,
32 Unlikely = 2,
34 Possible = 3,
36 Likely = 4,
38 AlmostCertain = 5,
40}
41
42impl Likelihood {
43 pub fn value(&self) -> u8 {
45 *self as u8
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
51#[serde(rename_all = "lowercase")]
52pub enum Impact {
53 Negligible = 1,
55 Low = 2,
57 Medium = 3,
59 High = 4,
61 Critical = 5,
63}
64
65impl Impact {
66 pub fn value(&self) -> u8 {
68 *self as u8
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
74#[serde(rename_all = "lowercase")]
75pub enum RiskLevel {
76 Low,
78 Medium,
80 High,
82 Critical,
84}
85
86impl RiskLevel {
87 pub fn from_score(score: u8) -> Self {
89 match score {
90 1..=5 => RiskLevel::Low,
91 6..=11 => RiskLevel::Medium,
92 12..=19 => RiskLevel::High,
93 20..=25 => RiskLevel::Critical,
94 _ => RiskLevel::Low,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102pub enum TreatmentOption {
103 Avoid,
105 Mitigate,
107 Transfer,
109 Accept,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum TreatmentStatus {
117 NotStarted,
119 InProgress,
121 Completed,
123 OnHold,
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
129#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
130#[serde(rename_all = "lowercase")]
131pub enum RiskReviewFrequency {
132 Monthly,
134 Quarterly,
136 Annually,
138 AdHoc,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct Risk {
145 pub risk_id: String,
147 pub title: String,
149 pub description: String,
151 pub category: RiskCategory,
153 pub subcategory: Option<String>,
155 pub likelihood: Likelihood,
157 pub impact: Impact,
159 pub risk_score: u8,
161 pub risk_level: RiskLevel,
163 pub threat: Option<String>,
165 pub vulnerability: Option<String>,
167 pub asset: Option<String>,
169 pub existing_controls: Vec<String>,
171 pub treatment_option: TreatmentOption,
173 pub treatment_plan: Vec<String>,
175 pub treatment_owner: Option<String>,
177 pub treatment_deadline: Option<DateTime<Utc>>,
179 pub treatment_status: TreatmentStatus,
181 pub residual_likelihood: Option<Likelihood>,
183 pub residual_impact: Option<Impact>,
185 pub residual_risk_score: Option<u8>,
187 pub residual_risk_level: Option<RiskLevel>,
189 pub last_reviewed: Option<DateTime<Utc>>,
191 pub next_review: Option<DateTime<Utc>>,
193 pub review_frequency: RiskReviewFrequency,
195 pub compliance_requirements: Vec<String>,
197 pub created_at: DateTime<Utc>,
199 pub updated_at: DateTime<Utc>,
201 pub created_by: Uuid,
203}
204
205impl Risk {
206 pub fn new(
208 risk_id: String,
209 title: String,
210 description: String,
211 category: RiskCategory,
212 likelihood: Likelihood,
213 impact: Impact,
214 created_by: Uuid,
215 ) -> Self {
216 let risk_score = likelihood.value() * impact.value();
217 let risk_level = RiskLevel::from_score(risk_score);
218
219 Self {
220 risk_id,
221 title,
222 description,
223 category,
224 subcategory: None,
225 likelihood,
226 impact,
227 risk_score,
228 risk_level,
229 threat: None,
230 vulnerability: None,
231 asset: None,
232 existing_controls: Vec::new(),
233 treatment_option: TreatmentOption::Accept,
234 treatment_plan: Vec::new(),
235 treatment_owner: None,
236 treatment_deadline: None,
237 treatment_status: TreatmentStatus::NotStarted,
238 residual_likelihood: None,
239 residual_impact: None,
240 residual_risk_score: None,
241 residual_risk_level: None,
242 last_reviewed: None,
243 next_review: None,
244 review_frequency: RiskReviewFrequency::Quarterly,
245 compliance_requirements: Vec::new(),
246 created_at: Utc::now(),
247 updated_at: Utc::now(),
248 created_by,
249 }
250 }
251
252 pub fn recalculate(&mut self) {
254 self.risk_score = self.likelihood.value() * self.impact.value();
255 self.risk_level = RiskLevel::from_score(self.risk_score);
256
257 if let (Some(res_likelihood), Some(res_impact)) =
258 (self.residual_likelihood, self.residual_impact)
259 {
260 self.residual_risk_score = Some(res_likelihood.value() * res_impact.value());
261 self.residual_risk_level = self.residual_risk_score.map(RiskLevel::from_score);
262 }
263 }
264
265 pub fn calculate_next_review(&mut self) {
267 let now = Utc::now();
268 let next = match self.review_frequency {
269 RiskReviewFrequency::Monthly => now + chrono::Duration::days(30),
270 RiskReviewFrequency::Quarterly => now + chrono::Duration::days(90),
271 RiskReviewFrequency::Annually => now + chrono::Duration::days(365),
272 RiskReviewFrequency::AdHoc => now + chrono::Duration::days(90), };
274 self.next_review = Some(next);
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct RiskSummary {
281 pub total_risks: u32,
283 pub critical: u32,
285 pub high: u32,
287 pub medium: u32,
289 pub low: u32,
291 pub by_category: HashMap<RiskCategory, u32>,
293 pub by_treatment_status: HashMap<TreatmentStatus, u32>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
300pub struct RiskAssessmentConfig {
301 pub enabled: bool,
303 pub default_review_frequency: RiskReviewFrequency,
305 pub risk_tolerance: RiskTolerance,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
312pub struct RiskTolerance {
313 pub max_acceptable_score: u8,
315 pub require_treatment_above: u8,
317}
318
319impl Default for RiskAssessmentConfig {
320 fn default() -> Self {
321 Self {
322 enabled: true,
323 default_review_frequency: RiskReviewFrequency::Quarterly,
324 risk_tolerance: RiskTolerance {
325 max_acceptable_score: 5, require_treatment_above: 11, },
328 }
329 }
330}
331
332pub struct RiskAssessmentEngine {
334 config: RiskAssessmentConfig,
335 risks: std::sync::Arc<tokio::sync::RwLock<HashMap<String, Risk>>>,
337 risk_id_counter: std::sync::Arc<tokio::sync::RwLock<u64>>,
339 persistence_path: Option<std::path::PathBuf>,
341}
342
343impl RiskAssessmentEngine {
344 pub fn new(config: RiskAssessmentConfig) -> Self {
346 Self {
347 config,
348 risks: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
349 risk_id_counter: std::sync::Arc::new(tokio::sync::RwLock::new(0)),
350 persistence_path: None,
351 }
352 }
353
354 pub async fn with_persistence<P: AsRef<std::path::Path>>(
356 config: RiskAssessmentConfig,
357 persistence_path: P,
358 ) -> Result<Self, Error> {
359 let path = persistence_path.as_ref().to_path_buf();
360 let mut engine = Self {
361 config,
362 risks: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
363 risk_id_counter: std::sync::Arc::new(tokio::sync::RwLock::new(0)),
364 persistence_path: Some(path.clone()),
365 };
366
367 engine.load_risks().await?;
369
370 Ok(engine)
371 }
372
373 async fn load_risks(&mut self) -> Result<(), Error> {
375 let path = match &self.persistence_path {
376 Some(p) => p,
377 None => return Ok(()), };
379
380 if !path.exists() {
381 return Ok(()); }
383
384 let content = tokio::fs::read_to_string(path)
385 .await
386 .map_err(|e| Error::Generic(format!("Failed to read risk register: {}", e)))?;
387
388 let risks: HashMap<String, Risk> = serde_json::from_str(&content)
389 .map_err(|e| Error::Generic(format!("Failed to parse risk register: {}", e)))?;
390
391 let max_id = risks
393 .keys()
394 .filter_map(|id| {
395 id.strip_prefix("RISK-")
396 .and_then(|num| num.parse::<u64>().ok())
397 })
398 .max()
399 .unwrap_or(0);
400
401 let mut risk_map = self.risks.write().await;
402 *risk_map = risks;
403 drop(risk_map);
404
405 let mut counter = self.risk_id_counter.write().await;
406 *counter = max_id;
407 drop(counter);
408
409 Ok(())
410 }
411
412 async fn save_risks(&self) -> Result<(), Error> {
414 let path = match &self.persistence_path {
415 Some(p) => p,
416 None => return Ok(()), };
418
419 if let Some(parent) = path.parent() {
421 tokio::fs::create_dir_all(parent)
422 .await
423 .map_err(|e| Error::Generic(format!("Failed to create directory: {}", e)))?;
424 }
425
426 let risks = self.risks.read().await;
427 let content = serde_json::to_string_pretty(&*risks)
428 .map_err(|e| Error::Generic(format!("Failed to serialize risk register: {}", e)))?;
429
430 tokio::fs::write(path, content)
431 .await
432 .map_err(|e| Error::Generic(format!("Failed to write risk register: {}", e)))?;
433
434 Ok(())
435 }
436
437 async fn generate_risk_id(&self) -> String {
439 let mut counter = self.risk_id_counter.write().await;
440 *counter += 1;
441 format!("RISK-{:03}", *counter)
442 }
443
444 pub async fn create_risk(
446 &self,
447 title: String,
448 description: String,
449 category: RiskCategory,
450 likelihood: Likelihood,
451 impact: Impact,
452 created_by: Uuid,
453 ) -> Result<Risk, Error> {
454 let risk_id = self.generate_risk_id().await;
455 let mut risk = Risk::new(
456 risk_id.clone(),
457 title,
458 description,
459 category,
460 likelihood,
461 impact,
462 created_by,
463 );
464 risk.review_frequency = self.config.default_review_frequency;
465 risk.calculate_next_review();
466
467 let mut risks = self.risks.write().await;
468 risks.insert(risk_id, risk.clone());
469 drop(risks);
470
471 self.save_risks().await?;
473
474 Ok(risk)
475 }
476
477 pub async fn get_risk(&self, risk_id: &str) -> Result<Option<Risk>, Error> {
479 let risks = self.risks.read().await;
480 Ok(risks.get(risk_id).cloned())
481 }
482
483 pub async fn get_all_risks(&self) -> Result<Vec<Risk>, Error> {
485 let risks = self.risks.read().await;
486 Ok(risks.values().cloned().collect())
487 }
488
489 pub async fn get_risks_by_level(&self, level: RiskLevel) -> Result<Vec<Risk>, Error> {
491 let risks = self.risks.read().await;
492 Ok(risks.values().filter(|r| r.risk_level == level).cloned().collect())
493 }
494
495 pub async fn get_risks_by_category(&self, category: RiskCategory) -> Result<Vec<Risk>, Error> {
497 let risks = self.risks.read().await;
498 Ok(risks.values().filter(|r| r.category == category).cloned().collect())
499 }
500
501 pub async fn get_risks_by_treatment_status(
503 &self,
504 status: TreatmentStatus,
505 ) -> Result<Vec<Risk>, Error> {
506 let risks = self.risks.read().await;
507 Ok(risks.values().filter(|r| r.treatment_status == status).cloned().collect())
508 }
509
510 pub async fn update_risk(&self, risk_id: &str, mut risk: Risk) -> Result<(), Error> {
512 risk.recalculate();
513 risk.updated_at = Utc::now();
514
515 let mut risks = self.risks.write().await;
516 if risks.contains_key(risk_id) {
517 risks.insert(risk_id.to_string(), risk);
518 drop(risks);
519 self.save_risks().await?;
521 Ok(())
522 } else {
523 Err(Error::Generic("Risk not found".to_string()))
524 }
525 }
526
527 pub async fn update_risk_assessment(
529 &self,
530 risk_id: &str,
531 likelihood: Option<Likelihood>,
532 impact: Option<Impact>,
533 ) -> Result<(), Error> {
534 let mut risks = self.risks.write().await;
535 if let Some(risk) = risks.get_mut(risk_id) {
536 if let Some(l) = likelihood {
537 risk.likelihood = l;
538 }
539 if let Some(i) = impact {
540 risk.impact = i;
541 }
542 risk.recalculate();
543 risk.updated_at = Utc::now();
544 drop(risks);
545 self.save_risks().await?;
547 Ok(())
548 } else {
549 Err(Error::Generic("Risk not found".to_string()))
550 }
551 }
552
553 pub async fn update_treatment_plan(
555 &self,
556 risk_id: &str,
557 treatment_option: TreatmentOption,
558 treatment_plan: Vec<String>,
559 treatment_owner: Option<String>,
560 treatment_deadline: Option<DateTime<Utc>>,
561 ) -> Result<(), Error> {
562 let mut risks = self.risks.write().await;
563 if let Some(risk) = risks.get_mut(risk_id) {
564 risk.treatment_option = treatment_option;
565 risk.treatment_plan = treatment_plan;
566 risk.treatment_owner = treatment_owner;
567 risk.treatment_deadline = treatment_deadline;
568 risk.updated_at = Utc::now();
569 drop(risks);
570 self.save_risks().await?;
572 Ok(())
573 } else {
574 Err(Error::Generic("Risk not found".to_string()))
575 }
576 }
577
578 pub async fn update_treatment_status(
580 &self,
581 risk_id: &str,
582 status: TreatmentStatus,
583 ) -> Result<(), Error> {
584 let mut risks = self.risks.write().await;
585 if let Some(risk) = risks.get_mut(risk_id) {
586 risk.treatment_status = status;
587 risk.updated_at = Utc::now();
588 drop(risks);
589 self.save_risks().await?;
591 Ok(())
592 } else {
593 Err(Error::Generic("Risk not found".to_string()))
594 }
595 }
596
597 pub async fn set_residual_risk(
599 &self,
600 risk_id: &str,
601 residual_likelihood: Likelihood,
602 residual_impact: Impact,
603 ) -> Result<(), Error> {
604 let mut risks = self.risks.write().await;
605 if let Some(risk) = risks.get_mut(risk_id) {
606 risk.residual_likelihood = Some(residual_likelihood);
607 risk.residual_impact = Some(residual_impact);
608 risk.recalculate();
609 risk.updated_at = Utc::now();
610 drop(risks);
611 self.save_risks().await?;
613 Ok(())
614 } else {
615 Err(Error::Generic("Risk not found".to_string()))
616 }
617 }
618
619 pub async fn review_risk(&self, risk_id: &str, reviewed_by: Uuid) -> Result<(), Error> {
621 let mut risks = self.risks.write().await;
622 if let Some(risk) = risks.get_mut(risk_id) {
623 risk.last_reviewed = Some(Utc::now());
624 risk.calculate_next_review();
625 risk.updated_at = Utc::now();
626 let _ = reviewed_by; drop(risks);
628 self.save_risks().await?;
630 Ok(())
631 } else {
632 Err(Error::Generic("Risk not found".to_string()))
633 }
634 }
635
636 pub async fn get_risk_summary(&self) -> Result<RiskSummary, Error> {
638 let risks = self.risks.read().await;
639
640 let mut summary = RiskSummary {
641 total_risks: risks.len() as u32,
642 critical: 0,
643 high: 0,
644 medium: 0,
645 low: 0,
646 by_category: HashMap::new(),
647 by_treatment_status: HashMap::new(),
648 };
649
650 for risk in risks.values() {
651 match risk.risk_level {
652 RiskLevel::Critical => summary.critical += 1,
653 RiskLevel::High => summary.high += 1,
654 RiskLevel::Medium => summary.medium += 1,
655 RiskLevel::Low => summary.low += 1,
656 }
657
658 *summary.by_category.entry(risk.category).or_insert(0) += 1;
659 let count = summary.by_treatment_status.entry(risk.treatment_status).or_insert(0);
660 *count += 1;
661 }
662
663 Ok(summary)
664 }
665
666 pub async fn get_risks_due_for_review(&self) -> Result<Vec<Risk>, Error> {
668 let risks = self.risks.read().await;
669 let now = Utc::now();
670
671 Ok(risks
672 .values()
673 .filter(|r| r.next_review.map(|next| next <= now).unwrap_or(false))
674 .cloned()
675 .collect())
676 }
677}
678
679#[cfg(test)]
680mod tests {
681 use super::*;
682
683 #[tokio::test]
684 async fn test_risk_creation() {
685 let config = RiskAssessmentConfig::default();
686 let engine = RiskAssessmentEngine::new(config);
687
688 let risk = engine
689 .create_risk(
690 "Test Risk".to_string(),
691 "Test description".to_string(),
692 RiskCategory::Technical,
693 Likelihood::Possible,
694 Impact::High,
695 Uuid::new_v4(),
696 )
697 .await
698 .unwrap();
699
700 assert_eq!(risk.risk_score, 12); assert_eq!(risk.risk_level, RiskLevel::High);
702 }
703
704 #[test]
705 fn test_risk_level_calculation() {
706 assert_eq!(RiskLevel::from_score(3), RiskLevel::Low);
707 assert_eq!(RiskLevel::from_score(9), RiskLevel::Medium);
708 assert_eq!(RiskLevel::from_score(15), RiskLevel::High);
709 assert_eq!(RiskLevel::from_score(22), RiskLevel::Critical);
710 }
711}