1use super::{
19 from_aggregate_rel, from_cast, from_cross_rel, from_exchange_rel, from_fetch_rel,
20 from_field_reference, from_filter_rel, from_if_then, from_join_rel, from_literal,
21 from_project_rel, from_read_rel, from_scalar_function, from_set_rel,
22 from_singular_or_list, from_sort_rel, from_subquery, from_substrait_rel,
23 from_substrait_rex, from_window_function,
24};
25use crate::extensions::Extensions;
26use async_trait::async_trait;
27use datafusion::arrow::datatypes::DataType;
28use datafusion::catalog::TableProvider;
29use datafusion::common::{
30 DFSchema, ScalarValue, TableReference, not_impl_err, substrait_err,
31};
32use datafusion::execution::{FunctionRegistry, SessionState};
33use datafusion::logical_expr::{Expr, Extension, LogicalPlan};
34use std::sync::Arc;
35use substrait::proto;
36use substrait::proto::expression as substrait_expression;
37use substrait::proto::expression::{
38 Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction,
39 SingularOrList, SwitchExpression, WindowFunction,
40};
41use substrait::proto::{
42 AggregateRel, ConsistentPartitionWindowRel, CrossRel, DynamicParameter, ExchangeRel,
43 Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, FetchRel,
44 FilterRel, JoinRel, ProjectRel, ReadRel, Rel, SetRel, SortRel, r#type,
45};
46
47#[async_trait]
48pub trait SubstraitConsumer: Send + Sync + Sized {
154 async fn resolve_table_ref(
155 &self,
156 table_ref: &TableReference,
157 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>>;
158
159 fn get_extensions(&self) -> &Extensions;
165 fn get_function_registry(&self) -> &impl FunctionRegistry;
166
167 async fn consume_rel(&self, rel: &Rel) -> datafusion::common::Result<LogicalPlan> {
175 from_substrait_rel(self, rel).await
176 }
177
178 async fn consume_read(
179 &self,
180 rel: &ReadRel,
181 ) -> datafusion::common::Result<LogicalPlan> {
182 from_read_rel(self, rel).await
183 }
184
185 async fn consume_filter(
186 &self,
187 rel: &FilterRel,
188 ) -> datafusion::common::Result<LogicalPlan> {
189 from_filter_rel(self, rel).await
190 }
191
192 async fn consume_fetch(
193 &self,
194 rel: &FetchRel,
195 ) -> datafusion::common::Result<LogicalPlan> {
196 from_fetch_rel(self, rel).await
197 }
198
199 async fn consume_aggregate(
200 &self,
201 rel: &AggregateRel,
202 ) -> datafusion::common::Result<LogicalPlan> {
203 from_aggregate_rel(self, rel).await
204 }
205
206 async fn consume_sort(
207 &self,
208 rel: &SortRel,
209 ) -> datafusion::common::Result<LogicalPlan> {
210 from_sort_rel(self, rel).await
211 }
212
213 async fn consume_join(
214 &self,
215 rel: &JoinRel,
216 ) -> datafusion::common::Result<LogicalPlan> {
217 from_join_rel(self, rel).await
218 }
219
220 async fn consume_project(
221 &self,
222 rel: &ProjectRel,
223 ) -> datafusion::common::Result<LogicalPlan> {
224 from_project_rel(self, rel).await
225 }
226
227 async fn consume_set(&self, rel: &SetRel) -> datafusion::common::Result<LogicalPlan> {
228 from_set_rel(self, rel).await
229 }
230
231 async fn consume_cross(
232 &self,
233 rel: &CrossRel,
234 ) -> datafusion::common::Result<LogicalPlan> {
235 from_cross_rel(self, rel).await
236 }
237
238 async fn consume_consistent_partition_window(
239 &self,
240 _rel: &ConsistentPartitionWindowRel,
241 ) -> datafusion::common::Result<LogicalPlan> {
242 not_impl_err!("Consistent Partition Window Rel not supported")
243 }
244
245 async fn consume_exchange(
246 &self,
247 rel: &ExchangeRel,
248 ) -> datafusion::common::Result<LogicalPlan> {
249 from_exchange_rel(self, rel).await
250 }
251
252 async fn consume_expression(
260 &self,
261 expr: &Expression,
262 input_schema: &DFSchema,
263 ) -> datafusion::common::Result<Expr> {
264 from_substrait_rex(self, expr, input_schema).await
265 }
266
267 async fn consume_literal(&self, expr: &Literal) -> datafusion::common::Result<Expr> {
268 from_literal(self, expr).await
269 }
270
271 async fn consume_field_reference(
272 &self,
273 expr: &FieldReference,
274 input_schema: &DFSchema,
275 ) -> datafusion::common::Result<Expr> {
276 from_field_reference(self, expr, input_schema).await
277 }
278
279 async fn consume_scalar_function(
280 &self,
281 expr: &ScalarFunction,
282 input_schema: &DFSchema,
283 ) -> datafusion::common::Result<Expr> {
284 from_scalar_function(self, expr, input_schema).await
285 }
286
287 async fn consume_window_function(
288 &self,
289 expr: &WindowFunction,
290 input_schema: &DFSchema,
291 ) -> datafusion::common::Result<Expr> {
292 from_window_function(self, expr, input_schema).await
293 }
294
295 async fn consume_if_then(
296 &self,
297 expr: &IfThen,
298 input_schema: &DFSchema,
299 ) -> datafusion::common::Result<Expr> {
300 from_if_then(self, expr, input_schema).await
301 }
302
303 async fn consume_switch(
304 &self,
305 _expr: &SwitchExpression,
306 _input_schema: &DFSchema,
307 ) -> datafusion::common::Result<Expr> {
308 not_impl_err!("Switch expression not supported")
309 }
310
311 async fn consume_singular_or_list(
312 &self,
313 expr: &SingularOrList,
314 input_schema: &DFSchema,
315 ) -> datafusion::common::Result<Expr> {
316 from_singular_or_list(self, expr, input_schema).await
317 }
318
319 async fn consume_multi_or_list(
320 &self,
321 _expr: &MultiOrList,
322 _input_schema: &DFSchema,
323 ) -> datafusion::common::Result<Expr> {
324 not_impl_err!("Multi Or List expression not supported")
325 }
326
327 async fn consume_cast(
328 &self,
329 expr: &substrait_expression::Cast,
330 input_schema: &DFSchema,
331 ) -> datafusion::common::Result<Expr> {
332 from_cast(self, expr, input_schema).await
333 }
334
335 async fn consume_subquery(
336 &self,
337 expr: &substrait_expression::Subquery,
338 input_schema: &DFSchema,
339 ) -> datafusion::common::Result<Expr> {
340 from_subquery(self, expr, input_schema).await
341 }
342
343 async fn consume_nested(
344 &self,
345 _expr: &Nested,
346 _input_schema: &DFSchema,
347 ) -> datafusion::common::Result<Expr> {
348 not_impl_err!("Nested expression not supported")
349 }
350
351 async fn consume_enum(
352 &self,
353 _expr: &Enum,
354 _input_schema: &DFSchema,
355 ) -> datafusion::common::Result<Expr> {
356 not_impl_err!("Enum expression not supported")
357 }
358
359 async fn consume_dynamic_parameter(
360 &self,
361 _expr: &DynamicParameter,
362 _input_schema: &DFSchema,
363 ) -> datafusion::common::Result<Expr> {
364 not_impl_err!("Dynamic Parameter expression not supported")
365 }
366
367 async fn consume_extension_leaf(
373 &self,
374 rel: &ExtensionLeafRel,
375 ) -> datafusion::common::Result<LogicalPlan> {
376 if let Some(detail) = rel.detail.as_ref() {
377 return substrait_err!(
378 "Missing handler for ExtensionLeafRel: {}",
379 detail.type_url
380 );
381 }
382 substrait_err!("Missing handler for ExtensionLeafRel")
383 }
384
385 async fn consume_extension_single(
386 &self,
387 rel: &ExtensionSingleRel,
388 ) -> datafusion::common::Result<LogicalPlan> {
389 if let Some(detail) = rel.detail.as_ref() {
390 return substrait_err!(
391 "Missing handler for ExtensionSingleRel: {}",
392 detail.type_url
393 );
394 }
395 substrait_err!("Missing handler for ExtensionSingleRel")
396 }
397
398 async fn consume_extension_multi(
399 &self,
400 rel: &ExtensionMultiRel,
401 ) -> datafusion::common::Result<LogicalPlan> {
402 if let Some(detail) = rel.detail.as_ref() {
403 return substrait_err!(
404 "Missing handler for ExtensionMultiRel: {}",
405 detail.type_url
406 );
407 }
408 substrait_err!("Missing handler for ExtensionMultiRel")
409 }
410
411 fn consume_user_defined_type(
414 &self,
415 user_defined_type: &r#type::UserDefined,
416 ) -> datafusion::common::Result<DataType> {
417 substrait_err!(
418 "Missing handler for user-defined type: {}",
419 user_defined_type.type_reference
420 )
421 }
422
423 fn consume_user_defined_literal(
424 &self,
425 user_defined_literal: &proto::expression::literal::UserDefined,
426 ) -> datafusion::common::Result<ScalarValue> {
427 substrait_err!(
428 "Missing handler for user-defined literals {}",
429 user_defined_literal.type_reference
430 )
431 }
432}
433
434pub struct DefaultSubstraitConsumer<'a> {
438 pub(super) extensions: &'a Extensions,
439 pub(super) state: &'a SessionState,
440}
441
442impl<'a> DefaultSubstraitConsumer<'a> {
443 pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self {
444 DefaultSubstraitConsumer { extensions, state }
445 }
446}
447
448#[async_trait]
449impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
450 async fn resolve_table_ref(
451 &self,
452 table_ref: &TableReference,
453 ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
454 let table = table_ref.table().to_string();
455 let schema = self.state.schema_for_ref(table_ref.clone())?;
456 let table_provider = schema.table(&table).await?;
457 Ok(table_provider)
458 }
459
460 fn get_extensions(&self) -> &Extensions {
461 self.extensions
462 }
463
464 fn get_function_registry(&self) -> &impl FunctionRegistry {
465 self.state
466 }
467
468 async fn consume_extension_leaf(
469 &self,
470 rel: &ExtensionLeafRel,
471 ) -> datafusion::common::Result<LogicalPlan> {
472 let Some(ext_detail) = &rel.detail else {
473 return substrait_err!("Unexpected empty detail in ExtensionLeafRel");
474 };
475 let plan = self
476 .state
477 .serializer_registry()
478 .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
479 Ok(LogicalPlan::Extension(Extension { node: plan }))
480 }
481
482 async fn consume_extension_single(
483 &self,
484 rel: &ExtensionSingleRel,
485 ) -> datafusion::common::Result<LogicalPlan> {
486 let Some(ext_detail) = &rel.detail else {
487 return substrait_err!("Unexpected empty detail in ExtensionSingleRel");
488 };
489 let plan = self
490 .state
491 .serializer_registry()
492 .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
493 let Some(input_rel) = &rel.input else {
494 return substrait_err!(
495 "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead"
496 );
497 };
498 let input_plan = self.consume_rel(input_rel).await?;
499 let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?;
500 Ok(LogicalPlan::Extension(Extension { node: plan }))
501 }
502
503 async fn consume_extension_multi(
504 &self,
505 rel: &ExtensionMultiRel,
506 ) -> datafusion::common::Result<LogicalPlan> {
507 let Some(ext_detail) = &rel.detail else {
508 return substrait_err!("Unexpected empty detail in ExtensionMultiRel");
509 };
510 let plan = self
511 .state
512 .serializer_registry()
513 .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
514 let mut inputs = Vec::with_capacity(rel.inputs.len());
515 for input in &rel.inputs {
516 let input_plan = self.consume_rel(input).await?;
517 inputs.push(input_plan);
518 }
519 let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
520 Ok(LogicalPlan::Extension(Extension { node: plan }))
521 }
522}