use std::collections::{HashMap, HashSet, VecDeque};
use std::io::{Read as IoRead, Write as IoWrite};
use flate2::Compression;
use flate2::read::DeflateDecoder;
use flate2::write::DeflateEncoder;
use serde::Serialize;
use rsigma_parser::{
ConditionExpr, ConditionOperator, CorrelationCondition, CorrelationRule, CorrelationType,
FieldAlias, Level,
};
use crate::error::{EvalError, Result};
use crate::event::Event;
#[derive(Debug, Clone)]
pub struct CompiledCorrelation {
pub id: Option<String>,
pub name: Option<String>,
pub title: String,
pub level: Option<Level>,
pub tags: Vec<String>,
pub correlation_type: CorrelationType,
pub rule_refs: Vec<String>,
pub group_by: Vec<GroupByField>,
pub timespan_secs: u64,
pub condition: CompiledCondition,
pub extended_expr: Option<ConditionExpr>,
pub generate: bool,
pub suppress_secs: Option<u64>,
pub action: Option<crate::correlation_engine::CorrelationAction>,
pub event_mode: Option<crate::correlation_engine::CorrelationEventMode>,
pub max_events: Option<usize>,
}
#[derive(Debug, Clone)]
pub enum GroupByField {
Direct(String),
Aliased {
alias: String,
mapping: HashMap<String, String>,
},
}
impl GroupByField {
pub fn name(&self) -> &str {
match self {
GroupByField::Direct(s) => s,
GroupByField::Aliased { alias, .. } => alias,
}
}
pub fn resolve(&self, rule_refs: &[&str]) -> &str {
match self {
GroupByField::Direct(s) => s,
GroupByField::Aliased { alias, mapping } => {
for r in rule_refs {
if let Some(field) = mapping.get(*r) {
return field.as_str();
}
}
alias
}
}
}
}
#[derive(Debug, Clone)]
pub struct CompiledCondition {
pub field: Option<String>,
pub predicates: Vec<(ConditionOperator, f64)>,
}
impl CompiledCondition {
pub fn check(&self, value: f64) -> bool {
self.predicates.iter().all(|(op, threshold)| match op {
ConditionOperator::Lt => value < *threshold,
ConditionOperator::Lte => value <= *threshold,
ConditionOperator::Gt => value > *threshold,
ConditionOperator::Gte => value >= *threshold,
ConditionOperator::Eq => (value - *threshold).abs() < f64::EPSILON,
ConditionOperator::Neq => (value - *threshold).abs() >= f64::EPSILON,
})
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, serde::Deserialize)]
pub struct GroupKey(pub Vec<Option<String>>);
impl GroupKey {
pub fn extract(event: &Event, group_by: &[GroupByField], rule_refs: &[&str]) -> Self {
let values = group_by
.iter()
.map(|field| {
let field_name = field.resolve(rule_refs);
event.get_field(field_name).and_then(value_to_string)
})
.collect();
GroupKey(values)
}
pub fn from_pairs(pairs: &[(String, String)], group_by: &[GroupByField]) -> Self {
let values = group_by
.iter()
.map(|field| {
let name = field.name();
pairs
.iter()
.find(|(k, _)| k == name)
.map(|(_, v)| v.clone())
})
.collect();
GroupKey(values)
}
pub fn to_pairs(&self, group_by: &[GroupByField]) -> Vec<(String, String)> {
group_by
.iter()
.zip(self.0.iter())
.filter_map(|(field, value)| {
value
.as_ref()
.map(|v| (field.name().to_string(), v.clone()))
})
.collect()
}
}
fn value_to_string(v: &serde_json::Value) -> Option<String> {
match v {
serde_json::Value::String(s) => Some(s.clone()),
serde_json::Value::Number(n) => Some(n.to_string()),
serde_json::Value::Bool(b) => Some(b.to_string()),
_ => None,
}
}
const COMPRESSION_LEVEL: Compression = Compression::fast();
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct EventBuffer {
#[serde(with = "event_buffer_serde")]
entries: VecDeque<(i64, Vec<u8>)>,
max_events: usize,
}
mod event_buffer_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::collections::VecDeque;
#[derive(Serialize, Deserialize)]
struct Entry {
ts: i64,
#[serde(with = "base64_bytes")]
data: Vec<u8>,
}
mod base64_bytes {
use base64::Engine as _;
use base64::engine::general_purpose::STANDARD;
use serde::{Deserializer, Serializer};
pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(&STANDARD.encode(bytes))
}
pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
let s: String = serde::Deserialize::deserialize(d)?;
STANDARD.decode(s).map_err(serde::de::Error::custom)
}
}
pub fn serialize<S: Serializer>(
entries: &VecDeque<(i64, Vec<u8>)>,
s: S,
) -> Result<S::Ok, S::Error> {
let v: Vec<Entry> = entries
.iter()
.map(|(ts, data)| Entry {
ts: *ts,
data: data.clone(),
})
.collect();
v.serialize(s)
}
pub fn deserialize<'de, D: Deserializer<'de>>(
d: D,
) -> Result<VecDeque<(i64, Vec<u8>)>, D::Error> {
let v: Vec<Entry> = Vec::deserialize(d)?;
Ok(v.into_iter().map(|e| (e.ts, e.data)).collect())
}
}
impl EventBuffer {
pub fn new(max_events: usize) -> Self {
EventBuffer {
entries: VecDeque::with_capacity(max_events.min(64)),
max_events,
}
}
pub fn push(&mut self, ts: i64, event: &serde_json::Value) {
if let Some(compressed) = compress_event(event) {
if self.entries.len() >= self.max_events {
self.entries.pop_front();
}
self.entries.push_back((ts, compressed));
}
}
pub fn evict(&mut self, cutoff: i64) {
while self.entries.front().is_some_and(|(t, _)| *t < cutoff) {
self.entries.pop_front();
}
}
pub fn decompress_all(&self) -> Vec<serde_json::Value> {
self.entries
.iter()
.filter_map(|(_, compressed)| decompress_event(compressed))
.collect()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn compressed_bytes(&self) -> usize {
self.entries.iter().map(|(_, data)| data.len()).sum()
}
pub fn len(&self) -> usize {
self.entries.len()
}
}
fn compress_event(event: &serde_json::Value) -> Option<Vec<u8>> {
let json_bytes = serde_json::to_vec(event).ok()?;
let mut encoder = DeflateEncoder::new(Vec::new(), COMPRESSION_LEVEL);
encoder.write_all(&json_bytes).ok()?;
encoder.finish().ok()
}
fn decompress_event(compressed: &[u8]) -> Option<serde_json::Value> {
let mut decoder = DeflateDecoder::new(compressed);
let mut json_bytes = Vec::new();
decoder.read_to_end(&mut json_bytes).ok()?;
serde_json::from_slice(&json_bytes).ok()
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct EventRef {
pub timestamp: i64,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub struct EventRefBuffer {
entries: VecDeque<EventRef>,
max_events: usize,
}
impl EventRefBuffer {
pub fn new(max_events: usize) -> Self {
EventRefBuffer {
entries: VecDeque::with_capacity(max_events.min(64)),
max_events,
}
}
pub fn push(&mut self, ts: i64, event: &serde_json::Value) {
if self.entries.len() >= self.max_events {
self.entries.pop_front();
}
let id = extract_event_id(event);
self.entries.push_back(EventRef { timestamp: ts, id });
}
pub fn evict(&mut self, cutoff: i64) {
while self.entries.front().is_some_and(|r| r.timestamp < cutoff) {
self.entries.pop_front();
}
}
pub fn refs(&self) -> Vec<EventRef> {
self.entries.iter().cloned().collect()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn clear(&mut self) {
self.entries.clear();
}
pub fn len(&self) -> usize {
self.entries.len()
}
}
fn extract_event_id(event: &serde_json::Value) -> Option<String> {
const ID_FIELDS: &[&str] = &["id", "_id", "event_id", "EventRecordID", "event.id"];
for field in ID_FIELDS {
if let Some(val) = event.get(field) {
return match val {
serde_json::Value::String(s) => Some(s.clone()),
serde_json::Value::Number(n) => Some(n.to_string()),
_ => None,
};
}
}
None
}
#[derive(Debug, Clone, Serialize, serde::Deserialize)]
pub enum WindowState {
EventCount { timestamps: VecDeque<i64> },
ValueCount { entries: VecDeque<(i64, String)> },
Temporal {
rule_hits: HashMap<String, VecDeque<i64>>,
},
NumericAgg { entries: VecDeque<(i64, f64)> },
}
impl WindowState {
pub fn new_for(corr_type: CorrelationType) -> Self {
match corr_type {
CorrelationType::EventCount => WindowState::EventCount {
timestamps: VecDeque::new(),
},
CorrelationType::ValueCount => WindowState::ValueCount {
entries: VecDeque::new(),
},
CorrelationType::Temporal | CorrelationType::TemporalOrdered => WindowState::Temporal {
rule_hits: HashMap::new(),
},
CorrelationType::ValueSum
| CorrelationType::ValueAvg
| CorrelationType::ValuePercentile
| CorrelationType::ValueMedian => WindowState::NumericAgg {
entries: VecDeque::new(),
},
}
}
pub fn evict(&mut self, cutoff: i64) {
match self {
WindowState::EventCount { timestamps } => {
while timestamps.front().is_some_and(|&t| t < cutoff) {
timestamps.pop_front();
}
}
WindowState::ValueCount { entries } => {
while entries.front().is_some_and(|(t, _)| *t < cutoff) {
entries.pop_front();
}
}
WindowState::Temporal { rule_hits } => {
for timestamps in rule_hits.values_mut() {
while timestamps.front().is_some_and(|&t| t < cutoff) {
timestamps.pop_front();
}
}
rule_hits.retain(|_, ts| !ts.is_empty());
}
WindowState::NumericAgg { entries } => {
while entries.front().is_some_and(|(t, _)| *t < cutoff) {
entries.pop_front();
}
}
}
}
pub fn is_empty(&self) -> bool {
match self {
WindowState::EventCount { timestamps } => timestamps.is_empty(),
WindowState::ValueCount { entries } => entries.is_empty(),
WindowState::Temporal { rule_hits } => rule_hits.is_empty(),
WindowState::NumericAgg { entries } => entries.is_empty(),
}
}
pub fn latest_timestamp(&self) -> Option<i64> {
match self {
WindowState::EventCount { timestamps } => timestamps.back().copied(),
WindowState::ValueCount { entries } => entries.back().map(|(t, _)| *t),
WindowState::Temporal { rule_hits } => {
rule_hits.values().filter_map(|ts| ts.back().copied()).max()
}
WindowState::NumericAgg { entries } => entries.back().map(|(t, _)| *t),
}
}
pub fn clear(&mut self) {
match self {
WindowState::EventCount { timestamps } => timestamps.clear(),
WindowState::ValueCount { entries } => entries.clear(),
WindowState::Temporal { rule_hits } => rule_hits.clear(),
WindowState::NumericAgg { entries } => entries.clear(),
}
}
pub fn push_event_count(&mut self, ts: i64) {
if let WindowState::EventCount { timestamps } = self {
timestamps.push_back(ts);
}
}
pub fn push_value_count(&mut self, ts: i64, value: String) {
if let WindowState::ValueCount { entries } = self {
entries.push_back((ts, value));
}
}
pub fn push_temporal(&mut self, ts: i64, rule_ref: &str) {
if let WindowState::Temporal { rule_hits } = self {
rule_hits
.entry(rule_ref.to_string())
.or_default()
.push_back(ts);
}
}
pub fn push_numeric(&mut self, ts: i64, value: f64) {
if let WindowState::NumericAgg { entries } = self {
entries.push_back((ts, value));
}
}
pub fn check_condition(
&self,
condition: &CompiledCondition,
corr_type: CorrelationType,
rule_refs: &[String],
extended_expr: Option<&ConditionExpr>,
) -> Option<f64> {
let value = match (self, corr_type) {
(WindowState::EventCount { timestamps }, CorrelationType::EventCount) => {
timestamps.len() as f64
}
(WindowState::ValueCount { entries }, CorrelationType::ValueCount) => {
let distinct: HashSet<&String> = entries.iter().map(|(_, v)| v).collect();
distinct.len() as f64
}
(WindowState::Temporal { rule_hits }, CorrelationType::Temporal) => {
if let Some(expr) = extended_expr {
if eval_temporal_expr(expr, rule_hits) {
let fired: usize = rule_refs
.iter()
.filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
.count();
return Some(fired as f64);
} else {
return None;
}
}
let fired: usize = rule_refs
.iter()
.filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
.count();
fired as f64
}
(WindowState::Temporal { rule_hits }, CorrelationType::TemporalOrdered) => {
if let Some(expr) = extended_expr
&& !eval_temporal_expr(expr, rule_hits)
{
return None;
}
if check_temporal_ordered(rule_refs, rule_hits) {
rule_refs.len() as f64
} else {
0.0
}
}
(WindowState::NumericAgg { entries }, CorrelationType::ValueSum) => {
entries.iter().map(|(_, v)| v).sum()
}
(WindowState::NumericAgg { entries }, CorrelationType::ValueAvg) => {
if entries.is_empty() {
0.0
} else {
let sum: f64 = entries.iter().map(|(_, v)| v).sum();
sum / entries.len() as f64
}
}
(WindowState::NumericAgg { entries }, CorrelationType::ValuePercentile) => {
if entries.is_empty() {
return None;
}
let mut values: Vec<f64> = entries
.iter()
.map(|(_, v)| *v)
.filter(|v| v.is_finite())
.collect();
if values.is_empty() {
return None;
}
values.sort_by(|a, b| a.partial_cmp(b).expect("NaN filtered"));
let percentile_rank = condition
.predicates
.first()
.map(|(_, threshold)| *threshold)
.unwrap_or(50.0);
let pval = percentile_linear_interp(&values, percentile_rank);
return Some(pval);
}
(WindowState::NumericAgg { entries }, CorrelationType::ValueMedian) => {
if entries.is_empty() {
0.0
} else {
let mut values: Vec<f64> = entries
.iter()
.map(|(_, v)| *v)
.filter(|v| v.is_finite())
.collect();
if values.is_empty() {
return None;
}
values.sort_by(|a, b| a.partial_cmp(b).expect("NaN filtered"));
let mid = values.len() / 2;
if values.len().is_multiple_of(2) && values.len() >= 2 {
(values[mid - 1] + values[mid]) / 2.0
} else {
values[mid]
}
}
}
_ => return None, };
if condition.check(value) {
Some(value)
} else {
None
}
}
}
fn check_temporal_ordered(
rule_refs: &[String],
rule_hits: &HashMap<String, VecDeque<i64>>,
) -> bool {
if rule_refs.is_empty() {
return true;
}
for r in rule_refs {
if rule_hits.get(r.as_str()).is_none_or(|ts| ts.is_empty()) {
return false;
}
}
fn find_ordered(
rule_refs: &[String],
rule_hits: &HashMap<String, VecDeque<i64>>,
idx: usize,
min_ts: i64,
) -> bool {
if idx >= rule_refs.len() {
return true;
}
let Some(timestamps) = rule_hits.get(&rule_refs[idx]) else {
return false;
};
for &ts in timestamps {
if ts >= min_ts && find_ordered(rule_refs, rule_hits, idx + 1, ts) {
return true;
}
}
false
}
find_ordered(rule_refs, rule_hits, 0, i64::MIN)
}
fn eval_temporal_expr(expr: &ConditionExpr, rule_hits: &HashMap<String, VecDeque<i64>>) -> bool {
match expr {
ConditionExpr::Identifier(name) => rule_hits
.get(name.as_str())
.is_some_and(|ts| !ts.is_empty()),
ConditionExpr::And(children) => children.iter().all(|c| eval_temporal_expr(c, rule_hits)),
ConditionExpr::Or(children) => children.iter().any(|c| eval_temporal_expr(c, rule_hits)),
ConditionExpr::Not(child) => !eval_temporal_expr(child, rule_hits),
ConditionExpr::Selector { .. } => {
false
}
}
}
fn percentile_linear_interp(values: &[f64], percentile: f64) -> f64 {
if values.is_empty() {
return 0.0;
}
let n = values.len();
if n == 1 {
return values[0];
}
let p = percentile.clamp(0.0, 100.0) / 100.0;
let rank = p * (n - 1) as f64;
let lower = rank.floor() as usize;
let upper = rank.ceil() as usize;
let fraction = rank - lower as f64;
if lower == upper || upper >= n {
values[lower.min(n - 1)]
} else {
values[lower] + fraction * (values[upper] - values[lower])
}
}
pub fn compile_correlation(rule: &CorrelationRule) -> Result<CompiledCorrelation> {
let alias_map: HashMap<&str, &FieldAlias> =
rule.aliases.iter().map(|a| (a.alias.as_str(), a)).collect();
let group_by: Vec<GroupByField> = rule
.group_by
.iter()
.map(|field_name| {
if let Some(alias) = alias_map.get(field_name.as_str()) {
GroupByField::Aliased {
alias: field_name.clone(),
mapping: alias.mapping.clone(),
}
} else {
GroupByField::Direct(field_name.clone())
}
})
.collect();
let (condition, extended_expr) = compile_condition(&rule.condition, rule.correlation_type)?;
let suppress_secs = rule
.custom_attributes
.get("rsigma.suppress")
.and_then(|v| rsigma_parser::Timespan::parse(v).ok())
.map(|ts| ts.seconds);
let action = rule.custom_attributes.get("rsigma.action").and_then(|v| {
v.parse::<crate::correlation_engine::CorrelationAction>()
.ok()
});
let event_mode = rule
.custom_attributes
.get("rsigma.correlation_event_mode")
.and_then(|v| {
v.parse::<crate::correlation_engine::CorrelationEventMode>()
.ok()
});
let max_events = rule
.custom_attributes
.get("rsigma.max_correlation_events")
.and_then(|v| v.parse::<usize>().ok());
Ok(CompiledCorrelation {
id: rule.id.clone(),
name: rule.name.clone(),
title: rule.title.clone(),
level: rule.level,
tags: rule.tags.clone(),
correlation_type: rule.correlation_type,
rule_refs: rule.rules.clone(),
group_by,
timespan_secs: rule.timespan.seconds,
condition,
extended_expr,
generate: rule.generate,
suppress_secs,
action,
event_mode,
max_events,
})
}
fn compile_condition(
cond: &CorrelationCondition,
corr_type: CorrelationType,
) -> Result<(CompiledCondition, Option<ConditionExpr>)> {
match cond {
CorrelationCondition::Threshold { predicates, field } => Ok((
CompiledCondition {
field: field.clone(),
predicates: predicates
.iter()
.map(|(op, count)| (*op, *count as f64))
.collect(),
},
None,
)),
CorrelationCondition::Extended(expr) => {
match corr_type {
CorrelationType::Temporal | CorrelationType::TemporalOrdered => {
Ok((
CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 1.0)],
},
Some(expr.clone()),
))
}
_ => Err(EvalError::CorrelationError(
"Extended conditions are only supported for temporal correlation types"
.to_string(),
)),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_group_key_extract() {
let v = json!({"User": "admin", "Host": "srv01"});
let event = Event::from_value(&v);
let group_by = vec![
GroupByField::Direct("User".to_string()),
GroupByField::Direct("Host".to_string()),
];
let key = GroupKey::extract(&event, &group_by, &["rule1"]);
assert_eq!(
key.0,
vec![Some("admin".to_string()), Some("srv01".to_string())]
);
}
#[test]
fn test_group_key_missing_field() {
let v = json!({"User": "admin"});
let event = Event::from_value(&v);
let group_by = vec![
GroupByField::Direct("User".to_string()),
GroupByField::Direct("Host".to_string()),
];
let key = GroupKey::extract(&event, &group_by, &["rule1"]);
assert_eq!(key.0, vec![Some("admin".to_string()), None]);
}
#[test]
fn test_group_key_aliased() {
let v = json!({"source.ip": "10.0.0.1"});
let event = Event::from_value(&v);
let group_by = vec![GroupByField::Aliased {
alias: "internal_ip".to_string(),
mapping: HashMap::from([
("rule_a".to_string(), "source.ip".to_string()),
("rule_b".to_string(), "destination.ip".to_string()),
]),
}];
let key = GroupKey::extract(&event, &group_by, &["rule_a"]);
assert_eq!(key.0, vec![Some("10.0.0.1".to_string())]);
}
#[test]
fn test_condition_check() {
let cond = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 100.0)],
};
assert!(!cond.check(99.0));
assert!(cond.check(100.0));
assert!(cond.check(101.0));
}
#[test]
fn test_condition_check_range() {
let cond = CompiledCondition {
field: None,
predicates: vec![
(ConditionOperator::Gt, 100.0),
(ConditionOperator::Lte, 200.0),
],
};
assert!(!cond.check(100.0));
assert!(cond.check(101.0));
assert!(cond.check(200.0));
assert!(!cond.check(201.0));
}
#[test]
fn test_window_event_count() {
let mut state = WindowState::new_for(CorrelationType::EventCount);
for i in 0..5 {
state.push_event_count(1000 + i);
}
let cond = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 5.0)],
};
assert_eq!(
state.check_condition(&cond, CorrelationType::EventCount, &[], None),
Some(5.0)
);
}
#[test]
fn test_window_event_count_eviction() {
let mut state = WindowState::new_for(CorrelationType::EventCount);
for i in 0..10 {
state.push_event_count(1000 + i);
}
state.evict(1005);
let cond = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 5.0)],
};
assert_eq!(
state.check_condition(&cond, CorrelationType::EventCount, &[], None),
Some(5.0)
);
}
#[test]
fn test_window_value_count() {
let mut state = WindowState::new_for(CorrelationType::ValueCount);
state.push_value_count(1000, "user1".to_string());
state.push_value_count(1001, "user2".to_string());
state.push_value_count(1002, "user1".to_string()); state.push_value_count(1003, "user3".to_string());
let cond = CompiledCondition {
field: Some("User".to_string()),
predicates: vec![(ConditionOperator::Gte, 3.0)],
};
assert_eq!(
state.check_condition(&cond, CorrelationType::ValueCount, &[], None),
Some(3.0)
);
}
#[test]
fn test_window_temporal() {
let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
let mut state = WindowState::new_for(CorrelationType::Temporal);
state.push_temporal(1000, "rule_a");
let cond = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 2.0)],
};
assert!(
state
.check_condition(&cond, CorrelationType::Temporal, &refs, None)
.is_none()
);
state.push_temporal(1001, "rule_b");
assert_eq!(
state.check_condition(&cond, CorrelationType::Temporal, &refs, None),
Some(2.0)
);
}
#[test]
fn test_window_temporal_ordered() {
let refs = vec![
"rule_a".to_string(),
"rule_b".to_string(),
"rule_c".to_string(),
];
let mut state = WindowState::new_for(CorrelationType::TemporalOrdered);
state.push_temporal(1000, "rule_a");
state.push_temporal(1001, "rule_b");
state.push_temporal(1002, "rule_c");
let cond = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 3.0)],
};
assert!(
state
.check_condition(&cond, CorrelationType::TemporalOrdered, &refs, None)
.is_some()
);
}
#[test]
fn test_window_temporal_ordered_wrong_order() {
let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
let mut state = WindowState::new_for(CorrelationType::TemporalOrdered);
state.push_temporal(1000, "rule_b");
state.push_temporal(1001, "rule_a");
let cond = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 2.0)],
};
assert!(
state
.check_condition(&cond, CorrelationType::TemporalOrdered, &refs, None)
.is_none()
);
}
#[test]
fn test_window_value_sum() {
let mut state = WindowState::new_for(CorrelationType::ValueSum);
state.push_numeric(1000, 500.0);
state.push_numeric(1001, 600.0);
let cond = CompiledCondition {
field: Some("bytes_sent".to_string()),
predicates: vec![(ConditionOperator::Gt, 1000.0)],
};
assert_eq!(
state.check_condition(&cond, CorrelationType::ValueSum, &[], None),
Some(1100.0)
);
}
#[test]
fn test_window_value_avg() {
let mut state = WindowState::new_for(CorrelationType::ValueAvg);
state.push_numeric(1000, 100.0);
state.push_numeric(1001, 200.0);
state.push_numeric(1002, 300.0);
let cond = CompiledCondition {
field: Some("bytes".to_string()),
predicates: vec![(ConditionOperator::Gte, 200.0)],
};
assert_eq!(
state.check_condition(&cond, CorrelationType::ValueAvg, &[], None),
Some(200.0)
);
}
#[test]
fn test_window_value_median() {
let mut state = WindowState::new_for(CorrelationType::ValueMedian);
state.push_numeric(1000, 10.0);
state.push_numeric(1001, 20.0);
state.push_numeric(1002, 30.0);
let cond = CompiledCondition {
field: Some("latency".to_string()),
predicates: vec![(ConditionOperator::Gte, 20.0)],
};
assert_eq!(
state.check_condition(&cond, CorrelationType::ValueMedian, &[], None),
Some(20.0)
);
}
#[test]
fn test_compile_correlation_basic() {
use rsigma_parser::parse_sigma_yaml;
let yaml = r#"
title: Base Rule
id: f305fd62-beca-47da-ad95-7690a0620084
logsource:
product: aws
service: cloudtrail
detection:
selection:
eventSource: "s3.amazonaws.com"
condition: selection
level: low
---
title: Multiple AWS bucket enumerations
id: be246094-01d3-4bba-88de-69e582eba0cc
status: experimental
correlation:
type: event_count
rules:
- f305fd62-beca-47da-ad95-7690a0620084
group-by:
- userIdentity.arn
timespan: 1h
condition:
gte: 100
level: high
"#;
let collection = parse_sigma_yaml(yaml).unwrap();
assert_eq!(collection.correlations.len(), 1);
let compiled = compile_correlation(&collection.correlations[0]).unwrap();
assert_eq!(compiled.correlation_type, CorrelationType::EventCount);
assert_eq!(compiled.timespan_secs, 3600);
assert_eq!(compiled.rule_refs.len(), 1);
assert_eq!(compiled.group_by.len(), 1);
assert!(compiled.condition.check(100.0));
assert!(!compiled.condition.check(99.0));
}
#[test]
fn test_eval_temporal_expr_and() {
let mut rule_hits = HashMap::new();
rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
rule_hits.insert("rule_b".to_string(), VecDeque::from([1001]));
let expr = ConditionExpr::And(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]);
assert!(eval_temporal_expr(&expr, &rule_hits));
}
#[test]
fn test_eval_temporal_expr_and_incomplete() {
let mut rule_hits = HashMap::new();
rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
let expr = ConditionExpr::And(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]);
assert!(!eval_temporal_expr(&expr, &rule_hits));
}
#[test]
fn test_eval_temporal_expr_or() {
let mut rule_hits = HashMap::new();
rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
let expr = ConditionExpr::Or(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]);
assert!(eval_temporal_expr(&expr, &rule_hits));
}
#[test]
fn test_eval_temporal_expr_not() {
let rule_hits = HashMap::new();
let expr = ConditionExpr::Not(Box::new(ConditionExpr::Identifier("rule_a".to_string())));
assert!(eval_temporal_expr(&expr, &rule_hits));
}
#[test]
fn test_eval_temporal_expr_complex() {
let mut rule_hits = HashMap::new();
rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
rule_hits.insert("rule_b".to_string(), VecDeque::from([1001]));
let expr = ConditionExpr::And(vec![
ConditionExpr::And(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]),
ConditionExpr::Not(Box::new(ConditionExpr::Identifier("rule_c".to_string()))),
]);
assert!(eval_temporal_expr(&expr, &rule_hits));
}
#[test]
fn test_check_condition_with_extended_expr() {
let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
let mut state = WindowState::new_for(CorrelationType::Temporal);
state.push_temporal(1000, "rule_a");
state.push_temporal(1001, "rule_b");
let cond = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 1.0)],
};
let expr = ConditionExpr::And(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]);
assert!(
state
.check_condition(&cond, CorrelationType::Temporal, &refs, Some(&expr))
.is_some()
);
let mut state2 = WindowState::new_for(CorrelationType::Temporal);
state2.push_temporal(1000, "rule_a");
assert!(
state2
.check_condition(&cond, CorrelationType::Temporal, &refs, Some(&expr))
.is_none()
);
}
#[test]
fn test_percentile_linear_interp_single() {
assert!((percentile_linear_interp(&[42.0], 50.0) - 42.0).abs() < f64::EPSILON);
}
#[test]
fn test_percentile_linear_interp_basic() {
let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
assert!((percentile_linear_interp(values, 0.0) - 1.0).abs() < f64::EPSILON);
assert!((percentile_linear_interp(values, 25.0) - 2.0).abs() < f64::EPSILON);
assert!((percentile_linear_interp(values, 50.0) - 3.0).abs() < f64::EPSILON);
assert!((percentile_linear_interp(values, 75.0) - 4.0).abs() < f64::EPSILON);
assert!((percentile_linear_interp(values, 100.0) - 5.0).abs() < f64::EPSILON);
}
#[test]
fn test_percentile_linear_interp_interpolation() {
let values = &[10.0, 20.0, 30.0, 40.0];
assert!((percentile_linear_interp(values, 50.0) - 25.0).abs() < f64::EPSILON);
}
#[test]
fn test_percentile_linear_interp_1st_percentile() {
let values: Vec<f64> = (1..=100).map(|x| x as f64).collect();
let p1 = percentile_linear_interp(&values, 1.0);
assert!((p1 - 1.99).abs() < 0.01);
}
#[test]
fn test_value_percentile_check_condition() {
let mut state = WindowState::new_for(CorrelationType::ValuePercentile);
for i in 1..=100 {
state.push_numeric(1000 + i, i as f64);
}
let cond = CompiledCondition {
field: Some("latency".to_string()),
predicates: vec![(ConditionOperator::Lte, 50.0)],
};
let result = state.check_condition(&cond, CorrelationType::ValuePercentile, &[], None);
assert!(result.is_some());
let val = result.unwrap();
assert!((val - 50.5).abs() < 1.0, "expected ~50.5, got {val}");
}
#[test]
fn test_percentile_0th_and_100th() {
let values = &[5.0, 10.0, 15.0, 20.0];
assert!((percentile_linear_interp(values, 0.0) - 5.0).abs() < f64::EPSILON);
assert!((percentile_linear_interp(values, 100.0) - 20.0).abs() < f64::EPSILON);
}
#[test]
fn test_percentile_two_values() {
let values = &[10.0, 20.0];
assert!((percentile_linear_interp(values, 50.0) - 15.0).abs() < f64::EPSILON);
assert!((percentile_linear_interp(values, 25.0) - 12.5).abs() < f64::EPSILON);
}
#[test]
fn test_percentile_clamps_out_of_range() {
let values = &[1.0, 2.0, 3.0];
assert!((percentile_linear_interp(values, -10.0) - 1.0).abs() < f64::EPSILON);
assert!((percentile_linear_interp(values, 150.0) - 3.0).abs() < f64::EPSILON);
}
#[test]
fn test_value_percentile_empty_window() {
let state = WindowState::new_for(CorrelationType::ValuePercentile);
let cond = CompiledCondition {
field: Some("latency".to_string()),
predicates: vec![(ConditionOperator::Lte, 50.0)],
};
assert!(
state
.check_condition(&cond, CorrelationType::ValuePercentile, &[], None)
.is_none()
);
}
#[test]
fn test_extended_temporal_or_single_rule() {
let mut rule_hits = HashMap::new();
rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
let expr = ConditionExpr::Or(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]);
assert!(eval_temporal_expr(&expr, &rule_hits));
}
#[test]
fn test_extended_temporal_empty_hits() {
let rule_hits = HashMap::new();
let expr = ConditionExpr::And(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]);
assert!(!eval_temporal_expr(&expr, &rule_hits));
let expr_or = ConditionExpr::Or(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]);
assert!(!eval_temporal_expr(&expr_or, &rule_hits));
}
#[test]
fn test_extended_temporal_with_empty_deque() {
let mut rule_hits = HashMap::new();
rule_hits.insert("rule_a".to_string(), VecDeque::new());
rule_hits.insert("rule_b".to_string(), VecDeque::from([1000]));
let expr = ConditionExpr::And(vec![
ConditionExpr::Identifier("rule_a".to_string()),
ConditionExpr::Identifier("rule_b".to_string()),
]);
assert!(!eval_temporal_expr(&expr, &rule_hits));
}
#[test]
fn test_check_condition_temporal_no_extended_expr() {
let refs = vec![
"rule_a".to_string(),
"rule_b".to_string(),
"rule_c".to_string(),
];
let mut state = WindowState::new_for(CorrelationType::Temporal);
state.push_temporal(1000, "rule_a");
state.push_temporal(1001, "rule_b");
let cond = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 2.0)],
};
assert_eq!(
state.check_condition(&cond, CorrelationType::Temporal, &refs, None),
Some(2.0)
);
let cond3 = CompiledCondition {
field: None,
predicates: vec![(ConditionOperator::Gte, 3.0)],
};
assert!(
state
.check_condition(&cond3, CorrelationType::Temporal, &refs, None)
.is_none()
);
}
#[test]
fn test_event_buffer_push_and_decompress() {
let mut buf = EventBuffer::new(10);
let event = json!({"User": "admin", "action": "login", "src_ip": "10.0.0.1"});
buf.push(1000, &event);
assert_eq!(buf.len(), 1);
assert!(!buf.is_empty());
let events = buf.decompress_all();
assert_eq!(events.len(), 1);
assert_eq!(events[0], event);
}
#[test]
fn test_event_buffer_compression_saves_memory() {
let mut buf = EventBuffer::new(100);
let event = json!({
"User": "admin",
"action": "login",
"src_ip": "192.168.1.100",
"dst_ip": "10.0.0.1",
"EventTime": "2024-07-10T12:30:00Z",
"process": "sshd",
"host": "production-server-01.example.com",
"message": "Accepted password for admin from 192.168.1.100 port 22 ssh2",
"severity": "info",
"tags": ["authentication", "network", "linux"]
});
let raw_size = serde_json::to_vec(&event).unwrap().len();
buf.push(1000, &event);
let compressed_size = buf.compressed_bytes();
assert!(
compressed_size < raw_size,
"Compressed {compressed_size}B should be less than raw {raw_size}B"
);
let events = buf.decompress_all();
assert_eq!(events[0], event);
}
#[test]
fn test_event_buffer_max_events_cap() {
let mut buf = EventBuffer::new(3);
for i in 0..5 {
buf.push(1000 + i, &json!({"idx": i}));
}
assert_eq!(buf.len(), 3);
let events = buf.decompress_all();
assert_eq!(events[0], json!({"idx": 2}));
assert_eq!(events[1], json!({"idx": 3}));
assert_eq!(events[2], json!({"idx": 4}));
}
#[test]
fn test_event_buffer_eviction() {
let mut buf = EventBuffer::new(10);
for i in 0..5 {
buf.push(1000 + i, &json!({"idx": i}));
}
assert_eq!(buf.len(), 5);
buf.evict(1003);
assert_eq!(buf.len(), 2);
let events = buf.decompress_all();
assert_eq!(events[0], json!({"idx": 3}));
assert_eq!(events[1], json!({"idx": 4}));
}
#[test]
fn test_event_buffer_clear() {
let mut buf = EventBuffer::new(10);
buf.push(1000, &json!({"a": 1}));
buf.push(1001, &json!({"b": 2}));
assert_eq!(buf.len(), 2);
buf.clear();
assert!(buf.is_empty());
assert_eq!(buf.len(), 0);
assert_eq!(buf.compressed_bytes(), 0);
}
#[test]
fn test_compress_decompress_roundtrip() {
let values = vec![
json!(null),
json!(42),
json!("hello world"),
json!({"nested": {"deep": [1, 2, 3]}}),
json!([1, "two", null, true, {"five": 5}]),
];
for val in values {
let compressed = compress_event(&val).unwrap();
let decompressed = decompress_event(&compressed).unwrap();
assert_eq!(decompressed, val, "Roundtrip failed for {val}");
}
}
#[test]
fn test_event_ref_buffer_push_and_refs() {
let mut buf = EventRefBuffer::new(10);
buf.push(1000, &json!({"id": "evt-1", "data": "hello"}));
buf.push(1001, &json!({"_id": 42, "data": "world"}));
buf.push(1002, &json!({"data": "no-id"}));
assert_eq!(buf.len(), 3);
let refs = buf.refs();
assert_eq!(refs[0].timestamp, 1000);
assert_eq!(refs[0].id, Some("evt-1".to_string()));
assert_eq!(refs[1].timestamp, 1001);
assert_eq!(refs[1].id, Some("42".to_string()));
assert_eq!(refs[2].timestamp, 1002);
assert_eq!(refs[2].id, None);
}
#[test]
fn test_event_ref_buffer_max_cap() {
let mut buf = EventRefBuffer::new(3);
for i in 0..5 {
buf.push(1000 + i, &json!({"id": format!("e-{i}")}));
}
assert_eq!(buf.len(), 3);
let refs = buf.refs();
assert_eq!(refs[0].id, Some("e-2".to_string()));
assert_eq!(refs[1].id, Some("e-3".to_string()));
assert_eq!(refs[2].id, Some("e-4".to_string()));
}
#[test]
fn test_event_ref_buffer_eviction() {
let mut buf = EventRefBuffer::new(10);
for i in 0..5 {
buf.push(1000 + i, &json!({"id": format!("e-{i}")}));
}
buf.evict(1003);
assert_eq!(buf.len(), 2);
let refs = buf.refs();
assert_eq!(refs[0].timestamp, 1003);
assert_eq!(refs[1].timestamp, 1004);
}
#[test]
fn test_event_ref_buffer_clear() {
let mut buf = EventRefBuffer::new(10);
buf.push(1000, &json!({"id": "a"}));
buf.push(1001, &json!({"id": "b"}));
assert_eq!(buf.len(), 2);
buf.clear();
assert!(buf.is_empty());
assert_eq!(buf.len(), 0);
}
#[test]
fn test_extract_event_id_common_fields() {
assert_eq!(
extract_event_id(&json!({"id": "abc"})),
Some("abc".to_string())
);
assert_eq!(
extract_event_id(&json!({"_id": 123})),
Some("123".to_string())
);
assert_eq!(
extract_event_id(&json!({"event_id": "x-1"})),
Some("x-1".to_string())
);
assert_eq!(
extract_event_id(&json!({"EventRecordID": 999})),
Some("999".to_string())
);
assert_eq!(extract_event_id(&json!({"no_id_field": true})), None);
}
#[test]
fn test_compile_correlation_with_custom_attributes() {
use rsigma_parser::*;
let mut custom_attributes = std::collections::HashMap::new();
custom_attributes.insert(
"rsigma.correlation_event_mode".to_string(),
"refs".to_string(),
);
custom_attributes.insert(
"rsigma.max_correlation_events".to_string(),
"25".to_string(),
);
custom_attributes.insert("rsigma.suppress".to_string(), "5m".to_string());
custom_attributes.insert("rsigma.action".to_string(), "reset".to_string());
let rule = CorrelationRule {
title: "Test Corr".to_string(),
id: Some("corr-1".to_string()),
name: None,
status: None,
description: None,
author: None,
date: None,
modified: None,
references: vec![],
tags: vec![],
level: Some(Level::High),
correlation_type: CorrelationType::EventCount,
rules: vec!["rule-1".to_string()],
group_by: vec!["User".to_string()],
timespan: Timespan::parse("60s").unwrap(),
condition: CorrelationCondition::Threshold {
predicates: vec![(ConditionOperator::Gte, 5)],
field: None,
},
aliases: vec![],
generate: false,
custom_attributes,
};
let compiled = compile_correlation(&rule).unwrap();
assert_eq!(
compiled.event_mode,
Some(crate::correlation_engine::CorrelationEventMode::Refs)
);
assert_eq!(compiled.max_events, Some(25));
assert_eq!(compiled.suppress_secs, Some(300)); assert_eq!(
compiled.action,
Some(crate::correlation_engine::CorrelationAction::Reset)
);
}
}