datafusion_python/sql/
logical.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use std::sync::Arc;
19
20use crate::errors::PyDataFusionResult;
21use crate::expr::aggregate::PyAggregate;
22use crate::expr::analyze::PyAnalyze;
23use crate::expr::distinct::PyDistinct;
24use crate::expr::empty_relation::PyEmptyRelation;
25use crate::expr::explain::PyExplain;
26use crate::expr::extension::PyExtension;
27use crate::expr::filter::PyFilter;
28use crate::expr::join::PyJoin;
29use crate::expr::limit::PyLimit;
30use crate::expr::projection::PyProjection;
31use crate::expr::sort::PySort;
32use crate::expr::subquery::PySubquery;
33use crate::expr::subquery_alias::PySubqueryAlias;
34use crate::expr::table_scan::PyTableScan;
35use crate::expr::unnest::PyUnnest;
36use crate::expr::window::PyWindowExpr;
37use crate::{context::PySessionContext, errors::py_unsupported_variant_err};
38use datafusion::logical_expr::LogicalPlan;
39use datafusion_proto::logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec};
40use prost::Message;
41use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes};
42
43use crate::expr::logical_node::LogicalNode;
44
45#[pyclass(name = "LogicalPlan", module = "datafusion", subclass)]
46#[derive(Debug, Clone)]
47pub struct PyLogicalPlan {
48    pub(crate) plan: Arc<LogicalPlan>,
49}
50
51impl PyLogicalPlan {
52    /// creates a new PyLogicalPlan
53    pub fn new(plan: LogicalPlan) -> Self {
54        Self {
55            plan: Arc::new(plan),
56        }
57    }
58
59    pub fn plan(&self) -> Arc<LogicalPlan> {
60        self.plan.clone()
61    }
62}
63
64#[pymethods]
65impl PyLogicalPlan {
66    /// Return the specific logical operator
67    pub fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
68        match self.plan.as_ref() {
69            LogicalPlan::Aggregate(plan) => PyAggregate::from(plan.clone()).to_variant(py),
70            LogicalPlan::Analyze(plan) => PyAnalyze::from(plan.clone()).to_variant(py),
71            LogicalPlan::Distinct(plan) => PyDistinct::from(plan.clone()).to_variant(py),
72            LogicalPlan::EmptyRelation(plan) => PyEmptyRelation::from(plan.clone()).to_variant(py),
73            LogicalPlan::Explain(plan) => PyExplain::from(plan.clone()).to_variant(py),
74            LogicalPlan::Extension(plan) => PyExtension::from(plan.clone()).to_variant(py),
75            LogicalPlan::Filter(plan) => PyFilter::from(plan.clone()).to_variant(py),
76            LogicalPlan::Join(plan) => PyJoin::from(plan.clone()).to_variant(py),
77            LogicalPlan::Limit(plan) => PyLimit::from(plan.clone()).to_variant(py),
78            LogicalPlan::Projection(plan) => PyProjection::from(plan.clone()).to_variant(py),
79            LogicalPlan::Sort(plan) => PySort::from(plan.clone()).to_variant(py),
80            LogicalPlan::TableScan(plan) => PyTableScan::from(plan.clone()).to_variant(py),
81            LogicalPlan::Subquery(plan) => PySubquery::from(plan.clone()).to_variant(py),
82            LogicalPlan::SubqueryAlias(plan) => PySubqueryAlias::from(plan.clone()).to_variant(py),
83            LogicalPlan::Unnest(plan) => PyUnnest::from(plan.clone()).to_variant(py),
84            LogicalPlan::Window(plan) => PyWindowExpr::from(plan.clone()).to_variant(py),
85            LogicalPlan::Repartition(_)
86            | LogicalPlan::Union(_)
87            | LogicalPlan::Statement(_)
88            | LogicalPlan::Values(_)
89            | LogicalPlan::Dml(_)
90            | LogicalPlan::Ddl(_)
91            | LogicalPlan::Copy(_)
92            | LogicalPlan::DescribeTable(_)
93            | LogicalPlan::RecursiveQuery(_) => Err(py_unsupported_variant_err(format!(
94                "Conversion of variant not implemented: {:?}",
95                self.plan
96            ))),
97        }
98    }
99
100    /// Get the inputs to this plan
101    fn inputs(&self) -> Vec<PyLogicalPlan> {
102        let mut inputs = vec![];
103        for input in self.plan.inputs() {
104            inputs.push(input.to_owned().into());
105        }
106        inputs
107    }
108
109    fn __repr__(&self) -> PyResult<String> {
110        Ok(format!("{:?}", self.plan))
111    }
112
113    fn display(&self) -> String {
114        format!("{}", self.plan.display())
115    }
116
117    fn display_indent(&self) -> String {
118        format!("{}", self.plan.display_indent())
119    }
120
121    fn display_indent_schema(&self) -> String {
122        format!("{}", self.plan.display_indent_schema())
123    }
124
125    fn display_graphviz(&self) -> String {
126        format!("{}", self.plan.display_graphviz())
127    }
128
129    pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyDataFusionResult<Bound<'py, PyBytes>> {
130        let codec = DefaultLogicalExtensionCodec {};
131        let proto =
132            datafusion_proto::protobuf::LogicalPlanNode::try_from_logical_plan(&self.plan, &codec)?;
133
134        let bytes = proto.encode_to_vec();
135        Ok(PyBytes::new(py, &bytes))
136    }
137
138    #[staticmethod]
139    pub fn from_proto(
140        ctx: PySessionContext,
141        proto_msg: Bound<'_, PyBytes>,
142    ) -> PyDataFusionResult<Self> {
143        let bytes: &[u8] = proto_msg.extract()?;
144        let proto_plan =
145            datafusion_proto::protobuf::LogicalPlanNode::decode(bytes).map_err(|e| {
146                PyRuntimeError::new_err(format!(
147                    "Unable to decode logical node from serialized bytes: {}",
148                    e
149                ))
150            })?;
151
152        let codec = DefaultLogicalExtensionCodec {};
153        let plan = proto_plan.try_into_logical_plan(&ctx.ctx, &codec)?;
154        Ok(Self::new(plan))
155    }
156}
157
158impl From<PyLogicalPlan> for LogicalPlan {
159    fn from(logical_plan: PyLogicalPlan) -> LogicalPlan {
160        logical_plan.plan.as_ref().clone()
161    }
162}
163
164impl From<LogicalPlan> for PyLogicalPlan {
165    fn from(logical_plan: LogicalPlan) -> PyLogicalPlan {
166        PyLogicalPlan {
167            plan: Arc::new(logical_plan),
168        }
169    }
170}