use crate::defaults::{
default_alert_threshold, default_ancestry_suspicious_weight, default_artifact_access_weight,
default_business_hours_end, default_business_hours_start, default_cooldown,
default_correlation_window, default_max_events, default_max_kills, default_off_hours_weight,
default_policy_version, default_rapid_enum_weight, default_suspicious_process_weight,
default_true,
};
use crate::secure_fs::{RestrictedInputKind, read_restricted_file};
use crate::timing::{TimingOperation, enforce_operation_min_timing};
use crate::{AgentError, POLICY_VERSION, Result};
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::time::Instant;
const CFG_PARSE_FAILED: u16 = 100;
const CFG_VALIDATION_FAILED: u16 = 101;
const CFG_MISSING_REQUIRED: u16 = 102;
const CFG_INVALID_VALUE: u16 = 103;
const CFG_VERSION_MISMATCH: u16 = 106;
#[derive(Debug, Serialize, Deserialize)]
pub struct PolicyConfig {
#[serde(default = "default_policy_version")]
pub version: u32,
pub scoring: ScoringPolicy,
pub response: ResponsePolicy,
pub deception: DeceptionPolicy,
#[serde(default)]
pub registered_custom_conditions: HashSet<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ScoringPolicy {
#[serde(default = "default_correlation_window")]
pub correlation_window_secs: u64,
#[serde(default = "default_alert_threshold")]
pub alert_threshold: f64,
#[serde(default = "default_max_events")]
pub max_events_in_memory: usize,
#[serde(default = "default_true")]
pub enable_time_scoring: bool,
#[serde(default = "default_true")]
pub enable_ancestry_tracking: bool,
#[serde(default)]
pub weights: ScoringWeights,
#[serde(default = "default_business_hours_start")]
pub business_hours_start: u8,
#[serde(default = "default_business_hours_end")]
pub business_hours_end: u8,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ScoringWeights {
#[serde(default = "default_artifact_access_weight")]
pub artifact_access: f64,
#[serde(default = "default_suspicious_process_weight")]
pub suspicious_process: f64,
#[serde(default = "default_rapid_enum_weight")]
pub rapid_enumeration: f64,
#[serde(default = "default_off_hours_weight")]
pub off_hours_activity: f64,
#[serde(default = "default_ancestry_suspicious_weight")]
pub ancestry_suspicious: f64,
}
impl Default for ScoringWeights {
fn default() -> Self {
Self {
artifact_access: 50.0,
suspicious_process: 30.0,
rapid_enumeration: 20.0,
off_hours_activity: 15.0,
ancestry_suspicious: 10.0,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponsePolicy {
pub rules: Vec<ResponseRule>,
#[serde(default = "default_cooldown")]
pub cooldown_secs: u64,
#[serde(default = "default_max_kills")]
pub max_kills_per_incident: usize,
#[serde(default)]
pub dry_run: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseRule {
pub severity: Severity,
#[serde(default)]
pub conditions: Vec<ResponseCondition>,
pub action: ActionType,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ResponseCondition {
MinConfidence {
threshold: f64,
},
NotParentedBy {
process_name: String,
},
MinSignalTypes {
count: usize,
},
RepeatCount {
count: usize,
window_secs: u64,
},
TimeWindow {
start_hour: u8,
end_hour: u8,
},
Custom {
name: String,
params: HashMap<String, String>,
},
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
#[serde(rename_all = "PascalCase")]
pub enum Severity {
Low,
Medium,
High,
Critical,
}
impl Severity {
#[must_use]
pub fn from_score(score: f64) -> Self {
if score >= 80.0 {
Self::Critical
} else if score >= 60.0 {
Self::High
} else if score >= 40.0 {
Self::Medium
} else {
Self::Low
}
}
}
impl std::fmt::Display for Severity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Low => write!(f, "Low"),
Self::Medium => write!(f, "Medium"),
Self::High => write!(f, "High"),
Self::Critical => write!(f, "Critical"),
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ActionType {
Log,
Alert,
KillProcess,
IsolateHost,
CustomScript {
path: PathBuf,
},
}
#[derive(Debug, Serialize, Deserialize)]
pub struct DeceptionPolicy {
#[serde(default, deserialize_with = "deserialize_lowercase_boxed")]
pub suspicious_processes: Box<[String]>,
#[serde(default, deserialize_with = "deserialize_boxed")]
pub suspicious_patterns: Box<[String]>,
}
fn deserialize_lowercase_boxed<'de, D>(
deserializer: D,
) -> std::result::Result<Box<[String]>, D::Error>
where
D: Deserializer<'de>,
{
let vec = Vec::<String>::deserialize(deserializer)?;
Ok(vec.into_iter().map(|s| s.to_lowercase()).collect())
}
fn deserialize_boxed<'de, D>(deserializer: D) -> std::result::Result<Box<[String]>, D::Error>
where
D: Deserializer<'de>,
{
let vec = Vec::<String>::deserialize(deserializer)?;
Ok(vec.into_boxed_slice())
}
impl PolicyConfig {
pub async fn from_file<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let started = Instant::now();
let path = path.as_ref();
let result = async {
let contents = read_restricted_file(path, RestrictedInputKind::Policy).await?;
Self::from_toml_str(&contents)
}
.await;
enforce_operation_min_timing(started, TimingOperation::PolicyLoad);
result
}
pub(crate) fn from_toml_str(contents: &str) -> Result<Self> {
let policy: PolicyConfig = toml::from_str(contents).map_err(|e| {
AgentError::new(
CFG_PARSE_FAILED,
"Configuration input could not be parsed",
format!("operation=parse_policy_toml; Policy TOML syntax error: {e}"),
"",
)
})?;
if policy.version > POLICY_VERSION {
return Err(AgentError::new(
CFG_VERSION_MISMATCH,
"Configuration version is not supported",
format!(
"operation=validate_policy_version; Policy version too new (agent: {POLICY_VERSION}, policy: {}). Upgrade agent; file_version={}; expected_version={POLICY_VERSION}",
policy.version, policy.version
),
"",
));
}
policy.validate()?;
Ok(policy)
}
pub fn validate(&self) -> Result<()> {
let started = Instant::now();
let result = (|| {
if !(0.0..=100.0).contains(&self.scoring.alert_threshold) {
return Err(AgentError::new(
CFG_INVALID_VALUE,
"Configuration contains an invalid value",
format!(
"operation=validate_policy_scoring; field=scoring.alert_threshold; reason=scoring.alert_threshold must be within valid range; actual_value={}; expected_range=0-100",
self.scoring.alert_threshold
),
"scoring.alert_threshold",
));
}
if self.scoring.correlation_window_secs == 0
|| self.scoring.correlation_window_secs > 3600
{
return Err(AgentError::new(
CFG_INVALID_VALUE,
"Configuration contains an invalid value",
format!(
"operation=validate_policy_scoring; field=scoring.correlation_window_secs; reason=scoring.correlation_window_secs must be within valid range; actual_value={}; expected_range=1-3600",
self.scoring.correlation_window_secs
),
"scoring.correlation_window_secs",
));
}
if self.scoring.max_events_in_memory == 0 || self.scoring.max_events_in_memory > 100_000
{
return Err(AgentError::new(
CFG_INVALID_VALUE,
"Configuration contains an invalid value",
format!(
"operation=validate_policy_scoring; field=scoring.max_events_in_memory; reason=scoring.max_events_in_memory must be within valid range; actual_value={}; expected_range=1-100000",
self.scoring.max_events_in_memory
),
"scoring.max_events_in_memory",
));
}
if self.response.rules.is_empty() {
return Err(AgentError::new(
CFG_MISSING_REQUIRED,
"Required configuration is missing",
"operation=validate_policy_response; response.rules cannot be empty; impact=no_response_actions",
"response.rules",
));
}
if self.response.cooldown_secs == 0 {
return Err(AgentError::new(
CFG_INVALID_VALUE,
"Configuration contains an invalid value",
"operation=validate_policy_response; field=response.cooldown_secs; response.cooldown_secs cannot be zero",
"response.cooldown_secs",
));
}
for idx in 0..self.response.rules.len() {
for prev in 0..idx {
if self.response.rules[idx].severity == self.response.rules[prev].severity {
return Err(AgentError::new(
CFG_VALIDATION_FAILED,
"Configuration validation failed",
format!(
"operation=validate_policy; Duplicate response rule for severity: {}",
self.response.rules[idx].severity
),
"",
));
}
}
let rule = &self.response.rules[idx];
for condition in &rule.conditions {
if let ResponseCondition::Custom { name, .. } = condition
&& !self.registered_custom_conditions.contains(name)
{
return Err(AgentError::new(
CFG_VALIDATION_FAILED,
"Configuration validation failed",
format!(
"operation=validate_policy; Custom condition '{name}' not in registered_custom_conditions. Register it to prevent policy injection attacks."
),
"",
));
}
}
}
Ok(())
})();
enforce_operation_min_timing(started, TimingOperation::PolicyValidate);
result
}
#[inline]
#[must_use]
pub fn is_suspicious_process(&self, name: &str) -> bool {
let started = Instant::now();
let found = self
.deception
.suspicious_processes
.iter()
.any(|pattern| contains_ascii_case_insensitive(name, pattern.as_str()));
enforce_operation_min_timing(started, TimingOperation::PolicySuspiciousCheckLegacy);
found
}
}
#[inline]
fn contains_ascii_case_insensitive(haystack: &str, needle: &str) -> bool {
if needle.is_empty() {
return true;
}
let h = haystack.as_bytes();
let n = needle.as_bytes();
if n.len() > h.len() {
return false;
}
for start in 0..=(h.len() - n.len()) {
let mut matched = true;
for i in 0..n.len() {
if !h[start + i].eq_ignore_ascii_case(&n[i]) {
matched = false;
break;
}
}
if matched {
return true;
}
}
false
}
impl Default for PolicyConfig {
fn default() -> Self {
Self {
version: POLICY_VERSION,
scoring: ScoringPolicy {
correlation_window_secs: 300,
alert_threshold: 50.0,
max_events_in_memory: 10_000,
enable_time_scoring: true,
enable_ancestry_tracking: true,
weights: ScoringWeights::default(),
business_hours_start: 9,
business_hours_end: 17,
},
response: ResponsePolicy {
rules: vec![
ResponseRule {
severity: Severity::Low,
conditions: vec![],
action: ActionType::Log,
},
ResponseRule {
severity: Severity::Medium,
conditions: vec![],
action: ActionType::Alert,
},
ResponseRule {
severity: Severity::High,
conditions: vec![ResponseCondition::MinConfidence { threshold: 70.0 }],
action: ActionType::KillProcess,
},
ResponseRule {
severity: Severity::Critical,
conditions: vec![
ResponseCondition::MinConfidence { threshold: 85.0 },
ResponseCondition::MinSignalTypes { count: 2 },
],
action: ActionType::IsolateHost,
},
],
cooldown_secs: 60,
max_kills_per_incident: 10,
dry_run: false,
},
deception: DeceptionPolicy {
suspicious_processes: vec![
"mimikatz".to_string(),
"procdump".to_string(),
"lazagne".to_string(),
]
.into_boxed_slice(),
suspicious_patterns: Box::new([]),
},
registered_custom_conditions: HashSet::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_default_policy_validates() {
let policy = PolicyConfig::default();
assert!(policy.validate().is_ok());
}
#[test]
fn test_severity_from_score() {
assert_eq!(Severity::from_score(90.0), Severity::Critical);
assert_eq!(Severity::from_score(70.0), Severity::High);
assert_eq!(Severity::from_score(50.0), Severity::Medium);
assert_eq!(Severity::from_score(30.0), Severity::Low);
}
#[test]
fn test_suspicious_process_case_insensitive() {
let policy = PolicyConfig::default();
assert!(policy.is_suspicious_process("MIMIKATZ.exe"));
assert!(policy.is_suspicious_process("mimikatz"));
assert!(policy.is_suspicious_process("MiMiKaTz"));
assert!(!policy.is_suspicious_process("firefox"));
}
#[test]
fn test_custom_condition_validation() {
let mut policy = PolicyConfig::default();
policy
.response
.rules
.retain(|r| r.severity != Severity::Medium);
policy.response.rules.push(ResponseRule {
severity: Severity::Medium,
conditions: vec![ResponseCondition::Custom {
name: "unregistered".to_string(),
params: HashMap::new(),
}],
action: ActionType::Log,
});
assert!(policy.validate().is_err());
policy
.registered_custom_conditions
.insert("unregistered".to_string());
assert!(policy.validate().is_ok());
}
#[test]
fn test_max_events_validation() {
let mut policy = PolicyConfig::default();
policy.scoring.max_events_in_memory = 150_000;
assert!(policy.validate().is_err());
policy.scoring.max_events_in_memory = 50_000;
assert!(policy.validate().is_ok());
}
}