1use std::collections::HashMap;
11
12use async_trait::async_trait;
13use tracing::{instrument, info_span};
14use pulse_core::{Context, EventTime, Operator, Record, Result, Watermark};
15use chrono::{TimeZone, Utc};
16pub mod time;
17pub mod window;
18pub use time::{WatermarkClock, WatermarkPolicy};
19pub use window::{Window, WindowAssigner, WindowOperator};
20
21#[async_trait]
22pub trait FnMap: Send + Sync {
23 async fn call(&self, value: serde_json::Value) -> Result<Vec<serde_json::Value>>;
24}
25
26pub struct MapFn<F>(pub F);
27impl<F> MapFn<F> {
28 pub fn new(f: F) -> Self {
29 Self(f)
30 }
31}
32#[async_trait]
33impl<F> FnMap for MapFn<F>
34where
35 F: Fn(serde_json::Value) -> Vec<serde_json::Value> + Send + Sync,
36{
37 async fn call(&self, value: serde_json::Value) -> Result<Vec<serde_json::Value>> {
38 Ok((self.0)(value))
39 }
40}
41
42pub struct Map<F> {
52 func: F,
53}
54impl<F> Map<F> {
55 pub fn new(func: F) -> Self {
56 Self { func }
57 }
58}
59
60#[async_trait]
61impl<F> Operator for Map<F>
62where
63 F: FnMap + Send + Sync + 'static,
64{
65 #[instrument(name = "map_on_element", skip_all)]
66 async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
67 let outs = self.func.call(rec.value).await?;
68 pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Map", "receive"]).inc();
69 for v in outs {
70 ctx.collect(Record {
71 event_time: rec.event_time,
72 value: v.clone(),
73 });
74 pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Map", "emit"]).inc();
75 }
76 Ok(())
77 }
78}
79
80#[async_trait]
81pub trait FnFilter: Send + Sync {
82 async fn call(&self, value: &serde_json::Value) -> Result<bool>;
83}
84
85pub struct FilterFn<F>(pub F);
86impl<F> FilterFn<F> {
87 pub fn new(f: F) -> Self {
88 Self(f)
89 }
90}
91#[async_trait]
92impl<F> FnFilter for FilterFn<F>
93where
94 F: Fn(&serde_json::Value) -> bool + Send + Sync,
95{
96 async fn call(&self, value: &serde_json::Value) -> Result<bool> {
97 Ok((self.0)(value))
98 }
99}
100
101pub struct Filter<F> {
110 pred: F,
111}
112impl<F> Filter<F> {
113 pub fn new(pred: F) -> Self {
114 Self { pred }
115 }
116}
117
118#[async_trait]
119impl<F> Operator for Filter<F>
120where
121 F: FnFilter + Send + Sync + 'static,
122{
123 #[instrument(name = "filter_on_element", skip_all)]
124 async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
125 pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Filter", "receive"]).inc();
126 if self.pred.call(&rec.value).await? {
127 ctx.collect(rec);
128 pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["Filter", "emit"]).inc();
129 }
130 Ok(())
131 }
132}
133
134pub struct KeyBy {
143 field: String,
144}
145impl KeyBy {
146 pub fn new(field: impl Into<String>) -> Self {
147 Self { field: field.into() }
148 }
149}
150
151#[async_trait]
152impl Operator for KeyBy {
153 #[instrument(name = "keyby_on_element", skip_all)]
154 async fn on_element(&mut self, ctx: &mut dyn Context, mut rec: Record) -> Result<()> {
155 pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["KeyBy", "receive"]).inc();
156 let key = rec
157 .value
158 .get(&self.field)
159 .cloned()
160 .unwrap_or(serde_json::Value::Null);
161 let mut obj = match rec.value {
162 serde_json::Value::Object(o) => o,
163 _ => serde_json::Map::new(),
164 };
165 obj.insert("key".to_string(), key);
166 rec.value = serde_json::Value::Object(obj);
167 ctx.collect(rec);
168 pulse_core::metrics::OP_THROUGHPUT.with_label_values(&["KeyBy", "emit"]).inc();
169 Ok(())
170 }
171}
172
173#[derive(Clone, Copy)]
175pub struct WindowTumbling {
176 pub size_ms: i64,
177}
178impl WindowTumbling {
179 pub fn minutes(m: i64) -> Self {
180 Self { size_ms: m * 60_000 }
181 }
182}
183
184pub struct Aggregate {
194 pub key_field: String,
195 pub value_field: String,
196 pub op: AggregationKind,
197 windows: HashMap<(i128, serde_json::Value), i64>, }
199
200#[derive(Clone, Copy)]
202pub enum AggregationKind {
203 Count,
204}
205
206impl Aggregate {
207 pub fn count_per_window(key_field: impl Into<String>, value_field: impl Into<String>) -> Self {
208 Self {
209 key_field: key_field.into(),
210 value_field: value_field.into(),
211 op: AggregationKind::Count,
212 windows: HashMap::new(),
213 }
214 }
215}
216
217#[async_trait]
218impl Operator for Aggregate {
219 async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
220 let minute_ms = 60_000_i128;
221 let ts_ms = rec.event_time.timestamp_millis() as i128; let win_start_ms = (ts_ms / minute_ms) * minute_ms;
223 let key = rec
224 .value
225 .get(&self.key_field)
226 .cloned()
227 .unwrap_or(serde_json::Value::Null);
228 let entry = self.windows.entry((win_start_ms, key.clone())).or_insert(0);
229 *entry += 1;
230 let mut out = serde_json::Map::new();
232 out.insert("window_start_ms".into(), serde_json::json!(win_start_ms));
233 out.insert("key".into(), key);
234 out.insert("count".into(), serde_json::json!(*entry));
235 ctx.collect(Record {
236 event_time: rec.event_time,
237 value: serde_json::Value::Object(out),
238 });
239 Ok(())
240 }
241 async fn on_watermark(&mut self, _ctx: &mut dyn Context, _wm: Watermark) -> Result<()> {
242 Ok(())
243 }
244}
245
246pub mod prelude {
247 pub use super::{
248 AggKind, Aggregate, AggregationKind, Filter, FnFilter, FnMap, KeyBy, Map, WindowKind, WindowTumbling,
249 WindowedAggregate,
250 };
251}
252
253#[derive(Clone, Debug)]
257pub enum WindowKind {
258 Tumbling { size_ms: i64 },
259 Sliding { size_ms: i64, slide_ms: i64 },
260 Session { gap_ms: i64 },
261}
262
263#[derive(Clone, Debug)]
265pub enum AggKind {
266 Count,
267 Sum { field: String },
268 Avg { field: String },
269 Distinct { field: String },
270}
271
272#[derive(Clone, Debug, Default)]
273enum AggState {
274 #[default]
275 Empty,
276 Count(i64),
277 Sum {
278 sum: f64,
279 count: i64,
280 }, Distinct(std::collections::HashSet<String>),
282}
283
284fn as_f64(v: &serde_json::Value) -> f64 {
285 match v {
286 serde_json::Value::Number(n) => n.as_f64().unwrap_or(0.0),
287 serde_json::Value::String(s) => s.parse::<f64>().unwrap_or(0.0),
288 _ => 0.0,
289 }
290}
291
292fn stringify(v: &serde_json::Value) -> String {
293 match v {
294 serde_json::Value::String(s) => s.clone(),
295 other => other.to_string(),
296 }
297}
298
299pub struct WindowedAggregate {
310 pub key_field: String,
311 pub win: WindowKind,
312 pub agg: AggKind,
313 by_window: HashMap<(i128, serde_json::Value), (i128 , AggState)>,
315 sessions: HashMap<serde_json::Value, (i128, i128, AggState)>,
317}
318
319impl WindowedAggregate {
320 pub fn tumbling_count(key_field: impl Into<String>, size_ms: i64) -> Self {
321 Self {
322 key_field: key_field.into(),
323 win: WindowKind::Tumbling { size_ms },
324 agg: AggKind::Count,
325 by_window: HashMap::new(),
326 sessions: HashMap::new(),
327 }
328 }
329 pub fn tumbling_sum(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
330 Self {
331 key_field: key_field.into(),
332 win: WindowKind::Tumbling { size_ms },
333 agg: AggKind::Sum { field: field.into() },
334 by_window: HashMap::new(),
335 sessions: HashMap::new(),
336 }
337 }
338 pub fn tumbling_avg(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
339 Self {
340 key_field: key_field.into(),
341 win: WindowKind::Tumbling { size_ms },
342 agg: AggKind::Avg { field: field.into() },
343 by_window: HashMap::new(),
344 sessions: HashMap::new(),
345 }
346 }
347 pub fn tumbling_distinct(key_field: impl Into<String>, size_ms: i64, field: impl Into<String>) -> Self {
348 Self {
349 key_field: key_field.into(),
350 win: WindowKind::Tumbling { size_ms },
351 agg: AggKind::Distinct { field: field.into() },
352 by_window: HashMap::new(),
353 sessions: HashMap::new(),
354 }
355 }
356
357 pub fn sliding_count(key_field: impl Into<String>, size_ms: i64, slide_ms: i64) -> Self {
358 Self {
359 key_field: key_field.into(),
360 win: WindowKind::Sliding { size_ms, slide_ms },
361 agg: AggKind::Count,
362 by_window: HashMap::new(),
363 sessions: HashMap::new(),
364 }
365 }
366 pub fn sliding_sum(
367 key_field: impl Into<String>,
368 size_ms: i64,
369 slide_ms: i64,
370 field: impl Into<String>,
371 ) -> Self {
372 Self {
373 key_field: key_field.into(),
374 win: WindowKind::Sliding { size_ms, slide_ms },
375 agg: AggKind::Sum { field: field.into() },
376 by_window: HashMap::new(),
377 sessions: HashMap::new(),
378 }
379 }
380 pub fn sliding_avg(
381 key_field: impl Into<String>,
382 size_ms: i64,
383 slide_ms: i64,
384 field: impl Into<String>,
385 ) -> Self {
386 Self {
387 key_field: key_field.into(),
388 win: WindowKind::Sliding { size_ms, slide_ms },
389 agg: AggKind::Avg { field: field.into() },
390 by_window: HashMap::new(),
391 sessions: HashMap::new(),
392 }
393 }
394 pub fn sliding_distinct(
395 key_field: impl Into<String>,
396 size_ms: i64,
397 slide_ms: i64,
398 field: impl Into<String>,
399 ) -> Self {
400 Self {
401 key_field: key_field.into(),
402 win: WindowKind::Sliding { size_ms, slide_ms },
403 agg: AggKind::Distinct { field: field.into() },
404 by_window: HashMap::new(),
405 sessions: HashMap::new(),
406 }
407 }
408
409 pub fn session_count(key_field: impl Into<String>, gap_ms: i64) -> Self {
410 Self {
411 key_field: key_field.into(),
412 win: WindowKind::Session { gap_ms },
413 agg: AggKind::Count,
414 by_window: HashMap::new(),
415 sessions: HashMap::new(),
416 }
417 }
418 pub fn session_sum(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
419 Self {
420 key_field: key_field.into(),
421 win: WindowKind::Session { gap_ms },
422 agg: AggKind::Sum { field: field.into() },
423 by_window: HashMap::new(),
424 sessions: HashMap::new(),
425 }
426 }
427 pub fn session_avg(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
428 Self {
429 key_field: key_field.into(),
430 win: WindowKind::Session { gap_ms },
431 agg: AggKind::Avg { field: field.into() },
432 by_window: HashMap::new(),
433 sessions: HashMap::new(),
434 }
435 }
436 pub fn session_distinct(key_field: impl Into<String>, gap_ms: i64, field: impl Into<String>) -> Self {
437 Self {
438 key_field: key_field.into(),
439 win: WindowKind::Session { gap_ms },
440 agg: AggKind::Distinct { field: field.into() },
441 by_window: HashMap::new(),
442 sessions: HashMap::new(),
443 }
444 }
445}
446
447fn update_state(state: &mut AggState, agg: &AggKind, value: &serde_json::Value) {
448 match agg {
449 AggKind::Count => {
450 *state = match std::mem::take(state) {
451 AggState::Empty => AggState::Count(1),
452 AggState::Count(c) => AggState::Count(c + 1),
453 other => other,
454 };
455 }
456 AggKind::Sum { field } => {
457 let x = as_f64(value.get(field).unwrap_or(&serde_json::Value::Null));
458 *state = match std::mem::take(state) {
459 AggState::Empty => AggState::Sum { sum: x, count: 1 },
460 AggState::Sum { sum, count } => AggState::Sum {
461 sum: sum + x,
462 count: count + 1,
463 },
464 other => other,
465 };
466 }
467 AggKind::Avg { field } => {
468 let x = as_f64(value.get(field).unwrap_or(&serde_json::Value::Null));
469 *state = match std::mem::take(state) {
470 AggState::Empty => AggState::Sum { sum: x, count: 1 },
471 AggState::Sum { sum, count } => AggState::Sum {
472 sum: sum + x,
473 count: count + 1,
474 },
475 other => other,
476 };
477 }
478 AggKind::Distinct { field } => {
479 let s = stringify(value.get(field).unwrap_or(&serde_json::Value::Null));
480 *state = match std::mem::take(state) {
481 AggState::Empty => {
482 let mut set = std::collections::HashSet::new();
483 set.insert(s);
484 AggState::Distinct(set)
485 }
486 AggState::Distinct(mut set) => {
487 set.insert(s);
488 AggState::Distinct(set)
489 }
490 other => other,
491 };
492 }
493 }
494}
495
496fn finalize_value(state: &AggState, agg: &AggKind) -> serde_json::Value {
497 match (state, agg) {
498 (AggState::Count(c), _) => serde_json::json!(*c),
499 (AggState::Sum { sum, .. }, AggKind::Sum { .. }) => serde_json::json!(sum),
500 (AggState::Sum { sum, count }, AggKind::Avg { .. }) => {
501 let avg = if *count > 0 { *sum / (*count as f64) } else { 0.0 };
502 serde_json::json!(avg)
503 }
504 (AggState::Distinct(set), AggKind::Distinct { .. }) => serde_json::json!(set.len() as i64),
505 _ => serde_json::json!(null),
506 }
507}
508
509#[async_trait]
510impl Operator for WindowedAggregate {
511 async fn on_element(&mut self, ctx: &mut dyn Context, rec: Record) -> Result<()> {
512 let ts_ms = rec.event_time.timestamp_millis() as i128; let key = rec
514 .value
515 .get(&self.key_field)
516 .cloned()
517 .unwrap_or(serde_json::Value::Null);
518
519 match self.win {
520 WindowKind::Tumbling { size_ms } => {
521 let start = (ts_ms / (size_ms as i128)) * (size_ms as i128);
522 let end = start + (size_ms as i128);
523 let entry = self
524 .by_window
525 .entry((end, key.clone()))
526 .or_insert((start, AggState::Empty));
527 update_state(&mut entry.1, &self.agg, &rec.value);
528 let _ = ctx
530 .timers()
531 .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
532 .await;
533 }
534 WindowKind::Sliding { size_ms, slide_ms } => {
535 let k = (size_ms / slide_ms) as i128;
536 let anchor = (ts_ms / (slide_ms as i128)) * (slide_ms as i128);
537 for j in 0..k {
538 let start = anchor - (j * (slide_ms as i128));
539 let end = start + (size_ms as i128);
540 if start <= ts_ms && end > ts_ms {
541 let entry = self
542 .by_window
543 .entry((end, key.clone()))
544 .or_insert((start, AggState::Empty));
545 update_state(&mut entry.1, &self.agg, &rec.value);
546 let _ = ctx
547 .timers()
548 .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
549 .await;
550 }
551 }
552 }
553 WindowKind::Session { gap_ms } => {
554 let e = self
555 .sessions
556 .entry(key.clone())
557 .or_insert((ts_ms, ts_ms, AggState::Empty));
558 let (start, last_seen, state) = e;
559 if ts_ms - *last_seen <= (gap_ms as i128) {
560 *last_seen = ts_ms;
561 update_state(state, &self.agg, &rec.value);
562 } else {
563 let mut out = serde_json::Map::new();
565 out.insert("window_start_ms".into(), serde_json::json!(*start));
566 out.insert(
567 "window_end_ms".into(),
568 serde_json::json!(*last_seen + (gap_ms as i128)),
569 );
570 out.insert("key".into(), key.clone());
571 let val = finalize_value(state, &self.agg);
572 match self.agg {
573 AggKind::Count => {
574 out.insert("count".into(), val);
575 }
576 AggKind::Sum { .. } => {
577 out.insert("sum".into(), val);
578 }
579 AggKind::Avg { .. } => {
580 out.insert("avg".into(), val);
581 }
582 AggKind::Distinct { .. } => {
583 out.insert("distinct_count".into(), val);
584 }
585 }
586 ctx.collect(Record {
587 event_time: rec.event_time,
588 value: serde_json::Value::Object(out),
589 });
590 *start = ts_ms;
592 *last_seen = ts_ms;
593 *state = AggState::Empty;
594 update_state(state, &self.agg, &rec.value);
595 }
596 let end = ts_ms + (gap_ms as i128);
598 let _ = ctx
599 .timers()
600 .register_event_time_timer(pulse_core::EventTime(Utc.timestamp_millis_opt(end as i64).unwrap()), None)
601 .await;
602 }
603 }
604 Ok(())
605 }
606
607 async fn on_watermark(&mut self, ctx: &mut dyn Context, wm: Watermark) -> Result<()> {
608 let wm_ms = wm.0 .0.timestamp_millis() as i128;
609
610 match self.win {
611 WindowKind::Tumbling { .. } | WindowKind::Sliding { .. } => {
612 let mut to_emit: Vec<((i128, serde_json::Value), (i128, AggState))> = self
614 .by_window
615 .iter()
616 .filter(|((end, _), _)| *end <= wm_ms)
617 .map(|(k, v)| (k.clone(), v.clone()))
618 .collect();
619 for ((end, key), (start, state)) in to_emit.drain(..) {
620 let mut out = serde_json::Map::new();
621 out.insert("window_start_ms".into(), serde_json::json!(start));
622 out.insert("window_end_ms".into(), serde_json::json!(end));
623 out.insert("key".into(), key.clone());
624 let val = finalize_value(&state, &self.agg);
625 match self.agg {
626 AggKind::Count => {
627 out.insert("count".into(), val);
628 }
629 AggKind::Sum { .. } => {
630 out.insert("sum".into(), val);
631 }
632 AggKind::Avg { .. } => {
633 out.insert("avg".into(), val);
634 }
635 AggKind::Distinct { .. } => {
636 out.insert("distinct_count".into(), val);
637 }
638 }
639 ctx.collect(Record { event_time: wm.0 .0, value: serde_json::Value::Object(out) });
640 self.by_window.remove(&(end, key));
641 }
642 }
643 WindowKind::Session { gap_ms } => {
644 let keys: Vec<_> = self.sessions.keys().cloned().collect();
646 for key in keys {
647 if let Some((start, last_seen, state)) = self.sessions.get(&key).cloned() {
648 if last_seen + (gap_ms as i128) <= wm_ms {
649 let mut out = serde_json::Map::new();
650 out.insert("window_start_ms".into(), serde_json::json!(start));
651 out.insert(
652 "window_end_ms".into(),
653 serde_json::json!(last_seen + (gap_ms as i128)),
654 );
655 out.insert("key".into(), key.clone());
656 let val = finalize_value(&state, &self.agg);
657 match self.agg {
658 AggKind::Count => {
659 out.insert("count".into(), val);
660 }
661 AggKind::Sum { .. } => {
662 out.insert("sum".into(), val);
663 }
664 AggKind::Avg { .. } => {
665 out.insert("avg".into(), val);
666 }
667 AggKind::Distinct { .. } => {
668 out.insert("distinct_count".into(), val);
669 }
670 }
671 ctx.collect(Record { event_time: wm.0 .0, value: serde_json::Value::Object(out) });
672 self.sessions.remove(&key);
673 }
674 }
675 }
676 }
677 }
678 Ok(())
679 }
680
681 async fn on_timer(
682 &mut self,
683 ctx: &mut dyn Context,
684 when: EventTime,
685 _key: Option<Vec<u8>>,
686 ) -> Result<()> {
687 let when_ms = when.0.timestamp_millis() as i128;
689
690 match self.win {
691 WindowKind::Tumbling { .. } | WindowKind::Sliding { .. } => {
692 let mut to_emit: Vec<((i128, serde_json::Value), (i128, AggState))> = self
693 .by_window
694 .iter()
695 .filter(|((end, _), _)| *end <= when_ms)
696 .map(|(k, v)| (k.clone(), v.clone()))
697 .collect();
698 for ((end, key), (start, state)) in to_emit.drain(..) {
699 let mut out = serde_json::Map::new();
700 out.insert("window_start_ms".into(), serde_json::json!(start));
701 out.insert("window_end_ms".into(), serde_json::json!(end));
702 out.insert("key".into(), key.clone());
703 let val = finalize_value(&state, &self.agg);
704 match self.agg {
705 AggKind::Count => {
706 out.insert("count".into(), val);
707 }
708 AggKind::Sum { .. } => {
709 out.insert("sum".into(), val);
710 }
711 AggKind::Avg { .. } => {
712 out.insert("avg".into(), val);
713 }
714 AggKind::Distinct { .. } => {
715 out.insert("distinct_count".into(), val);
716 }
717 }
718 ctx.collect(Record { event_time: when.0, value: serde_json::Value::Object(out) });
719 self.by_window.remove(&(end, key));
720 }
721 }
722 WindowKind::Session { gap_ms } => {
723 let keys: Vec<_> = self.sessions.keys().cloned().collect();
724 for key in keys {
725 if let Some((start, last_seen, state)) = self.sessions.get(&key).cloned() {
726 if last_seen + (gap_ms as i128) <= when_ms {
727 let mut out = serde_json::Map::new();
728 out.insert("window_start_ms".into(), serde_json::json!(start));
729 out.insert(
730 "window_end_ms".into(),
731 serde_json::json!(last_seen + (gap_ms as i128)),
732 );
733 out.insert("key".into(), key.clone());
734 let val = finalize_value(&state, &self.agg);
735 match self.agg {
736 AggKind::Count => {
737 out.insert("count".into(), val);
738 }
739 AggKind::Sum { .. } => {
740 out.insert("sum".into(), val);
741 }
742 AggKind::Avg { .. } => {
743 out.insert("avg".into(), val);
744 }
745 AggKind::Distinct { .. } => {
746 out.insert("distinct_count".into(), val);
747 }
748 }
749 ctx.collect(Record { event_time: when.0, value: serde_json::Value::Object(out) });
750 self.sessions.remove(&key);
751 }
752 }
753 }
754 }
755 }
756 Ok(())
757 }
758}
759
760#[cfg(test)]
761mod window_tests {
762 use super::*;
763 use pulse_core::{Context, EventTime, KvState, Record, Result, Timers, Watermark};
764 use std::sync::Arc;
765
766 struct TestState;
767 #[async_trait]
768 impl KvState for TestState {
769 async fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>> {
770 Ok(None)
771 }
772 async fn put(&self, _key: &[u8], _value: Vec<u8>) -> Result<()> {
773 Ok(())
774 }
775 async fn delete(&self, _key: &[u8]) -> Result<()> {
776 Ok(())
777 }
778 async fn iter_prefix(&self, _prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
779 Ok(Vec::new())
780 }
781 async fn snapshot(&self) -> Result<pulse_core::SnapshotId> {
782 Ok("test-snap".to_string())
783 }
784 async fn restore(&self, _snapshot: pulse_core::SnapshotId) -> Result<()> {
785 Ok(())
786 }
787 }
788 struct TestTimers;
789 #[async_trait]
790 impl Timers for TestTimers {
791 async fn register_event_time_timer(&self, _when: EventTime, _key: Option<Vec<u8>>) -> Result<()> {
792 Ok(())
793 }
794 }
795
796 struct TestCtx {
797 out: Vec<Record>,
798 kv: Arc<dyn KvState>,
799 timers: Arc<dyn Timers>,
800 }
801 #[async_trait]
802 impl Context for TestCtx {
803 fn collect(&mut self, record: Record) {
804 self.out.push(record);
805 }
806 fn watermark(&mut self, _wm: Watermark) {}
807 fn kv(&self) -> Arc<dyn KvState> {
808 self.kv.clone()
809 }
810 fn timers(&self) -> Arc<dyn Timers> {
811 self.timers.clone()
812 }
813 }
814
815 fn record_with(ts_ms: i128, key: &str) -> Record {
816 Record {
817 event_time: Utc.timestamp_millis_opt(ts_ms as i64).unwrap(),
818 value: serde_json::json!({"word": key}),
819 }
820 }
821
822 #[tokio::test]
823 async fn tumbling_count_emits_on_watermark() {
824 let mut op = WindowedAggregate::tumbling_count("word", 60_000);
825 let mut ctx = TestCtx {
826 out: vec![],
827 kv: Arc::new(TestState),
828 timers: Arc::new(TestTimers),
829 };
830 op.on_element(&mut ctx, record_with(1_000, "a")).await.unwrap();
831 op.on_element(&mut ctx, record_with(1_010, "a")).await.unwrap();
832 op.on_watermark(&mut ctx, Watermark(EventTime(Utc.timestamp_millis_opt(60_000).unwrap())))
834 .await
835 .unwrap();
836 assert_eq!(ctx.out.len(), 1);
837 assert_eq!(ctx.out[0].value["count"], serde_json::json!(2));
838 }
839}
840#[cfg(test)]
841mod tests {
842 use super::*;
843 use pulse_core::{Context, EventTime, KvState, Record, Result, Timers};
844 use std::sync::Arc;
845
846 struct TestState;
847 #[async_trait]
848 impl KvState for TestState {
849 async fn get(&self, _key: &[u8]) -> Result<Option<Vec<u8>>> {
850 Ok(None)
851 }
852 async fn put(&self, _key: &[u8], _value: Vec<u8>) -> Result<()> {
853 Ok(())
854 }
855 async fn delete(&self, _key: &[u8]) -> Result<()> {
856 Ok(())
857 }
858 async fn iter_prefix(&self, _prefix: Option<&[u8]>) -> Result<Vec<(Vec<u8>, Vec<u8>)>> {
859 Ok(Vec::new())
860 }
861 async fn snapshot(&self) -> Result<pulse_core::SnapshotId> {
862 Ok("test-snap".to_string())
863 }
864 async fn restore(&self, _snapshot: pulse_core::SnapshotId) -> Result<()> {
865 Ok(())
866 }
867 }
868
869 struct TestTimers;
870 #[async_trait]
871 impl Timers for TestTimers {
872 async fn register_event_time_timer(&self, _when: EventTime, _key: Option<Vec<u8>>) -> Result<()> {
873 Ok(())
874 }
875 }
876
877 struct TestCtx {
878 out: Vec<Record>,
879 kv: Arc<dyn KvState>,
880 timers: Arc<dyn Timers>,
881 }
882
883 #[async_trait]
884 impl Context for TestCtx {
885 fn collect(&mut self, record: Record) {
886 self.out.push(record);
887 }
888 fn watermark(&mut self, _wm: pulse_core::Watermark) {}
889 fn kv(&self) -> Arc<dyn KvState> {
890 self.kv.clone()
891 }
892 fn timers(&self) -> Arc<dyn Timers> {
893 self.timers.clone()
894 }
895 }
896
897 fn rec(v: serde_json::Value) -> Record {
898 Record {
899 event_time: Utc::now(),
900 value: v,
901 }
902 }
903
904 #[tokio::test]
905 async fn test_map() {
906 let mut op = Map::new(MapFn::new(|v| vec![v]));
907 let mut ctx = TestCtx {
908 out: vec![],
909 kv: Arc::new(TestState),
910 timers: Arc::new(TestTimers),
911 };
912 op.on_element(&mut ctx, rec(serde_json::json!({"a":1})))
913 .await
914 .unwrap();
915 assert_eq!(ctx.out.len(), 1);
916 }
917
918 #[tokio::test]
919 async fn test_filter() {
920 let mut op = Filter::new(FilterFn::new(|v: &serde_json::Value| {
921 v.get("ok").and_then(|x| x.as_bool()).unwrap_or(false)
922 }));
923 let mut ctx = TestCtx {
924 out: vec![],
925 kv: Arc::new(TestState),
926 timers: Arc::new(TestTimers),
927 };
928 op.on_element(&mut ctx, rec(serde_json::json!({"ok":false})))
929 .await
930 .unwrap();
931 op.on_element(&mut ctx, rec(serde_json::json!({"ok":true})))
932 .await
933 .unwrap();
934 assert_eq!(ctx.out.len(), 1);
935 }
936
937 #[tokio::test]
938 async fn test_keyby() {
939 let mut op = KeyBy::new("word");
940 let mut ctx = TestCtx {
941 out: vec![],
942 kv: Arc::new(TestState),
943 timers: Arc::new(TestTimers),
944 };
945 op.on_element(&mut ctx, rec(serde_json::json!({"word":"hi"})))
946 .await
947 .unwrap();
948 assert_eq!(ctx.out.len(), 1);
949 assert_eq!(ctx.out[0].value["key"], serde_json::json!("hi"));
950 }
951
952 #[tokio::test]
953 async fn test_aggregate_count() {
954 let mut op = Aggregate::count_per_window("key", "word");
955 let mut ctx = TestCtx {
956 out: vec![],
957 kv: Arc::new(TestState),
958 timers: Arc::new(TestTimers),
959 };
960 op.on_element(&mut ctx, rec(serde_json::json!({"key":"hello"})))
961 .await
962 .unwrap();
963 op.on_element(&mut ctx, rec(serde_json::json!({"key":"hello"})))
964 .await
965 .unwrap();
966 assert_eq!(ctx.out.len(), 2);
967 assert_eq!(ctx.out[1].value["count"], serde_json::json!(2));
968 }
969}