use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::Duration;
use serde::Deserialize;
use rsigma_parser::Level;
use crate::Scope;
use crate::selector::{Selector, SelectorParseError};
use super::RiskLayer;
use super::accumulator::{IncidentConfig, RiskCaps};
use super::incident::IncludeMode;
use super::object::ObjectSelector;
use super::score::{Reducer, ScoreConfig};
#[derive(Debug, Clone, Default, Deserialize)]
pub struct RiskFile {
#[serde(default)]
pub strip_event: bool,
#[serde(default)]
pub scope: Option<ScopeConfig>,
#[serde(default)]
pub score: ScoreFile,
#[serde(default)]
pub objects: Vec<ObjectFile>,
#[serde(default)]
pub emit_risk_events: bool,
#[serde(default)]
pub nats_subject: Option<String>,
#[serde(default)]
pub incident: Option<IncidentFile>,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct IncidentFile {
#[serde(default, with = "humantime_opt")]
pub window: Option<Duration>,
#[serde(default)]
pub score_threshold: Option<i64>,
#[serde(default)]
pub tactic_count_threshold: Option<u64>,
#[serde(default, with = "humantime_opt")]
pub cooldown: Option<Duration>,
#[serde(default)]
pub include: IncludeLabel,
#[serde(default)]
pub nats_subject: Option<String>,
#[serde(default)]
pub caps: Option<RiskCapsFile>,
}
#[derive(Debug, Clone, Copy, Default, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum IncludeLabel {
#[default]
Refs,
Results,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct RiskCapsFile {
#[serde(default)]
pub max_open_entities: Option<usize>,
#[serde(default)]
pub max_sources_per_entity: Option<usize>,
#[serde(default)]
pub max_results_per_incident: Option<usize>,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct ScopeConfig {
#[serde(default)]
pub rules: Vec<String>,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub levels: Vec<String>,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct ScoreFile {
#[serde(default)]
pub attribute: Option<String>,
#[serde(default)]
pub tag_scores: HashMap<String, i64>,
#[serde(default)]
pub tag_reducer: ReducerLabel,
#[serde(default)]
pub level_scores: HashMap<Level, i64>,
#[serde(default)]
pub default_score: i64,
}
#[derive(Debug, Clone, Copy, Default, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ReducerLabel {
#[default]
Sum,
Max,
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct ObjectFile {
#[serde(rename = "type")]
pub object_type: String,
pub selector: String,
}
#[derive(Debug)]
pub enum RiskConfigError {
Io(std::io::Error, PathBuf),
Yaml(yaml_serde::Error),
Scope(String),
ObjectSelector(SelectorParseError),
EmptyObjectType,
NoObjects,
NoThreshold,
}
impl std::fmt::Display for RiskConfigError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RiskConfigError::Io(e, p) => {
write!(f, "failed to read risk config '{}': {e}", p.display())
}
RiskConfigError::Yaml(e) => write!(f, "invalid risk YAML: {e}"),
RiskConfigError::Scope(message) => write!(f, "scope: {message}"),
RiskConfigError::ObjectSelector(e) => write!(f, "objects.selector: {e}"),
RiskConfigError::EmptyObjectType => {
write!(f, "objects: each entry requires a non-empty `type`")
}
RiskConfigError::NoObjects => write!(
f,
"objects is empty; list at least one risk-object selector"
),
RiskConfigError::NoThreshold => write!(
f,
"incident is configured but neither score_threshold nor \
tactic_count_threshold is set; set at least one"
),
}
}
}
impl std::error::Error for RiskConfigError {}
pub fn load_risk_file(path: &Path) -> Result<RiskFile, RiskConfigError> {
let text =
std::fs::read_to_string(path).map_err(|e| RiskConfigError::Io(e, path.to_path_buf()))?;
yaml_serde::from_str(&text).map_err(RiskConfigError::Yaml)
}
pub fn parse_risk_config(text: &str) -> Result<RiskLayer, RiskConfigError> {
let file: RiskFile = yaml_serde::from_str(text).map_err(RiskConfigError::Yaml)?;
build_risk_layer(file)
}
pub fn build_risk_layer(file: RiskFile) -> Result<RiskLayer, RiskConfigError> {
let scope = match file.scope {
Some(s) => Scope::new(s.rules, s.tags, s.levels).map_err(RiskConfigError::Scope)?,
None => Scope::default(),
};
let reducer = match file.score.tag_reducer {
ReducerLabel::Sum => Reducer::Sum,
ReducerLabel::Max => Reducer::Max,
};
let score = ScoreConfig::new(
file.score.attribute,
file.score.tag_scores,
reducer,
file.score.level_scores,
file.score.default_score,
);
if file.objects.is_empty() {
return Err(RiskConfigError::NoObjects);
}
let mut objects = Vec::with_capacity(file.objects.len());
for obj in file.objects {
if obj.object_type.trim().is_empty() {
return Err(RiskConfigError::EmptyObjectType);
}
let selector = Selector::parse(&obj.selector).map_err(RiskConfigError::ObjectSelector)?;
objects.push(ObjectSelector {
object_type: obj.object_type,
selector,
});
}
let incident = match file.incident {
Some(i) => Some(build_incident_config(i)?),
None => None,
};
Ok(RiskLayer::new(
scope,
file.strip_event,
score,
objects,
file.emit_risk_events,
file.nats_subject,
incident,
))
}
const DEFAULT_WINDOW: Duration = Duration::from_secs(24 * 3600);
const DEFAULT_COOLDOWN: Duration = Duration::from_secs(3600);
fn build_incident_config(file: IncidentFile) -> Result<IncidentConfig, RiskConfigError> {
if file.score_threshold.is_none() && file.tactic_count_threshold.is_none() {
return Err(RiskConfigError::NoThreshold);
}
let include = match file.include {
IncludeLabel::Refs => IncludeMode::Refs,
IncludeLabel::Results => IncludeMode::Results,
};
let caps_file = file.caps.unwrap_or_default();
let defaults = RiskCaps::default();
let caps = RiskCaps {
max_open_entities: caps_file
.max_open_entities
.unwrap_or(defaults.max_open_entities),
max_sources_per_entity: caps_file
.max_sources_per_entity
.unwrap_or(defaults.max_sources_per_entity),
max_results_per_incident: caps_file
.max_results_per_incident
.unwrap_or(defaults.max_results_per_incident),
};
Ok(IncidentConfig {
window: file.window.unwrap_or(DEFAULT_WINDOW),
score_threshold: file.score_threshold,
tactic_count_threshold: file.tactic_count_threshold,
cooldown: file.cooldown.unwrap_or(DEFAULT_COOLDOWN),
include,
nats_subject: file.nats_subject,
caps,
})
}
mod humantime_opt {
use std::time::Duration;
use serde::{Deserialize, Deserializer};
pub fn deserialize<'de, D>(d: D) -> Result<Option<Duration>, D::Error>
where
D: Deserializer<'de>,
{
let raw: Option<String> = Option::deserialize(d)?;
match raw {
Some(s) => humantime::parse_duration(&s)
.map(Some)
.map_err(serde::de::Error::custom),
None => Ok(None),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn minimal_config_builds() {
let yaml = "objects:\n - type: user\n selector: enrichment.user\n";
parse_risk_config(yaml).unwrap();
}
#[test]
fn empty_objects_is_rejected() {
let err = parse_risk_config("score:\n default_score: 5\n").unwrap_err();
assert!(matches!(err, RiskConfigError::NoObjects));
}
#[test]
fn bad_object_selector_points_at_the_field() {
let yaml = "objects:\n - type: user\n selector: bogus.field\n";
let err = parse_risk_config(yaml).unwrap_err();
let msg = err.to_string();
assert!(msg.contains("objects.selector"), "got: {msg}");
assert!(msg.contains("bogus.field"), "got: {msg}");
}
#[test]
fn empty_object_type_is_rejected() {
let yaml = "objects:\n - type: \"\"\n selector: enrichment.user\n";
let err = parse_risk_config(yaml).unwrap_err();
assert!(matches!(err, RiskConfigError::EmptyObjectType));
}
#[test]
fn full_config_parses() {
let yaml = r#"
strip_event: true
scope:
levels: [low, medium, high, critical]
score:
tag_scores:
"attack.*": 10
crown-jewel: 50
tag_reducer: max
level_scores:
high: 40
critical: 80
default_score: 1
objects:
- type: user
selector: enrichment.user
- type: src_ip
selector: match.SourceIp
emit_risk_events: true
nats_subject: risk.events
"#;
parse_risk_config(yaml).unwrap();
}
#[test]
fn bad_scope_glob_is_rejected() {
let yaml = "scope:\n rules: [\"[unclosed\"]\nobjects:\n - type: user\n selector: enrichment.user\n";
let err = parse_risk_config(yaml).unwrap_err();
assert!(matches!(err, RiskConfigError::Scope(_)));
}
}