pulse_ops/
lib.rs

1//! pulse-ops: standard operators built on top of pulse-core.
2//!
3//! Included operators:
4//! - `Map`: one-to-many mapping of JSON payloads
5//! - `Filter`: predicate-based filtering
6//! - `KeyBy`: materialize a `key` field from an existing field
7//! - `Aggregate` (simplified): per-minute running count updates
8//! - `WindowedAggregate`: configurable windows (tumbling/sliding/session) with count/sum/avg/distinct
9
10use std::collections::HashMap;
11
12use async_trait::async_trait;
13use tracing::{instrument, info_span};
14use pulse_core::{Context, EventTime, Operator, Record, Result, Watermark};
15use chrono::{TimeZone, Utc};
16pub mod time;
17pub mod window;
18pub use time::{WatermarkClock, WatermarkPolicy};
19pub use window::{Window, WindowAssigner, WindowOperator};
20
21#[async_trait]
22pub trait FnMap: Send + Sync {
23    async fn call(&self, value: serde_json::Value) -> Result<Vec<serde_json::Value>>;
24}
25
26pub struct MapFn<F>(pub F);
27impl<F> MapFn<F> {
28    pub fn new(f: F) -> Self {
29        Self(f)
30    }
31}
32#[async_trait]
33impl<F> FnMap for MapFn<F>
34where
35    F: Fn(serde_json::Value) -> Vec<serde_json::Value> + Send + Sync,
36{
37    async fn call(&self, value: serde_json::Value) -> Result<Vec<serde_json::Value>> {
38        Ok((self.0)(value))
39    }
40}
41
42/// Map operator: applies a user function that returns zero or more outputs per input.
43/// Map operator: applies a user function that returns zero or more outputs per input.
44///
45/// Example
46/// ```no_run
47/// use pulse_ops::{Map, MapFn};
48/// let map = Map::new(MapFn::new(|v: serde_json::Value| vec![v]));
49/// # let _ = map;
50/// ```
51pub struct Map<F> {
52    func: F,
53}
54impl<F> Map<F> {
55    pub fn new(func: F) -> Self {
56        Self { func }
57    }
58}
59
60#[async_trait]
61impl<F> Operator for Map<F>
62where
63    F: FnMap + Send + Sync + 'static,
64{
65    #[instrument(name = "map_on_element", skip_all)]
66    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
67        let outs = self.func.call(rec.value).await?;
68        pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Map", "receive"]).inc();
69        for v in outs {
70            ctx.collect(Record {
71                event_time: rec.event_time,
72                value: v.clone(),
73            });
74            pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Map", "emit"]).inc();
75        }
76        Ok(())
77    }
78}
79
80#[async_trait]
81pub trait FnFilter: Send + Sync {
82    async fn call(&self, value: &serde_json::Value) -> Result<bool>;
83}
84
85pub struct FilterFn<F>(pub F);
86impl<F> FilterFn<F> {
87    pub fn new(f: F) -> Self {
88        Self(f)
89    }
90}
91#[async_trait]
92impl<F> FnFilter for FilterFn<F>
93where
94    F: Fn(&serde_json::Value) -> bool + Send + Sync,
95{
96    async fn call(&self, value: &serde_json::Value) -> Result<bool> {
97        Ok((self.0)(value))
98    }
99}
100
101/// Filter operator: keeps inputs that satisfy the predicate.
102///
103/// Example
104/// ```no_run
105/// use pulse_ops::{Filter, FilterFn};
106/// let filter = Filter::new(FilterFn::new(|v: &serde_json::Value| v.get("ok").and_then(|x| x.as_bool()).unwrap_or(false)));
107/// # let _ = filter;
108/// ```
109pub struct Filter<F> {
110    pred: F,
111}
112impl<F> Filter<F> {
113    pub fn new(pred: F) -> Self {
114        Self { pred }
115    }
116}
117
118#[async_trait]
119impl<F> Operator for Filter<F>
120where
121    F: FnFilter + Send + Sync + 'static,
122{
123    #[instrument(name = "filter_on_element", skip_all)]
124    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
125        pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Filter", "receive"]).inc();
126        if self.pred.call(&rec.value).await? {
127            ctx.collect(rec);
128            pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Filter", "emit"]).inc();
129        }
130        Ok(())
131    }
132}
133
134/// KeyBy operator: copies an existing field into a canonical `key` field.
135///
136/// Example
137/// ```no_run
138/// use pulse_ops::KeyBy;
139/// let key_by = KeyBy::new("word");
140/// # let _ = key_by;
141/// ```
142pub struct KeyBy {
143    field: String,
144}
145impl KeyBy {
146    pub fn new(field: impl Into<String>) -> Self {
147        Self { field: field.into() }
148    }
149}
150
151#[async_trait]
152impl Operator for KeyBy {
153    #[instrument(name = "keyby_on_element", skip_all)]
154    async fn on_element(&mut self, ctx: &mut dyn Context, mut rec: Record) -> Result<()> {
155        pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["KeyBy", "receive"]).inc();
156        let key = rec
157            .value
158            .get(&self.field)
159            .cloned()
160            .unwrap_or(serde_json::Value::Null);
161        let mut obj = match rec.value {
162            serde_json::Value::Object(o) => o,
163            _ => serde_json::Map::new(),
164        };
165        obj.insert("key".to_string(), key);
166        rec.value = serde_json::Value::Object(obj);
167        ctx.collect(rec);
168        pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["KeyBy", "emit"]).inc();
169        Ok(())
170    }
171}
172
173/// Fixed-size tumbling window helper (legacy from the simple Aggregate).
174#[derive(Clone, Copy)]
175pub struct WindowTumbling {
176    pub size_ms: i64,
177}
178impl WindowTumbling {
179    pub fn minutes(m: i64) -> Self {
180        Self { size_ms: m * 60_000 }
181    }
182}
183
184/// Simple aggregate that maintains a per-minute count per `key_field`.
185/// Simple aggregate that maintains a per-minute count per `key_field`.
186///
187/// Example
188/// ```no_run
189/// use pulse_ops::Aggregate;
190/// let agg = Aggregate::count_per_window("key", "word");
191/// # let _ = agg;
192/// ```
193pub struct Aggregate {
194    pub key_field: String,
195    pub value_field: String,
196    pub op: AggregationKind,
197    windows: HashMap<(i128, serde_json::Value), i64>, // (window_start, key) -> count
198}
199
200/// Supported aggregation kinds for the simple `Aggregate`.
201#[derive(Clone, Copy)]
202pub enum AggregationKind {
203    Count,
204}
205
206impl Aggregate {
207    pub fn count_per_window(key_field: impl Into<String>, value_field: impl Into<String>) -> Self {
208        Self {
209            key_field: key_field.into(),
210            value_field: value_field.into(),
211            op: AggregationKind::Count,
212            windows: HashMap::new(),
213        }
214    }
215}
216
217#[async_trait]
218impl Operator for Aggregate {
219    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
220        let minute_ms = 60_000_i128;
221        let ts_ms = rec.event_time.timestamp_millis() as i128; // ms
222        let win_start_ms = (ts_ms / minute_ms) * minute_ms;
223        let key = rec
224            .value
225            .get(&self.key_field)
226            .cloned()
227            .unwrap_or(serde_json::Value::Null);
228        let entry = self.windows.entry((win_start_ms, key.clone())).or_insert(0);
229        *entry += 1;
230        // Emit current count as an update
231        let mut out = serde_json::Map::new();
232        out.insert("window_start_ms".into(), serde_json::json!(win_start_ms));
233        out.insert("key".into(), key);
234        out.insert("count".into(), serde_json::json!(*entry));
235        ctx.collect(Record {
236            event_time: rec.event_time,
237            value: serde_json::Value::Object(out),
238        });
239        Ok(())
240    }
241    async fn on_watermark(&mut self, _ctx: &mut dyn Context, _wm: Watermark) -> Result<()> {
242        Ok(())
243    }
244}
245
246pub mod prelude {
247    pub use super::{
248        AggKind, Aggregate, AggregationKind, Filter, FnFilter, FnMap, KeyBy, Map, WindowKind, WindowTumbling,
249        WindowedAggregate,
250    };
251}
252
253// ===== Windowed, configurable aggregations =====
254
255/// Kinds of windows supported by `WindowedAggregate`.
256#[derive(Clone, Debug)]
257pub enum WindowKind {
258    Tumbling { size_ms: i64 },
259    Sliding { size_ms: i64, slide_ms: i64 },
260    Session { gap_ms: i64 },
261}
262
263/// Supported aggregation kinds for `WindowedAggregate`.
264#[derive(Clone, Debug)]
265pub enum AggKind {
266    Count,
267    Sum { field: String },
268    Avg { field: String },
269    Distinct { field: String },
270}
271
272#[derive(Clone, Debug, Default)]
273enum AggState {
274    #[default]
275    Empty,
276    Count(i64),
277    Sum {
278        sum: f64,
279        count: i64,
280    }, // count is reused for avg
281    Distinct(std::collections::HashSet<String>),
282}
283
284fn as_f64(v: &serde_json::Value) -> f64 {
285    match v {
286        serde_json::Value::Number(n) => n.as_f64().unwrap_or(0.0),
287        serde_json::Value::String(s) => s.parse::<f64>().unwrap_or(0.0),
288        _ => 0.0,
289    }
290}
291
292fn stringify(v: &serde_json::Value) -> String {
293    match v {
294        serde_json::Value::String(s) => s.clone(),
295        other => other.to_string(),
296    }
297}
298
299/// A stateful windowed aggregation operator supporting different windows & aggregations.
300/// A stateful windowed aggregation operator supporting different windows & aggregations.
301///
302/// Examples
303/// ```no_run
304/// use pulse_ops::WindowedAggregate;
305/// // Tumbling count of words per 60s window
306/// let op = WindowedAggregate::tumbling_count("word", 60_000);
307/// # let _ = op;
308/// ```
309pub struct WindowedAggregate {
310    pub key_field: String,
311    pub win: WindowKind,
312    pub agg: AggKind,
313    // For tumbling/sliding: (end_ms, key) -> state, and track start_ms via map
314    by_window: HashMap<(i128, serde_json::Value), (i128 /*start_ms*/, AggState)>,
315    // For session: key -> (start_ms, last_seen_ms, state)
316    sessions: HashMap<serde_json::Value, (i128, i128, AggState)>,
317    // Allowed lateness in milliseconds: postpone closing windows until wm - lateness >= end
318    allowed_lateness_ms: i64,
319    // Last observed watermark in ms to evaluate late events
320    last_wm_ms: Option<i128>,
321    late_policy: LateDataPolicy,
322}
323
324#[derive(Clone, Debug)]
325enum LateDataPolicy {
326    Drop,
327}
328
329impl WindowedAggregate {
330    pub fn tumbling_count(key_field: impl Into<String>, size_ms: i64) -> Self {
331        Self {
332            key_field: key_field.into(),
333            win: WindowKind::Tumbling { size_ms },
334            agg: AggKind::Count,
335            by_window: HashMap::new(),
336            sessions: HashMap::new(),
337            allowed_lateness_ms: 0,
338            last_wm_ms: None,
339            late_policy: LateDataPolicy::Drop,
340        }
341    }
342    pub fn tumbling_sum(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
343        Self {
344            key_field: key_field.into(),
345            win: WindowKind::Tumbling { size_ms },
346            agg: AggKind::Sum { field: field.into() },
347            by_window: HashMap::new(),
348            sessions: HashMap::new(),
349            allowed_lateness_ms: 0,
350            last_wm_ms: None,
351            late_policy: LateDataPolicy::Drop,
352        }
353    }
354    pub fn tumbling_avg(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
355        Self {
356            key_field: key_field.into(),
357            win: WindowKind::Tumbling { size_ms },
358            agg: AggKind::Avg { field: field.into() },
359            by_window: HashMap::new(),
360            sessions: HashMap::new(),
361            allowed_lateness_ms: 0,
362            last_wm_ms: None,
363            late_policy: LateDataPolicy::Drop,
364        }
365    }
366    pub fn tumbling_distinct(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
367        Self {
368            key_field: key_field.into(),
369            win: WindowKind::Tumbling { size_ms },
370            agg: AggKind::Distinct { field: field.into() },
371            by_window: HashMap::new(),
372            sessions: HashMap::new(),
373            allowed_lateness_ms: 0,
374            last_wm_ms: None,
375            late_policy: LateDataPolicy::Drop,
376        }
377    }
378
379    pub fn sliding_count(key_field: impl Into<String>, size_ms: i64, slide_ms: i64) -> Self {
380        Self {
381            key_field: key_field.into(),
382            win: WindowKind::Sliding { size_ms, slide_ms },
383            agg: AggKind::Count,
384            by_window: HashMap::new(),
385            sessions: HashMap::new(),
386            allowed_lateness_ms: 0,
387            last_wm_ms: None,
388            late_policy: LateDataPolicy::Drop,
389        }
390    }
391    pub fn sliding_sum(
392        key_field: impl Into<String>,
393        size_ms: i64,
394        slide_ms: i64,
395        field: impl Into<String>,
396    ) -> Self {
397        Self {
398            key_field: key_field.into(),
399            win: WindowKind::Sliding { size_ms, slide_ms },
400            agg: AggKind::Sum { field: field.into() },
401            by_window: HashMap::new(),
402            sessions: HashMap::new(),
403            allowed_lateness_ms: 0,
404            last_wm_ms: None,
405            late_policy: LateDataPolicy::Drop,
406        }
407    }
408    pub fn sliding_avg(
409        key_field: impl Into<String>,
410        size_ms: i64,
411        slide_ms: i64,
412        field: impl Into<String>,
413    ) -> Self {
414        Self {
415            key_field: key_field.into(),
416            win: WindowKind::Sliding { size_ms, slide_ms },
417            agg: AggKind::Avg { field: field.into() },
418            by_window: HashMap::new(),
419            sessions: HashMap::new(),
420            allowed_lateness_ms: 0,
421            last_wm_ms: None,
422            late_policy: LateDataPolicy::Drop,
423        }
424    }
425    pub fn sliding_distinct(
426        key_field: impl Into<String>,
427        size_ms: i64,
428        slide_ms: i64,
429        field: impl Into<String>,
430    ) -> Self {
431        Self {
432            key_field: key_field.into(),
433            win: WindowKind::Sliding { size_ms, slide_ms },
434            agg: AggKind::Distinct { field: field.into() },
435            by_window: HashMap::new(),
436            sessions: HashMap::new(),
437            allowed_lateness_ms: 0,
438            last_wm_ms: None,
439            late_policy: LateDataPolicy::Drop,
440        }
441    }
442
443    pub fn session_count(key_field: impl Into<String>, gap_ms: i64) -> Self {
444        Self {
445            key_field: key_field.into(),
446            win: WindowKind::Session { gap_ms },
447            agg: AggKind::Count,
448            by_window: HashMap::new(),
449            sessions: HashMap::new(),
450            allowed_lateness_ms: 0,
451            last_wm_ms: None,
452            late_policy: LateDataPolicy::Drop,
453        }
454    }
455    pub fn session_sum(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
456        Self {
457            key_field: key_field.into(),
458            win: WindowKind::Session { gap_ms },
459            agg: AggKind::Sum { field: field.into() },
460            by_window: HashMap::new(),
461            sessions: HashMap::new(),
462            allowed_lateness_ms: 0,
463            last_wm_ms: None,
464            late_policy: LateDataPolicy::Drop,
465        }
466    }
467    pub fn session_avg(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
468        Self {
469            key_field: key_field.into(),
470            win: WindowKind::Session { gap_ms },
471            agg: AggKind::Avg { field: field.into() },
472            by_window: HashMap::new(),
473            sessions: HashMap::new(),
474            allowed_lateness_ms: 0,
475            last_wm_ms: None,
476            late_policy: LateDataPolicy::Drop,
477        }
478    }
479    pub fn session_distinct(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
480        Self {
481            key_field: key_field.into(),
482            win: WindowKind::Session { gap_ms },
483            agg: AggKind::Distinct { field: field.into() },
484            by_window: HashMap::new(),
485            sessions: HashMap::new(),
486            allowed_lateness_ms: 0,
487            last_wm_ms: None,
488            late_policy: LateDataPolicy::Drop,
489        }
490    }
491
492    pub fn with_allowed_lateness(mut self, ms: i64) -> Self {
493        self.allowed_lateness_ms = ms.max(0);
494        self
495    }
496}
497
498fn update_state(state: &mut AggState, agg: &AggKind, value: &serde_json::Value) {
499    match agg {
500        AggKind::Count => {
501            *state = match std::mem::take(state) {
502                AggState::Empty => AggState::Count(1),
503                AggState::Count(c) => AggState::Count(c + 1),
504                other => other,
505            };
506        }
507        AggKind::Sum { field } => {
508            let x = as_f64(value.get(field).unwrap_or(&serde_json::Value::Null));
509            *state = match std::mem::take(state) {
510                AggState::Empty => AggState::Sum { sum: x, count: 1 },
511                AggState::Sum { sum, count } => AggState::Sum {
512                    sum: sum + x,
513                    count: count + 1,
514                },
515                other => other,
516            };
517        }
518        AggKind::Avg { field } => {
519            let x = as_f64(value.get(field).unwrap_or(&serde_json::Value::Null));
520            *state = match std::mem::take(state) {
521                AggState::Empty => AggState::Sum { sum: x, count: 1 },
522                AggState::Sum { sum, count } => AggState::Sum {
523                    sum: sum + x,
524                    count: count + 1,
525                },
526                other => other,
527            };
528        }
529        AggKind::Distinct { field } => {
530            let s = stringify(value.get(field).unwrap_or(&serde_json::Value::Null));
531            *state = match std::mem::take(state) {
532                AggState::Empty => {
533                    let mut set = std::collections::HashSet::new();
534                    set.insert(s);
535                    AggState::Distinct(set)
536                }
537                AggState::Distinct(mut set) => {
538                    set.insert(s);
539                    AggState::Distinct(set)
540                }
541                other => other,
542            };
543        }
544    }
545}
546
547fn finalize_value(state: &AggState, agg: &AggKind) -> serde_json::Value {
548    match (state, agg) {
549        (AggState::Count(c), _) => serde_json::json!(*c),
550        (AggState::Sum { sum, .. }, AggKind::Sum { .. }) => serde_json::json!(sum),
551        (AggState::Sum { sum, count }, AggKind::Avg { .. }) => {
552            let avg = if *count > 0 { *sum / (*count as f64) } else { 0.0 };
553            serde_json::json!(avg)
554        }
555        (AggState::Distinct(set), AggKind::Distinct { .. }) => serde_json::json!(set.len() as i64),
556        _ => serde_json::json!(null),
557    }
558}
559
560#[async_trait]
561impl Operator for WindowedAggregate {
562    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
563        let ts_ms = rec.event_time.timestamp_millis() as i128; // ms
564        // Late data handling: if we have a watermark and this event is older than (wm - allowed_lateness), drop
565        if let Some(wm) = self.last_wm_ms {
566            if ts_ms < (wm - (self.allowed_lateness_ms as i128)) {
567                match self.late_policy {
568                    LateDataPolicy::Drop => return Ok(()),
569                }
570            }
571        }
572        let key = rec
573            .value
574            .get(&self.key_field)
575            .cloned()
576            .unwrap_or(serde_json::Value::Null);
577
578        match self.win {
579            WindowKind::Tumbling { size_ms } => {
580                let start = (ts_ms / (size_ms as i128)) * (size_ms as i128);
581                let end = start + (size_ms as i128);
582                let entry = self
583                    .by_window
584                    .entry((end, key.clone()))
585                    .or_insert((start, AggState::Empty));
586                update_state(&mut entry.1, &self.agg, &rec.value);
587                // Optional: schedule a timer at end
588                let _ = ctx
589                    .timers()
590                    .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
591                    .await;
592            }
593            WindowKind::Sliding { size_ms, slide_ms } => {
594                let k = (size_ms / slide_ms) as i128;
595                let anchor = (ts_ms / (slide_ms as i128)) * (slide_ms as i128);
596                for j in 0..k {
597                    let start = anchor - (j * (slide_ms as i128));
598                    let end = start + (size_ms as i128);
599                    if start <= ts_ms && end > ts_ms {
600                        let entry = self
601                            .by_window
602                            .entry((end, key.clone()))
603                            .or_insert((start, AggState::Empty));
604                        update_state(&mut entry.1, &self.agg, &rec.value);
605                        let _ = ctx
606                            .timers()
607                            .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
608                            .await;
609                    }
610                }
611            }
612            WindowKind::Session { gap_ms } => {
613                let e = self
614                    .sessions
615                    .entry(key.clone())
616                    .or_insert((ts_ms, ts_ms, AggState::Empty));
617                let (start, last_seen, state) = e;
618                if ts_ms - *last_seen <= (gap_ms as i128) {
619                    *last_seen = ts_ms;
620                    update_state(state, &self.agg, &rec.value);
621                } else {
622                    // close previous session
623                    let mut out = serde_json::Map::new();
624                    out.insert("window_start_ms".into(), serde_json::json!(*start));
625                    out.insert(
626                        "window_end_ms".into(),
627                        serde_json::json!(*last_seen + (gap_ms as i128)),
628                    );
629                    out.insert("key".into(), key.clone());
630                    let val = finalize_value(state, &self.agg);
631                    match self.agg {
632                        AggKind::Count => {
633                            out.insert("count".into(), val);
634                        }
635                        AggKind::Sum { .. } => {
636                            out.insert("sum".into(), val);
637                        }
638                        AggKind::Avg { .. } => {
639                            out.insert("avg".into(), val);
640                        }
641                        AggKind::Distinct { .. } => {
642                            out.insert("distinct_count".into(), val);
643                        }
644                    }
645                    ctx.collect(Record {
646                        event_time: rec.event_time,
647                        value: serde_json::Value::Object(out),
648                    });
649                    // start new
650                    *start = ts_ms;
651                    *last_seen = ts_ms;
652                    *state = AggState::Empty;
653                    update_state(state, &self.agg, &rec.value);
654                }
655                // schedule close timer
656                let end = ts_ms + (gap_ms as i128);
657                let _ = ctx
658                    .timers()
659                    .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
660                    .await;
661            }
662        }
663        Ok(())
664    }
665
666    async fn on_watermark(&mut self, ctx: &mut dyn Context, wm: Watermark) -> Result<()> {
667        let wm_ms_raw = wm.0 .0.timestamp_millis() as i128;
668        let wm_ms = wm_ms_raw - (self.allowed_lateness_ms as i128);
669        self.last_wm_ms = Some(wm_ms_raw);
670
671        match self.win {
672            WindowKind::Tumbling { .. } | WindowKind::Sliding { .. } => {
673                // Emit and clear all windows with end <= wm
674                let mut to_emit: Vec<((i128, serde_json::Value), (i128, AggState))> = self
675                    .by_window
676                    .iter()
677                    .filter(|((end, _), _)| *end <= wm_ms)
678                    .map(|(k, v)| (k.clone(), v.clone()))
679                    .collect();
680                for ((end, key), (start, state)) in to_emit.drain(..) {
681                    let mut out = serde_json::Map::new();
682                    out.insert("window_start_ms".into(), serde_json::json!(start));
683                    out.insert("window_end_ms".into(), serde_json::json!(end));
684                    out.insert("key".into(), key.clone());
685                    let val = finalize_value(&state, &self.agg);
686                    match self.agg {
687                        AggKind::Count => {
688                            out.insert("count".into(), val);
689                        }
690                        AggKind::Sum { .. } => {
691                            out.insert("sum".into(), val);
692                        }
693                        AggKind::Avg { .. } => {
694                            out.insert("avg".into(), val);
695                        }
696                        AggKind::Distinct { .. } => {
697                            out.insert("distinct_count".into(), val);
698                        }
699                    }
700                    ctx.collect(Record { event_time: wm.0 .0, value: serde_json::Value::Object(out) });
701                    self.by_window.remove(&(end, key));
702                }
703            }
704            WindowKind::Session { gap_ms } => {
705                // Close sessions whose inactivity + gap <= wm
706                let keys: Vec<_> = self.sessions.keys().cloned().collect();
707                for key in keys {
708                    if let Some((start, last_seen, state)) = self.sessions.get(&key).cloned() {
709                        if last_seen + (gap_ms as i128) <= wm_ms {
710                            let mut out = serde_json::Map::new();
711                            out.insert("window_start_ms".into(), serde_json::json!(start));
712                            out.insert(
713                                "window_end_ms".into(),
714                                serde_json::json!(last_seen + (gap_ms as i128)),
715                            );
716                            out.insert("key".into(), key.clone());
717                            let val = finalize_value(&state, &self.agg);
718                            match self.agg {
719                                AggKind::Count => {
720                                    out.insert("count".into(), val);
721                                }
722                                AggKind::Sum { .. } => {
723                                    out.insert("sum".into(), val);
724                                }
725                                AggKind::Avg { .. } => {
726                                    out.insert("avg".into(), val);
727                                }
728                                AggKind::Distinct { .. } => {
729                                    out.insert("distinct_count".into(), val);
730                                }
731                            }
732                            ctx.collect(Record { event_time: wm.0 .0, value: serde_json::Value::Object(out) });
733                            self.sessions.remove(&key);
734                        }
735                    }
736                }
737            }
738        }
739        Ok(())
740    }
741
742    async fn on_timer(
743        &mut self,
744        ctx: &mut dyn Context,
745        when: EventTime,
746        _key: Option<Vec<u8>>,
747    ) -> Result<()> {
748        // Treat timers same as watermarks for emission, but apply allowed lateness shift.
749        let when_ms = when.0.timestamp_millis() as i128 - (self.allowed_lateness_ms as i128);
750
751        match self.win {
752            WindowKind::Tumbling { .. } | WindowKind::Sliding { .. } => {
753                let mut to_emit: Vec<((i128, serde_json::Value), (i128, AggState))> = self
754                    .by_window
755                    .iter()
756                    .filter(|((end, _), _)| *end <= when_ms)
757                    .map(|(k, v)| (k.clone(), v.clone()))
758                    .collect();
759                for ((end, key), (start, state)) in to_emit.drain(..) {
760                    let mut out = serde_json::Map::new();
761                    out.insert("window_start_ms".into(), serde_json::json!(start));
762                    out.insert("window_end_ms".into(), serde_json::json!(end));
763                    out.insert("key".into(), key.clone());
764                    let val = finalize_value(&state, &self.agg);
765                    match self.agg {
766                        AggKind::Count => {
767                            out.insert("count".into(), val);
768                        }
769                        AggKind::Sum { .. } => {
770                            out.insert("sum".into(), val);
771                        }
772                        AggKind::Avg { .. } => {
773                            out.insert("avg".into(), val);
774                        }
775                        AggKind::Distinct { .. } => {
776                            out.insert("distinct_count".into(), val);
777                        }
778                    }
779                    ctx.collect(Record { event_time: when.0, value: serde_json::Value::Object(out) });
780                    self.by_window.remove(&(end, key));
781                }
782            }
783            WindowKind::Session { gap_ms } => {
784                let keys: Vec<_> = self.sessions.keys().cloned().collect();
785                for key in keys {
786                    if let Some((start, last_seen, state)) = self.sessions.get(&key).cloned() {
787                        if last_seen + (gap_ms as i128) <= when_ms {
788                            let mut out = serde_json::Map::new();
789                            out.insert("window_start_ms".into(), serde_json::json!(start));
790                            out.insert(
791                                "window_end_ms".into(),
792                                serde_json::json!(last_seen + (gap_ms as i128)),
793                            );
794                            out.insert("key".into(), key.clone());
795                            let val = finalize_value(&state, &self.agg);
796                            match self.agg {
797                                AggKind::Count => {
798                                    out.insert("count".into(), val);
799                                }
800                                AggKind::Sum { .. } => {
801                                    out.insert("sum".into(), val);
802                                }
803                                AggKind::Avg { .. } => {
804                                    out.insert("avg".into(), val);
805                                }
806                                AggKind::Distinct { .. } => {
807                                    out.insert("distinct_count".into(), val);
808                                }
809                            }
810                            ctx.collect(Record { event_time: when.0, value: serde_json::Value::Object(out) });
811                            self.sessions.remove(&key);
812                        }
813                    }
814                }
815            }
816        }
817        Ok(())
818    }
819}
820
821#[cfg(test)]
822mod window_tests {
823    use super::*;
824    use pulse_core::{Context, EventTime, KvState, Record, Result, Timers, Watermark};
825    use std::sync::Arc;
826
827    struct TestState;
828    #[async_trait]
829    impl KvState for TestState {
830        async fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>> {
831            Ok(None)
832        }
833        async fn put(&self, _key: &[u8], _value: Vec<u8>) -> Result<()> {
834            Ok(())
835        }
836        async fn delete(&self, _key: &[u8]) -> Result<()> {
837            Ok(())
838        }
839        async fn iter_prefix(&self, _prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
840            Ok(Vec::new())
841        }
842        async fn snapshot(&self) -> Result<pulse_core::SnapshotId> {
843            Ok("test-snap".to_string())
844        }
845        async fn restore(&self, _snapshot: pulse_core::SnapshotId) -> Result<()> {
846            Ok(())
847        }
848    }
849    struct TestTimers;
850    #[async_trait]
851    impl Timers for TestTimers {
852        async fn register_event_time_timer(&self, _when: EventTime, _key: Option<Vec<u8>>) -> Result<()> {
853            Ok(())
854        }
855    }
856
857    struct TestCtx {
858        out: Vec<Record>,
859        kv: Arc<dyn KvState>,
860        timers: Arc<dyn Timers>,
861    }
862    #[async_trait]
863    impl Context for TestCtx {
864        fn collect(&mut self, record: Record) {
865            self.out.push(record);
866        }
867        fn watermark(&mut self, _wm: Watermark) {}
868        fn kv(&self) -> Arc<dyn KvState> {
869            self.kv.clone()
870        }
871        fn timers(&self) -> Arc<dyn Timers> {
872            self.timers.clone()
873        }
874    }
875
876    fn record_with(ts_ms: i128, key: &str) -> Record {
877        Record {
878            event_time: Utc.timestamp_millis_opt(ts_ms as i64).unwrap(),
879            value: serde_json::json!({"word": key}),
880        }
881    }
882
883    #[tokio::test]
884    async fn tumbling_count_emits_on_watermark() {
885        let mut op = WindowedAggregate::tumbling_count("word", 60_000);
886        let mut ctx = TestCtx {
887            out: vec![],
888            kv: Arc::new(TestState),
889            timers: Arc::new(TestTimers),
890        };
891        op.on_element(&mut ctx, record_with(1_000, "a")).await.unwrap();
892        op.on_element(&mut ctx, record_with(1_010, "a")).await.unwrap();
893        // Watermark after end of window 0..60000
894        op.on_watermark(&mut ctx, Watermark(EventTime(Utc.timestamp_millis_opt(60_000).unwrap())))
895            .await
896            .unwrap();
897        assert_eq!(ctx.out.len(), 1);
898        assert_eq!(ctx.out[0].value["count"], serde_json::json!(2));
899    }
900}
901#[cfg(test)]
902mod tests {
903    use super::*;
904    use pulse_core::{Context, EventTime, KvState, Record, Result, Timers};
905    use std::sync::Arc;
906
907    struct TestState;
908    #[async_trait]
909    impl KvState for TestState {
910        async fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>> {
911            Ok(None)
912        }
913        async fn put(&self, _key: &[u8], _value: Vec<u8>) -> Result<()> {
914            Ok(())
915        }
916        async fn delete(&self, _key: &[u8]) -> Result<()> {
917            Ok(())
918        }
919        async fn iter_prefix(&self, _prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
920            Ok(Vec::new())
921        }
922        async fn snapshot(&self) -> Result<pulse_core::SnapshotId> {
923            Ok("test-snap".to_string())
924        }
925        async fn restore(&self, _snapshot: pulse_core::SnapshotId) -> Result<()> {
926            Ok(())
927        }
928    }
929
930    struct TestTimers;
931    #[async_trait]
932    impl Timers for TestTimers {
933        async fn register_event_time_timer(&self, _when: EventTime, _key: Option<Vec<u8>>) -> Result<()> {
934            Ok(())
935        }
936    }
937
938    struct TestCtx {
939        out: Vec<Record>,
940        kv: Arc<dyn KvState>,
941        timers: Arc<dyn Timers>,
942    }
943
944    #[async_trait]
945    impl Context for TestCtx {
946        fn collect(&mut self, record: Record) {
947            self.out.push(record);
948        }
949        fn watermark(&mut self, _wm: pulse_core::Watermark) {}
950        fn kv(&self) -> Arc<dyn KvState> {
951            self.kv.clone()
952        }
953        fn timers(&self) -> Arc<dyn Timers> {
954            self.timers.clone()
955        }
956    }
957
958    fn rec(v: serde_json::Value) -> Record {
959        Record {
960            event_time: Utc::now(),
961            value: v,
962        }
963    }
964
965    #[tokio::test]
966    async fn test_map() {
967        let mut op = Map::new(MapFn::new(|v| vec![v]));
968        let mut ctx = TestCtx {
969            out: vec![],
970            kv: Arc::new(TestState),
971            timers: Arc::new(TestTimers),
972        };
973        op.on_element(&mut ctx, rec(serde_json::json!({"a":1})))
974            .await
975            .unwrap();
976        assert_eq!(ctx.out.len(), 1);
977    }
978
979    #[tokio::test]
980    async fn test_filter() {
981        let mut op = Filter::new(FilterFn::new(|v: &serde_json::Value| {
982            v.get("ok").and_then(|x| x.as_bool()).unwrap_or(false)
983        }));
984        let mut ctx = TestCtx {
985            out: vec![],
986            kv: Arc::new(TestState),
987            timers: Arc::new(TestTimers),
988        };
989        op.on_element(&mut ctx, rec(serde_json::json!({"ok":false})))
990            .await
991            .unwrap();
992        op.on_element(&mut ctx, rec(serde_json::json!({"ok":true})))
993            .await
994            .unwrap();
995        assert_eq!(ctx.out.len(), 1);
996    }
997
998    #[tokio::test]
999    async fn test_keyby() {
1000        let mut op = KeyBy::new("word");
1001        let mut ctx = TestCtx {
1002            out: vec![],
1003            kv: Arc::new(TestState),
1004            timers: Arc::new(TestTimers),
1005        };
1006        op.on_element(&mut ctx, rec(serde_json::json!({"word":"hi"})))
1007            .await
1008            .unwrap();
1009        assert_eq!(ctx.out.len(), 1);
1010        assert_eq!(ctx.out[0].value["key"], serde_json::json!("hi"));
1011    }
1012
1013    #[tokio::test]
1014    async fn test_aggregate_count() {
1015        let mut op = Aggregate::count_per_window("key", "word");
1016        let mut ctx = TestCtx {
1017            out: vec![],
1018            kv: Arc::new(TestState),
1019            timers: Arc::new(TestTimers),
1020        };
1021        op.on_element(&mut ctx, rec(serde_json::json!({"key":"hello"})))
1022            .await
1023            .unwrap();
1024        op.on_element(&mut ctx, rec(serde_json::json!({"key":"hello"})))
1025            .await
1026            .unwrap();
1027        assert_eq!(ctx.out.len(), 2);
1028        assert_eq!(ctx.out[1].value["count"], serde_json::json!(2));
1029    }
1030
1031    #[tokio::test]
1032    async fn windowed_allowed_lateness_defers_emission() {
1033        let mut op = WindowedAggregate::tumbling_count("word", 60_000).with_allowed_lateness(30_000);
1034        let mut ctx = TestCtx { out: vec![], kv: Arc::new(TestState), timers: Arc::new(TestTimers) };
1035        // Two events in first minute window
1036        op.on_element(&mut ctx, rec(serde_json::json!({"word":"a"}))).await.unwrap();
1037        op.on_element(&mut ctx, rec(serde_json::json!({"word":"a"}))).await.unwrap();
1038        // Watermark at window end should NOT emit due to allowed lateness of 30s
1039        let base = Utc::now();
1040        let end_ms = ((base.timestamp_millis()/60_000)*60_000 + 60_000) as i64;
1041        op.on_watermark(&mut ctx, Watermark(EventTime(Utc.timestamp_millis_opt(end_ms).unwrap()))).await.unwrap();
1042        assert!(ctx.out.is_empty());
1043        // After lateness passes, emission should occur
1044        op.on_watermark(&mut ctx, Watermark(EventTime(Utc.timestamp_millis_opt(end_ms + 30_000).unwrap()))).await.unwrap();
1045        assert!(!ctx.out.is_empty());
1046    }
1047
1048    #[tokio::test]
1049    async fn windowed_agg_avg_and_distinct() {
1050        let mut avg_op = WindowedAggregate::tumbling_avg("key", 60_000, "x");
1051        let mut distinct_op = WindowedAggregate::tumbling_distinct("key", 60_000, "s");
1052        let mut ctx = TestCtx { out: vec![], kv: Arc::new(TestState), timers: Arc::new(TestTimers) };
1053        // feed two records in same window
1054        avg_op.on_element(&mut ctx, rec(serde_json::json!({"key":"k","x": 1}))).await.unwrap();
1055        avg_op.on_element(&mut ctx, rec(serde_json::json!({"key":"k","x": 3}))).await.unwrap();
1056        // watermark end of window
1057        let wm = pulse_core::Watermark(pulse_core::EventTime(Utc.timestamp_millis_opt(((Utc::now().timestamp_millis()/60_000)*60_000 + 60_000) as i64).unwrap()));
1058        avg_op.on_watermark(&mut ctx, wm).await.unwrap();
1059        // Expect one output with avg=2.0
1060        assert!(ctx.out.iter().any(|r| r.value.get("avg").is_some()));
1061        // Reset output for distinct
1062        ctx.out.clear();
1063        distinct_op.on_element(&mut ctx, rec(serde_json::json!({"key":"k","s":"a"}))).await.unwrap();
1064        distinct_op.on_element(&mut ctx, rec(serde_json::json!({"key":"k","s":"a"}))).await.unwrap();
1065        distinct_op.on_element(&mut ctx, rec(serde_json::json!({"key":"k","s":"b"}))).await.unwrap();
1066        distinct_op.on_watermark(&mut ctx, wm).await.unwrap();
1067        // Expect distinct_count = 2
1068        assert!(ctx.out.iter().any(|r| r.value.get("distinct_count").and_then(|v| v.as_i64()).unwrap_or(0) == 2));
1069    }
1070}