Skip to main content

robin_sparkless_polars/
engine_backend.rs

1//! Implementations of `robin-sparkless-core` engine traits for the Polars backend.
2//!
3//! This module is the **Polars adapter layer**: it implements the engine-agnostic
4//! traits defined in `robin-sparkless-core::engine` (`SparkSessionBackend`,
5//! `DataFrameReaderBackend`, `DataFrameBackend`, `GroupedDataBackend`) in terms of
6//! the concrete Polars-backed types from this crate (`SparkSession`, `DataFrame`,
7//! `GroupedData`, `DataFrameReader`).
8//!
9//! High-level code in the root crate should depend on these traits (via
10//! `robin-sparkless-core::engine`) rather than on Polars directly; this keeps
11//! the execution engine swappable while the public expression IR and engine
12//! interfaces remain stable.
13
14use crate::dataframe::{DataFrame, GroupedData, JoinType as PlJoinType};
15use crate::error::polars_to_core_error;
16use crate::expr_ir::expr_ir_to_expr;
17use crate::session::{DataFrameReader, SparkSession};
18use polars::prelude::PolarsError;
19use robin_sparkless_core::engine::{
20    DataFrameBackend, DataFrameReaderBackend, GroupedDataBackend, SparkSessionBackend,
21};
22use robin_sparkless_core::error::EngineError as CoreEngineError;
23use robin_sparkless_core::expr::ExprIr;
24use robin_sparkless_core::schema::StructType;
25use std::path::Path;
26
27fn map_err(e: PolarsError) -> CoreEngineError {
28    polars_to_core_error(e)
29}
30
31/// Core and polars EngineError are the same type; use for clarity in ? chains.
32#[inline]
33fn to_core(e: robin_sparkless_core::EngineError) -> CoreEngineError {
34    e
35}
36
37fn downcast_df<'a>(
38    other: &'a dyn DataFrameBackend,
39    op: &str,
40) -> Result<&'a DataFrame, CoreEngineError> {
41    other.as_any().downcast_ref::<DataFrame>().ok_or_else(|| {
42        CoreEngineError::User(format!("{} only supported with same backend (Polars)", op))
43    })
44}
45
46impl DataFrameBackend for DataFrame {
47    fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
48        self
49    }
50
51    fn filter(&self, condition: &ExprIr) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
52        let expr = expr_ir_to_expr(condition).map_err(to_core)?;
53        let df = self.filter(expr).map_err(map_err)?;
54        Ok(Box::new(df))
55    }
56
57    fn select(&self, exprs: &[ExprIr]) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
58        let exprs: Vec<_> = exprs
59            .iter()
60            .map(expr_ir_to_expr)
61            .collect::<Result<Vec<_>, _>>()
62            .map_err(to_core)?;
63        let df = self.select_exprs(exprs).map_err(map_err)?;
64        Ok(Box::new(df))
65    }
66
67    fn select_columns(
68        &self,
69        columns: &[&str],
70    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
71        let df = self.select(columns.to_vec()).map_err(map_err)?;
72        Ok(Box::new(df))
73    }
74
75    fn with_column(
76        &self,
77        name: &str,
78        expr: &ExprIr,
79    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
80        let e = expr_ir_to_expr(expr).map_err(to_core)?;
81        let df = self.with_column_expr(name, e).map_err(map_err)?;
82        Ok(Box::new(df))
83    }
84
85    fn join(
86        &self,
87        other: &dyn DataFrameBackend,
88        on: &[&str],
89        how: robin_sparkless_core::engine::JoinType,
90    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
91        let right = downcast_df(other, "join")?;
92        let pl_how = match how {
93            robin_sparkless_core::engine::JoinType::Inner => PlJoinType::Inner,
94            robin_sparkless_core::engine::JoinType::Left => PlJoinType::Left,
95            robin_sparkless_core::engine::JoinType::Right => PlJoinType::Right,
96            robin_sparkless_core::engine::JoinType::Full => PlJoinType::Outer,
97            robin_sparkless_core::engine::JoinType::LeftAnti => PlJoinType::LeftAnti,
98            robin_sparkless_core::engine::JoinType::LeftSemi => PlJoinType::LeftSemi,
99            robin_sparkless_core::engine::JoinType::Cross => {
100                let df = self.cross_join(right).map_err(map_err)?;
101                return Ok(Box::new(df));
102            }
103        };
104        let df = self.join(right, on.to_vec(), pl_how).map_err(map_err)?;
105        Ok(Box::new(df))
106    }
107
108    fn group_by(
109        &self,
110        column_names: &[&str],
111    ) -> Result<Box<dyn GroupedDataBackend>, CoreEngineError> {
112        let g = self.group_by(column_names.to_vec()).map_err(map_err)?;
113        Ok(Box::new(g))
114    }
115
116    fn order_by(
117        &self,
118        column_names: &[&str],
119        ascending: &[bool],
120    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
121        let asc: Vec<bool> = ascending.to_vec();
122        let df = self.order_by(column_names.to_vec(), asc).map_err(map_err)?;
123        Ok(Box::new(df))
124    }
125
126    fn limit(&self, n: usize) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
127        let df = self.limit(n).map_err(map_err)?;
128        Ok(Box::new(df))
129    }
130
131    fn union(
132        &self,
133        other: &dyn DataFrameBackend,
134    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
135        let right = downcast_df(other, "union")?;
136        let df = self.union(right).map_err(map_err)?;
137        Ok(Box::new(df))
138    }
139
140    fn union_by_name(
141        &self,
142        other: &dyn DataFrameBackend,
143        allow_missing_columns: bool,
144    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
145        let right = downcast_df(other, "union_by_name")?;
146        let df = self
147            .union_by_name(right, allow_missing_columns)
148            .map_err(map_err)?;
149        Ok(Box::new(df))
150    }
151
152    fn distinct(
153        &self,
154        subset: Option<Vec<&str>>,
155    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
156        let df = self.distinct(subset).map_err(map_err)?;
157        Ok(Box::new(df))
158    }
159
160    fn drop_columns(&self, columns: &[&str]) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
161        let df = self.drop(columns.to_vec()).map_err(map_err)?;
162        Ok(Box::new(df))
163    }
164
165    fn with_column_renamed(
166        &self,
167        old_name: &str,
168        new_name: &str,
169    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
170        let df = self
171            .with_column_renamed(old_name, new_name)
172            .map_err(map_err)?;
173        Ok(Box::new(df))
174    }
175
176    fn cross_join(
177        &self,
178        other: &dyn DataFrameBackend,
179    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
180        let right = downcast_df(other, "cross_join")?;
181        let df = self.cross_join(right).map_err(map_err)?;
182        Ok(Box::new(df))
183    }
184
185    fn collect(&self) -> Result<robin_sparkless_core::engine::CollectedRows, CoreEngineError> {
186        self.collect_as_json_rows().map_err(map_err)
187    }
188
189    fn schema(&self) -> Result<StructType, CoreEngineError> {
190        DataFrame::schema(self).map_err(map_err)
191    }
192
193    fn columns(&self) -> Result<Vec<String>, CoreEngineError> {
194        DataFrame::columns(self).map_err(map_err)
195    }
196
197    fn count(&self) -> Result<u64, CoreEngineError> {
198        let n = DataFrame::count(self).map_err(map_err)?;
199        Ok(n as u64)
200    }
201}
202
203impl GroupedDataBackend for GroupedData {
204    fn agg(&self, exprs: &[ExprIr]) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
205        let pl_exprs: Vec<_> = exprs
206            .iter()
207            .map(expr_ir_to_expr)
208            .collect::<Result<Vec<_>, _>>()
209            .map_err(to_core)?;
210        let df = self.agg(pl_exprs).map_err(map_err)?;
211        Ok(Box::new(df))
212    }
213
214    fn count(&self) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
215        let df = self.count().map_err(map_err)?;
216        Ok(Box::new(df))
217    }
218
219    fn sum(&self, column: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
220        let df = self.sum(column).map_err(map_err)?;
221        Ok(Box::new(df))
222    }
223
224    fn min(&self, column: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
225        let df = self.min(column).map_err(map_err)?;
226        Ok(Box::new(df))
227    }
228
229    fn max(&self, column: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
230        let df = self.max(column).map_err(map_err)?;
231        Ok(Box::new(df))
232    }
233
234    fn mean(&self, column: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
235        let df = self.avg(&[column]).map_err(map_err)?;
236        Ok(Box::new(df))
237    }
238
239    fn avg(&self, columns: &[&str]) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
240        let df = self.avg(columns).map_err(map_err)?;
241        Ok(Box::new(df))
242    }
243}
244
245impl DataFrameReaderBackend for DataFrameReader {
246    fn csv(&self, path: &Path) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
247        let df = self.csv(path).map_err(map_err)?;
248        Ok(Box::new(df))
249    }
250
251    fn parquet(&self, path: &Path) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
252        let df = self.parquet(path).map_err(map_err)?;
253        Ok(Box::new(df))
254    }
255
256    fn json(&self, path: &Path) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
257        let df = self.json(path).map_err(map_err)?;
258        Ok(Box::new(df))
259    }
260
261    fn table(&self, name: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
262        let df = self.table(name).map_err(map_err)?;
263        Ok(Box::new(df))
264    }
265}
266
267impl SparkSessionBackend for SparkSession {
268    fn read(&self) -> Box<dyn DataFrameReaderBackend> {
269        Box::new(DataFrameReader::new(self.clone()))
270    }
271
272    fn table(&self, name: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
273        let df = self.table(name).map_err(map_err)?;
274        Ok(Box::new(df))
275    }
276
277    fn create_dataframe_from_rows(
278        &self,
279        rows: Vec<Vec<serde_json::Value>>,
280        schema: Vec<(String, String)>,
281        verify_schema: bool,
282        schema_was_inferred: bool,
283    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
284        let df = self
285            .create_dataframe_from_rows(rows, schema, verify_schema, schema_was_inferred)
286            .map_err(map_err)?;
287        Ok(Box::new(df))
288    }
289
290    fn create_dataframe(
291        &self,
292        data: Vec<(i64, i64, String)>,
293        column_names: Vec<&str>,
294    ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
295        let df = self.create_dataframe(data, column_names).map_err(map_err)?;
296        Ok(Box::new(df))
297    }
298
299    fn sql(&self, query: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
300        let df = self.sql(query).map_err(map_err)?;
301        Ok(Box::new(df))
302    }
303
304    fn register_table(&self, name: &str, df: &dyn DataFrameBackend) {
305        let polars_df = df
306            .as_any()
307            .downcast_ref::<DataFrame>()
308            .expect("register_table only supported with same backend (Polars)");
309        SparkSession::register_table(self, name, polars_df.clone());
310    }
311
312    fn is_case_sensitive(&self) -> bool {
313        self.is_case_sensitive()
314    }
315
316    fn get_config(&self) -> &std::collections::HashMap<String, String> {
317        self.get_config()
318    }
319}