Skip to main content

laminar_sql/datafusion/
table_provider.rs

1//! Streaming table provider for `DataFusion` integration
2//!
3//! This module provides `StreamingTableProvider` which implements `DataFusion`'s
4//! `TableProvider` trait, allowing streaming sources to be registered as
5//! tables in a `SessionContext` and queried with SQL.
6
7use std::any::Any;
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use datafusion::catalog::Session;
13use datafusion::datasource::TableProvider;
14use datafusion::physical_plan::ExecutionPlan;
15use datafusion_common::DataFusionError;
16use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
17
18use super::exec::StreamingScanExec;
19use super::source::StreamSourceRef;
20
21/// A `DataFusion` table provider backed by a streaming source.
22///
23/// This allows streaming sources to be registered as tables in `DataFusion`'s
24/// `SessionContext` and queried using SQL. The provider handles:
25///
26/// - Schema exposure to `DataFusion`'s catalog
27/// - Projection pushdown to the source
28/// - Filter pushdown when supported by the source
29///
30/// # Usage
31///
32/// ```rust,ignore
33/// let source = Arc::new(ChannelStreamSource::new(schema));
34/// let provider = StreamingTableProvider::new("events", source);
35/// ctx.register_table("events", Arc::new(provider))?;
36///
37/// let df = ctx.sql("SELECT * FROM events WHERE id > 100").await?;
38/// ```
39#[derive(Debug)]
40pub struct StreamingTableProvider {
41    /// Table name
42    name: String,
43    /// The underlying streaming source
44    source: StreamSourceRef,
45}
46
47impl StreamingTableProvider {
48    /// Creates a new streaming table provider.
49    ///
50    /// # Arguments
51    ///
52    /// * `name` - Name of the table (used for display/debugging)
53    /// * `source` - The streaming source backing this table
54    #[must_use]
55    pub fn new(name: impl Into<String>, source: StreamSourceRef) -> Self {
56        Self {
57            name: name.into(),
58            source,
59        }
60    }
61
62    /// Returns the table name.
63    #[must_use]
64    pub fn name(&self) -> &str {
65        &self.name
66    }
67
68    /// Returns the underlying streaming source.
69    #[must_use]
70    pub fn source(&self) -> &StreamSourceRef {
71        &self.source
72    }
73}
74
75#[async_trait]
76impl TableProvider for StreamingTableProvider {
77    fn as_any(&self) -> &dyn Any {
78        self
79    }
80
81    fn schema(&self) -> SchemaRef {
82        self.source.schema()
83    }
84
85    fn table_type(&self) -> TableType {
86        // Streaming tables behave like base tables but are read-only
87        TableType::Base
88    }
89
90    fn supports_filters_pushdown(
91        &self,
92        filters: &[&Expr],
93    ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
94        // Ask the source which filters it can handle
95        let expr_refs: Vec<Expr> = filters.iter().map(|e| (*e).clone()).collect();
96        let supported = self.source.supports_filters(&expr_refs);
97
98        Ok(supported
99            .into_iter()
100            .map(|s| {
101                if s {
102                    TableProviderFilterPushDown::Exact
103                } else {
104                    TableProviderFilterPushDown::Unsupported
105                }
106            })
107            .collect())
108    }
109
110    async fn scan(
111        &self,
112        _state: &dyn Session,
113        projection: Option<&Vec<usize>>,
114        filters: &[Expr],
115        _limit: Option<usize>,
116    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
117        // Determine which filters the source supports
118        let supported = self.source.supports_filters(filters);
119        let pushed_filters: Vec<Expr> = filters
120            .iter()
121            .zip(supported.iter())
122            .filter_map(|(f, &s)| if s { Some(f.clone()) } else { None })
123            .collect();
124
125        Ok(Arc::new(StreamingScanExec::new(
126            Arc::clone(&self.source),
127            projection.cloned(),
128            pushed_filters,
129        )))
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::datafusion::source::StreamSource;
137    use arrow_schema::{DataType, Field, Schema};
138    use datafusion::execution::SendableRecordBatchStream;
139
140    #[derive(Debug)]
141    struct MockSource {
142        schema: SchemaRef,
143        supports_eq_filter: bool,
144    }
145
146    #[async_trait]
147    impl StreamSource for MockSource {
148        fn schema(&self) -> SchemaRef {
149            Arc::clone(&self.schema)
150        }
151
152        fn stream(
153            &self,
154            _projection: Option<Vec<usize>>,
155            _filters: Vec<Expr>,
156        ) -> Result<SendableRecordBatchStream, DataFusionError> {
157            Err(DataFusionError::NotImplemented("mock".to_string()))
158        }
159
160        fn supports_filters(&self, filters: &[Expr]) -> Vec<bool> {
161            filters
162                .iter()
163                .map(|f| {
164                    if self.supports_eq_filter {
165                        // Only support equality filters for testing
166                        matches!(f, Expr::BinaryExpr(e) if e.op == datafusion_expr::Operator::Eq)
167                    } else {
168                        false
169                    }
170                })
171                .collect()
172        }
173    }
174
175    fn test_schema() -> SchemaRef {
176        Arc::new(Schema::new(vec![
177            Field::new("id", DataType::Int64, false),
178            Field::new("name", DataType::Utf8, true),
179        ]))
180    }
181
182    #[test]
183    fn test_table_provider_schema() {
184        let schema = test_schema();
185        let source: StreamSourceRef = Arc::new(MockSource {
186            schema: Arc::clone(&schema),
187            supports_eq_filter: false,
188        });
189        let provider = StreamingTableProvider::new("test_table", source);
190
191        assert_eq!(provider.schema(), schema);
192        assert_eq!(provider.name(), "test_table");
193    }
194
195    #[test]
196    fn test_table_provider_type() {
197        let schema = test_schema();
198        let source: StreamSourceRef = Arc::new(MockSource {
199            schema,
200            supports_eq_filter: false,
201        });
202        let provider = StreamingTableProvider::new("test", source);
203
204        assert_eq!(provider.table_type(), TableType::Base);
205    }
206
207    #[test]
208    fn test_filter_pushdown_unsupported() {
209        let schema = test_schema();
210        let source: StreamSourceRef = Arc::new(MockSource {
211            schema,
212            supports_eq_filter: false,
213        });
214        let provider = StreamingTableProvider::new("test", source);
215
216        let filter = Expr::Literal(datafusion_common::ScalarValue::Int64(Some(1)), None);
217        let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
218
219        assert_eq!(result.len(), 1);
220        assert!(matches!(
221            result[0],
222            TableProviderFilterPushDown::Unsupported
223        ));
224    }
225
226    #[test]
227    fn test_filter_pushdown_supported() {
228        let schema = test_schema();
229        let source: StreamSourceRef = Arc::new(MockSource {
230            schema,
231            supports_eq_filter: true,
232        });
233        let provider = StreamingTableProvider::new("test", source);
234
235        // Create an equality filter: id = 1
236        let filter = Expr::BinaryExpr(datafusion_expr::BinaryExpr {
237            left: Box::new(Expr::Column(datafusion_common::Column::new_unqualified(
238                "id",
239            ))),
240            op: datafusion_expr::Operator::Eq,
241            right: Box::new(Expr::Literal(
242                datafusion_common::ScalarValue::Int64(Some(1)),
243                None,
244            )),
245        });
246        let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
247
248        assert_eq!(result.len(), 1);
249        assert!(matches!(result[0], TableProviderFilterPushDown::Exact));
250    }
251
252    #[tokio::test]
253    async fn test_scan_creates_exec() {
254        use crate::datafusion::create_session_context;
255
256        let schema = test_schema();
257        let source: StreamSourceRef = Arc::new(MockSource {
258            schema: Arc::clone(&schema),
259            supports_eq_filter: false,
260        });
261        let provider = StreamingTableProvider::new("test", source);
262
263        let ctx = create_session_context();
264        let session_state = ctx.state();
265
266        let exec = provider
267            .scan(&session_state, None, &[], None)
268            .await
269            .unwrap();
270
271        // Verify it's a StreamingScanExec
272        assert!(exec.as_any().is::<StreamingScanExec>());
273        assert_eq!(exec.schema(), schema);
274    }
275
276    #[tokio::test]
277    async fn test_scan_with_projection() {
278        use crate::datafusion::create_session_context;
279
280        let schema = test_schema();
281        let source: StreamSourceRef = Arc::new(MockSource {
282            schema,
283            supports_eq_filter: false,
284        });
285        let provider = StreamingTableProvider::new("test", source);
286
287        let ctx = create_session_context();
288        let session_state = ctx.state();
289
290        let projection = vec![0]; // Only id column
291        let exec = provider
292            .scan(&session_state, Some(&projection), &[], None)
293            .await
294            .unwrap();
295
296        let output_schema = exec.schema();
297        assert_eq!(output_schema.fields().len(), 1);
298        assert_eq!(output_schema.field(0).name(), "id");
299    }
300}