1pub mod aggregate_bridge;
70mod bridge;
71mod channel_source;
72mod exec;
73pub mod execute;
75mod source;
76mod table_provider;
77pub mod watermark_udf;
79pub mod window_udf;
81
82pub use aggregate_bridge::{
83 create_aggregate_factory, lookup_aggregate_udf, result_to_scalar_value, scalar_value_to_result,
84 DataFusionAccumulatorAdapter, DataFusionAggregateFactory,
85};
86pub use bridge::{BridgeSendError, BridgeSender, BridgeStream, BridgeTrySendError, StreamBridge};
87pub use channel_source::ChannelStreamSource;
88pub use exec::StreamingScanExec;
89pub use execute::{execute_streaming_sql, DdlResult, QueryResult, StreamingSqlResult};
90pub use source::{SortColumn, StreamSource, StreamSourceRef};
91pub use table_provider::StreamingTableProvider;
92pub use watermark_udf::WatermarkUdf;
93pub use window_udf::{HopWindowStart, SessionWindowStart, TumbleWindowStart};
94
95use std::sync::atomic::AtomicI64;
96use std::sync::Arc;
97
98use datafusion::prelude::*;
99use datafusion_expr::ScalarUDF;
100
101#[must_use]
120pub fn create_streaming_context() -> SessionContext {
121 let config = SessionConfig::new()
122 .with_batch_size(8192)
123 .with_target_partitions(1); let ctx = SessionContext::new_with_config(config);
126 register_streaming_functions(&ctx);
127 ctx
128}
129
130pub fn register_streaming_functions(ctx: &SessionContext) {
141 ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
142 ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
143 ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
144 ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::unset()));
145}
146
147pub fn register_streaming_functions_with_watermark(
158 ctx: &SessionContext,
159 watermark_ms: Arc<AtomicI64>,
160) {
161 ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
162 ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
163 ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
164 ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::new(watermark_ms)));
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use arrow_array::{Float64Array, Int64Array, RecordBatch};
171 use arrow_schema::{DataType, Field, Schema};
172 use datafusion::execution::FunctionRegistry;
173 use futures::StreamExt;
174 use std::sync::Arc;
175
176 fn test_schema() -> Arc<Schema> {
177 Arc::new(Schema::new(vec![
178 Field::new("id", DataType::Int64, false),
179 Field::new("value", DataType::Float64, true),
180 ]))
181 }
182
183 fn take_test_sender(source: &ChannelStreamSource) -> super::bridge::BridgeSender {
185 source.take_sender().expect("sender already taken")
186 }
187
188 fn test_batch(schema: &Arc<Schema>, ids: Vec<i64>, values: Vec<f64>) -> RecordBatch {
189 RecordBatch::try_new(
190 Arc::clone(schema),
191 vec![
192 Arc::new(Int64Array::from(ids)),
193 Arc::new(Float64Array::from(values)),
194 ],
195 )
196 .unwrap()
197 }
198
199 #[test]
200 fn test_create_streaming_context() {
201 let ctx = create_streaming_context();
202 let state = ctx.state();
203 let config = state.config();
204
205 assert_eq!(config.batch_size(), 8192);
206 assert_eq!(config.target_partitions(), 1);
207 }
208
209 #[tokio::test]
210 async fn test_full_query_pipeline() {
211 let ctx = create_streaming_context();
212 let schema = test_schema();
213
214 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
216 let sender = take_test_sender(&source);
217 let provider = StreamingTableProvider::new("events", source);
218 ctx.register_table("events", Arc::new(provider)).unwrap();
219
220 sender
222 .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
223 .await
224 .unwrap();
225 sender
226 .send(test_batch(&schema, vec![4, 5], vec![40.0, 50.0]))
227 .await
228 .unwrap();
229 drop(sender); let df = ctx.sql("SELECT * FROM events").await.unwrap();
233 let batches = df.collect().await.unwrap();
234
235 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
237 assert_eq!(total_rows, 5);
238 }
239
240 #[tokio::test]
241 async fn test_query_with_projection() {
242 let ctx = create_streaming_context();
243 let schema = test_schema();
244
245 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
246 let sender = take_test_sender(&source);
247 let provider = StreamingTableProvider::new("events", source);
248 ctx.register_table("events", Arc::new(provider)).unwrap();
249
250 sender
251 .send(test_batch(&schema, vec![1, 2], vec![100.0, 200.0]))
252 .await
253 .unwrap();
254 drop(sender);
255
256 let df = ctx.sql("SELECT id FROM events").await.unwrap();
258 let batches = df.collect().await.unwrap();
259
260 assert_eq!(batches.len(), 1);
261 assert_eq!(batches[0].num_columns(), 1);
262 assert_eq!(batches[0].schema().field(0).name(), "id");
263 }
264
265 #[tokio::test]
266 async fn test_query_with_filter() {
267 let ctx = create_streaming_context();
268 let schema = test_schema();
269
270 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
271 let sender = take_test_sender(&source);
272 let provider = StreamingTableProvider::new("events", source);
273 ctx.register_table("events", Arc::new(provider)).unwrap();
274
275 sender
276 .send(test_batch(
277 &schema,
278 vec![1, 2, 3, 4, 5],
279 vec![10.0, 20.0, 30.0, 40.0, 50.0],
280 ))
281 .await
282 .unwrap();
283 drop(sender);
284
285 let df = ctx
287 .sql("SELECT * FROM events WHERE value > 25")
288 .await
289 .unwrap();
290 let batches = df.collect().await.unwrap();
291
292 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
293 assert_eq!(total_rows, 3); }
295
296 #[tokio::test]
297 async fn test_unbounded_aggregation_rejected() {
298 let ctx = create_streaming_context();
301 let schema = test_schema();
302
303 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
304 let sender = take_test_sender(&source);
305 let provider = StreamingTableProvider::new("events", source);
306 ctx.register_table("events", Arc::new(provider)).unwrap();
307
308 sender
309 .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
310 .await
311 .unwrap();
312 drop(sender);
313
314 let df = ctx.sql("SELECT COUNT(*) as cnt FROM events").await.unwrap();
316
317 let result = df.collect().await;
319 assert!(
320 result.is_err(),
321 "Aggregation on unbounded stream should fail"
322 );
323 }
324
325 #[tokio::test]
326 async fn test_query_with_order_by() {
327 let ctx = create_streaming_context();
328 let schema = test_schema();
329
330 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
331 let sender = take_test_sender(&source);
332 let provider = StreamingTableProvider::new("events", source);
333 ctx.register_table("events", Arc::new(provider)).unwrap();
334
335 sender
336 .send(test_batch(&schema, vec![3, 1, 2], vec![30.0, 10.0, 20.0]))
337 .await
338 .unwrap();
339 drop(sender);
340
341 let df = ctx.sql("SELECT id, value FROM events").await.unwrap();
343 let batches = df.collect().await.unwrap();
344
345 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
347 assert_eq!(total_rows, 3);
348 }
349
350 #[tokio::test]
351 async fn test_bridge_throughput() {
352 let schema = test_schema();
354 let bridge = StreamBridge::new(Arc::clone(&schema), 10000);
355 let sender = bridge.sender();
356 let mut stream = bridge.into_stream();
357
358 let batch_count = 1000;
359 let batch = test_batch(&schema, vec![1, 2, 3, 4, 5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
360
361 let send_task = tokio::spawn(async move {
363 for _ in 0..batch_count {
364 sender.send(batch.clone()).await.unwrap();
365 }
366 });
367
368 let mut received = 0;
370 while let Some(result) = stream.next().await {
371 result.unwrap();
372 received += 1;
373 if received == batch_count {
374 break;
375 }
376 }
377
378 send_task.await.unwrap();
379 assert_eq!(received, batch_count);
380 }
381
382 #[test]
385 fn test_streaming_functions_registered() {
386 let ctx = create_streaming_context();
387 assert!(ctx.udf("tumble").is_ok(), "tumble UDF not registered");
389 assert!(ctx.udf("hop").is_ok(), "hop UDF not registered");
390 assert!(ctx.udf("session").is_ok(), "session UDF not registered");
391 assert!(ctx.udf("watermark").is_ok(), "watermark UDF not registered");
392 }
393
394 #[test]
395 fn test_streaming_functions_with_watermark() {
396 use std::sync::atomic::AtomicI64;
397
398 let ctx = SessionContext::new();
399 let wm = Arc::new(AtomicI64::new(42_000));
400 register_streaming_functions_with_watermark(&ctx, wm);
401
402 assert!(ctx.udf("tumble").is_ok());
403 assert!(ctx.udf("watermark").is_ok());
404 }
405
406 #[tokio::test]
407 async fn test_tumble_udf_via_datafusion() {
408 use arrow_array::TimestampMillisecondArray;
409 use arrow_schema::TimeUnit;
410
411 let ctx = create_streaming_context();
412
413 let schema = Arc::new(Schema::new(vec![
415 Field::new(
416 "event_time",
417 DataType::Timestamp(TimeUnit::Millisecond, None),
418 false,
419 ),
420 Field::new("value", DataType::Float64, false),
421 ]));
422
423 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
424 let sender = take_test_sender(&source);
425 let provider = StreamingTableProvider::new("events", source);
426 ctx.register_table("events", Arc::new(provider)).unwrap();
427
428 let batch = RecordBatch::try_new(
432 Arc::clone(&schema),
433 vec![
434 Arc::new(TimestampMillisecondArray::from(vec![
435 60_000i64, 120_000, 360_000,
436 ])),
437 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
438 ],
439 )
440 .unwrap();
441 sender.send(batch).await.unwrap();
442 drop(sender);
443
444 let df = ctx
447 .sql(
448 "SELECT tumble(event_time, INTERVAL '5' MINUTE) as window_start, \
449 value \
450 FROM events",
451 )
452 .await
453 .unwrap();
454
455 let batches = df.collect().await.unwrap();
456 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
457 assert_eq!(total_rows, 3);
458
459 let ws_col = batches[0]
461 .column(0)
462 .as_any()
463 .downcast_ref::<TimestampMillisecondArray>()
464 .expect("window_start should be TimestampMillisecond");
465 assert_eq!(ws_col.value(0), 0);
467 assert_eq!(ws_col.value(1), 0);
468 assert_eq!(ws_col.value(2), 300_000);
470 }
471
472 #[tokio::test]
473 async fn test_logical_plan_from_windowed_query() {
474 use arrow_schema::TimeUnit;
475
476 let ctx = create_streaming_context();
477
478 let schema = Arc::new(Schema::new(vec![
479 Field::new(
480 "event_time",
481 DataType::Timestamp(TimeUnit::Millisecond, None),
482 false,
483 ),
484 Field::new("value", DataType::Float64, false),
485 ]));
486
487 let source = Arc::new(ChannelStreamSource::new(schema));
488 let _sender = source.take_sender();
489 let provider = StreamingTableProvider::new("events", source);
490 ctx.register_table("events", Arc::new(provider)).unwrap();
491
492 let df = ctx
494 .sql(
495 "SELECT tumble(event_time, INTERVAL '5' MINUTE) as w, \
496 COUNT(*) as cnt \
497 FROM events \
498 GROUP BY tumble(event_time, INTERVAL '5' MINUTE)",
499 )
500 .await;
501
502 assert!(df.is_ok(), "Failed to create logical plan: {df:?}");
504 }
505
506 #[tokio::test]
507 async fn test_end_to_end_execute_streaming_sql() {
508 use crate::planner::StreamingPlanner;
509
510 let ctx = create_streaming_context();
511
512 let schema = Arc::new(Schema::new(vec![
513 Field::new("id", DataType::Int64, false),
514 Field::new("name", DataType::Utf8, true),
515 ]));
516
517 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
518 let sender = take_test_sender(&source);
519 let provider = StreamingTableProvider::new("items", source);
520 ctx.register_table("items", Arc::new(provider)).unwrap();
521
522 let batch = RecordBatch::try_new(
523 Arc::clone(&schema),
524 vec![
525 Arc::new(Int64Array::from(vec![1, 2, 3])),
526 Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])),
527 ],
528 )
529 .unwrap();
530 sender.send(batch).await.unwrap();
531 drop(sender);
532
533 let mut planner = StreamingPlanner::new();
534 let result = execute_streaming_sql("SELECT id FROM items WHERE id > 1", &ctx, &mut planner)
535 .await
536 .unwrap();
537
538 match result {
539 StreamingSqlResult::Query(qr) => {
540 let mut stream = qr.stream;
541 let mut total = 0;
542 while let Some(batch) = stream.next().await {
543 total += batch.unwrap().num_rows();
544 }
545 assert_eq!(total, 2); }
547 StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
548 }
549 }
550
551 #[tokio::test]
552 async fn test_watermark_function_in_filter() {
553 use arrow_array::TimestampMillisecondArray;
554 use arrow_schema::TimeUnit;
555 use std::sync::atomic::AtomicI64;
556
557 let config = SessionConfig::new()
559 .with_batch_size(8192)
560 .with_target_partitions(1);
561 let ctx = SessionContext::new_with_config(config);
562 let wm = Arc::new(AtomicI64::new(200_000)); register_streaming_functions_with_watermark(&ctx, wm);
564
565 let schema = Arc::new(Schema::new(vec![
566 Field::new(
567 "event_time",
568 DataType::Timestamp(TimeUnit::Millisecond, None),
569 false,
570 ),
571 Field::new("value", DataType::Float64, false),
572 ]));
573
574 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
575 let sender = take_test_sender(&source);
576 let provider = StreamingTableProvider::new("events", source);
577 ctx.register_table("events", Arc::new(provider)).unwrap();
578
579 let batch = RecordBatch::try_new(
581 Arc::clone(&schema),
582 vec![
583 Arc::new(TimestampMillisecondArray::from(vec![
584 100_000i64, 200_000, 300_000,
585 ])),
586 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
587 ],
588 )
589 .unwrap();
590 sender.send(batch).await.unwrap();
591 drop(sender);
592
593 let df = ctx
595 .sql("SELECT value FROM events WHERE event_time > watermark()")
596 .await
597 .unwrap();
598 let batches = df.collect().await.unwrap();
599 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
600 assert_eq!(total_rows, 1);
602 }
603}