Skip to main content

rsigma_eval/correlation/
window.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use rsigma_parser::{ConditionExpr, CorrelationType};
4use serde::Serialize;
5
6use super::CompiledCondition;
7
8// =============================================================================
9// Window State
10// =============================================================================
11
12/// Per-group mutable state within a time window.
13///
14/// Each variant matches the type of aggregation being performed.
15#[derive(Debug, Clone, Serialize, serde::Deserialize)]
16pub enum WindowState {
17    /// For `event_count`: timestamps of matching events.
18    EventCount { timestamps: VecDeque<i64> },
19    /// For `value_count`: (timestamp, field_value) pairs.
20    ValueCount { entries: VecDeque<(i64, String)> },
21    /// For `temporal` / `temporal_ordered`: rule_ref -> list of hit timestamps.
22    Temporal {
23        rule_hits: HashMap<String, VecDeque<i64>>,
24    },
25    /// For `value_sum`, `value_avg`, `value_percentile`, `value_median`:
26    /// (timestamp, numeric_value) pairs.
27    NumericAgg { entries: VecDeque<(i64, f64)> },
28}
29
30impl WindowState {
31    /// Create a new empty window state for the given correlation type.
32    pub fn new_for(corr_type: CorrelationType) -> Self {
33        match corr_type {
34            CorrelationType::EventCount => WindowState::EventCount {
35                timestamps: VecDeque::new(),
36            },
37            CorrelationType::ValueCount => WindowState::ValueCount {
38                entries: VecDeque::new(),
39            },
40            CorrelationType::Temporal | CorrelationType::TemporalOrdered => WindowState::Temporal {
41                rule_hits: HashMap::new(),
42            },
43            CorrelationType::ValueSum
44            | CorrelationType::ValueAvg
45            | CorrelationType::ValuePercentile
46            | CorrelationType::ValueMedian => WindowState::NumericAgg {
47                entries: VecDeque::new(),
48            },
49        }
50    }
51
52    /// Remove all entries older than the cutoff timestamp.
53    pub fn evict(&mut self, cutoff: i64) {
54        match self {
55            WindowState::EventCount { timestamps } => {
56                while timestamps.front().is_some_and(|&t| t < cutoff) {
57                    timestamps.pop_front();
58                }
59            }
60            WindowState::ValueCount { entries } => {
61                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
62                    entries.pop_front();
63                }
64            }
65            WindowState::Temporal { rule_hits } => {
66                for timestamps in rule_hits.values_mut() {
67                    while timestamps.front().is_some_and(|&t| t < cutoff) {
68                        timestamps.pop_front();
69                    }
70                }
71                // Remove empty rule entries
72                rule_hits.retain(|_, ts| !ts.is_empty());
73            }
74            WindowState::NumericAgg { entries } => {
75                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
76                    entries.pop_front();
77                }
78            }
79        }
80    }
81
82    /// Returns true if this state has no entries.
83    pub fn is_empty(&self) -> bool {
84        match self {
85            WindowState::EventCount { timestamps } => timestamps.is_empty(),
86            WindowState::ValueCount { entries } => entries.is_empty(),
87            WindowState::Temporal { rule_hits } => rule_hits.is_empty(),
88            WindowState::NumericAgg { entries } => entries.is_empty(),
89        }
90    }
91
92    /// Returns the most recent timestamp in this window, or `None` if empty.
93    pub fn latest_timestamp(&self) -> Option<i64> {
94        match self {
95            WindowState::EventCount { timestamps } => timestamps.back().copied(),
96            WindowState::ValueCount { entries } => entries.back().map(|(t, _)| *t),
97            WindowState::Temporal { rule_hits } => {
98                rule_hits.values().filter_map(|ts| ts.back().copied()).max()
99            }
100            WindowState::NumericAgg { entries } => entries.back().map(|(t, _)| *t),
101        }
102    }
103
104    /// Clear all entries from the window state (used by `CorrelationAction::Reset`).
105    pub fn clear(&mut self) {
106        match self {
107            WindowState::EventCount { timestamps } => timestamps.clear(),
108            WindowState::ValueCount { entries } => entries.clear(),
109            WindowState::Temporal { rule_hits } => rule_hits.clear(),
110            WindowState::NumericAgg { entries } => entries.clear(),
111        }
112    }
113
114    /// Record an event_count hit.
115    pub fn push_event_count(&mut self, ts: i64) {
116        if let WindowState::EventCount { timestamps } = self {
117            timestamps.push_back(ts);
118        }
119    }
120
121    /// Record a value_count hit with the field value.
122    pub fn push_value_count(&mut self, ts: i64, value: String) {
123        if let WindowState::ValueCount { entries } = self {
124            entries.push_back((ts, value));
125        }
126    }
127
128    /// Record a temporal hit for a specific rule reference.
129    pub fn push_temporal(&mut self, ts: i64, rule_ref: &str) {
130        if let WindowState::Temporal { rule_hits } = self {
131            rule_hits
132                .entry(rule_ref.to_string())
133                .or_default()
134                .push_back(ts);
135        }
136    }
137
138    /// Record a numeric aggregation value.
139    pub fn push_numeric(&mut self, ts: i64, value: f64) {
140        if let WindowState::NumericAgg { entries } = self {
141            entries.push_back((ts, value));
142        }
143    }
144
145    /// Evaluate the window state against the correlation condition.
146    ///
147    /// Returns `Some(aggregated_value)` if the condition is satisfied,
148    /// `None` otherwise.
149    ///
150    /// For temporal correlations with an extended expression, the expression
151    /// is evaluated against the set of rules that have fired in the window.
152    pub fn check_condition(
153        &self,
154        condition: &CompiledCondition,
155        corr_type: CorrelationType,
156        rule_refs: &[String],
157        extended_expr: Option<&ConditionExpr>,
158    ) -> Option<f64> {
159        let value = match (self, corr_type) {
160            (WindowState::EventCount { timestamps }, CorrelationType::EventCount) => {
161                timestamps.len() as f64
162            }
163            (WindowState::ValueCount { entries }, CorrelationType::ValueCount) => {
164                // Count distinct values
165                let distinct: HashSet<&String> = entries.iter().map(|(_, v)| v).collect();
166                distinct.len() as f64
167            }
168            (WindowState::Temporal { rule_hits }, CorrelationType::Temporal) => {
169                // If an extended expression is provided, evaluate it
170                if let Some(expr) = extended_expr {
171                    if eval_temporal_expr(expr, rule_hits) {
172                        // Return the count of fired rules as the value
173                        let fired: usize = rule_refs
174                            .iter()
175                            .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
176                            .count();
177                        return Some(fired as f64);
178                    } else {
179                        return None;
180                    }
181                }
182                // Default: count how many distinct referenced rules have fired
183                let fired: usize = rule_refs
184                    .iter()
185                    .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
186                    .count();
187                fired as f64
188            }
189            (WindowState::Temporal { rule_hits }, CorrelationType::TemporalOrdered) => {
190                // If an extended expression is provided, evaluate it first
191                if let Some(expr) = extended_expr
192                    && !eval_temporal_expr(expr, rule_hits)
193                {
194                    return None;
195                }
196                // Check if all referenced rules fired in order
197                if check_temporal_ordered(rule_refs, rule_hits) {
198                    rule_refs.len() as f64
199                } else {
200                    0.0
201                }
202            }
203            (WindowState::NumericAgg { entries }, CorrelationType::ValueSum) => {
204                entries.iter().map(|(_, v)| v).sum()
205            }
206            (WindowState::NumericAgg { entries }, CorrelationType::ValueAvg) => {
207                if entries.is_empty() {
208                    0.0
209                } else {
210                    let sum: f64 = entries.iter().map(|(_, v)| v).sum();
211                    sum / entries.len() as f64
212                }
213            }
214            (WindowState::NumericAgg { entries }, CorrelationType::ValuePercentile) => {
215                // Proper percentile calculation using linear interpolation.
216                // The condition threshold represents a percentile rank (0-100).
217                // We compute the value at that percentile from the window data.
218                if entries.is_empty() {
219                    return None;
220                }
221                let mut values: Vec<f64> = entries
222                    .iter()
223                    .map(|(_, v)| *v)
224                    .filter(|v| v.is_finite())
225                    .collect();
226                if values.is_empty() {
227                    return None;
228                }
229                values.sort_by(|a, b| a.total_cmp(b));
230                let percentile_rank = condition.percentile.map(|p| p as f64).unwrap_or(50.0);
231                let pval = percentile_linear_interp(&values, percentile_rank);
232                return Some(pval);
233            }
234            (WindowState::NumericAgg { entries }, CorrelationType::ValueMedian) => {
235                // An empty window has no median. Returning `0.0` here would
236                // spuriously satisfy predicates like `lte: 0` or `eq: 0`, so
237                // match the percentile branch and skip evaluation.
238                if entries.is_empty() {
239                    return None;
240                }
241                let mut values: Vec<f64> = entries
242                    .iter()
243                    .map(|(_, v)| *v)
244                    .filter(|v| v.is_finite())
245                    .collect();
246                if values.is_empty() {
247                    return None;
248                }
249                values.sort_by(|a, b| a.total_cmp(b));
250                let mid = values.len() / 2;
251                if values.len().is_multiple_of(2) && values.len() >= 2 {
252                    (values[mid - 1] + values[mid]) / 2.0
253                } else {
254                    values[mid]
255                }
256            }
257            _ => return None, // mismatched state/type
258        };
259
260        if condition.check(value) {
261            Some(value)
262        } else {
263            None
264        }
265    }
266}
267
268/// Check if all referenced rules fired in the correct order within the window.
269///
270/// For `temporal_ordered`, each rule must have at least one hit, and there
271/// must exist a sequence of timestamps (one per rule) that is non-decreasing
272/// and follows the rule ordering.
273fn check_temporal_ordered(
274    rule_refs: &[String],
275    rule_hits: &HashMap<String, VecDeque<i64>>,
276) -> bool {
277    if rule_refs.is_empty() {
278        return true;
279    }
280
281    // All rules must have at least one hit
282    for r in rule_refs {
283        if rule_hits.get(r.as_str()).is_none_or(|ts| ts.is_empty()) {
284            return false;
285        }
286    }
287
288    // Check if there's a valid ordered sequence: for each rule in order,
289    // find a timestamp >= the previous rule's chosen timestamp.
290    fn find_ordered(
291        rule_refs: &[String],
292        rule_hits: &HashMap<String, VecDeque<i64>>,
293        idx: usize,
294        min_ts: i64,
295    ) -> bool {
296        if idx >= rule_refs.len() {
297            return true;
298        }
299        let Some(timestamps) = rule_hits.get(&rule_refs[idx]) else {
300            return false;
301        };
302        for &ts in timestamps {
303            if ts >= min_ts && find_ordered(rule_refs, rule_hits, idx + 1, ts) {
304                return true;
305            }
306        }
307        false
308    }
309
310    find_ordered(rule_refs, rule_hits, 0, i64::MIN)
311}
312
313/// Evaluate a boolean condition expression against the set of rules that have
314/// fired within the temporal window.
315///
316/// Each `Identifier` in the expression is treated as a rule reference — it's
317/// `true` if that rule has at least one hit in `rule_hits`.
318pub(super) fn eval_temporal_expr(
319    expr: &ConditionExpr,
320    rule_hits: &HashMap<String, VecDeque<i64>>,
321) -> bool {
322    match expr {
323        ConditionExpr::Identifier(name) => rule_hits
324            .get(name.as_str())
325            .is_some_and(|ts| !ts.is_empty()),
326        ConditionExpr::And(children) => children.iter().all(|c| eval_temporal_expr(c, rule_hits)),
327        ConditionExpr::Or(children) => children.iter().any(|c| eval_temporal_expr(c, rule_hits)),
328        ConditionExpr::Not(child) => !eval_temporal_expr(child, rule_hits),
329        ConditionExpr::Selector { .. } => {
330            // Selectors are not meaningful for temporal condition evaluation
331            false
332        }
333    }
334}
335
336/// Compute the value at a given percentile rank using linear interpolation.
337///
338/// Returns 0.0 if `values` is empty.
339/// `values` must be sorted in ascending order.
340/// `percentile` is from 0.0 to 100.0.
341pub(super) fn percentile_linear_interp(values: &[f64], percentile: f64) -> f64 {
342    if values.is_empty() {
343        return 0.0;
344    }
345    let n = values.len();
346    if n == 1 {
347        return values[0];
348    }
349
350    // Clamp percentile to [0, 100]
351    let p = percentile.clamp(0.0, 100.0) / 100.0;
352
353    // Use the "C = 1" interpolation method (most common in statistics)
354    // rank = p * (n - 1)
355    let rank = p * (n - 1) as f64;
356    let lower = rank.floor() as usize;
357    let upper = rank.ceil() as usize;
358    let fraction = rank - lower as f64;
359
360    if lower == upper || upper >= n {
361        values[lower.min(n - 1)]
362    } else {
363        values[lower] + fraction * (values[upper] - values[lower])
364    }
365}