Skip to main content

laminar_sql/datafusion/
watermark_filter.rs

1//! Dynamic watermark filter for scan-level late-data pruning
2//!
3//! Pushes a `ts >= watermark` predicate down to `StreamingScanExec` so
4//! late rows are dropped before expression evaluation. The shared
5//! `Arc<AtomicI64>` watermark is the same one Ring 0 already updates.
6
7use std::fmt::{Debug, Formatter};
8use std::pin::Pin;
9use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use arrow_array::RecordBatch;
14use arrow_schema::SchemaRef;
15use datafusion::physical_plan::RecordBatchStream;
16use datafusion_common::DataFusionError;
17use futures::Stream;
18use laminar_core::time::{filter_batch_by_timestamp, ThresholdOp};
19
20/// Dynamic filter that drops rows older than the current watermark.
21///
22/// Holds a shared watermark atomic (same as [`super::watermark_udf::WatermarkUdf`])
23/// and a monotonic generation counter that increments on each watermark
24/// advance. The generation lets downstream consumers detect stale state
25/// without comparing full watermark values.
26pub struct WatermarkDynamicFilter {
27    watermark_ms: Arc<AtomicI64>,
28    generation: Arc<AtomicU64>,
29    time_column: String,
30}
31
32impl Debug for WatermarkDynamicFilter {
33    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("WatermarkDynamicFilter")
35            .field("watermark_ms", &self.watermark_ms.load(Ordering::Acquire))
36            .field("generation", &self.generation.load(Ordering::Acquire))
37            .field("time_column", &self.time_column)
38            .finish()
39    }
40}
41
42impl WatermarkDynamicFilter {
43    /// Creates a new watermark filter.
44    ///
45    /// # Arguments
46    ///
47    /// * `watermark_ms` - Shared atomic holding the current watermark
48    ///   in epoch milliseconds. Values < 0 mean "uninitialized".
49    /// * `generation` - Monotonic counter incremented on each advance.
50    /// * `time_column` - Name of the event-time column in record batches.
51    pub fn new(
52        watermark_ms: Arc<AtomicI64>,
53        generation: Arc<AtomicU64>,
54        time_column: String,
55    ) -> Self {
56        Self {
57            watermark_ms,
58            generation,
59            time_column,
60        }
61    }
62
63    /// Advances the watermark if `new_ms` exceeds the current value.
64    ///
65    /// On a successful advance the generation counter is incremented.
66    /// No-op when `new_ms <= current`.
67    pub fn advance_watermark(&self, new_ms: i64) {
68        // Atomic compare-and-swap loop to avoid TOCTOU race where two
69        // concurrent callers both see old < new_ms and clobber each other.
70        loop {
71            let old = self.watermark_ms.load(Ordering::Acquire);
72            if new_ms <= old {
73                break;
74            }
75            if self
76                .watermark_ms
77                .compare_exchange(old, new_ms, Ordering::Release, Ordering::Acquire)
78                .is_ok()
79            {
80                self.generation.fetch_add(1, Ordering::Release);
81                break;
82            }
83        }
84    }
85
86    /// Returns the current generation (monotonically increasing).
87    #[must_use]
88    pub fn generation(&self) -> u64 {
89        self.generation.load(Ordering::Acquire)
90    }
91
92    /// Returns the current watermark in epoch milliseconds.
93    #[must_use]
94    pub fn watermark_ms(&self) -> i64 {
95        self.watermark_ms.load(Ordering::Acquire)
96    }
97
98    /// Keep only rows where `time_column >= watermark`. Returns
99    /// `Ok(None)` when nothing survives; passes through untouched
100    /// while the watermark is uninitialised (< 0).
101    ///
102    /// # Errors
103    ///
104    /// `DataFusionError::Plan` when `time_column` is missing or isn't
105    /// a `Timestamp(_)` (propagated from [`filter_batch_by_timestamp`]).
106    pub fn filter_batch(
107        &self,
108        batch: &RecordBatch,
109    ) -> Result<Option<RecordBatch>, DataFusionError> {
110        let wm = self.watermark_ms.load(Ordering::Acquire);
111        if wm < 0 {
112            return Ok(Some(batch.clone()));
113        }
114
115        filter_batch_by_timestamp(batch, &self.time_column, wm, ThresholdOp::GreaterEq)
116            .map_err(|e| DataFusionError::Plan(format!("watermark filter: {e}")))
117    }
118}
119
120/// Stream wrapper that applies watermark filtering to each batch.
121///
122/// Wraps a `SendableRecordBatchStream` and drops rows older than the
123/// current watermark before passing them downstream. Follows the same
124/// pattern as `ProjectingStream` in `channel_source.rs`.
125pub(crate) struct WatermarkFilterStream {
126    inner: Pin<Box<dyn Stream<Item = Result<RecordBatch, DataFusionError>> + Send>>,
127    filter: Arc<WatermarkDynamicFilter>,
128    schema: SchemaRef,
129}
130
131impl WatermarkFilterStream {
132    /// Creates a new watermark-filtered stream.
133    pub fn new(
134        inner: datafusion::execution::SendableRecordBatchStream,
135        filter: Arc<WatermarkDynamicFilter>,
136        schema: SchemaRef,
137    ) -> Self {
138        Self {
139            inner,
140            filter,
141            schema,
142        }
143    }
144}
145
146impl Debug for WatermarkFilterStream {
147    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
148        f.debug_struct("WatermarkFilterStream")
149            .field("filter", &self.filter)
150            .field("schema", &self.schema)
151            .finish_non_exhaustive()
152    }
153}
154
155impl Stream for WatermarkFilterStream {
156    type Item = Result<RecordBatch, DataFusionError>;
157
158    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
159        loop {
160            match self.inner.as_mut().poll_next(cx) {
161                Poll::Ready(Some(Ok(batch))) => match self.filter.filter_batch(&batch) {
162                    Ok(Some(filtered)) => return Poll::Ready(Some(Ok(filtered))),
163                    Ok(None) => {
164                        // All rows filtered out — loop to try the next batch
165                    }
166                    Err(e) => return Poll::Ready(Some(Err(e))),
167                },
168                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
169                Poll::Ready(None) => return Poll::Ready(None),
170                Poll::Pending => return Poll::Pending,
171            }
172        }
173    }
174}
175
176impl RecordBatchStream for WatermarkFilterStream {
177    fn schema(&self) -> SchemaRef {
178        Arc::clone(&self.schema)
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use arrow_array::{Int64Array, TimestampMillisecondArray, TimestampNanosecondArray};
186    use arrow_schema::{DataType, Field, Schema, TimeUnit};
187
188    fn make_millis_batch(timestamps: Vec<i64>) -> RecordBatch {
189        let schema = Arc::new(Schema::new(vec![
190            Field::new(
191                "ts",
192                DataType::Timestamp(TimeUnit::Millisecond, None),
193                false,
194            ),
195            Field::new("value", DataType::Int64, false),
196        ]));
197        #[allow(clippy::cast_possible_wrap)]
198        let values: Vec<i64> = (0..timestamps.len() as i64).collect();
199        RecordBatch::try_new(
200            schema,
201            vec![
202                Arc::new(TimestampMillisecondArray::from(timestamps)),
203                Arc::new(Int64Array::from(values)),
204            ],
205        )
206        .unwrap()
207    }
208
209    fn make_nanos_batch(timestamps: Vec<i64>) -> RecordBatch {
210        let schema = Arc::new(Schema::new(vec![
211            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false),
212            Field::new("value", DataType::Int64, false),
213        ]));
214        #[allow(clippy::cast_possible_wrap)]
215        let values: Vec<i64> = (0..timestamps.len() as i64).collect();
216        RecordBatch::try_new(
217            schema,
218            vec![
219                Arc::new(TimestampNanosecondArray::from(timestamps)),
220                Arc::new(Int64Array::from(values)),
221            ],
222        )
223        .unwrap()
224    }
225
226    fn make_filter(wm: i64) -> WatermarkDynamicFilter {
227        WatermarkDynamicFilter::new(
228            Arc::new(AtomicI64::new(wm)),
229            Arc::new(AtomicU64::new(0)),
230            "ts".to_string(),
231        )
232    }
233
234    #[test]
235    fn test_filter_skips_late_data() {
236        let filter = make_filter(250);
237        let batch = make_millis_batch(vec![100, 200, 300, 400]);
238        let result = filter.filter_batch(&batch).unwrap().unwrap();
239        assert_eq!(result.num_rows(), 2);
240        let ts = result
241            .column(0)
242            .as_any()
243            .downcast_ref::<TimestampMillisecondArray>()
244            .unwrap();
245        assert_eq!(ts.value(0), 300);
246        assert_eq!(ts.value(1), 400);
247    }
248
249    #[test]
250    fn test_filter_passes_on_time_data() {
251        let filter = make_filter(50);
252        let batch = make_millis_batch(vec![100, 200, 300, 400]);
253        let result = filter.filter_batch(&batch).unwrap().unwrap();
254        assert_eq!(result.num_rows(), 4);
255    }
256
257    #[test]
258    fn test_generation_increments_on_advance() {
259        let filter = make_filter(100);
260        assert_eq!(filter.generation(), 0);
261        filter.advance_watermark(200);
262        assert_eq!(filter.generation(), 1);
263        assert_eq!(filter.watermark_ms(), 200);
264        filter.advance_watermark(300);
265        assert_eq!(filter.generation(), 2);
266    }
267
268    #[test]
269    fn test_no_advance_no_generation_change() {
270        let filter = make_filter(200);
271        assert_eq!(filter.generation(), 0);
272        filter.advance_watermark(200);
273        assert_eq!(filter.generation(), 0);
274        filter.advance_watermark(100);
275        assert_eq!(filter.generation(), 0);
276        assert_eq!(filter.watermark_ms(), 200);
277    }
278
279    #[test]
280    fn test_passes_all_when_uninitialized() {
281        let filter = make_filter(-1);
282        let batch = make_millis_batch(vec![100, 200, 300, 400]);
283        let result = filter.filter_batch(&batch).unwrap().unwrap();
284        assert_eq!(result.num_rows(), 4);
285    }
286
287    #[test]
288    fn test_empty_batch_returns_none() {
289        let filter = make_filter(500);
290        let batch = make_millis_batch(vec![100, 200, 300, 400]);
291        let result = filter.filter_batch(&batch).unwrap();
292        assert!(result.is_none());
293    }
294
295    /// Watermark 250 ms → threshold 250_000_000 ns for a Nanosecond column.
296    #[test]
297    fn test_nanosecond_timestamp_rescaled_to_watermark() {
298        let filter = make_filter(250);
299        let batch = make_nanos_batch(vec![100_000_000, 200_000_000, 300_000_000, 400_000_000]);
300        let result = filter.filter_batch(&batch).unwrap().unwrap();
301        assert_eq!(result.num_rows(), 2);
302        let ts = result
303            .column(0)
304            .as_any()
305            .downcast_ref::<TimestampNanosecondArray>()
306            .unwrap();
307        assert_eq!(ts.value(0), 300_000_000);
308        assert_eq!(ts.value(1), 400_000_000);
309    }
310}