1use std::any::Any;
10use std::sync::{Arc, Mutex};
11
12use arrow::array::RecordBatch;
13use arrow::datatypes::SchemaRef;
14use async_trait::async_trait;
15use datafusion::catalog::Session;
16use datafusion::datasource::TableProvider;
17use datafusion::error::DataFusionError;
18use datafusion::execution::{SendableRecordBatchStream, TaskContext};
19use datafusion::logical_expr::Expr;
20use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
21use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
22use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
23use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
24use datafusion_common::Statistics;
25use datafusion_expr::TableType;
26
27pub struct LiveSourceProvider {
34 current: Arc<Mutex<Vec<RecordBatch>>>,
35 schema: SchemaRef,
36}
37
38impl LiveSourceProvider {
39 #[must_use]
41 pub fn new(schema: SchemaRef) -> Self {
42 Self {
43 current: Arc::new(Mutex::new(Vec::new())),
44 schema,
45 }
46 }
47
48 #[must_use]
50 pub fn handle(&self) -> LiveSourceHandle {
51 LiveSourceHandle {
52 slot: Arc::clone(&self.current),
53 }
54 }
55}
56
57impl std::fmt::Debug for LiveSourceProvider {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 f.debug_struct("LiveSourceProvider")
60 .field("schema_fields", &self.schema.fields().len())
61 .finish_non_exhaustive()
62 }
63}
64
65#[async_trait]
66impl TableProvider for LiveSourceProvider {
67 fn as_any(&self) -> &dyn Any {
68 self
69 }
70
71 fn schema(&self) -> SchemaRef {
72 self.schema.clone()
73 }
74
75 fn table_type(&self) -> TableType {
76 TableType::Base
77 }
78
79 async fn scan(
80 &self,
81 _state: &dyn Session,
82 projection: Option<&Vec<usize>>,
83 _filters: &[Expr],
84 _limit: Option<usize>,
85 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
86 Ok(Arc::new(LiveSourceExec::new(
87 Arc::clone(&self.current),
88 self.schema.clone(),
89 projection.cloned(),
90 )))
91 }
92}
93
94pub(crate) struct LiveSourceExec {
100 slot: Arc<Mutex<Vec<RecordBatch>>>,
101 schema: SchemaRef,
102 projection: Option<Vec<usize>>,
103 properties: PlanProperties,
104}
105
106impl LiveSourceExec {
107 fn new(
108 slot: Arc<Mutex<Vec<RecordBatch>>>,
109 source_schema: SchemaRef,
110 projection: Option<Vec<usize>>,
111 ) -> Self {
112 let schema = match &projection {
113 Some(indices) => {
114 let fields: Vec<_> = indices
115 .iter()
116 .map(|&i| source_schema.field(i).clone())
117 .collect();
118 Arc::new(arrow::datatypes::Schema::new(fields))
119 }
120 None => source_schema,
121 };
122 let properties = PlanProperties::new(
123 EquivalenceProperties::new(schema.clone()),
124 Partitioning::UnknownPartitioning(1),
125 EmissionType::Final,
126 Boundedness::Bounded,
127 );
128 Self {
129 slot,
130 schema,
131 projection,
132 properties,
133 }
134 }
135}
136
137impl std::fmt::Debug for LiveSourceExec {
138 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139 f.debug_struct("LiveSourceExec")
140 .field("schema_fields", &self.schema.fields().len())
141 .finish_non_exhaustive()
142 }
143}
144
145impl DisplayAs for LiveSourceExec {
146 fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 match t {
148 DisplayFormatType::Default | DisplayFormatType::Verbose => {
149 write!(f, "LiveSourceExec: schema={}", self.schema.fields().len())
150 }
151 DisplayFormatType::TreeRender => write!(f, "LiveSourceExec"),
152 }
153 }
154}
155
156impl ExecutionPlan for LiveSourceExec {
157 fn name(&self) -> &'static str {
158 "LiveSourceExec"
159 }
160
161 fn as_any(&self) -> &dyn Any {
162 self
163 }
164
165 fn schema(&self) -> SchemaRef {
166 self.schema.clone()
167 }
168
169 fn properties(&self) -> &PlanProperties {
170 &self.properties
171 }
172
173 fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
174 vec![]
175 }
176
177 fn with_new_children(
178 self: Arc<Self>,
179 children: Vec<Arc<dyn ExecutionPlan>>,
180 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
181 if children.is_empty() {
182 Ok(self)
183 } else {
184 Err(DataFusionError::Plan(
185 "LiveSourceExec is a leaf node".to_string(),
186 ))
187 }
188 }
189
190 fn execute(
191 &self,
192 partition: usize,
193 _context: Arc<TaskContext>,
194 ) -> Result<SendableRecordBatchStream, DataFusionError> {
195 if partition != 0 {
196 return Err(DataFusionError::Plan(format!(
197 "LiveSourceExec only supports partition 0, got {partition}"
198 )));
199 }
200
201 let batches = self.slot.lock().expect("LiveSourceExec poisoned").clone();
202 let schema = self.schema.clone();
203 let projection = self.projection.clone();
204
205 let output = futures::stream::iter(if batches.is_empty() {
207 vec![Ok(RecordBatch::new_empty(schema))]
208 } else if let Some(indices) = projection {
209 batches
210 .into_iter()
211 .map(move |batch| batch.project(&indices).map_err(DataFusionError::from))
212 .collect()
213 } else {
214 batches.into_iter().map(Ok).collect()
215 });
216
217 Ok(Box::pin(RecordBatchStreamAdapter::new(
218 self.schema.clone(),
219 output,
220 )))
221 }
222
223 fn statistics(&self) -> datafusion_common::Result<Statistics> {
224 Ok(Statistics::default())
225 }
226}
227
228impl datafusion::physical_plan::ExecutionPlanProperties for LiveSourceExec {
229 fn output_partitioning(&self) -> &Partitioning {
230 self.properties.output_partitioning()
231 }
232
233 fn output_ordering(&self) -> Option<&datafusion::physical_expr::LexOrdering> {
234 self.properties.output_ordering()
235 }
236
237 fn boundedness(&self) -> Boundedness {
238 Boundedness::Bounded
239 }
240
241 fn pipeline_behavior(&self) -> EmissionType {
242 EmissionType::Final
243 }
244
245 fn equivalence_properties(&self) -> &EquivalenceProperties {
246 self.properties.equivalence_properties()
247 }
248}
249
250#[derive(Clone)]
254pub struct LiveSourceHandle {
255 slot: Arc<Mutex<Vec<RecordBatch>>>,
256}
257
258impl LiveSourceHandle {
259 pub fn swap(&self, batches: Vec<RecordBatch>) {
266 let mut guard = self.slot.lock().expect("LiveSourceHandle poisoned");
267 guard.clear();
268 guard.extend(batches);
269 }
270
271 pub fn clear(&self) {
277 self.slot.lock().expect("LiveSourceHandle poisoned").clear();
278 }
279}
280
281impl std::fmt::Debug for LiveSourceHandle {
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.debug_struct("LiveSourceHandle").finish()
284 }
285}
286
287#[cfg(test)]
290mod tests {
291 use super::*;
292 use arrow::array::{Float64Array, Int64Array, StringArray};
293 use arrow::datatypes::{DataType, Field, Schema};
294
295 fn test_schema() -> SchemaRef {
296 Arc::new(Schema::new(vec![
297 Field::new("id", DataType::Int64, false),
298 Field::new("name", DataType::Utf8, true),
299 Field::new("price", DataType::Float64, true),
300 ]))
301 }
302
303 fn make_batch(ids: &[i64], names: &[&str], prices: &[f64]) -> RecordBatch {
304 RecordBatch::try_new(
305 test_schema(),
306 vec![
307 Arc::new(Int64Array::from(ids.to_vec())),
308 Arc::new(StringArray::from(
309 names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
310 )),
311 Arc::new(Float64Array::from(prices.to_vec())),
312 ],
313 )
314 .unwrap()
315 }
316
317 fn test_ctx() -> datafusion::prelude::SessionContext {
318 datafusion::prelude::SessionContext::new()
320 }
321
322 async fn count_rows(ctx: &datafusion::prelude::SessionContext, sql: &str) -> usize {
323 let df = ctx.sql(sql).await.unwrap();
324 df.collect()
325 .await
326 .unwrap()
327 .iter()
328 .map(RecordBatch::num_rows)
329 .sum()
330 }
331
332 #[test]
333 fn test_handle_swap_and_clear() {
334 let provider = LiveSourceProvider::new(test_schema());
335 let h1 = provider.handle();
336 let h2 = h1.clone();
337
338 h1.swap(vec![make_batch(&[1, 2], &["A", "B"], &[1.0, 2.0])]);
339 assert_eq!(h2.slot.lock().unwrap().len(), 1);
340
341 h2.clear();
342 assert_eq!(h1.slot.lock().unwrap().len(), 0);
343 }
344
345 #[tokio::test]
346 async fn test_scan_reads_fresh_data_each_execute() {
347 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
348 let handle = provider.handle();
349 let ctx = test_ctx();
350 ctx.register_table("t", provider).unwrap();
351
352 handle.swap(vec![make_batch(
353 &[1, 2, 3],
354 &["A", "B", "C"],
355 &[10.0, 20.0, 30.0],
356 )]);
357 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 3);
358 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 3);
359 }
360
361 #[tokio::test]
362 async fn test_scan_empty() {
363 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
364 let ctx = test_ctx();
365 ctx.register_table("t", provider).unwrap();
366 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 0);
367 }
368
369 #[tokio::test]
370 async fn test_projection() {
371 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
372 let handle = provider.handle();
373 let ctx = test_ctx();
374 ctx.register_table("t", provider).unwrap();
375
376 handle.swap(vec![make_batch(
377 &[1, 2, 3],
378 &["A", "B", "C"],
379 &[10.0, 20.0, 30.0],
380 )]);
381
382 let df = ctx.sql("SELECT id, price FROM t").await.unwrap();
383 let result = df.collect().await.unwrap();
384 assert_eq!(result.iter().map(RecordBatch::num_rows).sum::<usize>(), 3);
385 assert_eq!(result[0].schema().fields().len(), 2);
386 assert_eq!(result[0].schema().field(0).name(), "id");
387 assert_eq!(result[0].schema().field(1).name(), "price");
388 }
389
390 #[tokio::test]
391 async fn test_multi_cycle() {
392 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
393 let handle = provider.handle();
394 let ctx = test_ctx();
395 ctx.register_table("t", provider).unwrap();
396
397 handle.swap(vec![make_batch(&[1], &["A"], &[10.0])]);
398 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 1);
399
400 handle.swap(vec![make_batch(&[2, 3], &["B", "C"], &[20.0, 30.0])]);
401 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 2);
402
403 handle.clear();
404 assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 0);
405 }
406
407 #[tokio::test]
408 async fn test_cached_plan_sees_fresh_data() {
409 use datafusion::physical_plan::ExecutionPlanProperties as _;
410
411 let provider = Arc::new(LiveSourceProvider::new(test_schema()));
412 let handle = provider.handle();
413 let ctx = test_ctx();
414 ctx.register_table("t", provider).unwrap();
415
416 handle.swap(vec![make_batch(&[1], &["A"], &[10.0])]);
417 let logical = ctx
418 .state()
419 .create_logical_plan("SELECT * FROM t")
420 .await
421 .unwrap();
422 let physical = ctx.state().create_physical_plan(&logical).await.unwrap();
423 assert_eq!(physical.output_partitioning().partition_count(), 1);
424
425 let task_ctx = ctx.task_ctx();
426 let r1 = datafusion::physical_plan::collect(physical.clone(), task_ctx.clone())
427 .await
428 .unwrap();
429 assert_eq!(r1.iter().map(RecordBatch::num_rows).sum::<usize>(), 1);
430
431 handle.swap(vec![make_batch(
432 &[2, 3, 4],
433 &["B", "C", "D"],
434 &[20.0, 30.0, 40.0],
435 )]);
436 let r2 = datafusion::physical_plan::collect(physical.clone(), task_ctx.clone())
437 .await
438 .unwrap();
439 assert_eq!(r2.iter().map(RecordBatch::num_rows).sum::<usize>(), 3);
440
441 handle.clear();
442 let r3 = datafusion::physical_plan::collect(physical, task_ctx)
443 .await
444 .unwrap();
445 assert_eq!(r3.iter().map(RecordBatch::num_rows).sum::<usize>(), 0);
446 }
447}