1use datafusion::execution::SendableRecordBatchStream;
4use datafusion::prelude::SessionContext;
5
6use crate::parser::parse_streaming_sql;
7use crate::planner::{QueryPlan, StreamingPlan, StreamingPlanner};
8use crate::Error;
9
10#[derive(Debug)]
12pub enum StreamingSqlResult {
13 Ddl(DdlResult),
15 Query(QueryResult),
17}
18
19#[derive(Debug)]
21pub struct DdlResult {
22 pub plan: StreamingPlan,
24}
25
26pub struct QueryResult {
33 pub stream: SendableRecordBatchStream,
35 pub query_plan: Option<QueryPlan>,
40}
41
42impl std::fmt::Debug for QueryResult {
43 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44 f.debug_struct("QueryResult")
45 .field("query_plan", &self.query_plan)
46 .field("stream", &"<SendableRecordBatchStream>")
47 .finish()
48 }
49}
50
51pub async fn execute_streaming_sql(
71 sql: &str,
72 ctx: &SessionContext,
73 planner: &mut StreamingPlanner,
74) -> std::result::Result<StreamingSqlResult, Error> {
75 let statements = parse_streaming_sql(sql)?;
76
77 if statements.is_empty() {
78 return Err(Error::ParseError(
79 crate::parser::ParseError::StreamingError("Empty SQL statement".to_string()),
80 ));
81 }
82
83 let statement = &statements[0];
85 let plan = planner.plan(statement)?;
86
87 match plan {
88 StreamingPlan::Query(query_plan) => {
89 let logical_plan = planner.to_logical_plan(&query_plan, ctx).await?;
90 let df = ctx.execute_logical_plan(logical_plan).await?;
91 let stream = df.execute_stream().await?;
92
93 Ok(StreamingSqlResult::Query(QueryResult {
94 stream,
95 query_plan: Some(query_plan),
96 }))
97 }
98 StreamingPlan::Standard(stmt) => {
99 let sql_str = stmt.to_string();
100 let df = ctx.sql(&sql_str).await?;
101 let stream = df.execute_stream().await?;
102
103 Ok(StreamingSqlResult::Query(QueryResult {
104 stream,
105 query_plan: None,
106 }))
107 }
108 StreamingPlan::RegisterSource(_)
109 | StreamingPlan::RegisterSink(_)
110 | StreamingPlan::RegisterLookupTable(_)
111 | StreamingPlan::DropLookupTable { .. } => Ok(StreamingSqlResult::Ddl(DdlResult { plan })),
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use crate::datafusion::create_streaming_context;
119
120 #[tokio::test]
121 async fn test_execute_ddl_source() {
122 let ctx = create_streaming_context();
123 crate::datafusion::register_streaming_functions(&ctx);
124 let mut planner = StreamingPlanner::new();
125
126 let result = execute_streaming_sql(
127 "CREATE SOURCE events (id INT, name VARCHAR)",
128 &ctx,
129 &mut planner,
130 )
131 .await
132 .unwrap();
133
134 assert!(matches!(result, StreamingSqlResult::Ddl(_)));
135 }
136
137 #[tokio::test]
138 async fn test_execute_ddl_sink() {
139 let ctx = create_streaming_context();
140 crate::datafusion::register_streaming_functions(&ctx);
141 let mut planner = StreamingPlanner::new();
142
143 execute_streaming_sql(
145 "CREATE SOURCE events (id INT, name VARCHAR)",
146 &ctx,
147 &mut planner,
148 )
149 .await
150 .unwrap();
151
152 let result = execute_streaming_sql("CREATE SINK output FROM events", &ctx, &mut planner)
153 .await
154 .unwrap();
155
156 assert!(matches!(result, StreamingSqlResult::Ddl(_)));
157 }
158
159 #[tokio::test]
160 async fn test_execute_empty_sql_error() {
161 let ctx = create_streaming_context();
162 let mut planner = StreamingPlanner::new();
163
164 let result = execute_streaming_sql("", &ctx, &mut planner).await;
165 assert!(result.is_err());
166 }
167
168 #[tokio::test]
169 async fn test_execute_standard_passthrough() {
170 use futures::StreamExt;
171
172 let ctx = create_streaming_context();
173 crate::datafusion::register_streaming_functions(&ctx);
174 let mut planner = StreamingPlanner::new();
175
176 let result = execute_streaming_sql("SELECT 1 as value", &ctx, &mut planner)
178 .await
179 .unwrap();
180
181 match result {
182 StreamingSqlResult::Query(qr) => {
183 assert!(qr.query_plan.is_none());
184 let mut stream = qr.stream;
185 let batch = stream.next().await.unwrap().unwrap();
186 assert_eq!(batch.num_rows(), 1);
187 }
188 StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
189 }
190 }
191
192 #[tokio::test]
193 async fn test_execute_standard_query_with_table() {
194 use arrow_array::{Int64Array, RecordBatch, StringArray};
195 use arrow_schema::{DataType, Field, Schema};
196 use futures::StreamExt;
197 use std::sync::Arc;
198
199 let ctx = create_streaming_context();
200 crate::datafusion::register_streaming_functions(&ctx);
201
202 let schema = Arc::new(Schema::new(vec![
203 Field::new("id", DataType::Int64, false),
204 Field::new("name", DataType::Utf8, true),
205 ]));
206
207 let source = Arc::new(crate::datafusion::ChannelStreamSource::new(Arc::clone(
208 &schema,
209 )));
210 let sender = source.take_sender().expect("sender available");
211 let provider = crate::datafusion::StreamingTableProvider::new("users", source);
212 ctx.register_table("users", Arc::new(provider)).unwrap();
213
214 let batch = RecordBatch::try_new(
216 Arc::clone(&schema),
217 vec![
218 Arc::new(Int64Array::from(vec![1, 2])),
219 Arc::new(StringArray::from(vec!["alice", "bob"])),
220 ],
221 )
222 .unwrap();
223 sender.send(batch).await.unwrap();
224 drop(sender);
225
226 let mut planner = StreamingPlanner::new();
227 let result = execute_streaming_sql("SELECT id, name FROM users", &ctx, &mut planner)
228 .await
229 .unwrap();
230
231 match result {
232 StreamingSqlResult::Query(qr) => {
233 assert!(qr.query_plan.is_none()); let mut stream = qr.stream;
235 let mut total = 0;
236 while let Some(batch) = stream.next().await {
237 total += batch.unwrap().num_rows();
238 }
239 assert_eq!(total, 2);
240 }
241 StreamingSqlResult::Ddl(_) => panic!("Expected Query result"),
242 }
243 }
244
245 #[tokio::test]
249 async fn test_datafusion_timestamp_plus_interval_native() {
250 use arrow_array::{RecordBatch, TimestampNanosecondArray};
251 use arrow_schema::{DataType, Field, Schema, TimeUnit};
252 use datafusion::prelude::SessionContext;
253 use futures::StreamExt;
254 use std::sync::Arc;
255
256 let ctx = SessionContext::new();
257 let schema = Arc::new(Schema::new(vec![Field::new(
258 "ts",
259 DataType::Timestamp(TimeUnit::Nanosecond, None),
260 false,
261 )]));
262
263 let base_ns: i64 = 1_700_000_000_500_000_000;
265 let batch = RecordBatch::try_new(
266 Arc::clone(&schema),
267 vec![Arc::new(TimestampNanosecondArray::from(vec![base_ns]))],
268 )
269 .unwrap();
270
271 ctx.register_batch("events", batch).unwrap();
272
273 let df = ctx
274 .sql("SELECT ts + INTERVAL '5' SECOND AS shifted FROM events")
275 .await
276 .expect("DataFusion must plan Timestamp(Nanosecond) + INTERVAL natively");
277
278 let result_schema = df.schema().clone();
279 let shifted_type = result_schema.field(0).data_type();
280 assert!(
281 matches!(shifted_type, DataType::Timestamp(_, _)),
282 "expected Timestamp return type from Timestamp + INTERVAL, got {shifted_type:?}"
283 );
284
285 let mut stream = df.execute_stream().await.unwrap();
286 let batch = stream.next().await.unwrap().unwrap();
287 let col = batch.column(0);
288 let arr = col
289 .as_any()
290 .downcast_ref::<TimestampNanosecondArray>()
291 .expect("DataFusion should preserve Nanosecond precision");
292 assert_eq!(
293 arr.value(0),
294 base_ns + 5_000_000_000,
295 "Timestamp(Nanosecond) + INTERVAL '5' SECOND should add exactly 5_000_000_000 ns"
296 );
297 }
298
299 #[tokio::test]
302 async fn test_datafusion_timestamp_between_interval_native() {
303 use arrow_array::{Int64Array, RecordBatch, TimestampNanosecondArray};
304 use arrow_schema::{DataType, Field, Schema, TimeUnit};
305 use datafusion::prelude::SessionContext;
306 use futures::StreamExt;
307 use std::sync::Arc;
308
309 let ctx = SessionContext::new();
310 let schema = Arc::new(Schema::new(vec![
311 Field::new("id", DataType::Int64, false),
312 Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false),
313 ]));
314
315 let base_ns: i64 = 1_700_000_000_000_000_000;
316 let ids = Int64Array::from(vec![1, 2, 3, 4]);
317 let tses = TimestampNanosecondArray::from(vec![
318 base_ns, base_ns + 30_000_000_000, base_ns + 120_000_000_000, base_ns + 180_000_000_000, ]);
323 let batch =
324 RecordBatch::try_new(Arc::clone(&schema), vec![Arc::new(ids), Arc::new(tses)]).unwrap();
325 ctx.register_batch("events", batch).unwrap();
326
327 let sql = "SELECT id FROM events \
329 WHERE ts BETWEEN TIMESTAMP '2023-11-14 22:13:20' \
330 AND TIMESTAMP '2023-11-14 22:13:20' + INTERVAL '2' MINUTE";
331 let df = ctx.sql(sql).await.expect("BETWEEN with INTERVAL must plan");
332 let mut stream = df.execute_stream().await.unwrap();
333 let mut ids: Vec<i64> = Vec::new();
334 while let Some(batch) = stream.next().await {
335 let batch = batch.unwrap();
336 let col = batch
337 .column(0)
338 .as_any()
339 .downcast_ref::<Int64Array>()
340 .unwrap();
341 for i in 0..batch.num_rows() {
342 ids.push(col.value(i));
343 }
344 }
345 ids.sort_unstable();
346 assert_eq!(
347 ids,
348 vec![1, 2, 3],
349 "rows within 2-minute window should match (inclusive on upper bound)"
350 );
351 }
352
353 #[tokio::test]
357 async fn test_datafusion_timestamp_subtraction_cast_to_bigint() {
358 use arrow_array::{Int64Array, RecordBatch, TimestampNanosecondArray};
359 use arrow_schema::{DataType, Field, Schema, TimeUnit};
360 use datafusion::prelude::SessionContext;
361 use futures::StreamExt;
362 use std::sync::Arc;
363
364 let ctx = SessionContext::new();
365 let schema = Arc::new(Schema::new(vec![
366 Field::new(
367 "a_ts",
368 DataType::Timestamp(TimeUnit::Nanosecond, None),
369 false,
370 ),
371 Field::new(
372 "p_ts",
373 DataType::Timestamp(TimeUnit::Nanosecond, None),
374 false,
375 ),
376 ]));
377 let base: i64 = 1_700_000_000_000_000_000;
378 let batch = RecordBatch::try_new(
379 Arc::clone(&schema),
380 vec![
381 Arc::new(TimestampNanosecondArray::from(vec![base + 500_000_000])),
382 Arc::new(TimestampNanosecondArray::from(vec![base])),
383 ],
384 )
385 .unwrap();
386 ctx.register_batch("events", batch).unwrap();
387
388 let df = ctx
389 .sql("SELECT CAST(a_ts - p_ts AS BIGINT) / 1000000 AS ms FROM events")
390 .await
391 .expect("CAST(Timestamp - Timestamp AS BIGINT) must plan on DataFusion 52");
392 let mut stream = df.execute_stream().await.unwrap();
393 let batch = stream.next().await.unwrap().unwrap();
394 let col = batch
395 .column(0)
396 .as_any()
397 .downcast_ref::<Int64Array>()
398 .expect("result should be Int64 after the divide");
399 assert_eq!(col.value(0), 500, "500 ms difference");
400 }
401}