pulse_ops/
lib.rs

1//! pulse-ops: standard operators built on top of pulse-core.
2//!
3//! Included operators:
4//! - `Map`: one-to-many mapping of JSON payloads
5//! - `Filter`: predicate-based filtering
6//! - `KeyBy`: materialize a `key` field from an existing field
7//! - `Aggregate` (simplified): per-minute running count updates
8//! - `WindowedAggregate`: configurable windows (tumbling/sliding/session) with count/sum/avg/distinct
9
10use std::collections::HashMap;
11
12use async_trait::async_trait;
13use chrono::{TimeZone, Utc};
14use pulse_core::{Context, EventTime, Operator, Record, Result, Watermark};
15use tracing::{info_span, instrument};
16pub mod time;
17pub mod window;
18pub use time::{WatermarkClock, WatermarkPolicy};
19pub use window::{Window, WindowAssigner, WindowOperator};
20
21#[async_trait]
22pub trait FnMap: Send + Sync {
23    async fn call(&self, value: serde_json::Value) -> Result<Vec<serde_json::Value>>;
24}
25
26pub struct MapFn<F>(pub F);
27impl<F> MapFn<F> {
28    pub fn new(f: F) -> Self {
29        Self(f)
30    }
31}
32#[async_trait]
33impl<F> FnMap for MapFn<F>
34where
35    F: Fn(serde_json::Value) -> Vec<serde_json::Value> + Send + Sync,
36{
37    async fn call(&self, value: serde_json::Value) -> Result<Vec<serde_json::Value>> {
38        Ok((self.0)(value))
39    }
40}
41
42/// Map operator: applies a user function that returns zero or more outputs per input.
43/// Map operator: applies a user function that returns zero or more outputs per input.
44///
45/// Example
46/// ```no_run
47/// use pulse_ops::{Map, MapFn};
48/// let map = Map::new(MapFn::new(|v: serde_json::Value| vec![v]));
49/// # let _ = map;
50/// ```
51pub struct Map<F> {
52    func: F,
53}
54impl<F> Map<F> {
55    pub fn new(func: F) -> Self {
56        Self { func }
57    }
58}
59
60#[async_trait]
61impl<F> Operator for Map<F>
62where
63    F: FnMap + Send + Sync + 'static,
64{
65    #[instrument(name = "map_on_element", skip_all)]
66    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
67        let outs = self.func.call(rec.value).await?;
68        pulse_core::metrics::OP_THROUGHPUT
69            .with_label_values(&["Map", "receive"])
70            .inc();
71        for v in outs {
72            ctx.collect(Record {
73                event_time: rec.event_time,
74                value: v.clone(),
75            });
76            pulse_core::metrics::OP_THROUGHPUT
77                .with_label_values(&["Map", "emit"])
78                .inc();
79        }
80        Ok(())
81    }
82}
83
84#[async_trait]
85pub trait FnFilter: Send + Sync {
86    async fn call(&self, value: &serde_json::Value) -> Result<bool>;
87}
88
89pub struct FilterFn<F>(pub F);
90impl<F> FilterFn<F> {
91    pub fn new(f: F) -> Self {
92        Self(f)
93    }
94}
95#[async_trait]
96impl<F> FnFilter for FilterFn<F>
97where
98    F: Fn(&serde_json::Value) -> bool + Send + Sync,
99{
100    async fn call(&self, value: &serde_json::Value) -> Result<bool> {
101        Ok((self.0)(value))
102    }
103}
104
105/// Filter operator: keeps inputs that satisfy the predicate.
106///
107/// Example
108/// ```no_run
109/// use pulse_ops::{Filter, FilterFn};
110/// let filter = Filter::new(FilterFn::new(|v: &serde_json::Value| v.get("ok").and_then(|x| x.as_bool()).unwrap_or(false)));
111/// # let _ = filter;
112/// ```
113pub struct Filter<F> {
114    pred: F,
115}
116impl<F> Filter<F> {
117    pub fn new(pred: F) -> Self {
118        Self { pred }
119    }
120}
121
122#[async_trait]
123impl<F> Operator for Filter<F>
124where
125    F: FnFilter + Send + Sync + 'static,
126{
127    #[instrument(name = "filter_on_element", skip_all)]
128    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
129        pulse_core::metrics::OP_THROUGHPUT
130            .with_label_values(&["Filter", "receive"])
131            .inc();
132        if self.pred.call(&rec.value).await? {
133            ctx.collect(rec);
134            pulse_core::metrics::OP_THROUGHPUT
135                .with_label_values(&["Filter", "emit"])
136                .inc();
137        }
138        Ok(())
139    }
140}
141
142/// KeyBy operator: copies an existing field into a canonical `key` field.
143///
144/// Example
145/// ```no_run
146/// use pulse_ops::KeyBy;
147/// let key_by = KeyBy::new("word");
148/// # let _ = key_by;
149/// ```
150pub struct KeyBy {
151    field: String,
152}
153impl KeyBy {
154    pub fn new(field: impl Into<String>) -> Self {
155        Self { field: field.into() }
156    }
157}
158
159#[async_trait]
160impl Operator for KeyBy {
161    #[instrument(name = "keyby_on_element", skip_all)]
162    async fn on_element(&mut self, ctx: &mut dyn Context, mut rec: Record) -> Result<()> {
163        pulse_core::metrics::OP_THROUGHPUT
164            .with_label_values(&["KeyBy", "receive"])
165            .inc();
166        let key = rec
167            .value
168            .get(&self.field)
169            .cloned()
170            .unwrap_or(serde_json::Value::Null);
171        let mut obj = match rec.value {
172            serde_json::Value::Object(o) => o,
173            _ => serde_json::Map::new(),
174        };
175        obj.insert("key".to_string(), key);
176        rec.value = serde_json::Value::Object(obj);
177        ctx.collect(rec);
178        pulse_core::metrics::OP_THROUGHPUT
179            .with_label_values(&["KeyBy", "emit"])
180            .inc();
181        Ok(())
182    }
183}
184
185/// Fixed-size tumbling window helper (legacy from the simple Aggregate).
186#[derive(Clone, Copy)]
187pub struct WindowTumbling {
188    pub size_ms: i64,
189}
190impl WindowTumbling {
191    pub fn minutes(m: i64) -> Self {
192        Self { size_ms: m * 60_000 }
193    }
194}
195
196/// Simple aggregate that maintains a per-minute count per `key_field`.
197/// Simple aggregate that maintains a per-minute count per `key_field`.
198///
199/// Example
200/// ```no_run
201/// use pulse_ops::Aggregate;
202/// let agg = Aggregate::count_per_window("key", "word");
203/// # let _ = agg;
204/// ```
205pub struct Aggregate {
206    pub key_field: String,
207    pub value_field: String,
208    pub op: AggregationKind,
209    windows: HashMap<(i128, serde_json::Value), i64>, // (window_start, key) -> count
210}
211
212/// Supported aggregation kinds for the simple `Aggregate`.
213#[derive(Clone, Copy)]
214pub enum AggregationKind {
215    Count,
216}
217
218impl Aggregate {
219    pub fn count_per_window(key_field: impl Into<String>, value_field: impl Into<String>) -> Self {
220        Self {
221            key_field: key_field.into(),
222            value_field: value_field.into(),
223            op: AggregationKind::Count,
224            windows: HashMap::new(),
225        }
226    }
227}
228
229#[async_trait]
230impl Operator for Aggregate {
231    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
232        let minute_ms = 60_000_i128;
233        let ts_ms = rec.event_time.timestamp_millis() as i128; // ms
234        let win_start_ms = (ts_ms / minute_ms) * minute_ms;
235        let key = rec
236            .value
237            .get(&self.key_field)
238            .cloned()
239            .unwrap_or(serde_json::Value::Null);
240        let entry = self.windows.entry((win_start_ms, key.clone())).or_insert(0);
241        *entry += 1;
242        // Emit current count as an update
243        let mut out = serde_json::Map::new();
244        out.insert("window_start_ms".into(), serde_json::json!(win_start_ms));
245        out.insert("key".into(), key);
246        out.insert("count".into(), serde_json::json!(*entry));
247        ctx.collect(Record {
248            event_time: rec.event_time,
249            value: serde_json::Value::Object(out),
250        });
251        Ok(())
252    }
253    async fn on_watermark(&mut self, _ctx: &mut dyn Context, _wm: Watermark) -> Result<()> {
254        Ok(())
255    }
256}
257
258pub mod prelude {
259    pub use super::{
260        AggKind, Aggregate, AggregationKind, Filter, FnFilter, FnMap, KeyBy, Map, WindowKind, WindowTumbling,
261        WindowedAggregate,
262    };
263}
264
265// ===== Windowed, configurable aggregations =====
266
267/// Kinds of windows supported by `WindowedAggregate`.
268#[derive(Clone, Debug)]
269pub enum WindowKind {
270    Tumbling { size_ms: i64 },
271    Sliding { size_ms: i64, slide_ms: i64 },
272    Session { gap_ms: i64 },
273}
274
275/// Supported aggregation kinds for `WindowedAggregate`.
276#[derive(Clone, Debug)]
277pub enum AggKind {
278    Count,
279    Sum { field: String },
280    Avg { field: String },
281    Distinct { field: String },
282}
283
284#[derive(Clone, Debug, Default)]
285enum AggState {
286    #[default]
287    Empty,
288    Count(i64),
289    Sum {
290        sum: f64,
291        count: i64,
292    }, // count is reused for avg
293    Distinct(std::collections::HashSet<String>),
294}
295
296fn as_f64(v: &serde_json::Value) -> f64 {
297    match v {
298        serde_json::Value::Number(n) => n.as_f64().unwrap_or(0.0),
299        serde_json::Value::String(s) => s.parse::<f64>().unwrap_or(0.0),
300        _ => 0.0,
301    }
302}
303
304fn stringify(v: &serde_json::Value) -> String {
305    match v {
306        serde_json::Value::String(s) => s.clone(),
307        other => other.to_string(),
308    }
309}
310
311/// A stateful windowed aggregation operator supporting different windows & aggregations.
312/// A stateful windowed aggregation operator supporting different windows & aggregations.
313///
314/// Examples
315/// ```no_run
316/// use pulse_ops::WindowedAggregate;
317/// // Tumbling count of words per 60s window
318/// let op = WindowedAggregate::tumbling_count("word", 60_000);
319/// # let _ = op;
320/// ```
321pub struct WindowedAggregate {
322    pub key_field: String,
323    pub win: WindowKind,
324    pub agg: AggKind,
325    // For tumbling/sliding: (end_ms, key) -> state, and track start_ms via map
326    by_window: HashMap<(i128, serde_json::Value), (i128 /*start_ms*/, AggState)>,
327    // For session: key -> (start_ms, last_seen_ms, state)
328    sessions: HashMap<serde_json::Value, (i128, i128, AggState)>,
329    // Allowed lateness in milliseconds: postpone closing windows until wm - lateness >= end
330    allowed_lateness_ms: i64,
331    // Last observed watermark in ms to evaluate late events
332    last_wm_ms: Option<i128>,
333    late_policy: LateDataPolicy,
334}
335
336#[derive(Clone, Debug)]
337enum LateDataPolicy {
338    Drop,
339}
340
341impl WindowedAggregate {
342    pub fn tumbling_count(key_field: impl Into<String>, size_ms: i64) -> Self {
343        Self {
344            key_field: key_field.into(),
345            win: WindowKind::Tumbling { size_ms },
346            agg: AggKind::Count,
347            by_window: HashMap::new(),
348            sessions: HashMap::new(),
349            allowed_lateness_ms: 0,
350            last_wm_ms: None,
351            late_policy: LateDataPolicy::Drop,
352        }
353    }
354    pub fn tumbling_sum(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
355        Self {
356            key_field: key_field.into(),
357            win: WindowKind::Tumbling { size_ms },
358            agg: AggKind::Sum { field: field.into() },
359            by_window: HashMap::new(),
360            sessions: HashMap::new(),
361            allowed_lateness_ms: 0,
362            last_wm_ms: None,
363            late_policy: LateDataPolicy::Drop,
364        }
365    }
366    pub fn tumbling_avg(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
367        Self {
368            key_field: key_field.into(),
369            win: WindowKind::Tumbling { size_ms },
370            agg: AggKind::Avg { field: field.into() },
371            by_window: HashMap::new(),
372            sessions: HashMap::new(),
373            allowed_lateness_ms: 0,
374            last_wm_ms: None,
375            late_policy: LateDataPolicy::Drop,
376        }
377    }
378    pub fn tumbling_distinct(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
379        Self {
380            key_field: key_field.into(),
381            win: WindowKind::Tumbling { size_ms },
382            agg: AggKind::Distinct { field: field.into() },
383            by_window: HashMap::new(),
384            sessions: HashMap::new(),
385            allowed_lateness_ms: 0,
386            last_wm_ms: None,
387            late_policy: LateDataPolicy::Drop,
388        }
389    }
390
391    pub fn sliding_count(key_field: impl Into<String>, size_ms: i64, slide_ms: i64) -> Self {
392        Self {
393            key_field: key_field.into(),
394            win: WindowKind::Sliding { size_ms, slide_ms },
395            agg: AggKind::Count,
396            by_window: HashMap::new(),
397            sessions: HashMap::new(),
398            allowed_lateness_ms: 0,
399            last_wm_ms: None,
400            late_policy: LateDataPolicy::Drop,
401        }
402    }
403    pub fn sliding_sum(
404        key_field: impl Into<String>,
405        size_ms: i64,
406        slide_ms: i64,
407        field: impl Into<String>,
408    ) -> Self {
409        Self {
410            key_field: key_field.into(),
411            win: WindowKind::Sliding { size_ms, slide_ms },
412            agg: AggKind::Sum { field: field.into() },
413            by_window: HashMap::new(),
414            sessions: HashMap::new(),
415            allowed_lateness_ms: 0,
416            last_wm_ms: None,
417            late_policy: LateDataPolicy::Drop,
418        }
419    }
420    pub fn sliding_avg(
421        key_field: impl Into<String>,
422        size_ms: i64,
423        slide_ms: i64,
424        field: impl Into<String>,
425    ) -> Self {
426        Self {
427            key_field: key_field.into(),
428            win: WindowKind::Sliding { size_ms, slide_ms },
429            agg: AggKind::Avg { field: field.into() },
430            by_window: HashMap::new(),
431            sessions: HashMap::new(),
432            allowed_lateness_ms: 0,
433            last_wm_ms: None,
434            late_policy: LateDataPolicy::Drop,
435        }
436    }
437    pub fn sliding_distinct(
438        key_field: impl Into<String>,
439        size_ms: i64,
440        slide_ms: i64,
441        field: impl Into<String>,
442    ) -> Self {
443        Self {
444            key_field: key_field.into(),
445            win: WindowKind::Sliding { size_ms, slide_ms },
446            agg: AggKind::Distinct { field: field.into() },
447            by_window: HashMap::new(),
448            sessions: HashMap::new(),
449            allowed_lateness_ms: 0,
450            last_wm_ms: None,
451            late_policy: LateDataPolicy::Drop,
452        }
453    }
454
455    pub fn session_count(key_field: impl Into<String>, gap_ms: i64) -> Self {
456        Self {
457            key_field: key_field.into(),
458            win: WindowKind::Session { gap_ms },
459            agg: AggKind::Count,
460            by_window: HashMap::new(),
461            sessions: HashMap::new(),
462            allowed_lateness_ms: 0,
463            last_wm_ms: None,
464            late_policy: LateDataPolicy::Drop,
465        }
466    }
467    pub fn session_sum(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
468        Self {
469            key_field: key_field.into(),
470            win: WindowKind::Session { gap_ms },
471            agg: AggKind::Sum { field: field.into() },
472            by_window: HashMap::new(),
473            sessions: HashMap::new(),
474            allowed_lateness_ms: 0,
475            last_wm_ms: None,
476            late_policy: LateDataPolicy::Drop,
477        }
478    }
479    pub fn session_avg(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
480        Self {
481            key_field: key_field.into(),
482            win: WindowKind::Session { gap_ms },
483            agg: AggKind::Avg { field: field.into() },
484            by_window: HashMap::new(),
485            sessions: HashMap::new(),
486            allowed_lateness_ms: 0,
487            last_wm_ms: None,
488            late_policy: LateDataPolicy::Drop,
489        }
490    }
491    pub fn session_distinct(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
492        Self {
493            key_field: key_field.into(),
494            win: WindowKind::Session { gap_ms },
495            agg: AggKind::Distinct { field: field.into() },
496            by_window: HashMap::new(),
497            sessions: HashMap::new(),
498            allowed_lateness_ms: 0,
499            last_wm_ms: None,
500            late_policy: LateDataPolicy::Drop,
501        }
502    }
503
504    pub fn with_allowed_lateness(mut self, ms: i64) -> Self {
505        self.allowed_lateness_ms = ms.max(0);
506        self
507    }
508}
509
510fn update_state(state: &mut AggState, agg: &AggKind, value: &serde_json::Value) {
511    match agg {
512        AggKind::Count => {
513            *state = match std::mem::take(state) {
514                AggState::Empty => AggState::Count(1),
515                AggState::Count(c) => AggState::Count(c + 1),
516                other => other,
517            };
518        }
519        AggKind::Sum { field } => {
520            let x = as_f64(value.get(field).unwrap_or(&serde_json::Value::Null));
521            *state = match std::mem::take(state) {
522                AggState::Empty => AggState::Sum { sum: x, count: 1 },
523                AggState::Sum { sum, count } => AggState::Sum {
524                    sum: sum + x,
525                    count: count + 1,
526                },
527                other => other,
528            };
529        }
530        AggKind::Avg { field } => {
531            let x = as_f64(value.get(field).unwrap_or(&serde_json::Value::Null));
532            *state = match std::mem::take(state) {
533                AggState::Empty => AggState::Sum { sum: x, count: 1 },
534                AggState::Sum { sum, count } => AggState::Sum {
535                    sum: sum + x,
536                    count: count + 1,
537                },
538                other => other,
539            };
540        }
541        AggKind::Distinct { field } => {
542            let s = stringify(value.get(field).unwrap_or(&serde_json::Value::Null));
543            *state = match std::mem::take(state) {
544                AggState::Empty => {
545                    let mut set = std::collections::HashSet::new();
546                    set.insert(s);
547                    AggState::Distinct(set)
548                }
549                AggState::Distinct(mut set) => {
550                    set.insert(s);
551                    AggState::Distinct(set)
552                }
553                other => other,
554            };
555        }
556    }
557}
558
559fn finalize_value(state: &AggState, agg: &AggKind) -> serde_json::Value {
560    match (state, agg) {
561        (AggState::Count(c), _) => serde_json::json!(*c),
562        (AggState::Sum { sum, .. }, AggKind::Sum { .. }) => serde_json::json!(sum),
563        (AggState::Sum { sum, count }, AggKind::Avg { .. }) => {
564            let avg = if *count > 0 { *sum / (*count as f64) } else { 0.0 };
565            serde_json::json!(avg)
566        }
567        (AggState::Distinct(set), AggKind::Distinct { .. }) => serde_json::json!(set.len() as i64),
568        _ => serde_json::json!(null),
569    }
570}
571
572#[async_trait]
573impl Operator for WindowedAggregate {
574    async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
575        let ts_ms = rec.event_time.timestamp_millis() as i128; // ms
576                                                               // Late data handling: if we have a watermark and this event is older than (wm - allowed_lateness), drop
577        if let Some(wm) = self.last_wm_ms {
578            if ts_ms < (wm - (self.allowed_lateness_ms as i128)) {
579                match self.late_policy {
580                    LateDataPolicy::Drop => return Ok(()),
581                }
582            }
583        }
584        let key = rec
585            .value
586            .get(&self.key_field)
587            .cloned()
588            .unwrap_or(serde_json::Value::Null);
589
590        match self.win {
591            WindowKind::Tumbling { size_ms } => {
592                let start = (ts_ms / (size_ms as i128)) * (size_ms as i128);
593                let end = start + (size_ms as i128);
594                let entry = self
595                    .by_window
596                    .entry((end, key.clone()))
597                    .or_insert((start, AggState::Empty));
598                update_state(&mut entry.1, &self.agg, &rec.value);
599                // Optional: schedule a timer at end
600                let _ = ctx
601                    .timers()
602                    .register_event_time_timer(
603                        pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()),
604                        None,
605                    )
606                    .await;
607            }
608            WindowKind::Sliding { size_ms, slide_ms } => {
609                let k = (size_ms / slide_ms) as i128;
610                let anchor = (ts_ms / (slide_ms as i128)) * (slide_ms as i128);
611                for j in 0..k {
612                    let start = anchor - (j * (slide_ms as i128));
613                    let end = start + (size_ms as i128);
614                    if start <= ts_ms && end > ts_ms {
615                        let entry = self
616                            .by_window
617                            .entry((end, key.clone()))
618                            .or_insert((start, AggState::Empty));
619                        update_state(&mut entry.1, &self.agg, &rec.value);
620                        let _ = ctx
621                            .timers()
622                            .register_event_time_timer(
623                                pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()),
624                                None,
625                            )
626                            .await;
627                    }
628                }
629            }
630            WindowKind::Session { gap_ms } => {
631                let e = self
632                    .sessions
633                    .entry(key.clone())
634                    .or_insert((ts_ms, ts_ms, AggState::Empty));
635                let (start, last_seen, state) = e;
636                if ts_ms - *last_seen <= (gap_ms as i128) {
637                    *last_seen = ts_ms;
638                    update_state(state, &self.agg, &rec.value);
639                } else {
640                    // close previous session
641                    let mut out = serde_json::Map::new();
642                    out.insert("window_start_ms".into(), serde_json::json!(*start));
643                    out.insert(
644                        "window_end_ms".into(),
645                        serde_json::json!(*last_seen + (gap_ms as i128)),
646                    );
647                    out.insert("key".into(), key.clone());
648                    let val = finalize_value(state, &self.agg);
649                    match self.agg {
650                        AggKind::Count => {
651                            out.insert("count".into(), val);
652                        }
653                        AggKind::Sum { .. } => {
654                            out.insert("sum".into(), val);
655                        }
656                        AggKind::Avg { .. } => {
657                            out.insert("avg".into(), val);
658                        }
659                        AggKind::Distinct { .. } => {
660                            out.insert("distinct_count".into(), val);
661                        }
662                    }
663                    ctx.collect(Record {
664                        event_time: rec.event_time,
665                        value: serde_json::Value::Object(out),
666                    });
667                    // start new
668                    *start = ts_ms;
669                    *last_seen = ts_ms;
670                    *state = AggState::Empty;
671                    update_state(state, &self.agg, &rec.value);
672                }
673                // schedule close timer
674                let end = ts_ms + (gap_ms as i128);
675                let _ = ctx
676                    .timers()
677                    .register_event_time_timer(
678                        pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()),
679                        None,
680                    )
681                    .await;
682            }
683        }
684        Ok(())
685    }
686
687    async fn on_watermark(&mut self, ctx: &mut dyn Context, wm: Watermark) -> Result<()> {
688        let wm_ms_raw = wm.0 .0.timestamp_millis() as i128;
689        let wm_ms = wm_ms_raw - (self.allowed_lateness_ms as i128);
690        self.last_wm_ms = Some(wm_ms_raw);
691
692        match self.win {
693            WindowKind::Tumbling { .. } | WindowKind::Sliding { .. } => {
694                // Emit and clear all windows with end <= wm
695                let mut to_emit: Vec<((i128, serde_json::Value), (i128, AggState))> = self
696                    .by_window
697                    .iter()
698                    .filter(|((end, _), _)| *end <= wm_ms)
699                    .map(|(k, v)| (k.clone(), v.clone()))
700                    .collect();
701                for ((end, key), (start, state)) in to_emit.drain(..) {
702                    let mut out = serde_json::Map::new();
703                    out.insert("window_start_ms".into(), serde_json::json!(start));
704                    out.insert("window_end_ms".into(), serde_json::json!(end));
705                    out.insert("key".into(), key.clone());
706                    let val = finalize_value(&state, &self.agg);
707                    match self.agg {
708                        AggKind::Count => {
709                            out.insert("count".into(), val);
710                        }
711                        AggKind::Sum { .. } => {
712                            out.insert("sum".into(), val);
713                        }
714                        AggKind::Avg { .. } => {
715                            out.insert("avg".into(), val);
716                        }
717                        AggKind::Distinct { .. } => {
718                            out.insert("distinct_count".into(), val);
719                        }
720                    }
721                    ctx.collect(Record {
722                        event_time: wm.0 .0,
723                        value: serde_json::Value::Object(out),
724                    });
725                    self.by_window.remove(&(end, key));
726                }
727            }
728            WindowKind::Session { gap_ms } => {
729                // Close sessions whose inactivity + gap <= wm
730                let keys: Vec<_> = self.sessions.keys().cloned().collect();
731                for key in keys {
732                    if let Some((start, last_seen, state)) = self.sessions.get(&key).cloned() {
733                        if last_seen + (gap_ms as i128) <= wm_ms {
734                            let mut out = serde_json::Map::new();
735                            out.insert("window_start_ms".into(), serde_json::json!(start));
736                            out.insert(
737                                "window_end_ms".into(),
738                                serde_json::json!(last_seen + (gap_ms as i128)),
739                            );
740                            out.insert("key".into(), key.clone());
741                            let val = finalize_value(&state, &self.agg);
742                            match self.agg {
743                                AggKind::Count => {
744                                    out.insert("count".into(), val);
745                                }
746                                AggKind::Sum { .. } => {
747                                    out.insert("sum".into(), val);
748                                }
749                                AggKind::Avg { .. } => {
750                                    out.insert("avg".into(), val);
751                                }
752                                AggKind::Distinct { .. } => {
753                                    out.insert("distinct_count".into(), val);
754                                }
755                            }
756                            ctx.collect(Record {
757                                event_time: wm.0 .0,
758                                value: serde_json::Value::Object(out),
759                            });
760                            self.sessions.remove(&key);
761                        }
762                    }
763                }
764            }
765        }
766        Ok(())
767    }
768
769    async fn on_timer(
770        &mut self,
771        ctx: &mut dyn Context,
772        when: EventTime,
773        _key: Option<Vec<u8>>,
774    ) -> Result<()> {
775        // Treat timers same as watermarks for emission, but apply allowed lateness shift.
776        let when_ms = when.0.timestamp_millis() as i128 - (self.allowed_lateness_ms as i128);
777
778        match self.win {
779            WindowKind::Tumbling { .. } | WindowKind::Sliding { .. } => {
780                let mut to_emit: Vec<((i128, serde_json::Value), (i128, AggState))> = self
781                    .by_window
782                    .iter()
783                    .filter(|((end, _), _)| *end <= when_ms)
784                    .map(|(k, v)| (k.clone(), v.clone()))
785                    .collect();
786                for ((end, key), (start, state)) in to_emit.drain(..) {
787                    let mut out = serde_json::Map::new();
788                    out.insert("window_start_ms".into(), serde_json::json!(start));
789                    out.insert("window_end_ms".into(), serde_json::json!(end));
790                    out.insert("key".into(), key.clone());
791                    let val = finalize_value(&state, &self.agg);
792                    match self.agg {
793                        AggKind::Count => {
794                            out.insert("count".into(), val);
795                        }
796                        AggKind::Sum { .. } => {
797                            out.insert("sum".into(), val);
798                        }
799                        AggKind::Avg { .. } => {
800                            out.insert("avg".into(), val);
801                        }
802                        AggKind::Distinct { .. } => {
803                            out.insert("distinct_count".into(), val);
804                        }
805                    }
806                    ctx.collect(Record {
807                        event_time: when.0,
808                        value: serde_json::Value::Object(out),
809                    });
810                    self.by_window.remove(&(end, key));
811                }
812            }
813            WindowKind::Session { gap_ms } => {
814                let keys: Vec<_> = self.sessions.keys().cloned().collect();
815                for key in keys {
816                    if let Some((start, last_seen, state)) = self.sessions.get(&key).cloned() {
817                        if last_seen + (gap_ms as i128) <= when_ms {
818                            let mut out = serde_json::Map::new();
819                            out.insert("window_start_ms".into(), serde_json::json!(start));
820                            out.insert(
821                                "window_end_ms".into(),
822                                serde_json::json!(last_seen + (gap_ms as i128)),
823                            );
824                            out.insert("key".into(), key.clone());
825                            let val = finalize_value(&state, &self.agg);
826                            match self.agg {
827                                AggKind::Count => {
828                                    out.insert("count".into(), val);
829                                }
830                                AggKind::Sum { .. } => {
831                                    out.insert("sum".into(), val);
832                                }
833                                AggKind::Avg { .. } => {
834                                    out.insert("avg".into(), val);
835                                }
836                                AggKind::Distinct { .. } => {
837                                    out.insert("distinct_count".into(), val);
838                                }
839                            }
840                            ctx.collect(Record {
841                                event_time: when.0,
842                                value: serde_json::Value::Object(out),
843                            });
844                            self.sessions.remove(&key);
845                        }
846                    }
847                }
848            }
849        }
850        Ok(())
851    }
852}
853
854#[cfg(test)]
855mod window_tests {
856    use super::*;
857    use pulse_core::{Context, EventTime, KvState, Record, Result, Timers, Watermark};
858    use std::sync::Arc;
859
860    struct TestState;
861    #[async_trait]
862    impl KvState for TestState {
863        async fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>> {
864            Ok(None)
865        }
866        async fn put(&self, _key: &[u8], _value: Vec<u8>) -> Result<()> {
867            Ok(())
868        }
869        async fn delete(&self, _key: &[u8]) -> Result<()> {
870            Ok(())
871        }
872        async fn iter_prefix(&self, _prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
873            Ok(Vec::new())
874        }
875        async fn snapshot(&self) -> Result<pulse_core::SnapshotId> {
876            Ok("test-snap".to_string())
877        }
878        async fn restore(&self, _snapshot: pulse_core::SnapshotId) -> Result<()> {
879            Ok(())
880        }
881    }
882    struct TestTimers;
883    #[async_trait]
884    impl Timers for TestTimers {
885        async fn register_event_time_timer(&self, _when: EventTime, _key: Option<Vec<u8>>) -> Result<()> {
886            Ok(())
887        }
888    }
889
890    struct TestCtx {
891        out: Vec<Record>,
892        kv: Arc<dyn KvState>,
893        timers: Arc<dyn Timers>,
894    }
895    #[async_trait]
896    impl Context for TestCtx {
897        fn collect(&mut self, record: Record) {
898            self.out.push(record);
899        }
900        fn watermark(&mut self, _wm: Watermark) {}
901        fn kv(&self) -> Arc<dyn KvState> {
902            self.kv.clone()
903        }
904        fn timers(&self) -> Arc<dyn Timers> {
905            self.timers.clone()
906        }
907    }
908
909    fn record_with(ts_ms: i128, key: &str) -> Record {
910        Record {
911            event_time: Utc.timestamp_millis_opt(ts_ms as i64).unwrap(),
912            value: serde_json::json!({"word": key}),
913        }
914    }
915
916    #[tokio::test]
917    async fn tumbling_count_emits_on_watermark() {
918        let mut op = WindowedAggregate::tumbling_count("word", 60_000);
919        let mut ctx = TestCtx {
920            out: vec![],
921            kv: Arc::new(TestState),
922            timers: Arc::new(TestTimers),
923        };
924        op.on_element(&mut ctx, record_with(1_000, "a")).await.unwrap();
925        op.on_element(&mut ctx, record_with(1_010, "a")).await.unwrap();
926        // Watermark after end of window 0..60000
927        op.on_watermark(
928            &mut ctx,
929            Watermark(EventTime(Utc.timestamp_millis_opt(60_000).unwrap())),
930        )
931        .await
932        .unwrap();
933        assert_eq!(ctx.out.len(), 1);
934        assert_eq!(ctx.out[0].value["count"], serde_json::json!(2));
935    }
936}
937#[cfg(test)]
938mod tests {
939    use super::*;
940    use pulse_core::{Context, EventTime, KvState, Record, Result, Timers};
941    use std::sync::Arc;
942
943    struct TestState;
944    #[async_trait]
945    impl KvState for TestState {
946        async fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>> {
947            Ok(None)
948        }
949        async fn put(&self, _key: &[u8], _value: Vec<u8>) -> Result<()> {
950            Ok(())
951        }
952        async fn delete(&self, _key: &[u8]) -> Result<()> {
953            Ok(())
954        }
955        async fn iter_prefix(&self, _prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
956            Ok(Vec::new())
957        }
958        async fn snapshot(&self) -> Result<pulse_core::SnapshotId> {
959            Ok("test-snap".to_string())
960        }
961        async fn restore(&self, _snapshot: pulse_core::SnapshotId) -> Result<()> {
962            Ok(())
963        }
964    }
965
966    struct TestTimers;
967    #[async_trait]
968    impl Timers for TestTimers {
969        async fn register_event_time_timer(&self, _when: EventTime, _key: Option<Vec<u8>>) -> Result<()> {
970            Ok(())
971        }
972    }
973
974    struct TestCtx {
975        out: Vec<Record>,
976        kv: Arc<dyn KvState>,
977        timers: Arc<dyn Timers>,
978    }
979
980    #[async_trait]
981    impl Context for TestCtx {
982        fn collect(&mut self, record: Record) {
983            self.out.push(record);
984        }
985        fn watermark(&mut self, _wm: pulse_core::Watermark) {}
986        fn kv(&self) -> Arc<dyn KvState> {
987            self.kv.clone()
988        }
989        fn timers(&self) -> Arc<dyn Timers> {
990            self.timers.clone()
991        }
992    }
993
994    fn rec(v: serde_json::Value) -> Record {
995        Record {
996            event_time: Utc::now(),
997            value: v,
998        }
999    }
1000
1001    #[tokio::test]
1002    async fn test_map() {
1003        let mut op = Map::new(MapFn::new(|v| vec![v]));
1004        let mut ctx = TestCtx {
1005            out: vec![],
1006            kv: Arc::new(TestState),
1007            timers: Arc::new(TestTimers),
1008        };
1009        op.on_element(&mut ctx, rec(serde_json::json!({"a":1})))
1010            .await
1011            .unwrap();
1012        assert_eq!(ctx.out.len(), 1);
1013    }
1014
1015    #[tokio::test]
1016    async fn test_filter() {
1017        let mut op = Filter::new(FilterFn::new(|v: &serde_json::Value| {
1018            v.get("ok").and_then(|x| x.as_bool()).unwrap_or(false)
1019        }));
1020        let mut ctx = TestCtx {
1021            out: vec![],
1022            kv: Arc::new(TestState),
1023            timers: Arc::new(TestTimers),
1024        };
1025        op.on_element(&mut ctx, rec(serde_json::json!({"ok":false})))
1026            .await
1027            .unwrap();
1028        op.on_element(&mut ctx, rec(serde_json::json!({"ok":true})))
1029            .await
1030            .unwrap();
1031        assert_eq!(ctx.out.len(), 1);
1032    }
1033
1034    #[tokio::test]
1035    async fn test_keyby() {
1036        let mut op = KeyBy::new("word");
1037        let mut ctx = TestCtx {
1038            out: vec![],
1039            kv: Arc::new(TestState),
1040            timers: Arc::new(TestTimers),
1041        };
1042        op.on_element(&mut ctx, rec(serde_json::json!({"word":"hi"})))
1043            .await
1044            .unwrap();
1045        assert_eq!(ctx.out.len(), 1);
1046        assert_eq!(ctx.out[0].value["key"], serde_json::json!("hi"));
1047    }
1048
1049    #[tokio::test]
1050    async fn test_aggregate_count() {
1051        let mut op = Aggregate::count_per_window("key", "word");
1052        let mut ctx = TestCtx {
1053            out: vec![],
1054            kv: Arc::new(TestState),
1055            timers: Arc::new(TestTimers),
1056        };
1057        op.on_element(&mut ctx, rec(serde_json::json!({"key":"hello"})))
1058            .await
1059            .unwrap();
1060        op.on_element(&mut ctx, rec(serde_json::json!({"key":"hello"})))
1061            .await
1062            .unwrap();
1063        assert_eq!(ctx.out.len(), 2);
1064        assert_eq!(ctx.out[1].value["count"], serde_json::json!(2));
1065    }
1066
1067    #[tokio::test]
1068    async fn windowed_allowed_lateness_defers_emission() {
1069        let mut op = WindowedAggregate::tumbling_count("word", 60_000).with_allowed_lateness(30_000);
1070        let mut ctx = TestCtx {
1071            out: vec![],
1072            kv: Arc::new(TestState),
1073            timers: Arc::new(TestTimers),
1074        };
1075        // Two events in first minute window
1076        op.on_element(&mut ctx, rec(serde_json::json!({"word":"a"})))
1077            .await
1078            .unwrap();
1079        op.on_element(&mut ctx, rec(serde_json::json!({"word":"a"})))
1080            .await
1081            .unwrap();
1082        // Watermark at window end should NOT emit due to allowed lateness of 30s
1083        let base = Utc::now();
1084        let end_ms = ((base.timestamp_millis() / 60_000) * 60_000 + 60_000) as i64;
1085        op.on_watermark(
1086            &mut ctx,
1087            Watermark(EventTime(Utc.timestamp_millis_opt(end_ms).unwrap())),
1088        )
1089        .await
1090        .unwrap();
1091        assert!(ctx.out.is_empty());
1092        // After lateness passes, emission should occur
1093        op.on_watermark(
1094            &mut ctx,
1095            Watermark(EventTime(Utc.timestamp_millis_opt(end_ms + 30_000).unwrap())),
1096        )
1097        .await
1098        .unwrap();
1099        assert!(!ctx.out.is_empty());
1100    }
1101
1102    #[tokio::test]
1103    async fn windowed_agg_avg_and_distinct() {
1104        let mut avg_op = WindowedAggregate::tumbling_avg("key", 60_000, "x");
1105        let mut distinct_op = WindowedAggregate::tumbling_distinct("key", 60_000, "s");
1106        let mut ctx = TestCtx {
1107            out: vec![],
1108            kv: Arc::new(TestState),
1109            timers: Arc::new(TestTimers),
1110        };
1111        // feed two records in same window
1112        avg_op
1113            .on_element(&mut ctx, rec(serde_json::json!({"key":"k","x": 1})))
1114            .await
1115            .unwrap();
1116        avg_op
1117            .on_element(&mut ctx, rec(serde_json::json!({"key":"k","x": 3})))
1118            .await
1119            .unwrap();
1120        // watermark end of window
1121        let wm = pulse_core::Watermark(pulse_core::EventTime(
1122            Utc.timestamp_millis_opt(((Utc::now().timestamp_millis() / 60_000) * 60_000 + 60_000) as i64)
1123                .unwrap(),
1124        ));
1125        avg_op.on_watermark(&mut ctx, wm).await.unwrap();
1126        // Expect one output with avg=2.0
1127        assert!(ctx.out.iter().any(|r| r.value.get("avg").is_some()));
1128        // Reset output for distinct
1129        ctx.out.clear();
1130        distinct_op
1131            .on_element(&mut ctx, rec(serde_json::json!({"key":"k","s":"a"})))
1132            .await
1133            .unwrap();
1134        distinct_op
1135            .on_element(&mut ctx, rec(serde_json::json!({"key":"k","s":"a"})))
1136            .await
1137            .unwrap();
1138        distinct_op
1139            .on_element(&mut ctx, rec(serde_json::json!({"key":"k","s":"b"})))
1140            .await
1141            .unwrap();
1142        distinct_op.on_watermark(&mut ctx, wm).await.unwrap();
1143        // Expect distinct_count = 2
1144        assert!(ctx.out.iter().any(|r| r
1145            .value
1146            .get("distinct_count")
1147            .and_then(|v| v.as_i64())
1148            .unwrap_or(0)
1149            == 2));
1150    }
1151}