Skip to main content

datafusion_python/expr/
join.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::fmt::{self, Display, Formatter};
19
20use datafusion::common::NullEquality;
21use datafusion::logical_expr::logical_plan::{Join, JoinConstraint, JoinType};
22use pyo3::IntoPyObjectExt;
23use pyo3::prelude::*;
24
25use crate::common::df_schema::PyDFSchema;
26use crate::expr::PyExpr;
27use crate::expr::logical_node::LogicalNode;
28use crate::sql::logical::PyLogicalPlan;
29
30#[derive(Debug, Clone, PartialEq, Eq, Hash)]
31#[pyclass(from_py_object, frozen, name = "JoinType", module = "datafusion.expr")]
32pub struct PyJoinType {
33    join_type: JoinType,
34}
35
36impl From<JoinType> for PyJoinType {
37    fn from(join_type: JoinType) -> PyJoinType {
38        PyJoinType { join_type }
39    }
40}
41
42impl From<PyJoinType> for JoinType {
43    fn from(join_type: PyJoinType) -> Self {
44        join_type.join_type
45    }
46}
47
48#[pymethods]
49impl PyJoinType {
50    pub fn is_outer(&self) -> bool {
51        self.join_type.is_outer()
52    }
53
54    fn __repr__(&self) -> PyResult<String> {
55        Ok(format!("{}", self.join_type))
56    }
57}
58
59impl Display for PyJoinType {
60    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
61        write!(f, "{}", self.join_type)
62    }
63}
64
65#[derive(Debug, Clone, Copy)]
66#[pyclass(
67    from_py_object,
68    frozen,
69    name = "JoinConstraint",
70    module = "datafusion.expr"
71)]
72pub struct PyJoinConstraint {
73    join_constraint: JoinConstraint,
74}
75
76impl From<JoinConstraint> for PyJoinConstraint {
77    fn from(join_constraint: JoinConstraint) -> PyJoinConstraint {
78        PyJoinConstraint { join_constraint }
79    }
80}
81
82impl From<PyJoinConstraint> for JoinConstraint {
83    fn from(join_constraint: PyJoinConstraint) -> Self {
84        join_constraint.join_constraint
85    }
86}
87
88#[pymethods]
89impl PyJoinConstraint {
90    fn __repr__(&self) -> PyResult<String> {
91        match self.join_constraint {
92            JoinConstraint::On => Ok("On".to_string()),
93            JoinConstraint::Using => Ok("Using".to_string()),
94        }
95    }
96}
97
98#[pyclass(
99    from_py_object,
100    frozen,
101    name = "Join",
102    module = "datafusion.expr",
103    subclass
104)]
105#[derive(Clone)]
106pub struct PyJoin {
107    join: Join,
108}
109
110impl From<Join> for PyJoin {
111    fn from(join: Join) -> PyJoin {
112        PyJoin { join }
113    }
114}
115
116impl From<PyJoin> for Join {
117    fn from(join: PyJoin) -> Self {
118        join.join
119    }
120}
121
122impl Display for PyJoin {
123    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
124        write!(
125            f,
126            "Join
127            Left: {:?}
128            Right: {:?}
129            On: {:?}
130            Filter: {:?}
131            JoinType: {:?}
132            JoinConstraint: {:?}
133            Schema: {:?}
134            NullEquality: {:?}",
135            &self.join.left,
136            &self.join.right,
137            &self.join.on,
138            &self.join.filter,
139            &self.join.join_type,
140            &self.join.join_constraint,
141            &self.join.schema,
142            &self.join.null_equality,
143        )
144    }
145}
146
147#[pymethods]
148impl PyJoin {
149    /// Retrieves the left input `LogicalPlan` to this `Join` node
150    fn left(&self) -> PyResult<PyLogicalPlan> {
151        Ok(self.join.left.as_ref().clone().into())
152    }
153
154    /// Retrieves the right input `LogicalPlan` to this `Join` node
155    fn right(&self) -> PyResult<PyLogicalPlan> {
156        Ok(self.join.right.as_ref().clone().into())
157    }
158
159    /// Retrieves the right input `LogicalPlan` to this `Join` node
160    fn on(&self) -> PyResult<Vec<(PyExpr, PyExpr)>> {
161        Ok(self
162            .join
163            .on
164            .iter()
165            .map(|(l, r)| (PyExpr::from(l.clone()), PyExpr::from(r.clone())))
166            .collect())
167    }
168
169    /// Retrieves the filter `Option<PyExpr>` of this `Join` node
170    fn filter(&self) -> PyResult<Option<PyExpr>> {
171        Ok(self.join.filter.clone().map(Into::into))
172    }
173
174    /// Retrieves the `JoinType` to this `Join` node
175    fn join_type(&self) -> PyResult<PyJoinType> {
176        Ok(self.join.join_type.into())
177    }
178
179    /// Retrieves the `JoinConstraint` to this `Join` node
180    fn join_constraint(&self) -> PyResult<PyJoinConstraint> {
181        Ok(self.join.join_constraint.into())
182    }
183
184    /// Resulting Schema for this `Join` node instance
185    fn schema(&self) -> PyResult<PyDFSchema> {
186        Ok(self.join.schema.as_ref().clone().into())
187    }
188
189    /// If null_equals_null is true, null == null else null != null
190    fn null_equals_null(&self) -> PyResult<bool> {
191        match self.join.null_equality {
192            NullEquality::NullEqualsNothing => Ok(false),
193            NullEquality::NullEqualsNull => Ok(true),
194        }
195    }
196
197    fn __repr__(&self) -> PyResult<String> {
198        Ok(format!("Join({self})"))
199    }
200
201    fn __name__(&self) -> PyResult<String> {
202        Ok("Join".to_string())
203    }
204}
205
206impl LogicalNode for PyJoin {
207    fn inputs(&self) -> Vec<PyLogicalPlan> {
208        vec![
209            PyLogicalPlan::from((*self.join.left).clone()),
210            PyLogicalPlan::from((*self.join.right).clone()),
211        ]
212    }
213
214    fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
215        self.clone().into_bound_py_any(py)
216    }
217}