oxirs_stream/aggregation/
exactly_once.rs1use std::collections::HashMap;
28use std::sync::Arc;
29
30use crate::error::StreamError;
31use crate::state::distributed_state::StateBackend;
32use crate::state::exactly_once::{DeduplicationConfig, ExactlyOnceProcessor, MessageId};
33
34#[derive(Debug, Clone, PartialEq)]
38pub enum PartitionAggregateValue {
39 Count(u64),
40 Sum(f64),
41 Min(f64),
42 Max(f64),
43 Mean {
45 sum: f64,
46 count: u64,
47 },
48}
49
50impl PartitionAggregateValue {
51 pub fn is_initial(&self) -> bool {
53 matches!(
54 self,
55 PartitionAggregateValue::Count(0)
56 | PartitionAggregateValue::Sum(0.0)
57 | PartitionAggregateValue::Mean { sum: _, count: 0 }
58 )
59 }
60}
61
62#[derive(Debug, Clone)]
69pub struct PartitionAggregateState {
70 inner: HashMap<String, PartitionAggregateValue>,
71}
72
73impl Default for PartitionAggregateState {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl PartitionAggregateState {
80 pub fn new() -> Self {
82 Self {
83 inner: HashMap::new(),
84 }
85 }
86
87 pub fn get(&self, key: &str) -> Option<&PartitionAggregateValue> {
89 self.inner.get(key)
90 }
91
92 pub fn put(&mut self, key: impl Into<String>, value: PartitionAggregateValue) {
94 self.inner.insert(key.into(), value);
95 }
96
97 pub fn len(&self) -> usize {
99 self.inner.len()
100 }
101
102 pub fn is_empty(&self) -> bool {
104 self.inner.is_empty()
105 }
106
107 pub fn iter(&self) -> impl Iterator<Item = (&String, &PartitionAggregateValue)> {
109 self.inner.iter()
110 }
111
112 pub fn encode(&self) -> Vec<u8> {
129 let mut out = Vec::new();
130 out.extend_from_slice(&(self.inner.len() as u32).to_le_bytes());
131 let mut keys: Vec<&String> = self.inner.keys().collect();
133 keys.sort();
134 for k in keys {
135 let v = match self.inner.get(k) {
136 Some(v) => v,
137 None => continue,
138 };
139 out.extend_from_slice(&(k.len() as u32).to_le_bytes());
140 out.extend_from_slice(k.as_bytes());
141 match v {
142 PartitionAggregateValue::Count(c) => {
143 out.push(0x01);
144 out.extend_from_slice(&c.to_le_bytes());
145 }
146 PartitionAggregateValue::Sum(s) => {
147 out.push(0x02);
148 out.extend_from_slice(&s.to_le_bytes());
149 }
150 PartitionAggregateValue::Min(m) => {
151 out.push(0x03);
152 out.extend_from_slice(&m.to_le_bytes());
153 }
154 PartitionAggregateValue::Max(m) => {
155 out.push(0x04);
156 out.extend_from_slice(&m.to_le_bytes());
157 }
158 PartitionAggregateValue::Mean { sum, count } => {
159 out.push(0x05);
160 out.extend_from_slice(&sum.to_le_bytes());
161 out.extend_from_slice(&count.to_le_bytes());
162 }
163 }
164 }
165 out
166 }
167
168 pub fn decode(buf: &[u8]) -> Result<Self, StreamError> {
170 let read_u32 = |buf: &[u8], offset: usize| -> Result<(u32, usize), StreamError> {
171 if buf.len() < offset + 4 {
172 return Err(StreamError::Deserialization(
173 "PartitionAggregateState: truncated u32".to_string(),
174 ));
175 }
176 let mut a = [0u8; 4];
177 a.copy_from_slice(&buf[offset..offset + 4]);
178 Ok((u32::from_le_bytes(a), offset + 4))
179 };
180 let read_u64 = |buf: &[u8], offset: usize| -> Result<(u64, usize), StreamError> {
181 if buf.len() < offset + 8 {
182 return Err(StreamError::Deserialization(
183 "PartitionAggregateState: truncated u64".to_string(),
184 ));
185 }
186 let mut a = [0u8; 8];
187 a.copy_from_slice(&buf[offset..offset + 8]);
188 Ok((u64::from_le_bytes(a), offset + 8))
189 };
190 let read_f64 = |buf: &[u8], offset: usize| -> Result<(f64, usize), StreamError> {
191 if buf.len() < offset + 8 {
192 return Err(StreamError::Deserialization(
193 "PartitionAggregateState: truncated f64".to_string(),
194 ));
195 }
196 let mut a = [0u8; 8];
197 a.copy_from_slice(&buf[offset..offset + 8]);
198 Ok((f64::from_le_bytes(a), offset + 8))
199 };
200
201 let mut state = PartitionAggregateState::new();
202 let (n, mut p) = read_u32(buf, 0)?;
203 for _ in 0..n {
204 let (klen, np) = read_u32(buf, p)?;
205 p = np;
206 let kend = p + klen as usize;
207 if buf.len() < kend {
208 return Err(StreamError::Deserialization(
209 "PartitionAggregateState: truncated key".to_string(),
210 ));
211 }
212 let key = std::str::from_utf8(&buf[p..kend])
213 .map_err(|e| StreamError::Deserialization(format!("bad utf8: {e}")))?
214 .to_string();
215 p = kend;
216 if buf.len() < p + 1 {
217 return Err(StreamError::Deserialization(
218 "PartitionAggregateState: missing tag".to_string(),
219 ));
220 }
221 let tag = buf[p];
222 p += 1;
223 let v = match tag {
224 0x01 => {
225 let (c, np) = read_u64(buf, p)?;
226 p = np;
227 PartitionAggregateValue::Count(c)
228 }
229 0x02 => {
230 let (s, np) = read_f64(buf, p)?;
231 p = np;
232 PartitionAggregateValue::Sum(s)
233 }
234 0x03 => {
235 let (m, np) = read_f64(buf, p)?;
236 p = np;
237 PartitionAggregateValue::Min(m)
238 }
239 0x04 => {
240 let (m, np) = read_f64(buf, p)?;
241 p = np;
242 PartitionAggregateValue::Max(m)
243 }
244 0x05 => {
245 let (s, np) = read_f64(buf, p)?;
246 let (c, np) = read_u64(buf, np)?;
247 p = np;
248 PartitionAggregateValue::Mean { sum: s, count: c }
249 }
250 t => {
251 return Err(StreamError::Deserialization(format!(
252 "unknown PartitionAggregateValue tag {t}"
253 )));
254 }
255 };
256 state.put(key, v);
257 }
258 Ok(state)
259 }
260}
261
262#[derive(Debug, Clone)]
266pub struct ExactlyOnceAggregatorConfig {
267 pub dedup: DeduplicationConfig,
268 pub state_key: String,
270}
271
272impl Default for ExactlyOnceAggregatorConfig {
273 fn default() -> Self {
274 Self {
275 dedup: DeduplicationConfig::default(),
276 state_key: "aggregator/state".to_string(),
277 }
278 }
279}
280
281#[derive(Debug, Clone, Default)]
283pub struct ExactlyOnceAggregatorStats {
284 pub events_folded: u64,
285 pub duplicates_filtered: u64,
286 pub checkpoints_taken: u64,
287}
288
289pub struct ExactlyOnceAggregator {
293 config: ExactlyOnceAggregatorConfig,
294 backend: Arc<dyn StateBackend>,
295 processor: ExactlyOnceProcessor,
296 state: PartitionAggregateState,
297 stats: ExactlyOnceAggregatorStats,
298}
299
300impl ExactlyOnceAggregator {
301 pub fn new(config: ExactlyOnceAggregatorConfig, backend: Arc<dyn StateBackend>) -> Self {
303 let processor = ExactlyOnceProcessor::new(config.dedup.clone(), backend.clone());
304 Self {
305 config,
306 backend,
307 processor,
308 state: PartitionAggregateState::new(),
309 stats: ExactlyOnceAggregatorStats::default(),
310 }
311 }
312
313 pub fn fold<F>(
319 &mut self,
320 id: MessageId,
321 partition_key: &str,
322 update: F,
323 ) -> Result<Option<PartitionAggregateValue>, StreamError>
324 where
325 F: FnOnce(Option<&PartitionAggregateValue>) -> PartitionAggregateValue,
326 {
327 let prev = self.state.get(partition_key).cloned();
328 let new_value_for_state = update(prev.as_ref());
329 let key_for_state = partition_key.to_string();
330 let value_for_dedup_apply = new_value_for_state.clone();
331 let state_key_bytes = self.config.state_key.as_bytes().to_vec();
332
333 let mut updated = self.state.clone();
335 updated.put(key_for_state.clone(), new_value_for_state.clone());
336 let encoded = updated.encode();
337
338 let result = self.processor.process(id, |txn| {
339 txn.add_state_change(state_key_bytes, encoded);
340 Ok(value_for_dedup_apply)
341 })?;
342
343 match result {
344 Some(applied) => {
345 self.state.put(key_for_state, applied.clone());
346 self.stats.events_folded += 1;
347 Ok(Some(applied))
348 }
349 None => {
350 self.stats.duplicates_filtered += 1;
351 Ok(None)
352 }
353 }
354 }
355
356 pub fn get(&self, partition_key: &str) -> Option<&PartitionAggregateValue> {
358 self.state.get(partition_key)
359 }
360
361 pub fn set(&mut self, partition_key: &str, value: PartitionAggregateValue) {
363 self.state.put(partition_key.to_string(), value);
364 }
365
366 pub fn checkpoint(&mut self) -> Result<(), StreamError> {
368 let encoded = self.state.encode();
369 self.backend
370 .put(self.config.state_key.as_bytes(), &encoded)?;
371 self.stats.checkpoints_taken += 1;
372 Ok(())
373 }
374
375 pub fn restore(&mut self) -> Result<(), StreamError> {
377 match self.backend.get(self.config.state_key.as_bytes())? {
378 Some(bytes) => {
379 let state = PartitionAggregateState::decode(&bytes)?;
380 self.state = state;
381 Ok(())
382 }
383 None => Ok(()),
384 }
385 }
386
387 pub fn clear(&mut self) {
389 self.state = PartitionAggregateState::new();
390 }
391
392 pub fn stats(&self) -> &ExactlyOnceAggregatorStats {
394 &self.stats
395 }
396
397 pub fn partition_count(&self) -> usize {
399 self.state.len()
400 }
401}
402
403#[cfg(test)]
406mod tests {
407 use super::*;
408 use crate::state::distributed_state::InMemoryStateBackend;
409 use crate::state::exactly_once::MessageId;
410
411 fn fresh_aggregator() -> ExactlyOnceAggregator {
412 let backend: Arc<dyn StateBackend> = Arc::new(InMemoryStateBackend::new());
413 ExactlyOnceAggregator::new(ExactlyOnceAggregatorConfig::default(), backend)
414 }
415
416 #[test]
417 fn fold_increments_count_exactly_once() {
418 let mut agg = fresh_aggregator();
419 let id = MessageId::new("p", 0, 1);
420 let v = agg
421 .fold(id.clone(), "k", |prev| match prev {
422 Some(PartitionAggregateValue::Count(c)) => PartitionAggregateValue::Count(*c + 1),
423 _ => PartitionAggregateValue::Count(1),
424 })
425 .expect("fold ok");
426 assert_eq!(v, Some(PartitionAggregateValue::Count(1)));
427 let v = agg
429 .fold(id, "k", |prev| match prev {
430 Some(PartitionAggregateValue::Count(c)) => PartitionAggregateValue::Count(*c + 1),
431 _ => PartitionAggregateValue::Count(1),
432 })
433 .expect("fold ok");
434 assert_eq!(v, None);
435 assert_eq!(agg.get("k"), Some(&PartitionAggregateValue::Count(1)));
436 assert_eq!(agg.stats.duplicates_filtered, 1);
437 }
438
439 #[test]
440 fn checkpoint_restore_roundtrip() {
441 let mut agg = fresh_aggregator();
442 for i in 1..=5u64 {
443 let id = MessageId::new("p", 0, i);
444 agg.fold(id, "k1", |prev| match prev {
445 Some(PartitionAggregateValue::Sum(s)) => PartitionAggregateValue::Sum(s + i as f64),
446 _ => PartitionAggregateValue::Sum(i as f64),
447 })
448 .expect("fold ok");
449 }
450 assert_eq!(agg.get("k1"), Some(&PartitionAggregateValue::Sum(15.0)));
452
453 agg.checkpoint().expect("checkpoint ok");
455 agg.clear();
456 assert!(agg.get("k1").is_none());
457 agg.restore().expect("restore ok");
458 assert_eq!(agg.get("k1"), Some(&PartitionAggregateValue::Sum(15.0)));
459 }
460
461 #[test]
462 fn separate_partitions_isolated() {
463 let mut agg = fresh_aggregator();
464 agg.fold(MessageId::new("p", 0, 1), "a", |_| {
465 PartitionAggregateValue::Count(1)
466 })
467 .expect("ok");
468 agg.fold(MessageId::new("p", 0, 2), "b", |_| {
469 PartitionAggregateValue::Count(7)
470 })
471 .expect("ok");
472 assert_eq!(agg.get("a"), Some(&PartitionAggregateValue::Count(1)));
473 assert_eq!(agg.get("b"), Some(&PartitionAggregateValue::Count(7)));
474 assert_eq!(agg.partition_count(), 2);
475 }
476
477 #[test]
478 fn encode_decode_round_trip() {
479 let mut s = PartitionAggregateState::new();
480 s.put("a", PartitionAggregateValue::Count(42));
481 s.put("b", PartitionAggregateValue::Sum(3.5));
482 s.put("c", PartitionAggregateValue::Min(-1.0));
483 s.put("d", PartitionAggregateValue::Max(99.0));
484 s.put(
485 "mean_e",
486 PartitionAggregateValue::Mean {
487 sum: 100.0,
488 count: 4,
489 },
490 );
491 let bytes = s.encode();
492 let decoded = PartitionAggregateState::decode(&bytes).expect("decode");
493 assert_eq!(decoded.len(), 5);
494 assert_eq!(decoded.get("a"), Some(&PartitionAggregateValue::Count(42)));
495 assert_eq!(decoded.get("b"), Some(&PartitionAggregateValue::Sum(3.5)));
496 assert_eq!(decoded.get("c"), Some(&PartitionAggregateValue::Min(-1.0)));
497 assert_eq!(decoded.get("d"), Some(&PartitionAggregateValue::Max(99.0)));
498 match decoded.get("mean_e") {
499 Some(PartitionAggregateValue::Mean { sum, count }) => {
500 assert!((sum - 100.0).abs() < 1e-9);
501 assert_eq!(*count, 4);
502 }
503 other => panic!("expected Mean, got {other:?}"),
504 }
505 }
506
507 #[test]
508 fn checkpoint_after_dedup_does_not_double_apply() {
509 let mut agg = fresh_aggregator();
510 let id = MessageId::new("p", 0, 1);
511 agg.fold(id.clone(), "k", |_| PartitionAggregateValue::Count(5))
512 .expect("ok");
513 agg.checkpoint().expect("ok");
514 let backend = agg.backend.clone();
517 let mut recovered =
518 ExactlyOnceAggregator::new(ExactlyOnceAggregatorConfig::default(), backend);
519 recovered.restore().expect("ok");
520 assert_eq!(recovered.get("k"), Some(&PartitionAggregateValue::Count(5)));
521 }
522}