Skip to main content

laminar_sql/datafusion/
mod.rs

1//! `DataFusion` integration for SQL processing
2//!
3//! This module provides the integration layer between `LaminarDB`'s push-based
4//! streaming engine and `DataFusion`'s pull-based SQL query execution.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────┐
10//! │                    Ring 2: Query Planning                        │
11//! │  SQL Query → SessionContext → LogicalPlan → ExecutionPlan       │
12//! │                                      │                          │
13//! │                            StreamingScanExec                    │
14//! │                                      │                          │
15//! │                              ┌───────▼──────┐                   │
16//! │                              │ StreamBridge │ (tokio channel)   │
17//! │                              └───────▲──────┘                   │
18//! ├──────────────────────────────────────┼──────────────────────────┤
19//! │                    Ring 0: Hot Path   │                          │
20//! │                                      │                          │
21//! │  Source → Reactor.poll() ────────────┘                          │
22//! │              (Events with RecordBatch data)                     │
23//! └─────────────────────────────────────────────────────────────────┘
24//! ```
25//!
26//! # Components
27//!
28//! - [`StreamSource`]: Trait for streaming data sources
29//! - [`StreamBridge`]: Channel-based push-to-pull bridge
30//! - [`StreamingScanExec`]: `DataFusion` execution plan for streaming scans
31//! - [`StreamingTableProvider`]: `DataFusion` table provider for streaming sources
32//! - [`ChannelStreamSource`]: Concrete source using channels
33//!
34//! # Usage
35//!
36//! ```rust,ignore
37//! use laminar_sql::datafusion::{
38//!     create_streaming_context, ChannelStreamSource, StreamingTableProvider,
39//! };
40//! use std::sync::Arc;
41//!
42//! // Create a streaming context
43//! let ctx = create_streaming_context();
44//!
45//! // Create a channel source
46//! let schema = Arc::new(Schema::new(vec![
47//!     Field::new("id", DataType::Int64, false),
48//!     Field::new("value", DataType::Float64, true),
49//! ]));
50//! let source = Arc::new(ChannelStreamSource::new(schema));
51//! let sender = source.sender();
52//!
53//! // Register as a table
54//! let provider = StreamingTableProvider::new("events", source);
55//! ctx.register_table("events", Arc::new(provider))?;
56//!
57//! // Push data from the Reactor
58//! sender.send(batch).await?;
59//!
60//! // Execute SQL queries
61//! let df = ctx.sql("SELECT * FROM events WHERE value > 100").await?;
62//! ```
63
64/// DataFusion aggregate bridge for streaming aggregation.
65///
66/// Bridges DataFusion's `Accumulator` trait with `laminar-core`'s
67/// `DynAccumulator` / `DynAggregatorFactory` traits. This avoids
68/// duplicating aggregation logic.
69pub mod aggregate_bridge;
70mod bridge;
71mod channel_source;
72/// Lambda higher-order functions for arrays and maps (F-SCHEMA-015 Tier 3)
73pub mod complex_type_lambda;
74/// Array, Struct, and Map scalar UDFs (F-SCHEMA-015)
75pub mod complex_type_udf;
76mod exec;
77/// End-to-end streaming SQL execution
78pub mod execute;
79/// Format bridge UDFs for inline format conversion
80pub mod format_bridge_udf;
81/// LaminarDB streaming JSON extension UDFs (F-SCHEMA-013)
82pub mod json_extensions;
83/// SQL/JSON path query compiler and scalar UDFs
84pub mod json_path;
85/// JSON table-valued functions (array/object expansion)
86pub mod json_tvf;
87/// JSONB binary format types for JSON UDF evaluation
88pub mod json_types;
89/// PostgreSQL-compatible JSON aggregate UDAFs
90pub mod json_udaf;
91/// PostgreSQL-compatible JSON scalar UDFs
92pub mod json_udf;
93/// Lookup join plan node for DataFusion.
94pub mod lookup_join;
95/// Processing-time UDF for `PROCTIME()` support
96pub mod proctime_udf;
97mod source;
98mod table_provider;
99/// Watermark UDF for current watermark access
100pub mod watermark_udf;
101/// Window function UDFs (TUMBLE, HOP, SESSION, CUMULATE)
102pub mod window_udf;
103
104pub use aggregate_bridge::{
105    create_aggregate_factory, lookup_aggregate_udf, result_to_scalar_value, scalar_value_to_result,
106    DataFusionAccumulatorAdapter, DataFusionAggregateFactory,
107};
108pub use bridge::{BridgeSendError, BridgeSender, BridgeStream, BridgeTrySendError, StreamBridge};
109pub use channel_source::ChannelStreamSource;
110pub use complex_type_lambda::{
111    register_lambda_functions, ArrayFilter, ArrayReduce, ArrayTransform, MapFilter,
112    MapTransformValues,
113};
114pub use complex_type_udf::{
115    register_complex_type_functions, MapContainsKey, MapFromArrays, MapKeys, MapValues, StructDrop,
116    StructExtract, StructMerge, StructRename, StructSet,
117};
118pub use exec::StreamingScanExec;
119pub use execute::{execute_streaming_sql, DdlResult, QueryResult, StreamingSqlResult};
120pub use format_bridge_udf::{FromJsonUdf, ParseEpochUdf, ParseTimestampUdf, ToJsonUdf};
121pub use json_extensions::{
122    register_json_extensions, JsonInferSchema, JsonToColumns, JsonbDeepMerge, JsonbExcept,
123    JsonbFlatten, JsonbMerge, JsonbPick, JsonbRenameKeys, JsonbStripNulls, JsonbUnflatten,
124};
125pub use json_path::{CompiledJsonPath, JsonPathStep, JsonbPathExistsUdf, JsonbPathMatchUdf};
126pub use json_tvf::{
127    register_json_table_functions, JsonbArrayElementsTextTvf, JsonbArrayElementsTvf,
128    JsonbEachTextTvf, JsonbEachTvf, JsonbObjectKeysTvf,
129};
130pub use json_udaf::{JsonAgg, JsonObjectAgg};
131pub use json_udf::{
132    JsonBuildArray, JsonBuildObject, JsonTypeof, JsonbContainedBy, JsonbContains, JsonbExists,
133    JsonbExistsAll, JsonbExistsAny, JsonbGet, JsonbGetIdx, JsonbGetPath, JsonbGetPathText,
134    JsonbGetText, JsonbGetTextIdx, ToJsonb,
135};
136pub use proctime_udf::ProcTimeUdf;
137pub use source::{SortColumn, StreamSource, StreamSourceRef};
138pub use table_provider::StreamingTableProvider;
139pub use watermark_udf::WatermarkUdf;
140pub use window_udf::{CumulateWindowStart, HopWindowStart, SessionWindowStart, TumbleWindowStart};
141
142use std::sync::atomic::AtomicI64;
143use std::sync::Arc;
144
145use datafusion::prelude::*;
146use datafusion_expr::ScalarUDF;
147
148/// Creates a `DataFusion` session context configured for streaming queries.
149///
150/// The context is configured with:
151/// - Batch size of 8192 (balanced for streaming throughput)
152/// - Single partition (streaming sources are typically not partitioned)
153/// - All streaming UDFs registered (TUMBLE, HOP, SESSION, WATERMARK)
154///
155/// The watermark UDF is initialized with no watermark set (returns NULL).
156/// Use [`register_streaming_functions_with_watermark`] to provide a live
157/// watermark source.
158///
159/// # Example
160///
161/// ```rust,ignore
162/// let ctx = create_streaming_context();
163/// ctx.register_table("events", provider)?;
164/// let df = ctx.sql("SELECT * FROM events").await?;
165/// ```
166#[must_use]
167pub fn create_streaming_context() -> SessionContext {
168    let config = SessionConfig::new()
169        .with_batch_size(8192)
170        .with_target_partitions(1); // Single partition for streaming
171
172    let ctx = SessionContext::new_with_config(config);
173    register_streaming_functions(&ctx);
174    ctx
175}
176
177/// Registers `LaminarDB` streaming UDFs with a session context.
178///
179/// Registers the following scalar functions:
180/// - `tumble(timestamp, interval)` — tumbling window start
181/// - `hop(timestamp, slide, size)` — hopping window start
182/// - `session(timestamp, gap)` — session window pass-through
183/// - `watermark()` — current watermark (returns NULL, no live source)
184///
185/// Use [`register_streaming_functions_with_watermark`] to provide a
186/// live watermark source from Ring 0.
187pub fn register_streaming_functions(ctx: &SessionContext) {
188    ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
189    ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
190    ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
191    ctx.register_udf(ScalarUDF::new_from_impl(CumulateWindowStart::new()));
192    ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::unset()));
193    ctx.register_udf(ScalarUDF::new_from_impl(ProcTimeUdf::new()));
194    register_json_functions(ctx);
195    register_json_extensions(ctx);
196    register_complex_type_functions(ctx);
197    register_lambda_functions(ctx);
198}
199
200/// Registers streaming UDFs with a live watermark source.
201///
202/// Same as [`register_streaming_functions`] but connects the `watermark()`
203/// UDF to a shared atomic value that Ring 0 updates in real time.
204///
205/// # Arguments
206///
207/// * `ctx` - `DataFusion` session context
208/// * `watermark_ms` - Shared atomic holding the current watermark in
209///   milliseconds since epoch. Values < 0 mean "no watermark" (returns NULL).
210pub fn register_streaming_functions_with_watermark(
211    ctx: &SessionContext,
212    watermark_ms: Arc<AtomicI64>,
213) {
214    ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
215    ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
216    ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
217    ctx.register_udf(ScalarUDF::new_from_impl(CumulateWindowStart::new()));
218    ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::new(watermark_ms)));
219    ctx.register_udf(ScalarUDF::new_from_impl(ProcTimeUdf::new()));
220    register_json_functions(ctx);
221    register_json_extensions(ctx);
222    register_complex_type_functions(ctx);
223    register_lambda_functions(ctx);
224}
225
226/// Registers all PostgreSQL-compatible JSON UDFs and UDAFs
227/// with the given `SessionContext`.
228pub fn register_json_functions(ctx: &SessionContext) {
229    // Extraction operators
230    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGet::new()));
231    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetIdx::new()));
232    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetText::new()));
233    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetTextIdx::new()));
234    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetPath::new()));
235    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetPathText::new()));
236
237    // Existence operators
238    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExists::new()));
239    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExistsAny::new()));
240    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExistsAll::new()));
241
242    // Containment operators
243    ctx.register_udf(ScalarUDF::new_from_impl(JsonbContains::new()));
244    ctx.register_udf(ScalarUDF::new_from_impl(JsonbContainedBy::new()));
245
246    // Interrogation / construction
247    ctx.register_udf(ScalarUDF::new_from_impl(JsonTypeof::new()));
248    ctx.register_udf(ScalarUDF::new_from_impl(JsonBuildObject::new()));
249    ctx.register_udf(ScalarUDF::new_from_impl(JsonBuildArray::new()));
250    ctx.register_udf(ScalarUDF::new_from_impl(ToJsonb::new()));
251
252    // Aggregates
253    ctx.register_udaf(datafusion_expr::AggregateUDF::new_from_impl(JsonAgg::new()));
254    ctx.register_udaf(datafusion_expr::AggregateUDF::new_from_impl(
255        JsonObjectAgg::new(),
256    ));
257
258    // Format bridge functions
259    ctx.register_udf(ScalarUDF::new_from_impl(ParseEpochUdf::new()));
260    ctx.register_udf(ScalarUDF::new_from_impl(ParseTimestampUdf::new()));
261    ctx.register_udf(ScalarUDF::new_from_impl(ToJsonUdf::new()));
262    ctx.register_udf(ScalarUDF::new_from_impl(FromJsonUdf::new()));
263
264    // JSON path query functions (scalar)
265    ctx.register_udf(ScalarUDF::new_from_impl(JsonbPathExistsUdf::new()));
266    ctx.register_udf(ScalarUDF::new_from_impl(JsonbPathMatchUdf::new()));
267
268    // JSON table-valued functions
269    register_json_table_functions(ctx);
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use arrow_array::{Float64Array, Int64Array, RecordBatch};
276    use arrow_schema::{DataType, Field, Schema};
277    use datafusion::execution::FunctionRegistry;
278    use futures::StreamExt;
279    use std::sync::Arc;
280
281    fn test_schema() -> Arc<Schema> {
282        Arc::new(Schema::new(vec![
283            Field::new("id", DataType::Int64, false),
284            Field::new("value", DataType::Float64, true),
285        ]))
286    }
287
288    /// Take the sender from a `ChannelStreamSource`, panicking if already taken.
289    fn take_test_sender(source: &ChannelStreamSource) -> super::bridge::BridgeSender {
290        source.take_sender().expect("sender already taken")
291    }
292
293    fn test_batch(schema: &Arc<Schema>, ids: Vec<i64>, values: Vec<f64>) -> RecordBatch {
294        RecordBatch::try_new(
295            Arc::clone(schema),
296            vec![
297                Arc::new(Int64Array::from(ids)),
298                Arc::new(Float64Array::from(values)),
299            ],
300        )
301        .unwrap()
302    }
303
304    #[test]
305    fn test_create_streaming_context() {
306        let ctx = create_streaming_context();
307        let state = ctx.state();
308        let config = state.config();
309
310        assert_eq!(config.batch_size(), 8192);
311        assert_eq!(config.target_partitions(), 1);
312    }
313
314    #[tokio::test]
315    async fn test_full_query_pipeline() {
316        let ctx = create_streaming_context();
317        let schema = test_schema();
318
319        // Create source and take the sender (important for channel closure)
320        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
321        let sender = take_test_sender(&source);
322        let provider = StreamingTableProvider::new("events", source);
323        ctx.register_table("events", Arc::new(provider)).unwrap();
324
325        // Send test data
326        sender
327            .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
328            .await
329            .unwrap();
330        sender
331            .send(test_batch(&schema, vec![4, 5], vec![40.0, 50.0]))
332            .await
333            .unwrap();
334        drop(sender); // Close the channel
335
336        // Execute query
337        let df = ctx.sql("SELECT * FROM events").await.unwrap();
338        let batches = df.collect().await.unwrap();
339
340        // Verify results
341        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
342        assert_eq!(total_rows, 5);
343    }
344
345    #[tokio::test]
346    async fn test_query_with_projection() {
347        let ctx = create_streaming_context();
348        let schema = test_schema();
349
350        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
351        let sender = take_test_sender(&source);
352        let provider = StreamingTableProvider::new("events", source);
353        ctx.register_table("events", Arc::new(provider)).unwrap();
354
355        sender
356            .send(test_batch(&schema, vec![1, 2], vec![100.0, 200.0]))
357            .await
358            .unwrap();
359        drop(sender);
360
361        // Query only the id column
362        let df = ctx.sql("SELECT id FROM events").await.unwrap();
363        let batches = df.collect().await.unwrap();
364
365        assert_eq!(batches.len(), 1);
366        assert_eq!(batches[0].num_columns(), 1);
367        assert_eq!(batches[0].schema().field(0).name(), "id");
368    }
369
370    #[tokio::test]
371    async fn test_query_with_filter() {
372        let ctx = create_streaming_context();
373        let schema = test_schema();
374
375        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
376        let sender = take_test_sender(&source);
377        let provider = StreamingTableProvider::new("events", source);
378        ctx.register_table("events", Arc::new(provider)).unwrap();
379
380        sender
381            .send(test_batch(
382                &schema,
383                vec![1, 2, 3, 4, 5],
384                vec![10.0, 20.0, 30.0, 40.0, 50.0],
385            ))
386            .await
387            .unwrap();
388        drop(sender);
389
390        // Filter for value > 25
391        let df = ctx
392            .sql("SELECT * FROM events WHERE value > 25")
393            .await
394            .unwrap();
395        let batches = df.collect().await.unwrap();
396
397        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
398        assert_eq!(total_rows, 3); // 30, 40, 50
399    }
400
401    #[tokio::test]
402    async fn test_unbounded_aggregation_rejected() {
403        // Aggregations on unbounded streams should be rejected by `DataFusion`.
404        // Streaming aggregations require windows, which are implemented.
405        let ctx = create_streaming_context();
406        let schema = test_schema();
407
408        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
409        let sender = take_test_sender(&source);
410        let provider = StreamingTableProvider::new("events", source);
411        ctx.register_table("events", Arc::new(provider)).unwrap();
412
413        sender
414            .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
415            .await
416            .unwrap();
417        drop(sender);
418
419        // Aggregate query on unbounded stream should fail at execution
420        let df = ctx.sql("SELECT COUNT(*) as cnt FROM events").await.unwrap();
421
422        // Execution should fail because we can't aggregate an infinite stream
423        let result = df.collect().await;
424        assert!(
425            result.is_err(),
426            "Aggregation on unbounded stream should fail"
427        );
428    }
429
430    #[tokio::test]
431    async fn test_query_with_order_by() {
432        let ctx = create_streaming_context();
433        let schema = test_schema();
434
435        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
436        let sender = take_test_sender(&source);
437        let provider = StreamingTableProvider::new("events", source);
438        ctx.register_table("events", Arc::new(provider)).unwrap();
439
440        sender
441            .send(test_batch(&schema, vec![3, 1, 2], vec![30.0, 10.0, 20.0]))
442            .await
443            .unwrap();
444        drop(sender);
445
446        // Query with ORDER BY (`DataFusion` handles this with Sort operator)
447        let df = ctx.sql("SELECT id, value FROM events").await.unwrap();
448        let batches = df.collect().await.unwrap();
449
450        // Verify we got results (ordering may vary due to streaming nature)
451        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
452        assert_eq!(total_rows, 3);
453    }
454
455    #[tokio::test]
456    async fn test_bridge_throughput() {
457        // Benchmark-style test for bridge performance
458        let schema = test_schema();
459        let bridge = StreamBridge::new(Arc::clone(&schema), 10000);
460        let sender = bridge.sender();
461        let mut stream = bridge.into_stream();
462
463        let batch_count = 1000;
464        let batch = test_batch(&schema, vec![1, 2, 3, 4, 5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
465
466        // Spawn sender task
467        let send_task = tokio::spawn(async move {
468            for _ in 0..batch_count {
469                sender.send(batch.clone()).await.unwrap();
470            }
471        });
472
473        // Receive all batches
474        let mut received = 0;
475        while let Some(result) = stream.next().await {
476            result.unwrap();
477            received += 1;
478            if received == batch_count {
479                break;
480            }
481        }
482
483        send_task.await.unwrap();
484        assert_eq!(received, batch_count);
485    }
486
487    // ── Integration Tests ──────────────────────────────────────────
488
489    #[test]
490    fn test_streaming_functions_registered() {
491        let ctx = create_streaming_context();
492        // Verify all 4 UDFs are registered
493        assert!(ctx.udf("tumble").is_ok(), "tumble UDF not registered");
494        assert!(ctx.udf("hop").is_ok(), "hop UDF not registered");
495        assert!(ctx.udf("session").is_ok(), "session UDF not registered");
496        assert!(ctx.udf("watermark").is_ok(), "watermark UDF not registered");
497    }
498
499    #[test]
500    fn test_streaming_functions_with_watermark() {
501        use std::sync::atomic::AtomicI64;
502
503        let ctx = SessionContext::new();
504        let wm = Arc::new(AtomicI64::new(42_000));
505        register_streaming_functions_with_watermark(&ctx, wm);
506
507        assert!(ctx.udf("tumble").is_ok());
508        assert!(ctx.udf("watermark").is_ok());
509    }
510
511    #[tokio::test]
512    async fn test_tumble_udf_via_datafusion() {
513        use arrow_array::TimestampMillisecondArray;
514        use arrow_schema::TimeUnit;
515
516        let ctx = create_streaming_context();
517
518        // Create schema with timestamp and value columns
519        let schema = Arc::new(Schema::new(vec![
520            Field::new(
521                "event_time",
522                DataType::Timestamp(TimeUnit::Millisecond, None),
523                false,
524            ),
525            Field::new("value", DataType::Float64, false),
526        ]));
527
528        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
529        let sender = take_test_sender(&source);
530        let provider = StreamingTableProvider::new("events", source);
531        ctx.register_table("events", Arc::new(provider)).unwrap();
532
533        // Send events across two 5-minute windows:
534        // Window [0, 300_000): timestamps 60_000, 120_000
535        // Window [300_000, 600_000): timestamps 360_000
536        let batch = RecordBatch::try_new(
537            Arc::clone(&schema),
538            vec![
539                Arc::new(TimestampMillisecondArray::from(vec![
540                    60_000i64, 120_000, 360_000,
541                ])),
542                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
543            ],
544        )
545        .unwrap();
546        sender.send(batch).await.unwrap();
547        drop(sender);
548
549        // Verify the tumble UDF computes correct window starts via DataFusion
550        // (GROUP BY aggregation and ORDER BY on unbounded streams are handled by Ring 0)
551        let df = ctx
552            .sql(
553                "SELECT tumble(event_time, INTERVAL '5' MINUTE) as window_start, \
554                 value \
555                 FROM events",
556            )
557            .await
558            .unwrap();
559
560        let batches = df.collect().await.unwrap();
561        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
562        assert_eq!(total_rows, 3);
563
564        // Verify the window_start values (single batch, order preserved)
565        let ws_col = batches[0]
566            .column(0)
567            .as_any()
568            .downcast_ref::<TimestampMillisecondArray>()
569            .expect("window_start should be TimestampMillisecond");
570        // 60_000 and 120_000 → window [0, 300_000), start = 0
571        assert_eq!(ws_col.value(0), 0);
572        assert_eq!(ws_col.value(1), 0);
573        // 360_000 → window [300_000, 600_000), start = 300_000
574        assert_eq!(ws_col.value(2), 300_000);
575    }
576
577    #[tokio::test]
578    async fn test_logical_plan_from_windowed_query() {
579        use arrow_schema::TimeUnit;
580
581        let ctx = create_streaming_context();
582
583        let schema = Arc::new(Schema::new(vec![
584            Field::new(
585                "event_time",
586                DataType::Timestamp(TimeUnit::Millisecond, None),
587                false,
588            ),
589            Field::new("value", DataType::Float64, false),
590        ]));
591
592        let source = Arc::new(ChannelStreamSource::new(schema));
593        let _sender = source.take_sender();
594        let provider = StreamingTableProvider::new("events", source);
595        ctx.register_table("events", Arc::new(provider)).unwrap();
596
597        // Create a LogicalPlan for a windowed query
598        let df = ctx
599            .sql(
600                "SELECT tumble(event_time, INTERVAL '5' MINUTE) as w, \
601                 COUNT(*) as cnt \
602                 FROM events \
603                 GROUP BY tumble(event_time, INTERVAL '5' MINUTE)",
604            )
605            .await;
606
607        // Should succeed in creating the logical plan (UDFs are registered)
608        assert!(df.is_ok(), "Failed to create logical plan: {df:?}");
609    }
610
611    #[tokio::test]
612    async fn test_end_to_end_execute_streaming_sql() {
613        use crate::planner::StreamingPlanner;
614
615        let ctx = create_streaming_context();
616
617        let schema = Arc::new(Schema::new(vec![
618            Field::new("id", DataType::Int64, false),
619            Field::new("name", DataType::Utf8, true),
620        ]));
621
622        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
623        let sender = take_test_sender(&source);
624        let provider = StreamingTableProvider::new("items", source);
625        ctx.register_table("items", Arc::new(provider)).unwrap();
626
627        let batch = RecordBatch::try_new(
628            Arc::clone(&schema),
629            vec![
630                Arc::new(Int64Array::from(vec![1, 2, 3])),
631                Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])),
632            ],
633        )
634        .unwrap();
635        sender.send(batch).await.unwrap();
636        drop(sender);
637
638        let mut planner = StreamingPlanner::new();
639        let result = execute_streaming_sql("SELECT id FROM items WHERE id > 1", &ctx, &mut planner)
640            .await
641            .unwrap();
642
643        match result {
644            StreamingSqlResult::Query(qr) => {
645                let mut stream = qr.stream;
646                let mut total = 0;
647                while let Some(batch) = stream.next().await {
648                    total += batch.unwrap().num_rows();
649                }
650                assert_eq!(total, 2); // id=2, id=3
651            }
652            StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
653        }
654    }
655
656    #[tokio::test]
657    async fn test_watermark_function_in_filter() {
658        use arrow_array::TimestampMillisecondArray;
659        use arrow_schema::TimeUnit;
660        use std::sync::atomic::AtomicI64;
661
662        // Create context with a specific watermark value
663        let config = SessionConfig::new()
664            .with_batch_size(8192)
665            .with_target_partitions(1);
666        let ctx = SessionContext::new_with_config(config);
667        let wm = Arc::new(AtomicI64::new(200_000)); // watermark at 200s
668        register_streaming_functions_with_watermark(&ctx, wm);
669
670        let schema = Arc::new(Schema::new(vec![
671            Field::new(
672                "event_time",
673                DataType::Timestamp(TimeUnit::Millisecond, None),
674                false,
675            ),
676            Field::new("value", DataType::Float64, false),
677        ]));
678
679        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
680        let sender = take_test_sender(&source);
681        let provider = StreamingTableProvider::new("events", source);
682        ctx.register_table("events", Arc::new(provider)).unwrap();
683
684        // Events: 100s, 200s, 300s - watermark is at 200s
685        let batch = RecordBatch::try_new(
686            Arc::clone(&schema),
687            vec![
688                Arc::new(TimestampMillisecondArray::from(vec![
689                    100_000i64, 200_000, 300_000,
690                ])),
691                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
692            ],
693        )
694        .unwrap();
695        sender.send(batch).await.unwrap();
696        drop(sender);
697
698        // Filter events after watermark
699        let df = ctx
700            .sql("SELECT value FROM events WHERE event_time > watermark()")
701            .await
702            .unwrap();
703        let batches = df.collect().await.unwrap();
704        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
705        // Only event at 300s is after watermark (200s)
706        assert_eq!(total_rows, 1);
707    }
708}