use crate::Error;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RiskCategory {
Technical,
Operational,
Compliance,
Business,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Likelihood {
Rare = 1,
Unlikely = 2,
Possible = 3,
Likely = 4,
AlmostCertain = 5,
}
impl Likelihood {
pub fn value(&self) -> u8 {
*self as u8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Impact {
Negligible = 1,
Low = 2,
Medium = 3,
High = 4,
Critical = 5,
}
impl Impact {
pub fn value(&self) -> u8 {
*self as u8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum RiskLevel {
Low,
Medium,
High,
Critical,
}
impl RiskLevel {
pub fn from_score(score: u8) -> Self {
match score {
1..=5 => RiskLevel::Low,
6..=11 => RiskLevel::Medium,
12..=19 => RiskLevel::High,
20..=25 => RiskLevel::Critical,
_ => RiskLevel::Low,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TreatmentOption {
Avoid,
Mitigate,
Transfer,
Accept,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum TreatmentStatus {
NotStarted,
InProgress,
Completed,
OnHold,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[serde(rename_all = "lowercase")]
pub enum RiskReviewFrequency {
Monthly,
Quarterly,
Annually,
AdHoc,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskReview {
pub reviewed_at: DateTime<Utc>,
pub reviewed_by: Uuid,
pub notes: Option<String>,
pub status: TreatmentStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Risk {
pub risk_id: String,
pub title: String,
pub description: String,
pub category: RiskCategory,
pub subcategory: Option<String>,
pub likelihood: Likelihood,
pub impact: Impact,
pub risk_score: u8,
pub risk_level: RiskLevel,
pub threat: Option<String>,
pub vulnerability: Option<String>,
pub asset: Option<String>,
pub existing_controls: Vec<String>,
pub treatment_option: TreatmentOption,
pub treatment_plan: Vec<String>,
pub treatment_owner: Option<String>,
pub treatment_deadline: Option<DateTime<Utc>>,
pub treatment_status: TreatmentStatus,
pub residual_likelihood: Option<Likelihood>,
pub residual_impact: Option<Impact>,
pub residual_risk_score: Option<u8>,
pub residual_risk_level: Option<RiskLevel>,
pub last_reviewed: Option<DateTime<Utc>>,
pub reviewed_by: Option<Uuid>,
pub review_history: Vec<RiskReview>,
pub next_review: Option<DateTime<Utc>>,
pub review_frequency: RiskReviewFrequency,
pub compliance_requirements: Vec<String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
pub created_by: Uuid,
}
impl Risk {
pub fn new(
risk_id: String,
title: String,
description: String,
category: RiskCategory,
likelihood: Likelihood,
impact: Impact,
created_by: Uuid,
) -> Self {
let risk_score = likelihood.value() * impact.value();
let risk_level = RiskLevel::from_score(risk_score);
Self {
risk_id,
title,
description,
category,
subcategory: None,
likelihood,
impact,
risk_score,
risk_level,
threat: None,
vulnerability: None,
asset: None,
existing_controls: Vec::new(),
treatment_option: TreatmentOption::Accept,
treatment_plan: Vec::new(),
treatment_owner: None,
treatment_deadline: None,
treatment_status: TreatmentStatus::NotStarted,
residual_likelihood: None,
residual_impact: None,
residual_risk_score: None,
residual_risk_level: None,
last_reviewed: None,
reviewed_by: None,
review_history: Vec::new(),
next_review: None,
review_frequency: RiskReviewFrequency::Quarterly,
compliance_requirements: Vec::new(),
created_at: Utc::now(),
updated_at: Utc::now(),
created_by,
}
}
pub fn recalculate(&mut self) {
self.risk_score = self.likelihood.value() * self.impact.value();
self.risk_level = RiskLevel::from_score(self.risk_score);
if let (Some(res_likelihood), Some(res_impact)) =
(self.residual_likelihood, self.residual_impact)
{
self.residual_risk_score = Some(res_likelihood.value() * res_impact.value());
self.residual_risk_level = self.residual_risk_score.map(RiskLevel::from_score);
}
}
pub fn calculate_next_review(&mut self) {
let now = Utc::now();
let next = match self.review_frequency {
RiskReviewFrequency::Monthly => now + chrono::Duration::days(30),
RiskReviewFrequency::Quarterly => now + chrono::Duration::days(90),
RiskReviewFrequency::Annually => now + chrono::Duration::days(365),
RiskReviewFrequency::AdHoc => now + chrono::Duration::days(90), };
self.next_review = Some(next);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RiskSummary {
pub total_risks: u32,
pub critical: u32,
pub high: u32,
pub medium: u32,
pub low: u32,
pub by_category: HashMap<RiskCategory, u32>,
pub by_treatment_status: HashMap<TreatmentStatus, u32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct RiskAssessmentConfig {
pub enabled: bool,
pub default_review_frequency: RiskReviewFrequency,
pub risk_tolerance: RiskTolerance,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
pub struct RiskTolerance {
pub max_acceptable_score: u8,
pub require_treatment_above: u8,
}
impl Default for RiskAssessmentConfig {
fn default() -> Self {
Self {
enabled: true,
default_review_frequency: RiskReviewFrequency::Quarterly,
risk_tolerance: RiskTolerance {
max_acceptable_score: 5, require_treatment_above: 11, },
}
}
}
pub struct RiskAssessmentEngine {
config: RiskAssessmentConfig,
risks: std::sync::Arc<tokio::sync::RwLock<HashMap<String, Risk>>>,
risk_id_counter: std::sync::Arc<tokio::sync::RwLock<u64>>,
persistence_path: Option<std::path::PathBuf>,
}
impl RiskAssessmentEngine {
pub fn new(config: RiskAssessmentConfig) -> Self {
Self {
config,
risks: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
risk_id_counter: std::sync::Arc::new(tokio::sync::RwLock::new(0)),
persistence_path: None,
}
}
pub async fn with_persistence<P: AsRef<std::path::Path>>(
config: RiskAssessmentConfig,
persistence_path: P,
) -> Result<Self, Error> {
let path = persistence_path.as_ref().to_path_buf();
let mut engine = Self {
config,
risks: std::sync::Arc::new(tokio::sync::RwLock::new(HashMap::new())),
risk_id_counter: std::sync::Arc::new(tokio::sync::RwLock::new(0)),
persistence_path: Some(path.clone()),
};
engine.load_risks().await?;
Ok(engine)
}
async fn load_risks(&mut self) -> Result<(), Error> {
let path = match &self.persistence_path {
Some(p) => p,
None => return Ok(()), };
if !path.exists() {
return Ok(()); }
let content = tokio::fs::read_to_string(path)
.await
.map_err(|e| Error::io_with_context("reading risk register", e.to_string()))?;
let risks: HashMap<String, Risk> = serde_json::from_str(&content)
.map_err(|e| Error::io_with_context("parsing risk register", e.to_string()))?;
let max_id = risks
.keys()
.filter_map(|id| id.strip_prefix("RISK-").and_then(|num| num.parse::<u64>().ok()))
.max()
.unwrap_or(0);
let mut risk_map = self.risks.write().await;
*risk_map = risks;
drop(risk_map);
let mut counter = self.risk_id_counter.write().await;
*counter = max_id;
drop(counter);
Ok(())
}
async fn save_risks(&self) -> Result<(), Error> {
let path = match &self.persistence_path {
Some(p) => p,
None => return Ok(()), };
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
Error::io_with_context("creating risk register directory", e.to_string())
})?;
}
let risks = self.risks.read().await;
let content = serde_json::to_string_pretty(&*risks)
.map_err(|e| Error::io_with_context("serializing risk register", e.to_string()))?;
tokio::fs::write(path, content)
.await
.map_err(|e| Error::io_with_context("writing risk register", e.to_string()))?;
Ok(())
}
async fn generate_risk_id(&self) -> String {
let mut counter = self.risk_id_counter.write().await;
*counter += 1;
format!("RISK-{:03}", *counter)
}
pub async fn create_risk(
&self,
title: String,
description: String,
category: RiskCategory,
likelihood: Likelihood,
impact: Impact,
created_by: Uuid,
) -> Result<Risk, Error> {
let risk_id = self.generate_risk_id().await;
let mut risk = Risk::new(
risk_id.clone(),
title,
description,
category,
likelihood,
impact,
created_by,
);
risk.review_frequency = self.config.default_review_frequency;
risk.calculate_next_review();
let mut risks = self.risks.write().await;
risks.insert(risk_id, risk.clone());
drop(risks);
self.save_risks().await?;
Ok(risk)
}
pub async fn get_risk(&self, risk_id: &str) -> Result<Option<Risk>, Error> {
let risks = self.risks.read().await;
Ok(risks.get(risk_id).cloned())
}
pub async fn get_all_risks(&self) -> Result<Vec<Risk>, Error> {
let risks = self.risks.read().await;
Ok(risks.values().cloned().collect())
}
pub async fn get_risks_by_level(&self, level: RiskLevel) -> Result<Vec<Risk>, Error> {
let risks = self.risks.read().await;
Ok(risks.values().filter(|r| r.risk_level == level).cloned().collect())
}
pub async fn get_risks_by_category(&self, category: RiskCategory) -> Result<Vec<Risk>, Error> {
let risks = self.risks.read().await;
Ok(risks.values().filter(|r| r.category == category).cloned().collect())
}
pub async fn get_risks_by_treatment_status(
&self,
status: TreatmentStatus,
) -> Result<Vec<Risk>, Error> {
let risks = self.risks.read().await;
Ok(risks.values().filter(|r| r.treatment_status == status).cloned().collect())
}
pub async fn update_risk(&self, risk_id: &str, mut risk: Risk) -> Result<(), Error> {
risk.recalculate();
risk.updated_at = Utc::now();
let mut risks = self.risks.write().await;
if risks.contains_key(risk_id) {
risks.insert(risk_id.to_string(), risk);
drop(risks);
self.save_risks().await?;
Ok(())
} else {
Err(Error::not_found("Risk", risk_id))
}
}
pub async fn update_risk_assessment(
&self,
risk_id: &str,
likelihood: Option<Likelihood>,
impact: Option<Impact>,
) -> Result<(), Error> {
let mut risks = self.risks.write().await;
if let Some(risk) = risks.get_mut(risk_id) {
if let Some(l) = likelihood {
risk.likelihood = l;
}
if let Some(i) = impact {
risk.impact = i;
}
risk.recalculate();
risk.updated_at = Utc::now();
drop(risks);
self.save_risks().await?;
Ok(())
} else {
Err(Error::not_found("Risk", risk_id))
}
}
pub async fn update_treatment_plan(
&self,
risk_id: &str,
treatment_option: TreatmentOption,
treatment_plan: Vec<String>,
treatment_owner: Option<String>,
treatment_deadline: Option<DateTime<Utc>>,
) -> Result<(), Error> {
let mut risks = self.risks.write().await;
if let Some(risk) = risks.get_mut(risk_id) {
risk.treatment_option = treatment_option;
risk.treatment_plan = treatment_plan;
risk.treatment_owner = treatment_owner;
risk.treatment_deadline = treatment_deadline;
risk.updated_at = Utc::now();
drop(risks);
self.save_risks().await?;
Ok(())
} else {
Err(Error::not_found("Risk", risk_id))
}
}
pub async fn update_treatment_status(
&self,
risk_id: &str,
status: TreatmentStatus,
) -> Result<(), Error> {
let mut risks = self.risks.write().await;
if let Some(risk) = risks.get_mut(risk_id) {
risk.treatment_status = status;
risk.updated_at = Utc::now();
drop(risks);
self.save_risks().await?;
Ok(())
} else {
Err(Error::not_found("Risk", risk_id))
}
}
pub async fn set_residual_risk(
&self,
risk_id: &str,
residual_likelihood: Likelihood,
residual_impact: Impact,
) -> Result<(), Error> {
let mut risks = self.risks.write().await;
if let Some(risk) = risks.get_mut(risk_id) {
risk.residual_likelihood = Some(residual_likelihood);
risk.residual_impact = Some(residual_impact);
risk.recalculate();
risk.updated_at = Utc::now();
drop(risks);
self.save_risks().await?;
Ok(())
} else {
Err(Error::not_found("Risk", risk_id))
}
}
pub async fn review_risk(&self, risk_id: &str, reviewed_by: Uuid) -> Result<(), Error> {
self.review_risk_with_notes(risk_id, reviewed_by, None).await
}
pub async fn review_risk_with_notes(
&self,
risk_id: &str,
reviewed_by: Uuid,
notes: Option<String>,
) -> Result<(), Error> {
let mut risks = self.risks.write().await;
if let Some(risk) = risks.get_mut(risk_id) {
let now = Utc::now();
let review = RiskReview {
reviewed_at: now,
reviewed_by,
notes,
status: risk.treatment_status,
};
risk.last_reviewed = Some(now);
risk.reviewed_by = Some(reviewed_by);
risk.review_history.push(review);
risk.calculate_next_review();
risk.updated_at = now;
drop(risks);
self.save_risks().await?;
Ok(())
} else {
Err(Error::not_found("Risk", risk_id))
}
}
pub async fn get_risk_summary(&self) -> Result<RiskSummary, Error> {
let risks = self.risks.read().await;
let mut summary = RiskSummary {
total_risks: risks.len() as u32,
critical: 0,
high: 0,
medium: 0,
low: 0,
by_category: HashMap::new(),
by_treatment_status: HashMap::new(),
};
for risk in risks.values() {
match risk.risk_level {
RiskLevel::Critical => summary.critical += 1,
RiskLevel::High => summary.high += 1,
RiskLevel::Medium => summary.medium += 1,
RiskLevel::Low => summary.low += 1,
}
*summary.by_category.entry(risk.category).or_insert(0) += 1;
let count = summary.by_treatment_status.entry(risk.treatment_status).or_insert(0);
*count += 1;
}
Ok(summary)
}
pub async fn get_risks_due_for_review(&self) -> Result<Vec<Risk>, Error> {
let risks = self.risks.read().await;
let now = Utc::now();
Ok(risks
.values()
.filter(|r| r.next_review.map(|next| next <= now).unwrap_or(false))
.cloned()
.collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_risk_creation() {
let config = RiskAssessmentConfig::default();
let engine = RiskAssessmentEngine::new(config);
let risk = engine
.create_risk(
"Test Risk".to_string(),
"Test description".to_string(),
RiskCategory::Technical,
Likelihood::Possible,
Impact::High,
Uuid::new_v4(),
)
.await
.unwrap();
assert_eq!(risk.risk_score, 12); assert_eq!(risk.risk_level, RiskLevel::High);
}
#[test]
fn test_risk_level_calculation() {
assert_eq!(RiskLevel::from_score(3), RiskLevel::Low);
assert_eq!(RiskLevel::from_score(9), RiskLevel::Medium);
assert_eq!(RiskLevel::from_score(15), RiskLevel::High);
assert_eq!(RiskLevel::from_score(22), RiskLevel::Critical);
}
}