1use std::collections::HashMap;
2
3use chrono::{DateTime, Duration, Utc};
4use serde::{Deserialize, Serialize};
5use pulse_core::KvState;
6
7#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
8pub enum WindowAssigner {
9 Tumbling { size: Duration },
10 Sliding { size: Duration, slide: Duration },
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
14pub struct Window {
15 pub start: DateTime<Utc>,
16 pub end: DateTime<Utc>,
17}
18
19impl WindowAssigner {
20 fn assign(&self, ts: DateTime<Utc>) -> Vec<Window> {
21 match *self {
22 WindowAssigner::Tumbling { size } => {
23 let epoch = DateTime::<Utc>::from_timestamp(0, 0).unwrap();
24 let since = ts - epoch;
25 let buckets = since.num_milliseconds() / size.num_milliseconds();
26 let start = epoch + Duration::milliseconds(buckets * size.num_milliseconds());
27 let end = start + size;
28 vec![Window { start, end }]
29 }
30 WindowAssigner::Sliding { size, slide } => {
31 let epoch = DateTime::<Utc>::from_timestamp(0, 0).unwrap();
32 let since = ts - epoch;
33 let k = (size.num_milliseconds() / slide.num_milliseconds()) as i64;
34 let anchor_ms = (since.num_milliseconds() / slide.num_milliseconds()) * slide.num_milliseconds();
35 let mut out = Vec::new();
36 for j in 0..k {
37 let start = epoch
38 + Duration::milliseconds(anchor_ms - j * slide.num_milliseconds());
39 let end = start + size;
40 if start <= ts && ts < end {
41 out.push(Window { start, end });
42 }
43 }
44 out
45 }
46 }
47 }
48}
49
50pub struct WindowOperator<S> {
53 assigner: WindowAssigner,
54 state: HashMap<Window, S>,
55 reduce: Box<dyn Fn(&mut S, &serde_json::Value) + Send + Sync>,
56 init: Box<dyn Fn() -> S + Send + Sync>,
57 backend: Option<std::sync::Arc<dyn KvState>>, ns_prefix: Vec<u8>,
60}
61
62impl<S> WindowOperator<S>
63where
64 S: Clone + Default + Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
65{
66 pub fn new<Init, Red>(assigner: WindowAssigner, init: Init, reduce: Red) -> Self
67 where
68 Init: Fn() -> S + Send + Sync + 'static,
69 Red: Fn(&mut S, &serde_json::Value) + Send + Sync + 'static,
70 {
71 Self {
72 assigner,
73 state: HashMap::new(),
74 init: Box::new(init),
75 reduce: Box::new(reduce),
76 backend: None,
77 ns_prefix: b"window:".to_vec(),
78 }
79 }
80
81 pub fn with_backend(mut self, backend: std::sync::Arc<dyn KvState>, ns_prefix: impl AsRef<[u8]>) -> Self {
83 self.backend = Some(backend);
84 self.ns_prefix = ns_prefix.as_ref().to_vec();
85 self
86 }
87
88 fn make_key(&self, w: &Window) -> Vec<u8> {
89 let start = w.start.timestamp_millis();
91 let end = w.end.timestamp_millis();
92 let mut k = self.ns_prefix.clone();
93 k.extend_from_slice(start.to_string().as_bytes());
94 k.push(b'|');
95 k.extend_from_slice(end.to_string().as_bytes());
96 k
97 }
98
99 pub fn on_element(&mut self, ts: DateTime<Utc>, value: &serde_json::Value) {
100 for w in self.assigner.assign(ts) {
101 let entry = self.state.entry(w.clone()).or_insert_with(|| (self.init)());
102 (self.reduce)(entry, value);
103 if let Some(backend) = &self.backend {
105 if let Ok(bytes) = serde_json::to_vec(entry) {
106 let key = self.make_key(&w);
107 let _ = futures::executor::block_on(backend.put(&key, bytes));
109 }
110 }
111 }
112 }
113
114 pub fn on_watermark(&mut self, watermark: DateTime<Utc>) -> Vec<(Window, S)> {
115 let mut to_emit = Vec::new();
116 let keys: Vec<_> = self.state.keys().cloned().collect();
117 for w in keys {
118 if watermark >= w.end {
119 if let Some(s) = self.state.remove(&w) {
120 to_emit.push((w.clone(), s));
121 }
122 if let Some(backend) = &self.backend {
124 let key = self.make_key(&w);
125 let _ = futures::executor::block_on(backend.delete(&key));
126 }
127 }
128 }
129 to_emit
130 }
131
132 pub async fn restore_from_backend(&mut self) -> pulse_core::Result<()>
134 where
135 S: for<'de> Deserialize<'de>,
136 {
137 if let Some(backend) = &self.backend {
138 let entries = backend.iter_prefix(Some(&self.ns_prefix)).await?;
139 for (k, v) in entries {
140 if let Ok(state) = serde_json::from_slice::<S>(&v) {
141 let s = String::from_utf8_lossy(&k[self.ns_prefix.len()..]);
143 if let Some((a, b)) = s.split_once('|') {
144 if let (Ok(start_ms), Ok(end_ms)) = (a.parse::<i64>(), b.parse::<i64>()) {
145 let w = Window {
146 start: DateTime::<Utc>::from_timestamp_millis(start_ms).unwrap(),
147 end: DateTime::<Utc>::from_timestamp_millis(end_ms).unwrap(),
148 };
149 self.state.insert(w, state);
150 }
151 }
152 }
153 }
154 }
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod persist_tests {
161 use super::*;
162 use crate::window::WindowAssigner;
163 use chrono::Duration;
164 use pulse_state::InMemoryState;
165 use std::sync::Arc;
166
167 #[tokio::test]
168 async fn window_state_persists_and_restores() {
169 let assigner = WindowAssigner::Tumbling { size: Duration::seconds(60) };
170 let backend = Arc::new(InMemoryState::default());
171 let mut op1 = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0))
173 .with_backend(backend.clone(), b"win:count:");
174 let t0 = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap();
175 op1.on_element(t0, &serde_json::json!({"n": 2}));
176 let _snap = backend.snapshot().await.unwrap();
178
179 let mut op2 = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0))
181 .with_backend(backend.clone(), b"win:count:");
182 op2.restore_from_backend().await.unwrap();
183 let wm = t0 + Duration::seconds(60);
185 let out = op2.on_watermark(wm);
186 assert_eq!(out.len(), 1);
187 assert_eq!(out[0].1, 2);
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194
195 #[test]
196 fn tumbling_emits_on_watermark_with_lateness() {
197 let assigner = WindowAssigner::Tumbling { size: Duration::seconds(60) };
198 let mut op = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0));
199
200 let t0 = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap();
201 let t1 = t0 + Duration::seconds(10);
202 let t2 = t0 + Duration::seconds(70); op.on_element(t0, &serde_json::json!({"n": 1}));
205 op.on_element(t1, &serde_json::json!({"n": 2}));
206 let wm1 = t0 + Duration::seconds(60);
208 let out1 = op.on_watermark(wm1);
209 assert_eq!(out1.len(), 1);
210 assert_eq!(out1[0].1, 3);
211
212 op.on_element(t2, &serde_json::json!({"n": 5}));
213 let wm2 = t2 + Duration::seconds(60);
214 let out2 = op.on_watermark(wm2);
215 assert_eq!(out2.len(), 1);
216 assert_eq!(out2[0].1, 5);
217 }
218
219 #[test]
220 fn sliding_emits_multiple_overlaps() {
221 let assigner = WindowAssigner::Sliding { size: Duration::seconds(60), slide: Duration::seconds(15) };
222 let mut op = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0));
223
224 let base = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap();
225 let t = base + Duration::seconds(30);
226 op.on_element(t, &serde_json::json!({"n": 1}));
227 let wins = assigner.assign(t);
229 let max_end = wins.iter().map(|w| w.end).max().unwrap();
230 let out = op.on_watermark(max_end);
231 assert_eq!(out.len(), wins.len());
232 let sums: Vec<i64> = out.iter().map(|(_, s)| *s).collect();
233 assert!(sums.iter().all(|&x| x == 1));
234 }
235
236 #[test]
237 fn out_of_order_data_waits_until_watermark() {
238 let assigner = WindowAssigner::Tumbling { size: Duration::seconds(60) };
239 let mut op = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0));
240 let base = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap();
241 let late = base + Duration::seconds(10);
242 op.on_element(base + Duration::seconds(75), &serde_json::json!({"n": 7}));
244 op.on_element(late, &serde_json::json!({"n": 3}));
246 let wins_for_late = assigner.assign(late);
248 let end_of_late = wins_for_late.iter().map(|w| w.end).max().unwrap();
249 let out0 = op.on_watermark(end_of_late - Duration::seconds(1));
250 assert!(out0.is_empty());
251 let out1 = op.on_watermark(end_of_late);
252 assert_eq!(out1.len(), 1);
253 assert_eq!(out1[0].1, 3);
254 }
255}