Skip to main content

laminar_sql/datafusion/
watermark_filter.rs

1//! Dynamic watermark filter for scan-level late-data pruning
2//!
3//! Pushes a `ts >= watermark` predicate down to `StreamingScanExec` so
4//! late rows are dropped before expression evaluation. The shared
5//! `Arc<AtomicI64>` watermark is the same one Ring 0 already updates.
6
7use std::fmt::{Debug, Formatter};
8use std::pin::Pin;
9use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use arrow::compute::kernels::cmp::gt_eq;
14use arrow_array::cast::AsArray;
15use arrow_array::types::TimestampMillisecondType;
16use arrow_array::{Int64Array, RecordBatch};
17use arrow_schema::{DataType, SchemaRef, TimeUnit};
18use datafusion::physical_plan::RecordBatchStream;
19use datafusion_common::DataFusionError;
20use futures::Stream;
21
22/// Dynamic filter that drops rows older than the current watermark.
23///
24/// Holds a shared watermark atomic (same as [`super::watermark_udf::WatermarkUdf`])
25/// and a monotonic generation counter that increments on each watermark
26/// advance. The generation lets downstream consumers detect stale state
27/// without comparing full watermark values.
28pub struct WatermarkDynamicFilter {
29    watermark_ms: Arc<AtomicI64>,
30    generation: Arc<AtomicU64>,
31    time_column: String,
32}
33
34impl Debug for WatermarkDynamicFilter {
35    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36        f.debug_struct("WatermarkDynamicFilter")
37            .field("watermark_ms", &self.watermark_ms.load(Ordering::Acquire))
38            .field("generation", &self.generation.load(Ordering::Acquire))
39            .field("time_column", &self.time_column)
40            .finish()
41    }
42}
43
44impl WatermarkDynamicFilter {
45    /// Creates a new watermark filter.
46    ///
47    /// # Arguments
48    ///
49    /// * `watermark_ms` - Shared atomic holding the current watermark
50    ///   in epoch milliseconds. Values < 0 mean "uninitialized".
51    /// * `generation` - Monotonic counter incremented on each advance.
52    /// * `time_column` - Name of the event-time column in record batches.
53    pub fn new(
54        watermark_ms: Arc<AtomicI64>,
55        generation: Arc<AtomicU64>,
56        time_column: String,
57    ) -> Self {
58        Self {
59            watermark_ms,
60            generation,
61            time_column,
62        }
63    }
64
65    /// Advances the watermark if `new_ms` exceeds the current value.
66    ///
67    /// On a successful advance the generation counter is incremented.
68    /// No-op when `new_ms <= current`.
69    pub fn advance_watermark(&self, new_ms: i64) {
70        let old = self.watermark_ms.load(Ordering::Acquire);
71        if new_ms > old {
72            self.watermark_ms.store(new_ms, Ordering::Release);
73            self.generation.fetch_add(1, Ordering::Release);
74        }
75    }
76
77    /// Returns the current generation (monotonically increasing).
78    #[must_use]
79    pub fn generation(&self) -> u64 {
80        self.generation.load(Ordering::Acquire)
81    }
82
83    /// Returns the current watermark in epoch milliseconds.
84    #[must_use]
85    pub fn watermark_ms(&self) -> i64 {
86        self.watermark_ms.load(Ordering::Acquire)
87    }
88
89    /// Filters a record batch, keeping only rows where `time_column >= watermark`.
90    ///
91    /// Returns `Ok(None)` when all rows are filtered out.
92    /// When watermark < 0 (uninitialized), all rows pass through.
93    ///
94    /// Handles both `Int64` (epoch millis) and `Timestamp(Millisecond, _)` columns.
95    ///
96    /// # Errors
97    ///
98    /// Returns an error if the time column is missing or has an unsupported type.
99    pub fn filter_batch(
100        &self,
101        batch: &RecordBatch,
102    ) -> Result<Option<RecordBatch>, DataFusionError> {
103        let wm = self.watermark_ms.load(Ordering::Acquire);
104        if wm < 0 {
105            return Ok(Some(batch.clone()));
106        }
107
108        let schema = batch.schema();
109        let col_idx = schema.index_of(&self.time_column).map_err(|_| {
110            DataFusionError::Plan(format!(
111                "watermark filter: time column '{}' not found in schema",
112                self.time_column
113            ))
114        })?;
115
116        let col = batch.column(col_idx);
117        let mask = match col.data_type() {
118            DataType::Int64 => {
119                let ts_array = col
120                    .as_any()
121                    .downcast_ref::<Int64Array>()
122                    .ok_or_else(|| DataFusionError::Internal("expected Int64Array".to_string()))?;
123                let threshold = Int64Array::new_scalar(wm);
124                gt_eq(ts_array, &threshold)?
125            }
126            DataType::Timestamp(TimeUnit::Millisecond, _) => {
127                let ts_array = col.as_primitive::<TimestampMillisecondType>();
128                let threshold = arrow_array::TimestampMillisecondArray::new_scalar(wm);
129                gt_eq(ts_array, &threshold)?
130            }
131            other => {
132                return Err(DataFusionError::Plan(format!(
133                    "watermark filter: unsupported time column type {other:?}, \
134                     expected Int64 or Timestamp(Millisecond)"
135                )));
136            }
137        };
138
139        let filtered = arrow::compute::filter_record_batch(batch, &mask)?;
140        if filtered.num_rows() == 0 {
141            Ok(None)
142        } else {
143            Ok(Some(filtered))
144        }
145    }
146}
147
148/// Stream wrapper that applies watermark filtering to each batch.
149///
150/// Wraps a `SendableRecordBatchStream` and drops rows older than the
151/// current watermark before passing them downstream. Follows the same
152/// pattern as `ProjectingStream` in `channel_source.rs`.
153pub(crate) struct WatermarkFilterStream {
154    inner: Pin<Box<dyn Stream<Item = Result<RecordBatch, DataFusionError>> + Send>>,
155    filter: Arc<WatermarkDynamicFilter>,
156    schema: SchemaRef,
157}
158
159impl WatermarkFilterStream {
160    /// Creates a new watermark-filtered stream.
161    pub fn new(
162        inner: datafusion::execution::SendableRecordBatchStream,
163        filter: Arc<WatermarkDynamicFilter>,
164        schema: SchemaRef,
165    ) -> Self {
166        Self {
167            inner,
168            filter,
169            schema,
170        }
171    }
172}
173
174impl Debug for WatermarkFilterStream {
175    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
176        f.debug_struct("WatermarkFilterStream")
177            .field("filter", &self.filter)
178            .field("schema", &self.schema)
179            .finish_non_exhaustive()
180    }
181}
182
183impl Stream for WatermarkFilterStream {
184    type Item = Result<RecordBatch, DataFusionError>;
185
186    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187        loop {
188            match self.inner.as_mut().poll_next(cx) {
189                Poll::Ready(Some(Ok(batch))) => match self.filter.filter_batch(&batch) {
190                    Ok(Some(filtered)) => return Poll::Ready(Some(Ok(filtered))),
191                    Ok(None) => {
192                        // All rows filtered out — loop to try the next batch
193                    }
194                    Err(e) => return Poll::Ready(Some(Err(e))),
195                },
196                Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
197                Poll::Ready(None) => return Poll::Ready(None),
198                Poll::Pending => return Poll::Pending,
199            }
200        }
201    }
202}
203
204impl RecordBatchStream for WatermarkFilterStream {
205    fn schema(&self) -> SchemaRef {
206        Arc::clone(&self.schema)
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213    use arrow_array::TimestampMillisecondArray;
214    use arrow_schema::{Field, Schema};
215
216    fn make_int64_batch(timestamps: Vec<i64>) -> RecordBatch {
217        let schema = Arc::new(Schema::new(vec![
218            Field::new("ts", DataType::Int64, false),
219            Field::new("value", DataType::Int64, false),
220        ]));
221        #[allow(clippy::cast_possible_wrap)]
222        let values: Vec<i64> = (0..timestamps.len() as i64).collect();
223        RecordBatch::try_new(
224            schema,
225            vec![
226                Arc::new(Int64Array::from(timestamps)),
227                Arc::new(Int64Array::from(values)),
228            ],
229        )
230        .unwrap()
231    }
232
233    fn make_timestamp_batch(timestamps: Vec<i64>) -> RecordBatch {
234        let schema = Arc::new(Schema::new(vec![
235            Field::new(
236                "ts",
237                DataType::Timestamp(TimeUnit::Millisecond, None),
238                false,
239            ),
240            Field::new("value", DataType::Int64, false),
241        ]));
242        #[allow(clippy::cast_possible_wrap)]
243        let values: Vec<i64> = (0..timestamps.len() as i64).collect();
244        RecordBatch::try_new(
245            schema,
246            vec![
247                Arc::new(TimestampMillisecondArray::from(timestamps)),
248                Arc::new(Int64Array::from(values)),
249            ],
250        )
251        .unwrap()
252    }
253
254    fn make_filter(wm: i64) -> WatermarkDynamicFilter {
255        WatermarkDynamicFilter::new(
256            Arc::new(AtomicI64::new(wm)),
257            Arc::new(AtomicU64::new(0)),
258            "ts".to_string(),
259        )
260    }
261
262    #[test]
263    fn test_filter_skips_late_data() {
264        let filter = make_filter(250);
265        let batch = make_int64_batch(vec![100, 200, 300, 400]);
266        let result = filter.filter_batch(&batch).unwrap().unwrap();
267        assert_eq!(result.num_rows(), 2);
268        let ts = result
269            .column(0)
270            .as_any()
271            .downcast_ref::<Int64Array>()
272            .unwrap();
273        assert_eq!(ts.value(0), 300);
274        assert_eq!(ts.value(1), 400);
275    }
276
277    #[test]
278    fn test_filter_passes_on_time_data() {
279        let filter = make_filter(50);
280        let batch = make_int64_batch(vec![100, 200, 300, 400]);
281        let result = filter.filter_batch(&batch).unwrap().unwrap();
282        assert_eq!(result.num_rows(), 4);
283    }
284
285    #[test]
286    fn test_generation_increments_on_advance() {
287        let filter = make_filter(100);
288        assert_eq!(filter.generation(), 0);
289        filter.advance_watermark(200);
290        assert_eq!(filter.generation(), 1);
291        assert_eq!(filter.watermark_ms(), 200);
292        filter.advance_watermark(300);
293        assert_eq!(filter.generation(), 2);
294    }
295
296    #[test]
297    fn test_no_advance_no_generation_change() {
298        let filter = make_filter(200);
299        assert_eq!(filter.generation(), 0);
300        // Same value — no change
301        filter.advance_watermark(200);
302        assert_eq!(filter.generation(), 0);
303        // Lower value — no change
304        filter.advance_watermark(100);
305        assert_eq!(filter.generation(), 0);
306        assert_eq!(filter.watermark_ms(), 200);
307    }
308
309    #[test]
310    fn test_passes_all_when_uninitialized() {
311        let filter = make_filter(-1);
312        let batch = make_int64_batch(vec![100, 200, 300, 400]);
313        let result = filter.filter_batch(&batch).unwrap().unwrap();
314        assert_eq!(result.num_rows(), 4);
315    }
316
317    #[test]
318    fn test_empty_batch_returns_none() {
319        let filter = make_filter(500);
320        let batch = make_int64_batch(vec![100, 200, 300, 400]);
321        let result = filter.filter_batch(&batch).unwrap();
322        assert!(result.is_none());
323    }
324
325    #[test]
326    fn test_arrow_timestamp_type() {
327        let filter = make_filter(250);
328        let batch = make_timestamp_batch(vec![100, 200, 300, 400]);
329        let result = filter.filter_batch(&batch).unwrap().unwrap();
330        assert_eq!(result.num_rows(), 2);
331        let ts = result
332            .column(0)
333            .as_any()
334            .downcast_ref::<TimestampMillisecondArray>()
335            .unwrap();
336        assert_eq!(ts.value(0), 300);
337        assert_eq!(ts.value(1), 400);
338    }
339}