datafusion_python/sql/
logical.rs1use std::sync::Arc;
19
20use crate::errors::PyDataFusionResult;
21use crate::expr::aggregate::PyAggregate;
22use crate::expr::analyze::PyAnalyze;
23use crate::expr::distinct::PyDistinct;
24use crate::expr::empty_relation::PyEmptyRelation;
25use crate::expr::explain::PyExplain;
26use crate::expr::extension::PyExtension;
27use crate::expr::filter::PyFilter;
28use crate::expr::join::PyJoin;
29use crate::expr::limit::PyLimit;
30use crate::expr::projection::PyProjection;
31use crate::expr::sort::PySort;
32use crate::expr::subquery::PySubquery;
33use crate::expr::subquery_alias::PySubqueryAlias;
34use crate::expr::table_scan::PyTableScan;
35use crate::expr::unnest::PyUnnest;
36use crate::expr::window::PyWindowExpr;
37use crate::{context::PySessionContext, errors::py_unsupported_variant_err};
38use datafusion::logical_expr::LogicalPlan;
39use datafusion_proto::logical_plan::{AsLogicalPlan, DefaultLogicalExtensionCodec};
40use prost::Message;
41use pyo3::{exceptions::PyRuntimeError, prelude::*, types::PyBytes};
42
43use crate::expr::logical_node::LogicalNode;
44
45#[pyclass(name = "LogicalPlan", module = "datafusion", subclass)]
46#[derive(Debug, Clone)]
47pub struct PyLogicalPlan {
48 pub(crate) plan: Arc<LogicalPlan>,
49}
50
51impl PyLogicalPlan {
52 pub fn new(plan: LogicalPlan) -> Self {
54 Self {
55 plan: Arc::new(plan),
56 }
57 }
58
59 pub fn plan(&self) -> Arc<LogicalPlan> {
60 self.plan.clone()
61 }
62}
63
64#[pymethods]
65impl PyLogicalPlan {
66 pub fn to_variant<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
68 match self.plan.as_ref() {
69 LogicalPlan::Aggregate(plan) => PyAggregate::from(plan.clone()).to_variant(py),
70 LogicalPlan::Analyze(plan) => PyAnalyze::from(plan.clone()).to_variant(py),
71 LogicalPlan::Distinct(plan) => PyDistinct::from(plan.clone()).to_variant(py),
72 LogicalPlan::EmptyRelation(plan) => PyEmptyRelation::from(plan.clone()).to_variant(py),
73 LogicalPlan::Explain(plan) => PyExplain::from(plan.clone()).to_variant(py),
74 LogicalPlan::Extension(plan) => PyExtension::from(plan.clone()).to_variant(py),
75 LogicalPlan::Filter(plan) => PyFilter::from(plan.clone()).to_variant(py),
76 LogicalPlan::Join(plan) => PyJoin::from(plan.clone()).to_variant(py),
77 LogicalPlan::Limit(plan) => PyLimit::from(plan.clone()).to_variant(py),
78 LogicalPlan::Projection(plan) => PyProjection::from(plan.clone()).to_variant(py),
79 LogicalPlan::Sort(plan) => PySort::from(plan.clone()).to_variant(py),
80 LogicalPlan::TableScan(plan) => PyTableScan::from(plan.clone()).to_variant(py),
81 LogicalPlan::Subquery(plan) => PySubquery::from(plan.clone()).to_variant(py),
82 LogicalPlan::SubqueryAlias(plan) => PySubqueryAlias::from(plan.clone()).to_variant(py),
83 LogicalPlan::Unnest(plan) => PyUnnest::from(plan.clone()).to_variant(py),
84 LogicalPlan::Window(plan) => PyWindowExpr::from(plan.clone()).to_variant(py),
85 LogicalPlan::Repartition(_)
86 | LogicalPlan::Union(_)
87 | LogicalPlan::Statement(_)
88 | LogicalPlan::Values(_)
89 | LogicalPlan::Dml(_)
90 | LogicalPlan::Ddl(_)
91 | LogicalPlan::Copy(_)
92 | LogicalPlan::DescribeTable(_)
93 | LogicalPlan::RecursiveQuery(_) => Err(py_unsupported_variant_err(format!(
94 "Conversion of variant not implemented: {:?}",
95 self.plan
96 ))),
97 }
98 }
99
100 fn inputs(&self) -> Vec<PyLogicalPlan> {
102 let mut inputs = vec![];
103 for input in self.plan.inputs() {
104 inputs.push(input.to_owned().into());
105 }
106 inputs
107 }
108
109 fn __repr__(&self) -> PyResult<String> {
110 Ok(format!("{:?}", self.plan))
111 }
112
113 fn display(&self) -> String {
114 format!("{}", self.plan.display())
115 }
116
117 fn display_indent(&self) -> String {
118 format!("{}", self.plan.display_indent())
119 }
120
121 fn display_indent_schema(&self) -> String {
122 format!("{}", self.plan.display_indent_schema())
123 }
124
125 fn display_graphviz(&self) -> String {
126 format!("{}", self.plan.display_graphviz())
127 }
128
129 pub fn to_proto<'py>(&'py self, py: Python<'py>) -> PyDataFusionResult<Bound<'py, PyBytes>> {
130 let codec = DefaultLogicalExtensionCodec {};
131 let proto =
132 datafusion_proto::protobuf::LogicalPlanNode::try_from_logical_plan(&self.plan, &codec)?;
133
134 let bytes = proto.encode_to_vec();
135 Ok(PyBytes::new(py, &bytes))
136 }
137
138 #[staticmethod]
139 pub fn from_proto(
140 ctx: PySessionContext,
141 proto_msg: Bound<'_, PyBytes>,
142 ) -> PyDataFusionResult<Self> {
143 let bytes: &[u8] = proto_msg.extract()?;
144 let proto_plan =
145 datafusion_proto::protobuf::LogicalPlanNode::decode(bytes).map_err(|e| {
146 PyRuntimeError::new_err(format!(
147 "Unable to decode logical node from serialized bytes: {}",
148 e
149 ))
150 })?;
151
152 let codec = DefaultLogicalExtensionCodec {};
153 let plan = proto_plan.try_into_logical_plan(&ctx.ctx, &codec)?;
154 Ok(Self::new(plan))
155 }
156}
157
158impl From<PyLogicalPlan> for LogicalPlan {
159 fn from(logical_plan: PyLogicalPlan) -> LogicalPlan {
160 logical_plan.plan.as_ref().clone()
161 }
162}
163
164impl From<LogicalPlan> for PyLogicalPlan {
165 fn from(logical_plan: LogicalPlan) -> PyLogicalPlan {
166 PyLogicalPlan {
167 plan: Arc::new(logical_plan),
168 }
169 }
170}