use datafusion_expr::logical_plan::{Join, JoinConstraint, JoinType};
use pyo3::prelude::*;
use std::fmt::{self, Display, Formatter};
use crate::common::df_schema::PyDFSchema;
use crate::expr::{logical_node::LogicalNode, PyExpr};
use crate::sql::logical::PyLogicalPlan;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
#[pyclass(name = "JoinType", module = "datafusion.expr")]
pub struct PyJoinType {
join_type: JoinType,
}
impl From<JoinType> for PyJoinType {
fn from(join_type: JoinType) -> PyJoinType {
PyJoinType { join_type }
}
}
impl From<PyJoinType> for JoinType {
fn from(join_type: PyJoinType) -> Self {
join_type.join_type
}
}
#[pymethods]
impl PyJoinType {
pub fn is_outer(&self) -> bool {
self.join_type.is_outer()
}
}
impl Display for PyJoinType {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "{}", self.join_type)
}
}
#[derive(Debug, Clone, Copy)]
#[pyclass(name = "JoinConstraint", module = "datafusion.expr")]
pub struct PyJoinConstraint {
join_constraint: JoinConstraint,
}
impl From<JoinConstraint> for PyJoinConstraint {
fn from(join_constraint: JoinConstraint) -> PyJoinConstraint {
PyJoinConstraint { join_constraint }
}
}
impl From<PyJoinConstraint> for JoinConstraint {
fn from(join_constraint: PyJoinConstraint) -> Self {
join_constraint.join_constraint
}
}
#[pyclass(name = "Join", module = "datafusion.expr", subclass)]
#[derive(Clone)]
pub struct PyJoin {
join: Join,
}
impl From<Join> for PyJoin {
fn from(join: Join) -> PyJoin {
PyJoin { join }
}
}
impl From<PyJoin> for Join {
fn from(join: PyJoin) -> Self {
join.join
}
}
impl Display for PyJoin {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(
f,
"Join
Left: {:?}
Right: {:?}
On: {:?}
Filter: {:?}
JoinType: {:?}
JoinConstraint: {:?}
Schema: {:?}
NullEqualsNull: {:?}",
&self.join.left,
&self.join.right,
&self.join.on,
&self.join.filter,
&self.join.join_type,
&self.join.join_constraint,
&self.join.schema,
&self.join.null_equals_null,
)
}
}
#[pymethods]
impl PyJoin {
fn left(&self) -> PyResult<PyLogicalPlan> {
Ok(self.join.left.as_ref().clone().into())
}
fn right(&self) -> PyResult<PyLogicalPlan> {
Ok(self.join.right.as_ref().clone().into())
}
fn on(&self) -> PyResult<Vec<(PyExpr, PyExpr)>> {
Ok(self
.join
.on
.iter()
.map(|(l, r)| (PyExpr::from(l.clone()), PyExpr::from(r.clone())))
.collect())
}
fn filter(&self) -> PyResult<Option<PyExpr>> {
Ok(self.join.filter.clone().map(Into::into))
}
fn join_type(&self) -> PyResult<PyJoinType> {
Ok(self.join.join_type.into())
}
fn join_constraint(&self) -> PyResult<PyJoinConstraint> {
Ok(self.join.join_constraint.into())
}
fn schema(&self) -> PyResult<PyDFSchema> {
Ok(self.join.schema.as_ref().clone().into())
}
fn null_equals_null(&self) -> PyResult<bool> {
Ok(self.join.null_equals_null)
}
fn __repr__(&self) -> PyResult<String> {
Ok(format!("Join({})", self))
}
fn __name__(&self) -> PyResult<String> {
Ok("Join".to_string())
}
}
impl LogicalNode for PyJoin {
fn inputs(&self) -> Vec<PyLogicalPlan> {
vec![
PyLogicalPlan::from((*self.join.left).clone()),
PyLogicalPlan::from((*self.join.right).clone()),
]
}
fn to_variant(&self, py: Python) -> PyResult<PyObject> {
Ok(self.clone().into_py(py))
}
}