Skip to main content

oxirs_stream/aggregation/
exactly_once.rs

1//! Exactly-once aggregation under operator parallelism.
2//!
3//! [`ExactlyOnceAggregator`] composes
4//! [`crate::state::exactly_once::ExactlyOnceProcessor`] with a per-partition
5//! aggregation state to guarantee that:
6//!
7//! 1. Each event is folded into the aggregate exactly once, even under
8//!    re-delivery.
9//! 2. Aggregate state is checkpointable per partition (snapshots are
10//!    obtained via the underlying [`crate::state::distributed_state::StateBackend`]).
11//! 3. State recovery after a failure restores the same aggregate values that
12//!    were emitted before the crash.
13//!
14//! ## Usage outline
15//!
16//! ```ignore
17//! let backend = Arc::new(InMemoryStateBackend::new());
18//! let mut agg = ExactlyOnceAggregator::<u64>::new(
19//!     ExactlyOnceAggregatorConfig::default(),
20//!     backend.clone(),
21//! );
22//! agg.fold(MessageId::new("p", 0, 1), partition_key, value, |state, v| state + v)?;
23//! ```
24//!
25//! `partition_key` is any string-typed key (the operator-parallel shard).
26
27use std::collections::HashMap;
28use std::sync::Arc;
29
30use crate::error::StreamError;
31use crate::state::distributed_state::StateBackend;
32use crate::state::exactly_once::{DeduplicationConfig, ExactlyOnceProcessor, MessageId};
33
34// ─── Aggregate value types ───────────────────────────────────────────────────
35
36/// Aggregate values supported by [`PartitionAggregateState`].
37#[derive(Debug, Clone, PartialEq)]
38pub enum PartitionAggregateValue {
39    Count(u64),
40    Sum(f64),
41    Min(f64),
42    Max(f64),
43    /// Mean tracked as `(sum, count)`.
44    Mean {
45        sum: f64,
46        count: u64,
47    },
48}
49
50impl PartitionAggregateValue {
51    /// Return `true` if the value is the identity element (initial state).
52    pub fn is_initial(&self) -> bool {
53        matches!(
54            self,
55            PartitionAggregateValue::Count(0)
56                | PartitionAggregateValue::Sum(0.0)
57                | PartitionAggregateValue::Mean { sum: _, count: 0 }
58        )
59    }
60}
61
62// ─── Per-partition state ─────────────────────────────────────────────────────
63
64/// Per-partition aggregate state.
65///
66/// `K` is the partition/group key (typically a `String`).  `V` is the typed
67/// aggregate value held for that key.
68#[derive(Debug, Clone)]
69pub struct PartitionAggregateState {
70    inner: HashMap<String, PartitionAggregateValue>,
71}
72
73impl Default for PartitionAggregateState {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl PartitionAggregateState {
80    /// Empty state.
81    pub fn new() -> Self {
82        Self {
83            inner: HashMap::new(),
84        }
85    }
86
87    /// Look up the current value for `key`.
88    pub fn get(&self, key: &str) -> Option<&PartitionAggregateValue> {
89        self.inner.get(key)
90    }
91
92    /// Insert or replace the value for `key`.
93    pub fn put(&mut self, key: impl Into<String>, value: PartitionAggregateValue) {
94        self.inner.insert(key.into(), value);
95    }
96
97    /// Number of keys currently tracked.
98    pub fn len(&self) -> usize {
99        self.inner.len()
100    }
101
102    /// True iff no keys are tracked.
103    pub fn is_empty(&self) -> bool {
104        self.inner.is_empty()
105    }
106
107    /// Iterate over all (key, value) pairs.
108    pub fn iter(&self) -> impl Iterator<Item = (&String, &PartitionAggregateValue)> {
109        self.inner.iter()
110    }
111
112    // ─── Encoding ────────────────────────────────────────────────────────────
113
114    /// Encode this state to a deterministic byte vector.
115    ///
116    /// Format
117    ///
118    /// * `[u32 len]` — number of entries.
119    /// * `len` × `[u32 key_len][key…][u8 tag][payload]`.
120    /// * Tag/payload:
121    ///   * `0x01` Count: `[u64]`.
122    ///   * `0x02` Sum:   `[f64]`.
123    ///   * `0x03` Min:   `[f64]`.
124    ///   * `0x04` Max:   `[f64]`.
125    ///   * `0x05` Mean:  `[f64 sum][u64 count]`.
126    ///
127    /// All integers are little-endian.
128    pub fn encode(&self) -> Vec<u8> {
129        let mut out = Vec::new();
130        out.extend_from_slice(&(self.inner.len() as u32).to_le_bytes());
131        // Sort by key for deterministic output.
132        let mut keys: Vec<&String> = self.inner.keys().collect();
133        keys.sort();
134        for k in keys {
135            let v = match self.inner.get(k) {
136                Some(v) => v,
137                None => continue,
138            };
139            out.extend_from_slice(&(k.len() as u32).to_le_bytes());
140            out.extend_from_slice(k.as_bytes());
141            match v {
142                PartitionAggregateValue::Count(c) => {
143                    out.push(0x01);
144                    out.extend_from_slice(&c.to_le_bytes());
145                }
146                PartitionAggregateValue::Sum(s) => {
147                    out.push(0x02);
148                    out.extend_from_slice(&s.to_le_bytes());
149                }
150                PartitionAggregateValue::Min(m) => {
151                    out.push(0x03);
152                    out.extend_from_slice(&m.to_le_bytes());
153                }
154                PartitionAggregateValue::Max(m) => {
155                    out.push(0x04);
156                    out.extend_from_slice(&m.to_le_bytes());
157                }
158                PartitionAggregateValue::Mean { sum, count } => {
159                    out.push(0x05);
160                    out.extend_from_slice(&sum.to_le_bytes());
161                    out.extend_from_slice(&count.to_le_bytes());
162                }
163            }
164        }
165        out
166    }
167
168    /// Decode a state previously produced by [`Self::encode`].
169    pub fn decode(buf: &[u8]) -> Result<Self, StreamError> {
170        let read_u32 = |buf: &[u8], offset: usize| -> Result<(u32, usize), StreamError> {
171            if buf.len() < offset + 4 {
172                return Err(StreamError::Deserialization(
173                    "PartitionAggregateState: truncated u32".to_string(),
174                ));
175            }
176            let mut a = [0u8; 4];
177            a.copy_from_slice(&buf[offset..offset + 4]);
178            Ok((u32::from_le_bytes(a), offset + 4))
179        };
180        let read_u64 = |buf: &[u8], offset: usize| -> Result<(u64, usize), StreamError> {
181            if buf.len() < offset + 8 {
182                return Err(StreamError::Deserialization(
183                    "PartitionAggregateState: truncated u64".to_string(),
184                ));
185            }
186            let mut a = [0u8; 8];
187            a.copy_from_slice(&buf[offset..offset + 8]);
188            Ok((u64::from_le_bytes(a), offset + 8))
189        };
190        let read_f64 = |buf: &[u8], offset: usize| -> Result<(f64, usize), StreamError> {
191            if buf.len() < offset + 8 {
192                return Err(StreamError::Deserialization(
193                    "PartitionAggregateState: truncated f64".to_string(),
194                ));
195            }
196            let mut a = [0u8; 8];
197            a.copy_from_slice(&buf[offset..offset + 8]);
198            Ok((f64::from_le_bytes(a), offset + 8))
199        };
200
201        let mut state = PartitionAggregateState::new();
202        let (n, mut p) = read_u32(buf, 0)?;
203        for _ in 0..n {
204            let (klen, np) = read_u32(buf, p)?;
205            p = np;
206            let kend = p + klen as usize;
207            if buf.len() < kend {
208                return Err(StreamError::Deserialization(
209                    "PartitionAggregateState: truncated key".to_string(),
210                ));
211            }
212            let key = std::str::from_utf8(&buf[p..kend])
213                .map_err(|e| StreamError::Deserialization(format!("bad utf8: {e}")))?
214                .to_string();
215            p = kend;
216            if buf.len() < p + 1 {
217                return Err(StreamError::Deserialization(
218                    "PartitionAggregateState: missing tag".to_string(),
219                ));
220            }
221            let tag = buf[p];
222            p += 1;
223            let v = match tag {
224                0x01 => {
225                    let (c, np) = read_u64(buf, p)?;
226                    p = np;
227                    PartitionAggregateValue::Count(c)
228                }
229                0x02 => {
230                    let (s, np) = read_f64(buf, p)?;
231                    p = np;
232                    PartitionAggregateValue::Sum(s)
233                }
234                0x03 => {
235                    let (m, np) = read_f64(buf, p)?;
236                    p = np;
237                    PartitionAggregateValue::Min(m)
238                }
239                0x04 => {
240                    let (m, np) = read_f64(buf, p)?;
241                    p = np;
242                    PartitionAggregateValue::Max(m)
243                }
244                0x05 => {
245                    let (s, np) = read_f64(buf, p)?;
246                    let (c, np) = read_u64(buf, np)?;
247                    p = np;
248                    PartitionAggregateValue::Mean { sum: s, count: c }
249                }
250                t => {
251                    return Err(StreamError::Deserialization(format!(
252                        "unknown PartitionAggregateValue tag {t}"
253                    )));
254                }
255            };
256            state.put(key, v);
257        }
258        Ok(state)
259    }
260}
261
262// ─── Aggregator config / stats ───────────────────────────────────────────────
263
264/// Configuration for [`ExactlyOnceAggregator`].
265#[derive(Debug, Clone)]
266pub struct ExactlyOnceAggregatorConfig {
267    pub dedup: DeduplicationConfig,
268    /// Logical name used for the state key inside the backend.
269    pub state_key: String,
270}
271
272impl Default for ExactlyOnceAggregatorConfig {
273    fn default() -> Self {
274        Self {
275            dedup: DeduplicationConfig::default(),
276            state_key: "aggregator/state".to_string(),
277        }
278    }
279}
280
281/// Runtime statistics.
282#[derive(Debug, Clone, Default)]
283pub struct ExactlyOnceAggregatorStats {
284    pub events_folded: u64,
285    pub duplicates_filtered: u64,
286    pub checkpoints_taken: u64,
287}
288
289// ─── ExactlyOnceAggregator ───────────────────────────────────────────────────
290
291/// Aggregator wrapper that guarantees exactly-once fold semantics.
292pub struct ExactlyOnceAggregator {
293    config: ExactlyOnceAggregatorConfig,
294    backend: Arc<dyn StateBackend>,
295    processor: ExactlyOnceProcessor,
296    state: PartitionAggregateState,
297    stats: ExactlyOnceAggregatorStats,
298}
299
300impl ExactlyOnceAggregator {
301    /// Create a new aggregator backed by `backend`.
302    pub fn new(config: ExactlyOnceAggregatorConfig, backend: Arc<dyn StateBackend>) -> Self {
303        let processor = ExactlyOnceProcessor::new(config.dedup.clone(), backend.clone());
304        Self {
305            config,
306            backend,
307            processor,
308            state: PartitionAggregateState::new(),
309            stats: ExactlyOnceAggregatorStats::default(),
310        }
311    }
312
313    /// Fold a single event into the aggregate state.
314    ///
315    /// `id` uniquely identifies the message (used for dedup).  The `update`
316    /// closure produces the *new* aggregate value for the partition key from
317    /// the previous value.
318    pub fn fold<F>(
319        &mut self,
320        id: MessageId,
321        partition_key: &str,
322        update: F,
323    ) -> Result<Option<PartitionAggregateValue>, StreamError>
324    where
325        F: FnOnce(Option<&PartitionAggregateValue>) -> PartitionAggregateValue,
326    {
327        let prev = self.state.get(partition_key).cloned();
328        let new_value_for_state = update(prev.as_ref());
329        let key_for_state = partition_key.to_string();
330        let value_for_dedup_apply = new_value_for_state.clone();
331        let state_key_bytes = self.config.state_key.as_bytes().to_vec();
332
333        // Encode the *post-update* state for the transaction.
334        let mut updated = self.state.clone();
335        updated.put(key_for_state.clone(), new_value_for_state.clone());
336        let encoded = updated.encode();
337
338        let result = self.processor.process(id, |txn| {
339            txn.add_state_change(state_key_bytes, encoded);
340            Ok(value_for_dedup_apply)
341        })?;
342
343        match result {
344            Some(applied) => {
345                self.state.put(key_for_state, applied.clone());
346                self.stats.events_folded += 1;
347                Ok(Some(applied))
348            }
349            None => {
350                self.stats.duplicates_filtered += 1;
351                Ok(None)
352            }
353        }
354    }
355
356    /// Look up the current aggregate value for a partition key.
357    pub fn get(&self, partition_key: &str) -> Option<&PartitionAggregateValue> {
358        self.state.get(partition_key)
359    }
360
361    /// Manually overwrite the value for a partition (used during recovery).
362    pub fn set(&mut self, partition_key: &str, value: PartitionAggregateValue) {
363        self.state.put(partition_key.to_string(), value);
364    }
365
366    /// Snapshot the current state into the backend (idempotent).
367    pub fn checkpoint(&mut self) -> Result<(), StreamError> {
368        let encoded = self.state.encode();
369        self.backend
370            .put(self.config.state_key.as_bytes(), &encoded)?;
371        self.stats.checkpoints_taken += 1;
372        Ok(())
373    }
374
375    /// Restore aggregate state from the backend (no-op if absent).
376    pub fn restore(&mut self) -> Result<(), StreamError> {
377        match self.backend.get(self.config.state_key.as_bytes())? {
378            Some(bytes) => {
379                let state = PartitionAggregateState::decode(&bytes)?;
380                self.state = state;
381                Ok(())
382            }
383            None => Ok(()),
384        }
385    }
386
387    /// Drop the in-memory aggregate state (test/recovery helper).
388    pub fn clear(&mut self) {
389        self.state = PartitionAggregateState::new();
390    }
391
392    /// Snapshot statistics.
393    pub fn stats(&self) -> &ExactlyOnceAggregatorStats {
394        &self.stats
395    }
396
397    /// Number of partitions currently tracked.
398    pub fn partition_count(&self) -> usize {
399        self.state.len()
400    }
401}
402
403// ─── Tests ───────────────────────────────────────────────────────────────────
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use crate::state::distributed_state::InMemoryStateBackend;
409    use crate::state::exactly_once::MessageId;
410
411    fn fresh_aggregator() -> ExactlyOnceAggregator {
412        let backend: Arc<dyn StateBackend> = Arc::new(InMemoryStateBackend::new());
413        ExactlyOnceAggregator::new(ExactlyOnceAggregatorConfig::default(), backend)
414    }
415
416    #[test]
417    fn fold_increments_count_exactly_once() {
418        let mut agg = fresh_aggregator();
419        let id = MessageId::new("p", 0, 1);
420        let v = agg
421            .fold(id.clone(), "k", |prev| match prev {
422                Some(PartitionAggregateValue::Count(c)) => PartitionAggregateValue::Count(*c + 1),
423                _ => PartitionAggregateValue::Count(1),
424            })
425            .expect("fold ok");
426        assert_eq!(v, Some(PartitionAggregateValue::Count(1)));
427        // Replay → no double-count.
428        let v = agg
429            .fold(id, "k", |prev| match prev {
430                Some(PartitionAggregateValue::Count(c)) => PartitionAggregateValue::Count(*c + 1),
431                _ => PartitionAggregateValue::Count(1),
432            })
433            .expect("fold ok");
434        assert_eq!(v, None);
435        assert_eq!(agg.get("k"), Some(&PartitionAggregateValue::Count(1)));
436        assert_eq!(agg.stats.duplicates_filtered, 1);
437    }
438
439    #[test]
440    fn checkpoint_restore_roundtrip() {
441        let mut agg = fresh_aggregator();
442        for i in 1..=5u64 {
443            let id = MessageId::new("p", 0, i);
444            agg.fold(id, "k1", |prev| match prev {
445                Some(PartitionAggregateValue::Sum(s)) => PartitionAggregateValue::Sum(s + i as f64),
446                _ => PartitionAggregateValue::Sum(i as f64),
447            })
448            .expect("fold ok");
449        }
450        // Sum 1+2+3+4+5 = 15.
451        assert_eq!(agg.get("k1"), Some(&PartitionAggregateValue::Sum(15.0)));
452
453        // Checkpoint, clear, restore.
454        agg.checkpoint().expect("checkpoint ok");
455        agg.clear();
456        assert!(agg.get("k1").is_none());
457        agg.restore().expect("restore ok");
458        assert_eq!(agg.get("k1"), Some(&PartitionAggregateValue::Sum(15.0)));
459    }
460
461    #[test]
462    fn separate_partitions_isolated() {
463        let mut agg = fresh_aggregator();
464        agg.fold(MessageId::new("p", 0, 1), "a", |_| {
465            PartitionAggregateValue::Count(1)
466        })
467        .expect("ok");
468        agg.fold(MessageId::new("p", 0, 2), "b", |_| {
469            PartitionAggregateValue::Count(7)
470        })
471        .expect("ok");
472        assert_eq!(agg.get("a"), Some(&PartitionAggregateValue::Count(1)));
473        assert_eq!(agg.get("b"), Some(&PartitionAggregateValue::Count(7)));
474        assert_eq!(agg.partition_count(), 2);
475    }
476
477    #[test]
478    fn encode_decode_round_trip() {
479        let mut s = PartitionAggregateState::new();
480        s.put("a", PartitionAggregateValue::Count(42));
481        s.put("b", PartitionAggregateValue::Sum(3.5));
482        s.put("c", PartitionAggregateValue::Min(-1.0));
483        s.put("d", PartitionAggregateValue::Max(99.0));
484        s.put(
485            "mean_e",
486            PartitionAggregateValue::Mean {
487                sum: 100.0,
488                count: 4,
489            },
490        );
491        let bytes = s.encode();
492        let decoded = PartitionAggregateState::decode(&bytes).expect("decode");
493        assert_eq!(decoded.len(), 5);
494        assert_eq!(decoded.get("a"), Some(&PartitionAggregateValue::Count(42)));
495        assert_eq!(decoded.get("b"), Some(&PartitionAggregateValue::Sum(3.5)));
496        assert_eq!(decoded.get("c"), Some(&PartitionAggregateValue::Min(-1.0)));
497        assert_eq!(decoded.get("d"), Some(&PartitionAggregateValue::Max(99.0)));
498        match decoded.get("mean_e") {
499            Some(PartitionAggregateValue::Mean { sum, count }) => {
500                assert!((sum - 100.0).abs() < 1e-9);
501                assert_eq!(*count, 4);
502            }
503            other => panic!("expected Mean, got {other:?}"),
504        }
505    }
506
507    #[test]
508    fn checkpoint_after_dedup_does_not_double_apply() {
509        let mut agg = fresh_aggregator();
510        let id = MessageId::new("p", 0, 1);
511        agg.fold(id.clone(), "k", |_| PartitionAggregateValue::Count(5))
512            .expect("ok");
513        agg.checkpoint().expect("ok");
514        // Recover into a *new* aggregator on the same backend to simulate
515        // crash recovery.
516        let backend = agg.backend.clone();
517        let mut recovered =
518            ExactlyOnceAggregator::new(ExactlyOnceAggregatorConfig::default(), backend);
519        recovered.restore().expect("ok");
520        assert_eq!(recovered.get("k"), Some(&PartitionAggregateValue::Count(5)));
521    }
522}