laminar_sql/datafusion/
table_provider.rs1use std::any::Any;
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use datafusion::catalog::Session;
13use datafusion::datasource::TableProvider;
14use datafusion::physical_plan::ExecutionPlan;
15use datafusion_common::DataFusionError;
16use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
17
18use super::exec::StreamingScanExec;
19use super::source::StreamSourceRef;
20
21#[derive(Debug)]
40pub struct StreamingTableProvider {
41 name: String,
43 source: StreamSourceRef,
45}
46
47impl StreamingTableProvider {
48 #[must_use]
50 pub fn new(name: impl Into<String>, source: StreamSourceRef) -> Self {
51 Self {
52 name: name.into(),
53 source,
54 }
55 }
56
57 #[must_use]
59 pub fn name(&self) -> &str {
60 &self.name
61 }
62
63 #[must_use]
65 pub fn source(&self) -> &StreamSourceRef {
66 &self.source
67 }
68}
69
70#[async_trait]
71impl TableProvider for StreamingTableProvider {
72 fn as_any(&self) -> &dyn Any {
73 self
74 }
75
76 fn schema(&self) -> SchemaRef {
77 self.source.schema()
78 }
79
80 fn table_type(&self) -> TableType {
81 TableType::Base
83 }
84
85 fn supports_filters_pushdown(
86 &self,
87 filters: &[&Expr],
88 ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
89 let expr_refs: Vec<Expr> = filters.iter().map(|e| (*e).clone()).collect();
91 let supported = self.source.supports_filters(&expr_refs);
92
93 Ok(supported
94 .into_iter()
95 .map(|s| {
96 if s {
97 TableProviderFilterPushDown::Exact
98 } else {
99 TableProviderFilterPushDown::Unsupported
100 }
101 })
102 .collect())
103 }
104
105 async fn scan(
106 &self,
107 _state: &dyn Session,
108 projection: Option<&Vec<usize>>,
109 filters: &[Expr],
110 _limit: Option<usize>,
111 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
112 let supported = self.source.supports_filters(filters);
114 let pushed_filters: Vec<Expr> = filters
115 .iter()
116 .zip(supported.iter())
117 .filter_map(|(f, &s)| if s { Some(f.clone()) } else { None })
118 .collect();
119
120 Ok(Arc::new(StreamingScanExec::new(
121 Arc::clone(&self.source),
122 projection.cloned(),
123 pushed_filters,
124 )))
125 }
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131 use crate::datafusion::source::StreamSource;
132 use arrow_schema::{DataType, Field, Schema};
133 use datafusion::execution::SendableRecordBatchStream;
134
135 #[derive(Debug)]
136 struct MockSource {
137 schema: SchemaRef,
138 supports_eq_filter: bool,
139 }
140
141 #[async_trait]
142 impl StreamSource for MockSource {
143 fn schema(&self) -> SchemaRef {
144 Arc::clone(&self.schema)
145 }
146
147 fn stream(
148 &self,
149 _projection: Option<Vec<usize>>,
150 _filters: Vec<Expr>,
151 ) -> Result<SendableRecordBatchStream, DataFusionError> {
152 Err(DataFusionError::NotImplemented("mock".to_string()))
153 }
154
155 fn supports_filters(&self, filters: &[Expr]) -> Vec<bool> {
156 filters
157 .iter()
158 .map(|f| {
159 if self.supports_eq_filter {
160 matches!(f, Expr::BinaryExpr(e) if e.op == datafusion_expr::Operator::Eq)
162 } else {
163 false
164 }
165 })
166 .collect()
167 }
168 }
169
170 fn test_schema() -> SchemaRef {
171 Arc::new(Schema::new(vec![
172 Field::new("id", DataType::Int64, false),
173 Field::new("name", DataType::Utf8, true),
174 ]))
175 }
176
177 #[test]
178 fn test_table_provider_schema() {
179 let schema = test_schema();
180 let source: StreamSourceRef = Arc::new(MockSource {
181 schema: Arc::clone(&schema),
182 supports_eq_filter: false,
183 });
184 let provider = StreamingTableProvider::new("test_table", source);
185
186 assert_eq!(provider.schema(), schema);
187 assert_eq!(provider.name(), "test_table");
188 }
189
190 #[test]
191 fn test_table_provider_type() {
192 let schema = test_schema();
193 let source: StreamSourceRef = Arc::new(MockSource {
194 schema,
195 supports_eq_filter: false,
196 });
197 let provider = StreamingTableProvider::new("test", source);
198
199 assert_eq!(provider.table_type(), TableType::Base);
200 }
201
202 #[test]
203 fn test_filter_pushdown_unsupported() {
204 let schema = test_schema();
205 let source: StreamSourceRef = Arc::new(MockSource {
206 schema,
207 supports_eq_filter: false,
208 });
209 let provider = StreamingTableProvider::new("test", source);
210
211 let filter = Expr::Literal(datafusion_common::ScalarValue::Int64(Some(1)), None);
212 let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
213
214 assert_eq!(result.len(), 1);
215 assert!(matches!(
216 result[0],
217 TableProviderFilterPushDown::Unsupported
218 ));
219 }
220
221 #[test]
222 fn test_filter_pushdown_supported() {
223 let schema = test_schema();
224 let source: StreamSourceRef = Arc::new(MockSource {
225 schema,
226 supports_eq_filter: true,
227 });
228 let provider = StreamingTableProvider::new("test", source);
229
230 let filter = Expr::BinaryExpr(datafusion_expr::BinaryExpr {
232 left: Box::new(Expr::Column(datafusion_common::Column::new_unqualified(
233 "id",
234 ))),
235 op: datafusion_expr::Operator::Eq,
236 right: Box::new(Expr::Literal(
237 datafusion_common::ScalarValue::Int64(Some(1)),
238 None,
239 )),
240 });
241 let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
242
243 assert_eq!(result.len(), 1);
244 assert!(matches!(result[0], TableProviderFilterPushDown::Exact));
245 }
246
247 #[tokio::test]
248 async fn test_scan_creates_exec() {
249 use crate::datafusion::create_session_context;
250
251 let schema = test_schema();
252 let source: StreamSourceRef = Arc::new(MockSource {
253 schema: Arc::clone(&schema),
254 supports_eq_filter: false,
255 });
256 let provider = StreamingTableProvider::new("test", source);
257
258 let ctx = create_session_context();
259 let session_state = ctx.state();
260
261 let exec = provider
262 .scan(&session_state, None, &[], None)
263 .await
264 .unwrap();
265
266 assert!(exec.as_any().is::<StreamingScanExec>());
268 assert_eq!(exec.schema(), schema);
269 }
270
271 #[tokio::test]
272 async fn test_scan_with_projection() {
273 use crate::datafusion::create_session_context;
274
275 let schema = test_schema();
276 let source: StreamSourceRef = Arc::new(MockSource {
277 schema,
278 supports_eq_filter: false,
279 });
280 let provider = StreamingTableProvider::new("test", source);
281
282 let ctx = create_session_context();
283 let session_state = ctx.state();
284
285 let projection = vec![0]; let exec = provider
287 .scan(&session_state, Some(&projection), &[], None)
288 .await
289 .unwrap();
290
291 let output_schema = exec.schema();
292 assert_eq!(output_schema.fields().len(), 1);
293 assert_eq!(output_schema.field(0).name(), "id");
294 }
295}