Skip to main content

rsigma_eval/
correlation.rs

1//! Compiled correlation types, group key, window state, and compilation.
2//!
3//! Transforms the parser's `CorrelationRule` AST into an optimized
4//! `CompiledCorrelation` with associated `WindowState` for stateful evaluation.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::io::{Read as IoRead, Write as IoWrite};
8
9use flate2::Compression;
10use flate2::read::DeflateDecoder;
11use flate2::write::DeflateEncoder;
12use serde::Serialize;
13
14use rsigma_parser::{
15    ConditionExpr, ConditionOperator, CorrelationCondition, CorrelationRule, CorrelationType,
16    FieldAlias, Level,
17};
18
19use crate::error::{EvalError, Result};
20use crate::event::Event;
21
22// =============================================================================
23// Compiled types
24// =============================================================================
25
26/// Compiled form of a `CorrelationRule`, ready for stateful evaluation.
27#[derive(Debug, Clone)]
28pub struct CompiledCorrelation {
29    pub id: Option<String>,
30    pub name: Option<String>,
31    pub title: String,
32    pub level: Option<Level>,
33    pub tags: Vec<String>,
34    pub correlation_type: CorrelationType,
35    /// IDs or names of referenced rules (detection or other correlations).
36    pub rule_refs: Vec<String>,
37    /// Resolved group-by fields (may include aliases).
38    pub group_by: Vec<GroupByField>,
39    /// Time window in seconds.
40    pub timespan_secs: u64,
41    /// Compiled threshold condition.
42    pub condition: CompiledCondition,
43    /// Extended boolean condition expression for temporal correlations.
44    /// When set, evaluates this expression against fired rules instead of
45    /// a simple threshold count.
46    pub extended_expr: Option<ConditionExpr>,
47    /// Whether referenced detection rules should also generate standalone matches.
48    pub generate: bool,
49    /// Per-correlation suppression window in seconds, resolved from the
50    /// `rsigma.suppress` custom attribute. `None` means use engine default.
51    pub suppress_secs: Option<u64>,
52    /// Per-correlation action on match, resolved from the `rsigma.action`
53    /// custom attribute. `None` means use engine default.
54    pub action: Option<crate::correlation_engine::CorrelationAction>,
55    /// Event inclusion mode for this correlation.
56    /// `None` means use the engine default (`CorrelationConfig.correlation_event_mode`).
57    pub event_mode: Option<crate::correlation_engine::CorrelationEventMode>,
58    /// Maximum events to store per window group for event inclusion.
59    /// `None` means use the engine default (`CorrelationConfig.max_correlation_events`).
60    pub max_events: Option<usize>,
61}
62
63/// A group-by field, potentially aliased per referenced rule.
64#[derive(Debug, Clone)]
65pub enum GroupByField {
66    /// Simple field name, same across all referenced rules.
67    Direct(String),
68    /// Aliased: maps rule_ref -> actual field name in that rule's events.
69    Aliased {
70        alias: String,
71        mapping: HashMap<String, String>,
72    },
73}
74
75impl GroupByField {
76    /// Get the display name of this group-by field.
77    pub fn name(&self) -> &str {
78        match self {
79            GroupByField::Direct(s) => s,
80            GroupByField::Aliased { alias, .. } => alias,
81        }
82    }
83
84    /// Resolve the actual field name to look up in an event, given which
85    /// rule (by ID or name) produced the detection match.
86    ///
87    /// Tries to find the rule in the alias mapping by any of the provided
88    /// identifiers (ID, name, etc.).
89    pub fn resolve(&self, rule_refs: &[&str]) -> &str {
90        match self {
91            GroupByField::Direct(s) => s,
92            GroupByField::Aliased { alias, mapping } => {
93                for r in rule_refs {
94                    if let Some(field) = mapping.get(*r) {
95                        return field.as_str();
96                    }
97                }
98                alias
99            }
100        }
101    }
102}
103
104/// Compiled threshold condition with one or more predicates (supports ranges).
105#[derive(Debug, Clone)]
106pub struct CompiledCondition {
107    /// Optional field name for value_count, value_sum, value_avg, value_percentile.
108    pub field: Option<String>,
109    /// One or more predicates to satisfy (all must be true for the condition to match).
110    pub predicates: Vec<(ConditionOperator, f64)>,
111}
112
113impl CompiledCondition {
114    /// Check if the given value satisfies all predicates.
115    pub fn check(&self, value: f64) -> bool {
116        self.predicates.iter().all(|(op, threshold)| match op {
117            ConditionOperator::Lt => value < *threshold,
118            ConditionOperator::Lte => value <= *threshold,
119            ConditionOperator::Gt => value > *threshold,
120            ConditionOperator::Gte => value >= *threshold,
121            ConditionOperator::Eq => (value - *threshold).abs() < f64::EPSILON,
122            ConditionOperator::Neq => (value - *threshold).abs() >= f64::EPSILON,
123        })
124    }
125}
126
127// =============================================================================
128// Group Key
129// =============================================================================
130
131/// Composite key for group-by partitioning.
132///
133/// Each element corresponds to a `GroupByField` value extracted from an event.
134/// `None` means the field was absent from the event.
135#[derive(Debug, Clone, Hash, Eq, PartialEq)]
136pub struct GroupKey(pub Vec<Option<String>>);
137
138impl GroupKey {
139    /// Extract a group key from an event given the group-by fields and the
140    /// rule reference identifiers (ID, name, etc.) that produced the detection match.
141    pub fn extract(event: &Event, group_by: &[GroupByField], rule_refs: &[&str]) -> Self {
142        let values = group_by
143            .iter()
144            .map(|field| {
145                let field_name = field.resolve(rule_refs);
146                event.get_field(field_name).and_then(value_to_string)
147            })
148            .collect();
149        GroupKey(values)
150    }
151
152    /// Build a group key from explicit field-value pairs (for chaining).
153    pub fn from_pairs(pairs: &[(String, String)], group_by: &[GroupByField]) -> Self {
154        let values = group_by
155            .iter()
156            .map(|field| {
157                let name = field.name();
158                pairs
159                    .iter()
160                    .find(|(k, _)| k == name)
161                    .map(|(_, v)| v.clone())
162            })
163            .collect();
164        GroupKey(values)
165    }
166
167    /// Convert to field-name/value pairs for output.
168    pub fn to_pairs(&self, group_by: &[GroupByField]) -> Vec<(String, String)> {
169        group_by
170            .iter()
171            .zip(self.0.iter())
172            .filter_map(|(field, value)| {
173                value
174                    .as_ref()
175                    .map(|v| (field.name().to_string(), v.clone()))
176            })
177            .collect()
178    }
179}
180
181/// Convert a JSON value to a string for group-key purposes.
182fn value_to_string(v: &serde_json::Value) -> Option<String> {
183    match v {
184        serde_json::Value::String(s) => Some(s.clone()),
185        serde_json::Value::Number(n) => Some(n.to_string()),
186        serde_json::Value::Bool(b) => Some(b.to_string()),
187        _ => None,
188    }
189}
190
191// =============================================================================
192// Compressed Event Buffer
193// =============================================================================
194
195/// Default compression level — fast compression (level 1) for minimal latency.
196/// Deflate level 1 still achieves ~2-3x compression on JSON while being very fast.
197const COMPRESSION_LEVEL: Compression = Compression::fast();
198
199/// Compressed event storage for correlation event inclusion.
200///
201/// Stores event JSON payloads as individually deflate-compressed blobs alongside
202/// their timestamps. This enables per-event eviction (matching `WindowState`
203/// eviction) while keeping memory usage low.
204///
205/// # Memory Model
206///
207/// Each stored event costs approximately `compressed_size + 24` bytes
208/// (8 for timestamp, 16 for Vec overhead). Typical JSON events (500B–5KB)
209/// compress to 100B–1KB with deflate, giving 3–5x memory savings.
210///
211/// The buffer enforces a hard cap (`max_events`) so memory is bounded at:
212///   `max_events × (avg_compressed_size + 24)` bytes per group key.
213#[derive(Debug, Clone)]
214pub struct EventBuffer {
215    /// (timestamp, deflate-compressed event JSON) pairs, ordered by timestamp.
216    entries: VecDeque<(i64, Vec<u8>)>,
217    /// Maximum number of events to retain. When exceeded, the oldest event is
218    /// evicted regardless of the time window.
219    max_events: usize,
220}
221
222impl EventBuffer {
223    /// Create a new event buffer with the given capacity cap.
224    pub fn new(max_events: usize) -> Self {
225        EventBuffer {
226            entries: VecDeque::with_capacity(max_events.min(64)),
227            max_events,
228        }
229    }
230
231    /// Compress and store an event. Evicts the oldest entry if at capacity.
232    pub fn push(&mut self, ts: i64, event: &serde_json::Value) {
233        // Compress the event JSON with deflate
234        if let Some(compressed) = compress_event(event) {
235            if self.entries.len() >= self.max_events {
236                self.entries.pop_front();
237            }
238            self.entries.push_back((ts, compressed));
239        }
240    }
241
242    /// Remove all entries older than the cutoff timestamp.
243    pub fn evict(&mut self, cutoff: i64) {
244        while self.entries.front().is_some_and(|(t, _)| *t < cutoff) {
245            self.entries.pop_front();
246        }
247    }
248
249    /// Decompress and return all stored events.
250    pub fn decompress_all(&self) -> Vec<serde_json::Value> {
251        self.entries
252            .iter()
253            .filter_map(|(_, compressed)| decompress_event(compressed))
254            .collect()
255    }
256
257    /// Returns true if there are no stored events.
258    pub fn is_empty(&self) -> bool {
259        self.entries.is_empty()
260    }
261
262    /// Clear all stored events.
263    pub fn clear(&mut self) {
264        self.entries.clear();
265    }
266
267    /// Total compressed bytes stored (for monitoring/diagnostics).
268    pub fn compressed_bytes(&self) -> usize {
269        self.entries.iter().map(|(_, data)| data.len()).sum()
270    }
271
272    /// Number of stored events.
273    pub fn len(&self) -> usize {
274        self.entries.len()
275    }
276}
277
278/// Compress an event JSON value using deflate.
279fn compress_event(event: &serde_json::Value) -> Option<Vec<u8>> {
280    let json_bytes = serde_json::to_vec(event).ok()?;
281    let mut encoder = DeflateEncoder::new(Vec::new(), COMPRESSION_LEVEL);
282    encoder.write_all(&json_bytes).ok()?;
283    encoder.finish().ok()
284}
285
286/// Decompress a deflate-compressed event back to a JSON value.
287fn decompress_event(compressed: &[u8]) -> Option<serde_json::Value> {
288    let mut decoder = DeflateDecoder::new(compressed);
289    let mut json_bytes = Vec::new();
290    decoder.read_to_end(&mut json_bytes).ok()?;
291    serde_json::from_slice(&json_bytes).ok()
292}
293
294// =============================================================================
295// Event Reference (lightweight mode)
296// =============================================================================
297
298/// A lightweight event reference: timestamp plus optional event ID.
299///
300/// Used in `Refs` mode for memory-efficient correlation event tracking.
301/// Each ref costs ~40 bytes (vs. 100–1000+ bytes for compressed events),
302/// making this mode suitable for high-volume correlations where only
303/// traceability is needed.
304#[derive(Debug, Clone, Serialize)]
305pub struct EventRef {
306    /// Event timestamp (epoch seconds).
307    pub timestamp: i64,
308    /// Event ID extracted from common fields (`id`, `_id`, `event_id`, etc.).
309    #[serde(skip_serializing_if = "Option::is_none")]
310    pub id: Option<String>,
311}
312
313/// Lightweight event reference buffer for `Refs` mode.
314///
315/// Stores only timestamps and optional event IDs — no event payload,
316/// no compression. This is the minimal-memory alternative to `EventBuffer`.
317#[derive(Debug, Clone)]
318pub struct EventRefBuffer {
319    /// Event references, ordered by timestamp.
320    entries: VecDeque<EventRef>,
321    /// Maximum number of refs to retain.
322    max_events: usize,
323}
324
325impl EventRefBuffer {
326    /// Create a new ref buffer with the given capacity cap.
327    pub fn new(max_events: usize) -> Self {
328        EventRefBuffer {
329            entries: VecDeque::with_capacity(max_events.min(64)),
330            max_events,
331        }
332    }
333
334    /// Store a reference to an event. Evicts the oldest ref if at capacity.
335    pub fn push(&mut self, ts: i64, event: &serde_json::Value) {
336        if self.entries.len() >= self.max_events {
337            self.entries.pop_front();
338        }
339        let id = extract_event_id(event);
340        self.entries.push_back(EventRef { timestamp: ts, id });
341    }
342
343    /// Remove all refs older than the cutoff timestamp.
344    pub fn evict(&mut self, cutoff: i64) {
345        while self.entries.front().is_some_and(|r| r.timestamp < cutoff) {
346            self.entries.pop_front();
347        }
348    }
349
350    /// Return cloned refs.
351    pub fn refs(&self) -> Vec<EventRef> {
352        self.entries.iter().cloned().collect()
353    }
354
355    /// Returns true if there are no stored refs.
356    pub fn is_empty(&self) -> bool {
357        self.entries.is_empty()
358    }
359
360    /// Clear all stored refs.
361    pub fn clear(&mut self) {
362        self.entries.clear();
363    }
364
365    /// Number of stored refs.
366    pub fn len(&self) -> usize {
367        self.entries.len()
368    }
369}
370
371/// Try to extract an event ID from common fields.
372///
373/// Checks (in order): `id`, `_id`, `event_id`, `EventRecordID`, `event.id`.
374/// Returns the first found value as a string.
375fn extract_event_id(event: &serde_json::Value) -> Option<String> {
376    const ID_FIELDS: &[&str] = &["id", "_id", "event_id", "EventRecordID", "event.id"];
377    for field in ID_FIELDS {
378        if let Some(val) = event.get(field) {
379            return match val {
380                serde_json::Value::String(s) => Some(s.clone()),
381                serde_json::Value::Number(n) => Some(n.to_string()),
382                _ => None,
383            };
384        }
385    }
386    None
387}
388
389// =============================================================================
390// Window State
391// =============================================================================
392
393/// Per-group mutable state within a time window.
394///
395/// Each variant matches the type of aggregation being performed.
396#[derive(Debug, Clone)]
397pub enum WindowState {
398    /// For `event_count`: timestamps of matching events.
399    EventCount { timestamps: VecDeque<i64> },
400    /// For `value_count`: (timestamp, field_value) pairs.
401    ValueCount { entries: VecDeque<(i64, String)> },
402    /// For `temporal` / `temporal_ordered`: rule_ref -> list of hit timestamps.
403    Temporal {
404        rule_hits: HashMap<String, VecDeque<i64>>,
405    },
406    /// For `value_sum`, `value_avg`, `value_percentile`, `value_median`:
407    /// (timestamp, numeric_value) pairs.
408    NumericAgg { entries: VecDeque<(i64, f64)> },
409}
410
411impl WindowState {
412    /// Create a new empty window state for the given correlation type.
413    pub fn new_for(corr_type: CorrelationType) -> Self {
414        match corr_type {
415            CorrelationType::EventCount => WindowState::EventCount {
416                timestamps: VecDeque::new(),
417            },
418            CorrelationType::ValueCount => WindowState::ValueCount {
419                entries: VecDeque::new(),
420            },
421            CorrelationType::Temporal | CorrelationType::TemporalOrdered => WindowState::Temporal {
422                rule_hits: HashMap::new(),
423            },
424            CorrelationType::ValueSum
425            | CorrelationType::ValueAvg
426            | CorrelationType::ValuePercentile
427            | CorrelationType::ValueMedian => WindowState::NumericAgg {
428                entries: VecDeque::new(),
429            },
430        }
431    }
432
433    /// Remove all entries older than the cutoff timestamp.
434    pub fn evict(&mut self, cutoff: i64) {
435        match self {
436            WindowState::EventCount { timestamps } => {
437                while timestamps.front().is_some_and(|&t| t < cutoff) {
438                    timestamps.pop_front();
439                }
440            }
441            WindowState::ValueCount { entries } => {
442                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
443                    entries.pop_front();
444                }
445            }
446            WindowState::Temporal { rule_hits } => {
447                for timestamps in rule_hits.values_mut() {
448                    while timestamps.front().is_some_and(|&t| t < cutoff) {
449                        timestamps.pop_front();
450                    }
451                }
452                // Remove empty rule entries
453                rule_hits.retain(|_, ts| !ts.is_empty());
454            }
455            WindowState::NumericAgg { entries } => {
456                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
457                    entries.pop_front();
458                }
459            }
460        }
461    }
462
463    /// Returns true if this state has no entries.
464    pub fn is_empty(&self) -> bool {
465        match self {
466            WindowState::EventCount { timestamps } => timestamps.is_empty(),
467            WindowState::ValueCount { entries } => entries.is_empty(),
468            WindowState::Temporal { rule_hits } => rule_hits.is_empty(),
469            WindowState::NumericAgg { entries } => entries.is_empty(),
470        }
471    }
472
473    /// Returns the most recent timestamp in this window, or `None` if empty.
474    pub fn latest_timestamp(&self) -> Option<i64> {
475        match self {
476            WindowState::EventCount { timestamps } => timestamps.back().copied(),
477            WindowState::ValueCount { entries } => entries.back().map(|(t, _)| *t),
478            WindowState::Temporal { rule_hits } => {
479                rule_hits.values().filter_map(|ts| ts.back().copied()).max()
480            }
481            WindowState::NumericAgg { entries } => entries.back().map(|(t, _)| *t),
482        }
483    }
484
485    /// Clear all entries from the window state (used by `CorrelationAction::Reset`).
486    pub fn clear(&mut self) {
487        match self {
488            WindowState::EventCount { timestamps } => timestamps.clear(),
489            WindowState::ValueCount { entries } => entries.clear(),
490            WindowState::Temporal { rule_hits } => rule_hits.clear(),
491            WindowState::NumericAgg { entries } => entries.clear(),
492        }
493    }
494
495    /// Record an event_count hit.
496    pub fn push_event_count(&mut self, ts: i64) {
497        if let WindowState::EventCount { timestamps } = self {
498            timestamps.push_back(ts);
499        }
500    }
501
502    /// Record a value_count hit with the field value.
503    pub fn push_value_count(&mut self, ts: i64, value: String) {
504        if let WindowState::ValueCount { entries } = self {
505            entries.push_back((ts, value));
506        }
507    }
508
509    /// Record a temporal hit for a specific rule reference.
510    pub fn push_temporal(&mut self, ts: i64, rule_ref: &str) {
511        if let WindowState::Temporal { rule_hits } = self {
512            rule_hits
513                .entry(rule_ref.to_string())
514                .or_default()
515                .push_back(ts);
516        }
517    }
518
519    /// Record a numeric aggregation value.
520    pub fn push_numeric(&mut self, ts: i64, value: f64) {
521        if let WindowState::NumericAgg { entries } = self {
522            entries.push_back((ts, value));
523        }
524    }
525
526    /// Evaluate the window state against the correlation condition.
527    ///
528    /// Returns `Some(aggregated_value)` if the condition is satisfied,
529    /// `None` otherwise.
530    ///
531    /// For temporal correlations with an extended expression, the expression
532    /// is evaluated against the set of rules that have fired in the window.
533    pub fn check_condition(
534        &self,
535        condition: &CompiledCondition,
536        corr_type: CorrelationType,
537        rule_refs: &[String],
538        extended_expr: Option<&ConditionExpr>,
539    ) -> Option<f64> {
540        let value = match (self, corr_type) {
541            (WindowState::EventCount { timestamps }, CorrelationType::EventCount) => {
542                timestamps.len() as f64
543            }
544            (WindowState::ValueCount { entries }, CorrelationType::ValueCount) => {
545                // Count distinct values
546                let distinct: HashSet<&String> = entries.iter().map(|(_, v)| v).collect();
547                distinct.len() as f64
548            }
549            (WindowState::Temporal { rule_hits }, CorrelationType::Temporal) => {
550                // If an extended expression is provided, evaluate it
551                if let Some(expr) = extended_expr {
552                    if eval_temporal_expr(expr, rule_hits) {
553                        // Return the count of fired rules as the value
554                        let fired: usize = rule_refs
555                            .iter()
556                            .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
557                            .count();
558                        return Some(fired as f64);
559                    } else {
560                        return None;
561                    }
562                }
563                // Default: count how many distinct referenced rules have fired
564                let fired: usize = rule_refs
565                    .iter()
566                    .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
567                    .count();
568                fired as f64
569            }
570            (WindowState::Temporal { rule_hits }, CorrelationType::TemporalOrdered) => {
571                // If an extended expression is provided, evaluate it first
572                if let Some(expr) = extended_expr
573                    && !eval_temporal_expr(expr, rule_hits)
574                {
575                    return None;
576                }
577                // Check if all referenced rules fired in order
578                if check_temporal_ordered(rule_refs, rule_hits) {
579                    rule_refs.len() as f64
580                } else {
581                    0.0
582                }
583            }
584            (WindowState::NumericAgg { entries }, CorrelationType::ValueSum) => {
585                entries.iter().map(|(_, v)| v).sum()
586            }
587            (WindowState::NumericAgg { entries }, CorrelationType::ValueAvg) => {
588                if entries.is_empty() {
589                    0.0
590                } else {
591                    let sum: f64 = entries.iter().map(|(_, v)| v).sum();
592                    sum / entries.len() as f64
593                }
594            }
595            (WindowState::NumericAgg { entries }, CorrelationType::ValuePercentile) => {
596                // Proper percentile calculation using linear interpolation.
597                // The condition threshold represents a percentile rank (0-100).
598                // We compute the value at that percentile from the window data.
599                if entries.is_empty() {
600                    return None;
601                }
602                let mut values: Vec<f64> = entries
603                    .iter()
604                    .map(|(_, v)| *v)
605                    .filter(|v| v.is_finite())
606                    .collect();
607                if values.is_empty() {
608                    return None;
609                }
610                values.sort_by(|a, b| a.partial_cmp(b).expect("NaN filtered"));
611                // Extract the percentile rank from the condition's first predicate
612                let percentile_rank = condition
613                    .predicates
614                    .first()
615                    .map(|(_, threshold)| *threshold)
616                    .unwrap_or(50.0);
617                let pval = percentile_linear_interp(&values, percentile_rank);
618                return Some(pval);
619            }
620            (WindowState::NumericAgg { entries }, CorrelationType::ValueMedian) => {
621                if entries.is_empty() {
622                    0.0
623                } else {
624                    let mut values: Vec<f64> = entries
625                        .iter()
626                        .map(|(_, v)| *v)
627                        .filter(|v| v.is_finite())
628                        .collect();
629                    if values.is_empty() {
630                        return None;
631                    }
632                    values.sort_by(|a, b| a.partial_cmp(b).expect("NaN filtered"));
633                    let mid = values.len() / 2;
634                    if values.len().is_multiple_of(2) && values.len() >= 2 {
635                        (values[mid - 1] + values[mid]) / 2.0
636                    } else {
637                        values[mid]
638                    }
639                }
640            }
641            _ => return None, // mismatched state/type
642        };
643
644        if condition.check(value) {
645            Some(value)
646        } else {
647            None
648        }
649    }
650}
651
652/// Check if all referenced rules fired in the correct order within the window.
653///
654/// For `temporal_ordered`, each rule must have at least one hit, and there
655/// must exist a sequence of timestamps (one per rule) that is non-decreasing
656/// and follows the rule ordering.
657fn check_temporal_ordered(
658    rule_refs: &[String],
659    rule_hits: &HashMap<String, VecDeque<i64>>,
660) -> bool {
661    if rule_refs.is_empty() {
662        return true;
663    }
664
665    // All rules must have at least one hit
666    for r in rule_refs {
667        if rule_hits.get(r.as_str()).is_none_or(|ts| ts.is_empty()) {
668            return false;
669        }
670    }
671
672    // Check if there's a valid ordered sequence: for each rule in order,
673    // find a timestamp >= the previous rule's chosen timestamp.
674    fn find_ordered(
675        rule_refs: &[String],
676        rule_hits: &HashMap<String, VecDeque<i64>>,
677        idx: usize,
678        min_ts: i64,
679    ) -> bool {
680        if idx >= rule_refs.len() {
681            return true;
682        }
683        let Some(timestamps) = rule_hits.get(&rule_refs[idx]) else {
684            return false;
685        };
686        for &ts in timestamps {
687            if ts >= min_ts && find_ordered(rule_refs, rule_hits, idx + 1, ts) {
688                return true;
689            }
690        }
691        false
692    }
693
694    find_ordered(rule_refs, rule_hits, 0, i64::MIN)
695}
696
697/// Evaluate a boolean condition expression against the set of rules that have
698/// fired within the temporal window.
699///
700/// Each `Identifier` in the expression is treated as a rule reference — it's
701/// `true` if that rule has at least one hit in `rule_hits`.
702fn eval_temporal_expr(expr: &ConditionExpr, rule_hits: &HashMap<String, VecDeque<i64>>) -> bool {
703    match expr {
704        ConditionExpr::Identifier(name) => rule_hits
705            .get(name.as_str())
706            .is_some_and(|ts| !ts.is_empty()),
707        ConditionExpr::And(children) => children.iter().all(|c| eval_temporal_expr(c, rule_hits)),
708        ConditionExpr::Or(children) => children.iter().any(|c| eval_temporal_expr(c, rule_hits)),
709        ConditionExpr::Not(child) => !eval_temporal_expr(child, rule_hits),
710        ConditionExpr::Selector { .. } => {
711            // Selectors are not meaningful for temporal condition evaluation
712            false
713        }
714    }
715}
716
717/// Compute the value at a given percentile rank using linear interpolation.
718///
719/// Returns 0.0 if `values` is empty.
720/// `values` must be sorted in ascending order.
721/// `percentile` is from 0.0 to 100.0.
722fn percentile_linear_interp(values: &[f64], percentile: f64) -> f64 {
723    if values.is_empty() {
724        return 0.0;
725    }
726    let n = values.len();
727    if n == 1 {
728        return values[0];
729    }
730
731    // Clamp percentile to [0, 100]
732    let p = percentile.clamp(0.0, 100.0) / 100.0;
733
734    // Use the "C = 1" interpolation method (most common in statistics)
735    // rank = p * (n - 1)
736    let rank = p * (n - 1) as f64;
737    let lower = rank.floor() as usize;
738    let upper = rank.ceil() as usize;
739    let fraction = rank - lower as f64;
740
741    if lower == upper || upper >= n {
742        values[lower.min(n - 1)]
743    } else {
744        values[lower] + fraction * (values[upper] - values[lower])
745    }
746}
747
748// =============================================================================
749// Compilation
750// =============================================================================
751
752/// Compile a parsed `CorrelationRule` into a `CompiledCorrelation`.
753pub fn compile_correlation(rule: &CorrelationRule) -> Result<CompiledCorrelation> {
754    // Build group-by fields, resolving aliases
755    let alias_map: HashMap<&str, &FieldAlias> =
756        rule.aliases.iter().map(|a| (a.alias.as_str(), a)).collect();
757
758    let group_by: Vec<GroupByField> = rule
759        .group_by
760        .iter()
761        .map(|field_name| {
762            if let Some(alias) = alias_map.get(field_name.as_str()) {
763                GroupByField::Aliased {
764                    alias: field_name.clone(),
765                    mapping: alias.mapping.clone(),
766                }
767            } else {
768                GroupByField::Direct(field_name.clone())
769            }
770        })
771        .collect();
772
773    // Compile condition
774    let (condition, extended_expr) = compile_condition(&rule.condition, rule.correlation_type)?;
775
776    // Resolve per-correlation overrides from custom attributes.
777    // These mirror the engine-level `rsigma.*` attributes but apply only
778    // to this correlation rule, taking precedence over engine defaults.
779    let suppress_secs = rule
780        .custom_attributes
781        .get("rsigma.suppress")
782        .and_then(|v| rsigma_parser::Timespan::parse(v).ok())
783        .map(|ts| ts.seconds);
784
785    let action = rule.custom_attributes.get("rsigma.action").and_then(|v| {
786        v.parse::<crate::correlation_engine::CorrelationAction>()
787            .ok()
788    });
789
790    let event_mode = rule
791        .custom_attributes
792        .get("rsigma.correlation_event_mode")
793        .and_then(|v| {
794            v.parse::<crate::correlation_engine::CorrelationEventMode>()
795                .ok()
796        });
797
798    let max_events = rule
799        .custom_attributes
800        .get("rsigma.max_correlation_events")
801        .and_then(|v| v.parse::<usize>().ok());
802
803    Ok(CompiledCorrelation {
804        id: rule.id.clone(),
805        name: rule.name.clone(),
806        title: rule.title.clone(),
807        level: rule.level,
808        tags: rule.tags.clone(),
809        correlation_type: rule.correlation_type,
810        rule_refs: rule.rules.clone(),
811        group_by,
812        timespan_secs: rule.timespan.seconds,
813        condition,
814        extended_expr,
815        generate: rule.generate,
816        suppress_secs,
817        action,
818        event_mode,
819        max_events,
820    })
821}
822
823/// Compile a `CorrelationCondition` into a `CompiledCondition` and optional expression.
824fn compile_condition(
825    cond: &CorrelationCondition,
826    corr_type: CorrelationType,
827) -> Result<(CompiledCondition, Option<ConditionExpr>)> {
828    match cond {
829        CorrelationCondition::Threshold { predicates, field } => Ok((
830            CompiledCondition {
831                field: field.clone(),
832                predicates: predicates
833                    .iter()
834                    .map(|(op, count)| (*op, *count as f64))
835                    .collect(),
836            },
837            None,
838        )),
839        CorrelationCondition::Extended(expr) => {
840            match corr_type {
841                CorrelationType::Temporal | CorrelationType::TemporalOrdered => {
842                    // For extended conditions, the threshold is a dummy (gte: 1)
843                    // since the actual evaluation is done via the expression tree.
844                    Ok((
845                        CompiledCondition {
846                            field: None,
847                            predicates: vec![(ConditionOperator::Gte, 1.0)],
848                        },
849                        Some(expr.clone()),
850                    ))
851                }
852                _ => Err(EvalError::CorrelationError(
853                    "Extended conditions are only supported for temporal correlation types"
854                        .to_string(),
855                )),
856            }
857        }
858    }
859}
860
861#[cfg(test)]
862mod tests {
863    use super::*;
864    use serde_json::json;
865
866    #[test]
867    fn test_group_key_extract() {
868        let v = json!({"User": "admin", "Host": "srv01"});
869        let event = Event::from_value(&v);
870        let group_by = vec![
871            GroupByField::Direct("User".to_string()),
872            GroupByField::Direct("Host".to_string()),
873        ];
874        let key = GroupKey::extract(&event, &group_by, &["rule1"]);
875        assert_eq!(
876            key.0,
877            vec![Some("admin".to_string()), Some("srv01".to_string())]
878        );
879    }
880
881    #[test]
882    fn test_group_key_missing_field() {
883        let v = json!({"User": "admin"});
884        let event = Event::from_value(&v);
885        let group_by = vec![
886            GroupByField::Direct("User".to_string()),
887            GroupByField::Direct("Host".to_string()),
888        ];
889        let key = GroupKey::extract(&event, &group_by, &["rule1"]);
890        assert_eq!(key.0, vec![Some("admin".to_string()), None]);
891    }
892
893    #[test]
894    fn test_group_key_aliased() {
895        let v = json!({"source.ip": "10.0.0.1"});
896        let event = Event::from_value(&v);
897        let group_by = vec![GroupByField::Aliased {
898            alias: "internal_ip".to_string(),
899            mapping: HashMap::from([
900                ("rule_a".to_string(), "source.ip".to_string()),
901                ("rule_b".to_string(), "destination.ip".to_string()),
902            ]),
903        }];
904        let key = GroupKey::extract(&event, &group_by, &["rule_a"]);
905        assert_eq!(key.0, vec![Some("10.0.0.1".to_string())]);
906    }
907
908    #[test]
909    fn test_condition_check() {
910        let cond = CompiledCondition {
911            field: None,
912            predicates: vec![(ConditionOperator::Gte, 100.0)],
913        };
914        assert!(!cond.check(99.0));
915        assert!(cond.check(100.0));
916        assert!(cond.check(101.0));
917    }
918
919    #[test]
920    fn test_condition_check_range() {
921        let cond = CompiledCondition {
922            field: None,
923            predicates: vec![
924                (ConditionOperator::Gt, 100.0),
925                (ConditionOperator::Lte, 200.0),
926            ],
927        };
928        assert!(!cond.check(100.0));
929        assert!(cond.check(101.0));
930        assert!(cond.check(200.0));
931        assert!(!cond.check(201.0));
932    }
933
934    #[test]
935    fn test_window_event_count() {
936        let mut state = WindowState::new_for(CorrelationType::EventCount);
937        for i in 0..5 {
938            state.push_event_count(1000 + i);
939        }
940        let cond = CompiledCondition {
941            field: None,
942            predicates: vec![(ConditionOperator::Gte, 5.0)],
943        };
944        assert_eq!(
945            state.check_condition(&cond, CorrelationType::EventCount, &[], None),
946            Some(5.0)
947        );
948    }
949
950    #[test]
951    fn test_window_event_count_eviction() {
952        let mut state = WindowState::new_for(CorrelationType::EventCount);
953        for i in 0..10 {
954            state.push_event_count(1000 + i);
955        }
956        // Evict events before ts=1005
957        state.evict(1005);
958        let cond = CompiledCondition {
959            field: None,
960            predicates: vec![(ConditionOperator::Gte, 5.0)],
961        };
962        assert_eq!(
963            state.check_condition(&cond, CorrelationType::EventCount, &[], None),
964            Some(5.0)
965        );
966    }
967
968    #[test]
969    fn test_window_value_count() {
970        let mut state = WindowState::new_for(CorrelationType::ValueCount);
971        state.push_value_count(1000, "user1".to_string());
972        state.push_value_count(1001, "user2".to_string());
973        state.push_value_count(1002, "user1".to_string()); // duplicate
974        state.push_value_count(1003, "user3".to_string());
975
976        let cond = CompiledCondition {
977            field: Some("User".to_string()),
978            predicates: vec![(ConditionOperator::Gte, 3.0)],
979        };
980        assert_eq!(
981            state.check_condition(&cond, CorrelationType::ValueCount, &[], None),
982            Some(3.0)
983        );
984    }
985
986    #[test]
987    fn test_window_temporal() {
988        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
989        let mut state = WindowState::new_for(CorrelationType::Temporal);
990        state.push_temporal(1000, "rule_a");
991        // Only rule_a fired — condition: all refs must fire
992        let cond = CompiledCondition {
993            field: None,
994            predicates: vec![(ConditionOperator::Gte, 2.0)],
995        };
996        assert!(
997            state
998                .check_condition(&cond, CorrelationType::Temporal, &refs, None)
999                .is_none()
1000        );
1001
1002        // Now rule_b fires too
1003        state.push_temporal(1001, "rule_b");
1004        assert_eq!(
1005            state.check_condition(&cond, CorrelationType::Temporal, &refs, None),
1006            Some(2.0)
1007        );
1008    }
1009
1010    #[test]
1011    fn test_window_temporal_ordered() {
1012        let refs = vec![
1013            "rule_a".to_string(),
1014            "rule_b".to_string(),
1015            "rule_c".to_string(),
1016        ];
1017        let mut state = WindowState::new_for(CorrelationType::TemporalOrdered);
1018        // Fire in order: a, b, c
1019        state.push_temporal(1000, "rule_a");
1020        state.push_temporal(1001, "rule_b");
1021        state.push_temporal(1002, "rule_c");
1022
1023        let cond = CompiledCondition {
1024            field: None,
1025            predicates: vec![(ConditionOperator::Gte, 3.0)],
1026        };
1027        assert!(
1028            state
1029                .check_condition(&cond, CorrelationType::TemporalOrdered, &refs, None)
1030                .is_some()
1031        );
1032    }
1033
1034    #[test]
1035    fn test_window_temporal_ordered_wrong_order() {
1036        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
1037        let mut state = WindowState::new_for(CorrelationType::TemporalOrdered);
1038        // Fire in wrong order: b before a
1039        state.push_temporal(1000, "rule_b");
1040        state.push_temporal(1001, "rule_a");
1041
1042        let cond = CompiledCondition {
1043            field: None,
1044            predicates: vec![(ConditionOperator::Gte, 2.0)],
1045        };
1046        assert!(
1047            state
1048                .check_condition(&cond, CorrelationType::TemporalOrdered, &refs, None)
1049                .is_none()
1050        );
1051    }
1052
1053    #[test]
1054    fn test_window_value_sum() {
1055        let mut state = WindowState::new_for(CorrelationType::ValueSum);
1056        state.push_numeric(1000, 500.0);
1057        state.push_numeric(1001, 600.0);
1058
1059        let cond = CompiledCondition {
1060            field: Some("bytes_sent".to_string()),
1061            predicates: vec![(ConditionOperator::Gt, 1000.0)],
1062        };
1063        assert_eq!(
1064            state.check_condition(&cond, CorrelationType::ValueSum, &[], None),
1065            Some(1100.0)
1066        );
1067    }
1068
1069    #[test]
1070    fn test_window_value_avg() {
1071        let mut state = WindowState::new_for(CorrelationType::ValueAvg);
1072        state.push_numeric(1000, 100.0);
1073        state.push_numeric(1001, 200.0);
1074        state.push_numeric(1002, 300.0);
1075
1076        let cond = CompiledCondition {
1077            field: Some("bytes".to_string()),
1078            predicates: vec![(ConditionOperator::Gte, 200.0)],
1079        };
1080        assert_eq!(
1081            state.check_condition(&cond, CorrelationType::ValueAvg, &[], None),
1082            Some(200.0)
1083        );
1084    }
1085
1086    #[test]
1087    fn test_window_value_median() {
1088        let mut state = WindowState::new_for(CorrelationType::ValueMedian);
1089        state.push_numeric(1000, 10.0);
1090        state.push_numeric(1001, 20.0);
1091        state.push_numeric(1002, 30.0);
1092
1093        let cond = CompiledCondition {
1094            field: Some("latency".to_string()),
1095            predicates: vec![(ConditionOperator::Gte, 20.0)],
1096        };
1097        assert_eq!(
1098            state.check_condition(&cond, CorrelationType::ValueMedian, &[], None),
1099            Some(20.0)
1100        );
1101    }
1102
1103    #[test]
1104    fn test_compile_correlation_basic() {
1105        use rsigma_parser::parse_sigma_yaml;
1106
1107        let yaml = r#"
1108title: Base Rule
1109id: f305fd62-beca-47da-ad95-7690a0620084
1110logsource:
1111    product: aws
1112    service: cloudtrail
1113detection:
1114    selection:
1115        eventSource: "s3.amazonaws.com"
1116    condition: selection
1117level: low
1118---
1119title: Multiple AWS bucket enumerations
1120id: be246094-01d3-4bba-88de-69e582eba0cc
1121status: experimental
1122correlation:
1123    type: event_count
1124    rules:
1125        - f305fd62-beca-47da-ad95-7690a0620084
1126    group-by:
1127        - userIdentity.arn
1128    timespan: 1h
1129    condition:
1130        gte: 100
1131level: high
1132"#;
1133        let collection = parse_sigma_yaml(yaml).unwrap();
1134        assert_eq!(collection.correlations.len(), 1);
1135
1136        let compiled = compile_correlation(&collection.correlations[0]).unwrap();
1137        assert_eq!(compiled.correlation_type, CorrelationType::EventCount);
1138        assert_eq!(compiled.timespan_secs, 3600);
1139        assert_eq!(compiled.rule_refs.len(), 1);
1140        assert_eq!(compiled.group_by.len(), 1);
1141        assert!(compiled.condition.check(100.0));
1142        assert!(!compiled.condition.check(99.0));
1143    }
1144
1145    // =========================================================================
1146    // Extended temporal condition tests
1147    // =========================================================================
1148
1149    #[test]
1150    fn test_eval_temporal_expr_and() {
1151        let mut rule_hits = HashMap::new();
1152        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1153        rule_hits.insert("rule_b".to_string(), VecDeque::from([1001]));
1154
1155        let expr = ConditionExpr::And(vec![
1156            ConditionExpr::Identifier("rule_a".to_string()),
1157            ConditionExpr::Identifier("rule_b".to_string()),
1158        ]);
1159        assert!(eval_temporal_expr(&expr, &rule_hits));
1160    }
1161
1162    #[test]
1163    fn test_eval_temporal_expr_and_incomplete() {
1164        let mut rule_hits = HashMap::new();
1165        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1166        // rule_b not fired
1167
1168        let expr = ConditionExpr::And(vec![
1169            ConditionExpr::Identifier("rule_a".to_string()),
1170            ConditionExpr::Identifier("rule_b".to_string()),
1171        ]);
1172        assert!(!eval_temporal_expr(&expr, &rule_hits));
1173    }
1174
1175    #[test]
1176    fn test_eval_temporal_expr_or() {
1177        let mut rule_hits = HashMap::new();
1178        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1179
1180        let expr = ConditionExpr::Or(vec![
1181            ConditionExpr::Identifier("rule_a".to_string()),
1182            ConditionExpr::Identifier("rule_b".to_string()),
1183        ]);
1184        assert!(eval_temporal_expr(&expr, &rule_hits));
1185    }
1186
1187    #[test]
1188    fn test_eval_temporal_expr_not() {
1189        let rule_hits = HashMap::new();
1190
1191        let expr = ConditionExpr::Not(Box::new(ConditionExpr::Identifier("rule_a".to_string())));
1192        assert!(eval_temporal_expr(&expr, &rule_hits));
1193    }
1194
1195    #[test]
1196    fn test_eval_temporal_expr_complex() {
1197        let mut rule_hits = HashMap::new();
1198        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1199        rule_hits.insert("rule_b".to_string(), VecDeque::from([1001]));
1200        // rule_c NOT fired
1201
1202        // (rule_a and rule_b) and not rule_c
1203        let expr = ConditionExpr::And(vec![
1204            ConditionExpr::And(vec![
1205                ConditionExpr::Identifier("rule_a".to_string()),
1206                ConditionExpr::Identifier("rule_b".to_string()),
1207            ]),
1208            ConditionExpr::Not(Box::new(ConditionExpr::Identifier("rule_c".to_string()))),
1209        ]);
1210        assert!(eval_temporal_expr(&expr, &rule_hits));
1211    }
1212
1213    #[test]
1214    fn test_check_condition_with_extended_expr() {
1215        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
1216        let mut state = WindowState::new_for(CorrelationType::Temporal);
1217        state.push_temporal(1000, "rule_a");
1218        state.push_temporal(1001, "rule_b");
1219
1220        let cond = CompiledCondition {
1221            field: None,
1222            predicates: vec![(ConditionOperator::Gte, 1.0)],
1223        };
1224        let expr = ConditionExpr::And(vec![
1225            ConditionExpr::Identifier("rule_a".to_string()),
1226            ConditionExpr::Identifier("rule_b".to_string()),
1227        ]);
1228
1229        // With expression: should match (both rules fired)
1230        assert!(
1231            state
1232                .check_condition(&cond, CorrelationType::Temporal, &refs, Some(&expr))
1233                .is_some()
1234        );
1235
1236        // Now test with only rule_a: expression should fail
1237        let mut state2 = WindowState::new_for(CorrelationType::Temporal);
1238        state2.push_temporal(1000, "rule_a");
1239        assert!(
1240            state2
1241                .check_condition(&cond, CorrelationType::Temporal, &refs, Some(&expr))
1242                .is_none()
1243        );
1244    }
1245
1246    // =========================================================================
1247    // Percentile linear interpolation tests
1248    // =========================================================================
1249
1250    #[test]
1251    fn test_percentile_linear_interp_single() {
1252        assert!((percentile_linear_interp(&[42.0], 50.0) - 42.0).abs() < f64::EPSILON);
1253    }
1254
1255    #[test]
1256    fn test_percentile_linear_interp_basic() {
1257        // Values: [1, 2, 3, 4, 5]
1258        let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
1259        // 0th percentile = 1.0
1260        assert!((percentile_linear_interp(values, 0.0) - 1.0).abs() < f64::EPSILON);
1261        // 25th percentile = 2.0
1262        assert!((percentile_linear_interp(values, 25.0) - 2.0).abs() < f64::EPSILON);
1263        // 50th percentile = 3.0
1264        assert!((percentile_linear_interp(values, 50.0) - 3.0).abs() < f64::EPSILON);
1265        // 75th percentile = 4.0
1266        assert!((percentile_linear_interp(values, 75.0) - 4.0).abs() < f64::EPSILON);
1267        // 100th percentile = 5.0
1268        assert!((percentile_linear_interp(values, 100.0) - 5.0).abs() < f64::EPSILON);
1269    }
1270
1271    #[test]
1272    fn test_percentile_linear_interp_interpolation() {
1273        // Values: [10, 20, 30, 40]
1274        let values = &[10.0, 20.0, 30.0, 40.0];
1275        // 50th percentile: rank = 0.5 * 3 = 1.5, interp between 20 and 30 = 25
1276        assert!((percentile_linear_interp(values, 50.0) - 25.0).abs() < f64::EPSILON);
1277    }
1278
1279    #[test]
1280    fn test_percentile_linear_interp_1st_percentile() {
1281        // Values: [1, 2, 3, ..., 100]
1282        let values: Vec<f64> = (1..=100).map(|x| x as f64).collect();
1283        // 1st percentile = 1.0 + 0.01 * 99 * (2.0 - 1.0) ~ 1.99
1284        let p1 = percentile_linear_interp(&values, 1.0);
1285        assert!((p1 - 1.99).abs() < 0.01);
1286    }
1287
1288    #[test]
1289    fn test_value_percentile_check_condition() {
1290        let mut state = WindowState::new_for(CorrelationType::ValuePercentile);
1291        // Push 100 values: 1.0, 2.0, ..., 100.0
1292        for i in 1..=100 {
1293            state.push_numeric(1000 + i, i as f64);
1294        }
1295
1296        let cond = CompiledCondition {
1297            field: Some("latency".to_string()),
1298            // The condition threshold is used as the percentile rank
1299            predicates: vec![(ConditionOperator::Lte, 50.0)],
1300        };
1301        // 50th percentile of 1..100 should be ~50.5
1302        let result = state.check_condition(&cond, CorrelationType::ValuePercentile, &[], None);
1303        assert!(result.is_some());
1304        let val = result.unwrap();
1305        assert!((val - 50.5).abs() < 1.0, "expected ~50.5, got {val}");
1306    }
1307
1308    #[test]
1309    fn test_percentile_0th_and_100th() {
1310        let values = &[5.0, 10.0, 15.0, 20.0];
1311        assert!((percentile_linear_interp(values, 0.0) - 5.0).abs() < f64::EPSILON);
1312        assert!((percentile_linear_interp(values, 100.0) - 20.0).abs() < f64::EPSILON);
1313    }
1314
1315    #[test]
1316    fn test_percentile_two_values() {
1317        let values = &[10.0, 20.0];
1318        // 50th percentile between 10 and 20 = 15
1319        assert!((percentile_linear_interp(values, 50.0) - 15.0).abs() < f64::EPSILON);
1320        // 25th percentile = 12.5
1321        assert!((percentile_linear_interp(values, 25.0) - 12.5).abs() < f64::EPSILON);
1322    }
1323
1324    #[test]
1325    fn test_percentile_clamps_out_of_range() {
1326        let values = &[1.0, 2.0, 3.0];
1327        // Negative percentile clamps to 0
1328        assert!((percentile_linear_interp(values, -10.0) - 1.0).abs() < f64::EPSILON);
1329        // > 100 clamps to 100
1330        assert!((percentile_linear_interp(values, 150.0) - 3.0).abs() < f64::EPSILON);
1331    }
1332
1333    #[test]
1334    fn test_value_percentile_empty_window() {
1335        let state = WindowState::new_for(CorrelationType::ValuePercentile);
1336        let cond = CompiledCondition {
1337            field: Some("latency".to_string()),
1338            predicates: vec![(ConditionOperator::Lte, 50.0)],
1339        };
1340        // Empty window should return None
1341        assert!(
1342            state
1343                .check_condition(&cond, CorrelationType::ValuePercentile, &[], None)
1344                .is_none()
1345        );
1346    }
1347
1348    #[test]
1349    fn test_extended_temporal_or_single_rule() {
1350        // "rule_a or rule_b" — only rule_a fired
1351        let mut rule_hits = HashMap::new();
1352        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1353
1354        let expr = ConditionExpr::Or(vec![
1355            ConditionExpr::Identifier("rule_a".to_string()),
1356            ConditionExpr::Identifier("rule_b".to_string()),
1357        ]);
1358        assert!(eval_temporal_expr(&expr, &rule_hits));
1359    }
1360
1361    #[test]
1362    fn test_extended_temporal_empty_hits() {
1363        let rule_hits = HashMap::new();
1364
1365        // "rule_a and rule_b" — nothing fired
1366        let expr = ConditionExpr::And(vec![
1367            ConditionExpr::Identifier("rule_a".to_string()),
1368            ConditionExpr::Identifier("rule_b".to_string()),
1369        ]);
1370        assert!(!eval_temporal_expr(&expr, &rule_hits));
1371
1372        // "rule_a or rule_b" — nothing fired
1373        let expr_or = ConditionExpr::Or(vec![
1374            ConditionExpr::Identifier("rule_a".to_string()),
1375            ConditionExpr::Identifier("rule_b".to_string()),
1376        ]);
1377        assert!(!eval_temporal_expr(&expr_or, &rule_hits));
1378    }
1379
1380    #[test]
1381    fn test_extended_temporal_with_empty_deque() {
1382        // Rule exists in map but with empty deque (all evicted)
1383        let mut rule_hits = HashMap::new();
1384        rule_hits.insert("rule_a".to_string(), VecDeque::new());
1385        rule_hits.insert("rule_b".to_string(), VecDeque::from([1000]));
1386
1387        let expr = ConditionExpr::And(vec![
1388            ConditionExpr::Identifier("rule_a".to_string()),
1389            ConditionExpr::Identifier("rule_b".to_string()),
1390        ]);
1391        // rule_a has empty deque — should be treated as not fired
1392        assert!(!eval_temporal_expr(&expr, &rule_hits));
1393    }
1394
1395    #[test]
1396    fn test_check_condition_temporal_no_extended_expr() {
1397        // Standard temporal without extended expr: uses threshold count
1398        let refs = vec![
1399            "rule_a".to_string(),
1400            "rule_b".to_string(),
1401            "rule_c".to_string(),
1402        ];
1403        let mut state = WindowState::new_for(CorrelationType::Temporal);
1404        state.push_temporal(1000, "rule_a");
1405        state.push_temporal(1001, "rule_b");
1406
1407        // Threshold: at least 2 rules must fire
1408        let cond = CompiledCondition {
1409            field: None,
1410            predicates: vec![(ConditionOperator::Gte, 2.0)],
1411        };
1412        // Without extended expr: 2 of 3 rules fired, meets gte 2
1413        assert_eq!(
1414            state.check_condition(&cond, CorrelationType::Temporal, &refs, None),
1415            Some(2.0)
1416        );
1417
1418        // With threshold 3: not enough
1419        let cond3 = CompiledCondition {
1420            field: None,
1421            predicates: vec![(ConditionOperator::Gte, 3.0)],
1422        };
1423        assert!(
1424            state
1425                .check_condition(&cond3, CorrelationType::Temporal, &refs, None)
1426                .is_none()
1427        );
1428    }
1429
1430    // =========================================================================
1431    // EventBuffer tests
1432    // =========================================================================
1433
1434    #[test]
1435    fn test_event_buffer_push_and_decompress() {
1436        let mut buf = EventBuffer::new(10);
1437        let event = json!({"User": "admin", "action": "login", "src_ip": "10.0.0.1"});
1438        buf.push(1000, &event);
1439
1440        assert_eq!(buf.len(), 1);
1441        assert!(!buf.is_empty());
1442
1443        let events = buf.decompress_all();
1444        assert_eq!(events.len(), 1);
1445        assert_eq!(events[0], event);
1446    }
1447
1448    #[test]
1449    fn test_event_buffer_compression_saves_memory() {
1450        let mut buf = EventBuffer::new(100);
1451        // Push a realistic-sized event (~500 bytes JSON)
1452        let event = json!({
1453            "User": "admin",
1454            "action": "login",
1455            "src_ip": "192.168.1.100",
1456            "dst_ip": "10.0.0.1",
1457            "EventTime": "2024-07-10T12:30:00Z",
1458            "process": "sshd",
1459            "host": "production-server-01.example.com",
1460            "message": "Accepted password for admin from 192.168.1.100 port 22 ssh2",
1461            "severity": "info",
1462            "tags": ["authentication", "network", "linux"]
1463        });
1464
1465        let raw_size = serde_json::to_vec(&event).unwrap().len();
1466        buf.push(1000, &event);
1467        let compressed_size = buf.compressed_bytes();
1468
1469        // Compressed should be notably smaller than raw
1470        assert!(
1471            compressed_size < raw_size,
1472            "Compressed {compressed_size}B should be less than raw {raw_size}B"
1473        );
1474
1475        // Verify roundtrip
1476        let events = buf.decompress_all();
1477        assert_eq!(events[0], event);
1478    }
1479
1480    #[test]
1481    fn test_event_buffer_max_events_cap() {
1482        let mut buf = EventBuffer::new(3);
1483
1484        for i in 0..5 {
1485            buf.push(1000 + i, &json!({"idx": i}));
1486        }
1487
1488        // Only the last 3 should remain
1489        assert_eq!(buf.len(), 3);
1490        let events = buf.decompress_all();
1491        assert_eq!(events[0], json!({"idx": 2}));
1492        assert_eq!(events[1], json!({"idx": 3}));
1493        assert_eq!(events[2], json!({"idx": 4}));
1494    }
1495
1496    #[test]
1497    fn test_event_buffer_eviction() {
1498        let mut buf = EventBuffer::new(10);
1499        for i in 0..5 {
1500            buf.push(1000 + i, &json!({"idx": i}));
1501        }
1502        assert_eq!(buf.len(), 5);
1503
1504        // Evict everything before ts 1003
1505        buf.evict(1003);
1506        assert_eq!(buf.len(), 2);
1507
1508        let events = buf.decompress_all();
1509        assert_eq!(events[0], json!({"idx": 3}));
1510        assert_eq!(events[1], json!({"idx": 4}));
1511    }
1512
1513    #[test]
1514    fn test_event_buffer_clear() {
1515        let mut buf = EventBuffer::new(10);
1516        buf.push(1000, &json!({"a": 1}));
1517        buf.push(1001, &json!({"b": 2}));
1518        assert_eq!(buf.len(), 2);
1519
1520        buf.clear();
1521        assert!(buf.is_empty());
1522        assert_eq!(buf.len(), 0);
1523        assert_eq!(buf.compressed_bytes(), 0);
1524    }
1525
1526    #[test]
1527    fn test_compress_decompress_roundtrip() {
1528        // Test various JSON shapes
1529        let values = vec![
1530            json!(null),
1531            json!(42),
1532            json!("hello world"),
1533            json!({"nested": {"deep": [1, 2, 3]}}),
1534            json!([1, "two", null, true, {"five": 5}]),
1535        ];
1536        for val in values {
1537            let compressed = compress_event(&val).unwrap();
1538            let decompressed = decompress_event(&compressed).unwrap();
1539            assert_eq!(decompressed, val, "Roundtrip failed for {val}");
1540        }
1541    }
1542
1543    // =========================================================================
1544    // EventRefBuffer tests
1545    // =========================================================================
1546
1547    #[test]
1548    fn test_event_ref_buffer_push_and_refs() {
1549        let mut buf = EventRefBuffer::new(10);
1550        buf.push(1000, &json!({"id": "evt-1", "data": "hello"}));
1551        buf.push(1001, &json!({"_id": 42, "data": "world"}));
1552        buf.push(1002, &json!({"data": "no-id"}));
1553
1554        assert_eq!(buf.len(), 3);
1555        let refs = buf.refs();
1556        assert_eq!(refs[0].timestamp, 1000);
1557        assert_eq!(refs[0].id, Some("evt-1".to_string()));
1558        assert_eq!(refs[1].timestamp, 1001);
1559        assert_eq!(refs[1].id, Some("42".to_string()));
1560        assert_eq!(refs[2].timestamp, 1002);
1561        assert_eq!(refs[2].id, None);
1562    }
1563
1564    #[test]
1565    fn test_event_ref_buffer_max_cap() {
1566        let mut buf = EventRefBuffer::new(3);
1567        for i in 0..5 {
1568            buf.push(1000 + i, &json!({"id": format!("e-{i}")}));
1569        }
1570        assert_eq!(buf.len(), 3);
1571        let refs = buf.refs();
1572        assert_eq!(refs[0].id, Some("e-2".to_string()));
1573        assert_eq!(refs[1].id, Some("e-3".to_string()));
1574        assert_eq!(refs[2].id, Some("e-4".to_string()));
1575    }
1576
1577    #[test]
1578    fn test_event_ref_buffer_eviction() {
1579        let mut buf = EventRefBuffer::new(10);
1580        for i in 0..5 {
1581            buf.push(1000 + i, &json!({"id": format!("e-{i}")}));
1582        }
1583        buf.evict(1003);
1584        assert_eq!(buf.len(), 2);
1585        let refs = buf.refs();
1586        assert_eq!(refs[0].timestamp, 1003);
1587        assert_eq!(refs[1].timestamp, 1004);
1588    }
1589
1590    #[test]
1591    fn test_event_ref_buffer_clear() {
1592        let mut buf = EventRefBuffer::new(10);
1593        buf.push(1000, &json!({"id": "a"}));
1594        buf.push(1001, &json!({"id": "b"}));
1595        assert_eq!(buf.len(), 2);
1596
1597        buf.clear();
1598        assert!(buf.is_empty());
1599        assert_eq!(buf.len(), 0);
1600    }
1601
1602    #[test]
1603    fn test_extract_event_id_common_fields() {
1604        assert_eq!(
1605            extract_event_id(&json!({"id": "abc"})),
1606            Some("abc".to_string())
1607        );
1608        assert_eq!(
1609            extract_event_id(&json!({"_id": 123})),
1610            Some("123".to_string())
1611        );
1612        assert_eq!(
1613            extract_event_id(&json!({"event_id": "x-1"})),
1614            Some("x-1".to_string())
1615        );
1616        assert_eq!(
1617            extract_event_id(&json!({"EventRecordID": 999})),
1618            Some("999".to_string())
1619        );
1620        assert_eq!(extract_event_id(&json!({"no_id_field": true})), None);
1621    }
1622
1623    #[test]
1624    fn test_compile_correlation_with_custom_attributes() {
1625        use rsigma_parser::*;
1626
1627        let mut custom_attributes = std::collections::HashMap::new();
1628        custom_attributes.insert(
1629            "rsigma.correlation_event_mode".to_string(),
1630            "refs".to_string(),
1631        );
1632        custom_attributes.insert(
1633            "rsigma.max_correlation_events".to_string(),
1634            "25".to_string(),
1635        );
1636        custom_attributes.insert("rsigma.suppress".to_string(), "5m".to_string());
1637        custom_attributes.insert("rsigma.action".to_string(), "reset".to_string());
1638
1639        let rule = CorrelationRule {
1640            title: "Test Corr".to_string(),
1641            id: Some("corr-1".to_string()),
1642            name: None,
1643            status: None,
1644            description: None,
1645            author: None,
1646            date: None,
1647            modified: None,
1648            references: vec![],
1649            tags: vec![],
1650            level: Some(Level::High),
1651            correlation_type: CorrelationType::EventCount,
1652            rules: vec!["rule-1".to_string()],
1653            group_by: vec!["User".to_string()],
1654            timespan: Timespan::parse("60s").unwrap(),
1655            condition: CorrelationCondition::Threshold {
1656                predicates: vec![(ConditionOperator::Gte, 5)],
1657                field: None,
1658            },
1659            aliases: vec![],
1660            generate: false,
1661            custom_attributes,
1662        };
1663
1664        let compiled = compile_correlation(&rule).unwrap();
1665
1666        // Per-correlation overrides should be resolved from custom_attributes
1667        assert_eq!(
1668            compiled.event_mode,
1669            Some(crate::correlation_engine::CorrelationEventMode::Refs)
1670        );
1671        assert_eq!(compiled.max_events, Some(25));
1672        assert_eq!(compiled.suppress_secs, Some(300)); // 5m = 300s
1673        assert_eq!(
1674            compiled.action,
1675            Some(crate::correlation_engine::CorrelationAction::Reset)
1676        );
1677    }
1678}