pulse_ops/
lib.rs

1//! pulse-ops: standard operators built on top of pulse-core.
2//!
3//! Included operators:
4//! - `Map`: one-to-many mapping of JSON payloads
5//! - `Filter`: predicate-based filtering
6//! - `KeyBy`: materialize a `key` field from an existing field
7//! - `Aggregate` (simplified): per-minute running count updates
8//! - `WindowedAggregate`: configurable windows (tumbling/sliding/session) with count/sum/avg/distinct
9
10use std::collections::HashMap;
11
12use async_trait::async_trait;
13use tracing::{instrument, info_span};
14use pulse_core::{Context, EventTime, Operator, Record, Result, Watermark};
15use chrono::{TimeZone, Utc};
16pub mod time;
17pub mod window;
18pub use time::{WatermarkClock, WatermarkPolicy};
19pub use window::{Window, WindowAssigner, WindowOperator};
20
21#[async_trait]
22pub trait FnMap: Send + Sync {
23    async fn call(&self, value: serde_json::Value) -> Result<Vec<serde_json::Value>>;
24}
25
26pub struct MapFn<F>(pub F);
27impl<F> MapFn<F> {
28    pub fn new(f: F) -> Self {
29        Self(f)
30    }
31}
32#[async_trait]
33impl<F> FnMap for MapFn<F>
34where
35    F: Fn(serde_json::Value) -> Vec<serde_json::Value> + Send + Sync,
36{
37    async fn call(&self, value: serde_json::Value) -> Result<Vec<serde_json::Value>> {
38        Ok((self.0)(value))
39    }
40}
41
42/// Map operator: applies a user function that returns zero or more outputs per input.
43/// Map operator: applies a user function that returns zero or more outputs per input.
44///
45/// Example
46/// ```no_run
47/// use pulse_ops::{Map, MapFn};
48/// let map = Map::new(MapFn::new(|v: serde_json::Value| vec![v]));
49/// # let _ = map;
50/// ```
51pub struct Map<F> {
52    func: F,
53}
54impl<F> Map<F> {
55    pub fn new(func: F) -> Self {
56        Self { func }
57    }
58}
59
60#[async_trait]
61impl<F> Operator for Map<F>
62where
63    F: FnMap + Send + Sync + 'static,
64{
65    #[instrument(name = "map_on_element", skip_all)]
66    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
67        let outs = self.func.call(rec.value).await?;
68        pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Map", "receive"]).inc();
69        for v in outs {
70            ctx.collect(Record {
71                event_time: rec.event_time,
72                value: v.clone(),
73            });
74            pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Map", "emit"]).inc();
75        }
76        Ok(())
77    }
78}
79
80#[async_trait]
81pub trait FnFilter: Send + Sync {
82    async fn call(&self, value: &serde_json::Value) -> Result<bool>;
83}
84
85pub struct FilterFn<F>(pub F);
86impl<F> FilterFn<F> {
87    pub fn new(f: F) -> Self {
88        Self(f)
89    }
90}
91#[async_trait]
92impl<F> FnFilter for FilterFn<F>
93where
94    F: Fn(&serde_json::Value) -> bool + Send + Sync,
95{
96    async fn call(&self, value: &serde_json::Value) -> Result<bool> {
97        Ok((self.0)(value))
98    }
99}
100
101/// Filter operator: keeps inputs that satisfy the predicate.
102///
103/// Example
104/// ```no_run
105/// use pulse_ops::{Filter, FilterFn};
106/// let filter = Filter::new(FilterFn::new(|v: &serde_json::Value| v.get("ok").and_then(|x| x.as_bool()).unwrap_or(false)));
107/// # let _ = filter;
108/// ```
109pub struct Filter<F> {
110    pred: F,
111}
112impl<F> Filter<F> {
113    pub fn new(pred: F) -> Self {
114        Self { pred }
115    }
116}
117
118#[async_trait]
119impl<F> Operator for Filter<F>
120where
121    F: FnFilter + Send + Sync + 'static,
122{
123    #[instrument(name = "filter_on_element", skip_all)]
124    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
125        pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Filter", "receive"]).inc();
126        if self.pred.call(&rec.value).await? {
127            ctx.collect(rec);
128            pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Filter", "emit"]).inc();
129        }
130        Ok(())
131    }
132}
133
134/// KeyBy operator: copies an existing field into a canonical `key` field.
135///
136/// Example
137/// ```no_run
138/// use pulse_ops::KeyBy;
139/// let key_by = KeyBy::new("word");
140/// # let _ = key_by;
141/// ```
142pub struct KeyBy {
143    field: String,
144}
145impl KeyBy {
146    pub fn new(field: impl Into<String>) -> Self {
147        Self { field: field.into() }
148    }
149}
150
151#[async_trait]
152impl Operator for KeyBy {
153    #[instrument(name = "keyby_on_element", skip_all)]
154    async fn on_element(&mut self, ctx: &mut dyn Context, mut rec: Record) -> Result<()> {
155        pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["KeyBy", "receive"]).inc();
156        let key = rec
157            .value
158            .get(&self.field)
159            .cloned()
160            .unwrap_or(serde_json::Value::Null);
161        let mut obj = match rec.value {
162            serde_json::Value::Object(o) => o,
163            _ => serde_json::Map::new(),
164        };
165        obj.insert("key".to_string(), key);
166        rec.value = serde_json::Value::Object(obj);
167        ctx.collect(rec);
168        pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["KeyBy", "emit"]).inc();
169        Ok(())
170    }
171}
172
173/// Fixed-size tumbling window helper (legacy from the simple Aggregate).
174#[derive(Clone, Copy)]
175pub struct WindowTumbling {
176    pub size_ms: i64,
177}
178impl WindowTumbling {
179    pub fn minutes(m: i64) -> Self {
180        Self { size_ms: m * 60_000 }
181    }
182}
183
184/// Simple aggregate that maintains a per-minute count per `key_field`.
185/// Simple aggregate that maintains a per-minute count per `key_field`.
186///
187/// Example
188/// ```no_run
189/// use pulse_ops::Aggregate;
190/// let agg = Aggregate::count_per_window("key", "word");
191/// # let _ = agg;
192/// ```
193pub struct Aggregate {
194    pub key_field: String,
195    pub value_field: String,
196    pub op: AggregationKind,
197    windows: HashMap<(i128, serde_json::Value), i64>, // (window_start, key) -> count
198}
199
200/// Supported aggregation kinds for the simple `Aggregate`.
201#[derive(Clone, Copy)]
202pub enum AggregationKind {
203    Count,
204}
205
206impl Aggregate {
207    pub fn count_per_window(key_field: impl Into<String>, value_field: impl Into<String>) -> Self {
208        Self {
209            key_field: key_field.into(),
210            value_field: value_field.into(),
211            op: AggregationKind::Count,
212            windows: HashMap::new(),
213        }
214    }
215}
216
217#[async_trait]
218impl Operator for Aggregate {
219    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
220        let minute_ms = 60_000_i128;
221        let ts_ms = rec.event_time.timestamp_millis() as i128; // ms
222        let win_start_ms = (ts_ms / minute_ms) * minute_ms;
223        let key = rec
224            .value
225            .get(&self.key_field)
226            .cloned()
227            .unwrap_or(serde_json::Value::Null);
228        let entry = self.windows.entry((win_start_ms, key.clone())).or_insert(0);
229        *entry += 1;
230        // Emit current count as an update
231        let mut out = serde_json::Map::new();
232        out.insert("window_start_ms".into(), serde_json::json!(win_start_ms));
233        out.insert("key".into(), key);
234        out.insert("count".into(), serde_json::json!(*entry));
235        ctx.collect(Record {
236            event_time: rec.event_time,
237            value: serde_json::Value::Object(out),
238        });
239        Ok(())
240    }
241    async fn on_watermark(&mut self, _ctx: &mut dyn Context, _wm: Watermark) -> Result<()> {
242        Ok(())
243    }
244}
245
246pub mod prelude {
247    pub use super::{
248        AggKind, Aggregate, AggregationKind, Filter, FnFilter, FnMap, KeyBy, Map, WindowKind, WindowTumbling,
249        WindowedAggregate,
250    };
251}
252
253// ===== Windowed, configurable aggregations =====
254
255/// Kinds of windows supported by `WindowedAggregate`.
256#[derive(Clone, Debug)]
257pub enum WindowKind {
258    Tumbling { size_ms: i64 },
259    Sliding { size_ms: i64, slide_ms: i64 },
260    Session { gap_ms: i64 },
261}
262
263/// Supported aggregation kinds for `WindowedAggregate`.
264#[derive(Clone, Debug)]
265pub enum AggKind {
266    Count,
267    Sum { field: String },
268    Avg { field: String },
269    Distinct { field: String },
270}
271
272#[derive(Clone, Debug, Default)]
273enum AggState {
274    #[default]
275    Empty,
276    Count(i64),
277    Sum {
278        sum: f64,
279        count: i64,
280    }, // count is reused for avg
281    Distinct(std::collections::HashSet<String>),
282}
283
284fn as_f64(v: &serde_json::Value) -> f64 {
285    match v {
286        serde_json::Value::Number(n) => n.as_f64().unwrap_or(0.0),
287        serde_json::Value::String(s) => s.parse::<f64>().unwrap_or(0.0),
288        _ => 0.0,
289    }
290}
291
292fn stringify(v: &serde_json::Value) -> String {
293    match v {
294        serde_json::Value::String(s) => s.clone(),
295        other => other.to_string(),
296    }
297}
298
299/// A stateful windowed aggregation operator supporting different windows & aggregations.
300/// A stateful windowed aggregation operator supporting different windows & aggregations.
301///
302/// Examples
303/// ```no_run
304/// use pulse_ops::WindowedAggregate;
305/// // Tumbling count of words per 60s window
306/// let op = WindowedAggregate::tumbling_count("word", 60_000);
307/// # let _ = op;
308/// ```
309pub struct WindowedAggregate {
310    pub key_field: String,
311    pub win: WindowKind,
312    pub agg: AggKind,
313    // For tumbling/sliding: (end_ms, key) -> state, and track start_ms via map
314    by_window: HashMap<(i128, serde_json::Value), (i128 /*start_ms*/, AggState)>,
315    // For session: key -> (start_ms, last_seen_ms, state)
316    sessions: HashMap<serde_json::Value, (i128, i128, AggState)>,
317}
318
319impl WindowedAggregate {
320    pub fn tumbling_count(key_field: impl Into<String>, size_ms: i64) -> Self {
321        Self {
322            key_field: key_field.into(),
323            win: WindowKind::Tumbling { size_ms },
324            agg: AggKind::Count,
325            by_window: HashMap::new(),
326            sessions: HashMap::new(),
327        }
328    }
329    pub fn tumbling_sum(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
330        Self {
331            key_field: key_field.into(),
332            win: WindowKind::Tumbling { size_ms },
333            agg: AggKind::Sum { field: field.into() },
334            by_window: HashMap::new(),
335            sessions: HashMap::new(),
336        }
337    }
338    pub fn tumbling_avg(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
339        Self {
340            key_field: key_field.into(),
341            win: WindowKind::Tumbling { size_ms },
342            agg: AggKind::Avg { field: field.into() },
343            by_window: HashMap::new(),
344            sessions: HashMap::new(),
345        }
346    }
347    pub fn tumbling_distinct(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
348        Self {
349            key_field: key_field.into(),
350            win: WindowKind::Tumbling { size_ms },
351            agg: AggKind::Distinct { field: field.into() },
352            by_window: HashMap::new(),
353            sessions: HashMap::new(),
354        }
355    }
356
357    pub fn sliding_count(key_field: impl Into<String>, size_ms: i64, slide_ms: i64) -> Self {
358        Self {
359            key_field: key_field.into(),
360            win: WindowKind::Sliding { size_ms, slide_ms },
361            agg: AggKind::Count,
362            by_window: HashMap::new(),
363            sessions: HashMap::new(),
364        }
365    }
366    pub fn sliding_sum(
367        key_field: impl Into<String>,
368        size_ms: i64,
369        slide_ms: i64,
370        field: impl Into<String>,
371    ) -> Self {
372        Self {
373            key_field: key_field.into(),
374            win: WindowKind::Sliding { size_ms, slide_ms },
375            agg: AggKind::Sum { field: field.into() },
376            by_window: HashMap::new(),
377            sessions: HashMap::new(),
378        }
379    }
380    pub fn sliding_avg(
381        key_field: impl Into<String>,
382        size_ms: i64,
383        slide_ms: i64,
384        field: impl Into<String>,
385    ) -> Self {
386        Self {
387            key_field: key_field.into(),
388            win: WindowKind::Sliding { size_ms, slide_ms },
389            agg: AggKind::Avg { field: field.into() },
390            by_window: HashMap::new(),
391            sessions: HashMap::new(),
392        }
393    }
394    pub fn sliding_distinct(
395        key_field: impl Into<String>,
396        size_ms: i64,
397        slide_ms: i64,
398        field: impl Into<String>,
399    ) -> Self {
400        Self {
401            key_field: key_field.into(),
402            win: WindowKind::Sliding { size_ms, slide_ms },
403            agg: AggKind::Distinct { field: field.into() },
404            by_window: HashMap::new(),
405            sessions: HashMap::new(),
406        }
407    }
408
409    pub fn session_count(key_field: impl Into<String>, gap_ms: i64) -> Self {
410        Self {
411            key_field: key_field.into(),
412            win: WindowKind::Session { gap_ms },
413            agg: AggKind::Count,
414            by_window: HashMap::new(),
415            sessions: HashMap::new(),
416        }
417    }
418    pub fn session_sum(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
419        Self {
420            key_field: key_field.into(),
421            win: WindowKind::Session { gap_ms },
422            agg: AggKind::Sum { field: field.into() },
423            by_window: HashMap::new(),
424            sessions: HashMap::new(),
425        }
426    }
427    pub fn session_avg(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
428        Self {
429            key_field: key_field.into(),
430            win: WindowKind::Session { gap_ms },
431            agg: AggKind::Avg { field: field.into() },
432            by_window: HashMap::new(),
433            sessions: HashMap::new(),
434        }
435    }
436    pub fn session_distinct(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
437        Self {
438            key_field: key_field.into(),
439            win: WindowKind::Session { gap_ms },
440            agg: AggKind::Distinct { field: field.into() },
441            by_window: HashMap::new(),
442            sessions: HashMap::new(),
443        }
444    }
445}
446
447fn update_state(state: &mut AggState, agg: &AggKind, value: &serde_json::Value) {
448    match agg {
449        AggKind::Count => {
450            *state = match std::mem::take(state) {
451                AggState::Empty => AggState::Count(1),
452                AggState::Count(c) => AggState::Count(c + 1),
453                other => other,
454            };
455        }
456        AggKind::Sum { field } => {
457            let x = as_f64(value.get(field).unwrap_or(&serde_json::Value::Null));
458            *state = match std::mem::take(state) {
459                AggState::Empty => AggState::Sum { sum: x, count: 1 },
460                AggState::Sum { sum, count } => AggState::Sum {
461                    sum: sum + x,
462                    count: count + 1,
463                },
464                other => other,
465            };
466        }
467        AggKind::Avg { field } => {
468            let x = as_f64(value.get(field).unwrap_or(&serde_json::Value::Null));
469            *state = match std::mem::take(state) {
470                AggState::Empty => AggState::Sum { sum: x, count: 1 },
471                AggState::Sum { sum, count } => AggState::Sum {
472                    sum: sum + x,
473                    count: count + 1,
474                },
475                other => other,
476            };
477        }
478        AggKind::Distinct { field } => {
479            let s = stringify(value.get(field).unwrap_or(&serde_json::Value::Null));
480            *state = match std::mem::take(state) {
481                AggState::Empty => {
482                    let mut set = std::collections::HashSet::new();
483                    set.insert(s);
484                    AggState::Distinct(set)
485                }
486                AggState::Distinct(mut set) => {
487                    set.insert(s);
488                    AggState::Distinct(set)
489                }
490                other => other,
491            };
492        }
493    }
494}
495
496fn finalize_value(state: &AggState, agg: &AggKind) -> serde_json::Value {
497    match (state, agg) {
498        (AggState::Count(c), _) => serde_json::json!(*c),
499        (AggState::Sum { sum, .. }, AggKind::Sum { .. }) => serde_json::json!(sum),
500        (AggState::Sum { sum, count }, AggKind::Avg { .. }) => {
501            let avg = if *count > 0 { *sum / (*count as f64) } else { 0.0 };
502            serde_json::json!(avg)
503        }
504        (AggState::Distinct(set), AggKind::Distinct { .. }) => serde_json::json!(set.len() as i64),
505        _ => serde_json::json!(null),
506    }
507}
508
509#[async_trait]
510impl Operator for WindowedAggregate {
511    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
512        let ts_ms = rec.event_time.timestamp_millis() as i128; // ms
513        let key = rec
514            .value
515            .get(&self.key_field)
516            .cloned()
517            .unwrap_or(serde_json::Value::Null);
518
519        match self.win {
520            WindowKind::Tumbling { size_ms } => {
521                let start = (ts_ms / (size_ms as i128)) * (size_ms as i128);
522                let end = start + (size_ms as i128);
523                let entry = self
524                    .by_window
525                    .entry((end, key.clone()))
526                    .or_insert((start, AggState::Empty));
527                update_state(&mut entry.1, &self.agg, &rec.value);
528                // Optional: schedule a timer at end
529                let _ = ctx
530                    .timers()
531                    .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
532                    .await;
533            }
534            WindowKind::Sliding { size_ms, slide_ms } => {
535                let k = (size_ms / slide_ms) as i128;
536                let anchor = (ts_ms / (slide_ms as i128)) * (slide_ms as i128);
537                for j in 0..k {
538                    let start = anchor - (j * (slide_ms as i128));
539                    let end = start + (size_ms as i128);
540                    if start <= ts_ms && end > ts_ms {
541                        let entry = self
542                            .by_window
543                            .entry((end, key.clone()))
544                            .or_insert((start, AggState::Empty));
545                        update_state(&mut entry.1, &self.agg, &rec.value);
546                        let _ = ctx
547                            .timers()
548                            .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
549                            .await;
550                    }
551                }
552            }
553            WindowKind::Session { gap_ms } => {
554                let e = self
555                    .sessions
556                    .entry(key.clone())
557                    .or_insert((ts_ms, ts_ms, AggState::Empty));
558                let (start, last_seen, state) = e;
559                if ts_ms - *last_seen <= (gap_ms as i128) {
560                    *last_seen = ts_ms;
561                    update_state(state, &self.agg, &rec.value);
562                } else {
563                    // close previous session
564                    let mut out = serde_json::Map::new();
565                    out.insert("window_start_ms".into(), serde_json::json!(*start));
566                    out.insert(
567                        "window_end_ms".into(),
568                        serde_json::json!(*last_seen + (gap_ms as i128)),
569                    );
570                    out.insert("key".into(), key.clone());
571                    let val = finalize_value(state, &self.agg);
572                    match self.agg {
573                        AggKind::Count => {
574                            out.insert("count".into(), val);
575                        }
576                        AggKind::Sum { .. } => {
577                            out.insert("sum".into(), val);
578                        }
579                        AggKind::Avg { .. } => {
580                            out.insert("avg".into(), val);
581                        }
582                        AggKind::Distinct { .. } => {
583                            out.insert("distinct_count".into(), val);
584                        }
585                    }
586                    ctx.collect(Record {
587                        event_time: rec.event_time,
588                        value: serde_json::Value::Object(out),
589                    });
590                    // start new
591                    *start = ts_ms;
592                    *last_seen = ts_ms;
593                    *state = AggState::Empty;
594                    update_state(state, &self.agg, &rec.value);
595                }
596                // schedule close timer
597                let end = ts_ms + (gap_ms as i128);
598                let _ = ctx
599                    .timers()
600                    .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
601                    .await;
602            }
603        }
604        Ok(())
605    }
606
607    async fn on_watermark(&mut self, ctx: &mut dyn Context, wm: Watermark) -> Result<()> {
608    let wm_ms = wm.0 .0.timestamp_millis() as i128;
609
610        match self.win {
611            WindowKind::Tumbling { .. } | WindowKind::Sliding { .. } => {
612                // Emit and clear all windows with end <= wm
613                let mut to_emit: Vec<((i128, serde_json::Value), (i128, AggState))> = self
614                    .by_window
615                    .iter()
616                    .filter(|((end, _), _)| *end <= wm_ms)
617                    .map(|(k, v)| (k.clone(), v.clone()))
618                    .collect();
619                for ((end, key), (start, state)) in to_emit.drain(..) {
620                    let mut out = serde_json::Map::new();
621                    out.insert("window_start_ms".into(), serde_json::json!(start));
622                    out.insert("window_end_ms".into(), serde_json::json!(end));
623                    out.insert("key".into(), key.clone());
624                    let val = finalize_value(&state, &self.agg);
625                    match self.agg {
626                        AggKind::Count => {
627                            out.insert("count".into(), val);
628                        }
629                        AggKind::Sum { .. } => {
630                            out.insert("sum".into(), val);
631                        }
632                        AggKind::Avg { .. } => {
633                            out.insert("avg".into(), val);
634                        }
635                        AggKind::Distinct { .. } => {
636                            out.insert("distinct_count".into(), val);
637                        }
638                    }
639                    ctx.collect(Record { event_time: wm.0 .0, value: serde_json::Value::Object(out) });
640                    self.by_window.remove(&(end, key));
641                }
642            }
643            WindowKind::Session { gap_ms } => {
644                // Close sessions whose inactivity + gap <= wm
645                let keys: Vec<_> = self.sessions.keys().cloned().collect();
646                for key in keys {
647                    if let Some((start, last_seen, state)) = self.sessions.get(&key).cloned() {
648                        if last_seen + (gap_ms as i128) <= wm_ms {
649                            let mut out = serde_json::Map::new();
650                            out.insert("window_start_ms".into(), serde_json::json!(start));
651                            out.insert(
652                                "window_end_ms".into(),
653                                serde_json::json!(last_seen + (gap_ms as i128)),
654                            );
655                            out.insert("key".into(), key.clone());
656                            let val = finalize_value(&state, &self.agg);
657                            match self.agg {
658                                AggKind::Count => {
659                                    out.insert("count".into(), val);
660                                }
661                                AggKind::Sum { .. } => {
662                                    out.insert("sum".into(), val);
663                                }
664                                AggKind::Avg { .. } => {
665                                    out.insert("avg".into(), val);
666                                }
667                                AggKind::Distinct { .. } => {
668                                    out.insert("distinct_count".into(), val);
669                                }
670                            }
671                            ctx.collect(Record { event_time: wm.0 .0, value: serde_json::Value::Object(out) });
672                            self.sessions.remove(&key);
673                        }
674                    }
675                }
676            }
677        }
678        Ok(())
679    }
680
681    async fn on_timer(
682        &mut self,
683        ctx: &mut dyn Context,
684        when: EventTime,
685        _key: Option<Vec<u8>>,
686    ) -> Result<()> {
687        // Treat timers same as watermarks for emission, but only for windows ending at `when`.
688        let when_ms = when.0.timestamp_millis() as i128;
689
690        match self.win {
691            WindowKind::Tumbling { .. } | WindowKind::Sliding { .. } => {
692                let mut to_emit: Vec<((i128, serde_json::Value), (i128, AggState))> = self
693                    .by_window
694                    .iter()
695                    .filter(|((end, _), _)| *end <= when_ms)
696                    .map(|(k, v)| (k.clone(), v.clone()))
697                    .collect();
698                for ((end, key), (start, state)) in to_emit.drain(..) {
699                    let mut out = serde_json::Map::new();
700                    out.insert("window_start_ms".into(), serde_json::json!(start));
701                    out.insert("window_end_ms".into(), serde_json::json!(end));
702                    out.insert("key".into(), key.clone());
703                    let val = finalize_value(&state, &self.agg);
704                    match self.agg {
705                        AggKind::Count => {
706                            out.insert("count".into(), val);
707                        }
708                        AggKind::Sum { .. } => {
709                            out.insert("sum".into(), val);
710                        }
711                        AggKind::Avg { .. } => {
712                            out.insert("avg".into(), val);
713                        }
714                        AggKind::Distinct { .. } => {
715                            out.insert("distinct_count".into(), val);
716                        }
717                    }
718                    ctx.collect(Record { event_time: when.0, value: serde_json::Value::Object(out) });
719                    self.by_window.remove(&(end, key));
720                }
721            }
722            WindowKind::Session { gap_ms } => {
723                let keys: Vec<_> = self.sessions.keys().cloned().collect();
724                for key in keys {
725                    if let Some((start, last_seen, state)) = self.sessions.get(&key).cloned() {
726                        if last_seen + (gap_ms as i128) <= when_ms {
727                            let mut out = serde_json::Map::new();
728                            out.insert("window_start_ms".into(), serde_json::json!(start));
729                            out.insert(
730                                "window_end_ms".into(),
731                                serde_json::json!(last_seen + (gap_ms as i128)),
732                            );
733                            out.insert("key".into(), key.clone());
734                            let val = finalize_value(&state, &self.agg);
735                            match self.agg {
736                                AggKind::Count => {
737                                    out.insert("count".into(), val);
738                                }
739                                AggKind::Sum { .. } => {
740                                    out.insert("sum".into(), val);
741                                }
742                                AggKind::Avg { .. } => {
743                                    out.insert("avg".into(), val);
744                                }
745                                AggKind::Distinct { .. } => {
746                                    out.insert("distinct_count".into(), val);
747                                }
748                            }
749                            ctx.collect(Record { event_time: when.0, value: serde_json::Value::Object(out) });
750                            self.sessions.remove(&key);
751                        }
752                    }
753                }
754            }
755        }
756        Ok(())
757    }
758}
759
760#[cfg(test)]
761mod window_tests {
762    use super::*;
763    use pulse_core::{Context, EventTime, KvState, Record, Result, Timers, Watermark};
764    use std::sync::Arc;
765
766    struct TestState;
767    #[async_trait]
768    impl KvState for TestState {
769        async fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>> {
770            Ok(None)
771        }
772        async fn put(&self, _key: &[u8], _value: Vec<u8>) -> Result<()> {
773            Ok(())
774        }
775        async fn delete(&self, _key: &[u8]) -> Result<()> {
776            Ok(())
777        }
778        async fn iter_prefix(&self, _prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
779            Ok(Vec::new())
780        }
781        async fn snapshot(&self) -> Result<pulse_core::SnapshotId> {
782            Ok("test-snap".to_string())
783        }
784        async fn restore(&self, _snapshot: pulse_core::SnapshotId) -> Result<()> {
785            Ok(())
786        }
787    }
788    struct TestTimers;
789    #[async_trait]
790    impl Timers for TestTimers {
791        async fn register_event_time_timer(&self, _when: EventTime, _key: Option<Vec<u8>>) -> Result<()> {
792            Ok(())
793        }
794    }
795
796    struct TestCtx {
797        out: Vec<Record>,
798        kv: Arc<dyn KvState>,
799        timers: Arc<dyn Timers>,
800    }
801    #[async_trait]
802    impl Context for TestCtx {
803        fn collect(&mut self, record: Record) {
804            self.out.push(record);
805        }
806        fn watermark(&mut self, _wm: Watermark) {}
807        fn kv(&self) -> Arc<dyn KvState> {
808            self.kv.clone()
809        }
810        fn timers(&self) -> Arc<dyn Timers> {
811            self.timers.clone()
812        }
813    }
814
815    fn record_with(ts_ms: i128, key: &str) -> Record {
816        Record {
817            event_time: Utc.timestamp_millis_opt(ts_ms as i64).unwrap(),
818            value: serde_json::json!({"word": key}),
819        }
820    }
821
822    #[tokio::test]
823    async fn tumbling_count_emits_on_watermark() {
824        let mut op = WindowedAggregate::tumbling_count("word", 60_000);
825        let mut ctx = TestCtx {
826            out: vec![],
827            kv: Arc::new(TestState),
828            timers: Arc::new(TestTimers),
829        };
830        op.on_element(&mut ctx, record_with(1_000, "a")).await.unwrap();
831        op.on_element(&mut ctx, record_with(1_010, "a")).await.unwrap();
832        // Watermark after end of window 0..60000
833        op.on_watermark(&mut ctx, Watermark(EventTime(Utc.timestamp_millis_opt(60_000).unwrap())))
834            .await
835            .unwrap();
836        assert_eq!(ctx.out.len(), 1);
837        assert_eq!(ctx.out[0].value["count"], serde_json::json!(2));
838    }
839}
840#[cfg(test)]
841mod tests {
842    use super::*;
843    use pulse_core::{Context, EventTime, KvState, Record, Result, Timers};
844    use std::sync::Arc;
845
846    struct TestState;
847    #[async_trait]
848    impl KvState for TestState {
849        async fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>> {
850            Ok(None)
851        }
852        async fn put(&self, _key: &[u8], _value: Vec<u8>) -> Result<()> {
853            Ok(())
854        }
855        async fn delete(&self, _key: &[u8]) -> Result<()> {
856            Ok(())
857        }
858        async fn iter_prefix(&self, _prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
859            Ok(Vec::new())
860        }
861        async fn snapshot(&self) -> Result<pulse_core::SnapshotId> {
862            Ok("test-snap".to_string())
863        }
864        async fn restore(&self, _snapshot: pulse_core::SnapshotId) -> Result<()> {
865            Ok(())
866        }
867    }
868
869    struct TestTimers;
870    #[async_trait]
871    impl Timers for TestTimers {
872        async fn register_event_time_timer(&self, _when: EventTime, _key: Option<Vec<u8>>) -> Result<()> {
873            Ok(())
874        }
875    }
876
877    struct TestCtx {
878        out: Vec<Record>,
879        kv: Arc<dyn KvState>,
880        timers: Arc<dyn Timers>,
881    }
882
883    #[async_trait]
884    impl Context for TestCtx {
885        fn collect(&mut self, record: Record) {
886            self.out.push(record);
887        }
888        fn watermark(&mut self, _wm: pulse_core::Watermark) {}
889        fn kv(&self) -> Arc<dyn KvState> {
890            self.kv.clone()
891        }
892        fn timers(&self) -> Arc<dyn Timers> {
893            self.timers.clone()
894        }
895    }
896
897    fn rec(v: serde_json::Value) -> Record {
898        Record {
899            event_time: Utc::now(),
900            value: v,
901        }
902    }
903
904    #[tokio::test]
905    async fn test_map() {
906        let mut op = Map::new(MapFn::new(|v| vec![v]));
907        let mut ctx = TestCtx {
908            out: vec![],
909            kv: Arc::new(TestState),
910            timers: Arc::new(TestTimers),
911        };
912        op.on_element(&mut ctx, rec(serde_json::json!({"a":1})))
913            .await
914            .unwrap();
915        assert_eq!(ctx.out.len(), 1);
916    }
917
918    #[tokio::test]
919    async fn test_filter() {
920        let mut op = Filter::new(FilterFn::new(|v: &serde_json::Value| {
921            v.get("ok").and_then(|x| x.as_bool()).unwrap_or(false)
922        }));
923        let mut ctx = TestCtx {
924            out: vec![],
925            kv: Arc::new(TestState),
926            timers: Arc::new(TestTimers),
927        };
928        op.on_element(&mut ctx, rec(serde_json::json!({"ok":false})))
929            .await
930            .unwrap();
931        op.on_element(&mut ctx, rec(serde_json::json!({"ok":true})))
932            .await
933            .unwrap();
934        assert_eq!(ctx.out.len(), 1);
935    }
936
937    #[tokio::test]
938    async fn test_keyby() {
939        let mut op = KeyBy::new("word");
940        let mut ctx = TestCtx {
941            out: vec![],
942            kv: Arc::new(TestState),
943            timers: Arc::new(TestTimers),
944        };
945        op.on_element(&mut ctx, rec(serde_json::json!({"word":"hi"})))
946            .await
947            .unwrap();
948        assert_eq!(ctx.out.len(), 1);
949        assert_eq!(ctx.out[0].value["key"], serde_json::json!("hi"));
950    }
951
952    #[tokio::test]
953    async fn test_aggregate_count() {
954        let mut op = Aggregate::count_per_window("key", "word");
955        let mut ctx = TestCtx {
956            out: vec![],
957            kv: Arc::new(TestState),
958            timers: Arc::new(TestTimers),
959        };
960        op.on_element(&mut ctx, rec(serde_json::json!({"key":"hello"})))
961            .await
962            .unwrap();
963        op.on_element(&mut ctx, rec(serde_json::json!({"key":"hello"})))
964            .await
965            .unwrap();
966        assert_eq!(ctx.out.len(), 2);
967        assert_eq!(ctx.out[1].value["count"], serde_json::json!(2));
968    }
969}