Skip to main content

laminar_sql/datafusion/
execute.rs

1//! Streaming SQL execution via DataFusion.
2
3use datafusion::execution::SendableRecordBatchStream;
4use datafusion::prelude::SessionContext;
5
6use crate::parser::parse_streaming_sql;
7use crate::planner::{QueryPlan, StreamingPlan, StreamingPlanner};
8use crate::Error;
9
10/// Result of executing a streaming SQL statement.
11#[derive(Debug)]
12pub enum StreamingSqlResult {
13    /// DDL statement result (CREATE SOURCE, CREATE SINK)
14    Ddl(DdlResult),
15    /// Query execution result with optional streaming metadata
16    Query(QueryResult),
17}
18
19/// Result of a DDL statement execution.
20#[derive(Debug)]
21pub struct DdlResult {
22    /// The streaming plan describing what was created or registered
23    pub plan: StreamingPlan,
24}
25
26/// Result of a query execution.
27///
28/// Contains both the `DataFusion` record batch stream and optional
29/// streaming metadata (window config, join config, emit clause) from
30/// the `QueryPlan`. Ring 0 operators use the `query_plan` to configure
31/// windowing and join behavior.
32pub struct QueryResult {
33    /// Record batch stream from `DataFusion` execution
34    pub stream: SendableRecordBatchStream,
35    /// Streaming query metadata (window config, join config, etc.)
36    ///
37    /// `None` for standard SQL pass-through queries.
38    /// `Some` for queries with streaming features (windows, joins).
39    pub query_plan: Option<QueryPlan>,
40}
41
42impl std::fmt::Debug for QueryResult {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.debug_struct("QueryResult")
45            .field("query_plan", &self.query_plan)
46            .field("stream", &"<SendableRecordBatchStream>")
47            .finish()
48    }
49}
50
51/// Executes a streaming SQL statement end-to-end.
52///
53/// This function performs the full pipeline:
54/// 1. Parse SQL with streaming extensions (CREATE SOURCE/SINK, windows, etc.)
55/// 2. Plan via [`StreamingPlanner`]
56/// 3. For DDL: return the streaming plan as [`DdlResult`]
57/// 4. For queries with streaming features: create `LogicalPlan` via
58///    `DataFusion`, execute, and return stream + [`QueryPlan`] metadata
59/// 5. For standard SQL: pass through to `DataFusion` directly
60///
61/// # Arguments
62///
63/// * `sql` - The SQL statement to execute
64/// * `ctx` - `DataFusion` session context (should have streaming functions registered)
65/// * `planner` - Streaming planner with registered sources/sinks
66///
67/// # Errors
68///
69/// Returns [`Error`] if parsing, planning, or execution fails.
70pub async fn execute_streaming_sql(
71    sql: &str,
72    ctx: &SessionContext,
73    planner: &mut StreamingPlanner,
74) -> std::result::Result<StreamingSqlResult, Error> {
75    let statements = parse_streaming_sql(sql)?;
76
77    if statements.is_empty() {
78        return Err(Error::ParseError(
79            crate::parser::ParseError::StreamingError("Empty SQL statement".to_string()),
80        ));
81    }
82
83    // Process the first statement
84    let statement = &statements[0];
85    let plan = planner.plan(statement)?;
86
87    match plan {
88        StreamingPlan::Query(query_plan) => {
89            let logical_plan = planner.to_logical_plan(&query_plan, ctx).await?;
90            let df = ctx.execute_logical_plan(logical_plan).await?;
91            let stream = df.execute_stream().await?;
92
93            Ok(StreamingSqlResult::Query(QueryResult {
94                stream,
95                query_plan: Some(query_plan),
96            }))
97        }
98        StreamingPlan::Standard(stmt) => {
99            let sql_str = stmt.to_string();
100            let df = ctx.sql(&sql_str).await?;
101            let stream = df.execute_stream().await?;
102
103            Ok(StreamingSqlResult::Query(QueryResult {
104                stream,
105                query_plan: None,
106            }))
107        }
108        StreamingPlan::RegisterSource(_)
109        | StreamingPlan::RegisterSink(_)
110        | StreamingPlan::RegisterLookupTable(_)
111        | StreamingPlan::DropLookupTable { .. } => Ok(StreamingSqlResult::Ddl(DdlResult { plan })),
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::datafusion::create_streaming_context;
119
120    #[tokio::test]
121    async fn test_execute_ddl_source() {
122        let ctx = create_streaming_context();
123        crate::datafusion::register_streaming_functions(&ctx);
124        let mut planner = StreamingPlanner::new();
125
126        let result = execute_streaming_sql(
127            "CREATE SOURCE events (id INT, name VARCHAR)",
128            &ctx,
129            &mut planner,
130        )
131        .await
132        .unwrap();
133
134        assert!(matches!(result, StreamingSqlResult::Ddl(_)));
135    }
136
137    #[tokio::test]
138    async fn test_execute_ddl_sink() {
139        let ctx = create_streaming_context();
140        crate::datafusion::register_streaming_functions(&ctx);
141        let mut planner = StreamingPlanner::new();
142
143        // Register source first
144        execute_streaming_sql(
145            "CREATE SOURCE events (id INT, name VARCHAR)",
146            &ctx,
147            &mut planner,
148        )
149        .await
150        .unwrap();
151
152        let result = execute_streaming_sql("CREATE SINK output FROM events", &ctx, &mut planner)
153            .await
154            .unwrap();
155
156        assert!(matches!(result, StreamingSqlResult::Ddl(_)));
157    }
158
159    #[tokio::test]
160    async fn test_execute_empty_sql_error() {
161        let ctx = create_streaming_context();
162        let mut planner = StreamingPlanner::new();
163
164        let result = execute_streaming_sql("", &ctx, &mut planner).await;
165        assert!(result.is_err());
166    }
167
168    #[tokio::test]
169    async fn test_execute_standard_passthrough() {
170        use futures::StreamExt;
171
172        let ctx = create_streaming_context();
173        crate::datafusion::register_streaming_functions(&ctx);
174        let mut planner = StreamingPlanner::new();
175
176        // Simple SELECT 1 goes through DataFusion directly
177        let result = execute_streaming_sql("SELECT 1 as value", &ctx, &mut planner)
178            .await
179            .unwrap();
180
181        match result {
182            StreamingSqlResult::Query(qr) => {
183                assert!(qr.query_plan.is_none());
184                let mut stream = qr.stream;
185                let batch = stream.next().await.unwrap().unwrap();
186                assert_eq!(batch.num_rows(), 1);
187            }
188            StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
189        }
190    }
191
192    #[tokio::test]
193    async fn test_execute_standard_query_with_table() {
194        use arrow_array::{Int64Array, RecordBatch, StringArray};
195        use arrow_schema::{DataType, Field, Schema};
196        use futures::StreamExt;
197        use std::sync::Arc;
198
199        let ctx = create_streaming_context();
200        crate::datafusion::register_streaming_functions(&ctx);
201
202        let schema = Arc::new(Schema::new(vec![
203            Field::new("id", DataType::Int64, false),
204            Field::new("name", DataType::Utf8, true),
205        ]));
206
207        let source = Arc::new(crate::datafusion::ChannelStreamSource::new(Arc::clone(
208            &schema,
209        )));
210        let sender = source.take_sender().expect("sender available");
211        let provider = crate::datafusion::StreamingTableProvider::new("users", source);
212        ctx.register_table("users", Arc::new(provider)).unwrap();
213
214        // Send data and close channel
215        let batch = RecordBatch::try_new(
216            Arc::clone(&schema),
217            vec![
218                Arc::new(Int64Array::from(vec![1, 2])),
219                Arc::new(StringArray::from(vec!["alice", "bob"])),
220            ],
221        )
222        .unwrap();
223        sender.send(batch).await.unwrap();
224        drop(sender);
225
226        let mut planner = StreamingPlanner::new();
227        let result = execute_streaming_sql("SELECT id, name FROM users", &ctx, &mut planner)
228            .await
229            .unwrap();
230
231        match result {
232            StreamingSqlResult::Query(qr) => {
233                assert!(qr.query_plan.is_none()); // Standard query
234                let mut stream = qr.stream;
235                let mut total = 0;
236                while let Some(batch) = stream.next().await {
237                    total += batch.unwrap().num_rows();
238                }
239                assert_eq!(total, 2);
240            }
241            StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
242        }
243    }
244
245    /// Regression: DataFusion must plan `Timestamp(Nanosecond) + INTERVAL`
246    /// natively. If this breaks on a DF upgrade, the OTel example's
247    /// interval-join predicates stop working.
248    #[tokio::test]
249    async fn test_datafusion_timestamp_plus_interval_native() {
250        use arrow_array::{RecordBatch, TimestampNanosecondArray};
251        use arrow_schema::{DataType, Field, Schema, TimeUnit};
252        use datafusion::prelude::SessionContext;
253        use futures::StreamExt;
254        use std::sync::Arc;
255
256        let ctx = SessionContext::new();
257        let schema = Arc::new(Schema::new(vec![Field::new(
258            "ts",
259            DataType::Timestamp(TimeUnit::Nanosecond, None),
260            false,
261        )]));
262
263        // Year 2023 → ns. Add 5s; expect +5_000_000_000 ns.
264        let base_ns: i64 = 1_700_000_000_500_000_000;
265        let batch = RecordBatch::try_new(
266            Arc::clone(&schema),
267            vec![Arc::new(TimestampNanosecondArray::from(vec![base_ns]))],
268        )
269        .unwrap();
270
271        ctx.register_batch("events", batch).unwrap();
272
273        let df = ctx
274            .sql("SELECT ts + INTERVAL '5' SECOND AS shifted FROM events")
275            .await
276            .expect("DataFusion must plan Timestamp(Nanosecond) + INTERVAL natively");
277
278        let result_schema = df.schema().clone();
279        let shifted_type = result_schema.field(0).data_type();
280        assert!(
281            matches!(shifted_type, DataType::Timestamp(_, _)),
282            "expected Timestamp return type from Timestamp + INTERVAL, got {shifted_type:?}"
283        );
284
285        let mut stream = df.execute_stream().await.unwrap();
286        let batch = stream.next().await.unwrap().unwrap();
287        let col = batch.column(0);
288        let arr = col
289            .as_any()
290            .downcast_ref::<TimestampNanosecondArray>()
291            .expect("DataFusion should preserve Nanosecond precision");
292        assert_eq!(
293            arr.value(0),
294            base_ns + 5_000_000_000,
295            "Timestamp(Nanosecond) + INTERVAL '5' SECOND should add exactly 5_000_000_000 ns"
296        );
297    }
298
299    /// Regression: `BETWEEN t AND t + INTERVAL 'N' MINUTE` is the
300    /// exact shape the OTel example's interval join uses.
301    #[tokio::test]
302    async fn test_datafusion_timestamp_between_interval_native() {
303        use arrow_array::{Int64Array, RecordBatch, TimestampNanosecondArray};
304        use arrow_schema::{DataType, Field, Schema, TimeUnit};
305        use datafusion::prelude::SessionContext;
306        use futures::StreamExt;
307        use std::sync::Arc;
308
309        let ctx = SessionContext::new();
310        let schema = Arc::new(Schema::new(vec![
311            Field::new("id", DataType::Int64, false),
312            Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false),
313        ]));
314
315        let base_ns: i64 = 1_700_000_000_000_000_000;
316        let ids = Int64Array::from(vec![1, 2, 3, 4]);
317        let tses = TimestampNanosecondArray::from(vec![
318            base_ns,                   // inside the window
319            base_ns + 30_000_000_000,  // 30s later, inside
320            base_ns + 120_000_000_000, // exactly 2 min later, inclusive
321            base_ns + 180_000_000_000, // 3 min later, outside
322        ]);
323        let batch =
324            RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(ids), Arc::new(tses)]).unwrap();
325        ctx.register_batch("events", batch).unwrap();
326
327        // Matches the exact shape of the OTel example's join predicate.
328        let sql = "SELECT id FROM events \
329             WHERE ts BETWEEN TIMESTAMP '2023-11-14 22:13:20' \
330                      AND TIMESTAMP '2023-11-14 22:13:20' + INTERVAL '2' MINUTE";
331        let df = ctx.sql(sql).await.expect("BETWEEN with INTERVAL must plan");
332        let mut stream = df.execute_stream().await.unwrap();
333        let mut ids: Vec<i64> = Vec::new();
334        while let Some(batch) = stream.next().await {
335            let batch = batch.unwrap();
336            let col = batch
337                .column(0)
338                .as_any()
339                .downcast_ref::<Int64Array>()
340                .unwrap();
341            for i in 0..batch.num_rows() {
342                ids.push(col.value(i));
343            }
344        }
345        ids.sort_unstable();
346        assert_eq!(
347            ids,
348            vec![1, 2, 3],
349            "rows within 2-minute window should match (inclusive on upper bound)"
350        );
351    }
352
353    /// Regression for the OTel example's `time_to_response_ms`: timestamp
354    /// subtraction returns `Duration(Nanosecond)` (not `Interval`), and
355    /// `CAST(Duration AS BIGINT) / 1_000_000` gives milliseconds.
356    #[tokio::test]
357    async fn test_datafusion_timestamp_subtraction_cast_to_bigint() {
358        use arrow_array::{Int64Array, RecordBatch, TimestampNanosecondArray};
359        use arrow_schema::{DataType, Field, Schema, TimeUnit};
360        use datafusion::prelude::SessionContext;
361        use futures::StreamExt;
362        use std::sync::Arc;
363
364        let ctx = SessionContext::new();
365        let schema = Arc::new(Schema::new(vec![
366            Field::new(
367                "a_ts",
368                DataType::Timestamp(TimeUnit::Nanosecond, None),
369                false,
370            ),
371            Field::new(
372                "p_ts",
373                DataType::Timestamp(TimeUnit::Nanosecond, None),
374                false,
375            ),
376        ]));
377        let base: i64 = 1_700_000_000_000_000_000;
378        let batch = RecordBatch::try_new(
379            Arc::clone(&schema),
380            vec![
381                Arc::new(TimestampNanosecondArray::from(vec![base + 500_000_000])),
382                Arc::new(TimestampNanosecondArray::from(vec![base])),
383            ],
384        )
385        .unwrap();
386        ctx.register_batch("events", batch).unwrap();
387
388        let df = ctx
389            .sql("SELECT CAST(a_ts - p_ts AS BIGINT) / 1000000 AS ms FROM events")
390            .await
391            .expect("CAST(Timestamp - Timestamp AS BIGINT) must plan on DataFusion 52");
392        let mut stream = df.execute_stream().await.unwrap();
393        let batch = stream.next().await.unwrap().unwrap();
394        let col = batch
395            .column(0)
396            .as_any()
397            .downcast_ref::<Int64Array>()
398            .expect("result should be Int64 after the divide");
399        assert_eq!(col.value(0), 500, "500 ms difference");
400    }
401}