Skip to main content

shape_runtime/
window_manager.rs

1//! Generic window manager for time-series aggregations
2//!
3//! Provides windowing operations for streaming data:
4//! - Tumbling (fixed non-overlapping)
5//! - Sliding (overlapping)
6//! - Session (gap-based)
7//! - Count-based
8//! - Cumulative
9//!
10//! This module is industry-agnostic and works with any timestamped data.
11
12use chrono::{DateTime, Duration, Utc};
13use shape_value::ValueWord;
14use std::collections::HashMap;
15
16use shape_ast::error::Result;
17/// Window type for aggregations
18#[derive(Debug, Clone)]
19pub enum WindowType {
20    /// Fixed non-overlapping windows
21    Tumbling { size: Duration },
22    /// Overlapping windows with a slide interval
23    Sliding { size: Duration, slide: Duration },
24    /// Windows based on inactivity gaps
25    Session { gap: Duration },
26    /// Count-based windows (every N records)
27    Count { size: usize },
28    /// Cumulative from start
29    Cumulative,
30}
31
32impl WindowType {
33    /// Create a tumbling window
34    pub fn tumbling(size: Duration) -> Self {
35        WindowType::Tumbling { size }
36    }
37
38    /// Create a sliding window
39    pub fn sliding(size: Duration, slide: Duration) -> Self {
40        WindowType::Sliding { size, slide }
41    }
42
43    /// Create a session window
44    pub fn session(gap: Duration) -> Self {
45        WindowType::Session { gap }
46    }
47
48    /// Create a count-based window
49    pub fn count(size: usize) -> Self {
50        WindowType::Count { size }
51    }
52
53    /// Create a cumulative window
54    pub fn cumulative() -> Self {
55        WindowType::Cumulative
56    }
57}
58
59/// A single data point in a window
60#[derive(Debug, Clone)]
61pub struct WindowDataPoint {
62    pub timestamp: DateTime<Utc>,
63    pub fields: HashMap<String, ValueWord>,
64}
65
66/// A completed window with aggregated data
67#[derive(Debug, Clone)]
68pub struct WindowResult {
69    /// Window start time
70    pub start: DateTime<Utc>,
71    /// Window end time
72    pub end: DateTime<Utc>,
73    /// Number of data points in window
74    pub count: usize,
75    /// Aggregated values
76    pub aggregates: HashMap<String, f64>,
77}
78
79/// Aggregation function type
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub enum AggregateFunction {
82    Sum,
83    Avg,
84    Min,
85    Max,
86    Count,
87    First,
88    Last,
89    StdDev,
90    Variance,
91}
92
93/// Aggregation specification
94#[derive(Debug, Clone)]
95pub struct AggregateSpec {
96    pub field: String,
97    pub function: AggregateFunction,
98    pub output_name: String,
99}
100
101/// Window state for tracking active windows
102#[derive(Debug)]
103struct WindowState {
104    start: DateTime<Utc>,
105    data: Vec<WindowDataPoint>,
106    last_timestamp: Option<DateTime<Utc>>,
107}
108
109/// Generic window manager for streaming aggregations
110pub struct WindowManager {
111    /// Window type configuration
112    window_type: WindowType,
113    /// Aggregation specifications
114    aggregates: Vec<AggregateSpec>,
115    /// Active windows (for sliding/session)
116    active_windows: Vec<WindowState>,
117    /// Current tumbling window
118    current_window: Option<WindowState>,
119    /// Count for count-based windows
120    current_count: usize,
121    /// Cumulative data (for cumulative windows)
122    cumulative_data: Vec<WindowDataPoint>,
123    /// Completed windows waiting to be emitted
124    completed_windows: Vec<WindowResult>,
125}
126
127impl WindowManager {
128    /// Create a new window manager
129    pub fn new(window_type: WindowType) -> Self {
130        Self {
131            window_type,
132            aggregates: Vec::new(),
133            active_windows: Vec::new(),
134            current_window: None,
135            current_count: 0,
136            cumulative_data: Vec::new(),
137            completed_windows: Vec::new(),
138        }
139    }
140
141    /// Add an aggregation specification
142    pub fn aggregate(
143        &mut self,
144        field: &str,
145        function: AggregateFunction,
146        output_name: &str,
147    ) -> &mut Self {
148        self.aggregates.push(AggregateSpec {
149            field: field.to_string(),
150            function,
151            output_name: output_name.to_string(),
152        });
153        self
154    }
155
156    /// Process a data point
157    pub fn process(
158        &mut self,
159        timestamp: DateTime<Utc>,
160        fields: HashMap<String, ValueWord>,
161    ) -> Result<()> {
162        let data_point = WindowDataPoint { timestamp, fields };
163
164        match &self.window_type {
165            WindowType::Tumbling { size } => {
166                self.process_tumbling(&data_point, *size)?;
167            }
168            WindowType::Sliding { size, slide } => {
169                self.process_sliding(&data_point, *size, *slide)?;
170            }
171            WindowType::Session { gap } => {
172                self.process_session(&data_point, *gap)?;
173            }
174            WindowType::Count { size } => {
175                self.process_count(&data_point, *size)?;
176            }
177            WindowType::Cumulative => {
178                self.process_cumulative(&data_point)?;
179            }
180        }
181
182        Ok(())
183    }
184
185    /// Process a tumbling window data point
186    fn process_tumbling(&mut self, data_point: &WindowDataPoint, size: Duration) -> Result<()> {
187        let window_start = self.align_to_window(data_point.timestamp, size);
188
189        // Check if we need to close the current window
190        let should_close = self
191            .current_window
192            .as_ref()
193            .map(|w| data_point.timestamp >= w.start + size)
194            .unwrap_or(false);
195
196        if should_close {
197            // Take the window out to compute result
198            if let Some(window) = self.current_window.take() {
199                let result = self.compute_window_result(&window)?;
200                self.completed_windows.push(result);
201            }
202        }
203
204        // Add to current window or create new one
205        match &mut self.current_window {
206            Some(window) => {
207                window.data.push(data_point.clone());
208                window.last_timestamp = Some(data_point.timestamp);
209            }
210            None => {
211                self.current_window = Some(WindowState {
212                    start: window_start,
213                    data: vec![data_point.clone()],
214                    last_timestamp: Some(data_point.timestamp),
215                });
216            }
217        }
218
219        Ok(())
220    }
221
222    /// Process a sliding window data point
223    fn process_sliding(
224        &mut self,
225        data_point: &WindowDataPoint,
226        size: Duration,
227        slide: Duration,
228    ) -> Result<()> {
229        // Add point to all applicable windows
230        let ts = data_point.timestamp;
231
232        // Create new windows as needed
233        let window_start = self.align_to_window(ts, slide);
234
235        // Check if we need to create a new window
236        let needs_new_window = self.active_windows.is_empty()
237            || self
238                .active_windows
239                .last()
240                .map(|w| ts >= w.start + slide)
241                .unwrap_or(true);
242
243        if needs_new_window {
244            self.active_windows.push(WindowState {
245                start: window_start,
246                data: Vec::new(),
247                last_timestamp: None,
248            });
249        }
250
251        // Add point to all windows that contain this timestamp
252        for window in &mut self.active_windows {
253            if ts >= window.start && ts < window.start + size {
254                window.data.push(data_point.clone());
255                window.last_timestamp = Some(ts);
256            }
257        }
258
259        // Close windows that have ended
260        let mut closed_indices = Vec::new();
261        for (i, window) in self.active_windows.iter().enumerate() {
262            if ts >= window.start + size {
263                let result = self.compute_window_result(window)?;
264                self.completed_windows.push(result);
265                closed_indices.push(i);
266            }
267        }
268
269        // Remove closed windows (in reverse to maintain indices)
270        for i in closed_indices.into_iter().rev() {
271            self.active_windows.remove(i);
272        }
273
274        Ok(())
275    }
276
277    /// Process a session window data point
278    fn process_session(&mut self, data_point: &WindowDataPoint, gap: Duration) -> Result<()> {
279        // Check if we need to close the current session due to gap
280        let should_close = self
281            .current_window
282            .as_ref()
283            .and_then(|w| w.last_timestamp)
284            .map(|last_ts| data_point.timestamp - last_ts > gap)
285            .unwrap_or(false);
286
287        if should_close {
288            if let Some(window) = self.current_window.take() {
289                let result = self.compute_window_result(&window)?;
290                self.completed_windows.push(result);
291            }
292        }
293
294        // Add to current session or start new one
295        match &mut self.current_window {
296            Some(window) => {
297                window.data.push(data_point.clone());
298                window.last_timestamp = Some(data_point.timestamp);
299            }
300            None => {
301                self.current_window = Some(WindowState {
302                    start: data_point.timestamp,
303                    data: vec![data_point.clone()],
304                    last_timestamp: Some(data_point.timestamp),
305                });
306            }
307        }
308
309        Ok(())
310    }
311
312    /// Process a count-based window data point
313    fn process_count(&mut self, data_point: &WindowDataPoint, size: usize) -> Result<()> {
314        if self.current_window.is_none() {
315            self.current_window = Some(WindowState {
316                start: data_point.timestamp,
317                data: Vec::new(),
318                last_timestamp: None,
319            });
320        }
321
322        // Add to current window
323        if let Some(window) = &mut self.current_window {
324            window.data.push(data_point.clone());
325            window.last_timestamp = Some(data_point.timestamp);
326        }
327        self.current_count += 1;
328
329        // Check if window is complete
330        if self.current_count >= size {
331            if let Some(window) = self.current_window.take() {
332                let result = self.compute_window_result(&window)?;
333                self.completed_windows.push(result);
334            }
335            self.current_count = 0;
336        }
337
338        Ok(())
339    }
340
341    /// Process a cumulative window data point
342    fn process_cumulative(&mut self, data_point: &WindowDataPoint) -> Result<()> {
343        self.cumulative_data.push(data_point.clone());
344
345        // Create a window result for the cumulative state
346        let start = self
347            .cumulative_data
348            .first()
349            .map(|d| d.timestamp)
350            .unwrap_or(data_point.timestamp);
351        let end = data_point.timestamp;
352
353        let window = WindowState {
354            start,
355            data: self.cumulative_data.clone(),
356            last_timestamp: Some(end),
357        };
358
359        let result = self.compute_window_result(&window)?;
360        self.completed_windows.push(result);
361
362        Ok(())
363    }
364
365    /// Align timestamp to window boundary
366    fn align_to_window(&self, ts: DateTime<Utc>, size: Duration) -> DateTime<Utc> {
367        let epoch = DateTime::UNIX_EPOCH;
368        let since_epoch = ts - epoch;
369        let size_millis = size.num_milliseconds();
370
371        if size_millis == 0 {
372            return ts;
373        }
374
375        let aligned_millis = (since_epoch.num_milliseconds() / size_millis) * size_millis;
376        epoch + Duration::milliseconds(aligned_millis)
377    }
378
379    /// Compute aggregations for a window
380    fn compute_window_result(&self, window: &WindowState) -> Result<WindowResult> {
381        let mut aggregates = HashMap::new();
382
383        for spec in &self.aggregates {
384            let values: Vec<f64> = window
385                .data
386                .iter()
387                .filter_map(|d| d.fields.get(&spec.field).and_then(|v| v.as_f64()))
388                .collect();
389
390            let result = self.compute_aggregate(&values, spec.function)?;
391            aggregates.insert(spec.output_name.clone(), result);
392        }
393
394        let end = window.last_timestamp.unwrap_or(window.start);
395
396        Ok(WindowResult {
397            start: window.start,
398            end,
399            count: window.data.len(),
400            aggregates,
401        })
402    }
403
404    /// Compute a single aggregate value
405    fn compute_aggregate(&self, values: &[f64], function: AggregateFunction) -> Result<f64> {
406        if values.is_empty() {
407            return Ok(f64::NAN);
408        }
409
410        Ok(match function {
411            AggregateFunction::Sum => values.iter().sum(),
412            AggregateFunction::Avg => values.iter().sum::<f64>() / values.len() as f64,
413            AggregateFunction::Min => values.iter().cloned().fold(f64::INFINITY, f64::min),
414            AggregateFunction::Max => values.iter().cloned().fold(f64::NEG_INFINITY, f64::max),
415            AggregateFunction::Count => values.len() as f64,
416            AggregateFunction::First => values.first().copied().unwrap_or(f64::NAN),
417            AggregateFunction::Last => values.last().copied().unwrap_or(f64::NAN),
418            AggregateFunction::StdDev => {
419                let mean = values.iter().sum::<f64>() / values.len() as f64;
420                let variance =
421                    values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64;
422                variance.sqrt()
423            }
424            AggregateFunction::Variance => {
425                let mean = values.iter().sum::<f64>() / values.len() as f64;
426                values.iter().map(|v| (v - mean).powi(2)).sum::<f64>() / values.len() as f64
427            }
428        })
429    }
430
431    /// Take completed windows
432    pub fn take_completed(&mut self) -> Vec<WindowResult> {
433        std::mem::take(&mut self.completed_windows)
434    }
435
436    /// Flush any remaining windows (call at end of stream)
437    pub fn flush(&mut self) -> Result<Vec<WindowResult>> {
438        // Close any active windows
439        if let Some(ref window) = self.current_window {
440            let result = self.compute_window_result(window)?;
441            self.completed_windows.push(result);
442        }
443
444        for window in &self.active_windows {
445            let result = self.compute_window_result(window)?;
446            self.completed_windows.push(result);
447        }
448
449        self.current_window = None;
450        self.active_windows.clear();
451
452        Ok(self.take_completed())
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    fn make_data_point(
461        timestamp: DateTime<Utc>,
462        value: f64,
463    ) -> (DateTime<Utc>, HashMap<String, ValueWord>) {
464        let mut fields = HashMap::new();
465        fields.insert("value".to_string(), ValueWord::from_f64(value));
466        (timestamp, fields)
467    }
468
469    #[test]
470    fn test_tumbling_window() {
471        let mut manager = WindowManager::new(WindowType::tumbling(Duration::seconds(10)));
472        manager.aggregate("value", AggregateFunction::Sum, "sum");
473        manager.aggregate("value", AggregateFunction::Avg, "avg");
474
475        // Use a fixed base time that aligns well with 10-second windows
476        let base = DateTime::from_timestamp(1000000000, 0).unwrap(); // A nice round timestamp
477
478        // Add points in first window (0-9 seconds)
479        for i in 0..5 {
480            let (ts, fields) = make_data_point(base + Duration::seconds(i), 10.0);
481            manager.process(ts, fields).unwrap();
482        }
483
484        // Should have no completed windows yet (all within first 10-sec window)
485        assert!(
486            manager.take_completed().is_empty(),
487            "Expected no completed windows within first window"
488        );
489
490        // Add point in next window (at 15 seconds, triggers close of first window)
491        let (ts, fields) = make_data_point(base + Duration::seconds(15), 20.0);
492        manager.process(ts, fields).unwrap();
493
494        let completed = manager.take_completed();
495        assert_eq!(completed.len(), 1, "Expected exactly 1 completed window");
496        assert_eq!(completed[0].count, 5, "Expected 5 data points in window");
497        assert_eq!(completed[0].aggregates.get("sum"), Some(&50.0));
498        assert_eq!(completed[0].aggregates.get("avg"), Some(&10.0));
499    }
500
501    #[test]
502    fn test_count_window() {
503        let mut manager = WindowManager::new(WindowType::count(3));
504        manager.aggregate("value", AggregateFunction::Sum, "sum");
505
506        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
507
508        for i in 0..3 {
509            let (ts, fields) = make_data_point(base + Duration::seconds(i as i64), (i + 1) as f64);
510            manager.process(ts, fields).unwrap();
511        }
512
513        let completed = manager.take_completed();
514        assert_eq!(completed.len(), 1);
515        assert_eq!(completed[0].count, 3);
516        assert_eq!(completed[0].aggregates.get("sum"), Some(&6.0)); // 1 + 2 + 3
517    }
518
519    #[test]
520    fn test_session_window() {
521        let mut manager = WindowManager::new(WindowType::session(Duration::seconds(5)));
522        manager.aggregate("value", AggregateFunction::Count, "count");
523
524        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
525
526        // First session: 3 points close together
527        for i in 0..3 {
528            let (ts, fields) = make_data_point(base + Duration::seconds(i), 1.0);
529            manager.process(ts, fields).unwrap();
530        }
531
532        // Gap > 5 seconds, starts new session
533        let (ts, fields) = make_data_point(base + Duration::seconds(10), 1.0);
534        manager.process(ts, fields).unwrap();
535
536        let completed = manager.take_completed();
537        assert_eq!(completed.len(), 1); // First session closed
538        assert_eq!(completed[0].count, 3);
539    }
540
541    #[test]
542    fn test_aggregate_functions() {
543        let mut manager = WindowManager::new(WindowType::count(5));
544        manager.aggregate("value", AggregateFunction::Min, "min");
545        manager.aggregate("value", AggregateFunction::Max, "max");
546        manager.aggregate("value", AggregateFunction::StdDev, "std");
547
548        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
549        let values = [1.0, 2.0, 3.0, 4.0, 5.0];
550
551        for (i, v) in values.iter().enumerate() {
552            let (ts, fields) = make_data_point(base + Duration::seconds(i as i64), *v);
553            manager.process(ts, fields).unwrap();
554        }
555
556        let completed = manager.take_completed();
557        assert_eq!(completed.len(), 1);
558        assert_eq!(completed[0].aggregates.get("min"), Some(&1.0));
559        assert_eq!(completed[0].aggregates.get("max"), Some(&5.0));
560        // Standard deviation of [1,2,3,4,5] is sqrt(2) ≈ 1.414
561        let std = completed[0].aggregates.get("std").unwrap();
562        assert!((std - 1.414).abs() < 0.01);
563    }
564
565    #[test]
566    fn test_flush() {
567        let mut manager = WindowManager::new(WindowType::tumbling(Duration::seconds(10)));
568        manager.aggregate("value", AggregateFunction::Sum, "sum");
569
570        let base = DateTime::from_timestamp(1000000000, 0).unwrap();
571        let (ts, fields) = make_data_point(base, 42.0);
572        manager.process(ts, fields).unwrap();
573
574        // Flush should emit partial window
575        let results = manager.flush().unwrap();
576        assert_eq!(results.len(), 1);
577        assert_eq!(results[0].aggregates.get("sum"), Some(&42.0));
578    }
579}