Skip to main content

rsigma_eval/
correlation.rs

1//! Compiled correlation types, group key, window state, and compilation.
2//!
3//! Transforms the parser's `CorrelationRule` AST into an optimized
4//! `CompiledCorrelation` with associated `WindowState` for stateful evaluation.
5
6use std::collections::{HashMap, HashSet, VecDeque};
7use std::io::{Read as IoRead, Write as IoWrite};
8use std::sync::Arc;
9
10use flate2::Compression;
11use flate2::read::DeflateDecoder;
12use flate2::write::DeflateEncoder;
13use serde::Serialize;
14
15use rsigma_parser::{
16    ConditionExpr, ConditionOperator, CorrelationCondition, CorrelationRule, CorrelationType,
17    FieldAlias, Level,
18};
19
20use crate::error::{EvalError, Result};
21use crate::event::{Event, EventValue};
22
23// =============================================================================
24// Compiled types
25// =============================================================================
26
27/// Compiled form of a `CorrelationRule`, ready for stateful evaluation.
28#[derive(Debug, Clone)]
29pub struct CompiledCorrelation {
30    pub id: Option<String>,
31    pub name: Option<String>,
32    pub title: String,
33    pub level: Option<Level>,
34    pub tags: Vec<String>,
35    pub correlation_type: CorrelationType,
36    /// IDs or names of referenced rules (detection or other correlations).
37    pub rule_refs: Vec<String>,
38    /// Resolved group-by fields (may include aliases).
39    pub group_by: Vec<GroupByField>,
40    /// Time window in seconds.
41    pub timespan_secs: u64,
42    /// Compiled threshold condition.
43    pub condition: CompiledCondition,
44    /// Extended boolean condition expression for temporal correlations.
45    /// When set, evaluates this expression against fired rules instead of
46    /// a simple threshold count.
47    pub extended_expr: Option<ConditionExpr>,
48    /// Whether referenced detection rules should also generate standalone matches.
49    pub generate: bool,
50    /// Per-correlation suppression window in seconds, resolved from the
51    /// `rsigma.suppress` custom attribute. `None` means use engine default.
52    pub suppress_secs: Option<u64>,
53    /// Per-correlation action on match, resolved from the `rsigma.action`
54    /// custom attribute. `None` means use engine default.
55    pub action: Option<crate::correlation_engine::CorrelationAction>,
56    /// Event inclusion mode for this correlation.
57    /// `None` means use the engine default (`CorrelationConfig.correlation_event_mode`).
58    pub event_mode: Option<crate::correlation_engine::CorrelationEventMode>,
59    /// Maximum events to store per window group for event inclusion.
60    /// `None` means use the engine default (`CorrelationConfig.max_correlation_events`).
61    pub max_events: Option<usize>,
62    /// Custom attributes from the original Sigma correlation rule (merged).
63    /// Wrapped in `Arc` so that per-match cloning is a pointer bump.
64    pub custom_attributes: Arc<HashMap<String, serde_json::Value>>,
65}
66
67/// A group-by field, potentially aliased per referenced rule.
68#[derive(Debug, Clone)]
69pub enum GroupByField {
70    /// Simple field name, same across all referenced rules.
71    Direct(String),
72    /// Aliased: maps rule_ref -> actual field name in that rule's events.
73    Aliased {
74        alias: String,
75        mapping: HashMap<String, String>,
76    },
77}
78
79impl GroupByField {
80    /// Get the display name of this group-by field.
81    pub fn name(&self) -> &str {
82        match self {
83            GroupByField::Direct(s) => s,
84            GroupByField::Aliased { alias, .. } => alias,
85        }
86    }
87
88    /// Resolve the actual field name to look up in an event, given which
89    /// rule (by ID or name) produced the detection match.
90    ///
91    /// Tries to find the rule in the alias mapping by any of the provided
92    /// identifiers (ID, name, etc.).
93    pub fn resolve(&self, rule_refs: &[&str]) -> &str {
94        match self {
95            GroupByField::Direct(s) => s,
96            GroupByField::Aliased { alias, mapping } => {
97                for r in rule_refs {
98                    if let Some(field) = mapping.get(*r) {
99                        return field.as_str();
100                    }
101                }
102                alias
103            }
104        }
105    }
106}
107
108/// Compiled threshold condition with one or more predicates (supports ranges).
109#[derive(Debug, Clone)]
110pub struct CompiledCondition {
111    /// Optional field name(s) for value_count, value_sum, value_avg, value_percentile.
112    /// When multiple fields are present, value_count counts distinct tuples.
113    pub field: Option<Vec<String>>,
114    /// One or more predicates to satisfy (all must be true for the condition to match).
115    pub predicates: Vec<(ConditionOperator, f64)>,
116    /// Percentile rank (0-100) for `value_percentile` type.
117    /// `None` means use the default (50).
118    pub percentile: Option<u64>,
119}
120
121impl CompiledCondition {
122    /// Check if the given value satisfies all predicates.
123    pub fn check(&self, value: f64) -> bool {
124        self.predicates.iter().all(|(op, threshold)| match op {
125            ConditionOperator::Lt => value < *threshold,
126            ConditionOperator::Lte => value <= *threshold,
127            ConditionOperator::Gt => value > *threshold,
128            ConditionOperator::Gte => value >= *threshold,
129            ConditionOperator::Eq => (value - *threshold).abs() < f64::EPSILON,
130            ConditionOperator::Neq => (value - *threshold).abs() >= f64::EPSILON,
131        })
132    }
133}
134
135// =============================================================================
136// Group Key
137// =============================================================================
138
139/// Composite key for group-by partitioning.
140///
141/// Each element corresponds to a `GroupByField` value extracted from an event.
142/// `None` means the field was absent from the event.
143#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, serde::Deserialize)]
144pub struct GroupKey(pub Vec<Option<String>>);
145
146impl GroupKey {
147    /// Extract a group key from an event given the group-by fields and the
148    /// rule reference identifiers (ID, name, etc.) that produced the detection match.
149    pub fn extract(event: &impl Event, group_by: &[GroupByField], rule_refs: &[&str]) -> Self {
150        let values = group_by
151            .iter()
152            .map(|field| {
153                let field_name = field.resolve(rule_refs);
154                event
155                    .get_field(field_name)
156                    .and_then(|v| value_to_string(&v))
157            })
158            .collect();
159        GroupKey(values)
160    }
161
162    /// Build a group key from explicit field-value pairs (for chaining).
163    pub fn from_pairs(pairs: &[(String, String)], group_by: &[GroupByField]) -> Self {
164        let values = group_by
165            .iter()
166            .map(|field| {
167                let name = field.name();
168                pairs
169                    .iter()
170                    .find(|(k, _)| k == name)
171                    .map(|(_, v)| v.clone())
172            })
173            .collect();
174        GroupKey(values)
175    }
176
177    /// Convert to field-name/value pairs for output.
178    pub fn to_pairs(&self, group_by: &[GroupByField]) -> Vec<(String, String)> {
179        group_by
180            .iter()
181            .zip(self.0.iter())
182            .filter_map(|(field, value)| {
183                value
184                    .as_ref()
185                    .map(|v| (field.name().to_string(), v.clone()))
186            })
187            .collect()
188    }
189}
190
191/// Convert an [`EventValue`] to a string for group-key purposes.
192fn value_to_string(v: &EventValue) -> Option<String> {
193    match v {
194        EventValue::Str(s) => Some(s.to_string()),
195        EventValue::Int(n) => Some(n.to_string()),
196        EventValue::Float(f) => Some(f.to_string()),
197        EventValue::Bool(b) => Some(b.to_string()),
198        _ => None,
199    }
200}
201
202// =============================================================================
203// Compressed Event Buffer
204// =============================================================================
205
206/// Default compression level — fast compression (level 1) for minimal latency.
207/// Deflate level 1 still achieves ~2-3x compression on JSON while being very fast.
208const COMPRESSION_LEVEL: Compression = Compression::fast();
209
210/// Compressed event storage for correlation event inclusion.
211///
212/// Stores event JSON payloads as individually deflate-compressed blobs alongside
213/// their timestamps. This enables per-event eviction (matching `WindowState`
214/// eviction) while keeping memory usage low.
215///
216/// # Memory Model
217///
218/// Each stored event costs approximately `compressed_size + 24` bytes
219/// (8 for timestamp, 16 for Vec overhead). Typical JSON events (500B–5KB)
220/// compress to 100B–1KB with deflate, giving 3–5x memory savings.
221///
222/// The buffer enforces a hard cap (`max_events`) so memory is bounded at:
223///   `max_events × (avg_compressed_size + 24)` bytes per group key.
224#[derive(Debug, Clone, Serialize, serde::Deserialize)]
225pub struct EventBuffer {
226    /// (timestamp, deflate-compressed event JSON) pairs, ordered by timestamp.
227    #[serde(with = "event_buffer_serde")]
228    entries: VecDeque<(i64, Vec<u8>)>,
229    /// Maximum number of events to retain. When exceeded, the oldest event is
230    /// evicted regardless of the time window.
231    max_events: usize,
232}
233
234/// Custom serde for EventBuffer entries: encodes compressed bytes as base64
235/// instead of JSON number arrays, cutting snapshot size ~3x.
236mod event_buffer_serde {
237    use serde::{Deserialize, Deserializer, Serialize, Serializer};
238    use std::collections::VecDeque;
239
240    #[derive(Serialize, Deserialize)]
241    struct Entry {
242        ts: i64,
243        #[serde(with = "base64_bytes")]
244        data: Vec<u8>,
245    }
246
247    mod base64_bytes {
248        use base64::Engine as _;
249        use base64::engine::general_purpose::STANDARD;
250        use serde::{Deserializer, Serializer};
251
252        pub fn serialize<S: Serializer>(bytes: &Vec<u8>, s: S) -> Result<S::Ok, S::Error> {
253            s.serialize_str(&STANDARD.encode(bytes))
254        }
255
256        pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<Vec<u8>, D::Error> {
257            let s: String = serde::Deserialize::deserialize(d)?;
258            STANDARD.decode(s).map_err(serde::de::Error::custom)
259        }
260    }
261
262    pub fn serialize<S: Serializer>(
263        entries: &VecDeque<(i64, Vec<u8>)>,
264        s: S,
265    ) -> Result<S::Ok, S::Error> {
266        let v: Vec<Entry> = entries
267            .iter()
268            .map(|(ts, data)| Entry {
269                ts: *ts,
270                data: data.clone(),
271            })
272            .collect();
273        v.serialize(s)
274    }
275
276    pub fn deserialize<'de, D: Deserializer<'de>>(
277        d: D,
278    ) -> Result<VecDeque<(i64, Vec<u8>)>, D::Error> {
279        let v: Vec<Entry> = Vec::deserialize(d)?;
280        Ok(v.into_iter().map(|e| (e.ts, e.data)).collect())
281    }
282}
283
284impl EventBuffer {
285    /// Create a new event buffer with the given capacity cap.
286    pub fn new(max_events: usize) -> Self {
287        EventBuffer {
288            entries: VecDeque::with_capacity(max_events.min(64)),
289            max_events,
290        }
291    }
292
293    /// Compress and store an event. Evicts the oldest entry if at capacity.
294    pub fn push(&mut self, ts: i64, event: &serde_json::Value) {
295        // Compress the event JSON with deflate
296        if let Some(compressed) = compress_event(event) {
297            if self.entries.len() >= self.max_events {
298                self.entries.pop_front();
299            }
300            self.entries.push_back((ts, compressed));
301        }
302    }
303
304    /// Remove all entries older than the cutoff timestamp.
305    pub fn evict(&mut self, cutoff: i64) {
306        while self.entries.front().is_some_and(|(t, _)| *t < cutoff) {
307            self.entries.pop_front();
308        }
309    }
310
311    /// Decompress and return all stored events.
312    pub fn decompress_all(&self) -> Vec<serde_json::Value> {
313        self.entries
314            .iter()
315            .filter_map(|(_, compressed)| decompress_event(compressed))
316            .collect()
317    }
318
319    /// Returns true if there are no stored events.
320    pub fn is_empty(&self) -> bool {
321        self.entries.is_empty()
322    }
323
324    /// Clear all stored events.
325    pub fn clear(&mut self) {
326        self.entries.clear();
327    }
328
329    /// Total compressed bytes stored (for monitoring/diagnostics).
330    pub fn compressed_bytes(&self) -> usize {
331        self.entries.iter().map(|(_, data)| data.len()).sum()
332    }
333
334    /// Number of stored events.
335    pub fn len(&self) -> usize {
336        self.entries.len()
337    }
338}
339
340/// Compress an event JSON value using deflate.
341fn compress_event(event: &serde_json::Value) -> Option<Vec<u8>> {
342    let json_bytes = serde_json::to_vec(event).ok()?;
343    let mut encoder = DeflateEncoder::new(Vec::new(), COMPRESSION_LEVEL);
344    encoder.write_all(&json_bytes).ok()?;
345    encoder.finish().ok()
346}
347
348/// Decompress a deflate-compressed event back to a JSON value.
349fn decompress_event(compressed: &[u8]) -> Option<serde_json::Value> {
350    let mut decoder = DeflateDecoder::new(compressed);
351    let mut json_bytes = Vec::new();
352    decoder.read_to_end(&mut json_bytes).ok()?;
353    serde_json::from_slice(&json_bytes).ok()
354}
355
356// =============================================================================
357// Event Reference (lightweight mode)
358// =============================================================================
359
360/// A lightweight event reference: timestamp plus optional event ID.
361///
362/// Used in `Refs` mode for memory-efficient correlation event tracking.
363/// Each ref costs ~40 bytes (vs. 100–1000+ bytes for compressed events),
364/// making this mode suitable for high-volume correlations where only
365/// traceability is needed.
366#[derive(Debug, Clone, Serialize, serde::Deserialize)]
367pub struct EventRef {
368    /// Event timestamp (epoch seconds).
369    pub timestamp: i64,
370    /// Event ID extracted from common fields (`id`, `_id`, `event_id`, etc.).
371    #[serde(skip_serializing_if = "Option::is_none")]
372    pub id: Option<String>,
373}
374
375/// Lightweight event reference buffer for `Refs` mode.
376///
377/// Stores only timestamps and optional event IDs — no event payload,
378/// no compression. This is the minimal-memory alternative to `EventBuffer`.
379#[derive(Debug, Clone, Serialize, serde::Deserialize)]
380pub struct EventRefBuffer {
381    /// Event references, ordered by timestamp.
382    entries: VecDeque<EventRef>,
383    /// Maximum number of refs to retain.
384    max_events: usize,
385}
386
387impl EventRefBuffer {
388    /// Create a new ref buffer with the given capacity cap.
389    pub fn new(max_events: usize) -> Self {
390        EventRefBuffer {
391            entries: VecDeque::with_capacity(max_events.min(64)),
392            max_events,
393        }
394    }
395
396    /// Store a reference to an event. Evicts the oldest ref if at capacity.
397    pub fn push(&mut self, ts: i64, event: &serde_json::Value) {
398        if self.entries.len() >= self.max_events {
399            self.entries.pop_front();
400        }
401        let id = extract_event_id(event);
402        self.entries.push_back(EventRef { timestamp: ts, id });
403    }
404
405    /// Remove all refs older than the cutoff timestamp.
406    pub fn evict(&mut self, cutoff: i64) {
407        while self.entries.front().is_some_and(|r| r.timestamp < cutoff) {
408            self.entries.pop_front();
409        }
410    }
411
412    /// Return cloned refs.
413    pub fn refs(&self) -> Vec<EventRef> {
414        self.entries.iter().cloned().collect()
415    }
416
417    /// Returns true if there are no stored refs.
418    pub fn is_empty(&self) -> bool {
419        self.entries.is_empty()
420    }
421
422    /// Clear all stored refs.
423    pub fn clear(&mut self) {
424        self.entries.clear();
425    }
426
427    /// Number of stored refs.
428    pub fn len(&self) -> usize {
429        self.entries.len()
430    }
431}
432
433/// Try to extract an event ID from common fields.
434///
435/// Checks (in order): `id`, `_id`, `event_id`, `EventRecordID`, `event.id`.
436/// Returns the first found value as a string.
437fn extract_event_id(event: &serde_json::Value) -> Option<String> {
438    const ID_FIELDS: &[&str] = &["id", "_id", "event_id", "EventRecordID", "event.id"];
439    for field in ID_FIELDS {
440        if let Some(val) = event.get(field) {
441            return match val {
442                serde_json::Value::String(s) => Some(s.clone()),
443                serde_json::Value::Number(n) => Some(n.to_string()),
444                _ => None,
445            };
446        }
447    }
448    None
449}
450
451// =============================================================================
452// Window State
453// =============================================================================
454
455/// Per-group mutable state within a time window.
456///
457/// Each variant matches the type of aggregation being performed.
458#[derive(Debug, Clone, Serialize, serde::Deserialize)]
459pub enum WindowState {
460    /// For `event_count`: timestamps of matching events.
461    EventCount { timestamps: VecDeque<i64> },
462    /// For `value_count`: (timestamp, field_value) pairs.
463    ValueCount { entries: VecDeque<(i64, String)> },
464    /// For `temporal` / `temporal_ordered`: rule_ref -> list of hit timestamps.
465    Temporal {
466        rule_hits: HashMap<String, VecDeque<i64>>,
467    },
468    /// For `value_sum`, `value_avg`, `value_percentile`, `value_median`:
469    /// (timestamp, numeric_value) pairs.
470    NumericAgg { entries: VecDeque<(i64, f64)> },
471}
472
473impl WindowState {
474    /// Create a new empty window state for the given correlation type.
475    pub fn new_for(corr_type: CorrelationType) -> Self {
476        match corr_type {
477            CorrelationType::EventCount => WindowState::EventCount {
478                timestamps: VecDeque::new(),
479            },
480            CorrelationType::ValueCount => WindowState::ValueCount {
481                entries: VecDeque::new(),
482            },
483            CorrelationType::Temporal | CorrelationType::TemporalOrdered => WindowState::Temporal {
484                rule_hits: HashMap::new(),
485            },
486            CorrelationType::ValueSum
487            | CorrelationType::ValueAvg
488            | CorrelationType::ValuePercentile
489            | CorrelationType::ValueMedian => WindowState::NumericAgg {
490                entries: VecDeque::new(),
491            },
492        }
493    }
494
495    /// Remove all entries older than the cutoff timestamp.
496    pub fn evict(&mut self, cutoff: i64) {
497        match self {
498            WindowState::EventCount { timestamps } => {
499                while timestamps.front().is_some_and(|&t| t < cutoff) {
500                    timestamps.pop_front();
501                }
502            }
503            WindowState::ValueCount { entries } => {
504                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
505                    entries.pop_front();
506                }
507            }
508            WindowState::Temporal { rule_hits } => {
509                for timestamps in rule_hits.values_mut() {
510                    while timestamps.front().is_some_and(|&t| t < cutoff) {
511                        timestamps.pop_front();
512                    }
513                }
514                // Remove empty rule entries
515                rule_hits.retain(|_, ts| !ts.is_empty());
516            }
517            WindowState::NumericAgg { entries } => {
518                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
519                    entries.pop_front();
520                }
521            }
522        }
523    }
524
525    /// Returns true if this state has no entries.
526    pub fn is_empty(&self) -> bool {
527        match self {
528            WindowState::EventCount { timestamps } => timestamps.is_empty(),
529            WindowState::ValueCount { entries } => entries.is_empty(),
530            WindowState::Temporal { rule_hits } => rule_hits.is_empty(),
531            WindowState::NumericAgg { entries } => entries.is_empty(),
532        }
533    }
534
535    /// Returns the most recent timestamp in this window, or `None` if empty.
536    pub fn latest_timestamp(&self) -> Option<i64> {
537        match self {
538            WindowState::EventCount { timestamps } => timestamps.back().copied(),
539            WindowState::ValueCount { entries } => entries.back().map(|(t, _)| *t),
540            WindowState::Temporal { rule_hits } => {
541                rule_hits.values().filter_map(|ts| ts.back().copied()).max()
542            }
543            WindowState::NumericAgg { entries } => entries.back().map(|(t, _)| *t),
544        }
545    }
546
547    /// Clear all entries from the window state (used by `CorrelationAction::Reset`).
548    pub fn clear(&mut self) {
549        match self {
550            WindowState::EventCount { timestamps } => timestamps.clear(),
551            WindowState::ValueCount { entries } => entries.clear(),
552            WindowState::Temporal { rule_hits } => rule_hits.clear(),
553            WindowState::NumericAgg { entries } => entries.clear(),
554        }
555    }
556
557    /// Record an event_count hit.
558    pub fn push_event_count(&mut self, ts: i64) {
559        if let WindowState::EventCount { timestamps } = self {
560            timestamps.push_back(ts);
561        }
562    }
563
564    /// Record a value_count hit with the field value.
565    pub fn push_value_count(&mut self, ts: i64, value: String) {
566        if let WindowState::ValueCount { entries } = self {
567            entries.push_back((ts, value));
568        }
569    }
570
571    /// Record a temporal hit for a specific rule reference.
572    pub fn push_temporal(&mut self, ts: i64, rule_ref: &str) {
573        if let WindowState::Temporal { rule_hits } = self {
574            rule_hits
575                .entry(rule_ref.to_string())
576                .or_default()
577                .push_back(ts);
578        }
579    }
580
581    /// Record a numeric aggregation value.
582    pub fn push_numeric(&mut self, ts: i64, value: f64) {
583        if let WindowState::NumericAgg { entries } = self {
584            entries.push_back((ts, value));
585        }
586    }
587
588    /// Evaluate the window state against the correlation condition.
589    ///
590    /// Returns `Some(aggregated_value)` if the condition is satisfied,
591    /// `None` otherwise.
592    ///
593    /// For temporal correlations with an extended expression, the expression
594    /// is evaluated against the set of rules that have fired in the window.
595    pub fn check_condition(
596        &self,
597        condition: &CompiledCondition,
598        corr_type: CorrelationType,
599        rule_refs: &[String],
600        extended_expr: Option<&ConditionExpr>,
601    ) -> Option<f64> {
602        let value = match (self, corr_type) {
603            (WindowState::EventCount { timestamps }, CorrelationType::EventCount) => {
604                timestamps.len() as f64
605            }
606            (WindowState::ValueCount { entries }, CorrelationType::ValueCount) => {
607                // Count distinct values
608                let distinct: HashSet<&String> = entries.iter().map(|(_, v)| v).collect();
609                distinct.len() as f64
610            }
611            (WindowState::Temporal { rule_hits }, CorrelationType::Temporal) => {
612                // If an extended expression is provided, evaluate it
613                if let Some(expr) = extended_expr {
614                    if eval_temporal_expr(expr, rule_hits) {
615                        // Return the count of fired rules as the value
616                        let fired: usize = rule_refs
617                            .iter()
618                            .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
619                            .count();
620                        return Some(fired as f64);
621                    } else {
622                        return None;
623                    }
624                }
625                // Default: count how many distinct referenced rules have fired
626                let fired: usize = rule_refs
627                    .iter()
628                    .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
629                    .count();
630                fired as f64
631            }
632            (WindowState::Temporal { rule_hits }, CorrelationType::TemporalOrdered) => {
633                // If an extended expression is provided, evaluate it first
634                if let Some(expr) = extended_expr
635                    && !eval_temporal_expr(expr, rule_hits)
636                {
637                    return None;
638                }
639                // Check if all referenced rules fired in order
640                if check_temporal_ordered(rule_refs, rule_hits) {
641                    rule_refs.len() as f64
642                } else {
643                    0.0
644                }
645            }
646            (WindowState::NumericAgg { entries }, CorrelationType::ValueSum) => {
647                entries.iter().map(|(_, v)| v).sum()
648            }
649            (WindowState::NumericAgg { entries }, CorrelationType::ValueAvg) => {
650                if entries.is_empty() {
651                    0.0
652                } else {
653                    let sum: f64 = entries.iter().map(|(_, v)| v).sum();
654                    sum / entries.len() as f64
655                }
656            }
657            (WindowState::NumericAgg { entries }, CorrelationType::ValuePercentile) => {
658                // Proper percentile calculation using linear interpolation.
659                // The condition threshold represents a percentile rank (0-100).
660                // We compute the value at that percentile from the window data.
661                if entries.is_empty() {
662                    return None;
663                }
664                let mut values: Vec<f64> = entries
665                    .iter()
666                    .map(|(_, v)| *v)
667                    .filter(|v| v.is_finite())
668                    .collect();
669                if values.is_empty() {
670                    return None;
671                }
672                values.sort_by(|a, b| a.partial_cmp(b).expect("NaN filtered"));
673                let percentile_rank = condition.percentile.map(|p| p as f64).unwrap_or(50.0);
674                let pval = percentile_linear_interp(&values, percentile_rank);
675                return Some(pval);
676            }
677            (WindowState::NumericAgg { entries }, CorrelationType::ValueMedian) => {
678                if entries.is_empty() {
679                    0.0
680                } else {
681                    let mut values: Vec<f64> = entries
682                        .iter()
683                        .map(|(_, v)| *v)
684                        .filter(|v| v.is_finite())
685                        .collect();
686                    if values.is_empty() {
687                        return None;
688                    }
689                    values.sort_by(|a, b| a.partial_cmp(b).expect("NaN filtered"));
690                    let mid = values.len() / 2;
691                    if values.len().is_multiple_of(2) && values.len() >= 2 {
692                        (values[mid - 1] + values[mid]) / 2.0
693                    } else {
694                        values[mid]
695                    }
696                }
697            }
698            _ => return None, // mismatched state/type
699        };
700
701        if condition.check(value) {
702            Some(value)
703        } else {
704            None
705        }
706    }
707}
708
709/// Check if all referenced rules fired in the correct order within the window.
710///
711/// For `temporal_ordered`, each rule must have at least one hit, and there
712/// must exist a sequence of timestamps (one per rule) that is non-decreasing
713/// and follows the rule ordering.
714fn check_temporal_ordered(
715    rule_refs: &[String],
716    rule_hits: &HashMap<String, VecDeque<i64>>,
717) -> bool {
718    if rule_refs.is_empty() {
719        return true;
720    }
721
722    // All rules must have at least one hit
723    for r in rule_refs {
724        if rule_hits.get(r.as_str()).is_none_or(|ts| ts.is_empty()) {
725            return false;
726        }
727    }
728
729    // Check if there's a valid ordered sequence: for each rule in order,
730    // find a timestamp >= the previous rule's chosen timestamp.
731    fn find_ordered(
732        rule_refs: &[String],
733        rule_hits: &HashMap<String, VecDeque<i64>>,
734        idx: usize,
735        min_ts: i64,
736    ) -> bool {
737        if idx >= rule_refs.len() {
738            return true;
739        }
740        let Some(timestamps) = rule_hits.get(&rule_refs[idx]) else {
741            return false;
742        };
743        for &ts in timestamps {
744            if ts >= min_ts && find_ordered(rule_refs, rule_hits, idx + 1, ts) {
745                return true;
746            }
747        }
748        false
749    }
750
751    find_ordered(rule_refs, rule_hits, 0, i64::MIN)
752}
753
754/// Evaluate a boolean condition expression against the set of rules that have
755/// fired within the temporal window.
756///
757/// Each `Identifier` in the expression is treated as a rule reference — it's
758/// `true` if that rule has at least one hit in `rule_hits`.
759fn eval_temporal_expr(expr: &ConditionExpr, rule_hits: &HashMap<String, VecDeque<i64>>) -> bool {
760    match expr {
761        ConditionExpr::Identifier(name) => rule_hits
762            .get(name.as_str())
763            .is_some_and(|ts| !ts.is_empty()),
764        ConditionExpr::And(children) => children.iter().all(|c| eval_temporal_expr(c, rule_hits)),
765        ConditionExpr::Or(children) => children.iter().any(|c| eval_temporal_expr(c, rule_hits)),
766        ConditionExpr::Not(child) => !eval_temporal_expr(child, rule_hits),
767        ConditionExpr::Selector { .. } => {
768            // Selectors are not meaningful for temporal condition evaluation
769            false
770        }
771    }
772}
773
774/// Compute the value at a given percentile rank using linear interpolation.
775///
776/// Returns 0.0 if `values` is empty.
777/// `values` must be sorted in ascending order.
778/// `percentile` is from 0.0 to 100.0.
779fn percentile_linear_interp(values: &[f64], percentile: f64) -> f64 {
780    if values.is_empty() {
781        return 0.0;
782    }
783    let n = values.len();
784    if n == 1 {
785        return values[0];
786    }
787
788    // Clamp percentile to [0, 100]
789    let p = percentile.clamp(0.0, 100.0) / 100.0;
790
791    // Use the "C = 1" interpolation method (most common in statistics)
792    // rank = p * (n - 1)
793    let rank = p * (n - 1) as f64;
794    let lower = rank.floor() as usize;
795    let upper = rank.ceil() as usize;
796    let fraction = rank - lower as f64;
797
798    if lower == upper || upper >= n {
799        values[lower.min(n - 1)]
800    } else {
801        values[lower] + fraction * (values[upper] - values[lower])
802    }
803}
804
805// =============================================================================
806// Compilation
807// =============================================================================
808
809/// Compile a parsed `CorrelationRule` into a `CompiledCorrelation`.
810pub fn compile_correlation(rule: &CorrelationRule) -> Result<CompiledCorrelation> {
811    // Build group-by fields, resolving aliases
812    let alias_map: HashMap<&str, &FieldAlias> =
813        rule.aliases.iter().map(|a| (a.alias.as_str(), a)).collect();
814
815    let group_by: Vec<GroupByField> = rule
816        .group_by
817        .iter()
818        .map(|field_name| {
819            if let Some(alias) = alias_map.get(field_name.as_str()) {
820                GroupByField::Aliased {
821                    alias: field_name.clone(),
822                    mapping: alias.mapping.clone(),
823                }
824            } else {
825                GroupByField::Direct(field_name.clone())
826            }
827        })
828        .collect();
829
830    // Compile condition
831    let (condition, extended_expr) = compile_condition(&rule.condition, rule.correlation_type)?;
832
833    // Resolve per-correlation overrides from custom attributes.
834    // These mirror the engine-level `rsigma.*` attributes but apply only
835    // to this correlation rule, taking precedence over engine defaults.
836    let suppress_secs = rule
837        .custom_attributes
838        .get("rsigma.suppress")
839        .and_then(|v| v.as_str())
840        .and_then(|s| rsigma_parser::Timespan::parse(s).ok())
841        .map(|ts| ts.seconds);
842
843    let action = rule
844        .custom_attributes
845        .get("rsigma.action")
846        .and_then(|v| v.as_str())
847        .and_then(|s| {
848            s.parse::<crate::correlation_engine::CorrelationAction>()
849                .ok()
850        });
851
852    let event_mode = rule
853        .custom_attributes
854        .get("rsigma.correlation_event_mode")
855        .and_then(|v| v.as_str())
856        .and_then(|s| {
857            s.parse::<crate::correlation_engine::CorrelationEventMode>()
858                .ok()
859        });
860
861    let max_events = rule
862        .custom_attributes
863        .get("rsigma.max_correlation_events")
864        .and_then(|v| v.as_str())
865        .and_then(|s| s.parse::<usize>().ok());
866
867    let custom_attributes = Arc::new(crate::compiler::yaml_to_json_map(&rule.custom_attributes));
868
869    Ok(CompiledCorrelation {
870        id: rule.id.clone(),
871        name: rule.name.clone(),
872        title: rule.title.clone(),
873        level: rule.level,
874        tags: rule.tags.clone(),
875        correlation_type: rule.correlation_type,
876        rule_refs: rule.rules.clone(),
877        group_by,
878        timespan_secs: rule.timespan.seconds,
879        condition,
880        extended_expr,
881        generate: rule.generate,
882        suppress_secs,
883        action,
884        event_mode,
885        max_events,
886        custom_attributes,
887    })
888}
889
890/// Compile a `CorrelationCondition` into a `CompiledCondition` and optional expression.
891fn compile_condition(
892    cond: &CorrelationCondition,
893    corr_type: CorrelationType,
894) -> Result<(CompiledCondition, Option<ConditionExpr>)> {
895    match cond {
896        CorrelationCondition::Threshold {
897            predicates,
898            field,
899            percentile,
900        } => Ok((
901            CompiledCondition {
902                field: field.clone(),
903                predicates: predicates
904                    .iter()
905                    .map(|(op, count)| (*op, *count as f64))
906                    .collect(),
907                percentile: *percentile,
908            },
909            None,
910        )),
911        CorrelationCondition::Extended(expr) => {
912            match corr_type {
913                CorrelationType::Temporal | CorrelationType::TemporalOrdered => {
914                    // For extended conditions, the threshold is a dummy (gte: 1)
915                    // since the actual evaluation is done via the expression tree.
916                    Ok((
917                        CompiledCondition {
918                            field: None,
919                            predicates: vec![(ConditionOperator::Gte, 1.0)],
920                            percentile: None,
921                        },
922                        Some(expr.clone()),
923                    ))
924                }
925                _ => Err(EvalError::CorrelationError(
926                    "Extended conditions are only supported for temporal correlation types"
927                        .to_string(),
928                )),
929            }
930        }
931    }
932}
933
934#[cfg(test)]
935mod tests {
936    use super::*;
937    use crate::event::JsonEvent;
938    use serde_json::json;
939
940    #[test]
941    fn test_group_key_extract() {
942        let v = json!({"User": "admin", "Host": "srv01"});
943        let event = JsonEvent::borrow(&v);
944        let group_by = vec![
945            GroupByField::Direct("User".to_string()),
946            GroupByField::Direct("Host".to_string()),
947        ];
948        let key = GroupKey::extract(&event, &group_by, &["rule1"]);
949        assert_eq!(
950            key.0,
951            vec![Some("admin".to_string()), Some("srv01".to_string())]
952        );
953    }
954
955    #[test]
956    fn test_group_key_missing_field() {
957        let v = json!({"User": "admin"});
958        let event = JsonEvent::borrow(&v);
959        let group_by = vec![
960            GroupByField::Direct("User".to_string()),
961            GroupByField::Direct("Host".to_string()),
962        ];
963        let key = GroupKey::extract(&event, &group_by, &["rule1"]);
964        assert_eq!(key.0, vec![Some("admin".to_string()), None]);
965    }
966
967    #[test]
968    fn test_group_key_aliased() {
969        let v = json!({"source.ip": "10.0.0.1"});
970        let event = JsonEvent::borrow(&v);
971        let group_by = vec![GroupByField::Aliased {
972            alias: "internal_ip".to_string(),
973            mapping: HashMap::from([
974                ("rule_a".to_string(), "source.ip".to_string()),
975                ("rule_b".to_string(), "destination.ip".to_string()),
976            ]),
977        }];
978        let key = GroupKey::extract(&event, &group_by, &["rule_a"]);
979        assert_eq!(key.0, vec![Some("10.0.0.1".to_string())]);
980    }
981
982    #[test]
983    fn test_condition_check() {
984        let cond = CompiledCondition {
985            field: None,
986            predicates: vec![(ConditionOperator::Gte, 100.0)],
987            percentile: None,
988        };
989        assert!(!cond.check(99.0));
990        assert!(cond.check(100.0));
991        assert!(cond.check(101.0));
992    }
993
994    #[test]
995    fn test_condition_check_range() {
996        let cond = CompiledCondition {
997            field: None,
998            predicates: vec![
999                (ConditionOperator::Gt, 100.0),
1000                (ConditionOperator::Lte, 200.0),
1001            ],
1002            percentile: None,
1003        };
1004        assert!(!cond.check(100.0));
1005        assert!(cond.check(101.0));
1006        assert!(cond.check(200.0));
1007        assert!(!cond.check(201.0));
1008    }
1009
1010    #[test]
1011    fn test_window_event_count() {
1012        let mut state = WindowState::new_for(CorrelationType::EventCount);
1013        for i in 0..5 {
1014            state.push_event_count(1000 + i);
1015        }
1016        let cond = CompiledCondition {
1017            field: None,
1018            predicates: vec![(ConditionOperator::Gte, 5.0)],
1019            percentile: None,
1020        };
1021        assert_eq!(
1022            state.check_condition(&cond, CorrelationType::EventCount, &[], None),
1023            Some(5.0)
1024        );
1025    }
1026
1027    #[test]
1028    fn test_window_event_count_eviction() {
1029        let mut state = WindowState::new_for(CorrelationType::EventCount);
1030        for i in 0..10 {
1031            state.push_event_count(1000 + i);
1032        }
1033        // Evict events before ts=1005
1034        state.evict(1005);
1035        let cond = CompiledCondition {
1036            field: None,
1037            predicates: vec![(ConditionOperator::Gte, 5.0)],
1038            percentile: None,
1039        };
1040        assert_eq!(
1041            state.check_condition(&cond, CorrelationType::EventCount, &[], None),
1042            Some(5.0)
1043        );
1044    }
1045
1046    #[test]
1047    fn test_window_value_count() {
1048        let mut state = WindowState::new_for(CorrelationType::ValueCount);
1049        state.push_value_count(1000, "user1".to_string());
1050        state.push_value_count(1001, "user2".to_string());
1051        state.push_value_count(1002, "user1".to_string()); // duplicate
1052        state.push_value_count(1003, "user3".to_string());
1053
1054        let cond = CompiledCondition {
1055            field: Some(vec!["User".to_string()]),
1056            predicates: vec![(ConditionOperator::Gte, 3.0)],
1057            percentile: None,
1058        };
1059        assert_eq!(
1060            state.check_condition(&cond, CorrelationType::ValueCount, &[], None),
1061            Some(3.0)
1062        );
1063    }
1064
1065    #[test]
1066    fn test_window_temporal() {
1067        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
1068        let mut state = WindowState::new_for(CorrelationType::Temporal);
1069        state.push_temporal(1000, "rule_a");
1070        // Only rule_a fired — condition: all refs must fire
1071        let cond = CompiledCondition {
1072            field: None,
1073            predicates: vec![(ConditionOperator::Gte, 2.0)],
1074            percentile: None,
1075        };
1076        assert!(
1077            state
1078                .check_condition(&cond, CorrelationType::Temporal, &refs, None)
1079                .is_none()
1080        );
1081
1082        // Now rule_b fires too
1083        state.push_temporal(1001, "rule_b");
1084        assert_eq!(
1085            state.check_condition(&cond, CorrelationType::Temporal, &refs, None),
1086            Some(2.0)
1087        );
1088    }
1089
1090    #[test]
1091    fn test_window_temporal_ordered() {
1092        let refs = vec![
1093            "rule_a".to_string(),
1094            "rule_b".to_string(),
1095            "rule_c".to_string(),
1096        ];
1097        let mut state = WindowState::new_for(CorrelationType::TemporalOrdered);
1098        // Fire in order: a, b, c
1099        state.push_temporal(1000, "rule_a");
1100        state.push_temporal(1001, "rule_b");
1101        state.push_temporal(1002, "rule_c");
1102
1103        let cond = CompiledCondition {
1104            field: None,
1105            predicates: vec![(ConditionOperator::Gte, 3.0)],
1106            percentile: None,
1107        };
1108        assert!(
1109            state
1110                .check_condition(&cond, CorrelationType::TemporalOrdered, &refs, None)
1111                .is_some()
1112        );
1113    }
1114
1115    #[test]
1116    fn test_window_temporal_ordered_wrong_order() {
1117        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
1118        let mut state = WindowState::new_for(CorrelationType::TemporalOrdered);
1119        // Fire in wrong order: b before a
1120        state.push_temporal(1000, "rule_b");
1121        state.push_temporal(1001, "rule_a");
1122
1123        let cond = CompiledCondition {
1124            field: None,
1125            predicates: vec![(ConditionOperator::Gte, 2.0)],
1126            percentile: None,
1127        };
1128        assert!(
1129            state
1130                .check_condition(&cond, CorrelationType::TemporalOrdered, &refs, None)
1131                .is_none()
1132        );
1133    }
1134
1135    #[test]
1136    fn test_window_value_sum() {
1137        let mut state = WindowState::new_for(CorrelationType::ValueSum);
1138        state.push_numeric(1000, 500.0);
1139        state.push_numeric(1001, 600.0);
1140
1141        let cond = CompiledCondition {
1142            field: Some(vec!["bytes_sent".to_string()]),
1143            predicates: vec![(ConditionOperator::Gt, 1000.0)],
1144            percentile: None,
1145        };
1146        assert_eq!(
1147            state.check_condition(&cond, CorrelationType::ValueSum, &[], None),
1148            Some(1100.0)
1149        );
1150    }
1151
1152    #[test]
1153    fn test_window_value_avg() {
1154        let mut state = WindowState::new_for(CorrelationType::ValueAvg);
1155        state.push_numeric(1000, 100.0);
1156        state.push_numeric(1001, 200.0);
1157        state.push_numeric(1002, 300.0);
1158
1159        let cond = CompiledCondition {
1160            field: Some(vec!["bytes".to_string()]),
1161            predicates: vec![(ConditionOperator::Gte, 200.0)],
1162            percentile: None,
1163        };
1164        assert_eq!(
1165            state.check_condition(&cond, CorrelationType::ValueAvg, &[], None),
1166            Some(200.0)
1167        );
1168    }
1169
1170    #[test]
1171    fn test_window_value_median() {
1172        let mut state = WindowState::new_for(CorrelationType::ValueMedian);
1173        state.push_numeric(1000, 10.0);
1174        state.push_numeric(1001, 20.0);
1175        state.push_numeric(1002, 30.0);
1176
1177        let cond = CompiledCondition {
1178            field: Some(vec!["latency".to_string()]),
1179            predicates: vec![(ConditionOperator::Gte, 20.0)],
1180            percentile: None,
1181        };
1182        assert_eq!(
1183            state.check_condition(&cond, CorrelationType::ValueMedian, &[], None),
1184            Some(20.0)
1185        );
1186    }
1187
1188    #[test]
1189    fn test_compile_correlation_basic() {
1190        use rsigma_parser::parse_sigma_yaml;
1191
1192        let yaml = r#"
1193title: Base Rule
1194id: f305fd62-beca-47da-ad95-7690a0620084
1195logsource:
1196    product: aws
1197    service: cloudtrail
1198detection:
1199    selection:
1200        eventSource: "s3.amazonaws.com"
1201    condition: selection
1202level: low
1203---
1204title: Multiple AWS bucket enumerations
1205id: be246094-01d3-4bba-88de-69e582eba0cc
1206status: experimental
1207correlation:
1208    type: event_count
1209    rules:
1210        - f305fd62-beca-47da-ad95-7690a0620084
1211    group-by:
1212        - userIdentity.arn
1213    timespan: 1h
1214    condition:
1215        gte: 100
1216level: high
1217"#;
1218        let collection = parse_sigma_yaml(yaml).unwrap();
1219        assert_eq!(collection.correlations.len(), 1);
1220
1221        let compiled = compile_correlation(&collection.correlations[0]).unwrap();
1222        assert_eq!(compiled.correlation_type, CorrelationType::EventCount);
1223        assert_eq!(compiled.timespan_secs, 3600);
1224        assert_eq!(compiled.rule_refs.len(), 1);
1225        assert_eq!(compiled.group_by.len(), 1);
1226        assert!(compiled.condition.check(100.0));
1227        assert!(!compiled.condition.check(99.0));
1228    }
1229
1230    // =========================================================================
1231    // Extended temporal condition tests
1232    // =========================================================================
1233
1234    #[test]
1235    fn test_eval_temporal_expr_and() {
1236        let mut rule_hits = HashMap::new();
1237        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1238        rule_hits.insert("rule_b".to_string(), VecDeque::from([1001]));
1239
1240        let expr = ConditionExpr::And(vec![
1241            ConditionExpr::Identifier("rule_a".to_string()),
1242            ConditionExpr::Identifier("rule_b".to_string()),
1243        ]);
1244        assert!(eval_temporal_expr(&expr, &rule_hits));
1245    }
1246
1247    #[test]
1248    fn test_eval_temporal_expr_and_incomplete() {
1249        let mut rule_hits = HashMap::new();
1250        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1251        // rule_b not fired
1252
1253        let expr = ConditionExpr::And(vec![
1254            ConditionExpr::Identifier("rule_a".to_string()),
1255            ConditionExpr::Identifier("rule_b".to_string()),
1256        ]);
1257        assert!(!eval_temporal_expr(&expr, &rule_hits));
1258    }
1259
1260    #[test]
1261    fn test_eval_temporal_expr_or() {
1262        let mut rule_hits = HashMap::new();
1263        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1264
1265        let expr = ConditionExpr::Or(vec![
1266            ConditionExpr::Identifier("rule_a".to_string()),
1267            ConditionExpr::Identifier("rule_b".to_string()),
1268        ]);
1269        assert!(eval_temporal_expr(&expr, &rule_hits));
1270    }
1271
1272    #[test]
1273    fn test_eval_temporal_expr_not() {
1274        let rule_hits = HashMap::new();
1275
1276        let expr = ConditionExpr::Not(Box::new(ConditionExpr::Identifier("rule_a".to_string())));
1277        assert!(eval_temporal_expr(&expr, &rule_hits));
1278    }
1279
1280    #[test]
1281    fn test_eval_temporal_expr_complex() {
1282        let mut rule_hits = HashMap::new();
1283        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1284        rule_hits.insert("rule_b".to_string(), VecDeque::from([1001]));
1285        // rule_c NOT fired
1286
1287        // (rule_a and rule_b) and not rule_c
1288        let expr = ConditionExpr::And(vec![
1289            ConditionExpr::And(vec![
1290                ConditionExpr::Identifier("rule_a".to_string()),
1291                ConditionExpr::Identifier("rule_b".to_string()),
1292            ]),
1293            ConditionExpr::Not(Box::new(ConditionExpr::Identifier("rule_c".to_string()))),
1294        ]);
1295        assert!(eval_temporal_expr(&expr, &rule_hits));
1296    }
1297
1298    #[test]
1299    fn test_check_condition_with_extended_expr() {
1300        let refs = vec!["rule_a".to_string(), "rule_b".to_string()];
1301        let mut state = WindowState::new_for(CorrelationType::Temporal);
1302        state.push_temporal(1000, "rule_a");
1303        state.push_temporal(1001, "rule_b");
1304
1305        let cond = CompiledCondition {
1306            field: None,
1307            predicates: vec![(ConditionOperator::Gte, 1.0)],
1308            percentile: None,
1309        };
1310        let expr = ConditionExpr::And(vec![
1311            ConditionExpr::Identifier("rule_a".to_string()),
1312            ConditionExpr::Identifier("rule_b".to_string()),
1313        ]);
1314
1315        // With expression: should match (both rules fired)
1316        assert!(
1317            state
1318                .check_condition(&cond, CorrelationType::Temporal, &refs, Some(&expr))
1319                .is_some()
1320        );
1321
1322        // Now test with only rule_a: expression should fail
1323        let mut state2 = WindowState::new_for(CorrelationType::Temporal);
1324        state2.push_temporal(1000, "rule_a");
1325        assert!(
1326            state2
1327                .check_condition(&cond, CorrelationType::Temporal, &refs, Some(&expr))
1328                .is_none()
1329        );
1330    }
1331
1332    // =========================================================================
1333    // Percentile linear interpolation tests
1334    // =========================================================================
1335
1336    #[test]
1337    fn test_percentile_linear_interp_single() {
1338        assert!((percentile_linear_interp(&[42.0], 50.0) - 42.0).abs() < f64::EPSILON);
1339    }
1340
1341    #[test]
1342    fn test_percentile_linear_interp_basic() {
1343        // Values: [1, 2, 3, 4, 5]
1344        let values = &[1.0, 2.0, 3.0, 4.0, 5.0];
1345        // 0th percentile = 1.0
1346        assert!((percentile_linear_interp(values, 0.0) - 1.0).abs() < f64::EPSILON);
1347        // 25th percentile = 2.0
1348        assert!((percentile_linear_interp(values, 25.0) - 2.0).abs() < f64::EPSILON);
1349        // 50th percentile = 3.0
1350        assert!((percentile_linear_interp(values, 50.0) - 3.0).abs() < f64::EPSILON);
1351        // 75th percentile = 4.0
1352        assert!((percentile_linear_interp(values, 75.0) - 4.0).abs() < f64::EPSILON);
1353        // 100th percentile = 5.0
1354        assert!((percentile_linear_interp(values, 100.0) - 5.0).abs() < f64::EPSILON);
1355    }
1356
1357    #[test]
1358    fn test_percentile_linear_interp_interpolation() {
1359        // Values: [10, 20, 30, 40]
1360        let values = &[10.0, 20.0, 30.0, 40.0];
1361        // 50th percentile: rank = 0.5 * 3 = 1.5, interp between 20 and 30 = 25
1362        assert!((percentile_linear_interp(values, 50.0) - 25.0).abs() < f64::EPSILON);
1363    }
1364
1365    #[test]
1366    fn test_percentile_linear_interp_1st_percentile() {
1367        // Values: [1, 2, 3, ..., 100]
1368        let values: Vec<f64> = (1..=100).map(|x| x as f64).collect();
1369        // 1st percentile = 1.0 + 0.01 * 99 * (2.0 - 1.0) ~ 1.99
1370        let p1 = percentile_linear_interp(&values, 1.0);
1371        assert!((p1 - 1.99).abs() < 0.01);
1372    }
1373
1374    #[test]
1375    fn test_value_percentile_check_condition() {
1376        let mut state = WindowState::new_for(CorrelationType::ValuePercentile);
1377        // Push 100 values: 1.0, 2.0, ..., 100.0
1378        for i in 1..=100 {
1379            state.push_numeric(1000 + i, i as f64);
1380        }
1381
1382        let cond = CompiledCondition {
1383            field: Some(vec!["latency".to_string()]),
1384            predicates: vec![(ConditionOperator::Lte, 50.0)],
1385            percentile: None,
1386        };
1387        // 50th percentile of 1..100 should be ~50.5
1388        let result = state.check_condition(&cond, CorrelationType::ValuePercentile, &[], None);
1389        assert!(result.is_some());
1390        let val = result.unwrap();
1391        assert!((val - 50.5).abs() < 1.0, "expected ~50.5, got {val}");
1392    }
1393
1394    #[test]
1395    fn test_percentile_0th_and_100th() {
1396        let values = &[5.0, 10.0, 15.0, 20.0];
1397        assert!((percentile_linear_interp(values, 0.0) - 5.0).abs() < f64::EPSILON);
1398        assert!((percentile_linear_interp(values, 100.0) - 20.0).abs() < f64::EPSILON);
1399    }
1400
1401    #[test]
1402    fn test_percentile_two_values() {
1403        let values = &[10.0, 20.0];
1404        // 50th percentile between 10 and 20 = 15
1405        assert!((percentile_linear_interp(values, 50.0) - 15.0).abs() < f64::EPSILON);
1406        // 25th percentile = 12.5
1407        assert!((percentile_linear_interp(values, 25.0) - 12.5).abs() < f64::EPSILON);
1408    }
1409
1410    #[test]
1411    fn test_percentile_clamps_out_of_range() {
1412        let values = &[1.0, 2.0, 3.0];
1413        // Negative percentile clamps to 0
1414        assert!((percentile_linear_interp(values, -10.0) - 1.0).abs() < f64::EPSILON);
1415        // > 100 clamps to 100
1416        assert!((percentile_linear_interp(values, 150.0) - 3.0).abs() < f64::EPSILON);
1417    }
1418
1419    #[test]
1420    fn test_value_percentile_empty_window() {
1421        let state = WindowState::new_for(CorrelationType::ValuePercentile);
1422        let cond = CompiledCondition {
1423            field: Some(vec!["latency".to_string()]),
1424            predicates: vec![(ConditionOperator::Lte, 50.0)],
1425            percentile: None,
1426        };
1427        // Empty window should return None
1428        assert!(
1429            state
1430                .check_condition(&cond, CorrelationType::ValuePercentile, &[], None)
1431                .is_none()
1432        );
1433    }
1434
1435    #[test]
1436    fn test_extended_temporal_or_single_rule() {
1437        // "rule_a or rule_b" — only rule_a fired
1438        let mut rule_hits = HashMap::new();
1439        rule_hits.insert("rule_a".to_string(), VecDeque::from([1000]));
1440
1441        let expr = ConditionExpr::Or(vec![
1442            ConditionExpr::Identifier("rule_a".to_string()),
1443            ConditionExpr::Identifier("rule_b".to_string()),
1444        ]);
1445        assert!(eval_temporal_expr(&expr, &rule_hits));
1446    }
1447
1448    #[test]
1449    fn test_extended_temporal_empty_hits() {
1450        let rule_hits = HashMap::new();
1451
1452        // "rule_a and rule_b" — nothing fired
1453        let expr = ConditionExpr::And(vec![
1454            ConditionExpr::Identifier("rule_a".to_string()),
1455            ConditionExpr::Identifier("rule_b".to_string()),
1456        ]);
1457        assert!(!eval_temporal_expr(&expr, &rule_hits));
1458
1459        // "rule_a or rule_b" — nothing fired
1460        let expr_or = ConditionExpr::Or(vec![
1461            ConditionExpr::Identifier("rule_a".to_string()),
1462            ConditionExpr::Identifier("rule_b".to_string()),
1463        ]);
1464        assert!(!eval_temporal_expr(&expr_or, &rule_hits));
1465    }
1466
1467    #[test]
1468    fn test_extended_temporal_with_empty_deque() {
1469        // Rule exists in map but with empty deque (all evicted)
1470        let mut rule_hits = HashMap::new();
1471        rule_hits.insert("rule_a".to_string(), VecDeque::new());
1472        rule_hits.insert("rule_b".to_string(), VecDeque::from([1000]));
1473
1474        let expr = ConditionExpr::And(vec![
1475            ConditionExpr::Identifier("rule_a".to_string()),
1476            ConditionExpr::Identifier("rule_b".to_string()),
1477        ]);
1478        // rule_a has empty deque — should be treated as not fired
1479        assert!(!eval_temporal_expr(&expr, &rule_hits));
1480    }
1481
1482    #[test]
1483    fn test_check_condition_temporal_no_extended_expr() {
1484        // Standard temporal without extended expr: uses threshold count
1485        let refs = vec![
1486            "rule_a".to_string(),
1487            "rule_b".to_string(),
1488            "rule_c".to_string(),
1489        ];
1490        let mut state = WindowState::new_for(CorrelationType::Temporal);
1491        state.push_temporal(1000, "rule_a");
1492        state.push_temporal(1001, "rule_b");
1493
1494        // Threshold: at least 2 rules must fire
1495        let cond = CompiledCondition {
1496            field: None,
1497            predicates: vec![(ConditionOperator::Gte, 2.0)],
1498            percentile: None,
1499        };
1500        // Without extended expr: 2 of 3 rules fired, meets gte 2
1501        assert_eq!(
1502            state.check_condition(&cond, CorrelationType::Temporal, &refs, None),
1503            Some(2.0)
1504        );
1505
1506        // With threshold 3: not enough
1507        let cond3 = CompiledCondition {
1508            field: None,
1509            predicates: vec![(ConditionOperator::Gte, 3.0)],
1510            percentile: None,
1511        };
1512        assert!(
1513            state
1514                .check_condition(&cond3, CorrelationType::Temporal, &refs, None)
1515                .is_none()
1516        );
1517    }
1518
1519    // =========================================================================
1520    // EventBuffer tests
1521    // =========================================================================
1522
1523    #[test]
1524    fn test_event_buffer_push_and_decompress() {
1525        let mut buf = EventBuffer::new(10);
1526        let event = json!({"User": "admin", "action": "login", "src_ip": "10.0.0.1"});
1527        buf.push(1000, &event);
1528
1529        assert_eq!(buf.len(), 1);
1530        assert!(!buf.is_empty());
1531
1532        let events = buf.decompress_all();
1533        assert_eq!(events.len(), 1);
1534        assert_eq!(events[0], event);
1535    }
1536
1537    #[test]
1538    fn test_event_buffer_compression_saves_memory() {
1539        let mut buf = EventBuffer::new(100);
1540        // Push a realistic-sized event (~500 bytes JSON)
1541        let event = json!({
1542            "User": "admin",
1543            "action": "login",
1544            "src_ip": "192.168.1.100",
1545            "dst_ip": "10.0.0.1",
1546            "EventTime": "2024-07-10T12:30:00Z",
1547            "process": "sshd",
1548            "host": "production-server-01.example.com",
1549            "message": "Accepted password for admin from 192.168.1.100 port 22 ssh2",
1550            "severity": "info",
1551            "tags": ["authentication", "network", "linux"]
1552        });
1553
1554        let raw_size = serde_json::to_vec(&event).unwrap().len();
1555        buf.push(1000, &event);
1556        let compressed_size = buf.compressed_bytes();
1557
1558        // Compressed should be notably smaller than raw
1559        assert!(
1560            compressed_size < raw_size,
1561            "Compressed {compressed_size}B should be less than raw {raw_size}B"
1562        );
1563
1564        // Verify roundtrip
1565        let events = buf.decompress_all();
1566        assert_eq!(events[0], event);
1567    }
1568
1569    #[test]
1570    fn test_event_buffer_max_events_cap() {
1571        let mut buf = EventBuffer::new(3);
1572
1573        for i in 0..5 {
1574            buf.push(1000 + i, &json!({"idx": i}));
1575        }
1576
1577        // Only the last 3 should remain
1578        assert_eq!(buf.len(), 3);
1579        let events = buf.decompress_all();
1580        assert_eq!(events[0], json!({"idx": 2}));
1581        assert_eq!(events[1], json!({"idx": 3}));
1582        assert_eq!(events[2], json!({"idx": 4}));
1583    }
1584
1585    #[test]
1586    fn test_event_buffer_eviction() {
1587        let mut buf = EventBuffer::new(10);
1588        for i in 0..5 {
1589            buf.push(1000 + i, &json!({"idx": i}));
1590        }
1591        assert_eq!(buf.len(), 5);
1592
1593        // Evict everything before ts 1003
1594        buf.evict(1003);
1595        assert_eq!(buf.len(), 2);
1596
1597        let events = buf.decompress_all();
1598        assert_eq!(events[0], json!({"idx": 3}));
1599        assert_eq!(events[1], json!({"idx": 4}));
1600    }
1601
1602    #[test]
1603    fn test_event_buffer_clear() {
1604        let mut buf = EventBuffer::new(10);
1605        buf.push(1000, &json!({"a": 1}));
1606        buf.push(1001, &json!({"b": 2}));
1607        assert_eq!(buf.len(), 2);
1608
1609        buf.clear();
1610        assert!(buf.is_empty());
1611        assert_eq!(buf.len(), 0);
1612        assert_eq!(buf.compressed_bytes(), 0);
1613    }
1614
1615    #[test]
1616    fn test_compress_decompress_roundtrip() {
1617        // Test various JSON shapes
1618        let values = vec![
1619            json!(null),
1620            json!(42),
1621            json!("hello world"),
1622            json!({"nested": {"deep": [1, 2, 3]}}),
1623            json!([1, "two", null, true, {"five": 5}]),
1624        ];
1625        for val in values {
1626            let compressed = compress_event(&val).unwrap();
1627            let decompressed = decompress_event(&compressed).unwrap();
1628            assert_eq!(decompressed, val, "Roundtrip failed for {val}");
1629        }
1630    }
1631
1632    // =========================================================================
1633    // EventRefBuffer tests
1634    // =========================================================================
1635
1636    #[test]
1637    fn test_event_ref_buffer_push_and_refs() {
1638        let mut buf = EventRefBuffer::new(10);
1639        buf.push(1000, &json!({"id": "evt-1", "data": "hello"}));
1640        buf.push(1001, &json!({"_id": 42, "data": "world"}));
1641        buf.push(1002, &json!({"data": "no-id"}));
1642
1643        assert_eq!(buf.len(), 3);
1644        let refs = buf.refs();
1645        assert_eq!(refs[0].timestamp, 1000);
1646        assert_eq!(refs[0].id, Some("evt-1".to_string()));
1647        assert_eq!(refs[1].timestamp, 1001);
1648        assert_eq!(refs[1].id, Some("42".to_string()));
1649        assert_eq!(refs[2].timestamp, 1002);
1650        assert_eq!(refs[2].id, None);
1651    }
1652
1653    #[test]
1654    fn test_event_ref_buffer_max_cap() {
1655        let mut buf = EventRefBuffer::new(3);
1656        for i in 0..5 {
1657            buf.push(1000 + i, &json!({"id": format!("e-{i}")}));
1658        }
1659        assert_eq!(buf.len(), 3);
1660        let refs = buf.refs();
1661        assert_eq!(refs[0].id, Some("e-2".to_string()));
1662        assert_eq!(refs[1].id, Some("e-3".to_string()));
1663        assert_eq!(refs[2].id, Some("e-4".to_string()));
1664    }
1665
1666    #[test]
1667    fn test_event_ref_buffer_eviction() {
1668        let mut buf = EventRefBuffer::new(10);
1669        for i in 0..5 {
1670            buf.push(1000 + i, &json!({"id": format!("e-{i}")}));
1671        }
1672        buf.evict(1003);
1673        assert_eq!(buf.len(), 2);
1674        let refs = buf.refs();
1675        assert_eq!(refs[0].timestamp, 1003);
1676        assert_eq!(refs[1].timestamp, 1004);
1677    }
1678
1679    #[test]
1680    fn test_event_ref_buffer_clear() {
1681        let mut buf = EventRefBuffer::new(10);
1682        buf.push(1000, &json!({"id": "a"}));
1683        buf.push(1001, &json!({"id": "b"}));
1684        assert_eq!(buf.len(), 2);
1685
1686        buf.clear();
1687        assert!(buf.is_empty());
1688        assert_eq!(buf.len(), 0);
1689    }
1690
1691    #[test]
1692    fn test_extract_event_id_common_fields() {
1693        assert_eq!(
1694            extract_event_id(&json!({"id": "abc"})),
1695            Some("abc".to_string())
1696        );
1697        assert_eq!(
1698            extract_event_id(&json!({"_id": 123})),
1699            Some("123".to_string())
1700        );
1701        assert_eq!(
1702            extract_event_id(&json!({"event_id": "x-1"})),
1703            Some("x-1".to_string())
1704        );
1705        assert_eq!(
1706            extract_event_id(&json!({"EventRecordID": 999})),
1707            Some("999".to_string())
1708        );
1709        assert_eq!(extract_event_id(&json!({"no_id_field": true})), None);
1710    }
1711
1712    #[test]
1713    fn test_compile_correlation_with_custom_attributes() {
1714        use rsigma_parser::*;
1715
1716        let mut custom_attributes: HashMap<String, serde_yaml::Value> =
1717            std::collections::HashMap::new();
1718        custom_attributes.insert(
1719            "rsigma.correlation_event_mode".to_string(),
1720            serde_yaml::Value::String("refs".to_string()),
1721        );
1722        custom_attributes.insert(
1723            "rsigma.max_correlation_events".to_string(),
1724            serde_yaml::Value::String("25".to_string()),
1725        );
1726        custom_attributes.insert(
1727            "rsigma.suppress".to_string(),
1728            serde_yaml::Value::String("5m".to_string()),
1729        );
1730        custom_attributes.insert(
1731            "rsigma.action".to_string(),
1732            serde_yaml::Value::String("reset".to_string()),
1733        );
1734
1735        let rule = CorrelationRule {
1736            title: "Test Corr".to_string(),
1737            id: Some("corr-1".to_string()),
1738            name: None,
1739            status: None,
1740            description: None,
1741            author: None,
1742            date: None,
1743            modified: None,
1744            related: vec![],
1745            references: vec![],
1746            taxonomy: None,
1747            license: None,
1748            tags: vec![],
1749            fields: vec![],
1750            falsepositives: vec![],
1751            level: Some(Level::High),
1752            scope: vec![],
1753            correlation_type: CorrelationType::EventCount,
1754            rules: vec!["rule-1".to_string()],
1755            group_by: vec!["User".to_string()],
1756            timespan: Timespan::parse("60s").unwrap(),
1757            condition: CorrelationCondition::Threshold {
1758                predicates: vec![(ConditionOperator::Gte, 5)],
1759                field: None,
1760                percentile: None,
1761            },
1762            aliases: vec![],
1763            generate: false,
1764            custom_attributes,
1765        };
1766
1767        let compiled = compile_correlation(&rule).unwrap();
1768
1769        // Per-correlation overrides should be resolved from custom_attributes
1770        assert_eq!(
1771            compiled.event_mode,
1772            Some(crate::correlation_engine::CorrelationEventMode::Refs)
1773        );
1774        assert_eq!(compiled.max_events, Some(25));
1775        assert_eq!(compiled.suppress_secs, Some(300)); // 5m = 300s
1776        assert_eq!(
1777            compiled.action,
1778            Some(crate::correlation_engine::CorrelationAction::Reset)
1779        );
1780    }
1781}