Skip to main content

rsigma_eval/correlation/
window.rs

1use std::collections::{HashMap, HashSet, VecDeque};
2
3use rsigma_parser::{ConditionExpr, CorrelationType};
4use serde::Serialize;
5
6use super::CompiledCondition;
7
8// =============================================================================
9// Window State
10// =============================================================================
11
12/// Per-group mutable state within a time window.
13///
14/// Each variant matches the type of aggregation being performed.
15#[derive(Debug, Clone, Serialize, serde::Deserialize)]
16pub enum WindowState {
17    /// For `event_count`: timestamps of matching events.
18    EventCount { timestamps: VecDeque<i64> },
19    /// For `value_count`: (timestamp, field_value) pairs.
20    ValueCount { entries: VecDeque<(i64, String)> },
21    /// For `temporal` / `temporal_ordered`: rule_ref -> list of hit timestamps.
22    Temporal {
23        rule_hits: HashMap<String, VecDeque<i64>>,
24    },
25    /// For `value_sum`, `value_avg`, `value_percentile`, `value_median`:
26    /// (timestamp, numeric_value) pairs.
27    NumericAgg { entries: VecDeque<(i64, f64)> },
28}
29
30impl WindowState {
31    /// Create a new empty window state for the given correlation type.
32    pub fn new_for(corr_type: CorrelationType) -> Self {
33        match corr_type {
34            CorrelationType::EventCount => WindowState::EventCount {
35                timestamps: VecDeque::new(),
36            },
37            CorrelationType::ValueCount => WindowState::ValueCount {
38                entries: VecDeque::new(),
39            },
40            CorrelationType::Temporal | CorrelationType::TemporalOrdered => WindowState::Temporal {
41                rule_hits: HashMap::new(),
42            },
43            CorrelationType::ValueSum
44            | CorrelationType::ValueAvg
45            | CorrelationType::ValuePercentile
46            | CorrelationType::ValueMedian => WindowState::NumericAgg {
47                entries: VecDeque::new(),
48            },
49        }
50    }
51
52    /// Remove all entries older than the cutoff timestamp.
53    pub fn evict(&mut self, cutoff: i64) {
54        match self {
55            WindowState::EventCount { timestamps } => {
56                while timestamps.front().is_some_and(|&t| t < cutoff) {
57                    timestamps.pop_front();
58                }
59            }
60            WindowState::ValueCount { entries } => {
61                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
62                    entries.pop_front();
63                }
64            }
65            WindowState::Temporal { rule_hits } => {
66                for timestamps in rule_hits.values_mut() {
67                    while timestamps.front().is_some_and(|&t| t < cutoff) {
68                        timestamps.pop_front();
69                    }
70                }
71                // Remove empty rule entries
72                rule_hits.retain(|_, ts| !ts.is_empty());
73            }
74            WindowState::NumericAgg { entries } => {
75                while entries.front().is_some_and(|(t, _)| *t < cutoff) {
76                    entries.pop_front();
77                }
78            }
79        }
80    }
81
82    /// Returns true if this state has no entries.
83    pub fn is_empty(&self) -> bool {
84        match self {
85            WindowState::EventCount { timestamps } => timestamps.is_empty(),
86            WindowState::ValueCount { entries } => entries.is_empty(),
87            WindowState::Temporal { rule_hits } => rule_hits.is_empty(),
88            WindowState::NumericAgg { entries } => entries.is_empty(),
89        }
90    }
91
92    /// Returns the most recent timestamp in this window, or `None` if empty.
93    pub fn latest_timestamp(&self) -> Option<i64> {
94        match self {
95            WindowState::EventCount { timestamps } => timestamps.back().copied(),
96            WindowState::ValueCount { entries } => entries.back().map(|(t, _)| *t),
97            WindowState::Temporal { rule_hits } => {
98                rule_hits.values().filter_map(|ts| ts.back().copied()).max()
99            }
100            WindowState::NumericAgg { entries } => entries.back().map(|(t, _)| *t),
101        }
102    }
103
104    /// Clear all entries from the window state (used by `CorrelationAction::Reset`).
105    pub fn clear(&mut self) {
106        match self {
107            WindowState::EventCount { timestamps } => timestamps.clear(),
108            WindowState::ValueCount { entries } => entries.clear(),
109            WindowState::Temporal { rule_hits } => rule_hits.clear(),
110            WindowState::NumericAgg { entries } => entries.clear(),
111        }
112    }
113
114    /// Record an event_count hit.
115    pub fn push_event_count(&mut self, ts: i64) {
116        if let WindowState::EventCount { timestamps } = self {
117            timestamps.push_back(ts);
118        }
119    }
120
121    /// Record a value_count hit with the field value.
122    pub fn push_value_count(&mut self, ts: i64, value: String) {
123        if let WindowState::ValueCount { entries } = self {
124            entries.push_back((ts, value));
125        }
126    }
127
128    /// Record a temporal hit for a specific rule reference.
129    pub fn push_temporal(&mut self, ts: i64, rule_ref: &str) {
130        if let WindowState::Temporal { rule_hits } = self {
131            rule_hits
132                .entry(rule_ref.to_string())
133                .or_default()
134                .push_back(ts);
135        }
136    }
137
138    /// Record a numeric aggregation value.
139    pub fn push_numeric(&mut self, ts: i64, value: f64) {
140        if let WindowState::NumericAgg { entries } = self {
141            entries.push_back((ts, value));
142        }
143    }
144
145    /// Evaluate the window state against the correlation condition.
146    ///
147    /// Returns `Some(aggregated_value)` if the condition is satisfied,
148    /// `None` otherwise.
149    ///
150    /// For temporal correlations with an extended expression, the expression
151    /// is evaluated against the set of rules that have fired in the window.
152    pub fn check_condition(
153        &self,
154        condition: &CompiledCondition,
155        corr_type: CorrelationType,
156        rule_refs: &[String],
157        extended_expr: Option<&ConditionExpr>,
158    ) -> Option<f64> {
159        let value = match (self, corr_type) {
160            (WindowState::EventCount { timestamps }, CorrelationType::EventCount) => {
161                timestamps.len() as f64
162            }
163            (WindowState::ValueCount { entries }, CorrelationType::ValueCount) => {
164                // Count distinct values
165                let distinct: HashSet<&String> = entries.iter().map(|(_, v)| v).collect();
166                distinct.len() as f64
167            }
168            (WindowState::Temporal { rule_hits }, CorrelationType::Temporal) => {
169                // If an extended expression is provided, evaluate it
170                if let Some(expr) = extended_expr {
171                    if eval_temporal_expr(expr, rule_hits) {
172                        // Return the count of fired rules as the value
173                        let fired: usize = rule_refs
174                            .iter()
175                            .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
176                            .count();
177                        return Some(fired as f64);
178                    } else {
179                        return None;
180                    }
181                }
182                // Default: count how many distinct referenced rules have fired
183                let fired: usize = rule_refs
184                    .iter()
185                    .filter(|r| rule_hits.get(r.as_str()).is_some_and(|ts| !ts.is_empty()))
186                    .count();
187                fired as f64
188            }
189            (WindowState::Temporal { rule_hits }, CorrelationType::TemporalOrdered) => {
190                // If an extended expression is provided, evaluate it first
191                if let Some(expr) = extended_expr
192                    && !eval_temporal_expr(expr, rule_hits)
193                {
194                    return None;
195                }
196                // Check if all referenced rules fired in order
197                if check_temporal_ordered(rule_refs, rule_hits) {
198                    rule_refs.len() as f64
199                } else {
200                    0.0
201                }
202            }
203            (WindowState::NumericAgg { entries }, CorrelationType::ValueSum) => {
204                entries.iter().map(|(_, v)| v).sum()
205            }
206            (WindowState::NumericAgg { entries }, CorrelationType::ValueAvg) => {
207                if entries.is_empty() {
208                    0.0
209                } else {
210                    let sum: f64 = entries.iter().map(|(_, v)| v).sum();
211                    sum / entries.len() as f64
212                }
213            }
214            (WindowState::NumericAgg { entries }, CorrelationType::ValuePercentile) => {
215                // Proper percentile calculation using linear interpolation.
216                // The condition threshold represents a percentile rank (0-100).
217                // We compute the value at that percentile from the window data.
218                if entries.is_empty() {
219                    return None;
220                }
221                let mut values: Vec<f64> = entries
222                    .iter()
223                    .map(|(_, v)| *v)
224                    .filter(|v| v.is_finite())
225                    .collect();
226                if values.is_empty() {
227                    return None;
228                }
229                values.sort_by(|a, b| a.total_cmp(b));
230                let percentile_rank = condition.percentile.map(|p| p as f64).unwrap_or(50.0);
231                let pval = percentile_linear_interp(&values, percentile_rank);
232                return Some(pval);
233            }
234            (WindowState::NumericAgg { entries }, CorrelationType::ValueMedian) => {
235                if entries.is_empty() {
236                    0.0
237                } else {
238                    let mut values: Vec<f64> = entries
239                        .iter()
240                        .map(|(_, v)| *v)
241                        .filter(|v| v.is_finite())
242                        .collect();
243                    if values.is_empty() {
244                        return None;
245                    }
246                    values.sort_by(|a, b| a.total_cmp(b));
247                    let mid = values.len() / 2;
248                    if values.len().is_multiple_of(2) && values.len() >= 2 {
249                        (values[mid - 1] + values[mid]) / 2.0
250                    } else {
251                        values[mid]
252                    }
253                }
254            }
255            _ => return None, // mismatched state/type
256        };
257
258        if condition.check(value) {
259            Some(value)
260        } else {
261            None
262        }
263    }
264}
265
266/// Check if all referenced rules fired in the correct order within the window.
267///
268/// For `temporal_ordered`, each rule must have at least one hit, and there
269/// must exist a sequence of timestamps (one per rule) that is non-decreasing
270/// and follows the rule ordering.
271fn check_temporal_ordered(
272    rule_refs: &[String],
273    rule_hits: &HashMap<String, VecDeque<i64>>,
274) -> bool {
275    if rule_refs.is_empty() {
276        return true;
277    }
278
279    // All rules must have at least one hit
280    for r in rule_refs {
281        if rule_hits.get(r.as_str()).is_none_or(|ts| ts.is_empty()) {
282            return false;
283        }
284    }
285
286    // Check if there's a valid ordered sequence: for each rule in order,
287    // find a timestamp >= the previous rule's chosen timestamp.
288    fn find_ordered(
289        rule_refs: &[String],
290        rule_hits: &HashMap<String, VecDeque<i64>>,
291        idx: usize,
292        min_ts: i64,
293    ) -> bool {
294        if idx >= rule_refs.len() {
295            return true;
296        }
297        let Some(timestamps) = rule_hits.get(&rule_refs[idx]) else {
298            return false;
299        };
300        for &ts in timestamps {
301            if ts >= min_ts && find_ordered(rule_refs, rule_hits, idx + 1, ts) {
302                return true;
303            }
304        }
305        false
306    }
307
308    find_ordered(rule_refs, rule_hits, 0, i64::MIN)
309}
310
311/// Evaluate a boolean condition expression against the set of rules that have
312/// fired within the temporal window.
313///
314/// Each `Identifier` in the expression is treated as a rule reference — it's
315/// `true` if that rule has at least one hit in `rule_hits`.
316pub(super) fn eval_temporal_expr(
317    expr: &ConditionExpr,
318    rule_hits: &HashMap<String, VecDeque<i64>>,
319) -> bool {
320    match expr {
321        ConditionExpr::Identifier(name) => rule_hits
322            .get(name.as_str())
323            .is_some_and(|ts| !ts.is_empty()),
324        ConditionExpr::And(children) => children.iter().all(|c| eval_temporal_expr(c, rule_hits)),
325        ConditionExpr::Or(children) => children.iter().any(|c| eval_temporal_expr(c, rule_hits)),
326        ConditionExpr::Not(child) => !eval_temporal_expr(child, rule_hits),
327        ConditionExpr::Selector { .. } => {
328            // Selectors are not meaningful for temporal condition evaluation
329            false
330        }
331    }
332}
333
334/// Compute the value at a given percentile rank using linear interpolation.
335///
336/// Returns 0.0 if `values` is empty.
337/// `values` must be sorted in ascending order.
338/// `percentile` is from 0.0 to 100.0.
339pub(super) fn percentile_linear_interp(values: &[f64], percentile: f64) -> f64 {
340    if values.is_empty() {
341        return 0.0;
342    }
343    let n = values.len();
344    if n == 1 {
345        return values[0];
346    }
347
348    // Clamp percentile to [0, 100]
349    let p = percentile.clamp(0.0, 100.0) / 100.0;
350
351    // Use the "C = 1" interpolation method (most common in statistics)
352    // rank = p * (n - 1)
353    let rank = p * (n - 1) as f64;
354    let lower = rank.floor() as usize;
355    let upper = rank.ceil() as usize;
356    let fraction = rank - lower as f64;
357
358    if lower == upper || upper >= n {
359        values[lower.min(n - 1)]
360    } else {
361        values[lower] + fraction * (values[upper] - values[lower])
362    }
363}