use std::collections::HashMap;
use crate::error::StreamError;
use super::WatermarkAligner;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct OperatorId(pub String);
impl OperatorId {
pub fn new(name: impl Into<String>) -> Self {
Self(name.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl<S: Into<String>> From<S> for OperatorId {
fn from(s: S) -> Self {
OperatorId(s.into())
}
}
pub struct OperatorWatermarkAggregator {
operator_id: OperatorId,
inputs: WatermarkAligner,
last_emitted: Option<i64>,
}
impl OperatorWatermarkAggregator {
pub fn new(operator_id: impl Into<OperatorId>) -> Self {
Self {
operator_id: operator_id.into(),
inputs: WatermarkAligner::new(),
last_emitted: None,
}
}
pub fn observe(
&mut self,
input: &OperatorId,
watermark_ms: i64,
expected_inputs: usize,
) -> Result<Option<i64>, StreamError> {
self.inputs.update(input.as_str(), watermark_ms);
if self.inputs.source_count() < expected_inputs {
return Ok(None);
}
let candidate = self.inputs.global_watermark();
if let Some(prev) = self.last_emitted {
if candidate < prev {
return Err(StreamError::WatermarkViolation {
operator_id: self.operator_id.0.clone(),
reason: format!("candidate output watermark {candidate} < previous {prev}"),
});
}
}
self.last_emitted = Some(candidate);
Ok(Some(candidate))
}
pub fn last_emitted(&self) -> Option<i64> {
self.last_emitted
}
pub fn operator_id(&self) -> &OperatorId {
&self.operator_id
}
}
pub struct WatermarkPropagator {
downstream_of: HashMap<OperatorId, Vec<OperatorId>>,
upstream_of: HashMap<OperatorId, Vec<OperatorId>>,
operators: HashMap<OperatorId, OperatorWatermarkAggregator>,
sinks: Vec<OperatorId>,
}
impl Default for WatermarkPropagator {
fn default() -> Self {
Self::new()
}
}
impl WatermarkPropagator {
pub fn new() -> Self {
Self {
downstream_of: HashMap::new(),
upstream_of: HashMap::new(),
operators: HashMap::new(),
sinks: Vec::new(),
}
}
pub fn add_operator(&mut self, operator: OperatorId) {
self.operators
.entry(operator.clone())
.or_insert_with(|| OperatorWatermarkAggregator::new(operator));
}
pub fn add_edge(&mut self, upstream: OperatorId, downstream: OperatorId) {
self.add_operator(upstream.clone());
self.add_operator(downstream.clone());
self.downstream_of
.entry(upstream.clone())
.or_default()
.push(downstream.clone());
self.upstream_of
.entry(downstream)
.or_default()
.push(upstream);
self.recompute_sinks();
}
fn recompute_sinks(&mut self) {
self.sinks = self
.operators
.keys()
.filter(|op| !self.downstream_of.contains_key(*op))
.cloned()
.collect();
}
pub fn push_source(
&mut self,
source: &OperatorId,
watermark_ms: i64,
) -> Result<Option<i64>, StreamError> {
if !self.operators.contains_key(source) {
self.add_operator(source.clone());
}
let agg = self
.operators
.get_mut(source)
.expect("source aggregator just added");
agg.observe(source, watermark_ms, 1)?;
let mut frontier: Vec<OperatorId> =
self.downstream_of.get(source).cloned().unwrap_or_default();
while let Some(op) = frontier.pop() {
let upstreams: Vec<OperatorId> = self.upstream_of.get(&op).cloned().unwrap_or_default();
if upstreams.is_empty() {
continue;
}
let mut readings: Vec<(OperatorId, i64)> = Vec::with_capacity(upstreams.len());
let mut all_ready = true;
for u in &upstreams {
match self.operators.get(u).and_then(|a| a.last_emitted()) {
Some(v) => readings.push((u.clone(), v)),
None => {
all_ready = false;
break;
}
}
}
if !all_ready {
continue;
}
let n_inputs = upstreams.len();
for (u, wm) in &readings {
let agg = self
.operators
.get_mut(&op)
.expect("operator known to topology");
agg.observe(u, *wm, n_inputs)?;
}
if let Some(ds) = self.downstream_of.get(&op) {
frontier.extend(ds.iter().cloned());
}
}
Ok(self.global_watermark())
}
pub fn global_watermark(&self) -> Option<i64> {
if self.sinks.is_empty() {
let mut min_v: Option<i64> = None;
for (_, agg) in self.operators.iter() {
let v = agg.last_emitted()?;
min_v = Some(min_v.map(|m| m.min(v)).unwrap_or(v));
}
return min_v;
}
let mut min_v: Option<i64> = None;
for sink in &self.sinks {
let v = self.operators.get(sink).and_then(|a| a.last_emitted())?;
min_v = Some(min_v.map(|m| m.min(v)).unwrap_or(v));
}
min_v
}
pub fn watermark_of(&self, op: &OperatorId) -> Option<i64> {
self.operators.get(op).and_then(|a| a.last_emitted())
}
pub fn operator_count(&self) -> usize {
self.operators.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn op(name: &str) -> OperatorId {
OperatorId::new(name)
}
#[test]
fn aggregator_emits_min_across_inputs() {
let mut agg = OperatorWatermarkAggregator::new("merge");
let a = op("a");
let b = op("b");
assert_eq!(agg.observe(&a, 1_000, 2).unwrap(), None);
let out = agg.observe(&b, 800, 2).unwrap();
assert_eq!(out, Some(800));
}
#[test]
fn aggregator_is_monotonic() {
let mut agg = OperatorWatermarkAggregator::new("op");
let a = op("a");
agg.observe(&a, 5_000, 1).unwrap();
let err = agg.observe(&a, 4_000, 1).expect_err("monotonic");
match err {
StreamError::WatermarkViolation { operator_id, .. } => {
assert_eq!(operator_id, "op");
}
other => panic!("expected WatermarkViolation, got {other:?}"),
}
}
#[test]
fn aggregator_equal_watermark_is_ok() {
let mut agg = OperatorWatermarkAggregator::new("op");
let a = op("a");
agg.observe(&a, 1_000, 1).unwrap();
assert_eq!(agg.observe(&a, 1_000, 1).unwrap(), Some(1_000));
}
#[test]
fn propagator_single_source_to_single_sink() {
let mut p = WatermarkPropagator::new();
let s = op("source");
let snk = op("sink");
p.add_edge(s.clone(), snk.clone());
let g = p.push_source(&s, 1_000).unwrap();
assert_eq!(g, Some(1_000));
assert_eq!(p.watermark_of(&s), Some(1_000));
assert_eq!(p.watermark_of(&snk), Some(1_000));
}
#[test]
fn propagator_two_sources_take_min() {
let mut p = WatermarkPropagator::new();
let a = op("a");
let b = op("b");
let j = op("j");
let s = op("sink");
p.add_edge(a.clone(), j.clone());
p.add_edge(b.clone(), j.clone());
p.add_edge(j.clone(), s.clone());
let g = p.push_source(&a, 1_000).unwrap();
assert_eq!(p.watermark_of(&j), None);
assert_eq!(g, None);
let g = p.push_source(&b, 700).unwrap();
assert_eq!(p.watermark_of(&j), Some(700));
assert_eq!(p.watermark_of(&s), Some(700));
assert_eq!(g, Some(700));
}
#[test]
fn propagator_global_watermark_is_min_across_sinks() {
let mut p = WatermarkPropagator::new();
let s = op("src");
let sa = op("sa");
let sb = op("sb");
p.add_edge(s.clone(), sa.clone());
p.add_edge(s.clone(), sb.clone());
let g = p.push_source(&s, 5_000).unwrap();
assert_eq!(g, Some(5_000));
assert_eq!(p.watermark_of(&sa), Some(5_000));
assert_eq!(p.watermark_of(&sb), Some(5_000));
}
#[test]
fn propagator_is_monotonic_across_topology() {
let mut p = WatermarkPropagator::new();
let s = op("src");
let snk = op("snk");
p.add_edge(s.clone(), snk.clone());
p.push_source(&s, 1_000).unwrap();
let err = p.push_source(&s, 500).expect_err("monotonic");
assert!(matches!(err, StreamError::WatermarkViolation { .. }));
}
#[test]
fn propagator_no_topology_treats_each_op_as_sink() {
let mut p = WatermarkPropagator::new();
let s = op("solo");
let g = p.push_source(&s, 42).unwrap();
assert_eq!(g, Some(42));
assert_eq!(p.operator_count(), 1);
}
}