1use crate::Error;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use uuid::Uuid;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
14#[serde(rename_all = "snake_case")]
15pub enum RiskCategory {
16 Technical,
18 Operational,
20 Compliance,
22 Business,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
28#[serde(rename_all = "lowercase")]
29pub enum Likelihood {
30 Rare = 1,
32 Unlikely = 2,
34 Possible = 3,
36 Likely = 4,
38 AlmostCertain = 5,
40}
41
42impl Likelihood {
43 pub fn value(&self) -> u8 {
45 *self as u8
46 }
47}
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
51#[serde(rename_all = "lowercase")]
52pub enum Impact {
53 Negligible = 1,
55 Low = 2,
57 Medium = 3,
59 High = 4,
61 Critical = 5,
63}
64
65impl Impact {
66 pub fn value(&self) -> u8 {
68 *self as u8
69 }
70}
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
74#[serde(rename_all = "lowercase")]
75pub enum RiskLevel {
76 Low,
78 Medium,
80 High,
82 Critical,
84}
85
86impl RiskLevel {
87 pub fn from_score(score: u8) -> Self {
89 match score {
90 1..=5 => RiskLevel::Low,
91 6..=11 => RiskLevel::Medium,
92 12..=19 => RiskLevel::High,
93 20..=25 => RiskLevel::Critical,
94 _ => RiskLevel::Low,
95 }
96 }
97}
98
99#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
101#[serde(rename_all = "lowercase")]
102pub enum TreatmentOption {
103 Avoid,
105 Mitigate,
107 Transfer,
109 Accept,
111}
112
113#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
115#[serde(rename_all = "snake_case")]
116pub enum TreatmentStatus {
117 NotStarted,
119 InProgress,
121 Completed,
123 OnHold,
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
129#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
130#[serde(rename_all = "lowercase")]
131pub enum RiskReviewFrequency {
132 Monthly,
134 Quarterly,
136 Annually,
138 AdHoc,
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct Risk {
145 pub risk_id: String,
147 pub title: String,
149 pub description: String,
151 pub category: RiskCategory,
153 pub subcategory: Option<String>,
155 pub likelihood: Likelihood,
157 pub impact: Impact,
159 pub risk_score: u8,
161 pub risk_level: RiskLevel,
163 pub threat: Option<String>,
165 pub vulnerability: Option<String>,
167 pub asset: Option<String>,
169 pub existing_controls: Vec<String>,
171 pub treatment_option: TreatmentOption,
173 pub treatment_plan: Vec<String>,
175 pub treatment_owner: Option<String>,
177 pub treatment_deadline: Option<DateTime<Utc>>,
179 pub treatment_status: TreatmentStatus,
181 pub residual_likelihood: Option<Likelihood>,
183 pub residual_impact: Option<Impact>,
185 pub residual_risk_score: Option<u8>,
187 pub residual_risk_level: Option<RiskLevel>,
189 pub last_reviewed: Option<DateTime<Utc>>,
191 pub next_review: Option<DateTime<Utc>>,
193 pub review_frequency: RiskReviewFrequency,
195 pub compliance_requirements: Vec<String>,
197 pub created_at: DateTime<Utc>,
199 pub updated_at: DateTime<Utc>,
201 pub created_by: Uuid,
203}
204
205impl Risk {
206 pub fn new(
208 risk_id: String,
209 title: String,
210 description: String,
211 category: RiskCategory,
212 likelihood: Likelihood,
213 impact: Impact,
214 created_by: Uuid,
215 ) -> Self {
216 let risk_score = likelihood.value() * impact.value();
217 let risk_level = RiskLevel::from_score(risk_score);
218
219 Self {
220 risk_id,
221 title,
222 description,
223 category,
224 subcategory: None,
225 likelihood,
226 impact,
227 risk_score,
228 risk_level,
229 threat: None,
230 vulnerability: None,
231 asset: None,
232 existing_controls: Vec::new(),
233 treatment_option: TreatmentOption::Accept,
234 treatment_plan: Vec::new(),
235 treatment_owner: None,
236 treatment_deadline: None,
237 treatment_status: TreatmentStatus::NotStarted,
238 residual_likelihood: None,
239 residual_impact: None,
240 residual_risk_score: None,
241 residual_risk_level: None,
242 last_reviewed: None,
243 next_review: None,
244 review_frequency: RiskReviewFrequency::Quarterly,
245 compliance_requirements: Vec::new(),
246 created_at: Utc::now(),
247 updated_at: Utc::now(),
248 created_by,
249 }
250 }
251
252 pub fn recalculate(&mut self) {
254 self.risk_score = self.likelihood.value() * self.impact.value();
255 self.risk_level = RiskLevel::from_score(self.risk_score);
256
257 if let (Some(res_likelihood), Some(res_impact)) =
258 (self.residual_likelihood, self.residual_impact)
259 {
260 self.residual_risk_score = Some(res_likelihood.value() * res_impact.value());
261 self.residual_risk_level = self.residual_risk_score.map(RiskLevel::from_score);
262 }
263 }
264
265 pub fn calculate_next_review(&mut self) {
267 let now = Utc::now();
268 let next = match self.review_frequency {
269 RiskReviewFrequency::Monthly => now + chrono::Duration::days(30),
270 RiskReviewFrequency::Quarterly => now + chrono::Duration::days(90),
271 RiskReviewFrequency::Annually => now + chrono::Duration::days(365),
272 RiskReviewFrequency::AdHoc => now + chrono::Duration::days(90), };
274 self.next_review = Some(next);
275 }
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct RiskSummary {
281 pub total_risks: u32,
283 pub critical: u32,
285 pub high: u32,
287 pub medium: u32,
289 pub low: u32,
291 pub by_category: HashMap<RiskCategory, u32>,
293 pub by_treatment_status: HashMap<TreatmentStatus, u32>,
295}
296
297#[derive(Debug, Clone, Serialize, Deserialize)]
299#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
300pub struct RiskAssessmentConfig {
301 pub enabled: bool,
303 pub default_review_frequency: RiskReviewFrequency,
305 pub risk_tolerance: RiskTolerance,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
311#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
312pub struct RiskTolerance {
313 pub max_acceptable_score: u8,
315 pub require_treatment_above: u8,
317}
318
319impl Default for RiskAssessmentConfig {
320 fn default() -> Self {
321 Self {
322 enabled: true,
323 default_review_frequency: RiskReviewFrequency::Quarterly,
324 risk_tolerance: RiskTolerance {
325 max_acceptable_score: 5, require_treatment_above: 11, },
328 }
329 }
330}
331
332pub struct RiskAssessmentEngine {
334 config: RiskAssessmentConfig,
335 risks: std::sync::Arc<tokio::sync::RwLock<HashMap<String, Risk>>>,
337 risk_id_counter: std::sync::Arc<tokio::sync::RwLock<u64>>,
339 persistence_path: Option<std::path::PathBuf>,
341}
342
343impl RiskAssessmentEngine {
344 pub fn new(config: RiskAssessmentConfig) -> Self {
346 Self {
347 config,
348 risks: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
349 risk_id_counter: std::sync::Arc::new(tokio::sync::RwLock::new(0)),
350 persistence_path: None,
351 }
352 }
353
354 pub async fn with_persistence<P: AsRef<std::path::Path>>(
356 config: RiskAssessmentConfig,
357 persistence_path: P,
358 ) -> Result<Self, Error> {
359 let path = persistence_path.as_ref().to_path_buf();
360 let mut engine = Self {
361 config,
362 risks: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
363 risk_id_counter: std::sync::Arc::new(tokio::sync::RwLock::new(0)),
364 persistence_path: Some(path.clone()),
365 };
366
367 engine.load_risks().await?;
369
370 Ok(engine)
371 }
372
373 async fn load_risks(&mut self) -> Result<(), Error> {
375 let path = match &self.persistence_path {
376 Some(p) => p,
377 None => return Ok(()), };
379
380 if !path.exists() {
381 return Ok(()); }
383
384 let content = tokio::fs::read_to_string(path)
385 .await
386 .map_err(|e| Error::Generic(format!("Failed to read risk register: {}", e)))?;
387
388 let risks: HashMap<String, Risk> = serde_json::from_str(&content)
389 .map_err(|e| Error::Generic(format!("Failed to parse risk register: {}", e)))?;
390
391 let max_id = risks
393 .keys()
394 .filter_map(|id| id.strip_prefix("RISK-").and_then(|num| num.parse::<u64>().ok()))
395 .max()
396 .unwrap_or(0);
397
398 let mut risk_map = self.risks.write().await;
399 *risk_map = risks;
400 drop(risk_map);
401
402 let mut counter = self.risk_id_counter.write().await;
403 *counter = max_id;
404 drop(counter);
405
406 Ok(())
407 }
408
409 async fn save_risks(&self) -> Result<(), Error> {
411 let path = match &self.persistence_path {
412 Some(p) => p,
413 None => return Ok(()), };
415
416 if let Some(parent) = path.parent() {
418 tokio::fs::create_dir_all(parent)
419 .await
420 .map_err(|e| Error::Generic(format!("Failed to create directory: {}", e)))?;
421 }
422
423 let risks = self.risks.read().await;
424 let content = serde_json::to_string_pretty(&*risks)
425 .map_err(|e| Error::Generic(format!("Failed to serialize risk register: {}", e)))?;
426
427 tokio::fs::write(path, content)
428 .await
429 .map_err(|e| Error::Generic(format!("Failed to write risk register: {}", e)))?;
430
431 Ok(())
432 }
433
434 async fn generate_risk_id(&self) -> String {
436 let mut counter = self.risk_id_counter.write().await;
437 *counter += 1;
438 format!("RISK-{:03}", *counter)
439 }
440
441 pub async fn create_risk(
443 &self,
444 title: String,
445 description: String,
446 category: RiskCategory,
447 likelihood: Likelihood,
448 impact: Impact,
449 created_by: Uuid,
450 ) -> Result<Risk, Error> {
451 let risk_id = self.generate_risk_id().await;
452 let mut risk = Risk::new(
453 risk_id.clone(),
454 title,
455 description,
456 category,
457 likelihood,
458 impact,
459 created_by,
460 );
461 risk.review_frequency = self.config.default_review_frequency;
462 risk.calculate_next_review();
463
464 let mut risks = self.risks.write().await;
465 risks.insert(risk_id, risk.clone());
466 drop(risks);
467
468 self.save_risks().await?;
470
471 Ok(risk)
472 }
473
474 pub async fn get_risk(&self, risk_id: &str) -> Result<Option<Risk>, Error> {
476 let risks = self.risks.read().await;
477 Ok(risks.get(risk_id).cloned())
478 }
479
480 pub async fn get_all_risks(&self) -> Result<Vec<Risk>, Error> {
482 let risks = self.risks.read().await;
483 Ok(risks.values().cloned().collect())
484 }
485
486 pub async fn get_risks_by_level(&self, level: RiskLevel) -> Result<Vec<Risk>, Error> {
488 let risks = self.risks.read().await;
489 Ok(risks.values().filter(|r| r.risk_level == level).cloned().collect())
490 }
491
492 pub async fn get_risks_by_category(&self, category: RiskCategory) -> Result<Vec<Risk>, Error> {
494 let risks = self.risks.read().await;
495 Ok(risks.values().filter(|r| r.category == category).cloned().collect())
496 }
497
498 pub async fn get_risks_by_treatment_status(
500 &self,
501 status: TreatmentStatus,
502 ) -> Result<Vec<Risk>, Error> {
503 let risks = self.risks.read().await;
504 Ok(risks.values().filter(|r| r.treatment_status == status).cloned().collect())
505 }
506
507 pub async fn update_risk(&self, risk_id: &str, mut risk: Risk) -> Result<(), Error> {
509 risk.recalculate();
510 risk.updated_at = Utc::now();
511
512 let mut risks = self.risks.write().await;
513 if risks.contains_key(risk_id) {
514 risks.insert(risk_id.to_string(), risk);
515 drop(risks);
516 self.save_risks().await?;
518 Ok(())
519 } else {
520 Err(Error::Generic("Risk not found".to_string()))
521 }
522 }
523
524 pub async fn update_risk_assessment(
526 &self,
527 risk_id: &str,
528 likelihood: Option<Likelihood>,
529 impact: Option<Impact>,
530 ) -> Result<(), Error> {
531 let mut risks = self.risks.write().await;
532 if let Some(risk) = risks.get_mut(risk_id) {
533 if let Some(l) = likelihood {
534 risk.likelihood = l;
535 }
536 if let Some(i) = impact {
537 risk.impact = i;
538 }
539 risk.recalculate();
540 risk.updated_at = Utc::now();
541 drop(risks);
542 self.save_risks().await?;
544 Ok(())
545 } else {
546 Err(Error::Generic("Risk not found".to_string()))
547 }
548 }
549
550 pub async fn update_treatment_plan(
552 &self,
553 risk_id: &str,
554 treatment_option: TreatmentOption,
555 treatment_plan: Vec<String>,
556 treatment_owner: Option<String>,
557 treatment_deadline: Option<DateTime<Utc>>,
558 ) -> Result<(), Error> {
559 let mut risks = self.risks.write().await;
560 if let Some(risk) = risks.get_mut(risk_id) {
561 risk.treatment_option = treatment_option;
562 risk.treatment_plan = treatment_plan;
563 risk.treatment_owner = treatment_owner;
564 risk.treatment_deadline = treatment_deadline;
565 risk.updated_at = Utc::now();
566 drop(risks);
567 self.save_risks().await?;
569 Ok(())
570 } else {
571 Err(Error::Generic("Risk not found".to_string()))
572 }
573 }
574
575 pub async fn update_treatment_status(
577 &self,
578 risk_id: &str,
579 status: TreatmentStatus,
580 ) -> Result<(), Error> {
581 let mut risks = self.risks.write().await;
582 if let Some(risk) = risks.get_mut(risk_id) {
583 risk.treatment_status = status;
584 risk.updated_at = Utc::now();
585 drop(risks);
586 self.save_risks().await?;
588 Ok(())
589 } else {
590 Err(Error::Generic("Risk not found".to_string()))
591 }
592 }
593
594 pub async fn set_residual_risk(
596 &self,
597 risk_id: &str,
598 residual_likelihood: Likelihood,
599 residual_impact: Impact,
600 ) -> Result<(), Error> {
601 let mut risks = self.risks.write().await;
602 if let Some(risk) = risks.get_mut(risk_id) {
603 risk.residual_likelihood = Some(residual_likelihood);
604 risk.residual_impact = Some(residual_impact);
605 risk.recalculate();
606 risk.updated_at = Utc::now();
607 drop(risks);
608 self.save_risks().await?;
610 Ok(())
611 } else {
612 Err(Error::Generic("Risk not found".to_string()))
613 }
614 }
615
616 pub async fn review_risk(&self, risk_id: &str, reviewed_by: Uuid) -> Result<(), Error> {
618 let mut risks = self.risks.write().await;
619 if let Some(risk) = risks.get_mut(risk_id) {
620 risk.last_reviewed = Some(Utc::now());
621 risk.calculate_next_review();
622 risk.updated_at = Utc::now();
623 let _ = reviewed_by; drop(risks);
625 self.save_risks().await?;
627 Ok(())
628 } else {
629 Err(Error::Generic("Risk not found".to_string()))
630 }
631 }
632
633 pub async fn get_risk_summary(&self) -> Result<RiskSummary, Error> {
635 let risks = self.risks.read().await;
636
637 let mut summary = RiskSummary {
638 total_risks: risks.len() as u32,
639 critical: 0,
640 high: 0,
641 medium: 0,
642 low: 0,
643 by_category: HashMap::new(),
644 by_treatment_status: HashMap::new(),
645 };
646
647 for risk in risks.values() {
648 match risk.risk_level {
649 RiskLevel::Critical => summary.critical += 1,
650 RiskLevel::High => summary.high += 1,
651 RiskLevel::Medium => summary.medium += 1,
652 RiskLevel::Low => summary.low += 1,
653 }
654
655 *summary.by_category.entry(risk.category).or_insert(0) += 1;
656 let count = summary.by_treatment_status.entry(risk.treatment_status).or_insert(0);
657 *count += 1;
658 }
659
660 Ok(summary)
661 }
662
663 pub async fn get_risks_due_for_review(&self) -> Result<Vec<Risk>, Error> {
665 let risks = self.risks.read().await;
666 let now = Utc::now();
667
668 Ok(risks
669 .values()
670 .filter(|r| r.next_review.map(|next| next <= now).unwrap_or(false))
671 .cloned()
672 .collect())
673 }
674}
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679
680 #[tokio::test]
681 async fn test_risk_creation() {
682 let config = RiskAssessmentConfig::default();
683 let engine = RiskAssessmentEngine::new(config);
684
685 let risk = engine
686 .create_risk(
687 "Test Risk".to_string(),
688 "Test description".to_string(),
689 RiskCategory::Technical,
690 Likelihood::Possible,
691 Impact::High,
692 Uuid::new_v4(),
693 )
694 .await
695 .unwrap();
696
697 assert_eq!(risk.risk_score, 12); assert_eq!(risk.risk_level, RiskLevel::High);
699 }
700
701 #[test]
702 fn test_risk_level_calculation() {
703 assert_eq!(RiskLevel::from_score(3), RiskLevel::Low);
704 assert_eq!(RiskLevel::from_score(9), RiskLevel::Medium);
705 assert_eq!(RiskLevel::from_score(15), RiskLevel::High);
706 assert_eq!(RiskLevel::from_score(22), RiskLevel::Critical);
707 }
708}