Skip to main content

laminar_sql/datafusion/
table_provider.rs

1//! Streaming table provider for `DataFusion` integration
2//!
3//! This module provides `StreamingTableProvider` which implements `DataFusion`'s
4//! `TableProvider` trait, allowing streaming sources to be registered as
5//! tables in a `SessionContext` and queried with SQL.
6
7use std::any::Any;
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use datafusion::catalog::Session;
13use datafusion::datasource::TableProvider;
14use datafusion::physical_plan::ExecutionPlan;
15use datafusion_common::DataFusionError;
16use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
17
18use super::exec::StreamingScanExec;
19use super::source::StreamSourceRef;
20
21/// A `DataFusion` table provider backed by a streaming source.
22///
23/// This allows streaming sources to be registered as tables in `DataFusion`'s
24/// `SessionContext` and queried using SQL. The provider handles:
25///
26/// - Schema exposure to `DataFusion`'s catalog
27/// - Projection pushdown to the source
28/// - Filter pushdown when supported by the source
29///
30/// # Usage
31///
32/// ```rust,ignore
33/// let source = Arc::new(ChannelStreamSource::new(schema));
34/// let provider = StreamingTableProvider::new("events", source);
35/// ctx.register_table("events", Arc::new(provider))?;
36///
37/// let df = ctx.sql("SELECT * FROM events WHERE id > 100").await?;
38/// ```
39#[derive(Debug)]
40pub struct StreamingTableProvider {
41    /// Table name
42    name: String,
43    /// The underlying streaming source
44    source: StreamSourceRef,
45}
46
47impl StreamingTableProvider {
48    /// Creates a new streaming table provider.
49    #[must_use]
50    pub fn new(name: impl Into<String>, source: StreamSourceRef) -> Self {
51        Self {
52            name: name.into(),
53            source,
54        }
55    }
56
57    /// Returns the table name.
58    #[must_use]
59    pub fn name(&self) -> &str {
60        &self.name
61    }
62
63    /// Returns the underlying streaming source.
64    #[must_use]
65    pub fn source(&self) -> &StreamSourceRef {
66        &self.source
67    }
68}
69
70#[async_trait]
71impl TableProvider for StreamingTableProvider {
72    fn as_any(&self) -> &dyn Any {
73        self
74    }
75
76    fn schema(&self) -> SchemaRef {
77        self.source.schema()
78    }
79
80    fn table_type(&self) -> TableType {
81        // Streaming tables behave like base tables but are read-only
82        TableType::Base
83    }
84
85    fn supports_filters_pushdown(
86        &self,
87        filters: &[&Expr],
88    ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
89        // Ask the source which filters it can handle
90        let expr_refs: Vec<Expr> = filters.iter().map(|e| (*e).clone()).collect();
91        let supported = self.source.supports_filters(&expr_refs);
92
93        Ok(supported
94            .into_iter()
95            .map(|s| {
96                if s {
97                    TableProviderFilterPushDown::Exact
98                } else {
99                    TableProviderFilterPushDown::Unsupported
100                }
101            })
102            .collect())
103    }
104
105    async fn scan(
106        &self,
107        _state: &dyn Session,
108        projection: Option<&Vec<usize>>,
109        filters: &[Expr],
110        _limit: Option<usize>,
111    ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
112        // Determine which filters the source supports
113        let supported = self.source.supports_filters(filters);
114        let pushed_filters: Vec<Expr> = filters
115            .iter()
116            .zip(supported.iter())
117            .filter_map(|(f, &s)| if s { Some(f.clone()) } else { None })
118            .collect();
119
120        Ok(Arc::new(StreamingScanExec::new(
121            Arc::clone(&self.source),
122            projection.cloned(),
123            pushed_filters,
124        )))
125    }
126}
127
128#[cfg(test)]
129mod tests {
130    use super::*;
131    use crate::datafusion::source::StreamSource;
132    use arrow_schema::{DataType, Field, Schema};
133    use datafusion::execution::SendableRecordBatchStream;
134
135    #[derive(Debug)]
136    struct MockSource {
137        schema: SchemaRef,
138        supports_eq_filter: bool,
139    }
140
141    #[async_trait]
142    impl StreamSource for MockSource {
143        fn schema(&self) -> SchemaRef {
144            Arc::clone(&self.schema)
145        }
146
147        fn stream(
148            &self,
149            _projection: Option<Vec<usize>>,
150            _filters: Vec<Expr>,
151        ) -> Result<SendableRecordBatchStream, DataFusionError> {
152            Err(DataFusionError::NotImplemented("mock".to_string()))
153        }
154
155        fn supports_filters(&self, filters: &[Expr]) -> Vec<bool> {
156            filters
157                .iter()
158                .map(|f| {
159                    if self.supports_eq_filter {
160                        // Only support equality filters for testing
161                        matches!(f, Expr::BinaryExpr(e) if e.op == datafusion_expr::Operator::Eq)
162                    } else {
163                        false
164                    }
165                })
166                .collect()
167        }
168    }
169
170    fn test_schema() -> SchemaRef {
171        Arc::new(Schema::new(vec![
172            Field::new("id", DataType::Int64, false),
173            Field::new("name", DataType::Utf8, true),
174        ]))
175    }
176
177    #[test]
178    fn test_table_provider_schema() {
179        let schema = test_schema();
180        let source: StreamSourceRef = Arc::new(MockSource {
181            schema: Arc::clone(&schema),
182            supports_eq_filter: false,
183        });
184        let provider = StreamingTableProvider::new("test_table", source);
185
186        assert_eq!(provider.schema(), schema);
187        assert_eq!(provider.name(), "test_table");
188    }
189
190    #[test]
191    fn test_table_provider_type() {
192        let schema = test_schema();
193        let source: StreamSourceRef = Arc::new(MockSource {
194            schema,
195            supports_eq_filter: false,
196        });
197        let provider = StreamingTableProvider::new("test", source);
198
199        assert_eq!(provider.table_type(), TableType::Base);
200    }
201
202    #[test]
203    fn test_filter_pushdown_unsupported() {
204        let schema = test_schema();
205        let source: StreamSourceRef = Arc::new(MockSource {
206            schema,
207            supports_eq_filter: false,
208        });
209        let provider = StreamingTableProvider::new("test", source);
210
211        let filter = Expr::Literal(datafusion_common::ScalarValue::Int64(Some(1)), None);
212        let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
213
214        assert_eq!(result.len(), 1);
215        assert!(matches!(
216            result[0],
217            TableProviderFilterPushDown::Unsupported
218        ));
219    }
220
221    #[test]
222    fn test_filter_pushdown_supported() {
223        let schema = test_schema();
224        let source: StreamSourceRef = Arc::new(MockSource {
225            schema,
226            supports_eq_filter: true,
227        });
228        let provider = StreamingTableProvider::new("test", source);
229
230        // Create an equality filter: id = 1
231        let filter = Expr::BinaryExpr(datafusion_expr::BinaryExpr {
232            left: Box::new(Expr::Column(datafusion_common::Column::new_unqualified(
233                "id",
234            ))),
235            op: datafusion_expr::Operator::Eq,
236            right: Box::new(Expr::Literal(
237                datafusion_common::ScalarValue::Int64(Some(1)),
238                None,
239            )),
240        });
241        let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
242
243        assert_eq!(result.len(), 1);
244        assert!(matches!(result[0], TableProviderFilterPushDown::Exact));
245    }
246
247    #[tokio::test]
248    async fn test_scan_creates_exec() {
249        use crate::datafusion::create_session_context;
250
251        let schema = test_schema();
252        let source: StreamSourceRef = Arc::new(MockSource {
253            schema: Arc::clone(&schema),
254            supports_eq_filter: false,
255        });
256        let provider = StreamingTableProvider::new("test", source);
257
258        let ctx = create_session_context();
259        let session_state = ctx.state();
260
261        let exec = provider
262            .scan(&session_state, None, &[], None)
263            .await
264            .unwrap();
265
266        // Verify it's a StreamingScanExec
267        assert!(exec.as_any().is::<StreamingScanExec>());
268        assert_eq!(exec.schema(), schema);
269    }
270
271    #[tokio::test]
272    async fn test_scan_with_projection() {
273        use crate::datafusion::create_session_context;
274
275        let schema = test_schema();
276        let source: StreamSourceRef = Arc::new(MockSource {
277            schema,
278            supports_eq_filter: false,
279        });
280        let provider = StreamingTableProvider::new("test", source);
281
282        let ctx = create_session_context();
283        let session_state = ctx.state();
284
285        let projection = vec![0]; // Only id column
286        let exec = provider
287            .scan(&session_state, Some(&projection), &[], None)
288            .await
289            .unwrap();
290
291        let output_schema = exec.schema();
292        assert_eq!(output_schema.fields().len(), 1);
293        assert_eq!(output_schema.field(0).name(), "id");
294    }
295}