Skip to main content

laminar_sql/datafusion/
mod.rs

1//! `DataFusion` integration for SQL processing
2//!
3//! This module provides the integration layer between `LaminarDB`'s push-based
4//! streaming engine and `DataFusion`'s pull-based SQL query execution.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────┐
10//! │                    Ring 2: Query Planning                        │
11//! │  SQL Query → SessionContext → LogicalPlan → ExecutionPlan       │
12//! │                                      │                          │
13//! │                            StreamingScanExec                    │
14//! │                                      │                          │
15//! │                              ┌───────▼──────┐                   │
16//! │                              │ StreamBridge │ (tokio channel)   │
17//! │                              └───────▲──────┘                   │
18//! ├──────────────────────────────────────┼──────────────────────────┤
19//! │                    Ring 0: Hot Path   │                          │
20//! │                                      │                          │
21//! │  Source → Reactor.poll() ────────────┘                          │
22//! │              (Events with RecordBatch data)                     │
23//! └─────────────────────────────────────────────────────────────────┘
24//! ```
25//!
26//! # Components
27//!
28//! - [`StreamSource`]: Trait for streaming data sources
29//! - [`StreamBridge`]: Channel-based push-to-pull bridge
30//! - [`StreamingScanExec`]: `DataFusion` execution plan for streaming scans
31//! - [`StreamingTableProvider`]: `DataFusion` table provider for streaming sources
32//! - [`ChannelStreamSource`]: Concrete source using channels
33//!
34//! # Usage
35//!
36//! ```rust,ignore
37//! use laminar_sql::datafusion::{
38//!     create_streaming_context, ChannelStreamSource, StreamingTableProvider,
39//! };
40//! use std::sync::Arc;
41//!
42//! // Create a streaming context
43//! let ctx = create_streaming_context();
44//!
45//! // Create a channel source
46//! let schema = Arc::new(Schema::new(vec![
47//!     Field::new("id", DataType::Int64, false),
48//!     Field::new("value", DataType::Float64, true),
49//! ]));
50//! let source = Arc::new(ChannelStreamSource::new(schema));
51//! let sender = source.sender();
52//!
53//! // Register as a table
54//! let provider = StreamingTableProvider::new("events", source);
55//! ctx.register_table("events", Arc::new(provider))?;
56//!
57//! // Push data from the Reactor
58//! sender.send(batch).await?;
59//!
60//! // Execute SQL queries
61//! let df = ctx.sql("SELECT * FROM events WHERE value > 100").await?;
62//! ```
63
64/// DataFusion aggregate bridge for streaming aggregation.
65///
66/// Bridges DataFusion's `Accumulator` trait with `laminar-core`'s
67/// `DynAccumulator` / `DynAggregatorFactory` traits. This avoids
68/// duplicating aggregation logic.
69pub mod aggregate_bridge;
70mod bridge;
71mod channel_source;
72/// Lambda higher-order functions for arrays and maps (F-SCHEMA-015 Tier 3)
73pub mod complex_type_lambda;
74/// Array, Struct, and Map scalar UDFs (F-SCHEMA-015)
75pub mod complex_type_udf;
76mod exec;
77/// End-to-end streaming SQL execution
78pub mod execute;
79/// Format bridge UDFs for inline format conversion
80pub mod format_bridge_udf;
81/// LaminarDB streaming JSON extension UDFs (F-SCHEMA-013)
82pub mod json_extensions;
83/// SQL/JSON path query compiler and scalar UDFs
84pub mod json_path;
85/// JSON table-valued functions (array/object expansion)
86pub mod json_tvf;
87/// JSONB binary format types for JSON UDF evaluation
88pub mod json_types;
89/// PostgreSQL-compatible JSON aggregate UDAFs
90pub mod json_udaf;
91/// PostgreSQL-compatible JSON scalar UDFs
92pub mod json_udf;
93/// Live source provider for streaming execution with plan caching
94pub mod live_source;
95/// Lookup join plan node for DataFusion.
96pub mod lookup_join;
97/// Physical execution plan and extension planner for lookup joins.
98pub mod lookup_join_exec;
99/// Processing-time UDF for `PROCTIME()` support
100pub mod proctime_udf;
101mod source;
102mod table_provider;
103/// Dynamic watermark filter for scan-level late-data pruning
104pub mod watermark_filter;
105/// Watermark UDF for current watermark access
106pub mod watermark_udf;
107/// Window function UDFs (TUMBLE, HOP, SESSION, CUMULATE)
108pub mod window_udf;
109
110pub use aggregate_bridge::{
111    create_aggregate_factory, lookup_aggregate_udf, result_to_scalar_value, scalar_value_to_result,
112    DataFusionAccumulatorAdapter, DataFusionAggregateFactory,
113};
114pub use bridge::{BridgeSendError, BridgeSender, BridgeStream, BridgeTrySendError, StreamBridge};
115pub use channel_source::ChannelStreamSource;
116pub use complex_type_lambda::{
117    register_lambda_functions, ArrayFilter, ArrayReduce, ArrayTransform, MapFilter,
118    MapTransformValues,
119};
120pub use complex_type_udf::{
121    register_complex_type_functions, MapContainsKey, MapFromArrays, MapKeys, MapValues, StructDrop,
122    StructExtract, StructMerge, StructRename, StructSet,
123};
124pub use exec::StreamingScanExec;
125pub use execute::{execute_streaming_sql, DdlResult, QueryResult, StreamingSqlResult};
126pub use format_bridge_udf::{FromJsonUdf, ParseEpochUdf, ParseTimestampUdf, ToJsonUdf};
127pub use json_extensions::{
128    register_json_extensions, JsonInferSchema, JsonToColumns, JsonbDeepMerge, JsonbExcept,
129    JsonbFlatten, JsonbMerge, JsonbPick, JsonbRenameKeys, JsonbStripNulls, JsonbUnflatten,
130};
131pub use json_path::{CompiledJsonPath, JsonPathStep, JsonbPathExistsUdf, JsonbPathMatchUdf};
132pub use json_tvf::{
133    register_json_table_functions, JsonbArrayElementsTextTvf, JsonbArrayElementsTvf,
134    JsonbEachTextTvf, JsonbEachTvf, JsonbObjectKeysTvf,
135};
136pub use json_udaf::{JsonAgg, JsonObjectAgg};
137pub use json_udf::{
138    JsonBuildArray, JsonBuildObject, JsonTypeof, JsonbContainedBy, JsonbContains, JsonbExists,
139    JsonbExistsAll, JsonbExistsAny, JsonbGet, JsonbGetIdx, JsonbGetPath, JsonbGetPathText,
140    JsonbGetText, JsonbGetTextIdx, ToJsonb,
141};
142pub use live_source::{LiveSourceHandle, LiveSourceProvider};
143pub use lookup_join_exec::{
144    LookupJoinExec, LookupJoinExtensionPlanner, LookupSnapshot, LookupTableRegistry,
145    PartialLookupJoinExec, PartialLookupState, RegisteredLookup, VersionedLookupJoinExec,
146    VersionedLookupState,
147};
148pub use proctime_udf::ProcTimeUdf;
149pub use source::{SortColumn, StreamSource, StreamSourceRef};
150pub use table_provider::StreamingTableProvider;
151pub use watermark_filter::WatermarkDynamicFilter;
152pub use watermark_udf::WatermarkUdf;
153pub use window_udf::{CumulateWindowStart, HopWindowStart, SessionWindowStart, TumbleWindowStart};
154
155use std::sync::atomic::AtomicI64;
156use std::sync::Arc;
157
158use datafusion::execution::SessionStateBuilder;
159use datafusion::prelude::*;
160use datafusion_expr::ScalarUDF;
161
162use crate::planner::streaming_optimizer::{StreamingPhysicalValidator, StreamingValidatorMode};
163
164/// Returns a base `SessionConfig` with identifier normalization disabled.
165///
166/// DataFusion's default behaviour lowercases all unquoted SQL identifiers
167/// (per the SQL standard). LaminarDB disables this so that mixed-case
168/// column names from external sources (Kafka, CDC, WebSocket) can be
169/// referenced without double-quoting.
170#[must_use]
171pub fn base_session_config() -> SessionConfig {
172    let mut config = SessionConfig::new();
173    config.options_mut().sql_parser.enable_ident_normalization = false;
174    // Single partition for streaming micro-batch execution. Multi-partition
175    // plans contain stateful operators (RepartitionExec) that cannot be
176    // reused across cycles, causing panics on cached physical plans.
177    config = config.with_target_partitions(1);
178    config
179}
180
181/// Creates a `DataFusion` session context with identifier normalization
182/// disabled.
183///
184/// Suitable for ad-hoc / non-streaming queries (filters, lookups).
185/// For streaming workloads prefer [`create_streaming_context`].
186#[must_use]
187pub fn create_session_context() -> SessionContext {
188    SessionContext::new_with_config(base_session_config())
189}
190
191/// Creates a `DataFusion` session context configured for streaming queries.
192///
193/// The context is configured with:
194/// - Batch size of 8192 (balanced for streaming throughput)
195/// - Single partition (streaming sources are typically not partitioned)
196/// - Identifier normalization disabled (mixed-case columns work unquoted)
197/// - All streaming UDFs registered (TUMBLE, HOP, SESSION, WATERMARK)
198/// - `StreamingPhysicalValidator` in `Reject` mode (blocks unsafe plans)
199///
200/// The watermark UDF is initialized with no watermark set (returns NULL).
201/// Use [`register_streaming_functions_with_watermark`] to provide a live
202/// watermark source.
203///
204/// # Example
205///
206/// ```rust,ignore
207/// let ctx = create_streaming_context();
208/// ctx.register_table("events", provider)?;
209/// let df = ctx.sql("SELECT * FROM events").await?;
210/// ```
211#[must_use]
212pub fn create_streaming_context() -> SessionContext {
213    create_streaming_context_with_validator(StreamingValidatorMode::Reject)
214}
215
216/// Creates a streaming context with a configurable validator mode.
217///
218/// Same as [`create_streaming_context`] but allows choosing how the
219/// [`StreamingPhysicalValidator`] handles plan violations.
220///
221/// Use [`StreamingValidatorMode::Off`] to get the previous behaviour
222/// (no plan-time validation).
223#[must_use]
224pub fn create_streaming_context_with_validator(mode: StreamingValidatorMode) -> SessionContext {
225    let config = base_session_config().with_batch_size(8192);
226
227    let ctx = if matches!(mode, StreamingValidatorMode::Off) {
228        SessionContext::new_with_config(config)
229    } else {
230        // Build a default state to get the standard optimizer rules, then
231        // prepend our streaming validator so it fires before DataFusion's
232        // built-in SanityCheckPlan (which produces generic error messages).
233        let default_state = SessionStateBuilder::new()
234            .with_config(config.clone())
235            .with_default_features()
236            .build();
237        let mut rules: Vec<
238            Arc<dyn datafusion::physical_optimizer::PhysicalOptimizerRule + Send + Sync>,
239        > = vec![Arc::new(StreamingPhysicalValidator::new(mode))];
240        rules.extend(default_state.physical_optimizers().iter().cloned());
241
242        let state = SessionStateBuilder::new()
243            .with_config(config)
244            .with_default_features()
245            .with_physical_optimizer_rules(rules)
246            .build();
247        SessionContext::new_with_state(state)
248    };
249
250    register_streaming_functions(&ctx);
251    ctx
252}
253
254/// Registers `LaminarDB` streaming UDFs with a session context.
255///
256/// Registers the following scalar functions:
257/// - `tumble(timestamp, interval)` — tumbling window start
258/// - `hop(timestamp, slide, size)` — hopping window start
259/// - `session(timestamp, gap)` — session window pass-through
260/// - `watermark()` — current watermark (returns NULL, no live source)
261///
262/// Use [`register_streaming_functions_with_watermark`] to provide a
263/// live watermark source from Ring 0.
264pub fn register_streaming_functions(ctx: &SessionContext) {
265    ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
266    ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
267    ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
268    ctx.register_udf(ScalarUDF::new_from_impl(CumulateWindowStart::new()));
269    ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::unset()));
270    ctx.register_udf(ScalarUDF::new_from_impl(ProcTimeUdf::new()));
271    register_json_functions(ctx);
272    register_json_extensions(ctx);
273    register_complex_type_functions(ctx);
274    register_lambda_functions(ctx);
275}
276
277/// Registers streaming UDFs with a live watermark source.
278///
279/// Same as [`register_streaming_functions`] but connects the `watermark()`
280/// UDF to a shared atomic value that Ring 0 updates in real time.
281///
282/// # Arguments
283///
284/// * `ctx` - `DataFusion` session context
285/// * `watermark_ms` - Shared atomic holding the current watermark in
286///   milliseconds since epoch. Values < 0 mean "no watermark" (returns NULL).
287pub fn register_streaming_functions_with_watermark(
288    ctx: &SessionContext,
289    watermark_ms: Arc<AtomicI64>,
290) {
291    ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
292    ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
293    ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
294    ctx.register_udf(ScalarUDF::new_from_impl(CumulateWindowStart::new()));
295    ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::new(watermark_ms)));
296    ctx.register_udf(ScalarUDF::new_from_impl(ProcTimeUdf::new()));
297    register_json_functions(ctx);
298    register_json_extensions(ctx);
299    register_complex_type_functions(ctx);
300    register_lambda_functions(ctx);
301}
302
303/// Registers all PostgreSQL-compatible JSON UDFs and UDAFs
304/// with the given `SessionContext`.
305pub fn register_json_functions(ctx: &SessionContext) {
306    // Extraction operators
307    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGet::new()));
308    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetIdx::new()));
309    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetText::new()));
310    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetTextIdx::new()));
311    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetPath::new()));
312    ctx.register_udf(ScalarUDF::new_from_impl(JsonbGetPathText::new()));
313
314    // Existence operators
315    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExists::new()));
316    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExistsAny::new()));
317    ctx.register_udf(ScalarUDF::new_from_impl(JsonbExistsAll::new()));
318
319    // Containment operators
320    ctx.register_udf(ScalarUDF::new_from_impl(JsonbContains::new()));
321    ctx.register_udf(ScalarUDF::new_from_impl(JsonbContainedBy::new()));
322
323    // Interrogation / construction
324    ctx.register_udf(ScalarUDF::new_from_impl(JsonTypeof::new()));
325    ctx.register_udf(ScalarUDF::new_from_impl(JsonBuildObject::new()));
326    ctx.register_udf(ScalarUDF::new_from_impl(JsonBuildArray::new()));
327    ctx.register_udf(ScalarUDF::new_from_impl(ToJsonb::new()));
328
329    // Aggregates
330    ctx.register_udaf(datafusion_expr::AggregateUDF::new_from_impl(JsonAgg::new()));
331    ctx.register_udaf(datafusion_expr::AggregateUDF::new_from_impl(
332        JsonObjectAgg::new(),
333    ));
334
335    // Format bridge functions
336    ctx.register_udf(ScalarUDF::new_from_impl(ParseEpochUdf::new()));
337    ctx.register_udf(ScalarUDF::new_from_impl(ParseTimestampUdf::new()));
338    ctx.register_udf(ScalarUDF::new_from_impl(ToJsonUdf::new()));
339    ctx.register_udf(ScalarUDF::new_from_impl(FromJsonUdf::new()));
340
341    // JSON path query functions (scalar)
342    ctx.register_udf(ScalarUDF::new_from_impl(JsonbPathExistsUdf::new()));
343    ctx.register_udf(ScalarUDF::new_from_impl(JsonbPathMatchUdf::new()));
344
345    // JSON table-valued functions
346    register_json_table_functions(ctx);
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352    use arrow_array::{Float64Array, Int64Array, RecordBatch};
353    use arrow_schema::{DataType, Field, Schema};
354    use datafusion::execution::FunctionRegistry;
355    use futures::StreamExt;
356    use std::sync::Arc;
357
358    fn test_schema() -> Arc<Schema> {
359        Arc::new(Schema::new(vec![
360            Field::new("id", DataType::Int64, false),
361            Field::new("value", DataType::Float64, true),
362        ]))
363    }
364
365    /// Take the sender from a `ChannelStreamSource`, panicking if already taken.
366    fn take_test_sender(source: &ChannelStreamSource) -> super::bridge::BridgeSender {
367        source.take_sender().expect("sender already taken")
368    }
369
370    fn test_batch(schema: &Arc<Schema>, ids: Vec<i64>, values: Vec<f64>) -> RecordBatch {
371        RecordBatch::try_new(
372            Arc::clone(schema),
373            vec![
374                Arc::new(Int64Array::from(ids)),
375                Arc::new(Float64Array::from(values)),
376            ],
377        )
378        .unwrap()
379    }
380
381    #[test]
382    fn test_create_streaming_context() {
383        let ctx = create_streaming_context();
384        let state = ctx.state();
385        let config = state.config();
386
387        assert_eq!(config.batch_size(), 8192);
388        assert_eq!(config.target_partitions(), 1);
389    }
390
391    #[tokio::test]
392    async fn test_full_query_pipeline() {
393        let ctx = create_streaming_context();
394        let schema = test_schema();
395
396        // Create source and take the sender (important for channel closure)
397        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
398        let sender = take_test_sender(&source);
399        let provider = StreamingTableProvider::new("events", source);
400        ctx.register_table("events", Arc::new(provider)).unwrap();
401
402        // Send test data
403        sender
404            .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
405            .await
406            .unwrap();
407        sender
408            .send(test_batch(&schema, vec![4, 5], vec![40.0, 50.0]))
409            .await
410            .unwrap();
411        drop(sender); // Close the channel
412
413        // Execute query
414        let df = ctx.sql("SELECT * FROM events").await.unwrap();
415        let batches = df.collect().await.unwrap();
416
417        // Verify results
418        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
419        assert_eq!(total_rows, 5);
420    }
421
422    #[tokio::test]
423    async fn test_query_with_projection() {
424        let ctx = create_streaming_context();
425        let schema = test_schema();
426
427        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
428        let sender = take_test_sender(&source);
429        let provider = StreamingTableProvider::new("events", source);
430        ctx.register_table("events", Arc::new(provider)).unwrap();
431
432        sender
433            .send(test_batch(&schema, vec![1, 2], vec![100.0, 200.0]))
434            .await
435            .unwrap();
436        drop(sender);
437
438        // Query only the id column
439        let df = ctx.sql("SELECT id FROM events").await.unwrap();
440        let batches = df.collect().await.unwrap();
441
442        assert_eq!(batches.len(), 1);
443        assert_eq!(batches[0].num_columns(), 1);
444        assert_eq!(batches[0].schema().field(0).name(), "id");
445    }
446
447    #[tokio::test]
448    async fn test_query_with_filter() {
449        let ctx = create_streaming_context();
450        let schema = test_schema();
451
452        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
453        let sender = take_test_sender(&source);
454        let provider = StreamingTableProvider::new("events", source);
455        ctx.register_table("events", Arc::new(provider)).unwrap();
456
457        sender
458            .send(test_batch(
459                &schema,
460                vec![1, 2, 3, 4, 5],
461                vec![10.0, 20.0, 30.0, 40.0, 50.0],
462            ))
463            .await
464            .unwrap();
465        drop(sender);
466
467        // Filter for value > 25
468        let df = ctx
469            .sql("SELECT * FROM events WHERE value > 25")
470            .await
471            .unwrap();
472        let batches = df.collect().await.unwrap();
473
474        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
475        assert_eq!(total_rows, 3); // 30, 40, 50
476    }
477
478    #[tokio::test]
479    async fn test_unbounded_aggregation_rejected() {
480        // Aggregations on unbounded streams should be rejected by `DataFusion`.
481        // Streaming aggregations require windows, which are implemented.
482        let ctx = create_streaming_context();
483        let schema = test_schema();
484
485        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
486        let sender = take_test_sender(&source);
487        let provider = StreamingTableProvider::new("events", source);
488        ctx.register_table("events", Arc::new(provider)).unwrap();
489
490        sender
491            .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
492            .await
493            .unwrap();
494        drop(sender);
495
496        // Aggregate query on unbounded stream should fail at execution
497        let df = ctx.sql("SELECT COUNT(*) as cnt FROM events").await.unwrap();
498
499        // Execution should fail because we can't aggregate an infinite stream
500        let result = df.collect().await;
501        assert!(
502            result.is_err(),
503            "Aggregation on unbounded stream should fail"
504        );
505    }
506
507    #[tokio::test]
508    async fn test_query_with_order_by() {
509        let ctx = create_streaming_context();
510        let schema = test_schema();
511
512        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
513        let sender = take_test_sender(&source);
514        let provider = StreamingTableProvider::new("events", source);
515        ctx.register_table("events", Arc::new(provider)).unwrap();
516
517        sender
518            .send(test_batch(&schema, vec![3, 1, 2], vec![30.0, 10.0, 20.0]))
519            .await
520            .unwrap();
521        drop(sender);
522
523        // Query with ORDER BY (`DataFusion` handles this with Sort operator)
524        let df = ctx.sql("SELECT id, value FROM events").await.unwrap();
525        let batches = df.collect().await.unwrap();
526
527        // Verify we got results (ordering may vary due to streaming nature)
528        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
529        assert_eq!(total_rows, 3);
530    }
531
532    #[tokio::test]
533    async fn test_bridge_throughput() {
534        // Benchmark-style test for bridge performance
535        let schema = test_schema();
536        let bridge = StreamBridge::new(Arc::clone(&schema), 10000);
537        let sender = bridge.sender();
538        let mut stream = bridge.into_stream();
539
540        let batch_count = 1000;
541        let batch = test_batch(&schema, vec![1, 2, 3, 4, 5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
542
543        // Spawn sender task
544        let send_task = tokio::spawn(async move {
545            for _ in 0..batch_count {
546                sender.send(batch.clone()).await.unwrap();
547            }
548        });
549
550        // Receive all batches
551        let mut received = 0;
552        while let Some(result) = stream.next().await {
553            result.unwrap();
554            received += 1;
555            if received == batch_count {
556                break;
557            }
558        }
559
560        send_task.await.unwrap();
561        assert_eq!(received, batch_count);
562    }
563
564    // ── Integration Tests ──────────────────────────────────────────
565
566    #[test]
567    fn test_streaming_functions_registered() {
568        let ctx = create_streaming_context();
569        // Verify all 4 UDFs are registered
570        assert!(ctx.udf("tumble").is_ok(), "tumble UDF not registered");
571        assert!(ctx.udf("hop").is_ok(), "hop UDF not registered");
572        assert!(ctx.udf("session").is_ok(), "session UDF not registered");
573        assert!(ctx.udf("watermark").is_ok(), "watermark UDF not registered");
574    }
575
576    #[test]
577    fn test_streaming_functions_with_watermark() {
578        use std::sync::atomic::AtomicI64;
579
580        let ctx = create_session_context();
581        let wm = Arc::new(AtomicI64::new(42_000));
582        register_streaming_functions_with_watermark(&ctx, wm);
583
584        assert!(ctx.udf("tumble").is_ok());
585        assert!(ctx.udf("watermark").is_ok());
586    }
587
588    #[tokio::test]
589    async fn test_tumble_udf_via_datafusion() {
590        use arrow_array::TimestampMillisecondArray;
591        use arrow_schema::TimeUnit;
592
593        let ctx = create_streaming_context();
594
595        // Create schema with timestamp and value columns
596        let schema = Arc::new(Schema::new(vec![
597            Field::new(
598                "event_time",
599                DataType::Timestamp(TimeUnit::Millisecond, None),
600                false,
601            ),
602            Field::new("value", DataType::Float64, false),
603        ]));
604
605        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
606        let sender = take_test_sender(&source);
607        let provider = StreamingTableProvider::new("events", source);
608        ctx.register_table("events", Arc::new(provider)).unwrap();
609
610        // Send events across two 5-minute windows:
611        // Window [0, 300_000): timestamps 60_000, 120_000
612        // Window [300_000, 600_000): timestamps 360_000
613        let batch = RecordBatch::try_new(
614            Arc::clone(&schema),
615            vec![
616                Arc::new(TimestampMillisecondArray::from(vec![
617                    60_000i64, 120_000, 360_000,
618                ])),
619                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
620            ],
621        )
622        .unwrap();
623        sender.send(batch).await.unwrap();
624        drop(sender);
625
626        // Verify the tumble UDF computes correct window starts via DataFusion
627        // (GROUP BY aggregation and ORDER BY on unbounded streams are handled by Ring 0)
628        let df = ctx
629            .sql(
630                "SELECT tumble(event_time, INTERVAL '5' MINUTE) as window_start, \
631                 value \
632                 FROM events",
633            )
634            .await
635            .unwrap();
636
637        let batches = df.collect().await.unwrap();
638        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
639        assert_eq!(total_rows, 3);
640
641        // Verify the window_start values (single batch, order preserved)
642        let ws_col = batches[0]
643            .column(0)
644            .as_any()
645            .downcast_ref::<TimestampMillisecondArray>()
646            .expect("window_start should be TimestampMillisecond");
647        // 60_000 and 120_000 → window [0, 300_000), start = 0
648        assert_eq!(ws_col.value(0), 0);
649        assert_eq!(ws_col.value(1), 0);
650        // 360_000 → window [300_000, 600_000), start = 300_000
651        assert_eq!(ws_col.value(2), 300_000);
652    }
653
654    #[tokio::test]
655    async fn test_logical_plan_from_windowed_query() {
656        use arrow_schema::TimeUnit;
657
658        let ctx = create_streaming_context();
659
660        let schema = Arc::new(Schema::new(vec![
661            Field::new(
662                "event_time",
663                DataType::Timestamp(TimeUnit::Millisecond, None),
664                false,
665            ),
666            Field::new("value", DataType::Float64, false),
667        ]));
668
669        let source = Arc::new(ChannelStreamSource::new(schema));
670        let _sender = source.take_sender();
671        let provider = StreamingTableProvider::new("events", source);
672        ctx.register_table("events", Arc::new(provider)).unwrap();
673
674        // Create a LogicalPlan for a windowed query
675        let df = ctx
676            .sql(
677                "SELECT tumble(event_time, INTERVAL '5' MINUTE) as w, \
678                 COUNT(*) as cnt \
679                 FROM events \
680                 GROUP BY tumble(event_time, INTERVAL '5' MINUTE)",
681            )
682            .await;
683
684        // Should succeed in creating the logical plan (UDFs are registered)
685        assert!(df.is_ok(), "Failed to create logical plan: {df:?}");
686    }
687
688    #[tokio::test]
689    async fn test_end_to_end_execute_streaming_sql() {
690        use crate::planner::StreamingPlanner;
691
692        let ctx = create_streaming_context();
693
694        let schema = Arc::new(Schema::new(vec![
695            Field::new("id", DataType::Int64, false),
696            Field::new("name", DataType::Utf8, true),
697        ]));
698
699        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
700        let sender = take_test_sender(&source);
701        let provider = StreamingTableProvider::new("items", source);
702        ctx.register_table("items", Arc::new(provider)).unwrap();
703
704        let batch = RecordBatch::try_new(
705            Arc::clone(&schema),
706            vec![
707                Arc::new(Int64Array::from(vec![1, 2, 3])),
708                Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])),
709            ],
710        )
711        .unwrap();
712        sender.send(batch).await.unwrap();
713        drop(sender);
714
715        let mut planner = StreamingPlanner::new();
716        let result = execute_streaming_sql("SELECT id FROM items WHERE id > 1", &ctx, &mut planner)
717            .await
718            .unwrap();
719
720        match result {
721            StreamingSqlResult::Query(qr) => {
722                let mut stream = qr.stream;
723                let mut total = 0;
724                while let Some(batch) = stream.next().await {
725                    total += batch.unwrap().num_rows();
726                }
727                assert_eq!(total, 2); // id=2, id=3
728            }
729            StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
730        }
731    }
732
733    #[tokio::test]
734    async fn test_watermark_function_in_filter() {
735        use arrow_array::TimestampMillisecondArray;
736        use arrow_schema::TimeUnit;
737        use std::sync::atomic::AtomicI64;
738
739        // Create context with a specific watermark value
740        let config = base_session_config()
741            .with_batch_size(8192)
742            .with_target_partitions(1);
743        let ctx = SessionContext::new_with_config(config);
744        let wm = Arc::new(AtomicI64::new(200_000)); // watermark at 200s
745        register_streaming_functions_with_watermark(&ctx, wm);
746
747        let schema = Arc::new(Schema::new(vec![
748            Field::new(
749                "event_time",
750                DataType::Timestamp(TimeUnit::Millisecond, None),
751                false,
752            ),
753            Field::new("value", DataType::Float64, false),
754        ]));
755
756        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
757        let sender = take_test_sender(&source);
758        let provider = StreamingTableProvider::new("events", source);
759        ctx.register_table("events", Arc::new(provider)).unwrap();
760
761        // Events: 100s, 200s, 300s - watermark is at 200s
762        let batch = RecordBatch::try_new(
763            Arc::clone(&schema),
764            vec![
765                Arc::new(TimestampMillisecondArray::from(vec![
766                    100_000i64, 200_000, 300_000,
767                ])),
768                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
769            ],
770        )
771        .unwrap();
772        sender.send(batch).await.unwrap();
773        drop(sender);
774
775        // Filter events after watermark
776        let df = ctx
777            .sql("SELECT value FROM events WHERE event_time > watermark()")
778            .await
779            .unwrap();
780        let batches = df.collect().await.unwrap();
781        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
782        // Only event at 300s is after watermark (200s)
783        assert_eq!(total_rows, 1);
784    }
785
786    #[tokio::test]
787    async fn test_date_trunc_available() {
788        let ctx = create_streaming_context();
789        let df = ctx
790            .sql("SELECT date_trunc('hour', TIMESTAMP '2026-01-15 14:30:00')")
791            .await
792            .unwrap();
793        let batches = df.collect().await.unwrap();
794        assert_eq!(batches.len(), 1);
795        assert_eq!(batches[0].num_rows(), 1);
796    }
797
798    #[tokio::test]
799    async fn test_date_bin_available() {
800        let ctx = create_streaming_context();
801        let df = ctx
802            .sql(
803                "SELECT date_bin(\
804                 INTERVAL '15 minutes', \
805                 TIMESTAMP '2026-01-15 14:32:00', \
806                 TIMESTAMP '2026-01-01 00:00:00')",
807            )
808            .await
809            .unwrap();
810        let batches = df.collect().await.unwrap();
811        assert_eq!(batches.len(), 1);
812        assert_eq!(batches[0].num_rows(), 1);
813    }
814
815    #[tokio::test]
816    async fn test_unnest_literal_array() {
817        let ctx = create_streaming_context();
818        let df = ctx
819            .sql("SELECT unnest(make_array(1, 2, 3)) AS val")
820            .await
821            .unwrap();
822        let batches = df.collect().await.unwrap();
823        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
824        assert_eq!(total_rows, 3);
825    }
826
827    #[tokio::test]
828    async fn test_unnest_from_table_with_array_col() {
829        let ctx = create_streaming_context();
830        // Register a table with an array column
831        ctx.sql(
832            "CREATE TABLE arr_table (id INT, tags INT[]) \
833             AS VALUES (1, make_array(10, 20)), (2, make_array(30))",
834        )
835        .await
836        .unwrap();
837        let df = ctx
838            .sql("SELECT id, unnest(tags) AS tag FROM arr_table")
839            .await
840            .unwrap();
841        let batches = df.collect().await.unwrap();
842        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
843        // Row 1: [10,20] → 2 rows, Row 2: [30] → 1 row = 3 total
844        assert_eq!(total_rows, 3);
845    }
846}