Skip to main content

laminar_sql/datafusion/
mod.rs

1//! `DataFusion` integration for SQL processing
2//!
3//! This module provides the integration layer between `LaminarDB`'s push-based
4//! streaming engine and `DataFusion`'s pull-based SQL query execution.
5//!
6//! # Architecture
7//!
8//! ```text
9//! ┌─────────────────────────────────────────────────────────────────┐
10//! │                    Ring 2: Query Planning                        │
11//! │  SQL Query → SessionContext → LogicalPlan → ExecutionPlan       │
12//! │                                      │                          │
13//! │                            StreamingScanExec                    │
14//! │                                      │                          │
15//! │                              ┌───────▼──────┐                   │
16//! │                              │ StreamBridge │ (tokio channel)   │
17//! │                              └───────▲──────┘                   │
18//! ├──────────────────────────────────────┼──────────────────────────┤
19//! │                    Ring 0: Hot Path   │                          │
20//! │                                      │                          │
21//! │  Source → Reactor.poll() ────────────┘                          │
22//! │              (Events with RecordBatch data)                     │
23//! └─────────────────────────────────────────────────────────────────┘
24//! ```
25//!
26//! # Components
27//!
28//! - [`StreamSource`]: Trait for streaming data sources
29//! - [`StreamBridge`]: Channel-based push-to-pull bridge
30//! - [`StreamingScanExec`]: `DataFusion` execution plan for streaming scans
31//! - [`StreamingTableProvider`]: `DataFusion` table provider for streaming sources
32//! - [`ChannelStreamSource`]: Concrete source using channels
33//!
34//! # Usage
35//!
36//! ```rust,ignore
37//! use laminar_sql::datafusion::{
38//!     create_streaming_context, ChannelStreamSource, StreamingTableProvider,
39//! };
40//! use std::sync::Arc;
41//!
42//! // Create a streaming context
43//! let ctx = create_streaming_context();
44//!
45//! // Create a channel source
46//! let schema = Arc::new(Schema::new(vec![
47//!     Field::new("id", DataType::Int64, false),
48//!     Field::new("value", DataType::Float64, true),
49//! ]));
50//! let source = Arc::new(ChannelStreamSource::new(schema));
51//! let sender = source.sender();
52//!
53//! // Register as a table
54//! let provider = StreamingTableProvider::new("events", source);
55//! ctx.register_table("events", Arc::new(provider))?;
56//!
57//! // Push data from the Reactor
58//! sender.send(batch).await?;
59//!
60//! // Execute SQL queries
61//! let df = ctx.sql("SELECT * FROM events WHERE value > 100").await?;
62//! ```
63
64/// F075: DataFusion aggregate bridge for streaming aggregation.
65///
66/// Bridges DataFusion's `Accumulator` trait with `laminar-core`'s
67/// `DynAccumulator` / `DynAggregatorFactory` traits. This avoids
68/// duplicating aggregation logic.
69pub mod aggregate_bridge;
70mod bridge;
71mod channel_source;
72mod exec;
73/// End-to-end streaming SQL execution
74pub mod execute;
75mod source;
76mod table_provider;
77/// Watermark UDF for current watermark access
78pub mod watermark_udf;
79/// Window function UDFs (TUMBLE, HOP, SESSION)
80pub mod window_udf;
81
82pub use aggregate_bridge::{
83    create_aggregate_factory, lookup_aggregate_udf, result_to_scalar_value, scalar_value_to_result,
84    DataFusionAccumulatorAdapter, DataFusionAggregateFactory,
85};
86pub use bridge::{BridgeSendError, BridgeSender, BridgeStream, BridgeTrySendError, StreamBridge};
87pub use channel_source::ChannelStreamSource;
88pub use exec::StreamingScanExec;
89pub use execute::{execute_streaming_sql, DdlResult, QueryResult, StreamingSqlResult};
90pub use source::{SortColumn, StreamSource, StreamSourceRef};
91pub use table_provider::StreamingTableProvider;
92pub use watermark_udf::WatermarkUdf;
93pub use window_udf::{HopWindowStart, SessionWindowStart, TumbleWindowStart};
94
95use std::sync::atomic::AtomicI64;
96use std::sync::Arc;
97
98use datafusion::prelude::*;
99use datafusion_expr::ScalarUDF;
100
101/// Creates a `DataFusion` session context configured for streaming queries.
102///
103/// The context is configured with:
104/// - Batch size of 8192 (balanced for streaming throughput)
105/// - Single partition (streaming sources are typically not partitioned)
106/// - All streaming UDFs registered (TUMBLE, HOP, SESSION, WATERMARK)
107///
108/// The watermark UDF is initialized with no watermark set (returns NULL).
109/// Use [`register_streaming_functions_with_watermark`] to provide a live
110/// watermark source.
111///
112/// # Example
113///
114/// ```rust,ignore
115/// let ctx = create_streaming_context();
116/// ctx.register_table("events", provider)?;
117/// let df = ctx.sql("SELECT * FROM events").await?;
118/// ```
119#[must_use]
120pub fn create_streaming_context() -> SessionContext {
121    let config = SessionConfig::new()
122        .with_batch_size(8192)
123        .with_target_partitions(1); // Single partition for streaming
124
125    let ctx = SessionContext::new_with_config(config);
126    register_streaming_functions(&ctx);
127    ctx
128}
129
130/// Registers `LaminarDB` streaming UDFs with a session context.
131///
132/// Registers the following scalar functions:
133/// - `tumble(timestamp, interval)` — tumbling window start
134/// - `hop(timestamp, slide, size)` — hopping window start
135/// - `session(timestamp, gap)` — session window pass-through
136/// - `watermark()` — current watermark (returns NULL, no live source)
137///
138/// Use [`register_streaming_functions_with_watermark`] to provide a
139/// live watermark source from Ring 0.
140pub fn register_streaming_functions(ctx: &SessionContext) {
141    ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
142    ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
143    ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
144    ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::unset()));
145}
146
147/// Registers streaming UDFs with a live watermark source.
148///
149/// Same as [`register_streaming_functions`] but connects the `watermark()`
150/// UDF to a shared atomic value that Ring 0 updates in real time.
151///
152/// # Arguments
153///
154/// * `ctx` - `DataFusion` session context
155/// * `watermark_ms` - Shared atomic holding the current watermark in
156///   milliseconds since epoch. Values < 0 mean "no watermark" (returns NULL).
157pub fn register_streaming_functions_with_watermark(
158    ctx: &SessionContext,
159    watermark_ms: Arc<AtomicI64>,
160) {
161    ctx.register_udf(ScalarUDF::new_from_impl(TumbleWindowStart::new()));
162    ctx.register_udf(ScalarUDF::new_from_impl(HopWindowStart::new()));
163    ctx.register_udf(ScalarUDF::new_from_impl(SessionWindowStart::new()));
164    ctx.register_udf(ScalarUDF::new_from_impl(WatermarkUdf::new(watermark_ms)));
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170    use arrow_array::{Float64Array, Int64Array, RecordBatch};
171    use arrow_schema::{DataType, Field, Schema};
172    use datafusion::execution::FunctionRegistry;
173    use futures::StreamExt;
174    use std::sync::Arc;
175
176    fn test_schema() -> Arc<Schema> {
177        Arc::new(Schema::new(vec![
178            Field::new("id", DataType::Int64, false),
179            Field::new("value", DataType::Float64, true),
180        ]))
181    }
182
183    /// Take the sender from a `ChannelStreamSource`, panicking if already taken.
184    fn take_test_sender(source: &ChannelStreamSource) -> super::bridge::BridgeSender {
185        source.take_sender().expect("sender already taken")
186    }
187
188    fn test_batch(schema: &Arc<Schema>, ids: Vec<i64>, values: Vec<f64>) -> RecordBatch {
189        RecordBatch::try_new(
190            Arc::clone(schema),
191            vec![
192                Arc::new(Int64Array::from(ids)),
193                Arc::new(Float64Array::from(values)),
194            ],
195        )
196        .unwrap()
197    }
198
199    #[test]
200    fn test_create_streaming_context() {
201        let ctx = create_streaming_context();
202        let state = ctx.state();
203        let config = state.config();
204
205        assert_eq!(config.batch_size(), 8192);
206        assert_eq!(config.target_partitions(), 1);
207    }
208
209    #[tokio::test]
210    async fn test_full_query_pipeline() {
211        let ctx = create_streaming_context();
212        let schema = test_schema();
213
214        // Create source and take the sender (important for channel closure)
215        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
216        let sender = take_test_sender(&source);
217        let provider = StreamingTableProvider::new("events", source);
218        ctx.register_table("events", Arc::new(provider)).unwrap();
219
220        // Send test data
221        sender
222            .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
223            .await
224            .unwrap();
225        sender
226            .send(test_batch(&schema, vec![4, 5], vec![40.0, 50.0]))
227            .await
228            .unwrap();
229        drop(sender); // Close the channel
230
231        // Execute query
232        let df = ctx.sql("SELECT * FROM events").await.unwrap();
233        let batches = df.collect().await.unwrap();
234
235        // Verify results
236        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
237        assert_eq!(total_rows, 5);
238    }
239
240    #[tokio::test]
241    async fn test_query_with_projection() {
242        let ctx = create_streaming_context();
243        let schema = test_schema();
244
245        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
246        let sender = take_test_sender(&source);
247        let provider = StreamingTableProvider::new("events", source);
248        ctx.register_table("events", Arc::new(provider)).unwrap();
249
250        sender
251            .send(test_batch(&schema, vec![1, 2], vec![100.0, 200.0]))
252            .await
253            .unwrap();
254        drop(sender);
255
256        // Query only the id column
257        let df = ctx.sql("SELECT id FROM events").await.unwrap();
258        let batches = df.collect().await.unwrap();
259
260        assert_eq!(batches.len(), 1);
261        assert_eq!(batches[0].num_columns(), 1);
262        assert_eq!(batches[0].schema().field(0).name(), "id");
263    }
264
265    #[tokio::test]
266    async fn test_query_with_filter() {
267        let ctx = create_streaming_context();
268        let schema = test_schema();
269
270        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
271        let sender = take_test_sender(&source);
272        let provider = StreamingTableProvider::new("events", source);
273        ctx.register_table("events", Arc::new(provider)).unwrap();
274
275        sender
276            .send(test_batch(
277                &schema,
278                vec![1, 2, 3, 4, 5],
279                vec![10.0, 20.0, 30.0, 40.0, 50.0],
280            ))
281            .await
282            .unwrap();
283        drop(sender);
284
285        // Filter for value > 25
286        let df = ctx
287            .sql("SELECT * FROM events WHERE value > 25")
288            .await
289            .unwrap();
290        let batches = df.collect().await.unwrap();
291
292        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
293        assert_eq!(total_rows, 3); // 30, 40, 50
294    }
295
296    #[tokio::test]
297    async fn test_unbounded_aggregation_rejected() {
298        // Aggregations on unbounded streams should be rejected by `DataFusion`.
299        // Streaming aggregations require windows, which are implemented in F006.
300        let ctx = create_streaming_context();
301        let schema = test_schema();
302
303        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
304        let sender = take_test_sender(&source);
305        let provider = StreamingTableProvider::new("events", source);
306        ctx.register_table("events", Arc::new(provider)).unwrap();
307
308        sender
309            .send(test_batch(&schema, vec![1, 2, 3], vec![10.0, 20.0, 30.0]))
310            .await
311            .unwrap();
312        drop(sender);
313
314        // Aggregate query on unbounded stream should fail at execution
315        let df = ctx.sql("SELECT COUNT(*) as cnt FROM events").await.unwrap();
316
317        // Execution should fail because we can't aggregate an infinite stream
318        let result = df.collect().await;
319        assert!(
320            result.is_err(),
321            "Aggregation on unbounded stream should fail"
322        );
323    }
324
325    #[tokio::test]
326    async fn test_query_with_order_by() {
327        let ctx = create_streaming_context();
328        let schema = test_schema();
329
330        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
331        let sender = take_test_sender(&source);
332        let provider = StreamingTableProvider::new("events", source);
333        ctx.register_table("events", Arc::new(provider)).unwrap();
334
335        sender
336            .send(test_batch(&schema, vec![3, 1, 2], vec![30.0, 10.0, 20.0]))
337            .await
338            .unwrap();
339        drop(sender);
340
341        // Query with ORDER BY (`DataFusion` handles this with Sort operator)
342        let df = ctx.sql("SELECT id, value FROM events").await.unwrap();
343        let batches = df.collect().await.unwrap();
344
345        // Verify we got results (ordering may vary due to streaming nature)
346        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
347        assert_eq!(total_rows, 3);
348    }
349
350    #[tokio::test]
351    async fn test_bridge_throughput() {
352        // Benchmark-style test for bridge performance
353        let schema = test_schema();
354        let bridge = StreamBridge::new(Arc::clone(&schema), 10000);
355        let sender = bridge.sender();
356        let mut stream = bridge.into_stream();
357
358        let batch_count = 1000;
359        let batch = test_batch(&schema, vec![1, 2, 3, 4, 5], vec![1.0, 2.0, 3.0, 4.0, 5.0]);
360
361        // Spawn sender task
362        let send_task = tokio::spawn(async move {
363            for _ in 0..batch_count {
364                sender.send(batch.clone()).await.unwrap();
365            }
366        });
367
368        // Receive all batches
369        let mut received = 0;
370        while let Some(result) = stream.next().await {
371            result.unwrap();
372            received += 1;
373            if received == batch_count {
374                break;
375            }
376        }
377
378        send_task.await.unwrap();
379        assert_eq!(received, batch_count);
380    }
381
382    // ── F005B Integration Tests ──────────────────────────────────────────
383
384    #[test]
385    fn test_streaming_functions_registered() {
386        let ctx = create_streaming_context();
387        // Verify all 4 UDFs are registered
388        assert!(ctx.udf("tumble").is_ok(), "tumble UDF not registered");
389        assert!(ctx.udf("hop").is_ok(), "hop UDF not registered");
390        assert!(ctx.udf("session").is_ok(), "session UDF not registered");
391        assert!(ctx.udf("watermark").is_ok(), "watermark UDF not registered");
392    }
393
394    #[test]
395    fn test_streaming_functions_with_watermark() {
396        use std::sync::atomic::AtomicI64;
397
398        let ctx = SessionContext::new();
399        let wm = Arc::new(AtomicI64::new(42_000));
400        register_streaming_functions_with_watermark(&ctx, wm);
401
402        assert!(ctx.udf("tumble").is_ok());
403        assert!(ctx.udf("watermark").is_ok());
404    }
405
406    #[tokio::test]
407    async fn test_tumble_udf_via_datafusion() {
408        use arrow_array::TimestampMillisecondArray;
409        use arrow_schema::TimeUnit;
410
411        let ctx = create_streaming_context();
412
413        // Create schema with timestamp and value columns
414        let schema = Arc::new(Schema::new(vec![
415            Field::new(
416                "event_time",
417                DataType::Timestamp(TimeUnit::Millisecond, None),
418                false,
419            ),
420            Field::new("value", DataType::Float64, false),
421        ]));
422
423        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
424        let sender = take_test_sender(&source);
425        let provider = StreamingTableProvider::new("events", source);
426        ctx.register_table("events", Arc::new(provider)).unwrap();
427
428        // Send events across two 5-minute windows:
429        // Window [0, 300_000): timestamps 60_000, 120_000
430        // Window [300_000, 600_000): timestamps 360_000
431        let batch = RecordBatch::try_new(
432            Arc::clone(&schema),
433            vec![
434                Arc::new(TimestampMillisecondArray::from(vec![
435                    60_000i64, 120_000, 360_000,
436                ])),
437                Arc::new(Float64Array::from(vec![10.0, 20.0, 30.0])),
438            ],
439        )
440        .unwrap();
441        sender.send(batch).await.unwrap();
442        drop(sender);
443
444        // Verify the tumble UDF computes correct window starts via DataFusion
445        // (GROUP BY aggregation and ORDER BY on unbounded streams are handled by Ring 0)
446        let df = ctx
447            .sql(
448                "SELECT tumble(event_time, INTERVAL '5' MINUTE) as window_start, \
449                 value \
450                 FROM events",
451            )
452            .await
453            .unwrap();
454
455        let batches = df.collect().await.unwrap();
456        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
457        assert_eq!(total_rows, 3);
458
459        // Verify the window_start values (single batch, order preserved)
460        let ws_col = batches[0]
461            .column(0)
462            .as_any()
463            .downcast_ref::<TimestampMillisecondArray>()
464            .expect("window_start should be TimestampMillisecond");
465        // 60_000 and 120_000 → window [0, 300_000), start = 0
466        assert_eq!(ws_col.value(0), 0);
467        assert_eq!(ws_col.value(1), 0);
468        // 360_000 → window [300_000, 600_000), start = 300_000
469        assert_eq!(ws_col.value(2), 300_000);
470    }
471
472    #[tokio::test]
473    async fn test_logical_plan_from_windowed_query() {
474        use arrow_schema::TimeUnit;
475
476        let ctx = create_streaming_context();
477
478        let schema = Arc::new(Schema::new(vec![
479            Field::new(
480                "event_time",
481                DataType::Timestamp(TimeUnit::Millisecond, None),
482                false,
483            ),
484            Field::new("value", DataType::Float64, false),
485        ]));
486
487        let source = Arc::new(ChannelStreamSource::new(schema));
488        let _sender = source.take_sender();
489        let provider = StreamingTableProvider::new("events", source);
490        ctx.register_table("events", Arc::new(provider)).unwrap();
491
492        // Create a LogicalPlan for a windowed query
493        let df = ctx
494            .sql(
495                "SELECT tumble(event_time, INTERVAL '5' MINUTE) as w, \
496                 COUNT(*) as cnt \
497                 FROM events \
498                 GROUP BY tumble(event_time, INTERVAL '5' MINUTE)",
499            )
500            .await;
501
502        // Should succeed in creating the logical plan (UDFs are registered)
503        assert!(df.is_ok(), "Failed to create logical plan: {df:?}");
504    }
505
506    #[tokio::test]
507    async fn test_end_to_end_execute_streaming_sql() {
508        use crate::planner::StreamingPlanner;
509
510        let ctx = create_streaming_context();
511
512        let schema = Arc::new(Schema::new(vec![
513            Field::new("id", DataType::Int64, false),
514            Field::new("name", DataType::Utf8, true),
515        ]));
516
517        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
518        let sender = take_test_sender(&source);
519        let provider = StreamingTableProvider::new("items", source);
520        ctx.register_table("items", Arc::new(provider)).unwrap();
521
522        let batch = RecordBatch::try_new(
523            Arc::clone(&schema),
524            vec![
525                Arc::new(Int64Array::from(vec![1, 2, 3])),
526                Arc::new(arrow_array::StringArray::from(vec!["a", "b", "c"])),
527            ],
528        )
529        .unwrap();
530        sender.send(batch).await.unwrap();
531        drop(sender);
532
533        let mut planner = StreamingPlanner::new();
534        let result = execute_streaming_sql("SELECT id FROM items WHERE id > 1", &ctx, &mut planner)
535            .await
536            .unwrap();
537
538        match result {
539            StreamingSqlResult::Query(qr) => {
540                let mut stream = qr.stream;
541                let mut total = 0;
542                while let Some(batch) = stream.next().await {
543                    total += batch.unwrap().num_rows();
544                }
545                assert_eq!(total, 2); // id=2, id=3
546            }
547            StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
548        }
549    }
550
551    #[tokio::test]
552    async fn test_watermark_function_in_filter() {
553        use arrow_array::TimestampMillisecondArray;
554        use arrow_schema::TimeUnit;
555        use std::sync::atomic::AtomicI64;
556
557        // Create context with a specific watermark value
558        let config = SessionConfig::new()
559            .with_batch_size(8192)
560            .with_target_partitions(1);
561        let ctx = SessionContext::new_with_config(config);
562        let wm = Arc::new(AtomicI64::new(200_000)); // watermark at 200s
563        register_streaming_functions_with_watermark(&ctx, wm);
564
565        let schema = Arc::new(Schema::new(vec![
566            Field::new(
567                "event_time",
568                DataType::Timestamp(TimeUnit::Millisecond, None),
569                false,
570            ),
571            Field::new("value", DataType::Float64, false),
572        ]));
573
574        let source = Arc::new(ChannelStreamSource::new(Arc::clone(&schema)));
575        let sender = take_test_sender(&source);
576        let provider = StreamingTableProvider::new("events", source);
577        ctx.register_table("events", Arc::new(provider)).unwrap();
578
579        // Events: 100s, 200s, 300s - watermark is at 200s
580        let batch = RecordBatch::try_new(
581            Arc::clone(&schema),
582            vec![
583                Arc::new(TimestampMillisecondArray::from(vec![
584                    100_000i64, 200_000, 300_000,
585                ])),
586                Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0])),
587            ],
588        )
589        .unwrap();
590        sender.send(batch).await.unwrap();
591        drop(sender);
592
593        // Filter events after watermark
594        let df = ctx
595            .sql("SELECT value FROM events WHERE event_time > watermark()")
596            .await
597            .unwrap();
598        let batches = df.collect().await.unwrap();
599        let total_rows: usize = batches.iter().map(RecordBatch::num_rows).sum();
600        // Only event at 300s is after watermark (200s)
601        assert_eq!(total_rows, 1);
602    }
603}