1use std::fmt::{Debug, Formatter};
8use std::pin::Pin;
9use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use arrow::compute::kernels::cmp::gt_eq;
14use arrow_array::cast::AsArray;
15use arrow_array::types::TimestampMillisecondType;
16use arrow_array::{Int64Array, RecordBatch};
17use arrow_schema::{DataType, SchemaRef, TimeUnit};
18use datafusion::physical_plan::RecordBatchStream;
19use datafusion_common::DataFusionError;
20use futures::Stream;
21
22pub struct WatermarkDynamicFilter {
29 watermark_ms: Arc<AtomicI64>,
30 generation: Arc<AtomicU64>,
31 time_column: String,
32}
33
34impl Debug for WatermarkDynamicFilter {
35 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
36 f.debug_struct("WatermarkDynamicFilter")
37 .field("watermark_ms", &self.watermark_ms.load(Ordering::Acquire))
38 .field("generation", &self.generation.load(Ordering::Acquire))
39 .field("time_column", &self.time_column)
40 .finish()
41 }
42}
43
44impl WatermarkDynamicFilter {
45 pub fn new(
54 watermark_ms: Arc<AtomicI64>,
55 generation: Arc<AtomicU64>,
56 time_column: String,
57 ) -> Self {
58 Self {
59 watermark_ms,
60 generation,
61 time_column,
62 }
63 }
64
65 pub fn advance_watermark(&self, new_ms: i64) {
70 let old = self.watermark_ms.load(Ordering::Acquire);
71 if new_ms > old {
72 self.watermark_ms.store(new_ms, Ordering::Release);
73 self.generation.fetch_add(1, Ordering::Release);
74 }
75 }
76
77 #[must_use]
79 pub fn generation(&self) -> u64 {
80 self.generation.load(Ordering::Acquire)
81 }
82
83 #[must_use]
85 pub fn watermark_ms(&self) -> i64 {
86 self.watermark_ms.load(Ordering::Acquire)
87 }
88
89 pub fn filter_batch(
100 &self,
101 batch: &RecordBatch,
102 ) -> Result<Option<RecordBatch>, DataFusionError> {
103 let wm = self.watermark_ms.load(Ordering::Acquire);
104 if wm < 0 {
105 return Ok(Some(batch.clone()));
106 }
107
108 let schema = batch.schema();
109 let col_idx = schema.index_of(&self.time_column).map_err(|_| {
110 DataFusionError::Plan(format!(
111 "watermark filter: time column '{}' not found in schema",
112 self.time_column
113 ))
114 })?;
115
116 let col = batch.column(col_idx);
117 let mask = match col.data_type() {
118 DataType::Int64 => {
119 let ts_array = col
120 .as_any()
121 .downcast_ref::<Int64Array>()
122 .ok_or_else(|| DataFusionError::Internal("expected Int64Array".to_string()))?;
123 let threshold = Int64Array::new_scalar(wm);
124 gt_eq(ts_array, &threshold)?
125 }
126 DataType::Timestamp(TimeUnit::Millisecond, _) => {
127 let ts_array = col.as_primitive::<TimestampMillisecondType>();
128 let threshold = arrow_array::TimestampMillisecondArray::new_scalar(wm);
129 gt_eq(ts_array, &threshold)?
130 }
131 other => {
132 return Err(DataFusionError::Plan(format!(
133 "watermark filter: unsupported time column type {other:?}, \
134 expected Int64 or Timestamp(Millisecond)"
135 )));
136 }
137 };
138
139 let filtered = arrow::compute::filter_record_batch(batch, &mask)?;
140 if filtered.num_rows() == 0 {
141 Ok(None)
142 } else {
143 Ok(Some(filtered))
144 }
145 }
146}
147
148pub(crate) struct WatermarkFilterStream {
154 inner: Pin<Box<dyn Stream<Item = Result<RecordBatch, DataFusionError>> + Send>>,
155 filter: Arc<WatermarkDynamicFilter>,
156 schema: SchemaRef,
157}
158
159impl WatermarkFilterStream {
160 pub fn new(
162 inner: datafusion::execution::SendableRecordBatchStream,
163 filter: Arc<WatermarkDynamicFilter>,
164 schema: SchemaRef,
165 ) -> Self {
166 Self {
167 inner,
168 filter,
169 schema,
170 }
171 }
172}
173
174impl Debug for WatermarkFilterStream {
175 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
176 f.debug_struct("WatermarkFilterStream")
177 .field("filter", &self.filter)
178 .field("schema", &self.schema)
179 .finish_non_exhaustive()
180 }
181}
182
183impl Stream for WatermarkFilterStream {
184 type Item = Result<RecordBatch, DataFusionError>;
185
186 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
187 loop {
188 match self.inner.as_mut().poll_next(cx) {
189 Poll::Ready(Some(Ok(batch))) => match self.filter.filter_batch(&batch) {
190 Ok(Some(filtered)) => return Poll::Ready(Some(Ok(filtered))),
191 Ok(None) => {
192 }
194 Err(e) => return Poll::Ready(Some(Err(e))),
195 },
196 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
197 Poll::Ready(None) => return Poll::Ready(None),
198 Poll::Pending => return Poll::Pending,
199 }
200 }
201 }
202}
203
204impl RecordBatchStream for WatermarkFilterStream {
205 fn schema(&self) -> SchemaRef {
206 Arc::clone(&self.schema)
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213 use arrow_array::TimestampMillisecondArray;
214 use arrow_schema::{Field, Schema};
215
216 fn make_int64_batch(timestamps: Vec<i64>) -> RecordBatch {
217 let schema = Arc::new(Schema::new(vec![
218 Field::new("ts", DataType::Int64, false),
219 Field::new("value", DataType::Int64, false),
220 ]));
221 #[allow(clippy::cast_possible_wrap)]
222 let values: Vec<i64> = (0..timestamps.len() as i64).collect();
223 RecordBatch::try_new(
224 schema,
225 vec![
226 Arc::new(Int64Array::from(timestamps)),
227 Arc::new(Int64Array::from(values)),
228 ],
229 )
230 .unwrap()
231 }
232
233 fn make_timestamp_batch(timestamps: Vec<i64>) -> RecordBatch {
234 let schema = Arc::new(Schema::new(vec![
235 Field::new(
236 "ts",
237 DataType::Timestamp(TimeUnit::Millisecond, None),
238 false,
239 ),
240 Field::new("value", DataType::Int64, false),
241 ]));
242 #[allow(clippy::cast_possible_wrap)]
243 let values: Vec<i64> = (0..timestamps.len() as i64).collect();
244 RecordBatch::try_new(
245 schema,
246 vec![
247 Arc::new(TimestampMillisecondArray::from(timestamps)),
248 Arc::new(Int64Array::from(values)),
249 ],
250 )
251 .unwrap()
252 }
253
254 fn make_filter(wm: i64) -> WatermarkDynamicFilter {
255 WatermarkDynamicFilter::new(
256 Arc::new(AtomicI64::new(wm)),
257 Arc::new(AtomicU64::new(0)),
258 "ts".to_string(),
259 )
260 }
261
262 #[test]
263 fn test_filter_skips_late_data() {
264 let filter = make_filter(250);
265 let batch = make_int64_batch(vec![100, 200, 300, 400]);
266 let result = filter.filter_batch(&batch).unwrap().unwrap();
267 assert_eq!(result.num_rows(), 2);
268 let ts = result
269 .column(0)
270 .as_any()
271 .downcast_ref::<Int64Array>()
272 .unwrap();
273 assert_eq!(ts.value(0), 300);
274 assert_eq!(ts.value(1), 400);
275 }
276
277 #[test]
278 fn test_filter_passes_on_time_data() {
279 let filter = make_filter(50);
280 let batch = make_int64_batch(vec![100, 200, 300, 400]);
281 let result = filter.filter_batch(&batch).unwrap().unwrap();
282 assert_eq!(result.num_rows(), 4);
283 }
284
285 #[test]
286 fn test_generation_increments_on_advance() {
287 let filter = make_filter(100);
288 assert_eq!(filter.generation(), 0);
289 filter.advance_watermark(200);
290 assert_eq!(filter.generation(), 1);
291 assert_eq!(filter.watermark_ms(), 200);
292 filter.advance_watermark(300);
293 assert_eq!(filter.generation(), 2);
294 }
295
296 #[test]
297 fn test_no_advance_no_generation_change() {
298 let filter = make_filter(200);
299 assert_eq!(filter.generation(), 0);
300 filter.advance_watermark(200);
302 assert_eq!(filter.generation(), 0);
303 filter.advance_watermark(100);
305 assert_eq!(filter.generation(), 0);
306 assert_eq!(filter.watermark_ms(), 200);
307 }
308
309 #[test]
310 fn test_passes_all_when_uninitialized() {
311 let filter = make_filter(-1);
312 let batch = make_int64_batch(vec![100, 200, 300, 400]);
313 let result = filter.filter_batch(&batch).unwrap().unwrap();
314 assert_eq!(result.num_rows(), 4);
315 }
316
317 #[test]
318 fn test_empty_batch_returns_none() {
319 let filter = make_filter(500);
320 let batch = make_int64_batch(vec![100, 200, 300, 400]);
321 let result = filter.filter_batch(&batch).unwrap();
322 assert!(result.is_none());
323 }
324
325 #[test]
326 fn test_arrow_timestamp_type() {
327 let filter = make_filter(250);
328 let batch = make_timestamp_batch(vec![100, 200, 300, 400]);
329 let result = filter.filter_batch(&batch).unwrap().unwrap();
330 assert_eq!(result.num_rows(), 2);
331 let ts = result
332 .column(0)
333 .as_any()
334 .downcast_ref::<TimestampMillisecondArray>()
335 .unwrap();
336 assert_eq!(ts.value(0), 300);
337 assert_eq!(ts.value(1), 400);
338 }
339}