Skip to main content

laminar_sql/datafusion/
live_source.rs

1//! Swappable table provider that eliminates per-cycle catalog churn and
2//! enables physical plan caching.
3//!
4//! Register a [`LiveSourceProvider`] once at pipeline startup. Each cycle,
5//! swap batches via [`LiveSourceHandle`], then execute the cached physical
6//! plan. The internal `LiveSourceExec` reads from the shared slot at `execute()` time,
7//! so the cached plan always sees fresh data.
8
9use std::any::Any;
10use std::sync::{Arc, Mutex};
11
12use arrow::array::RecordBatch;
13use arrow::datatypes::SchemaRef;
14use async_trait::async_trait;
15use datafusion::catalog::Session;
16use datafusion::datasource::TableProvider;
17use datafusion::error::DataFusionError;
18use datafusion::execution::{SendableRecordBatchStream, TaskContext};
19use datafusion::logical_expr::Expr;
20use datafusion::physical_expr::{EquivalenceProperties, Partitioning};
21use datafusion::physical_plan::execution_plan::{Boundedness, EmissionType};
22use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
23use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties};
24use datafusion_common::Statistics;
25use datafusion_expr::TableType;
26
27// ── TableProvider ────────────────────────────────────────────────────
28
29/// Swappable `TableProvider` for streaming micro-batch execution.
30///
31/// `scan()` returns an internal execution plan that reads from the shared
32/// batch slot at `execute()` time — enabling physical plan caching.
33pub struct LiveSourceProvider {
34    current: Arc<Mutex<Vec<RecordBatch>>>,
35    schema: SchemaRef,
36}
37
38impl LiveSourceProvider {
39    /// Creates a provider with the given schema and an empty batch slot.
40    #[must_use]
41    pub fn new(schema: SchemaRef) -> Self {
42        Self {
43            current: Arc::new(Mutex::new(Vec::new())),
44            schema,
45        }
46    }
47
48    /// Returns a handle for swapping batches into this provider.
49    #[must_use]
50    pub fn handle(&self) -> LiveSourceHandle {
51        LiveSourceHandle {
52            slot: Arc::clone(&self.current),
53        }
54    }
55}
56
57impl std::fmt::Debug for LiveSourceProvider {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        f.debug_struct("LiveSourceProvider")
60            .field("schema_fields", &self.schema.fields().len())
61            .finish_non_exhaustive()
62    }
63}
64
65#[async_trait]
66impl TableProvider for LiveSourceProvider {
67    fn as_any(&self) -> &dyn Any {
68        self
69    }
70
71    fn schema(&self) -> SchemaRef {
72        self.schema.clone()
73    }
74
75    fn table_type(&self) -> TableType {
76        TableType::Base
77    }
78
79    async fn scan(
80        &self,
81        _state: &dyn Session,
82        projection: Option<&Vec<usize>>,
83        _filters: &[Expr],
84        _limit: Option<usize>,
85    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
86        Ok(Arc::new(LiveSourceExec::new(
87            Arc::clone(&self.current),
88            self.schema.clone(),
89            projection.cloned(),
90        )))
91    }
92}
93
94// ── ExecutionPlan ────────────────────────────────────────────────────
95
96/// Leaf `ExecutionPlan` that reads from a shared batch slot at `execute()`
97/// time, not at construction time. This enables physical plan caching:
98/// the plan tree is built once, and each `execute()` call sees fresh data.
99pub(crate) struct LiveSourceExec {
100    slot: Arc<Mutex<Vec<RecordBatch>>>,
101    schema: SchemaRef,
102    projection: Option<Vec<usize>>,
103    properties: PlanProperties,
104}
105
106impl LiveSourceExec {
107    fn new(
108        slot: Arc<Mutex<Vec<RecordBatch>>>,
109        source_schema: SchemaRef,
110        projection: Option<Vec<usize>>,
111    ) -> Self {
112        let schema = match &projection {
113            Some(indices) => {
114                let fields: Vec<_> = indices
115                    .iter()
116                    .map(|&i| source_schema.field(i).clone())
117                    .collect();
118                Arc::new(arrow::datatypes::Schema::new(fields))
119            }
120            None => source_schema,
121        };
122        let properties = PlanProperties::new(
123            EquivalenceProperties::new(schema.clone()),
124            Partitioning::UnknownPartitioning(1),
125            EmissionType::Final,
126            Boundedness::Bounded,
127        );
128        Self {
129            slot,
130            schema,
131            projection,
132            properties,
133        }
134    }
135}
136
137impl std::fmt::Debug for LiveSourceExec {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        f.debug_struct("LiveSourceExec")
140            .field("schema_fields", &self.schema.fields().len())
141            .finish_non_exhaustive()
142    }
143}
144
145impl DisplayAs for LiveSourceExec {
146    fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        match t {
148            DisplayFormatType::Default | DisplayFormatType::Verbose => {
149                write!(f, "LiveSourceExec: schema={}", self.schema.fields().len())
150            }
151            DisplayFormatType::TreeRender => write!(f, "LiveSourceExec"),
152        }
153    }
154}
155
156impl ExecutionPlan for LiveSourceExec {
157    fn name(&self) -> &'static str {
158        "LiveSourceExec"
159    }
160
161    fn as_any(&self) -> &dyn Any {
162        self
163    }
164
165    fn schema(&self) -> SchemaRef {
166        self.schema.clone()
167    }
168
169    fn properties(&self) -> &PlanProperties {
170        &self.properties
171    }
172
173    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
174        vec![]
175    }
176
177    fn with_new_children(
178        self: Arc<Self>,
179        children: Vec<Arc<dyn ExecutionPlan>>,
180    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
181        if children.is_empty() {
182            Ok(self)
183        } else {
184            Err(DataFusionError::Plan(
185                "LiveSourceExec is a leaf node".to_string(),
186            ))
187        }
188    }
189
190    fn execute(
191        &self,
192        partition: usize,
193        _context: Arc<TaskContext>,
194    ) -> Result<SendableRecordBatchStream, DataFusionError> {
195        if partition != 0 {
196            return Err(DataFusionError::Plan(format!(
197                "LiveSourceExec only supports partition 0, got {partition}"
198            )));
199        }
200
201        let batches = self.slot.lock().expect("LiveSourceExec poisoned").clone();
202        let schema = self.schema.clone();
203        let projection = self.projection.clone();
204
205        // Stream batches individually — no concat. Apply projection per-batch.
206        let output = futures::stream::iter(if batches.is_empty() {
207            vec![Ok(RecordBatch::new_empty(schema))]
208        } else if let Some(indices) = projection {
209            batches
210                .into_iter()
211                .map(move |batch| batch.project(&indices).map_err(DataFusionError::from))
212                .collect()
213        } else {
214            batches.into_iter().map(Ok).collect()
215        });
216
217        Ok(Box::pin(RecordBatchStreamAdapter::new(
218            self.schema.clone(),
219            output,
220        )))
221    }
222
223    fn statistics(&self) -> datafusion_common::Result<Statistics> {
224        Ok(Statistics::default())
225    }
226}
227
228impl datafusion::physical_plan::ExecutionPlanProperties for LiveSourceExec {
229    fn output_partitioning(&self) -> &Partitioning {
230        self.properties.output_partitioning()
231    }
232
233    fn output_ordering(&self) -> Option<&datafusion::physical_expr::LexOrdering> {
234        self.properties.output_ordering()
235    }
236
237    fn boundedness(&self) -> Boundedness {
238        Boundedness::Bounded
239    }
240
241    fn pipeline_behavior(&self) -> EmissionType {
242        EmissionType::Final
243    }
244
245    fn equivalence_properties(&self) -> &EquivalenceProperties {
246        self.properties.equivalence_properties()
247    }
248}
249
250// ── Handle ───────────────────────────────────────────────────────────
251
252/// Handle for swapping batches into a [`LiveSourceProvider`].
253#[derive(Clone)]
254pub struct LiveSourceHandle {
255    slot: Arc<Mutex<Vec<RecordBatch>>>,
256}
257
258impl LiveSourceHandle {
259    /// Replace current batches.
260    ///
261    /// # Panics
262    ///
263    /// Panics if the internal mutex is poisoned (a thread panicked while
264    /// holding it). A poisoned mutex indicates corrupt pipeline state.
265    pub fn swap(&self, batches: Vec<RecordBatch>) {
266        let mut guard = self.slot.lock().expect("LiveSourceHandle poisoned");
267        guard.clear();
268        guard.extend(batches);
269    }
270
271    /// Clear all pending batches.
272    ///
273    /// # Panics
274    ///
275    /// Panics if the internal mutex is poisoned.
276    pub fn clear(&self) {
277        self.slot.lock().expect("LiveSourceHandle poisoned").clear();
278    }
279}
280
281impl std::fmt::Debug for LiveSourceHandle {
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.debug_struct("LiveSourceHandle").finish()
284    }
285}
286
287// ── Tests ────────────────────────────────────────────────────────────
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292    use arrow::array::{Float64Array, Int64Array, StringArray};
293    use arrow::datatypes::{DataType, Field, Schema};
294
295    fn test_schema() -> SchemaRef {
296        Arc::new(Schema::new(vec![
297            Field::new("id", DataType::Int64, false),
298            Field::new("name", DataType::Utf8, true),
299            Field::new("price", DataType::Float64, true),
300        ]))
301    }
302
303    fn make_batch(ids: &[i64], names: &[&str], prices: &[f64]) -> RecordBatch {
304        RecordBatch::try_new(
305            test_schema(),
306            vec![
307                Arc::new(Int64Array::from(ids.to_vec())),
308                Arc::new(StringArray::from(
309                    names.iter().map(|s| Some(*s)).collect::<Vec<_>>(),
310                )),
311                Arc::new(Float64Array::from(prices.to_vec())),
312            ],
313        )
314        .unwrap()
315    }
316
317    fn test_ctx() -> datafusion::prelude::SessionContext {
318        // Use a plain context (no streaming validator) for unit tests.
319        datafusion::prelude::SessionContext::new()
320    }
321
322    async fn count_rows(ctx: &datafusion::prelude::SessionContext, sql: &str) -> usize {
323        let df = ctx.sql(sql).await.unwrap();
324        df.collect()
325            .await
326            .unwrap()
327            .iter()
328            .map(RecordBatch::num_rows)
329            .sum()
330    }
331
332    #[test]
333    fn test_handle_swap_and_clear() {
334        let provider = LiveSourceProvider::new(test_schema());
335        let h1 = provider.handle();
336        let h2 = h1.clone();
337
338        h1.swap(vec![make_batch(&[1, 2], &["A", "B"], &[1.0, 2.0])]);
339        assert_eq!(h2.slot.lock().unwrap().len(), 1);
340
341        h2.clear();
342        assert_eq!(h1.slot.lock().unwrap().len(), 0);
343    }
344
345    #[tokio::test]
346    async fn test_scan_reads_fresh_data_each_execute() {
347        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
348        let handle = provider.handle();
349        let ctx = test_ctx();
350        ctx.register_table("t", provider).unwrap();
351
352        handle.swap(vec![make_batch(
353            &[1, 2, 3],
354            &["A", "B", "C"],
355            &[10.0, 20.0, 30.0],
356        )]);
357        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 3);
358        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 3);
359    }
360
361    #[tokio::test]
362    async fn test_scan_empty() {
363        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
364        let ctx = test_ctx();
365        ctx.register_table("t", provider).unwrap();
366        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 0);
367    }
368
369    #[tokio::test]
370    async fn test_projection() {
371        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
372        let handle = provider.handle();
373        let ctx = test_ctx();
374        ctx.register_table("t", provider).unwrap();
375
376        handle.swap(vec![make_batch(
377            &[1, 2, 3],
378            &["A", "B", "C"],
379            &[10.0, 20.0, 30.0],
380        )]);
381
382        let df = ctx.sql("SELECT id, price FROM t").await.unwrap();
383        let result = df.collect().await.unwrap();
384        assert_eq!(result.iter().map(RecordBatch::num_rows).sum::<usize>(), 3);
385        assert_eq!(result[0].schema().fields().len(), 2);
386        assert_eq!(result[0].schema().field(0).name(), "id");
387        assert_eq!(result[0].schema().field(1).name(), "price");
388    }
389
390    #[tokio::test]
391    async fn test_multi_cycle() {
392        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
393        let handle = provider.handle();
394        let ctx = test_ctx();
395        ctx.register_table("t", provider).unwrap();
396
397        handle.swap(vec![make_batch(&[1], &["A"], &[10.0])]);
398        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 1);
399
400        handle.swap(vec![make_batch(&[2, 3], &["B", "C"], &[20.0, 30.0])]);
401        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 2);
402
403        handle.clear();
404        assert_eq!(count_rows(&ctx, "SELECT * FROM t").await, 0);
405    }
406
407    #[tokio::test]
408    async fn test_cached_plan_sees_fresh_data() {
409        use datafusion::physical_plan::ExecutionPlanProperties as _;
410
411        let provider = Arc::new(LiveSourceProvider::new(test_schema()));
412        let handle = provider.handle();
413        let ctx = test_ctx();
414        ctx.register_table("t", provider).unwrap();
415
416        handle.swap(vec![make_batch(&[1], &["A"], &[10.0])]);
417        let logical = ctx
418            .state()
419            .create_logical_plan("SELECT * FROM t")
420            .await
421            .unwrap();
422        let physical = ctx.state().create_physical_plan(&logical).await.unwrap();
423        assert_eq!(physical.output_partitioning().partition_count(), 1);
424
425        let task_ctx = ctx.task_ctx();
426        let r1 = datafusion::physical_plan::collect(physical.clone(), task_ctx.clone())
427            .await
428            .unwrap();
429        assert_eq!(r1.iter().map(RecordBatch::num_rows).sum::<usize>(), 1);
430
431        handle.swap(vec![make_batch(
432            &[2, 3, 4],
433            &["B", "C", "D"],
434            &[20.0, 30.0, 40.0],
435        )]);
436        let r2 = datafusion::physical_plan::collect(physical.clone(), task_ctx.clone())
437            .await
438            .unwrap();
439        assert_eq!(r2.iter().map(RecordBatch::num_rows).sum::<usize>(), 3);
440
441        handle.clear();
442        let r3 = datafusion::physical_plan::collect(physical, task_ctx)
443            .await
444            .unwrap();
445        assert_eq!(r3.iter().map(RecordBatch::num_rows).sum::<usize>(), 0);
446    }
447}