1pub mod aggregate_bridge;
70mod bridge;
71mod channel_source;
72pub mod complex_type_lambda;
74pub mod complex_type_udf;
76mod exec;
77pub mod execute;
79pub mod format_bridge_udf;
81pub mod json_extensions;
83pub mod json_path;
85pub mod json_tvf;
87pub mod json_types;
89pub mod json_udaf;
91pub mod json_udf;
93pub mod lookup_join;
95pub mod proctime_udf;
97mod source;
98mod table_provider;
99pub mod watermark_udf;
101pub mod window_udf;
103
104pub use aggregate_bridge::{
105 create_aggregate_factory, lookup_aggregate_udf, result_to_scalar_value, scalar_value_to_result,
106 DataFusionAccumulatorAdapter, DataFusionAggregateFactory,
107};
108pub use bridge::{BridgeSendError, BridgeSender, BridgeStream, BridgeTrySendError, StreamBridge};
109pub use channel_source::ChannelStreamSource;
110pub use complex_type_lambda::{
111 register_lambda_functions, ArrayFilter, ArrayReduce, ArrayTransform, MapFilter,
112 MapTransformValues,
113};
114pub use complex_type_udf::{
115 register_complex_type_functions, MapContainsKey, MapFromArrays, MapKeys, MapValues, StructDrop,
116 StructExtract, StructMerge, StructRename, StructSet,
117};
118pub use exec::StreamingScanExec;
119pub use execute::{execute_streaming_sql, DdlResult, QueryResult, StreamingSqlResult};
120pub use format_bridge_udf::{FromJsonUdf, ParseEpochUdf, ParseTimestampUdf, ToJsonUdf};
121pub use json_extensions::{
122 register_json_extensions, JsonInferSchema, JsonToColumns, JsonbDeepMerge, JsonbExcept,
123 JsonbFlatten, JsonbMerge, JsonbPick, JsonbRenameKeys, JsonbStripNulls, JsonbUnflatten,
124};
125pub use json_path::{CompiledJsonPath, JsonPathStep, JsonbPathExistsUdf, JsonbPathMatchUdf};
126pub use json_tvf::{
127 register_json_table_functions, JsonbArrayElementsTextTvf, JsonbArrayElementsTvf,
128 JsonbEachTextTvf, JsonbEachTvf, JsonbObjectKeysTvf,
129};
130pub use json_udaf::{JsonAgg, JsonObjectAgg};
131pub use json_udf::{
132 JsonBuildArray, JsonBuildObject, JsonTypeof, JsonbContainedBy, JsonbContains, JsonbExists,
133 JsonbExistsAll, JsonbExistsAny, JsonbGet, JsonbGetIdx, JsonbGetPath, JsonbGetPathText,
134 JsonbGetText, JsonbGetTextIdx, ToJsonb,
135};
136pub use proctime_udf::ProcTimeUdf;
137pub use source::{SortColumn, StreamSource, StreamSourceRef};
138pub use table_provider::StreamingTableProvider;
139pub use watermark_udf::WatermarkUdf;
140pub use window_udf::{CumulateWindowStart, HopWindowStart, SessionWindowStart, TumbleWindowStart};
141
142use std::sync::atomic::AtomicI64;
143use std::sync::Arc;
144
145use datafusion::prelude::*;
146use datafusion_expr::ScalarUDF;
147
148#[must_use]
167pub fn create_streaming_context() -> SessionContext {
168 let config = SessionConfig::new()
169 .with_batch_size(8192)
170 .with_target_partitions(1); let ctx = SessionContext::new_with_config(config);
173 register_streaming_functions(&ctx);
174 ctx
175}
176
177pub fn register_streaming_functions(ctx: &SessionContext) {
188 ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
189 ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
190 ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
191 ctx.register_udf(ScalarUDF::new_from_impl(CumulateWindowStart::new()));
192 ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::unset()));
193 ctx.register_udf(ScalarUDF::new_from_impl(ProcTimeUdf::new()));
194 register_json_functions(ctx);
195 register_json_extensions(ctx);
196 register_complex_type_functions(ctx);
197 register_lambda_functions(ctx);
198}
199
200pub fn register_streaming_functions_with_watermark(
211 ctx: &SessionContext,
212 watermark_ms: Arc<AtomicI64>,
213) {
214 ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
215 ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
216 ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
217 ctx.register_udf(ScalarUDF::new_from_impl(CumulateWindowStart::new()));
218 ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::new(watermark_ms)));
219 ctx.register_udf(ScalarUDF::new_from_impl(ProcTimeUdf::new()));
220 register_json_functions(ctx);
221 register_json_extensions(ctx);
222 register_complex_type_functions(ctx);
223 register_lambda_functions(ctx);
224}
225
226pub fn register_json_functions(ctx: &SessionContext) {
229 ctx.register_udf(ScalarUDF::new_from_impl(JsonbGet::new()));
231 ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetIdx::new()));
232 ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetText::new()));
233 ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetTextIdx::new()));
234 ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetPath::new()));
235 ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetPathText::new()));
236
237 ctx.register_udf(ScalarUDF::new_from_impl(JsonbExists::new()));
239 ctx.register_udf(ScalarUDF::new_from_impl(JsonbExistsAny::new()));
240 ctx.register_udf(ScalarUDF::new_from_impl(JsonbExistsAll::new()));
241
242 ctx.register_udf(ScalarUDF::new_from_impl(JsonbContains::new()));
244 ctx.register_udf(ScalarUDF::new_from_impl(JsonbContainedBy::new()));
245
246 ctx.register_udf(ScalarUDF::new_from_impl(JsonTypeof::new()));
248 ctx.register_udf(ScalarUDF::new_from_impl(JsonBuildObject::new()));
249 ctx.register_udf(ScalarUDF::new_from_impl(JsonBuildArray::new()));
250 ctx.register_udf(ScalarUDF::new_from_impl(ToJsonb::new()));
251
252 ctx.register_udaf(datafusion_expr::AggregateUDF::new_from_impl(JsonAgg::new()));
254 ctx.register_udaf(datafusion_expr::AggregateUDF::new_from_impl(
255 JsonObjectAgg::new(),
256 ));
257
258 ctx.register_udf(ScalarUDF::new_from_impl(ParseEpochUdf::new()));
260 ctx.register_udf(ScalarUDF::new_from_impl(ParseTimestampUdf::new()));
261 ctx.register_udf(ScalarUDF::new_from_impl(ToJsonUdf::new()));
262 ctx.register_udf(ScalarUDF::new_from_impl(FromJsonUdf::new()));
263
264 ctx.register_udf(ScalarUDF::new_from_impl(JsonbPathExistsUdf::new()));
266 ctx.register_udf(ScalarUDF::new_from_impl(JsonbPathMatchUdf::new()));
267
268 register_json_table_functions(ctx);
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use arrow_array::{Float64Array, Int64Array, RecordBatch};
276 use arrow_schema::{DataType, Field, Schema};
277 use datafusion::execution::FunctionRegistry;
278 use futures::StreamExt;
279 use std::sync::Arc;
280
281 fn test_schema() -> Arc<Schema> {
282 Arc::new(Schema::new(vec![
283 Field::new("id", DataType::Int64, false),
284 Field::new("value", DataType::Float64, true),
285 ]))
286 }
287
288 fn take_test_sender(source: &ChannelStreamSource) -> super::bridge::BridgeSender {
290 source.take_sender().expect("sender already taken")
291 }
292
293 fn test_batch(schema: &Arc<Schema>, ids: Vec<i64>, values: Vec<f64>) -> RecordBatch {
294 RecordBatch::try_new(
295 Arc::clone(schema),
296 vec![
297 Arc::new(Int64Array::from(ids)),
298 Arc::new(Float64Array::from(values)),
299 ],
300 )
301 .unwrap()
302 }
303
304 #[test]
305 fn test_create_streaming_context() {
306 let ctx = create_streaming_context();
307 let state = ctx.state();
308 let config = state.config();
309
310 assert_eq!(config.batch_size(), 8192);
311 assert_eq!(config.target_partitions(), 1);
312 }
313
314 #[tokio::test]
315 async fn test_full_query_pipeline() {
316 let ctx = create_streaming_context();
317 let schema = test_schema();
318
319 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
321 let sender = take_test_sender(&source);
322 let provider = StreamingTableProvider::new("events", source);
323 ctx.register_table("events", Arc::new(provider)).unwrap();
324
325 sender
327 .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
328 .await
329 .unwrap();
330 sender
331 .send(test_batch(&schema, vec![4, 5], vec![40.0, 50.0]))
332 .await
333 .unwrap();
334 drop(sender); let df = ctx.sql("SELECT * FROM events").await.unwrap();
338 let batches = df.collect().await.unwrap();
339
340 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
342 assert_eq!(total_rows, 5);
343 }
344
345 #[tokio::test]
346 async fn test_query_with_projection() {
347 let ctx = create_streaming_context();
348 let schema = test_schema();
349
350 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
351 let sender = take_test_sender(&source);
352 let provider = StreamingTableProvider::new("events", source);
353 ctx.register_table("events", Arc::new(provider)).unwrap();
354
355 sender
356 .send(test_batch(&schema, vec![1, 2], vec![100.0, 200.0]))
357 .await
358 .unwrap();
359 drop(sender);
360
361 let df = ctx.sql("SELECT id FROM events").await.unwrap();
363 let batches = df.collect().await.unwrap();
364
365 assert_eq!(batches.len(), 1);
366 assert_eq!(batches[0].num_columns(), 1);
367 assert_eq!(batches[0].schema().field(0).name(), "id");
368 }
369
370 #[tokio::test]
371 async fn test_query_with_filter() {
372 let ctx = create_streaming_context();
373 let schema = test_schema();
374
375 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
376 let sender = take_test_sender(&source);
377 let provider = StreamingTableProvider::new("events", source);
378 ctx.register_table("events", Arc::new(provider)).unwrap();
379
380 sender
381 .send(test_batch(
382 &schema,
383 vec![1, 2, 3, 4, 5],
384 vec![10.0, 20.0, 30.0, 40.0, 50.0],
385 ))
386 .await
387 .unwrap();
388 drop(sender);
389
390 let df = ctx
392 .sql("SELECT * FROM events WHERE value > 25")
393 .await
394 .unwrap();
395 let batches = df.collect().await.unwrap();
396
397 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
398 assert_eq!(total_rows, 3); }
400
401 #[tokio::test]
402 async fn test_unbounded_aggregation_rejected() {
403 let ctx = create_streaming_context();
406 let schema = test_schema();
407
408 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
409 let sender = take_test_sender(&source);
410 let provider = StreamingTableProvider::new("events", source);
411 ctx.register_table("events", Arc::new(provider)).unwrap();
412
413 sender
414 .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
415 .await
416 .unwrap();
417 drop(sender);
418
419 let df = ctx.sql("SELECT COUNT(*) as cnt FROM events").await.unwrap();
421
422 let result = df.collect().await;
424 assert!(
425 result.is_err(),
426 "Aggregation on unbounded stream should fail"
427 );
428 }
429
430 #[tokio::test]
431 async fn test_query_with_order_by() {
432 let ctx = create_streaming_context();
433 let schema = test_schema();
434
435 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
436 let sender = take_test_sender(&source);
437 let provider = StreamingTableProvider::new("events", source);
438 ctx.register_table("events", Arc::new(provider)).unwrap();
439
440 sender
441 .send(test_batch(&schema, vec![3, 1, 2], vec![30.0, 10.0, 20.0]))
442 .await
443 .unwrap();
444 drop(sender);
445
446 let df = ctx.sql("SELECT id, value FROM events").await.unwrap();
448 let batches = df.collect().await.unwrap();
449
450 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
452 assert_eq!(total_rows, 3);
453 }
454
455 #[tokio::test]
456 async fn test_bridge_throughput() {
457 let schema = test_schema();
459 let bridge = StreamBridge::new(Arc::clone(&schema), 10000);
460 let sender = bridge.sender();
461 let mut stream = bridge.into_stream();
462
463 let batch_count = 1000;
464 let batch = test_batch(&schema, vec![1, 2, 3, 4, 5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
465
466 let send_task = tokio::spawn(async move {
468 for _ in 0..batch_count {
469 sender.send(batch.clone()).await.unwrap();
470 }
471 });
472
473 let mut received = 0;
475 while let Some(result) = stream.next().await {
476 result.unwrap();
477 received += 1;
478 if received == batch_count {
479 break;
480 }
481 }
482
483 send_task.await.unwrap();
484 assert_eq!(received, batch_count);
485 }
486
487 #[test]
490 fn test_streaming_functions_registered() {
491 let ctx = create_streaming_context();
492 assert!(ctx.udf("tumble").is_ok(), "tumble UDF not registered");
494 assert!(ctx.udf("hop").is_ok(), "hop UDF not registered");
495 assert!(ctx.udf("session").is_ok(), "session UDF not registered");
496 assert!(ctx.udf("watermark").is_ok(), "watermark UDF not registered");
497 }
498
499 #[test]
500 fn test_streaming_functions_with_watermark() {
501 use std::sync::atomic::AtomicI64;
502
503 let ctx = SessionContext::new();
504 let wm = Arc::new(AtomicI64::new(42_000));
505 register_streaming_functions_with_watermark(&ctx, wm);
506
507 assert!(ctx.udf("tumble").is_ok());
508 assert!(ctx.udf("watermark").is_ok());
509 }
510
511 #[tokio::test]
512 async fn test_tumble_udf_via_datafusion() {
513 use arrow_array::TimestampMillisecondArray;
514 use arrow_schema::TimeUnit;
515
516 let ctx = create_streaming_context();
517
518 let schema = Arc::new(Schema::new(vec![
520 Field::new(
521 "event_time",
522 DataType::Timestamp(TimeUnit::Millisecond, None),
523 false,
524 ),
525 Field::new("value", DataType::Float64, false),
526 ]));
527
528 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
529 let sender = take_test_sender(&source);
530 let provider = StreamingTableProvider::new("events", source);
531 ctx.register_table("events", Arc::new(provider)).unwrap();
532
533 let batch = RecordBatch::try_new(
537 Arc::clone(&schema),
538 vec![
539 Arc::new(TimestampMillisecondArray::from(vec![
540 60_000i64, 120_000, 360_000,
541 ])),
542 Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
543 ],
544 )
545 .unwrap();
546 sender.send(batch).await.unwrap();
547 drop(sender);
548
549 let df = ctx
552 .sql(
553 "SELECT tumble(event_time, INTERVAL '5' MINUTE) as window_start, \
554 value \
555 FROM events",
556 )
557 .await
558 .unwrap();
559
560 let batches = df.collect().await.unwrap();
561 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
562 assert_eq!(total_rows, 3);
563
564 let ws_col = batches[0]
566 .column(0)
567 .as_any()
568 .downcast_ref::<TimestampMillisecondArray>()
569 .expect("window_start should be TimestampMillisecond");
570 assert_eq!(ws_col.value(0), 0);
572 assert_eq!(ws_col.value(1), 0);
573 assert_eq!(ws_col.value(2), 300_000);
575 }
576
577 #[tokio::test]
578 async fn test_logical_plan_from_windowed_query() {
579 use arrow_schema::TimeUnit;
580
581 let ctx = create_streaming_context();
582
583 let schema = Arc::new(Schema::new(vec![
584 Field::new(
585 "event_time",
586 DataType::Timestamp(TimeUnit::Millisecond, None),
587 false,
588 ),
589 Field::new("value", DataType::Float64, false),
590 ]));
591
592 let source = Arc::new(ChannelStreamSource::new(schema));
593 let _sender = source.take_sender();
594 let provider = StreamingTableProvider::new("events", source);
595 ctx.register_table("events", Arc::new(provider)).unwrap();
596
597 let df = ctx
599 .sql(
600 "SELECT tumble(event_time, INTERVAL '5' MINUTE) as w, \
601 COUNT(*) as cnt \
602 FROM events \
603 GROUP BY tumble(event_time, INTERVAL '5' MINUTE)",
604 )
605 .await;
606
607 assert!(df.is_ok(), "Failed to create logical plan: {df:?}");
609 }
610
611 #[tokio::test]
612 async fn test_end_to_end_execute_streaming_sql() {
613 use crate::planner::StreamingPlanner;
614
615 let ctx = create_streaming_context();
616
617 let schema = Arc::new(Schema::new(vec![
618 Field::new("id", DataType::Int64, false),
619 Field::new("name", DataType::Utf8, true),
620 ]));
621
622 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
623 let sender = take_test_sender(&source);
624 let provider = StreamingTableProvider::new("items", source);
625 ctx.register_table("items", Arc::new(provider)).unwrap();
626
627 let batch = RecordBatch::try_new(
628 Arc::clone(&schema),
629 vec![
630 Arc::new(Int64Array::from(vec![1, 2, 3])),
631 Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])),
632 ],
633 )
634 .unwrap();
635 sender.send(batch).await.unwrap();
636 drop(sender);
637
638 let mut planner = StreamingPlanner::new();
639 let result = execute_streaming_sql("SELECT id FROM items WHERE id > 1", &ctx, &mut planner)
640 .await
641 .unwrap();
642
643 match result {
644 StreamingSqlResult::Query(qr) => {
645 let mut stream = qr.stream;
646 let mut total = 0;
647 while let Some(batch) = stream.next().await {
648 total += batch.unwrap().num_rows();
649 }
650 assert_eq!(total, 2); }
652 StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
653 }
654 }
655
656 #[tokio::test]
657 async fn test_watermark_function_in_filter() {
658 use arrow_array::TimestampMillisecondArray;
659 use arrow_schema::TimeUnit;
660 use std::sync::atomic::AtomicI64;
661
662 let config = SessionConfig::new()
664 .with_batch_size(8192)
665 .with_target_partitions(1);
666 let ctx = SessionContext::new_with_config(config);
667 let wm = Arc::new(AtomicI64::new(200_000)); register_streaming_functions_with_watermark(&ctx, wm);
669
670 let schema = Arc::new(Schema::new(vec![
671 Field::new(
672 "event_time",
673 DataType::Timestamp(TimeUnit::Millisecond, None),
674 false,
675 ),
676 Field::new("value", DataType::Float64, false),
677 ]));
678
679 let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
680 let sender = take_test_sender(&source);
681 let provider = StreamingTableProvider::new("events", source);
682 ctx.register_table("events", Arc::new(provider)).unwrap();
683
684 let batch = RecordBatch::try_new(
686 Arc::clone(&schema),
687 vec![
688 Arc::new(TimestampMillisecondArray::from(vec![
689 100_000i64, 200_000, 300_000,
690 ])),
691 Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
692 ],
693 )
694 .unwrap();
695 sender.send(batch).await.unwrap();
696 drop(sender);
697
698 let df = ctx
700 .sql("SELECT value FROM events WHERE event_time > watermark()")
701 .await
702 .unwrap();
703 let batches = df.collect().await.unwrap();
704 let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
705 assert_eq!(total_rows, 1);
707 }
708}