1use std::collections::HashMap;
2
3use chrono::{DateTime, Duration, Utc};
4use pulse_core::KvState;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
8pub enum WindowAssigner {
9 Tumbling { size: Duration },
10 Sliding { size: Duration, slide: Duration },
11}
12
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
14pub struct Window {
15 pub start: DateTime<Utc>,
16 pub end: DateTime<Utc>,
17}
18
19impl WindowAssigner {
20 fn assign(&self, ts: DateTime<Utc>) -> Vec<Window> {
21 match *self {
22 WindowAssigner::Tumbling { size } => {
23 let epoch = DateTime::<Utc>::from_timestamp(0, 0).unwrap();
24 let since = ts - epoch;
25 let buckets = since.num_milliseconds() / size.num_milliseconds();
26 let start = epoch + Duration::milliseconds(buckets * size.num_milliseconds());
27 let end = start + size;
28 vec![Window { start, end }]
29 }
30 WindowAssigner::Sliding { size, slide } => {
31 let epoch = DateTime::<Utc>::from_timestamp(0, 0).unwrap();
32 let since = ts - epoch;
33 let k = (size.num_milliseconds() / slide.num_milliseconds()) as i64;
34 let anchor_ms =
35 (since.num_milliseconds() / slide.num_milliseconds()) * slide.num_milliseconds();
36 let mut out = Vec::new();
37 for j in 0..k {
38 let start = epoch + Duration::milliseconds(anchor_ms - j * slide.num_milliseconds());
39 let end = start + size;
40 if start <= ts && ts < end {
41 out.push(Window { start, end });
42 }
43 }
44 out
45 }
46 }
47 }
48}
49
50pub struct WindowOperator<S> {
53 assigner: WindowAssigner,
54 state: HashMap<Window, S>,
55 reduce: Box<dyn Fn(&mut S, &serde_json::Value) + Send + Sync>,
56 init: Box<dyn Fn() -> S + Send + Sync>,
57 backend: Option<std::sync::Arc<dyn KvState>>, ns_prefix: Vec<u8>,
60}
61
62impl<S> WindowOperator<S>
63where
64 S: Clone + Default + Serialize + for<'de> Deserialize<'de> + Send + Sync + 'static,
65{
66 pub fn new<Init, Red>(assigner: WindowAssigner, init: Init, reduce: Red) -> Self
67 where
68 Init: Fn() -> S + Send + Sync + 'static,
69 Red: Fn(&mut S, &serde_json::Value) + Send + Sync + 'static,
70 {
71 Self {
72 assigner,
73 state: HashMap::new(),
74 init: Box::new(init),
75 reduce: Box::new(reduce),
76 backend: None,
77 ns_prefix: b"window:".to_vec(),
78 }
79 }
80
81 pub fn with_backend(mut self, backend: std::sync::Arc<dyn KvState>, ns_prefix: impl AsRef<[u8]>) -> Self {
83 self.backend = Some(backend);
84 self.ns_prefix = ns_prefix.as_ref().to_vec();
85 self
86 }
87
88 fn make_key(&self, w: &Window) -> Vec<u8> {
89 let start = w.start.timestamp_millis();
91 let end = w.end.timestamp_millis();
92 let mut k = self.ns_prefix.clone();
93 k.extend_from_slice(start.to_string().as_bytes());
94 k.push(b'|');
95 k.extend_from_slice(end.to_string().as_bytes());
96 k
97 }
98
99 pub fn on_element(&mut self, ts: DateTime<Utc>, value: &serde_json::Value) {
100 for w in self.assigner.assign(ts) {
101 let entry = self.state.entry(w.clone()).or_insert_with(|| (self.init)());
102 (self.reduce)(entry, value);
103 if let Some(backend) = &self.backend {
105 if let Ok(bytes) = serde_json::to_vec(entry) {
106 let key = self.make_key(&w);
107 let _ = futures::executor::block_on(backend.put(&key, bytes));
109 }
110 }
111 }
112 }
113
114 pub fn on_watermark(&mut self, watermark: DateTime<Utc>) -> Vec<(Window, S)> {
115 let mut to_emit = Vec::new();
116 let keys: Vec<_> = self.state.keys().cloned().collect();
117 for w in keys {
118 if watermark >= w.end {
119 if let Some(s) = self.state.remove(&w) {
120 to_emit.push((w.clone(), s));
121 }
122 if let Some(backend) = &self.backend {
124 let key = self.make_key(&w);
125 let _ = futures::executor::block_on(backend.delete(&key));
126 }
127 }
128 }
129 to_emit
130 }
131
132 pub async fn restore_from_backend(&mut self) -> pulse_core::Result<()>
134 where
135 S: for<'de> Deserialize<'de>,
136 {
137 if let Some(backend) = &self.backend {
138 let entries = backend.iter_prefix(Some(&self.ns_prefix)).await?;
139 for (k, v) in entries {
140 if let Ok(state) = serde_json::from_slice::<S>(&v) {
141 let s = String::from_utf8_lossy(&k[self.ns_prefix.len()..]);
143 if let Some((a, b)) = s.split_once('|') {
144 if let (Ok(start_ms), Ok(end_ms)) = (a.parse::<i64>(), b.parse::<i64>()) {
145 let w = Window {
146 start: DateTime::<Utc>::from_timestamp_millis(start_ms).unwrap(),
147 end: DateTime::<Utc>::from_timestamp_millis(end_ms).unwrap(),
148 };
149 self.state.insert(w, state);
150 }
151 }
152 }
153 }
154 }
155 Ok(())
156 }
157}
158
159#[cfg(test)]
160mod persist_tests {
161 use super::*;
162 use crate::window::WindowAssigner;
163 use chrono::Duration;
164 use pulse_state::InMemoryState;
165 use std::sync::Arc;
166
167 #[tokio::test]
168 async fn window_state_persists_and_restores() {
169 let assigner = WindowAssigner::Tumbling {
170 size: Duration::seconds(60),
171 };
172 let backend = Arc::new(InMemoryState::default());
173 let mut op1 = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0))
175 .with_backend(backend.clone(), b"win:count:");
176 let t0 = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap();
177 op1.on_element(t0, &serde_json::json!({"n": 2}));
178 let _snap = backend.snapshot().await.unwrap();
180
181 let mut op2 = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0))
183 .with_backend(backend.clone(), b"win:count:");
184 op2.restore_from_backend().await.unwrap();
185 let wm = t0 + Duration::seconds(60);
187 let out = op2.on_watermark(wm);
188 assert_eq!(out.len(), 1);
189 assert_eq!(out[0].1, 2);
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn tumbling_emits_on_watermark_with_lateness() {
199 let assigner = WindowAssigner::Tumbling {
200 size: Duration::seconds(60),
201 };
202 let mut op = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0));
203
204 let t0 = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap();
205 let t1 = t0 + Duration::seconds(10);
206 let t2 = t0 + Duration::seconds(70); op.on_element(t0, &serde_json::json!({"n": 1}));
209 op.on_element(t1, &serde_json::json!({"n": 2}));
210 let wm1 = t0 + Duration::seconds(60);
212 let out1 = op.on_watermark(wm1);
213 assert_eq!(out1.len(), 1);
214 assert_eq!(out1[0].1, 3);
215
216 op.on_element(t2, &serde_json::json!({"n": 5}));
217 let wm2 = t2 + Duration::seconds(60);
218 let out2 = op.on_watermark(wm2);
219 assert_eq!(out2.len(), 1);
220 assert_eq!(out2[0].1, 5);
221 }
222
223 #[test]
224 fn sliding_emits_multiple_overlaps() {
225 let assigner = WindowAssigner::Sliding {
226 size: Duration::seconds(60),
227 slide: Duration::seconds(15),
228 };
229 let mut op = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0));
230
231 let base = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap();
232 let t = base + Duration::seconds(30);
233 op.on_element(t, &serde_json::json!({"n": 1}));
234 let wins = assigner.assign(t);
236 let max_end = wins.iter().map(|w| w.end).max().unwrap();
237 let out = op.on_watermark(max_end);
238 assert_eq!(out.len(), wins.len());
239 let sums: Vec<i64> = out.iter().map(|(_, s)| *s).collect();
240 assert!(sums.iter().all(|&x| x == 1));
241 }
242
243 #[test]
244 fn out_of_order_data_waits_until_watermark() {
245 let assigner = WindowAssigner::Tumbling {
246 size: Duration::seconds(60),
247 };
248 let mut op = WindowOperator::new(assigner, || 0i64, |s, v| *s += v["n"].as_i64().unwrap_or(0));
249 let base = DateTime::<Utc>::from_timestamp(1_700_000_000, 0).unwrap();
250 let late = base + Duration::seconds(10);
251 op.on_element(base + Duration::seconds(75), &serde_json::json!({"n": 7}));
253 op.on_element(late, &serde_json::json!({"n": 3}));
255 let wins_for_late = assigner.assign(late);
257 let end_of_late = wins_for_late.iter().map(|w| w.end).max().unwrap();
258 let out0 = op.on_watermark(end_of_late - Duration::seconds(1));
259 assert!(out0.is_empty());
260 let out1 = op.on_watermark(end_of_late);
261 assert_eq!(out1.len(), 1);
262 assert_eq!(out1[0].1, 3);
263 }
264}