1use std::fmt::{Debug, Formatter};
8use std::pin::Pin;
9use std::sync::atomic::{AtomicI64, AtomicU64, Ordering};
10use std::sync::Arc;
11use std::task::{Context, Poll};
12
13use arrow_array::RecordBatch;
14use arrow_schema::SchemaRef;
15use datafusion::physical_plan::RecordBatchStream;
16use datafusion_common::DataFusionError;
17use futures::Stream;
18use laminar_core::time::{filter_batch_by_timestamp, ThresholdOp};
19
20pub struct WatermarkDynamicFilter {
27 watermark_ms: Arc<AtomicI64>,
28 generation: Arc<AtomicU64>,
29 time_column: String,
30}
31
32impl Debug for WatermarkDynamicFilter {
33 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("WatermarkDynamicFilter")
35 .field("watermark_ms", &self.watermark_ms.load(Ordering::Acquire))
36 .field("generation", &self.generation.load(Ordering::Acquire))
37 .field("time_column", &self.time_column)
38 .finish()
39 }
40}
41
42impl WatermarkDynamicFilter {
43 pub fn new(
52 watermark_ms: Arc<AtomicI64>,
53 generation: Arc<AtomicU64>,
54 time_column: String,
55 ) -> Self {
56 Self {
57 watermark_ms,
58 generation,
59 time_column,
60 }
61 }
62
63 pub fn advance_watermark(&self, new_ms: i64) {
68 loop {
71 let old = self.watermark_ms.load(Ordering::Acquire);
72 if new_ms <= old {
73 break;
74 }
75 if self
76 .watermark_ms
77 .compare_exchange(old, new_ms, Ordering::Release, Ordering::Acquire)
78 .is_ok()
79 {
80 self.generation.fetch_add(1, Ordering::Release);
81 break;
82 }
83 }
84 }
85
86 #[must_use]
88 pub fn generation(&self) -> u64 {
89 self.generation.load(Ordering::Acquire)
90 }
91
92 #[must_use]
94 pub fn watermark_ms(&self) -> i64 {
95 self.watermark_ms.load(Ordering::Acquire)
96 }
97
98 pub fn filter_batch(
107 &self,
108 batch: &RecordBatch,
109 ) -> Result<Option<RecordBatch>, DataFusionError> {
110 let wm = self.watermark_ms.load(Ordering::Acquire);
111 if wm < 0 {
112 return Ok(Some(batch.clone()));
113 }
114
115 filter_batch_by_timestamp(batch, &self.time_column, wm, ThresholdOp::GreaterEq)
116 .map_err(|e| DataFusionError::Plan(format!("watermark filter: {e}")))
117 }
118}
119
120pub(crate) struct WatermarkFilterStream {
126 inner: Pin<Box<dyn Stream<Item = Result<RecordBatch, DataFusionError>> + Send>>,
127 filter: Arc<WatermarkDynamicFilter>,
128 schema: SchemaRef,
129}
130
131impl WatermarkFilterStream {
132 pub fn new(
134 inner: datafusion::execution::SendableRecordBatchStream,
135 filter: Arc<WatermarkDynamicFilter>,
136 schema: SchemaRef,
137 ) -> Self {
138 Self {
139 inner,
140 filter,
141 schema,
142 }
143 }
144}
145
146impl Debug for WatermarkFilterStream {
147 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("WatermarkFilterStream")
149 .field("filter", &self.filter)
150 .field("schema", &self.schema)
151 .finish_non_exhaustive()
152 }
153}
154
155impl Stream for WatermarkFilterStream {
156 type Item = Result<RecordBatch, DataFusionError>;
157
158 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
159 loop {
160 match self.inner.as_mut().poll_next(cx) {
161 Poll::Ready(Some(Ok(batch))) => match self.filter.filter_batch(&batch) {
162 Ok(Some(filtered)) => return Poll::Ready(Some(Ok(filtered))),
163 Ok(None) => {
164 }
166 Err(e) => return Poll::Ready(Some(Err(e))),
167 },
168 Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))),
169 Poll::Ready(None) => return Poll::Ready(None),
170 Poll::Pending => return Poll::Pending,
171 }
172 }
173 }
174}
175
176impl RecordBatchStream for WatermarkFilterStream {
177 fn schema(&self) -> SchemaRef {
178 Arc::clone(&self.schema)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use arrow_array::{Int64Array, TimestampMillisecondArray, TimestampNanosecondArray};
186 use arrow_schema::{DataType, Field, Schema, TimeUnit};
187
188 fn make_millis_batch(timestamps: Vec<i64>) -> RecordBatch {
189 let schema = Arc::new(Schema::new(vec![
190 Field::new(
191 "ts",
192 DataType::Timestamp(TimeUnit::Millisecond, None),
193 false,
194 ),
195 Field::new("value", DataType::Int64, false),
196 ]));
197 #[allow(clippy::cast_possible_wrap)]
198 let values: Vec<i64> = (0..timestamps.len() as i64).collect();
199 RecordBatch::try_new(
200 schema,
201 vec![
202 Arc::new(TimestampMillisecondArray::from(timestamps)),
203 Arc::new(Int64Array::from(values)),
204 ],
205 )
206 .unwrap()
207 }
208
209 fn make_nanos_batch(timestamps: Vec<i64>) -> RecordBatch {
210 let schema = Arc::new(Schema::new(vec![
211 Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false),
212 Field::new("value", DataType::Int64, false),
213 ]));
214 #[allow(clippy::cast_possible_wrap)]
215 let values: Vec<i64> = (0..timestamps.len() as i64).collect();
216 RecordBatch::try_new(
217 schema,
218 vec![
219 Arc::new(TimestampNanosecondArray::from(timestamps)),
220 Arc::new(Int64Array::from(values)),
221 ],
222 )
223 .unwrap()
224 }
225
226 fn make_filter(wm: i64) -> WatermarkDynamicFilter {
227 WatermarkDynamicFilter::new(
228 Arc::new(AtomicI64::new(wm)),
229 Arc::new(AtomicU64::new(0)),
230 "ts".to_string(),
231 )
232 }
233
234 #[test]
235 fn test_filter_skips_late_data() {
236 let filter = make_filter(250);
237 let batch = make_millis_batch(vec![100, 200, 300, 400]);
238 let result = filter.filter_batch(&batch).unwrap().unwrap();
239 assert_eq!(result.num_rows(), 2);
240 let ts = result
241 .column(0)
242 .as_any()
243 .downcast_ref::<TimestampMillisecondArray>()
244 .unwrap();
245 assert_eq!(ts.value(0), 300);
246 assert_eq!(ts.value(1), 400);
247 }
248
249 #[test]
250 fn test_filter_passes_on_time_data() {
251 let filter = make_filter(50);
252 let batch = make_millis_batch(vec![100, 200, 300, 400]);
253 let result = filter.filter_batch(&batch).unwrap().unwrap();
254 assert_eq!(result.num_rows(), 4);
255 }
256
257 #[test]
258 fn test_generation_increments_on_advance() {
259 let filter = make_filter(100);
260 assert_eq!(filter.generation(), 0);
261 filter.advance_watermark(200);
262 assert_eq!(filter.generation(), 1);
263 assert_eq!(filter.watermark_ms(), 200);
264 filter.advance_watermark(300);
265 assert_eq!(filter.generation(), 2);
266 }
267
268 #[test]
269 fn test_no_advance_no_generation_change() {
270 let filter = make_filter(200);
271 assert_eq!(filter.generation(), 0);
272 filter.advance_watermark(200);
273 assert_eq!(filter.generation(), 0);
274 filter.advance_watermark(100);
275 assert_eq!(filter.generation(), 0);
276 assert_eq!(filter.watermark_ms(), 200);
277 }
278
279 #[test]
280 fn test_passes_all_when_uninitialized() {
281 let filter = make_filter(-1);
282 let batch = make_millis_batch(vec![100, 200, 300, 400]);
283 let result = filter.filter_batch(&batch).unwrap().unwrap();
284 assert_eq!(result.num_rows(), 4);
285 }
286
287 #[test]
288 fn test_empty_batch_returns_none() {
289 let filter = make_filter(500);
290 let batch = make_millis_batch(vec![100, 200, 300, 400]);
291 let result = filter.filter_batch(&batch).unwrap();
292 assert!(result.is_none());
293 }
294
295 #[test]
297 fn test_nanosecond_timestamp_rescaled_to_watermark() {
298 let filter = make_filter(250);
299 let batch = make_nanos_batch(vec![100_000_000, 200_000_000, 300_000_000, 400_000_000]);
300 let result = filter.filter_batch(&batch).unwrap().unwrap();
301 assert_eq!(result.num_rows(), 2);
302 let ts = result
303 .column(0)
304 .as_any()
305 .downcast_ref::<TimestampNanosecondArray>()
306 .unwrap();
307 assert_eq!(ts.value(0), 300_000_000);
308 assert_eq!(ts.value(1), 400_000_000);
309 }
310}