1use crate::dataframe::{DataFrame, GroupedData, JoinType as PlJoinType};
15use crate::error::polars_to_core_error;
16use crate::expr_ir::expr_ir_to_expr;
17use crate::session::{DataFrameReader, SparkSession};
18use polars::prelude::PolarsError;
19use robin_sparkless_core::engine::{
20 DataFrameBackend, DataFrameReaderBackend, GroupedDataBackend, SparkSessionBackend,
21};
22use robin_sparkless_core::error::EngineError as CoreEngineError;
23use robin_sparkless_core::expr::ExprIr;
24use robin_sparkless_core::schema::StructType;
25use std::path::Path;
26
27fn map_err(e: PolarsError) -> CoreEngineError {
28 polars_to_core_error(e)
29}
30
31#[inline]
33fn to_core(e: robin_sparkless_core::EngineError) -> CoreEngineError {
34 e
35}
36
37fn downcast_df<'a>(
38 other: &'a dyn DataFrameBackend,
39 op: &str,
40) -> Result<&'a DataFrame, CoreEngineError> {
41 other.as_any().downcast_ref::<DataFrame>().ok_or_else(|| {
42 CoreEngineError::User(format!("{} only supported with same backend (Polars)", op))
43 })
44}
45
46impl DataFrameBackend for DataFrame {
47 fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
48 self
49 }
50
51 fn filter(&self, condition: &ExprIr) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
52 let expr = expr_ir_to_expr(condition).map_err(to_core)?;
53 let df = self.filter(expr).map_err(map_err)?;
54 Ok(Box::new(df))
55 }
56
57 fn select(&self, exprs: &[ExprIr]) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
58 let exprs: Vec<_> = exprs
59 .iter()
60 .map(expr_ir_to_expr)
61 .collect::<Result<Vec<_>, _>>()
62 .map_err(to_core)?;
63 let df = self.select_exprs(exprs).map_err(map_err)?;
64 Ok(Box::new(df))
65 }
66
67 fn select_columns(
68 &self,
69 columns: &[&str],
70 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
71 let df = self.select(columns.to_vec()).map_err(map_err)?;
72 Ok(Box::new(df))
73 }
74
75 fn with_column(
76 &self,
77 name: &str,
78 expr: &ExprIr,
79 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
80 let e = expr_ir_to_expr(expr).map_err(to_core)?;
81 let df = self.with_column_expr(name, e).map_err(map_err)?;
82 Ok(Box::new(df))
83 }
84
85 fn join(
86 &self,
87 other: &dyn DataFrameBackend,
88 on: &[&str],
89 how: robin_sparkless_core::engine::JoinType,
90 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
91 let right = downcast_df(other, "join")?;
92 let pl_how = match how {
93 robin_sparkless_core::engine::JoinType::Inner => PlJoinType::Inner,
94 robin_sparkless_core::engine::JoinType::Left => PlJoinType::Left,
95 robin_sparkless_core::engine::JoinType::Right => PlJoinType::Right,
96 robin_sparkless_core::engine::JoinType::Full => PlJoinType::Outer,
97 robin_sparkless_core::engine::JoinType::LeftAnti => PlJoinType::LeftAnti,
98 robin_sparkless_core::engine::JoinType::LeftSemi => PlJoinType::LeftSemi,
99 robin_sparkless_core::engine::JoinType::Cross => {
100 let df = self.cross_join(right).map_err(map_err)?;
101 return Ok(Box::new(df));
102 }
103 };
104 let df = self.join(right, on.to_vec(), pl_how).map_err(map_err)?;
105 Ok(Box::new(df))
106 }
107
108 fn group_by(
109 &self,
110 column_names: &[&str],
111 ) -> Result<Box<dyn GroupedDataBackend>, CoreEngineError> {
112 let g = self.group_by(column_names.to_vec()).map_err(map_err)?;
113 Ok(Box::new(g))
114 }
115
116 fn order_by(
117 &self,
118 column_names: &[&str],
119 ascending: &[bool],
120 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
121 let asc: Vec<bool> = ascending.to_vec();
122 let df = self.order_by(column_names.to_vec(), asc).map_err(map_err)?;
123 Ok(Box::new(df))
124 }
125
126 fn limit(&self, n: usize) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
127 let df = self.limit(n).map_err(map_err)?;
128 Ok(Box::new(df))
129 }
130
131 fn union(
132 &self,
133 other: &dyn DataFrameBackend,
134 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
135 let right = downcast_df(other, "union")?;
136 let df = self.union(right).map_err(map_err)?;
137 Ok(Box::new(df))
138 }
139
140 fn union_by_name(
141 &self,
142 other: &dyn DataFrameBackend,
143 allow_missing_columns: bool,
144 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
145 let right = downcast_df(other, "union_by_name")?;
146 let df = self
147 .union_by_name(right, allow_missing_columns)
148 .map_err(map_err)?;
149 Ok(Box::new(df))
150 }
151
152 fn distinct(
153 &self,
154 subset: Option<Vec<&str>>,
155 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
156 let df = self.distinct(subset).map_err(map_err)?;
157 Ok(Box::new(df))
158 }
159
160 fn drop_columns(&self, columns: &[&str]) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
161 let df = self.drop(columns.to_vec()).map_err(map_err)?;
162 Ok(Box::new(df))
163 }
164
165 fn with_column_renamed(
166 &self,
167 old_name: &str,
168 new_name: &str,
169 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
170 let df = self
171 .with_column_renamed(old_name, new_name)
172 .map_err(map_err)?;
173 Ok(Box::new(df))
174 }
175
176 fn cross_join(
177 &self,
178 other: &dyn DataFrameBackend,
179 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
180 let right = downcast_df(other, "cross_join")?;
181 let df = self.cross_join(right).map_err(map_err)?;
182 Ok(Box::new(df))
183 }
184
185 fn collect(&self) -> Result<robin_sparkless_core::engine::CollectedRows, CoreEngineError> {
186 self.collect_as_json_rows().map_err(map_err)
187 }
188
189 fn schema(&self) -> Result<StructType, CoreEngineError> {
190 DataFrame::schema(self).map_err(map_err)
191 }
192
193 fn columns(&self) -> Result<Vec<String>, CoreEngineError> {
194 DataFrame::columns(self).map_err(map_err)
195 }
196
197 fn count(&self) -> Result<u64, CoreEngineError> {
198 let n = DataFrame::count(self).map_err(map_err)?;
199 Ok(n as u64)
200 }
201}
202
203impl GroupedDataBackend for GroupedData {
204 fn agg(&self, exprs: &[ExprIr]) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
205 let pl_exprs: Vec<_> = exprs
206 .iter()
207 .map(expr_ir_to_expr)
208 .collect::<Result<Vec<_>, _>>()
209 .map_err(to_core)?;
210 let df = self.agg(pl_exprs).map_err(map_err)?;
211 Ok(Box::new(df))
212 }
213
214 fn count(&self) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
215 let df = self.count().map_err(map_err)?;
216 Ok(Box::new(df))
217 }
218
219 fn sum(&self, column: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
220 let df = self.sum(column).map_err(map_err)?;
221 Ok(Box::new(df))
222 }
223
224 fn min(&self, column: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
225 let df = self.min(column).map_err(map_err)?;
226 Ok(Box::new(df))
227 }
228
229 fn max(&self, column: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
230 let df = self.max(column).map_err(map_err)?;
231 Ok(Box::new(df))
232 }
233
234 fn mean(&self, column: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
235 let df = self.avg(&[column]).map_err(map_err)?;
236 Ok(Box::new(df))
237 }
238
239 fn avg(&self, columns: &[&str]) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
240 let df = self.avg(columns).map_err(map_err)?;
241 Ok(Box::new(df))
242 }
243}
244
245impl DataFrameReaderBackend for DataFrameReader {
246 fn csv(&self, path: &Path) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
247 let df = self.csv(path).map_err(map_err)?;
248 Ok(Box::new(df))
249 }
250
251 fn parquet(&self, path: &Path) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
252 let df = self.parquet(path).map_err(map_err)?;
253 Ok(Box::new(df))
254 }
255
256 fn json(&self, path: &Path) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
257 let df = self.json(path).map_err(map_err)?;
258 Ok(Box::new(df))
259 }
260
261 fn table(&self, name: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
262 let df = self.table(name).map_err(map_err)?;
263 Ok(Box::new(df))
264 }
265}
266
267impl SparkSessionBackend for SparkSession {
268 fn read(&self) -> Box<dyn DataFrameReaderBackend> {
269 Box::new(DataFrameReader::new(self.clone()))
270 }
271
272 fn table(&self, name: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
273 let df = self.table(name).map_err(map_err)?;
274 Ok(Box::new(df))
275 }
276
277 fn create_dataframe_from_rows(
278 &self,
279 rows: Vec<Vec<serde_json::Value>>,
280 schema: Vec<(String, String)>,
281 verify_schema: bool,
282 schema_was_inferred: bool,
283 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
284 let df = self
285 .create_dataframe_from_rows(rows, schema, verify_schema, schema_was_inferred)
286 .map_err(map_err)?;
287 Ok(Box::new(df))
288 }
289
290 fn create_dataframe(
291 &self,
292 data: Vec<(i64, i64, String)>,
293 column_names: Vec<&str>,
294 ) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
295 let df = self.create_dataframe(data, column_names).map_err(map_err)?;
296 Ok(Box::new(df))
297 }
298
299 fn sql(&self, query: &str) -> Result<Box<dyn DataFrameBackend>, CoreEngineError> {
300 let df = self.sql(query).map_err(map_err)?;
301 Ok(Box::new(df))
302 }
303
304 fn register_table(&self, name: &str, df: &dyn DataFrameBackend) {
305 let polars_df = df
306 .as_any()
307 .downcast_ref::<DataFrame>()
308 .expect("register_table only supported with same backend (Polars)");
309 SparkSession::register_table(self, name, polars_df.clone());
310 }
311
312 fn is_case_sensitive(&self) -> bool {
313 self.is_case_sensitive()
314 }
315
316 fn get_config(&self) -> &std::collections::HashMap<String, String> {
317 self.get_config()
318 }
319}