Skip to main content

laminar_sql/datafusion/
watermark_filter.rs

1//! Dynamic watermark filter for scan-level late-data pruning
2//!
3//! Pushes a `ts >= watermark` predicate down to `StreamingScanExec` so
4//! late rows are dropped before expression evaluation. The shared
5//! `Arc<AtomicI64>` watermark is the same one Ring 0 already updates.
6
7use std::fmt::{Debug, Formatter};
8use std::pin::Pin;
9use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use arrow::compute::kernels::cmp::gt_eq;
14use arrow_array::cast::AsArray;
15use arrow_array::types::TimestampMillisecondType;
16use arrow_array::{Int64Array, RecordBatch};
17use arrow_schema::{DataType, SchemaRef, TimeUnit};
18use datafusion::physical_plan::RecordBatchStream;
19use datafusion_common::DataFusionError;
20use futures::Stream;
21
22/// Dynamic filter that drops rows older than the current watermark.
23///
24/// Holds a shared watermark atomic (same as [`super::watermark_udf::WatermarkUdf`])
25/// and a monotonic generation counter that increments on each watermark
26/// advance. The generation lets downstream consumers detect stale state
27/// without comparing full watermark values.
28pub struct WatermarkDynamicFilter {
29    watermark_ms: Arc<AtomicI64>,
30    generation: Arc<AtomicU64>,
31    time_column: String,
32}
33
34impl Debug for WatermarkDynamicFilter {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("WatermarkDynamicFilter")
37            .field("watermark_ms", &self.watermark_ms.load(Ordering::Acquire))
38            .field("generation", &self.generation.load(Ordering::Acquire))
39            .field("time_column", &self.time_column)
40            .finish()
41    }
42}
43
44impl WatermarkDynamicFilter {
45    /// Creates a new watermark filter.
46    ///
47    /// # Arguments
48    ///
49    /// * `watermark_ms` - Shared atomic holding the current watermark
50    ///   in epoch milliseconds. Values < 0 mean "uninitialized".
51    /// * `generation` - Monotonic counter incremented on each advance.
52    /// * `time_column` - Name of the event-time column in record batches.
53    pub fn new(
54        watermark_ms: Arc<AtomicI64>,
55        generation: Arc<AtomicU64>,
56        time_column: String,
57    ) -> Self {
58        Self {
59            watermark_ms,
60            generation,
61            time_column,
62        }
63    }
64
65    /// Advances the watermark if `new_ms` exceeds the current value.
66    ///
67    /// On a successful advance the generation counter is incremented.
68    /// No-op when `new_ms <= current`.
69    pub fn advance_watermark(&self, new_ms: i64) {
70        // Atomic compare-and-swap loop to avoid TOCTOU race where two
71        // concurrent callers both see old < new_ms and clobber each other.
72        loop {
73            let old = self.watermark_ms.load(Ordering::Acquire);
74            if new_ms <= old {
75                break;
76            }
77            if self
78                .watermark_ms
79                .compare_exchange(old, new_ms, Ordering::Release, Ordering::Acquire)
80                .is_ok()
81            {
82                self.generation.fetch_add(1, Ordering::Release);
83                break;
84            }
85        }
86    }
87
88    /// Returns the current generation (monotonically increasing).
89    #[must_use]
90    pub fn generation(&self) -> u64 {
91        self.generation.load(Ordering::Acquire)
92    }
93
94    /// Returns the current watermark in epoch milliseconds.
95    #[must_use]
96    pub fn watermark_ms(&self) -> i64 {
97        self.watermark_ms.load(Ordering::Acquire)
98    }
99
100    /// Filters a record batch, keeping only rows where `time_column >= watermark`.
101    ///
102    /// Returns `Ok(None)` when all rows are filtered out.
103    /// When watermark < 0 (uninitialized), all rows pass through.
104    ///
105    /// Handles both `Int64` (epoch millis) and `Timestamp(Millisecond, _)` columns.
106    ///
107    /// # Errors
108    ///
109    /// Returns an error if the time column is missing or has an unsupported type.
110    pub fn filter_batch(
111        &self,
112        batch: &RecordBatch,
113    ) -> Result<Option<RecordBatch>, DataFusionError> {
114        let wm = self.watermark_ms.load(Ordering::Acquire);
115        if wm < 0 {
116            return Ok(Some(batch.clone()));
117        }
118
119        let schema = batch.schema();
120        let col_idx = schema.index_of(&self.time_column).map_err(|_| {
121            DataFusionError::Plan(format!(
122                "watermark filter: time column '{}' not found in schema",
123                self.time_column
124            ))
125        })?;
126
127        let col = batch.column(col_idx);
128        let mask = match col.data_type() {
129            DataType::Int64 => {
130                let ts_array = col
131                    .as_any()
132                    .downcast_ref::<Int64Array>()
133                    .ok_or_else(|| DataFusionError::Internal("expected Int64Array".to_string()))?;
134                let threshold = Int64Array::new_scalar(wm);
135                gt_eq(ts_array, &threshold)?
136            }
137            DataType::Timestamp(TimeUnit::Millisecond, _) => {
138                let ts_array = col.as_primitive::<TimestampMillisecondType>();
139                let threshold = arrow_array::TimestampMillisecondArray::new_scalar(wm);
140                gt_eq(ts_array, &threshold)?
141            }
142            other => {
143                return Err(DataFusionError::Plan(format!(
144                    "watermark filter: unsupported time column type {other:?}, \
145                     expected Int64 or Timestamp(Millisecond)"
146                )));
147            }
148        };
149
150        let filtered = arrow::compute::filter_record_batch(batch, &mask)?;
151        if filtered.num_rows() == 0 {
152            Ok(None)
153        } else {
154            Ok(Some(filtered))
155        }
156    }
157}
158
159/// Stream wrapper that applies watermark filtering to each batch.
160///
161/// Wraps a `SendableRecordBatchStream` and drops rows older than the
162/// current watermark before passing them downstream. Follows the same
163/// pattern as `ProjectingStream` in `channel_source.rs`.
164pub(crate) struct WatermarkFilterStream {
165    inner: Pin<Box<dyn Stream<Item = Result<RecordBatch, DataFusionError>> + Send>>,
166    filter: Arc<WatermarkDynamicFilter>,
167    schema: SchemaRef,
168}
169
170impl WatermarkFilterStream {
171    /// Creates a new watermark-filtered stream.
172    pub fn new(
173        inner: datafusion::execution::SendableRecordBatchStream,
174        filter: Arc<WatermarkDynamicFilter>,
175        schema: SchemaRef,
176    ) -> Self {
177        Self {
178            inner,
179            filter,
180            schema,
181        }
182    }
183}
184
185impl Debug for WatermarkFilterStream {
186    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187        f.debug_struct("WatermarkFilterStream")
188            .field("filter", &self.filter)
189            .field("schema", &self.schema)
190            .finish_non_exhaustive()
191    }
192}
193
194impl Stream for WatermarkFilterStream {
195    type Item = Result<RecordBatch, DataFusionError>;
196
197    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198        loop {
199            match self.inner.as_mut().poll_next(cx) {
200                Poll::Ready(Some(Ok(batch))) => match self.filter.filter_batch(&batch) {
201                    Ok(Some(filtered)) => return Poll::Ready(Some(Ok(filtered))),
202                    Ok(None) => {
203                        // All rows filtered out — loop to try the next batch
204                    }
205                    Err(e) => return Poll::Ready(Some(Err(e))),
206                },
207                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
208                Poll::Ready(None) => return Poll::Ready(None),
209                Poll::Pending => return Poll::Pending,
210            }
211        }
212    }
213}
214
215impl RecordBatchStream for WatermarkFilterStream {
216    fn schema(&self) -> SchemaRef {
217        Arc::clone(&self.schema)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use arrow_array::TimestampMillisecondArray;
225    use arrow_schema::{Field, Schema};
226
227    fn make_int64_batch(timestamps: Vec<i64>) -> RecordBatch {
228        let schema = Arc::new(Schema::new(vec![
229            Field::new("ts", DataType::Int64, false),
230            Field::new("value", DataType::Int64, false),
231        ]));
232        #[allow(clippy::cast_possible_wrap)]
233        let values: Vec<i64> = (0..timestamps.len() as i64).collect();
234        RecordBatch::try_new(
235            schema,
236            vec![
237                Arc::new(Int64Array::from(timestamps)),
238                Arc::new(Int64Array::from(values)),
239            ],
240        )
241        .unwrap()
242    }
243
244    fn make_timestamp_batch(timestamps: Vec<i64>) -> RecordBatch {
245        let schema = Arc::new(Schema::new(vec![
246            Field::new(
247                "ts",
248                DataType::Timestamp(TimeUnit::Millisecond, None),
249                false,
250            ),
251            Field::new("value", DataType::Int64, false),
252        ]));
253        #[allow(clippy::cast_possible_wrap)]
254        let values: Vec<i64> = (0..timestamps.len() as i64).collect();
255        RecordBatch::try_new(
256            schema,
257            vec![
258                Arc::new(TimestampMillisecondArray::from(timestamps)),
259                Arc::new(Int64Array::from(values)),
260            ],
261        )
262        .unwrap()
263    }
264
265    fn make_filter(wm: i64) -> WatermarkDynamicFilter {
266        WatermarkDynamicFilter::new(
267            Arc::new(AtomicI64::new(wm)),
268            Arc::new(AtomicU64::new(0)),
269            "ts".to_string(),
270        )
271    }
272
273    #[test]
274    fn test_filter_skips_late_data() {
275        let filter = make_filter(250);
276        let batch = make_int64_batch(vec![100, 200, 300, 400]);
277        let result = filter.filter_batch(&batch).unwrap().unwrap();
278        assert_eq!(result.num_rows(), 2);
279        let ts = result
280            .column(0)
281            .as_any()
282            .downcast_ref::<Int64Array>()
283            .unwrap();
284        assert_eq!(ts.value(0), 300);
285        assert_eq!(ts.value(1), 400);
286    }
287
288    #[test]
289    fn test_filter_passes_on_time_data() {
290        let filter = make_filter(50);
291        let batch = make_int64_batch(vec![100, 200, 300, 400]);
292        let result = filter.filter_batch(&batch).unwrap().unwrap();
293        assert_eq!(result.num_rows(), 4);
294    }
295
296    #[test]
297    fn test_generation_increments_on_advance() {
298        let filter = make_filter(100);
299        assert_eq!(filter.generation(), 0);
300        filter.advance_watermark(200);
301        assert_eq!(filter.generation(), 1);
302        assert_eq!(filter.watermark_ms(), 200);
303        filter.advance_watermark(300);
304        assert_eq!(filter.generation(), 2);
305    }
306
307    #[test]
308    fn test_no_advance_no_generation_change() {
309        let filter = make_filter(200);
310        assert_eq!(filter.generation(), 0);
311        // Same value — no change
312        filter.advance_watermark(200);
313        assert_eq!(filter.generation(), 0);
314        // Lower value — no change
315        filter.advance_watermark(100);
316        assert_eq!(filter.generation(), 0);
317        assert_eq!(filter.watermark_ms(), 200);
318    }
319
320    #[test]
321    fn test_passes_all_when_uninitialized() {
322        let filter = make_filter(-1);
323        let batch = make_int64_batch(vec![100, 200, 300, 400]);
324        let result = filter.filter_batch(&batch).unwrap().unwrap();
325        assert_eq!(result.num_rows(), 4);
326    }
327
328    #[test]
329    fn test_empty_batch_returns_none() {
330        let filter = make_filter(500);
331        let batch = make_int64_batch(vec![100, 200, 300, 400]);
332        let result = filter.filter_batch(&batch).unwrap();
333        assert!(result.is_none());
334    }
335
336    #[test]
337    fn test_arrow_timestamp_type() {
338        let filter = make_filter(250);
339        let batch = make_timestamp_batch(vec![100, 200, 300, 400]);
340        let result = filter.filter_batch(&batch).unwrap().unwrap();
341        assert_eq!(result.num_rows(), 2);
342        let ts = result
343            .column(0)
344            .as_any()
345            .downcast_ref::<TimestampMillisecondArray>()
346            .unwrap();
347        assert_eq!(ts.value(0), 300);
348        assert_eq!(ts.value(1), 400);
349    }
350}