Skip to main content

datafusion_substrait/logical_plan/consumer/
substrait_consumer.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use super::{
19    from_aggregate_rel, from_cast, from_cross_rel, from_exchange_rel, from_fetch_rel,
20    from_field_reference, from_filter_rel, from_if_then, from_join_rel, from_literal,
21    from_project_rel, from_read_rel, from_scalar_function, from_set_rel,
22    from_singular_or_list, from_sort_rel, from_subquery, from_substrait_rel,
23    from_substrait_rex, from_window_function,
24};
25use crate::extensions::Extensions;
26use async_trait::async_trait;
27use datafusion::arrow::datatypes::DataType;
28use datafusion::catalog::TableProvider;
29use datafusion::common::{
30    DFSchema, ScalarValue, TableReference, not_impl_err, substrait_err,
31};
32use datafusion::execution::{FunctionRegistry, SessionState};
33use datafusion::logical_expr::{Expr, Extension, LogicalPlan};
34use std::sync::Arc;
35use substrait::proto;
36use substrait::proto::expression as substrait_expression;
37use substrait::proto::expression::{
38    Enum, FieldReference, IfThen, Literal, MultiOrList, Nested, ScalarFunction,
39    SingularOrList, SwitchExpression, WindowFunction,
40};
41use substrait::proto::{
42    AggregateRel, ConsistentPartitionWindowRel, CrossRel, DynamicParameter, ExchangeRel,
43    Expression, ExtensionLeafRel, ExtensionMultiRel, ExtensionSingleRel, FetchRel,
44    FilterRel, JoinRel, ProjectRel, ReadRel, Rel, SetRel, SortRel, r#type,
45};
46
47#[async_trait]
48/// This trait is used to consume Substrait plans, converting them into DataFusion Logical Plans.
49/// It can be implemented by users to allow for custom handling of relations, expressions, etc.
50///
51/// Combined with the [crate::logical_plan::producer::SubstraitProducer] this allows for fully
52/// customizable Substrait serde.
53///
54/// # Example Usage
55///
56/// ```
57/// # use async_trait::async_trait;
58/// # use datafusion::catalog::TableProvider;
59/// # use datafusion::common::{not_impl_err, substrait_err, DFSchema, ScalarValue, TableReference};
60/// # use datafusion::error::Result;
61/// # use datafusion::execution::{FunctionRegistry, SessionState};
62/// # use datafusion::logical_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
63/// # use std::sync::Arc;
64/// # use substrait::proto;
65/// # use substrait::proto::{ExtensionLeafRel, FilterRel, ProjectRel};
66/// # use datafusion::arrow::datatypes::DataType;
67/// # use datafusion::logical_expr::expr::ScalarFunction;
68/// # use datafusion_substrait::extensions::Extensions;
69/// # use datafusion_substrait::logical_plan::consumer::{
70/// #     from_project_rel, from_substrait_rel, from_substrait_rex, SubstraitConsumer
71/// # };
72///
73/// struct CustomSubstraitConsumer {
74///     extensions: Arc<Extensions>,
75///     state: Arc<SessionState>,
76/// }
77///
78/// #[async_trait]
79/// impl SubstraitConsumer for CustomSubstraitConsumer {
80///     async fn resolve_table_ref(
81///         &self,
82///         table_ref: &TableReference,
83///     ) -> Result<Option<Arc<dyn TableProvider>>> {
84///         let table = table_ref.table().to_string();
85///         let schema = self.state.schema_for_ref(table_ref.clone())?;
86///         let table_provider = schema.table(&table).await?;
87///         Ok(table_provider)
88///     }
89///
90///     fn get_extensions(&self) -> &Extensions {
91///         self.extensions.as_ref()
92///     }
93///
94///     fn get_function_registry(&self) -> &impl FunctionRegistry {
95///         self.state.as_ref()
96///     }
97///
98///     // You can reuse existing consumer code to assist in handling advanced extensions
99///     async fn consume_project(&self, rel: &ProjectRel) -> Result<LogicalPlan> {
100///         let df_plan = from_project_rel(self, rel).await?;
101///         if let Some(advanced_extension) = rel.advanced_extension.as_ref() {
102///             not_impl_err!(
103///                 "decode and handle an advanced extension: {:?}",
104///                 advanced_extension
105///             )
106///         } else {
107///             Ok(df_plan)
108///         }
109///     }
110///
111///     // You can implement a fully custom consumer method if you need special handling
112///     async fn consume_filter(&self, rel: &FilterRel) -> Result<LogicalPlan> {
113///         let input = self.consume_rel(rel.input.as_ref().unwrap()).await?;
114///         let expression =
115///             self.consume_expression(rel.condition.as_ref().unwrap(), input.schema())
116///                 .await?;
117///         // though this one is quite boring
118///         LogicalPlanBuilder::from(input).filter(expression)?.build()
119///     }
120///
121///     // You can add handlers for extension relations
122///     async fn consume_extension_leaf(
123///         &self,
124///         rel: &ExtensionLeafRel,
125///     ) -> Result<LogicalPlan> {
126///         not_impl_err!(
127///             "handle protobuf Any {} as you need",
128///             rel.detail.as_ref().unwrap().type_url
129///         )
130///     }
131///
132///     // and handlers for user-define types
133///     fn consume_user_defined_type(&self, typ: &proto::r#type::UserDefined) -> Result<DataType> {
134///         let type_string = self.extensions.types.get(&typ.type_reference).unwrap();
135///         match type_string.as_str() {
136///             "u!foo" => not_impl_err!("handle foo conversion"),
137///             "u!bar" => not_impl_err!("handle bar conversion"),
138///             _ => substrait_err!("unexpected type")
139///         }
140///     }
141///
142///     // and user-defined literals
143///     fn consume_user_defined_literal(&self, literal: &proto::expression::literal::UserDefined) -> Result<ScalarValue> {
144///         let type_string = self.extensions.types.get(&literal.type_reference).unwrap();
145///         match type_string.as_str() {
146///             "u!foo" => not_impl_err!("handle foo conversion"),
147///             "u!bar" => not_impl_err!("handle bar conversion"),
148///             _ => substrait_err!("unexpected type")
149///         }
150///     }
151/// }
152/// ```
153pub trait SubstraitConsumer: Send + Sync + Sized {
154    async fn resolve_table_ref(
155        &self,
156        table_ref: &TableReference,
157    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>>;
158
159    // TODO: Remove these two methods
160    //   Ideally, the abstract consumer should not place any constraints on implementations.
161    //   The functionality for which the Extensions and FunctionRegistry is needed should be abstracted
162    //   out into methods on the trait. As an example, resolve_table_reference is such a method.
163    //   See: https://github.com/apache/datafusion/issues/13863
164    fn get_extensions(&self) -> &Extensions;
165    fn get_function_registry(&self) -> &impl FunctionRegistry;
166
167    // Relation Methods
168    // There is one method per Substrait relation to allow for easy overriding of consumer behaviour.
169    // These methods have default implementations calling the common handler code, to allow for users
170    // to re-use common handling logic.
171
172    /// All [Rel]s to be converted pass through this method.
173    /// You can provide your own implementation if you wish to customize the conversion behaviour.
174    async fn consume_rel(&self, rel: &Rel) -> datafusion::common::Result<LogicalPlan> {
175        from_substrait_rel(self, rel).await
176    }
177
178    async fn consume_read(
179        &self,
180        rel: &ReadRel,
181    ) -> datafusion::common::Result<LogicalPlan> {
182        from_read_rel(self, rel).await
183    }
184
185    async fn consume_filter(
186        &self,
187        rel: &FilterRel,
188    ) -> datafusion::common::Result<LogicalPlan> {
189        from_filter_rel(self, rel).await
190    }
191
192    async fn consume_fetch(
193        &self,
194        rel: &FetchRel,
195    ) -> datafusion::common::Result<LogicalPlan> {
196        from_fetch_rel(self, rel).await
197    }
198
199    async fn consume_aggregate(
200        &self,
201        rel: &AggregateRel,
202    ) -> datafusion::common::Result<LogicalPlan> {
203        from_aggregate_rel(self, rel).await
204    }
205
206    async fn consume_sort(
207        &self,
208        rel: &SortRel,
209    ) -> datafusion::common::Result<LogicalPlan> {
210        from_sort_rel(self, rel).await
211    }
212
213    async fn consume_join(
214        &self,
215        rel: &JoinRel,
216    ) -> datafusion::common::Result<LogicalPlan> {
217        from_join_rel(self, rel).await
218    }
219
220    async fn consume_project(
221        &self,
222        rel: &ProjectRel,
223    ) -> datafusion::common::Result<LogicalPlan> {
224        from_project_rel(self, rel).await
225    }
226
227    async fn consume_set(&self, rel: &SetRel) -> datafusion::common::Result<LogicalPlan> {
228        from_set_rel(self, rel).await
229    }
230
231    async fn consume_cross(
232        &self,
233        rel: &CrossRel,
234    ) -> datafusion::common::Result<LogicalPlan> {
235        from_cross_rel(self, rel).await
236    }
237
238    async fn consume_consistent_partition_window(
239        &self,
240        _rel: &ConsistentPartitionWindowRel,
241    ) -> datafusion::common::Result<LogicalPlan> {
242        not_impl_err!("Consistent Partition Window Rel not supported")
243    }
244
245    async fn consume_exchange(
246        &self,
247        rel: &ExchangeRel,
248    ) -> datafusion::common::Result<LogicalPlan> {
249        from_exchange_rel(self, rel).await
250    }
251
252    // Expression Methods
253    // There is one method per Substrait expression to allow for easy overriding of consumer behaviour
254    // These methods have default implementations calling the common handler code, to allow for users
255    // to re-use common handling logic.
256
257    /// All [Expression]s to be converted pass through this method.
258    /// You can provide your own implementation if you wish to customize the conversion behaviour.
259    async fn consume_expression(
260        &self,
261        expr: &Expression,
262        input_schema: &DFSchema,
263    ) -> datafusion::common::Result<Expr> {
264        from_substrait_rex(self, expr, input_schema).await
265    }
266
267    async fn consume_literal(&self, expr: &Literal) -> datafusion::common::Result<Expr> {
268        from_literal(self, expr).await
269    }
270
271    async fn consume_field_reference(
272        &self,
273        expr: &FieldReference,
274        input_schema: &DFSchema,
275    ) -> datafusion::common::Result<Expr> {
276        from_field_reference(self, expr, input_schema).await
277    }
278
279    async fn consume_scalar_function(
280        &self,
281        expr: &ScalarFunction,
282        input_schema: &DFSchema,
283    ) -> datafusion::common::Result<Expr> {
284        from_scalar_function(self, expr, input_schema).await
285    }
286
287    async fn consume_window_function(
288        &self,
289        expr: &WindowFunction,
290        input_schema: &DFSchema,
291    ) -> datafusion::common::Result<Expr> {
292        from_window_function(self, expr, input_schema).await
293    }
294
295    async fn consume_if_then(
296        &self,
297        expr: &IfThen,
298        input_schema: &DFSchema,
299    ) -> datafusion::common::Result<Expr> {
300        from_if_then(self, expr, input_schema).await
301    }
302
303    async fn consume_switch(
304        &self,
305        _expr: &SwitchExpression,
306        _input_schema: &DFSchema,
307    ) -> datafusion::common::Result<Expr> {
308        not_impl_err!("Switch expression not supported")
309    }
310
311    async fn consume_singular_or_list(
312        &self,
313        expr: &SingularOrList,
314        input_schema: &DFSchema,
315    ) -> datafusion::common::Result<Expr> {
316        from_singular_or_list(self, expr, input_schema).await
317    }
318
319    async fn consume_multi_or_list(
320        &self,
321        _expr: &MultiOrList,
322        _input_schema: &DFSchema,
323    ) -> datafusion::common::Result<Expr> {
324        not_impl_err!("Multi Or List expression not supported")
325    }
326
327    async fn consume_cast(
328        &self,
329        expr: &substrait_expression::Cast,
330        input_schema: &DFSchema,
331    ) -> datafusion::common::Result<Expr> {
332        from_cast(self, expr, input_schema).await
333    }
334
335    async fn consume_subquery(
336        &self,
337        expr: &substrait_expression::Subquery,
338        input_schema: &DFSchema,
339    ) -> datafusion::common::Result<Expr> {
340        from_subquery(self, expr, input_schema).await
341    }
342
343    async fn consume_nested(
344        &self,
345        _expr: &Nested,
346        _input_schema: &DFSchema,
347    ) -> datafusion::common::Result<Expr> {
348        not_impl_err!("Nested expression not supported")
349    }
350
351    async fn consume_enum(
352        &self,
353        _expr: &Enum,
354        _input_schema: &DFSchema,
355    ) -> datafusion::common::Result<Expr> {
356        not_impl_err!("Enum expression not supported")
357    }
358
359    async fn consume_dynamic_parameter(
360        &self,
361        _expr: &DynamicParameter,
362        _input_schema: &DFSchema,
363    ) -> datafusion::common::Result<Expr> {
364        not_impl_err!("Dynamic Parameter expression not supported")
365    }
366
367    // User-Defined Functionality
368
369    // The details of extension relations, and how to handle them, are fully up to users to specify.
370    // The following methods allow users to customize the consumer behaviour
371
372    async fn consume_extension_leaf(
373        &self,
374        rel: &ExtensionLeafRel,
375    ) -> datafusion::common::Result<LogicalPlan> {
376        if let Some(detail) = rel.detail.as_ref() {
377            return substrait_err!(
378                "Missing handler for ExtensionLeafRel: {}",
379                detail.type_url
380            );
381        }
382        substrait_err!("Missing handler for ExtensionLeafRel")
383    }
384
385    async fn consume_extension_single(
386        &self,
387        rel: &ExtensionSingleRel,
388    ) -> datafusion::common::Result<LogicalPlan> {
389        if let Some(detail) = rel.detail.as_ref() {
390            return substrait_err!(
391                "Missing handler for ExtensionSingleRel: {}",
392                detail.type_url
393            );
394        }
395        substrait_err!("Missing handler for ExtensionSingleRel")
396    }
397
398    async fn consume_extension_multi(
399        &self,
400        rel: &ExtensionMultiRel,
401    ) -> datafusion::common::Result<LogicalPlan> {
402        if let Some(detail) = rel.detail.as_ref() {
403            return substrait_err!(
404                "Missing handler for ExtensionMultiRel: {}",
405                detail.type_url
406            );
407        }
408        substrait_err!("Missing handler for ExtensionMultiRel")
409    }
410
411    // Users can bring their own types to Substrait which require custom handling
412
413    fn consume_user_defined_type(
414        &self,
415        user_defined_type: &r#type::UserDefined,
416    ) -> datafusion::common::Result<DataType> {
417        substrait_err!(
418            "Missing handler for user-defined type: {}",
419            user_defined_type.type_reference
420        )
421    }
422
423    fn consume_user_defined_literal(
424        &self,
425        user_defined_literal: &proto::expression::literal::UserDefined,
426    ) -> datafusion::common::Result<ScalarValue> {
427        substrait_err!(
428            "Missing handler for user-defined literals {}",
429            user_defined_literal.type_reference
430        )
431    }
432}
433
434/// Default SubstraitConsumer for converting standard Substrait without user-defined extensions.
435///
436/// Used as the consumer in [crate::logical_plan::consumer::from_substrait_plan]
437pub struct DefaultSubstraitConsumer<'a> {
438    pub(super) extensions: &'a Extensions,
439    pub(super) state: &'a SessionState,
440}
441
442impl<'a> DefaultSubstraitConsumer<'a> {
443    pub fn new(extensions: &'a Extensions, state: &'a SessionState) -> Self {
444        DefaultSubstraitConsumer { extensions, state }
445    }
446}
447
448#[async_trait]
449impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
450    async fn resolve_table_ref(
451        &self,
452        table_ref: &TableReference,
453    ) -> datafusion::common::Result<Option<Arc<dyn TableProvider>>> {
454        let table = table_ref.table().to_string();
455        let schema = self.state.schema_for_ref(table_ref.clone())?;
456        let table_provider = schema.table(&table).await?;
457        Ok(table_provider)
458    }
459
460    fn get_extensions(&self) -> &Extensions {
461        self.extensions
462    }
463
464    fn get_function_registry(&self) -> &impl FunctionRegistry {
465        self.state
466    }
467
468    async fn consume_extension_leaf(
469        &self,
470        rel: &ExtensionLeafRel,
471    ) -> datafusion::common::Result<LogicalPlan> {
472        let Some(ext_detail) = &rel.detail else {
473            return substrait_err!("Unexpected empty detail in ExtensionLeafRel");
474        };
475        let plan = self
476            .state
477            .serializer_registry()
478            .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
479        Ok(LogicalPlan::Extension(Extension { node: plan }))
480    }
481
482    async fn consume_extension_single(
483        &self,
484        rel: &ExtensionSingleRel,
485    ) -> datafusion::common::Result<LogicalPlan> {
486        let Some(ext_detail) = &rel.detail else {
487            return substrait_err!("Unexpected empty detail in ExtensionSingleRel");
488        };
489        let plan = self
490            .state
491            .serializer_registry()
492            .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
493        let Some(input_rel) = &rel.input else {
494            return substrait_err!(
495                "ExtensionSingleRel missing input rel, try using ExtensionLeafRel instead"
496            );
497        };
498        let input_plan = self.consume_rel(input_rel).await?;
499        let plan = plan.with_exprs_and_inputs(plan.expressions(), vec![input_plan])?;
500        Ok(LogicalPlan::Extension(Extension { node: plan }))
501    }
502
503    async fn consume_extension_multi(
504        &self,
505        rel: &ExtensionMultiRel,
506    ) -> datafusion::common::Result<LogicalPlan> {
507        let Some(ext_detail) = &rel.detail else {
508            return substrait_err!("Unexpected empty detail in ExtensionMultiRel");
509        };
510        let plan = self
511            .state
512            .serializer_registry()
513            .deserialize_logical_plan(&ext_detail.type_url, &ext_detail.value)?;
514        let mut inputs = Vec::with_capacity(rel.inputs.len());
515        for input in &rel.inputs {
516            let input_plan = self.consume_rel(input).await?;
517            inputs.push(input_plan);
518        }
519        let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
520        Ok(LogicalPlan::Extension(Extension { node: plan }))
521    }
522}