1use crate::EngineError;
4use robin_sparkless_core::SparklessConfig;
5use robin_sparkless_polars::{
6 DataFrameReader as PolarsDataFrameReader, PlDataFrame, PolarsError,
7 SparkSession as PolarsSparkSession, SparkSessionBuilder as PolarsSparkSessionBuilder,
8};
9use std::collections::HashMap;
10use std::path::Path;
11
12use crate::dataframe::DataFrame;
13
14#[derive(Clone)]
16pub struct SparkSession(pub(crate) PolarsSparkSession);
17
18#[derive(Clone)]
20pub struct SparkSessionBuilder(pub(crate) PolarsSparkSessionBuilder);
21
22pub struct DataFrameReader(PolarsDataFrameReader);
24
25impl SparkSessionBuilder {
26 pub fn new() -> Self {
27 SparkSessionBuilder(PolarsSparkSessionBuilder::new())
28 }
29
30 pub fn app_name(self, name: impl Into<String>) -> Self {
31 SparkSessionBuilder(self.0.app_name(name))
32 }
33
34 pub fn master(self, master: impl Into<String>) -> Self {
35 SparkSessionBuilder(self.0.master(master))
36 }
37
38 pub fn config(self, key: impl Into<String>, value: impl Into<String>) -> Self {
39 SparkSessionBuilder(self.0.config(key, value))
40 }
41
42 pub fn get_config(&self) -> &HashMap<String, String> {
44 self.0.get_config()
45 }
46
47 pub fn get_or_create(self) -> SparkSession {
48 SparkSession(self.0.get_or_create())
49 }
50
51 pub fn with_config(self, config: &SparklessConfig) -> Self {
52 SparkSessionBuilder(self.0.with_config(config))
53 }
54}
55
56impl Default for SparkSessionBuilder {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl SparkSession {
63 pub fn builder() -> SparkSessionBuilder {
64 SparkSessionBuilder(PolarsSparkSession::builder())
65 }
66
67 pub fn from_config(config: &SparklessConfig) -> SparkSession {
68 SparkSession(PolarsSparkSession::from_config(config))
69 }
70
71 pub fn read(&self) -> DataFrameReader {
72 DataFrameReader(PolarsDataFrameReader::new(self.0.clone()))
73 }
74
75 pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
76 self.0.create_or_replace_temp_view(name, df.0)
77 }
78
79 pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
80 self.0.create_global_temp_view(name, df.0)
81 }
82
83 pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
84 self.0.create_or_replace_global_temp_view(name, df.0)
85 }
86
87 pub fn drop_temp_view(&self, name: &str) {
88 self.0.drop_temp_view(name)
89 }
90
91 pub fn drop_global_temp_view(&self, name: &str) -> bool {
92 self.0.drop_global_temp_view(name)
93 }
94
95 pub fn register_table(&self, name: &str, df: DataFrame) {
96 self.0.register_table(name, df.0)
97 }
98
99 pub fn register_database(&self, name: &str) {
100 self.0.register_database(name)
101 }
102
103 pub fn list_database_names(&self) -> Vec<String> {
104 self.0.list_database_names()
105 }
106
107 pub fn database_exists(&self, name: &str) -> bool {
108 self.0.database_exists(name)
109 }
110
111 pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
112 self.0.get_saved_table(name).map(DataFrame)
113 }
114
115 pub fn saved_table_exists(&self, name: &str) -> bool {
116 self.0.saved_table_exists(name)
117 }
118
119 pub fn table_exists(&self, name: &str) -> bool {
120 self.0.table_exists(name)
121 }
122
123 pub fn list_global_temp_view_names(&self) -> Vec<String> {
124 self.0.list_global_temp_view_names()
125 }
126
127 pub fn list_temp_view_names(&self) -> Vec<String> {
128 self.0.list_temp_view_names()
129 }
130
131 pub fn list_table_names(&self) -> Vec<String> {
132 self.0.list_table_names()
133 }
134
135 pub fn app_name(&self) -> Option<String> {
136 self.0.app_name()
137 }
138
139 pub fn new_session(&self) -> SparkSession {
140 SparkSession(self.0.new_session())
141 }
142
143 pub fn current_database(&self) -> String {
144 self.0.current_database()
145 }
146
147 pub fn set_current_database(&self, name: &str) -> Result<(), EngineError> {
148 self.0.set_current_database(name)
149 }
150
151 pub fn cache_table(&self, name: &str) {
152 self.0.cache_table(name)
153 }
154
155 pub fn uncache_table(&self, name: &str) {
156 self.0.uncache_table(name)
157 }
158
159 pub fn is_cached(&self, name: &str) -> bool {
160 self.0.is_cached(name)
161 }
162
163 pub fn drop_table(&self, name: &str) -> bool {
164 self.0.drop_table(name)
165 }
166
167 pub fn drop_database(&self, name: &str) -> bool {
168 self.0.drop_database(name)
169 }
170
171 pub fn warehouse_dir(&self) -> Option<&str> {
172 self.0.warehouse_dir()
173 }
174
175 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
176 self.0.table(name).map(DataFrame)
177 }
178
179 pub fn get_config(&self) -> &HashMap<String, String> {
180 self.0.get_config()
181 }
182
183 pub fn set_config(&mut self, key: impl Into<String>, value: impl Into<String>) {
184 self.0.set_config(key, value);
185 }
186
187 pub fn is_case_sensitive(&self) -> bool {
188 self.0.is_case_sensitive()
189 }
190
191 pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
192 where
193 F: Fn(
194 &[robin_sparkless_polars::Series],
195 ) -> Result<robin_sparkless_polars::Series, PolarsError>
196 + Send
197 + Sync
198 + 'static,
199 {
200 self.0.register_udf(name, f)
201 }
202
203 pub fn create_dataframe(
204 &self,
205 data: Vec<(i64, i64, String)>,
206 column_names: Vec<&str>,
207 ) -> Result<DataFrame, PolarsError> {
208 self.0.create_dataframe(data, column_names).map(DataFrame)
209 }
210
211 pub fn create_dataframe_engine(
212 &self,
213 data: Vec<(i64, i64, String)>,
214 column_names: Vec<&str>,
215 ) -> Result<DataFrame, EngineError> {
216 self.0
217 .create_dataframe_engine(data, column_names)
218 .map(DataFrame)
219 }
220
221 pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
222 DataFrame(self.0.create_dataframe_from_polars(df))
223 }
224
225 pub fn create_dataframe_from_rows(
226 &self,
227 rows: Vec<Vec<serde_json::Value>>,
228 schema: Vec<(String, String)>,
229 verify_schema: bool,
230 schema_was_inferred: bool,
231 ) -> Result<DataFrame, PolarsError> {
232 self.0
233 .create_dataframe_from_rows(rows, schema, verify_schema, schema_was_inferred)
234 .map(DataFrame)
235 }
236
237 pub fn create_dataframe_from_rows_engine(
238 &self,
239 rows: Vec<Vec<serde_json::Value>>,
240 schema: Vec<(String, String)>,
241 verify_schema: bool,
242 schema_was_inferred: bool,
243 ) -> Result<DataFrame, EngineError> {
244 self.0
245 .create_dataframe_from_rows_engine(rows, schema, verify_schema, schema_was_inferred)
246 .map(DataFrame)
247 }
248
249 pub fn create_dataframe_from_single_column(
251 &self,
252 values: Vec<serde_json::Value>,
253 type_str: &str,
254 ) -> Result<DataFrame, PolarsError> {
255 self.0
256 .create_dataframe_from_single_column(values, type_str)
257 .map(DataFrame)
258 }
259
260 pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
261 self.0.range(start, end, step).map(DataFrame)
262 }
263
264 pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
265 self.0.read_csv(path).map(DataFrame)
266 }
267
268 pub fn read_csv_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
269 self.0.read_csv_engine(path).map(DataFrame)
270 }
271
272 pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
273 self.0.read_parquet(path).map(DataFrame)
274 }
275
276 pub fn read_parquet_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
277 self.0.read_parquet_engine(path).map(DataFrame)
278 }
279
280 pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
281 self.0.read_json(path).map(DataFrame)
282 }
283
284 pub fn read_json_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
285 self.0.read_json_engine(path).map(DataFrame)
286 }
287
288 pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
289 self.0.sql(query).map(DataFrame)
290 }
291
292 pub fn table_engine(&self, name: &str) -> Result<DataFrame, EngineError> {
293 self.0.table_engine(name).map(DataFrame)
294 }
295
296 #[cfg(feature = "delta")]
297 pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
298 self.0.read_delta_path(path).map(DataFrame)
299 }
300
301 pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
302 self.0.read_delta_from_path(path).map(DataFrame)
303 }
304
305 #[cfg(feature = "delta")]
306 pub fn read_delta_path_with_version(
307 &self,
308 path: impl AsRef<Path>,
309 version: Option<i64>,
310 ) -> Result<DataFrame, PolarsError> {
311 self.0
312 .read_delta_path_with_version(path, version)
313 .map(DataFrame)
314 }
315
316 #[cfg(feature = "delta")]
317 pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
318 self.0.read_delta(name_or_path).map(DataFrame)
319 }
320
321 #[cfg(feature = "delta")]
322 pub fn read_delta_with_version(
323 &self,
324 name_or_path: &str,
325 version: Option<i64>,
326 ) -> Result<DataFrame, PolarsError> {
327 self.0
328 .read_delta_with_version(name_or_path, version)
329 .map(DataFrame)
330 }
331
332 pub fn stop(&self) {
333 self.0.stop()
334 }
335
336 pub fn udf_registry(&self) -> &robin_sparkless_polars::UdfRegistry {
338 self.0.udf_registry()
339 }
340}
341
342impl DataFrameReader {
343 pub fn option(self, key: impl Into<String>, value: impl Into<String>) -> Self {
344 DataFrameReader(self.0.option(key, value))
345 }
346
347 pub fn options(self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
348 DataFrameReader(self.0.options(opts))
349 }
350
351 pub fn format(self, fmt: impl Into<String>) -> Self {
352 DataFrameReader(self.0.format(fmt))
353 }
354
355 pub fn schema(self, schema: impl Into<String>) -> Self {
356 DataFrameReader(self.0.schema(schema))
357 }
358
359 pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
360 self.0.load(path).map(DataFrame)
361 }
362
363 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
364 self.0.table(name).map(DataFrame)
365 }
366
367 pub fn csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
368 self.0.csv(path).map(DataFrame)
369 }
370
371 pub fn parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
372 self.0.parquet(path).map(DataFrame)
373 }
374
375 pub fn json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
376 self.0.json(path).map(DataFrame)
377 }
378
379 #[cfg(feature = "delta")]
380 pub fn delta(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
381 self.0.delta(path).map(DataFrame)
382 }
383
384 #[cfg(any(
387 feature = "jdbc",
388 feature = "jdbc_mysql",
389 feature = "jdbc_mariadb",
390 feature = "jdbc_mssql",
391 feature = "jdbc_oracle",
392 feature = "jdbc_db2",
393 feature = "sqlite"
394 ))]
395 pub fn jdbc(
396 &self,
397 url: &str,
398 table: &str,
399 properties: &HashMap<String, String>,
400 ) -> Result<DataFrame, PolarsError> {
401 self.0
402 .jdbc_with_properties(url, table, properties)
403 .map(DataFrame)
404 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))
405 }
406}