laminar_sql/datafusion/
table_provider.rs1use std::any::Any;
8use std::sync::Arc;
9
10use arrow_schema::SchemaRef;
11use async_trait::async_trait;
12use datafusion::catalog::Session;
13use datafusion::datasource::TableProvider;
14use datafusion::physical_plan::ExecutionPlan;
15use datafusion_common::DataFusionError;
16use datafusion_expr::{Expr, TableProviderFilterPushDown, TableType};
17
18use super::exec::StreamingScanExec;
19use super::source::StreamSourceRef;
20
21#[derive(Debug)]
40pub struct StreamingTableProvider {
41 name: String,
43 source: StreamSourceRef,
45}
46
47impl StreamingTableProvider {
48 #[must_use]
55 pub fn new(name: impl Into<String>, source: StreamSourceRef) -> Self {
56 Self {
57 name: name.into(),
58 source,
59 }
60 }
61
62 #[must_use]
64 pub fn name(&self) -> &str {
65 &self.name
66 }
67
68 #[must_use]
70 pub fn source(&self) -> &StreamSourceRef {
71 &self.source
72 }
73}
74
75#[async_trait]
76impl TableProvider for StreamingTableProvider {
77 fn as_any(&self) -> &dyn Any {
78 self
79 }
80
81 fn schema(&self) -> SchemaRef {
82 self.source.schema()
83 }
84
85 fn table_type(&self) -> TableType {
86 TableType::Base
88 }
89
90 fn supports_filters_pushdown(
91 &self,
92 filters: &[&Expr],
93 ) -> Result<Vec<TableProviderFilterPushDown>, DataFusionError> {
94 let expr_refs: Vec<Expr> = filters.iter().map(|e| (*e).clone()).collect();
96 let supported = self.source.supports_filters(&expr_refs);
97
98 Ok(supported
99 .into_iter()
100 .map(|s| {
101 if s {
102 TableProviderFilterPushDown::Exact
103 } else {
104 TableProviderFilterPushDown::Unsupported
105 }
106 })
107 .collect())
108 }
109
110 async fn scan(
111 &self,
112 _state: &dyn Session,
113 projection: Option<&Vec<usize>>,
114 filters: &[Expr],
115 _limit: Option<usize>,
116 ) -> Result<Arc<dyn ExecutionPlan>, DataFusionError> {
117 let supported = self.source.supports_filters(filters);
119 let pushed_filters: Vec<Expr> = filters
120 .iter()
121 .zip(supported.iter())
122 .filter_map(|(f, &s)| if s { Some(f.clone()) } else { None })
123 .collect();
124
125 Ok(Arc::new(StreamingScanExec::new(
126 Arc::clone(&self.source),
127 projection.cloned(),
128 pushed_filters,
129 )))
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::datafusion::source::StreamSource;
137 use arrow_schema::{DataType, Field, Schema};
138 use datafusion::execution::SendableRecordBatchStream;
139
140 #[derive(Debug)]
141 struct MockSource {
142 schema: SchemaRef,
143 supports_eq_filter: bool,
144 }
145
146 #[async_trait]
147 impl StreamSource for MockSource {
148 fn schema(&self) -> SchemaRef {
149 Arc::clone(&self.schema)
150 }
151
152 fn stream(
153 &self,
154 _projection: Option<Vec<usize>>,
155 _filters: Vec<Expr>,
156 ) -> Result<SendableRecordBatchStream, DataFusionError> {
157 Err(DataFusionError::NotImplemented("mock".to_string()))
158 }
159
160 fn supports_filters(&self, filters: &[Expr]) -> Vec<bool> {
161 filters
162 .iter()
163 .map(|f| {
164 if self.supports_eq_filter {
165 matches!(f, Expr::BinaryExpr(e) if e.op == datafusion_expr::Operator::Eq)
167 } else {
168 false
169 }
170 })
171 .collect()
172 }
173 }
174
175 fn test_schema() -> SchemaRef {
176 Arc::new(Schema::new(vec![
177 Field::new("id", DataType::Int64, false),
178 Field::new("name", DataType::Utf8, true),
179 ]))
180 }
181
182 #[test]
183 fn test_table_provider_schema() {
184 let schema = test_schema();
185 let source: StreamSourceRef = Arc::new(MockSource {
186 schema: Arc::clone(&schema),
187 supports_eq_filter: false,
188 });
189 let provider = StreamingTableProvider::new("test_table", source);
190
191 assert_eq!(provider.schema(), schema);
192 assert_eq!(provider.name(), "test_table");
193 }
194
195 #[test]
196 fn test_table_provider_type() {
197 let schema = test_schema();
198 let source: StreamSourceRef = Arc::new(MockSource {
199 schema,
200 supports_eq_filter: false,
201 });
202 let provider = StreamingTableProvider::new("test", source);
203
204 assert_eq!(provider.table_type(), TableType::Base);
205 }
206
207 #[test]
208 fn test_filter_pushdown_unsupported() {
209 let schema = test_schema();
210 let source: StreamSourceRef = Arc::new(MockSource {
211 schema,
212 supports_eq_filter: false,
213 });
214 let provider = StreamingTableProvider::new("test", source);
215
216 let filter = Expr::Literal(datafusion_common::ScalarValue::Int64(Some(1)), None);
217 let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
218
219 assert_eq!(result.len(), 1);
220 assert!(matches!(
221 result[0],
222 TableProviderFilterPushDown::Unsupported
223 ));
224 }
225
226 #[test]
227 fn test_filter_pushdown_supported() {
228 let schema = test_schema();
229 let source: StreamSourceRef = Arc::new(MockSource {
230 schema,
231 supports_eq_filter: true,
232 });
233 let provider = StreamingTableProvider::new("test", source);
234
235 let filter = Expr::BinaryExpr(datafusion_expr::BinaryExpr {
237 left: Box::new(Expr::Column(datafusion_common::Column::new_unqualified(
238 "id",
239 ))),
240 op: datafusion_expr::Operator::Eq,
241 right: Box::new(Expr::Literal(
242 datafusion_common::ScalarValue::Int64(Some(1)),
243 None,
244 )),
245 });
246 let result = provider.supports_filters_pushdown(&[&filter]).unwrap();
247
248 assert_eq!(result.len(), 1);
249 assert!(matches!(result[0], TableProviderFilterPushDown::Exact));
250 }
251
252 #[tokio::test]
253 async fn test_scan_creates_exec() {
254 use crate::datafusion::create_session_context;
255
256 let schema = test_schema();
257 let source: StreamSourceRef = Arc::new(MockSource {
258 schema: Arc::clone(&schema),
259 supports_eq_filter: false,
260 });
261 let provider = StreamingTableProvider::new("test", source);
262
263 let ctx = create_session_context();
264 let session_state = ctx.state();
265
266 let exec = provider
267 .scan(&session_state, None, &[], None)
268 .await
269 .unwrap();
270
271 assert!(exec.as_any().is::<StreamingScanExec>());
273 assert_eq!(exec.schema(), schema);
274 }
275
276 #[tokio::test]
277 async fn test_scan_with_projection() {
278 use crate::datafusion::create_session_context;
279
280 let schema = test_schema();
281 let source: StreamSourceRef = Arc::new(MockSource {
282 schema,
283 supports_eq_filter: false,
284 });
285 let provider = StreamingTableProvider::new("test", source);
286
287 let ctx = create_session_context();
288 let session_state = ctx.state();
289
290 let projection = vec![0]; let exec = provider
292 .scan(&session_state, Some(&projection), &[], None)
293 .await
294 .unwrap();
295
296 let output_schema = exec.schema();
297 assert_eq!(output_schema.fields().len(), 1);
298 assert_eq!(output_schema.field(0).name(), "id");
299 }
300}