datafusion_expr/logical_plan/extension.rs
1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements. See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership. The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License. You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied. See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! This module defines the interface for logical nodes
19use crate::{Expr, LogicalPlan};
20use datafusion_common::{DFSchema, DFSchemaRef, Result};
21use std::cmp::Ordering;
22use std::hash::{Hash, Hasher};
23use std::{any::Any, collections::HashSet, fmt, sync::Arc};
24
25use super::InvariantLevel;
26
27/// This defines the interface for [`LogicalPlan`] nodes that can be
28/// used to extend DataFusion with custom relational operators.
29///
30/// The [`UserDefinedLogicalNodeCore`] trait is *the recommended way to implement*
31/// this trait and avoids having implementing some required boiler plate code.
32pub trait UserDefinedLogicalNode: fmt::Debug + Send + Sync {
33 /// Return a reference to self as Any, to support dynamic downcasting
34 ///
35 /// Typically this will look like:
36 ///
37 /// ```
38 /// # use std::any::Any;
39 /// # struct Dummy { }
40 ///
41 /// # impl Dummy {
42 /// // canonical boiler plate
43 /// fn as_any(&self) -> &dyn Any {
44 /// self
45 /// }
46 /// # }
47 /// ```
48 fn as_any(&self) -> &dyn Any;
49
50 /// Return the plan's name.
51 fn name(&self) -> &str;
52
53 /// Return the logical plan's inputs.
54 fn inputs(&self) -> Vec<&LogicalPlan>;
55
56 /// Return the output schema of this logical plan node.
57 fn schema(&self) -> &DFSchemaRef;
58
59 /// Perform check of invariants for the extension node.
60 fn check_invariants(&self, check: InvariantLevel) -> Result<()>;
61
62 /// Returns all expressions in the current logical plan node. This should
63 /// not include expressions of any inputs (aka non-recursively).
64 ///
65 /// These expressions are used for optimizer
66 /// passes and rewrites. See [`LogicalPlan::expressions`] for more details.
67 fn expressions(&self) -> Vec<Expr>;
68
69 /// A list of output columns (e.g. the names of columns in
70 /// self.schema()) for which predicates can not be pushed below
71 /// this node without changing the output.
72 ///
73 /// By default, this returns all columns and thus prevents any
74 /// predicates from being pushed below this node.
75 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
76 // default (safe) is all columns in the schema.
77 get_all_columns_from_schema(self.schema())
78 }
79
80 /// Write a single line, human readable string to `f` for use in explain plan.
81 ///
82 /// For example: `TopK: k=10`
83 fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
84
85 /// Create a new `UserDefinedLogicalNode` with the specified children
86 /// and expressions. This function is used during optimization
87 /// when the plan is being rewritten and a new instance of the
88 /// `UserDefinedLogicalNode` must be created.
89 ///
90 /// Note that exprs and inputs are in the same order as the result
91 /// of self.inputs and self.exprs.
92 ///
93 /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
94 fn with_exprs_and_inputs(
95 &self,
96 exprs: Vec<Expr>,
97 inputs: Vec<LogicalPlan>,
98 ) -> Result<Arc<dyn UserDefinedLogicalNode>>;
99
100 /// Returns the necessary input columns for this node required to compute
101 /// the columns in the output schema
102 ///
103 /// This is used for projection push-down when DataFusion has determined that
104 /// only a subset of the output columns of this node are needed by its parents.
105 /// This API is used to tell DataFusion which, if any, of the input columns are no longer
106 /// needed.
107 ///
108 /// Return `None`, the default, if this information can not be determined.
109 /// Returns `Some(_)` with the column indices for each child of this node that are
110 /// needed to compute `output_columns`
111 fn necessary_children_exprs(
112 &self,
113 _output_columns: &[usize],
114 ) -> Option<Vec<Vec<usize>>> {
115 None
116 }
117
118 /// Update the hash `state` with this node requirements from
119 /// [`Hash`].
120 ///
121 /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
122 /// [`UserDefinedLogicalNode`] directly.
123 ///
124 /// This method is required to support hashing [`LogicalPlan`]s. To
125 /// implement it, typically the type implementing
126 /// [`UserDefinedLogicalNode`] typically implements [`Hash`] and
127 /// then the following boiler plate is used:
128 ///
129 /// # Example:
130 /// ```
131 /// // User defined node that derives Hash
132 /// #[derive(Hash, Debug, PartialEq, Eq)]
133 /// struct MyNode {
134 /// val: u64,
135 /// }
136 ///
137 /// // impl UserDefinedLogicalNode {
138 /// // ...
139 /// # impl MyNode {
140 /// // Boiler plate to call the derived Hash impl
141 /// fn dyn_hash(&self, state: &mut dyn std::hash::Hasher) {
142 /// use std::hash::Hash;
143 /// let mut s = state;
144 /// self.hash(&mut s);
145 /// }
146 /// // }
147 /// # }
148 /// ```
149 /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Hash`]
150 /// directly because it must remain object safe.
151 fn dyn_hash(&self, state: &mut dyn Hasher);
152
153 /// Compare `other`, respecting requirements from [Eq].
154 ///
155 /// Note: consider using [`UserDefinedLogicalNodeCore`] instead of
156 /// [`UserDefinedLogicalNode`] directly.
157 ///
158 /// When `other` has an another type than `self`, then the values
159 /// are *not* equal.
160 ///
161 /// This method is required to support Eq on [`LogicalPlan`]s. To
162 /// implement it, typically the type implementing
163 /// [`UserDefinedLogicalNode`] typically implements [`Eq`] and
164 /// then the following boiler plate is used:
165 ///
166 /// # Example:
167 /// ```
168 /// # use datafusion_expr::UserDefinedLogicalNode;
169 /// // User defined node that derives Eq
170 /// #[derive(Hash, Debug, PartialEq, Eq)]
171 /// struct MyNode {
172 /// val: u64,
173 /// }
174 ///
175 /// // impl UserDefinedLogicalNode {
176 /// // ...
177 /// # impl MyNode {
178 /// // Boiler plate to call the derived Eq impl
179 /// fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
180 /// match other.as_any().downcast_ref::<Self>() {
181 /// Some(o) => self == o,
182 /// None => false,
183 /// }
184 /// }
185 /// // }
186 /// # }
187 /// ```
188 /// Note: [`UserDefinedLogicalNode`] is not constrained by [`Eq`]
189 /// directly because it must remain object safe.
190 fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool;
191
192 /// Compare `other`, respecting requirements from [PartialOrd].
193 /// Must return `Some(Equal)` if and only if `self.dyn_eq(other)`.
194 fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering>;
195
196 /// Returns `true` if a limit can be safely pushed down through this
197 /// `UserDefinedLogicalNode` node.
198 ///
199 /// If this method returns `true`, and the query plan contains a limit at
200 /// the output of this node, DataFusion will push the limit to the input
201 /// of this node.
202 fn supports_limit_pushdown(&self) -> bool {
203 false
204 }
205}
206
207impl Hash for dyn UserDefinedLogicalNode {
208 fn hash<H: Hasher>(&self, state: &mut H) {
209 self.dyn_hash(state);
210 }
211}
212
213impl PartialEq for dyn UserDefinedLogicalNode {
214 fn eq(&self, other: &Self) -> bool {
215 self.dyn_eq(other)
216 }
217}
218
219impl PartialOrd for dyn UserDefinedLogicalNode {
220 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
221 self.dyn_ord(other)
222 }
223}
224
225impl Eq for dyn UserDefinedLogicalNode {}
226
227/// This trait facilitates implementation of the [`UserDefinedLogicalNode`].
228///
229/// See the example in
230/// [user_defined_plan.rs](https://github.com/apache/datafusion/blob/main/datafusion/core/tests/user_defined/user_defined_plan.rs)
231/// file for an example of how to use this extension API.
232pub trait UserDefinedLogicalNodeCore:
233 fmt::Debug + Eq + PartialOrd + Hash + Sized + Send + Sync + 'static
234{
235 /// Return the plan's name.
236 fn name(&self) -> &str;
237
238 /// Return the logical plan's inputs.
239 fn inputs(&self) -> Vec<&LogicalPlan>;
240
241 /// Return the output schema of this logical plan node.
242 fn schema(&self) -> &DFSchemaRef;
243
244 /// Perform check of invariants for the extension node.
245 ///
246 /// This is the default implementation for extension nodes.
247 fn check_invariants(&self, _check: InvariantLevel) -> Result<()> {
248 Ok(())
249 }
250
251 /// Returns all expressions in the current logical plan node. This
252 /// should not include expressions of any inputs (aka
253 /// non-recursively). These expressions are used for optimizer
254 /// passes and rewrites.
255 fn expressions(&self) -> Vec<Expr>;
256
257 /// A list of output columns (e.g. the names of columns in
258 /// self.schema()) for which predicates can not be pushed below
259 /// this node without changing the output.
260 ///
261 /// By default, this returns all columns and thus prevents any
262 /// predicates from being pushed below this node.
263 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
264 // default (safe) is all columns in the schema.
265 get_all_columns_from_schema(self.schema())
266 }
267
268 /// Write a single line, human readable string to `f` for use in explain plan.
269 ///
270 /// For example: `TopK: k=10`
271 fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result;
272
273 /// Create a new `UserDefinedLogicalNode` with the specified children
274 /// and expressions. This function is used during optimization
275 /// when the plan is being rewritten and a new instance of the
276 /// `UserDefinedLogicalNode` must be created.
277 ///
278 /// Note that exprs and inputs are in the same order as the result
279 /// of self.inputs and self.exprs.
280 ///
281 /// So, `self.with_exprs_and_inputs(exprs, ..).expressions() == exprs
282 fn with_exprs_and_inputs(
283 &self,
284 exprs: Vec<Expr>,
285 inputs: Vec<LogicalPlan>,
286 ) -> Result<Self>;
287
288 /// Returns the necessary input columns for this node required to compute
289 /// the columns in the output schema
290 ///
291 /// This is used for projection push-down when DataFusion has determined that
292 /// only a subset of the output columns of this node are needed by its parents.
293 /// This API is used to tell DataFusion which, if any, of the input columns are no longer
294 /// needed.
295 ///
296 /// Return `None`, the default, if this information can not be determined.
297 /// Returns `Some(_)` with the column indices for each child of this node that are
298 /// needed to compute `output_columns`
299 fn necessary_children_exprs(
300 &self,
301 _output_columns: &[usize],
302 ) -> Option<Vec<Vec<usize>>> {
303 None
304 }
305
306 /// Returns `true` if a limit can be safely pushed down through this
307 /// `UserDefinedLogicalNode` node.
308 ///
309 /// If this method returns `true`, and the query plan contains a limit at
310 /// the output of this node, DataFusion will push the limit to the input
311 /// of this node.
312 fn supports_limit_pushdown(&self) -> bool {
313 false // Disallow limit push-down by default
314 }
315}
316
317/// Automatically derive UserDefinedLogicalNode to `UserDefinedLogicalNode`
318/// to avoid boiler plate for implementing `as_any`, `Hash`, `PartialEq` and `PartialOrd`.
319impl<T: UserDefinedLogicalNodeCore> UserDefinedLogicalNode for T {
320 fn as_any(&self) -> &dyn Any {
321 self
322 }
323
324 fn name(&self) -> &str {
325 self.name()
326 }
327
328 fn inputs(&self) -> Vec<&LogicalPlan> {
329 self.inputs()
330 }
331
332 fn schema(&self) -> &DFSchemaRef {
333 self.schema()
334 }
335
336 fn check_invariants(&self, check: InvariantLevel) -> Result<()> {
337 self.check_invariants(check)
338 }
339
340 fn expressions(&self) -> Vec<Expr> {
341 self.expressions()
342 }
343
344 fn prevent_predicate_push_down_columns(&self) -> HashSet<String> {
345 self.prevent_predicate_push_down_columns()
346 }
347
348 fn fmt_for_explain(&self, f: &mut fmt::Formatter) -> fmt::Result {
349 self.fmt_for_explain(f)
350 }
351
352 fn with_exprs_and_inputs(
353 &self,
354 exprs: Vec<Expr>,
355 inputs: Vec<LogicalPlan>,
356 ) -> Result<Arc<dyn UserDefinedLogicalNode>> {
357 Ok(Arc::new(self.with_exprs_and_inputs(exprs, inputs)?))
358 }
359
360 fn necessary_children_exprs(
361 &self,
362 output_columns: &[usize],
363 ) -> Option<Vec<Vec<usize>>> {
364 self.necessary_children_exprs(output_columns)
365 }
366
367 fn dyn_hash(&self, state: &mut dyn Hasher) {
368 let mut s = state;
369 self.hash(&mut s);
370 }
371
372 fn dyn_eq(&self, other: &dyn UserDefinedLogicalNode) -> bool {
373 match other.as_any().downcast_ref::<Self>() {
374 Some(o) => self == o,
375 None => false,
376 }
377 }
378
379 fn dyn_ord(&self, other: &dyn UserDefinedLogicalNode) -> Option<Ordering> {
380 other
381 .as_any()
382 .downcast_ref::<Self>()
383 .and_then(|other| self.partial_cmp(other))
384 }
385
386 fn supports_limit_pushdown(&self) -> bool {
387 self.supports_limit_pushdown()
388 }
389}
390
391fn get_all_columns_from_schema(schema: &DFSchema) -> HashSet<String> {
392 schema.fields().iter().map(|f| f.name().clone()).collect()
393}