1use std::fmt::{Debug, Formatter};
8use std::pin::Pin;
9use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use arrow::compute::kernels::cmp::gt_eq;
14use arrow_array::cast::AsArray;
15use arrow_array::types::TimestampMillisecondType;
16use arrow_array::{Int64Array, RecordBatch};
17use arrow_schema::{DataType, SchemaRef, TimeUnit};
18use datafusion::physical_plan::RecordBatchStream;
19use datafusion_common::DataFusionError;
20use futures::Stream;
21
22pub struct WatermarkDynamicFilter {
29 watermark_ms: Arc<AtomicI64>,
30 generation: Arc<AtomicU64>,
31 time_column: String,
32}
33
34impl Debug for WatermarkDynamicFilter {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("WatermarkDynamicFilter")
37 .field("watermark_ms", &self.watermark_ms.load(Ordering::Acquire))
38 .field("generation", &self.generation.load(Ordering::Acquire))
39 .field("time_column", &self.time_column)
40 .finish()
41 }
42}
43
44impl WatermarkDynamicFilter {
45 pub fn new(
54 watermark_ms: Arc<AtomicI64>,
55 generation: Arc<AtomicU64>,
56 time_column: String,
57 ) -> Self {
58 Self {
59 watermark_ms,
60 generation,
61 time_column,
62 }
63 }
64
65 pub fn advance_watermark(&self, new_ms: i64) {
70 loop {
73 let old = self.watermark_ms.load(Ordering::Acquire);
74 if new_ms <= old {
75 break;
76 }
77 if self
78 .watermark_ms
79 .compare_exchange(old, new_ms, Ordering::Release, Ordering::Acquire)
80 .is_ok()
81 {
82 self.generation.fetch_add(1, Ordering::Release);
83 break;
84 }
85 }
86 }
87
88 #[must_use]
90 pub fn generation(&self) -> u64 {
91 self.generation.load(Ordering::Acquire)
92 }
93
94 #[must_use]
96 pub fn watermark_ms(&self) -> i64 {
97 self.watermark_ms.load(Ordering::Acquire)
98 }
99
100 pub fn filter_batch(
111 &self,
112 batch: &RecordBatch,
113 ) -> Result<Option<RecordBatch>, DataFusionError> {
114 let wm = self.watermark_ms.load(Ordering::Acquire);
115 if wm < 0 {
116 return Ok(Some(batch.clone()));
117 }
118
119 let schema = batch.schema();
120 let col_idx = schema.index_of(&self.time_column).map_err(|_| {
121 DataFusionError::Plan(format!(
122 "watermark filter: time column '{}' not found in schema",
123 self.time_column
124 ))
125 })?;
126
127 let col = batch.column(col_idx);
128 let mask = match col.data_type() {
129 DataType::Int64 => {
130 let ts_array = col
131 .as_any()
132 .downcast_ref::<Int64Array>()
133 .ok_or_else(|| DataFusionError::Internal("expected Int64Array".to_string()))?;
134 let threshold = Int64Array::new_scalar(wm);
135 gt_eq(ts_array, &threshold)?
136 }
137 DataType::Timestamp(TimeUnit::Millisecond, _) => {
138 let ts_array = col.as_primitive::<TimestampMillisecondType>();
139 let threshold = arrow_array::TimestampMillisecondArray::new_scalar(wm);
140 gt_eq(ts_array, &threshold)?
141 }
142 other => {
143 return Err(DataFusionError::Plan(format!(
144 "watermark filter: unsupported time column type {other:?}, \
145 expected Int64 or Timestamp(Millisecond)"
146 )));
147 }
148 };
149
150 let filtered = arrow::compute::filter_record_batch(batch, &mask)?;
151 if filtered.num_rows() == 0 {
152 Ok(None)
153 } else {
154 Ok(Some(filtered))
155 }
156 }
157}
158
159pub(crate) struct WatermarkFilterStream {
165 inner: Pin<Box<dyn Stream<Item = Result<RecordBatch, DataFusionError>> + Send>>,
166 filter: Arc<WatermarkDynamicFilter>,
167 schema: SchemaRef,
168}
169
170impl WatermarkFilterStream {
171 pub fn new(
173 inner: datafusion::execution::SendableRecordBatchStream,
174 filter: Arc<WatermarkDynamicFilter>,
175 schema: SchemaRef,
176 ) -> Self {
177 Self {
178 inner,
179 filter,
180 schema,
181 }
182 }
183}
184
185impl Debug for WatermarkFilterStream {
186 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("WatermarkFilterStream")
188 .field("filter", &self.filter)
189 .field("schema", &self.schema)
190 .finish_non_exhaustive()
191 }
192}
193
194impl Stream for WatermarkFilterStream {
195 type Item = Result<RecordBatch, DataFusionError>;
196
197 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198 loop {
199 match self.inner.as_mut().poll_next(cx) {
200 Poll::Ready(Some(Ok(batch))) => match self.filter.filter_batch(&batch) {
201 Ok(Some(filtered)) => return Poll::Ready(Some(Ok(filtered))),
202 Ok(None) => {
203 }
205 Err(e) => return Poll::Ready(Some(Err(e))),
206 },
207 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
208 Poll::Ready(None) => return Poll::Ready(None),
209 Poll::Pending => return Poll::Pending,
210 }
211 }
212 }
213}
214
215impl RecordBatchStream for WatermarkFilterStream {
216 fn schema(&self) -> SchemaRef {
217 Arc::clone(&self.schema)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use arrow_array::TimestampMillisecondArray;
225 use arrow_schema::{Field, Schema};
226
227 fn make_int64_batch(timestamps: Vec<i64>) -> RecordBatch {
228 let schema = Arc::new(Schema::new(vec![
229 Field::new("ts", DataType::Int64, false),
230 Field::new("value", DataType::Int64, false),
231 ]));
232 #[allow(clippy::cast_possible_wrap)]
233 let values: Vec<i64> = (0..timestamps.len() as i64).collect();
234 RecordBatch::try_new(
235 schema,
236 vec![
237 Arc::new(Int64Array::from(timestamps)),
238 Arc::new(Int64Array::from(values)),
239 ],
240 )
241 .unwrap()
242 }
243
244 fn make_timestamp_batch(timestamps: Vec<i64>) -> RecordBatch {
245 let schema = Arc::new(Schema::new(vec![
246 Field::new(
247 "ts",
248 DataType::Timestamp(TimeUnit::Millisecond, None),
249 false,
250 ),
251 Field::new("value", DataType::Int64, false),
252 ]));
253 #[allow(clippy::cast_possible_wrap)]
254 let values: Vec<i64> = (0..timestamps.len() as i64).collect();
255 RecordBatch::try_new(
256 schema,
257 vec![
258 Arc::new(TimestampMillisecondArray::from(timestamps)),
259 Arc::new(Int64Array::from(values)),
260 ],
261 )
262 .unwrap()
263 }
264
265 fn make_filter(wm: i64) -> WatermarkDynamicFilter {
266 WatermarkDynamicFilter::new(
267 Arc::new(AtomicI64::new(wm)),
268 Arc::new(AtomicU64::new(0)),
269 "ts".to_string(),
270 )
271 }
272
273 #[test]
274 fn test_filter_skips_late_data() {
275 let filter = make_filter(250);
276 let batch = make_int64_batch(vec![100, 200, 300, 400]);
277 let result = filter.filter_batch(&batch).unwrap().unwrap();
278 assert_eq!(result.num_rows(), 2);
279 let ts = result
280 .column(0)
281 .as_any()
282 .downcast_ref::<Int64Array>()
283 .unwrap();
284 assert_eq!(ts.value(0), 300);
285 assert_eq!(ts.value(1), 400);
286 }
287
288 #[test]
289 fn test_filter_passes_on_time_data() {
290 let filter = make_filter(50);
291 let batch = make_int64_batch(vec![100, 200, 300, 400]);
292 let result = filter.filter_batch(&batch).unwrap().unwrap();
293 assert_eq!(result.num_rows(), 4);
294 }
295
296 #[test]
297 fn test_generation_increments_on_advance() {
298 let filter = make_filter(100);
299 assert_eq!(filter.generation(), 0);
300 filter.advance_watermark(200);
301 assert_eq!(filter.generation(), 1);
302 assert_eq!(filter.watermark_ms(), 200);
303 filter.advance_watermark(300);
304 assert_eq!(filter.generation(), 2);
305 }
306
307 #[test]
308 fn test_no_advance_no_generation_change() {
309 let filter = make_filter(200);
310 assert_eq!(filter.generation(), 0);
311 filter.advance_watermark(200);
313 assert_eq!(filter.generation(), 0);
314 filter.advance_watermark(100);
316 assert_eq!(filter.generation(), 0);
317 assert_eq!(filter.watermark_ms(), 200);
318 }
319
320 #[test]
321 fn test_passes_all_when_uninitialized() {
322 let filter = make_filter(-1);
323 let batch = make_int64_batch(vec![100, 200, 300, 400]);
324 let result = filter.filter_batch(&batch).unwrap().unwrap();
325 assert_eq!(result.num_rows(), 4);
326 }
327
328 #[test]
329 fn test_empty_batch_returns_none() {
330 let filter = make_filter(500);
331 let batch = make_int64_batch(vec![100, 200, 300, 400]);
332 let result = filter.filter_batch(&batch).unwrap();
333 assert!(result.is_none());
334 }
335
336 #[test]
337 fn test_arrow_timestamp_type() {
338 let filter = make_filter(250);
339 let batch = make_timestamp_batch(vec![100, 200, 300, 400]);
340 let result = filter.filter_batch(&batch).unwrap().unwrap();
341 assert_eq!(result.num_rows(), 2);
342 let ts = result
343 .column(0)
344 .as_any()
345 .downcast_ref::<TimestampMillisecondArray>()
346 .unwrap();
347 assert_eq!(ts.value(0), 300);
348 assert_eq!(ts.value(1), 400);
349 }
350}