Skip to main content

rsigma_eval/
correlation.rs

1//! Compiled correlation types, group key, window state, and compilation.
2//!
3//! Transforms the parser's `CorrelationRule` AST into an optimized
4//! `CompiledCorrelation` with associated `WindowState` for stateful evaluation.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::io::{Read as IoRead, Write as IoWrite};
8
9use flate2::Compression;
10use flate2::read::DeflateDecoder;
11use flate2::write::DeflateEncoder;
12use serde::Serialize;
13
14use rsigma_parser::{
15    ConditionExpr, ConditionOperator, CorrelationCondition, CorrelationRule, CorrelationType,
16    FieldAlias, Level,
17};
18
19use crate::error::{EvalError, Result};
20use crate::event::Event;
21
22// =============================================================================
23// Compiled types
24// =============================================================================
25
26/// Compiled form of a `CorrelationRule`, ready for stateful evaluation.
27#[derive(Debug, Clone)]
28pub struct CompiledCorrelation {
29    pub id: Option<String>,
30    pub name: Option<String>,
31    pub title: String,
32    pub level: Option<Level>,
33    pub tags: Vec<String>,
34    pub correlation_type: CorrelationType,
35    /// IDs or names of referenced rules (detection or other correlations).
36    pub rule_refs: Vec<String>,
37    /// Resolved group-by fields (may include aliases).
38    pub group_by: Vec<GroupByField>,
39    /// Time window in seconds.
40    pub timespan_secs: u64,
41    /// Compiled threshold condition.
42    pub condition: CompiledCondition,
43    /// Extended boolean condition expression for temporal correlations.
44    /// When set, evaluates this expression against fired rules instead of
45    /// a simple threshold count.
46    pub extended_expr: Option<ConditionExpr>,
47    /// Whether referenced detection rules should also generate standalone matches.
48    pub generate: bool,
49    /// Per-correlation suppression window in seconds, resolved from the
50    /// `rsigma.suppress` custom attribute. `None` means use engine default.
51    pub suppress_secs: Option<u64>,
52    /// Per-correlation action on match, resolved from the `rsigma.action`
53    /// custom attribute. `None` means use engine default.
54    pub action: Option<crate::correlation_engine::CorrelationAction>,
55    /// Event inclusion mode for this correlation.
56    /// `None` means use the engine default (`CorrelationConfig.correlation_event_mode`).
57    pub event_mode: Option<crate::correlation_engine::CorrelationEventMode>,
58    /// Maximum events to store per window group for event inclusion.
59    /// `None` means use the engine default (`CorrelationConfig.max_correlation_events`).
60    pub max_events: Option<usize>,
61}
62
63/// A group-by field, potentially aliased per referenced rule.
64#[derive(Debug, Clone)]
65pub enum GroupByField {
66    /// Simple field name, same across all referenced rules.
67    Direct(String),
68    /// Aliased: maps rule_ref -> actual field name in that rule's events.
69    Aliased {
70        alias: String,
71        mapping: HashMap<String, String>,
72    },
73}
74
75impl GroupByField {
76    /// Get the display name of this group-by field.
77    pub fn name(&self) -> &str {
78        match self {
79            GroupByField::Direct(s) => s,
80            GroupByField::Aliased { alias, .. } => alias,
81        }
82    }
83
84    /// Resolve the actual field name to look up in an event, given which
85    /// rule (by ID or name) produced the detection match.
86    ///
87    /// Tries to find the rule in the alias mapping by any of the provided
88    /// identifiers (ID, name, etc.).
89    pub fn resolve(&self, rule_refs: &[&str]) -> &str {
90        match self {
91            GroupByField::Direct(s) => s,
92            GroupByField::Aliased { alias, mapping } => {
93                for r in rule_refs {
94                    if let Some(field) = mapping.get(*r) {
95                        return field.as_str();
96                    }
97                }
98                alias
99            }
100        }
101    }
102}
103
104/// Compiled threshold condition with one or more predicates (supports ranges).
105#[derive(Debug, Clone)]
106pub struct CompiledCondition {
107    /// Optional field name for value_count, value_sum, value_avg, value_percentile.
108    pub field: Option<String>,
109    /// One or more predicates to satisfy (all must be true for the condition to match).
110    pub predicates: Vec<(ConditionOperator, f64)>,
111}
112
113impl CompiledCondition {
114    /// Check if the given value satisfies all predicates.
115    pub fn check(&self, value: f64) -> bool {
116        self.predicates.iter().all(|(op, threshold)| match op {
117            ConditionOperator::Lt => value < *threshold,
118            ConditionOperator::Lte => value <= *threshold,
119            ConditionOperator::Gt => value > *threshold,
120            ConditionOperator::Gte => value >= *threshold,
121            ConditionOperator::Eq => (value - *threshold).abs() < f64::EPSILON,
122            ConditionOperator::Neq => (value - *threshold).abs() >= f64::EPSILON,
123        })
124    }
125}
126
127// =============================================================================
128// Group Key
129// =============================================================================
130
131/// Composite key for group-by partitioning.
132///
133/// Each element corresponds to a `GroupByField` value extracted from an event.
134/// `None` means the field was absent from the event.
135#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, serde::Deserialize)]
136pub struct GroupKey(pub Vec<Option<String>>);
137
138impl GroupKey {
139    /// Extract a group key from an event given the group-by fields and the
140    /// rule reference identifiers (ID, name, etc.) that produced the detection match.
141    pub fn extract(event: &Event, group_by: &[GroupByField], rule_refs: &[&str]) -> Self {
142        let values = group_by
143            .iter()
144            .map(|field| {
145                let field_name = field.resolve(rule_refs);
146                event.get_field(field_name).and_then(value_to_string)
147            })
148            .collect();
149        GroupKey(values)
150    }
151
152    /// Build a group key from explicit field-value pairs (for chaining).
153    pub fn from_pairs(pairs: &[(String, String)], group_by: &[GroupByField]) -> Self {
154        let values = group_by
155            .iter()
156            .map(|field| {
157                let name = field.name();
158                pairs
159                    .iter()
160                    .find(|(k, _)| k == name)
161                    .map(|(_, v)| v.clone())
162            })
163            .collect();
164        GroupKey(values)
165    }
166
167    /// Convert to field-name/value pairs for output.
168    pub fn to_pairs(&self, group_by: &[GroupByField]) -> Vec<(String, String)> {
169        group_by
170            .iter()
171            .zip(self.0.iter())
172            .filter_map(|(field, value)| {
173                value
174                    .as_ref()
175                    .map(|v| (field.name().to_string(), v.clone()))
176            })
177            .collect()
178    }
179}
180
181/// Convert a JSON value to a string for group-key purposes.
182fn value_to_string(v: &serde_json::Value) -> Option<String> {
183    match v {
184        serde_json::Value::String(s) => Some(s.clone()),
185        serde_json::Value::Number(n) => Some(n.to_string()),
186        serde_json::Value::Bool(b) => Some(b.to_string()),
187        _ => None,
188    }
189}
190
191// =============================================================================
192// Compressed Event Buffer
193// =============================================================================
194
195/// Default compression level — fast compression (level 1) for minimal latency.
196/// Deflate level 1 still achieves ~2-3x compression on JSON while being very fast.
197const COMPRESSION_LEVEL: Compression = Compression::fast();
198
199/// Compressed event storage for correlation event inclusion.
200///
201/// Stores event JSON payloads as individually deflate-compressed blobs alongside
202/// their timestamps. This enables per-event eviction (matching `WindowState`
203/// eviction) while keeping memory usage low.
204///
205/// # Memory Model
206///
207/// Each stored event costs approximately `compressed_size + 24` bytes
208/// (8 for timestamp, 16 for Vec overhead). Typical JSON events (500B–5KB)
209/// compress to 100B–1KB with deflate, giving 3–5x memory savings.
210///
211/// The buffer enforces a hard cap (`max_events`) so memory is bounded at:
212///   `max_events × (avg_compressed_size + 24)` bytes per group key.
213#[derive(Debug, Clone, Serialize, serde::Deserialize)]
214pub struct EventBuffer {
215    /// (timestamp, deflate-compressed event JSON) pairs, ordered by timestamp.
216    #[serde(with = "event_buffer_serde")]
217    entries: VecDeque<(i64, Vec<u8>)>,
218    /// Maximum number of events to retain. When exceeded, the oldest event is
219    /// evicted regardless of the time window.
220    max_events: usize,
221}
222
223/// Custom serde for EventBuffer entries: encodes compressed bytes as base64
224/// instead of JSON number arrays, cutting snapshot size ~3x.
225mod event_buffer_serde {
226    use serde::{Deserialize, Deserializer, Serialize, Serializer};
227    use std::collections::VecDeque;
228
229    #[derive(Serialize, Deserialize)]
230    struct Entry {
231        ts: i64,
232        #[serde(with = "base64_bytes")]
233        data: Vec<u8>,
234    }
235
236    mod base64_bytes {
237        use base64::Engine as _;
238        use base64::engine::general_purpose::STANDARD;
239        use serde::{Deserializer, Serializer};
240
241        pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
242            s.serialize_str(&STANDARD.encode(bytes))
243        }
244
245        pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
246            let s: String = serde::Deserialize::deserialize(d)?;
247            STANDARD.decode(s).map_err(serde::de::Error::custom)
248        }
249    }
250
251    pub fn serialize<S: Serializer>(
252        entries: &VecDeque<(i64, Vec<u8>)>,
253        s: S,
254    ) -> Result<S::Ok, S::Error> {
255        let v: Vec<Entry> = entries
256            .iter()
257            .map(|(ts, data)| Entry {
258                ts: *ts,
259                data: data.clone(),
260            })
261            .collect();
262        v.serialize(s)
263    }
264
265    pub fn deserialize<'de, D: Deserializer<'de>>(
266        d: D,
267    ) -> Result<VecDeque<(i64, Vec<u8>)>, D::Error> {
268        let v: Vec<Entry> = Vec::deserialize(d)?;
269        Ok(v.into_iter().map(|e| (e.ts, e.data)).collect())
270    }
271}
272
273impl EventBuffer {
274    /// Create a new event buffer with the given capacity cap.
275    pub fn new(max_events: usize) -> Self {
276        EventBuffer {
277            entries: VecDeque::with_capacity(max_events.min(64)),
278            max_events,
279        }
280    }
281
282    /// Compress and store an event. Evicts the oldest entry if at capacity.
283    pub fn push(&mut self, ts: i64, event: &serde_json::Value) {
284        // Compress the event JSON with deflate
285        if let Some(compressed) = compress_event(event) {
286            if self.entries.len() >= self.max_events {
287                self.entries.pop_front();
288            }
289            self.entries.push_back((ts, compressed));
290        }
291    }
292
293    /// Remove all entries older than the cutoff timestamp.
294    pub fn evict(&mut self, cutoff: i64) {
295        while self.entries.front().is_some_and(|(t, _)| *t < cutoff) {
296            self.entries.pop_front();
297        }
298    }
299
300    /// Decompress and return all stored events.
301    pub fn decompress_all(&self) -> Vec<serde_json::Value> {
302        self.entries
303            .iter()
304            .filter_map(|(_, compressed)| decompress_event(compressed))
305            .collect()
306    }
307
308    /// Returns true if there are no stored events.
309    pub fn is_empty(&self) -> bool {
310        self.entries.is_empty()
311    }
312
313    /// Clear all stored events.
314    pub fn clear(&mut self) {
315        self.entries.clear();
316    }
317
318    /// Total compressed bytes stored (for monitoring/diagnostics).
319    pub fn compressed_bytes(&self) -> usize {
320        self.entries.iter().map(|(_, data)| data.len()).sum()
321    }
322
323    /// Number of stored events.
324    pub fn len(&self) -> usize {
325        self.entries.len()
326    }
327}
328
329/// Compress an event JSON value using deflate.
330fn compress_event(event: &serde_json::Value) -> Option<Vec<u8>> {
331    let json_bytes = serde_json::to_vec(event).ok()?;
332    let mut encoder = DeflateEncoder::new(Vec::new(), COMPRESSION_LEVEL);
333    encoder.write_all(&json_bytes).ok()?;
334    encoder.finish().ok()
335}
336
337/// Decompress a deflate-compressed event back to a JSON value.
338fn decompress_event(compressed: &[u8]) -> Option<serde_json::Value> {
339    let mut decoder = DeflateDecoder::new(compressed);
340    let mut json_bytes = Vec::new();
341    decoder.read_to_end(&mut json_bytes).ok()?;
342    serde_json::from_slice(&json_bytes).ok()
343}
344
345// =============================================================================
346// Event Reference (lightweight mode)
347// =============================================================================
348
349/// A lightweight event reference: timestamp plus optional event ID.
350///
351/// Used in `Refs` mode for memory-efficient correlation event tracking.
352/// Each ref costs ~40 bytes (vs. 100–1000+ bytes for compressed events),
353/// making this mode suitable for high-volume correlations where only
354/// traceability is needed.
355#[derive(Debug, Clone, Serialize, serde::Deserialize)]
356pub struct EventRef {
357    /// Event timestamp (epoch seconds).
358    pub timestamp: i64,
359    /// Event ID extracted from common fields (`id`, `_id`, `event_id`, etc.).
360    #[serde(skip_serializing_if = "Option::is_none")]
361    pub id: Option<String>,
362}
363
364/// Lightweight event reference buffer for `Refs` mode.
365///
366/// Stores only timestamps and optional event IDs — no event payload,
367/// no compression. This is the minimal-memory alternative to `EventBuffer`.
368#[derive(Debug, Clone, Serialize, serde::Deserialize)]
369pub struct EventRefBuffer {
370    /// Event references, ordered by timestamp.
371    entries: VecDeque<EventRef>,
372    /// Maximum number of refs to retain.
373    max_events: usize,
374}
375
376impl EventRefBuffer {
377    /// Create a new ref buffer with the given capacity cap.
378    pub fn new(max_events: usize) -> Self {
379        EventRefBuffer {
380            entries: VecDeque::with_capacity(max_events.min(64)),
381            max_events,
382        }
383    }
384
385    /// Store a reference to an event. Evicts the oldest ref if at capacity.
386    pub fn push(&mut self, ts: i64, event: &serde_json::Value) {
387        if self.entries.len() >= self.max_events {
388            self.entries.pop_front();
389        }
390        let id = extract_event_id(event);
391        self.entries.push_back(EventRef { timestamp: ts, id });
392    }
393
394    /// Remove all refs older than the cutoff timestamp.
395    pub fn evict(&mut self, cutoff: i64) {
396        while self.entries.front().is_some_and(|r| r.timestamp < cutoff) {
397            self.entries.pop_front();
398        }
399    }
400
401    /// Return cloned refs.
402    pub fn refs(&self) -> Vec<EventRef> {
403        self.entries.iter().cloned().collect()
404    }
405
406    /// Returns true if there are no stored refs.
407    pub fn is_empty(&self) -> bool {
408        self.entries.is_empty()
409    }
410
411    /// Clear all stored refs.
412    pub fn clear(&mut self) {
413        self.entries.clear();
414    }
415
416    /// Number of stored refs.
417    pub fn len(&self) -> usize {
418        self.entries.len()
419    }
420}
421
422/// Try to extract an event ID from common fields.
423///
424/// Checks (in order): `id`, `_id`, `event_id`, `EventRecordID`, `event.id`.
425/// Returns the first found value as a string.
426fn extract_event_id(event: &serde_json::Value) -> Option<String> {
427    const ID_FIELDS: &[&str] = &["id", "_id", "event_id", "EventRecordID", "event.id"];
428    for field in ID_FIELDS {
429        if let Some(val) = event.get(field) {
430            return match val {
431                serde_json::Value::String(s) => Some(s.clone()),
432                serde_json::Value::Number(n) => Some(n.to_string()),
433                _ => None,
434            };
435        }
436    }
437    None
438}
439
440// =============================================================================
441// Window State
442// =============================================================================
443
444/// Per-group mutable state within a time window.
445///
446/// Each variant matches the type of aggregation being performed.
447#[derive(Debug, Clone, Serialize, serde::Deserialize)]
448pub enum WindowState {
449    /// For `event_count`: timestamps of matching events.
450    EventCount { timestamps: VecDeque<i64> },
451    /// For `value_count`: (timestamp, field_value) pairs.
452    ValueCount { entries: VecDeque<(i64, String)> },
453    /// For `temporal` / `temporal_ordered`: rule_ref -> list of hit timestamps.
454    Temporal {
455        rule_hits: HashMap<String, VecDeque<i64>>,
456    },
457    /// For `value_sum`, `value_avg`, `value_percentile`, `value_median`:
458    /// (timestamp, numeric_value) pairs.
459    NumericAgg { entries: VecDeque<(i64, f64)> },
460}
461
462impl WindowState {
463    /// Create a new empty window state for the given correlation type.
464    pub fn new_for(corr_type: CorrelationType) -> Self {
465        match corr_type {
466            CorrelationType::EventCount => WindowState::EventCount {
467                timestamps: VecDeque::new(),
468            },
469            CorrelationType::ValueCount => WindowState::ValueCount {
470                entries: VecDeque::new(),
471            },
472            CorrelationType::Temporal | CorrelationType::TemporalOrdered => WindowState::Temporal {
473                rule_hits: HashMap::new(),
474            },
475            CorrelationType::ValueSum
476            | CorrelationType::ValueAvg
477            | CorrelationType::ValuePercentile
478            | CorrelationType::ValueMedian => WindowState::NumericAgg {
479                entries: VecDeque::new(),
480            },
481        }
482    }
483
484    /// Remove all entries older than the cutoff timestamp.
485    pub fn evict(&mut self, cutoff: i64) {
486        match self {
487            WindowState::EventCount { timestamps } => {
488                while timestamps.front().is_some_and(|&t| t < cutoff) {
489                    timestamps.pop_front();
490                }
491            }
492            WindowState::ValueCount { entries } => {
493                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
494                    entries.pop_front();
495                }
496            }
497            WindowState::Temporal { rule_hits } => {
498                for timestamps in rule_hits.values_mut() {
499                    while timestamps.front().is_some_and(|&t| t < cutoff) {
500                        timestamps.pop_front();
501                    }
502                }
503                // Remove empty rule entries
504                rule_hits.retain(|_, ts| !ts.is_empty());
505            }
506            WindowState::NumericAgg { entries } => {
507                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
508                    entries.pop_front();
509                }
510            }
511        }
512    }
513
514    /// Returns true if this state has no entries.
515    pub fn is_empty(&self) -> bool {
516        match self {
517            WindowState::EventCount { timestamps } => timestamps.is_empty(),
518            WindowState::ValueCount { entries } => entries.is_empty(),
519            WindowState::Temporal { rule_hits } => rule_hits.is_empty(),
520            WindowState::NumericAgg { entries } => entries.is_empty(),
521        }
522    }
523
524    /// Returns the most recent timestamp in this window, or `None` if empty.
525    pub fn latest_timestamp(&self) -> Option<i64> {
526        match self {
527            WindowState::EventCount { timestamps } => timestamps.back().copied(),
528            WindowState::ValueCount { entries } => entries.back().map(|(t, _)| *t),
529            WindowState::Temporal { rule_hits } => {
530                rule_hits.values().filter_map(|ts| ts.back().copied()).max()
531            }
532            WindowState::NumericAgg { entries } => entries.back().map(|(t, _)| *t),
533        }
534    }
535
536    /// Clear all entries from the window state (used by `CorrelationAction::Reset`).
537    pub fn clear(&mut self) {
538        match self {
539            WindowState::EventCount { timestamps } => timestamps.clear(),
540            WindowState::ValueCount { entries } => entries.clear(),
541            WindowState::Temporal { rule_hits } => rule_hits.clear(),
542            WindowState::NumericAgg { entries } => entries.clear(),
543        }
544    }
545
546    /// Record an event_count hit.
547    pub fn push_event_count(&mut self, ts: i64) {
548        if let WindowState::EventCount { timestamps } = self {
549            timestamps.push_back(ts);
550        }
551    }
552
553    /// Record a value_count hit with the field value.
554    pub fn push_value_count(&mut self, ts: i64, value: String) {
555        if let WindowState::ValueCount { entries } = self {
556            entries.push_back((ts, value));
557        }
558    }
559
560    /// Record a temporal hit for a specific rule reference.
561    pub fn push_temporal(&mut self, ts: i64, rule_ref: &str) {
562        if let WindowState::Temporal { rule_hits } = self {
563            rule_hits
564                .entry(rule_ref.to_string())
565                .or_default()
566                .push_back(ts);
567        }
568    }
569
570    /// Record a numeric aggregation value.
571    pub fn push_numeric(&mut self, ts: i64, value: f64) {
572        if let WindowState::NumericAgg { entries } = self {
573            entries.push_back((ts, value));
574        }
575    }
576
577    /// Evaluate the window state against the correlation condition.
578    ///
579    /// Returns `Some(aggregated_value)` if the condition is satisfied,
580    /// `None` otherwise.
581    ///
582    /// For temporal correlations with an extended expression, the expression
583    /// is evaluated against the set of rules that have fired in the window.
584    pub fn check_condition(
585        &self,
586        condition: &CompiledCondition,
587        corr_type: CorrelationType,
588        rule_refs: &[String],
589        extended_expr: Option<&ConditionExpr>,
590    ) -> Option<f64> {
591        let value = match (self, corr_type) {
592            (WindowState::EventCount { timestamps }, CorrelationType::EventCount) => {
593                timestamps.len() as f64
594            }
595            (WindowState::ValueCount { entries }, CorrelationType::ValueCount) => {
596                // Count distinct values
597                let distinct: HashSet<&String> = entries.iter().map(|(_, v)| v).collect();
598                distinct.len() as f64
599            }
600            (WindowState::Temporal { rule_hits }, CorrelationType::Temporal) => {
601                // If an extended expression is provided, evaluate it
602                if let Some(expr) = extended_expr {
603                    if eval_temporal_expr(expr, rule_hits) {
604                        // Return the count of fired rules as the value
605                        let fired: usize = rule_refs
606                            .iter()
607                            .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
608                            .count();
609                        return Some(fired as f64);
610                    } else {
611                        return None;
612                    }
613                }
614                // Default: count how many distinct referenced rules have fired
615                let fired: usize = rule_refs
616                    .iter()
617                    .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
618                    .count();
619                fired as f64
620            }
621            (WindowState::Temporal { rule_hits }, CorrelationType::TemporalOrdered) => {
622                // If an extended expression is provided, evaluate it first
623                if let Some(expr) = extended_expr
624                    && !eval_temporal_expr(expr, rule_hits)
625                {
626                    return None;
627                }
628                // Check if all referenced rules fired in order
629                if check_temporal_ordered(rule_refs, rule_hits) {
630                    rule_refs.len() as f64
631                } else {
632                    0.0
633                }
634            }
635            (WindowState::NumericAgg { entries }, CorrelationType::ValueSum) => {
636                entries.iter().map(|(_, v)| v).sum()
637            }
638            (WindowState::NumericAgg { entries }, CorrelationType::ValueAvg) => {
639                if entries.is_empty() {
640                    0.0
641                } else {
642                    let sum: f64 = entries.iter().map(|(_, v)| v).sum();
643                    sum / entries.len() as f64
644                }
645            }
646            (WindowState::NumericAgg { entries }, CorrelationType::ValuePercentile) => {
647                // Proper percentile calculation using linear interpolation.
648                // The condition threshold represents a percentile rank (0-100).
649                // We compute the value at that percentile from the window data.
650                if entries.is_empty() {
651                    return None;
652                }
653                let mut values: Vec<f64> = entries
654                    .iter()
655                    .map(|(_, v)| *v)
656                    .filter(|v| v.is_finite())
657                    .collect();
658                if values.is_empty() {
659                    return None;
660                }
661                values.sort_by(|a, b| a.partial_cmp(b).expect("NaN filtered"));
662                // Extract the percentile rank from the condition's first predicate
663                let percentile_rank = condition
664                    .predicates
665                    .first()
666                    .map(|(_, threshold)| *threshold)
667                    .unwrap_or(50.0);
668                let pval = percentile_linear_interp(&values, percentile_rank);
669                return Some(pval);
670            }
671            (WindowState::NumericAgg { entries }, CorrelationType::ValueMedian) => {
672                if entries.is_empty() {
673                    0.0
674                } else {
675                    let mut values: Vec<f64> = entries
676                        .iter()
677                        .map(|(_, v)| *v)
678                        .filter(|v| v.is_finite())
679                        .collect();
680                    if values.is_empty() {
681                        return None;
682                    }
683                    values.sort_by(|a, b| a.partial_cmp(b).expect("NaN filtered"));
684                    let mid = values.len() / 2;
685                    if values.len().is_multiple_of(2) && values.len() >= 2 {
686                        (values[mid - 1] + values[mid]) / 2.0
687                    } else {
688                        values[mid]
689                    }
690                }
691            }
692            _ => return None, // mismatched state/type
693        };
694
695        if condition.check(value) {
696            Some(value)
697        } else {
698            None
699        }
700    }
701}
702
703/// Check if all referenced rules fired in the correct order within the window.
704///
705/// For `temporal_ordered`, each rule must have at least one hit, and there
706/// must exist a sequence of timestamps (one per rule) that is non-decreasing
707/// and follows the rule ordering.
708fn check_temporal_ordered(
709    rule_refs: &[String],
710    rule_hits: &HashMap<String, VecDeque<i64>>,
711) -> bool {
712    if rule_refs.is_empty() {
713        return true;
714    }
715
716    // All rules must have at least one hit
717    for r in rule_refs {
718        if rule_hits.get(r.as_str()).is_none_or(|ts| ts.is_empty()) {
719            return false;
720        }
721    }
722
723    // Check if there's a valid ordered sequence: for each rule in order,
724    // find a timestamp >= the previous rule's chosen timestamp.
725    fn find_ordered(
726        rule_refs: &[String],
727        rule_hits: &HashMap<String, VecDeque<i64>>,
728        idx: usize,
729        min_ts: i64,
730    ) -> bool {
731        if idx >= rule_refs.len() {
732            return true;
733        }
734        let Some(timestamps) = rule_hits.get(&rule_refs[idx]) else {
735            return false;
736        };
737        for &ts in timestamps {
738            if ts >= min_ts && find_ordered(rule_refs, rule_hits, idx + 1, ts) {
739                return true;
740            }
741        }
742        false
743    }
744
745    find_ordered(rule_refs, rule_hits, 0, i64::MIN)
746}
747
748/// Evaluate a boolean condition expression against the set of rules that have
749/// fired within the temporal window.
750///
751/// Each `Identifier` in the expression is treated as a rule reference — it's
752/// `true` if that rule has at least one hit in `rule_hits`.
753fn eval_temporal_expr(expr: &ConditionExpr, rule_hits: &HashMap<String, VecDeque<i64>>) -> bool {
754    match expr {
755        ConditionExpr::Identifier(name) => rule_hits
756            .get(name.as_str())
757            .is_some_and(|ts| !ts.is_empty()),
758        ConditionExpr::And(children) => children.iter().all(|c| eval_temporal_expr(c, rule_hits)),
759        ConditionExpr::Or(children) => children.iter().any(|c| eval_temporal_expr(c, rule_hits)),
760        ConditionExpr::Not(child) => !eval_temporal_expr(child, rule_hits),
761        ConditionExpr::Selector { .. } => {
762            // Selectors are not meaningful for temporal condition evaluation
763            false
764        }
765    }
766}
767
768/// Compute the value at a given percentile rank using linear interpolation.
769///
770/// Returns 0.0 if `values` is empty.
771/// `values` must be sorted in ascending order.
772/// `percentile` is from 0.0 to 100.0.
773fn percentile_linear_interp(values: &[f64], percentile: f64) -> f64 {
774    if values.is_empty() {
775        return 0.0;
776    }
777    let n = values.len();
778    if n == 1 {
779        return values[0];
780    }
781
782    // Clamp percentile to [0, 100]
783    let p = percentile.clamp(0.0, 100.0) / 100.0;
784
785    // Use the "C = 1" interpolation method (most common in statistics)
786    // rank = p * (n - 1)
787    let rank = p * (n - 1) as f64;
788    let lower = rank.floor() as usize;
789    let upper = rank.ceil() as usize;
790    let fraction = rank - lower as f64;
791
792    if lower == upper || upper >= n {
793        values[lower.min(n - 1)]
794    } else {
795        values[lower] + fraction * (values[upper] - values[lower])
796    }
797}
798
799// =============================================================================
800// Compilation
801// =============================================================================
802
803/// Compile a parsed `CorrelationRule` into a `CompiledCorrelation`.
804pub fn compile_correlation(rule: &CorrelationRule) -> Result<CompiledCorrelation> {
805    // Build group-by fields, resolving aliases
806    let alias_map: HashMap<&str, &FieldAlias> =
807        rule.aliases.iter().map(|a| (a.alias.as_str(), a)).collect();
808
809    let group_by: Vec<GroupByField> = rule
810        .group_by
811        .iter()
812        .map(|field_name| {
813            if let Some(alias) = alias_map.get(field_name.as_str()) {
814                GroupByField::Aliased {
815                    alias: field_name.clone(),
816                    mapping: alias.mapping.clone(),
817                }
818            } else {
819                GroupByField::Direct(field_name.clone())
820            }
821        })
822        .collect();
823
824    // Compile condition
825    let (condition, extended_expr) = compile_condition(&rule.condition, rule.correlation_type)?;
826
827    // Resolve per-correlation overrides from custom attributes.
828    // These mirror the engine-level `rsigma.*` attributes but apply only
829    // to this correlation rule, taking precedence over engine defaults.
830    let suppress_secs = rule
831        .custom_attributes
832        .get("rsigma.suppress")
833        .and_then(|v| rsigma_parser::Timespan::parse(v).ok())
834        .map(|ts| ts.seconds);
835
836    let action = rule.custom_attributes.get("rsigma.action").and_then(|v| {
837        v.parse::<crate::correlation_engine::CorrelationAction>()
838            .ok()
839    });
840
841    let event_mode = rule
842        .custom_attributes
843        .get("rsigma.correlation_event_mode")
844        .and_then(|v| {
845            v.parse::<crate::correlation_engine::CorrelationEventMode>()
846                .ok()
847        });
848
849    let max_events = rule
850        .custom_attributes
851        .get("rsigma.max_correlation_events")
852        .and_then(|v| v.parse::<usize>().ok());
853
854    Ok(CompiledCorrelation {
855        id: rule.id.clone(),
856        name: rule.name.clone(),
857        title: rule.title.clone(),
858        level: rule.level,
859        tags: rule.tags.clone(),
860        correlation_type: rule.correlation_type,
861        rule_refs: rule.rules.clone(),
862        group_by,
863        timespan_secs: rule.timespan.seconds,
864        condition,
865        extended_expr,
866        generate: rule.generate,
867        suppress_secs,
868        action,
869        event_mode,
870        max_events,
871    })
872}
873
874/// Compile a `CorrelationCondition` into a `CompiledCondition` and optional expression.
875fn compile_condition(
876    cond: &CorrelationCondition,
877    corr_type: CorrelationType,
878) -> Result<(CompiledCondition, Option<ConditionExpr>)> {
879    match cond {
880        CorrelationCondition::Threshold { predicates, field } => Ok((
881            CompiledCondition {
882                field: field.clone(),
883                predicates: predicates
884                    .iter()
885                    .map(|(op, count)| (*op, *count as f64))
886                    .collect(),
887            },
888            None,
889        )),
890        CorrelationCondition::Extended(expr) => {
891            match corr_type {
892                CorrelationType::Temporal | CorrelationType::TemporalOrdered => {
893                    // For extended conditions, the threshold is a dummy (gte: 1)
894                    // since the actual evaluation is done via the expression tree.
895                    Ok((
896                        CompiledCondition {
897                            field: None,
898                            predicates: vec![(ConditionOperator::Gte, 1.0)],
899                        },
900                        Some(expr.clone()),
901                    ))
902                }
903                _ => Err(EvalError::CorrelationError(
904                    "Extended conditions are only supported for temporal correlation types"
905                        .to_string(),
906                )),
907            }
908        }
909    }
910}
911
912#[cfg(test)]
913mod tests {
914    use super::*;
915    use serde_json::json;
916
917    #[test]
918    fn test_group_key_extract() {
919        let v = json!({"User": "admin", "Host": "srv01"});
920        let event = Event::from_value(&v);
921        let group_by = vec![
922            GroupByField::Direct("User".to_string()),
923            GroupByField::Direct("Host".to_string()),
924        ];
925        let key = GroupKey::extract(&event, &group_by, &["rule1"]);
926        assert_eq!(
927            key.0,
928            vec![Some("admin".to_string()), Some("srv01".to_string())]
929        );
930    }
931
932    #[test]
933    fn test_group_key_missing_field() {
934        let v = json!({"User": "admin"});
935        let event = Event::from_value(&v);
936        let group_by = vec![
937            GroupByField::Direct("User".to_string()),
938            GroupByField::Direct("Host".to_string()),
939        ];
940        let key = GroupKey::extract(&event, &group_by, &["rule1"]);
941        assert_eq!(key.0, vec![Some("admin".to_string()), None]);
942    }
943
944    #[test]
945    fn test_group_key_aliased() {
946        let v = json!({"source.ip": "10.0.0.1"});
947        let event = Event::from_value(&v);
948        let group_by = vec![GroupByField::Aliased {
949            alias: "internal_ip".to_string(),
950            mapping: HashMap::from([
951                ("rule_a".to_string(), "source.ip".to_string()),
952                ("rule_b".to_string(), "destination.ip".to_string()),
953            ]),
954        }];
955        let key = GroupKey::extract(&event, &group_by, &["rule_a"]);
956        assert_eq!(key.0, vec![Some("10.0.0.1".to_string())]);
957    }
958
959    #[test]
960    fn test_condition_check() {
961        let cond = CompiledCondition {
962            field: None,
963            predicates: vec![(ConditionOperator::Gte, 100.0)],
964        };
965        assert!(!cond.check(99.0));
966        assert!(cond.check(100.0));
967        assert!(cond.check(101.0));
968    }
969
970    #[test]
971    fn test_condition_check_range() {
972        let cond = CompiledCondition {
973            field: None,
974            predicates: vec![
975                (ConditionOperator::Gt, 100.0),
976                (ConditionOperator::Lte, 200.0),
977            ],
978        };
979        assert!(!cond.check(100.0));
980        assert!(cond.check(101.0));
981        assert!(cond.check(200.0));
982        assert!(!cond.check(201.0));
983    }
984
985    #[test]
986    fn test_window_event_count() {
987        let mut state = WindowState::new_for(CorrelationType::EventCount);
988        for i in 0..5 {
989            state.push_event_count(1000 + i);
990        }
991        let cond = CompiledCondition {
992            field: None,
993            predicates: vec![(ConditionOperator::Gte, 5.0)],
994        };
995        assert_eq!(
996            state.check_condition(&cond, CorrelationType::EventCount, &[], None),
997            Some(5.0)
998        );
999    }
1000
1001    #[test]
1002    fn test_window_event_count_eviction() {
1003        let mut state = WindowState::new_for(CorrelationType::EventCount);
1004        for i in 0..10 {
1005            state.push_event_count(1000 + i);
1006        }
1007        // Evict events before ts=1005
1008        state.evict(1005);
1009        let cond = CompiledCondition {
1010            field: None,
1011            predicates: vec![(ConditionOperator::Gte, 5.0)],
1012        };
1013        assert_eq!(
1014            state.check_condition(&cond, CorrelationType::EventCount, &[], None),
1015            Some(5.0)
1016        );
1017    }
1018
1019    #[test]
1020    fn test_window_value_count() {
1021        let mut state = WindowState::new_for(CorrelationType::ValueCount);
1022        state.push_value_count(1000, "user1".to_string());
1023        state.push_value_count(1001, "user2".to_string());
1024        state.push_value_count(1002, "user1".to_string()); // duplicate
1025        state.push_value_count(1003, "user3".to_string());
1026
1027        let cond = CompiledCondition {
1028            field: Some("User".to_string()),
1029            predicates: vec![(ConditionOperator::Gte, 3.0)],
1030        };
1031        assert_eq!(
1032            state.check_condition(&cond, CorrelationType::ValueCount, &[], None),
1033            Some(3.0)
1034        );
1035    }
1036
1037    #[test]
1038    fn test_window_temporal() {
1039        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
1040        let mut state = WindowState::new_for(CorrelationType::Temporal);
1041        state.push_temporal(1000, "rule_a");
1042        // Only rule_a fired — condition: all refs must fire
1043        let cond = CompiledCondition {
1044            field: None,
1045            predicates: vec![(ConditionOperator::Gte, 2.0)],
1046        };
1047        assert!(
1048            state
1049                .check_condition(&cond, CorrelationType::Temporal, &refs, None)
1050                .is_none()
1051        );
1052
1053        // Now rule_b fires too
1054        state.push_temporal(1001, "rule_b");
1055        assert_eq!(
1056            state.check_condition(&cond, CorrelationType::Temporal, &refs, None),
1057            Some(2.0)
1058        );
1059    }
1060
1061    #[test]
1062    fn test_window_temporal_ordered() {
1063        let refs = vec![
1064            "rule_a".to_string(),
1065            "rule_b".to_string(),
1066            "rule_c".to_string(),
1067        ];
1068        let mut state = WindowState::new_for(CorrelationType::TemporalOrdered);
1069        // Fire in order: a, b, c
1070        state.push_temporal(1000, "rule_a");
1071        state.push_temporal(1001, "rule_b");
1072        state.push_temporal(1002, "rule_c");
1073
1074        let cond = CompiledCondition {
1075            field: None,
1076            predicates: vec![(ConditionOperator::Gte, 3.0)],
1077        };
1078        assert!(
1079            state
1080                .check_condition(&cond, CorrelationType::TemporalOrdered, &refs, None)
1081                .is_some()
1082        );
1083    }
1084
1085    #[test]
1086    fn test_window_temporal_ordered_wrong_order() {
1087        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
1088        let mut state = WindowState::new_for(CorrelationType::TemporalOrdered);
1089        // Fire in wrong order: b before a
1090        state.push_temporal(1000, "rule_b");
1091        state.push_temporal(1001, "rule_a");
1092
1093        let cond = CompiledCondition {
1094            field: None,
1095            predicates: vec![(ConditionOperator::Gte, 2.0)],
1096        };
1097        assert!(
1098            state
1099                .check_condition(&cond, CorrelationType::TemporalOrdered, &refs, None)
1100                .is_none()
1101        );
1102    }
1103
1104    #[test]
1105    fn test_window_value_sum() {
1106        let mut state = WindowState::new_for(CorrelationType::ValueSum);
1107        state.push_numeric(1000, 500.0);
1108        state.push_numeric(1001, 600.0);
1109
1110        let cond = CompiledCondition {
1111            field: Some("bytes_sent".to_string()),
1112            predicates: vec![(ConditionOperator::Gt, 1000.0)],
1113        };
1114        assert_eq!(
1115            state.check_condition(&cond, CorrelationType::ValueSum, &[], None),
1116            Some(1100.0)
1117        );
1118    }
1119
1120    #[test]
1121    fn test_window_value_avg() {
1122        let mut state = WindowState::new_for(CorrelationType::ValueAvg);
1123        state.push_numeric(1000, 100.0);
1124        state.push_numeric(1001, 200.0);
1125        state.push_numeric(1002, 300.0);
1126
1127        let cond = CompiledCondition {
1128            field: Some("bytes".to_string()),
1129            predicates: vec![(ConditionOperator::Gte, 200.0)],
1130        };
1131        assert_eq!(
1132            state.check_condition(&cond, CorrelationType::ValueAvg, &[], None),
1133            Some(200.0)
1134        );
1135    }
1136
1137    #[test]
1138    fn test_window_value_median() {
1139        let mut state = WindowState::new_for(CorrelationType::ValueMedian);
1140        state.push_numeric(1000, 10.0);
1141        state.push_numeric(1001, 20.0);
1142        state.push_numeric(1002, 30.0);
1143
1144        let cond = CompiledCondition {
1145            field: Some("latency".to_string()),
1146            predicates: vec![(ConditionOperator::Gte, 20.0)],
1147        };
1148        assert_eq!(
1149            state.check_condition(&cond, CorrelationType::ValueMedian, &[], None),
1150            Some(20.0)
1151        );
1152    }
1153
1154    #[test]
1155    fn test_compile_correlation_basic() {
1156        use rsigma_parser::parse_sigma_yaml;
1157
1158        let yaml = r#"
1159title: Base Rule
1160id: f305fd62-beca-47da-ad95-7690a0620084
1161logsource:
1162    product: aws
1163    service: cloudtrail
1164detection:
1165    selection:
1166        eventSource: "s3.amazonaws.com"
1167    condition: selection
1168level: low
1169---
1170title: Multiple AWS bucket enumerations
1171id: be246094-01d3-4bba-88de-69e582eba0cc
1172status: experimental
1173correlation:
1174    type: event_count
1175    rules:
1176        - f305fd62-beca-47da-ad95-7690a0620084
1177    group-by:
1178        - userIdentity.arn
1179    timespan: 1h
1180    condition:
1181        gte: 100
1182level: high
1183"#;
1184        let collection = parse_sigma_yaml(yaml).unwrap();
1185        assert_eq!(collection.correlations.len(), 1);
1186
1187        let compiled = compile_correlation(&collection.correlations[0]).unwrap();
1188        assert_eq!(compiled.correlation_type, CorrelationType::EventCount);
1189        assert_eq!(compiled.timespan_secs, 3600);
1190        assert_eq!(compiled.rule_refs.len(), 1);
1191        assert_eq!(compiled.group_by.len(), 1);
1192        assert!(compiled.condition.check(100.0));
1193        assert!(!compiled.condition.check(99.0));
1194    }
1195
1196    // =========================================================================
1197    // Extended temporal condition tests
1198    // =========================================================================
1199
1200    #[test]
1201    fn test_eval_temporal_expr_and() {
1202        let mut rule_hits = HashMap::new();
1203        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1204        rule_hits.insert("rule_b".to_string(), VecDeque::from([1001]));
1205
1206        let expr = ConditionExpr::And(vec![
1207            ConditionExpr::Identifier("rule_a".to_string()),
1208            ConditionExpr::Identifier("rule_b".to_string()),
1209        ]);
1210        assert!(eval_temporal_expr(&expr, &rule_hits));
1211    }
1212
1213    #[test]
1214    fn test_eval_temporal_expr_and_incomplete() {
1215        let mut rule_hits = HashMap::new();
1216        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1217        // rule_b not fired
1218
1219        let expr = ConditionExpr::And(vec![
1220            ConditionExpr::Identifier("rule_a".to_string()),
1221            ConditionExpr::Identifier("rule_b".to_string()),
1222        ]);
1223        assert!(!eval_temporal_expr(&expr, &rule_hits));
1224    }
1225
1226    #[test]
1227    fn test_eval_temporal_expr_or() {
1228        let mut rule_hits = HashMap::new();
1229        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1230
1231        let expr = ConditionExpr::Or(vec![
1232            ConditionExpr::Identifier("rule_a".to_string()),
1233            ConditionExpr::Identifier("rule_b".to_string()),
1234        ]);
1235        assert!(eval_temporal_expr(&expr, &rule_hits));
1236    }
1237
1238    #[test]
1239    fn test_eval_temporal_expr_not() {
1240        let rule_hits = HashMap::new();
1241
1242        let expr = ConditionExpr::Not(Box::new(ConditionExpr::Identifier("rule_a".to_string())));
1243        assert!(eval_temporal_expr(&expr, &rule_hits));
1244    }
1245
1246    #[test]
1247    fn test_eval_temporal_expr_complex() {
1248        let mut rule_hits = HashMap::new();
1249        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1250        rule_hits.insert("rule_b".to_string(), VecDeque::from([1001]));
1251        // rule_c NOT fired
1252
1253        // (rule_a and rule_b) and not rule_c
1254        let expr = ConditionExpr::And(vec![
1255            ConditionExpr::And(vec![
1256                ConditionExpr::Identifier("rule_a".to_string()),
1257                ConditionExpr::Identifier("rule_b".to_string()),
1258            ]),
1259            ConditionExpr::Not(Box::new(ConditionExpr::Identifier("rule_c".to_string()))),
1260        ]);
1261        assert!(eval_temporal_expr(&expr, &rule_hits));
1262    }
1263
1264    #[test]
1265    fn test_check_condition_with_extended_expr() {
1266        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
1267        let mut state = WindowState::new_for(CorrelationType::Temporal);
1268        state.push_temporal(1000, "rule_a");
1269        state.push_temporal(1001, "rule_b");
1270
1271        let cond = CompiledCondition {
1272            field: None,
1273            predicates: vec![(ConditionOperator::Gte, 1.0)],
1274        };
1275        let expr = ConditionExpr::And(vec![
1276            ConditionExpr::Identifier("rule_a".to_string()),
1277            ConditionExpr::Identifier("rule_b".to_string()),
1278        ]);
1279
1280        // With expression: should match (both rules fired)
1281        assert!(
1282            state
1283                .check_condition(&cond, CorrelationType::Temporal, &refs, Some(&expr))
1284                .is_some()
1285        );
1286
1287        // Now test with only rule_a: expression should fail
1288        let mut state2 = WindowState::new_for(CorrelationType::Temporal);
1289        state2.push_temporal(1000, "rule_a");
1290        assert!(
1291            state2
1292                .check_condition(&cond, CorrelationType::Temporal, &refs, Some(&expr))
1293                .is_none()
1294        );
1295    }
1296
1297    // =========================================================================
1298    // Percentile linear interpolation tests
1299    // =========================================================================
1300
1301    #[test]
1302    fn test_percentile_linear_interp_single() {
1303        assert!((percentile_linear_interp(&[42.0], 50.0) - 42.0).abs() < f64::EPSILON);
1304    }
1305
1306    #[test]
1307    fn test_percentile_linear_interp_basic() {
1308        // Values: [1, 2, 3, 4, 5]
1309        let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
1310        // 0th percentile = 1.0
1311        assert!((percentile_linear_interp(values, 0.0) - 1.0).abs() < f64::EPSILON);
1312        // 25th percentile = 2.0
1313        assert!((percentile_linear_interp(values, 25.0) - 2.0).abs() < f64::EPSILON);
1314        // 50th percentile = 3.0
1315        assert!((percentile_linear_interp(values, 50.0) - 3.0).abs() < f64::EPSILON);
1316        // 75th percentile = 4.0
1317        assert!((percentile_linear_interp(values, 75.0) - 4.0).abs() < f64::EPSILON);
1318        // 100th percentile = 5.0
1319        assert!((percentile_linear_interp(values, 100.0) - 5.0).abs() < f64::EPSILON);
1320    }
1321
1322    #[test]
1323    fn test_percentile_linear_interp_interpolation() {
1324        // Values: [10, 20, 30, 40]
1325        let values = &[10.0, 20.0, 30.0, 40.0];
1326        // 50th percentile: rank = 0.5 * 3 = 1.5, interp between 20 and 30 = 25
1327        assert!((percentile_linear_interp(values, 50.0) - 25.0).abs() < f64::EPSILON);
1328    }
1329
1330    #[test]
1331    fn test_percentile_linear_interp_1st_percentile() {
1332        // Values: [1, 2, 3, ..., 100]
1333        let values: Vec<f64> = (1..=100).map(|x| x as f64).collect();
1334        // 1st percentile = 1.0 + 0.01 * 99 * (2.0 - 1.0) ~ 1.99
1335        let p1 = percentile_linear_interp(&values, 1.0);
1336        assert!((p1 - 1.99).abs() < 0.01);
1337    }
1338
1339    #[test]
1340    fn test_value_percentile_check_condition() {
1341        let mut state = WindowState::new_for(CorrelationType::ValuePercentile);
1342        // Push 100 values: 1.0, 2.0, ..., 100.0
1343        for i in 1..=100 {
1344            state.push_numeric(1000 + i, i as f64);
1345        }
1346
1347        let cond = CompiledCondition {
1348            field: Some("latency".to_string()),
1349            // The condition threshold is used as the percentile rank
1350            predicates: vec![(ConditionOperator::Lte, 50.0)],
1351        };
1352        // 50th percentile of 1..100 should be ~50.5
1353        let result = state.check_condition(&cond, CorrelationType::ValuePercentile, &[], None);
1354        assert!(result.is_some());
1355        let val = result.unwrap();
1356        assert!((val - 50.5).abs() < 1.0, "expected ~50.5, got {val}");
1357    }
1358
1359    #[test]
1360    fn test_percentile_0th_and_100th() {
1361        let values = &[5.0, 10.0, 15.0, 20.0];
1362        assert!((percentile_linear_interp(values, 0.0) - 5.0).abs() < f64::EPSILON);
1363        assert!((percentile_linear_interp(values, 100.0) - 20.0).abs() < f64::EPSILON);
1364    }
1365
1366    #[test]
1367    fn test_percentile_two_values() {
1368        let values = &[10.0, 20.0];
1369        // 50th percentile between 10 and 20 = 15
1370        assert!((percentile_linear_interp(values, 50.0) - 15.0).abs() < f64::EPSILON);
1371        // 25th percentile = 12.5
1372        assert!((percentile_linear_interp(values, 25.0) - 12.5).abs() < f64::EPSILON);
1373    }
1374
1375    #[test]
1376    fn test_percentile_clamps_out_of_range() {
1377        let values = &[1.0, 2.0, 3.0];
1378        // Negative percentile clamps to 0
1379        assert!((percentile_linear_interp(values, -10.0) - 1.0).abs() < f64::EPSILON);
1380        // > 100 clamps to 100
1381        assert!((percentile_linear_interp(values, 150.0) - 3.0).abs() < f64::EPSILON);
1382    }
1383
1384    #[test]
1385    fn test_value_percentile_empty_window() {
1386        let state = WindowState::new_for(CorrelationType::ValuePercentile);
1387        let cond = CompiledCondition {
1388            field: Some("latency".to_string()),
1389            predicates: vec![(ConditionOperator::Lte, 50.0)],
1390        };
1391        // Empty window should return None
1392        assert!(
1393            state
1394                .check_condition(&cond, CorrelationType::ValuePercentile, &[], None)
1395                .is_none()
1396        );
1397    }
1398
1399    #[test]
1400    fn test_extended_temporal_or_single_rule() {
1401        // "rule_a or rule_b" — only rule_a fired
1402        let mut rule_hits = HashMap::new();
1403        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1404
1405        let expr = ConditionExpr::Or(vec![
1406            ConditionExpr::Identifier("rule_a".to_string()),
1407            ConditionExpr::Identifier("rule_b".to_string()),
1408        ]);
1409        assert!(eval_temporal_expr(&expr, &rule_hits));
1410    }
1411
1412    #[test]
1413    fn test_extended_temporal_empty_hits() {
1414        let rule_hits = HashMap::new();
1415
1416        // "rule_a and rule_b" — nothing fired
1417        let expr = ConditionExpr::And(vec![
1418            ConditionExpr::Identifier("rule_a".to_string()),
1419            ConditionExpr::Identifier("rule_b".to_string()),
1420        ]);
1421        assert!(!eval_temporal_expr(&expr, &rule_hits));
1422
1423        // "rule_a or rule_b" — nothing fired
1424        let expr_or = ConditionExpr::Or(vec![
1425            ConditionExpr::Identifier("rule_a".to_string()),
1426            ConditionExpr::Identifier("rule_b".to_string()),
1427        ]);
1428        assert!(!eval_temporal_expr(&expr_or, &rule_hits));
1429    }
1430
1431    #[test]
1432    fn test_extended_temporal_with_empty_deque() {
1433        // Rule exists in map but with empty deque (all evicted)
1434        let mut rule_hits = HashMap::new();
1435        rule_hits.insert("rule_a".to_string(), VecDeque::new());
1436        rule_hits.insert("rule_b".to_string(), VecDeque::from([1000]));
1437
1438        let expr = ConditionExpr::And(vec![
1439            ConditionExpr::Identifier("rule_a".to_string()),
1440            ConditionExpr::Identifier("rule_b".to_string()),
1441        ]);
1442        // rule_a has empty deque — should be treated as not fired
1443        assert!(!eval_temporal_expr(&expr, &rule_hits));
1444    }
1445
1446    #[test]
1447    fn test_check_condition_temporal_no_extended_expr() {
1448        // Standard temporal without extended expr: uses threshold count
1449        let refs = vec![
1450            "rule_a".to_string(),
1451            "rule_b".to_string(),
1452            "rule_c".to_string(),
1453        ];
1454        let mut state = WindowState::new_for(CorrelationType::Temporal);
1455        state.push_temporal(1000, "rule_a");
1456        state.push_temporal(1001, "rule_b");
1457
1458        // Threshold: at least 2 rules must fire
1459        let cond = CompiledCondition {
1460            field: None,
1461            predicates: vec![(ConditionOperator::Gte, 2.0)],
1462        };
1463        // Without extended expr: 2 of 3 rules fired, meets gte 2
1464        assert_eq!(
1465            state.check_condition(&cond, CorrelationType::Temporal, &refs, None),
1466            Some(2.0)
1467        );
1468
1469        // With threshold 3: not enough
1470        let cond3 = CompiledCondition {
1471            field: None,
1472            predicates: vec![(ConditionOperator::Gte, 3.0)],
1473        };
1474        assert!(
1475            state
1476                .check_condition(&cond3, CorrelationType::Temporal, &refs, None)
1477                .is_none()
1478        );
1479    }
1480
1481    // =========================================================================
1482    // EventBuffer tests
1483    // =========================================================================
1484
1485    #[test]
1486    fn test_event_buffer_push_and_decompress() {
1487        let mut buf = EventBuffer::new(10);
1488        let event = json!({"User": "admin", "action": "login", "src_ip": "10.0.0.1"});
1489        buf.push(1000, &event);
1490
1491        assert_eq!(buf.len(), 1);
1492        assert!(!buf.is_empty());
1493
1494        let events = buf.decompress_all();
1495        assert_eq!(events.len(), 1);
1496        assert_eq!(events[0], event);
1497    }
1498
1499    #[test]
1500    fn test_event_buffer_compression_saves_memory() {
1501        let mut buf = EventBuffer::new(100);
1502        // Push a realistic-sized event (~500 bytes JSON)
1503        let event = json!({
1504            "User": "admin",
1505            "action": "login",
1506            "src_ip": "192.168.1.100",
1507            "dst_ip": "10.0.0.1",
1508            "EventTime": "2024-07-10T12:30:00Z",
1509            "process": "sshd",
1510            "host": "production-server-01.example.com",
1511            "message": "Accepted password for admin from 192.168.1.100 port 22 ssh2",
1512            "severity": "info",
1513            "tags": ["authentication", "network", "linux"]
1514        });
1515
1516        let raw_size = serde_json::to_vec(&event).unwrap().len();
1517        buf.push(1000, &event);
1518        let compressed_size = buf.compressed_bytes();
1519
1520        // Compressed should be notably smaller than raw
1521        assert!(
1522            compressed_size < raw_size,
1523            "Compressed {compressed_size}B should be less than raw {raw_size}B"
1524        );
1525
1526        // Verify roundtrip
1527        let events = buf.decompress_all();
1528        assert_eq!(events[0], event);
1529    }
1530
1531    #[test]
1532    fn test_event_buffer_max_events_cap() {
1533        let mut buf = EventBuffer::new(3);
1534
1535        for i in 0..5 {
1536            buf.push(1000 + i, &json!({"idx": i}));
1537        }
1538
1539        // Only the last 3 should remain
1540        assert_eq!(buf.len(), 3);
1541        let events = buf.decompress_all();
1542        assert_eq!(events[0], json!({"idx": 2}));
1543        assert_eq!(events[1], json!({"idx": 3}));
1544        assert_eq!(events[2], json!({"idx": 4}));
1545    }
1546
1547    #[test]
1548    fn test_event_buffer_eviction() {
1549        let mut buf = EventBuffer::new(10);
1550        for i in 0..5 {
1551            buf.push(1000 + i, &json!({"idx": i}));
1552        }
1553        assert_eq!(buf.len(), 5);
1554
1555        // Evict everything before ts 1003
1556        buf.evict(1003);
1557        assert_eq!(buf.len(), 2);
1558
1559        let events = buf.decompress_all();
1560        assert_eq!(events[0], json!({"idx": 3}));
1561        assert_eq!(events[1], json!({"idx": 4}));
1562    }
1563
1564    #[test]
1565    fn test_event_buffer_clear() {
1566        let mut buf = EventBuffer::new(10);
1567        buf.push(1000, &json!({"a": 1}));
1568        buf.push(1001, &json!({"b": 2}));
1569        assert_eq!(buf.len(), 2);
1570
1571        buf.clear();
1572        assert!(buf.is_empty());
1573        assert_eq!(buf.len(), 0);
1574        assert_eq!(buf.compressed_bytes(), 0);
1575    }
1576
1577    #[test]
1578    fn test_compress_decompress_roundtrip() {
1579        // Test various JSON shapes
1580        let values = vec![
1581            json!(null),
1582            json!(42),
1583            json!("hello world"),
1584            json!({"nested": {"deep": [1, 2, 3]}}),
1585            json!([1, "two", null, true, {"five": 5}]),
1586        ];
1587        for val in values {
1588            let compressed = compress_event(&val).unwrap();
1589            let decompressed = decompress_event(&compressed).unwrap();
1590            assert_eq!(decompressed, val, "Roundtrip failed for {val}");
1591        }
1592    }
1593
1594    // =========================================================================
1595    // EventRefBuffer tests
1596    // =========================================================================
1597
1598    #[test]
1599    fn test_event_ref_buffer_push_and_refs() {
1600        let mut buf = EventRefBuffer::new(10);
1601        buf.push(1000, &json!({"id": "evt-1", "data": "hello"}));
1602        buf.push(1001, &json!({"_id": 42, "data": "world"}));
1603        buf.push(1002, &json!({"data": "no-id"}));
1604
1605        assert_eq!(buf.len(), 3);
1606        let refs = buf.refs();
1607        assert_eq!(refs[0].timestamp, 1000);
1608        assert_eq!(refs[0].id, Some("evt-1".to_string()));
1609        assert_eq!(refs[1].timestamp, 1001);
1610        assert_eq!(refs[1].id, Some("42".to_string()));
1611        assert_eq!(refs[2].timestamp, 1002);
1612        assert_eq!(refs[2].id, None);
1613    }
1614
1615    #[test]
1616    fn test_event_ref_buffer_max_cap() {
1617        let mut buf = EventRefBuffer::new(3);
1618        for i in 0..5 {
1619            buf.push(1000 + i, &json!({"id": format!("e-{i}")}));
1620        }
1621        assert_eq!(buf.len(), 3);
1622        let refs = buf.refs();
1623        assert_eq!(refs[0].id, Some("e-2".to_string()));
1624        assert_eq!(refs[1].id, Some("e-3".to_string()));
1625        assert_eq!(refs[2].id, Some("e-4".to_string()));
1626    }
1627
1628    #[test]
1629    fn test_event_ref_buffer_eviction() {
1630        let mut buf = EventRefBuffer::new(10);
1631        for i in 0..5 {
1632            buf.push(1000 + i, &json!({"id": format!("e-{i}")}));
1633        }
1634        buf.evict(1003);
1635        assert_eq!(buf.len(), 2);
1636        let refs = buf.refs();
1637        assert_eq!(refs[0].timestamp, 1003);
1638        assert_eq!(refs[1].timestamp, 1004);
1639    }
1640
1641    #[test]
1642    fn test_event_ref_buffer_clear() {
1643        let mut buf = EventRefBuffer::new(10);
1644        buf.push(1000, &json!({"id": "a"}));
1645        buf.push(1001, &json!({"id": "b"}));
1646        assert_eq!(buf.len(), 2);
1647
1648        buf.clear();
1649        assert!(buf.is_empty());
1650        assert_eq!(buf.len(), 0);
1651    }
1652
1653    #[test]
1654    fn test_extract_event_id_common_fields() {
1655        assert_eq!(
1656            extract_event_id(&json!({"id": "abc"})),
1657            Some("abc".to_string())
1658        );
1659        assert_eq!(
1660            extract_event_id(&json!({"_id": 123})),
1661            Some("123".to_string())
1662        );
1663        assert_eq!(
1664            extract_event_id(&json!({"event_id": "x-1"})),
1665            Some("x-1".to_string())
1666        );
1667        assert_eq!(
1668            extract_event_id(&json!({"EventRecordID": 999})),
1669            Some("999".to_string())
1670        );
1671        assert_eq!(extract_event_id(&json!({"no_id_field": true})), None);
1672    }
1673
1674    #[test]
1675    fn test_compile_correlation_with_custom_attributes() {
1676        use rsigma_parser::*;
1677
1678        let mut custom_attributes = std::collections::HashMap::new();
1679        custom_attributes.insert(
1680            "rsigma.correlation_event_mode".to_string(),
1681            "refs".to_string(),
1682        );
1683        custom_attributes.insert(
1684            "rsigma.max_correlation_events".to_string(),
1685            "25".to_string(),
1686        );
1687        custom_attributes.insert("rsigma.suppress".to_string(), "5m".to_string());
1688        custom_attributes.insert("rsigma.action".to_string(), "reset".to_string());
1689
1690        let rule = CorrelationRule {
1691            title: "Test Corr".to_string(),
1692            id: Some("corr-1".to_string()),
1693            name: None,
1694            status: None,
1695            description: None,
1696            author: None,
1697            date: None,
1698            modified: None,
1699            references: vec![],
1700            tags: vec![],
1701            level: Some(Level::High),
1702            correlation_type: CorrelationType::EventCount,
1703            rules: vec!["rule-1".to_string()],
1704            group_by: vec!["User".to_string()],
1705            timespan: Timespan::parse("60s").unwrap(),
1706            condition: CorrelationCondition::Threshold {
1707                predicates: vec![(ConditionOperator::Gte, 5)],
1708                field: None,
1709            },
1710            aliases: vec![],
1711            generate: false,
1712            custom_attributes,
1713        };
1714
1715        let compiled = compile_correlation(&rule).unwrap();
1716
1717        // Per-correlation overrides should be resolved from custom_attributes
1718        assert_eq!(
1719            compiled.event_mode,
1720            Some(crate::correlation_engine::CorrelationEventMode::Refs)
1721        );
1722        assert_eq!(compiled.max_events, Some(25));
1723        assert_eq!(compiled.suppress_secs, Some(300)); // 5m = 300s
1724        assert_eq!(
1725            compiled.action,
1726            Some(crate::correlation_engine::CorrelationAction::Reset)
1727        );
1728    }
1729}