datafusion_python/
table.rs1use std::any::Any;
19use std::sync::Arc;
20
21use arrow::datatypes::SchemaRef;
22use arrow::pyarrow::ToPyArrow;
23use async_trait::async_trait;
24use datafusion::catalog::Session;
25use datafusion::common::Column;
26use datafusion::datasource::{TableProvider, TableType};
27use datafusion::logical_expr::{Expr, LogicalPlanBuilder, TableProviderFilterPushDown};
28use datafusion::physical_plan::ExecutionPlan;
29use datafusion::prelude::DataFrame;
30use pyo3::prelude::*;
31
32use crate::dataframe::PyDataFrame;
33use crate::dataset::Dataset;
34use crate::utils::table_provider_from_pycapsule;
35
36#[pyclass(frozen, name = "RawTable", module = "datafusion.catalog", subclass)]
40#[derive(Clone)]
41pub struct PyTable {
42 pub table: Arc<dyn TableProvider>,
43}
44
45impl PyTable {
46 pub fn table(&self) -> Arc<dyn TableProvider> {
47 self.table.clone()
48 }
49}
50
51#[pymethods]
52impl PyTable {
53 #[new]
63 pub fn new(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
64 if let Ok(py_table) = obj.extract::<PyTable>() {
65 Ok(py_table)
66 } else if let Ok(py_table) = obj
67 .getattr("_inner")
68 .and_then(|inner| inner.extract::<PyTable>())
69 {
70 Ok(py_table)
71 } else if let Ok(py_df) = obj.extract::<PyDataFrame>() {
72 let provider = py_df.inner_df().as_ref().clone().into_view();
73 Ok(PyTable::from(provider))
74 } else if let Ok(py_df) = obj
75 .getattr("df")
76 .and_then(|inner| inner.extract::<PyDataFrame>())
77 {
78 let provider = py_df.inner_df().as_ref().clone().into_view();
79 Ok(PyTable::from(provider))
80 } else if let Some(provider) = table_provider_from_pycapsule(obj)? {
81 Ok(PyTable::from(provider))
82 } else {
83 let py = obj.py();
84 let provider = Arc::new(Dataset::new(obj, py)?) as Arc<dyn TableProvider>;
85 Ok(PyTable::from(provider))
86 }
87 }
88
89 #[getter]
91 fn schema<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
92 self.table.schema().to_pyarrow(py)
93 }
94
95 #[getter]
97 fn kind(&self) -> &str {
98 match self.table.table_type() {
99 TableType::Base => "physical",
100 TableType::View => "view",
101 TableType::Temporary => "temporary",
102 }
103 }
104
105 fn __repr__(&self) -> PyResult<String> {
106 let kind = self.kind();
107 Ok(format!("Table(kind={kind})"))
108 }
109}
110
111impl From<Arc<dyn TableProvider>> for PyTable {
112 fn from(table: Arc<dyn TableProvider>) -> Self {
113 Self { table }
114 }
115}
116
117#[derive(Clone, Debug)]
118pub(crate) struct TempViewTable {
119 df: Arc<DataFrame>,
120}
121
122impl TempViewTable {
127 pub(crate) fn new(df: Arc<DataFrame>) -> Self {
128 Self { df }
129 }
130}
131
132#[async_trait]
133impl TableProvider for TempViewTable {
134 fn as_any(&self) -> &dyn Any {
135 self
136 }
137
138 fn schema(&self) -> SchemaRef {
139 Arc::new(self.df.schema().as_arrow().clone())
140 }
141
142 fn table_type(&self) -> TableType {
143 TableType::Temporary
144 }
145
146 async fn scan(
147 &self,
148 state: &dyn Session,
149 projection: Option<&Vec<usize>>,
150 filters: &[Expr],
151 limit: Option<usize>,
152 ) -> datafusion::common::Result<Arc<dyn ExecutionPlan>> {
153 let filter = filters.iter().cloned().reduce(|acc, new| acc.and(new));
154 let plan = self.df.logical_plan().clone();
155 let mut plan = LogicalPlanBuilder::from(plan);
156
157 if let Some(filter) = filter {
158 plan = plan.filter(filter)?;
159 }
160
161 let mut plan = if let Some(projection) = projection {
162 let current_projection = (0..plan.schema().fields().len()).collect::<Vec<usize>>();
164 if projection == ¤t_projection {
165 plan
166 } else {
167 let fields: Vec<Expr> = projection
168 .iter()
169 .map(|i| {
170 Expr::Column(Column::from(
171 self.df.logical_plan().schema().qualified_field(*i),
172 ))
173 })
174 .collect();
175 plan.project(fields)?
176 }
177 } else {
178 plan
179 };
180
181 if let Some(limit) = limit {
182 plan = plan.limit(0, Some(limit))?;
183 }
184
185 state.create_physical_plan(&plan.build()?).await
186 }
187
188 fn supports_filters_pushdown(
189 &self,
190 filters: &[&Expr],
191 ) -> datafusion::common::Result<Vec<TableProviderFilterPushDown>> {
192 Ok(vec![TableProviderFilterPushDown::Exact; filters.len()])
193 }
194}