datafusion_expr/logical_plan/
extension.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module defines the interface for logical nodes
19use crate::{Expr, LogicalPlan};
20use datafusion_common::{DFSchema, DFSchemaRef, Result};
21use std::cmp::Ordering;
22use std::hash::{Hash, Hasher};
23use std::{any::Any, collections::HashSet, fmt, sync::Arc};
24
25use super::InvariantLevel;
26
27/// This defines the interface for [`LogicalPlan`] nodes that can be
28/// used to extend DataFusion with custom relational operators.
29///
30/// The [`UserDefinedLogicalNodeCore`] trait is *the recommended way to implement*
31/// this trait and avoids having implementing some required boiler plate code.
32pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
33    /// Return a reference to self as Any, to support dynamic downcasting
34    ///
35    /// Typically this will look like:
36    ///
37    /// ```
38    /// # use std::any::Any;
39    /// # struct Dummy { }
40    ///
41    /// # impl Dummy {
42    ///   // canonical boiler plate
43    ///   fn as_any(&self) -> &dyn Any {
44    ///      self
45    ///   }
46    /// # }
47    /// ```
48    fn as_any(&self) -> &dyn Any;
49
50    /// Return the plan's name.
51    fn name(&self) -> &str;
52
53    /// Return the logical plan's inputs.
54    fn inputs(&self) -> Vec<&LogicalPlan>;
55
56    /// Return the output schema of this logical plan node.
57    fn schema(&self) -> &DFSchemaRef;
58
59    /// Perform check of invariants for the extension node.
60    fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()>;
61
62    /// Returns all expressions in the current logical plan node. This should
63    /// not include expressions of any inputs (aka non-recursively).
64    ///
65    /// These expressions are used for optimizer
66    /// passes and rewrites. See [`LogicalPlan::expressions`] for more details.
67    fn expressions(&self) -> Vec<Expr>;
68
69    /// A list of output columns (e.g. the names of columns in
70    /// self.schema()) for which predicates can not be pushed below
71    /// this node without changing the output.
72    ///
73    /// By default, this returns all columns and thus prevents any
74    /// predicates from being pushed below this node.
75    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
76        // default (safe) is all columns in the schema.
77        get_all_columns_from_schema(self.schema())
78    }
79
80    /// Write a single line, human readable string to `f` for use in explain plan.
81    ///
82    /// For example: `TopK: k=10`
83    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
84
85    /// Create a new `UserDefinedLogicalNode` with the specified children
86    /// and expressions. This function is used during optimization
87    /// when the plan is being rewritten and a new instance of the
88    /// `UserDefinedLogicalNode` must be created.
89    ///
90    /// Note that exprs and inputs are in the same order as the result
91    /// of self.inputs and self.exprs.
92    ///
93    /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
94    fn with_exprs_and_inputs(
95        &self,
96        exprs: Vec<Expr>,
97        inputs: Vec<LogicalPlan>,
98    ) -> Result<Arc<dyn UserDefinedLogicalNode>>;
99
100    /// Returns the necessary input columns for this node required to compute
101    /// the columns in the output schema
102    ///
103    /// This is used for projection push-down when DataFusion has determined that
104    /// only a subset of the output columns of this node are needed by its parents.
105    /// This API is used to tell DataFusion which, if any, of the input columns are no longer
106    /// needed.
107    ///
108    /// Return `None`, the default, if this information can not be determined.
109    /// Returns `Some(_)` with the column indices for each child of this node that are
110    /// needed to compute `output_columns`
111    fn necessary_children_exprs(
112        &self,
113        _output_columns: &[usize],
114    ) -> Option<Vec<Vec<usize>>> {
115        None
116    }
117
118    /// Update the hash `state` with this node requirements from
119    /// [`Hash`].
120    ///
121    /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
122    /// [`UserDefinedLogicalNode`] directly.
123    ///
124    /// This method is required to support hashing [`LogicalPlan`]s.  To
125    /// implement it, typically the type implementing
126    /// [`UserDefinedLogicalNode`] typically implements [`Hash`] and
127    /// then the following boiler plate is used:
128    ///
129    /// # Example:
130    /// ```
131    /// // User defined node that derives Hash
132    /// #[derive(Hash, Debug, PartialEq, Eq)]
133    /// struct MyNode {
134    ///   val: u64
135    /// }
136    ///
137    /// // impl UserDefinedLogicalNode {
138    /// // ...
139    /// # impl MyNode {
140    ///   // Boiler plate to call the derived Hash impl
141    ///   fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
142    ///     use std::hash::Hash;
143    ///     let mut s = state;
144    ///     self.hash(&mut s);
145    ///   }
146    /// // }
147    /// # }
148    /// ```
149    /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Hash`]
150    /// directly because it must remain object safe.
151    fn dyn_hash(&self, state: &mut dyn Hasher);
152
153    /// Compare `other`, respecting requirements from [std::cmp::Eq].
154    ///
155    /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
156    /// [`UserDefinedLogicalNode`] directly.
157    ///
158    /// When `other` has an another type than `self`, then the values
159    /// are *not* equal.
160    ///
161    /// This method is required to support Eq on [`LogicalPlan`]s.  To
162    /// implement it, typically the type implementing
163    /// [`UserDefinedLogicalNode`] typically implements [`Eq`] and
164    /// then the following boiler plate is used:
165    ///
166    /// # Example:
167    /// ```
168    /// # use datafusion_expr::UserDefinedLogicalNode;
169    /// // User defined node that derives Eq
170    /// #[derive(Hash, Debug, PartialEq, Eq)]
171    /// struct MyNode {
172    ///   val: u64
173    /// }
174    ///
175    /// // impl UserDefinedLogicalNode {
176    /// // ...
177    /// # impl MyNode {
178    ///   // Boiler plate to call the derived Eq impl
179    ///   fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
180    ///     match other.as_any().downcast_ref::<Self>() {
181    ///       Some(o) => self == o,
182    ///       None => false,
183    ///     }
184    ///   }
185    /// // }
186    /// # }
187    /// ```
188    /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Eq`]
189    /// directly because it must remain object safe.
190    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool;
191    fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering>;
192
193    /// Returns `true` if a limit can be safely pushed down through this
194    /// `UserDefinedLogicalNode` node.
195    ///
196    /// If this method returns `true`, and the query plan contains a limit at
197    /// the output of this node, DataFusion will push the limit to the input
198    /// of this node.
199    fn supports_limit_pushdown(&self) -> bool {
200        false
201    }
202}
203
204impl Hash for dyn UserDefinedLogicalNode {
205    fn hash<H: Hasher>(&self, state: &mut H) {
206        self.dyn_hash(state);
207    }
208}
209
210impl PartialEq for dyn UserDefinedLogicalNode {
211    fn eq(&self, other: &Self) -> bool {
212        self.dyn_eq(other)
213    }
214}
215
216impl PartialOrd for dyn UserDefinedLogicalNode {
217    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
218        self.dyn_ord(other)
219    }
220}
221
222impl Eq for dyn UserDefinedLogicalNode {}
223
224/// This trait facilitates implementation of the [`UserDefinedLogicalNode`].
225///
226/// See the example in
227/// [user_defined_plan.rs](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/user_defined/user_defined_plan.rs)
228/// file for an example of how to use this extension API.
229pub trait UserDefinedLogicalNodeCore:
230    fmt::Debug + Eq + PartialOrd + Hash + Sized + Send + Sync + 'static
231{
232    /// Return the plan's name.
233    fn name(&self) -> &str;
234
235    /// Return the logical plan's inputs.
236    fn inputs(&self) -> Vec<&LogicalPlan>;
237
238    /// Return the output schema of this logical plan node.
239    fn schema(&self) -> &DFSchemaRef;
240
241    /// Perform check of invariants for the extension node.
242    ///
243    /// This is the default implementation for extension nodes.
244    fn check_invariants(
245        &self,
246        _check: InvariantLevel,
247        _plan: &LogicalPlan,
248    ) -> Result<()> {
249        Ok(())
250    }
251
252    /// Returns all expressions in the current logical plan node. This
253    /// should not include expressions of any inputs (aka
254    /// non-recursively). These expressions are used for optimizer
255    /// passes and rewrites.
256    fn expressions(&self) -> Vec<Expr>;
257
258    /// A list of output columns (e.g. the names of columns in
259    /// self.schema()) for which predicates can not be pushed below
260    /// this node without changing the output.
261    ///
262    /// By default, this returns all columns and thus prevents any
263    /// predicates from being pushed below this node.
264    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
265        // default (safe) is all columns in the schema.
266        get_all_columns_from_schema(self.schema())
267    }
268
269    /// Write a single line, human readable string to `f` for use in explain plan.
270    ///
271    /// For example: `TopK: k=10`
272    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
273
274    /// Create a new `UserDefinedLogicalNode` with the specified children
275    /// and expressions. This function is used during optimization
276    /// when the plan is being rewritten and a new instance of the
277    /// `UserDefinedLogicalNode` must be created.
278    ///
279    /// Note that exprs and inputs are in the same order as the result
280    /// of self.inputs and self.exprs.
281    ///
282    /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
283    fn with_exprs_and_inputs(
284        &self,
285        exprs: Vec<Expr>,
286        inputs: Vec<LogicalPlan>,
287    ) -> Result<Self>;
288
289    /// Returns the necessary input columns for this node required to compute
290    /// the columns in the output schema
291    ///
292    /// This is used for projection push-down when DataFusion has determined that
293    /// only a subset of the output columns of this node are needed by its parents.
294    /// This API is used to tell DataFusion which, if any, of the input columns are no longer
295    /// needed.
296    ///
297    /// Return `None`, the default, if this information can not be determined.
298    /// Returns `Some(_)` with the column indices for each child of this node that are
299    /// needed to compute `output_columns`
300    fn necessary_children_exprs(
301        &self,
302        _output_columns: &[usize],
303    ) -> Option<Vec<Vec<usize>>> {
304        None
305    }
306
307    /// Returns `true` if a limit can be safely pushed down through this
308    /// `UserDefinedLogicalNode` node.
309    ///
310    /// If this method returns `true`, and the query plan contains a limit at
311    /// the output of this node, DataFusion will push the limit to the input
312    /// of this node.
313    fn supports_limit_pushdown(&self) -> bool {
314        false // Disallow limit push-down by default
315    }
316}
317
318/// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode`
319/// to avoid boiler plate for implementing `as_any`, `Hash` and `PartialEq`
320impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
321    fn as_any(&self) -> &dyn Any {
322        self
323    }
324
325    fn name(&self) -> &str {
326        self.name()
327    }
328
329    fn inputs(&self) -> Vec<&LogicalPlan> {
330        self.inputs()
331    }
332
333    fn schema(&self) -> &DFSchemaRef {
334        self.schema()
335    }
336
337    fn check_invariants(&self, check: InvariantLevel, plan: &LogicalPlan) -> Result<()> {
338        self.check_invariants(check, plan)
339    }
340
341    fn expressions(&self) -> Vec<Expr> {
342        self.expressions()
343    }
344
345    fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
346        self.prevent_predicate_push_down_columns()
347    }
348
349    fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
350        self.fmt_for_explain(f)
351    }
352
353    fn with_exprs_and_inputs(
354        &self,
355        exprs: Vec<Expr>,
356        inputs: Vec<LogicalPlan>,
357    ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
358        Ok(Arc::new(self.with_exprs_and_inputs(exprs, inputs)?))
359    }
360
361    fn necessary_children_exprs(
362        &self,
363        output_columns: &[usize],
364    ) -> Option<Vec<Vec<usize>>> {
365        self.necessary_children_exprs(output_columns)
366    }
367
368    fn dyn_hash(&self, state: &mut dyn Hasher) {
369        let mut s = state;
370        self.hash(&mut s);
371    }
372
373    fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
374        match other.as_any().downcast_ref::<Self>() {
375            Some(o) => self == o,
376            None => false,
377        }
378    }
379
380    fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering> {
381        other
382            .as_any()
383            .downcast_ref::<Self>()
384            .and_then(|other| self.partial_cmp(other))
385    }
386
387    fn supports_limit_pushdown(&self) -> bool {
388        self.supports_limit_pushdown()
389    }
390}
391
392fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet<String> {
393    schema.fields().iter().map(|f| f.name().clone()).collect()
394}