Skip to main content

robin_sparkless/
session.rs

1//! Root-owned Session API; delegates to robin-sparkless-polars for execution.
2
3use crate::EngineError;
4use robin_sparkless_core::SparklessConfig;
5use robin_sparkless_polars::{
6    DataFrameReader as PolarsDataFrameReader, PlDataFrame, PolarsError,
7    SparkSession as PolarsSparkSession, SparkSessionBuilder as PolarsSparkSessionBuilder,
8};
9use std::collections::HashMap;
10use std::path::Path;
11
12use crate::dataframe::DataFrame;
13
14/// Root-owned SparkSession; delegates to the Polars backend.
15#[derive(Clone)]
16pub struct SparkSession(pub(crate) PolarsSparkSession);
17
18/// Root-owned SparkSessionBuilder; delegates to the Polars backend.
19#[derive(Clone)]
20pub struct SparkSessionBuilder(pub(crate) PolarsSparkSessionBuilder);
21
22/// Root-owned DataFrameReader; delegates to the Polars backend.
23pub struct DataFrameReader(PolarsDataFrameReader);
24
25impl SparkSessionBuilder {
26    pub fn new() -> Self {
27        SparkSessionBuilder(PolarsSparkSessionBuilder::new())
28    }
29
30    pub fn app_name(self, name: impl Into<String>) -> Self {
31        SparkSessionBuilder(self.0.app_name(name))
32    }
33
34    pub fn master(self, master: impl Into<String>) -> Self {
35        SparkSessionBuilder(self.0.master(master))
36    }
37
38    pub fn config(self, key: impl Into<String>, value: impl Into<String>) -> Self {
39        SparkSessionBuilder(self.0.config(key, value))
40    }
41
42    /// Config key-value pairs set on the builder (for singleton compatibility check).
43    pub fn get_config(&self) -> &HashMap<String, String> {
44        self.0.get_config()
45    }
46
47    pub fn get_or_create(self) -> SparkSession {
48        SparkSession(self.0.get_or_create())
49    }
50
51    pub fn with_config(self, config: &SparklessConfig) -> Self {
52        SparkSessionBuilder(self.0.with_config(config))
53    }
54}
55
56impl Default for SparkSessionBuilder {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl SparkSession {
63    pub fn builder() -> SparkSessionBuilder {
64        SparkSessionBuilder(PolarsSparkSession::builder())
65    }
66
67    pub fn from_config(config: &SparklessConfig) -> SparkSession {
68        SparkSession(PolarsSparkSession::from_config(config))
69    }
70
71    pub fn read(&self) -> DataFrameReader {
72        DataFrameReader(PolarsDataFrameReader::new(self.0.clone()))
73    }
74
75    pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
76        self.0.create_or_replace_temp_view(name, df.0)
77    }
78
79    pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
80        self.0.create_global_temp_view(name, df.0)
81    }
82
83    pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
84        self.0.create_or_replace_global_temp_view(name, df.0)
85    }
86
87    pub fn drop_temp_view(&self, name: &str) {
88        self.0.drop_temp_view(name)
89    }
90
91    pub fn drop_global_temp_view(&self, name: &str) -> bool {
92        self.0.drop_global_temp_view(name)
93    }
94
95    pub fn register_table(&self, name: &str, df: DataFrame) {
96        self.0.register_table(name, df.0)
97    }
98
99    pub fn register_database(&self, name: &str) {
100        self.0.register_database(name)
101    }
102
103    pub fn list_database_names(&self) -> Vec<String> {
104        self.0.list_database_names()
105    }
106
107    pub fn database_exists(&self, name: &str) -> bool {
108        self.0.database_exists(name)
109    }
110
111    pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
112        self.0.get_saved_table(name).map(DataFrame)
113    }
114
115    pub fn saved_table_exists(&self, name: &str) -> bool {
116        self.0.saved_table_exists(name)
117    }
118
119    pub fn table_exists(&self, name: &str) -> bool {
120        self.0.table_exists(name)
121    }
122
123    pub fn list_global_temp_view_names(&self) -> Vec<String> {
124        self.0.list_global_temp_view_names()
125    }
126
127    pub fn list_temp_view_names(&self) -> Vec<String> {
128        self.0.list_temp_view_names()
129    }
130
131    pub fn list_table_names(&self) -> Vec<String> {
132        self.0.list_table_names()
133    }
134
135    pub fn app_name(&self) -> Option<String> {
136        self.0.app_name()
137    }
138
139    pub fn new_session(&self) -> SparkSession {
140        SparkSession(self.0.new_session())
141    }
142
143    pub fn current_database(&self) -> String {
144        self.0.current_database()
145    }
146
147    pub fn set_current_database(&self, name: &str) -> Result<(), EngineError> {
148        self.0.set_current_database(name)
149    }
150
151    pub fn cache_table(&self, name: &str) {
152        self.0.cache_table(name)
153    }
154
155    pub fn uncache_table(&self, name: &str) {
156        self.0.uncache_table(name)
157    }
158
159    pub fn is_cached(&self, name: &str) -> bool {
160        self.0.is_cached(name)
161    }
162
163    pub fn drop_table(&self, name: &str) -> bool {
164        self.0.drop_table(name)
165    }
166
167    pub fn drop_database(&self, name: &str) -> bool {
168        self.0.drop_database(name)
169    }
170
171    pub fn warehouse_dir(&self) -> Option<&str> {
172        self.0.warehouse_dir()
173    }
174
175    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
176        self.0.table(name).map(DataFrame)
177    }
178
179    pub fn get_config(&self) -> &HashMap<String, String> {
180        self.0.get_config()
181    }
182
183    pub fn set_config(&mut self, key: impl Into<String>, value: impl Into<String>) {
184        self.0.set_config(key, value);
185    }
186
187    pub fn is_case_sensitive(&self) -> bool {
188        self.0.is_case_sensitive()
189    }
190
191    pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
192    where
193        F: Fn(
194                &[robin_sparkless_polars::Series],
195            ) -> Result<robin_sparkless_polars::Series, PolarsError>
196            + Send
197            + Sync
198            + 'static,
199    {
200        self.0.register_udf(name, f)
201    }
202
203    pub fn create_dataframe(
204        &self,
205        data: Vec<(i64, i64, String)>,
206        column_names: Vec<&str>,
207    ) -> Result<DataFrame, PolarsError> {
208        self.0.create_dataframe(data, column_names).map(DataFrame)
209    }
210
211    pub fn create_dataframe_engine(
212        &self,
213        data: Vec<(i64, i64, String)>,
214        column_names: Vec<&str>,
215    ) -> Result<DataFrame, EngineError> {
216        self.0
217            .create_dataframe_engine(data, column_names)
218            .map(DataFrame)
219    }
220
221    pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
222        DataFrame(self.0.create_dataframe_from_polars(df))
223    }
224
225    pub fn create_dataframe_from_rows(
226        &self,
227        rows: Vec<Vec<serde_json::Value>>,
228        schema: Vec<(String, String)>,
229        verify_schema: bool,
230        schema_was_inferred: bool,
231    ) -> Result<DataFrame, PolarsError> {
232        self.0
233            .create_dataframe_from_rows(rows, schema, verify_schema, schema_was_inferred)
234            .map(DataFrame)
235    }
236
237    pub fn create_dataframe_from_rows_engine(
238        &self,
239        rows: Vec<Vec<serde_json::Value>>,
240        schema: Vec<(String, String)>,
241        verify_schema: bool,
242        schema_was_inferred: bool,
243    ) -> Result<DataFrame, EngineError> {
244        self.0
245            .create_dataframe_from_rows_engine(rows, schema, verify_schema, schema_was_inferred)
246            .map(DataFrame)
247    }
248
249    /// #419: Create a DataFrame with a single column "value" from scalar values (e.g. createDataFrame([1,2,3], "bigint")).
250    pub fn create_dataframe_from_single_column(
251        &self,
252        values: Vec<serde_json::Value>,
253        type_str: &str,
254    ) -> Result<DataFrame, PolarsError> {
255        self.0
256            .create_dataframe_from_single_column(values, type_str)
257            .map(DataFrame)
258    }
259
260    pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
261        self.0.range(start, end, step).map(DataFrame)
262    }
263
264    pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
265        self.0.read_csv(path).map(DataFrame)
266    }
267
268    pub fn read_csv_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
269        self.0.read_csv_engine(path).map(DataFrame)
270    }
271
272    pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
273        self.0.read_parquet(path).map(DataFrame)
274    }
275
276    pub fn read_parquet_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
277        self.0.read_parquet_engine(path).map(DataFrame)
278    }
279
280    pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
281        self.0.read_json(path).map(DataFrame)
282    }
283
284    pub fn read_json_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
285        self.0.read_json_engine(path).map(DataFrame)
286    }
287
288    pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
289        self.0.sql(query).map(DataFrame)
290    }
291
292    pub fn table_engine(&self, name: &str) -> Result<DataFrame, EngineError> {
293        self.0.table_engine(name).map(DataFrame)
294    }
295
296    #[cfg(feature = "delta")]
297    pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
298        self.0.read_delta_path(path).map(DataFrame)
299    }
300
301    pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
302        self.0.read_delta_from_path(path).map(DataFrame)
303    }
304
305    #[cfg(feature = "delta")]
306    pub fn read_delta_path_with_version(
307        &self,
308        path: impl AsRef<Path>,
309        version: Option<i64>,
310    ) -> Result<DataFrame, PolarsError> {
311        self.0
312            .read_delta_path_with_version(path, version)
313            .map(DataFrame)
314    }
315
316    #[cfg(feature = "delta")]
317    pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
318        self.0.read_delta(name_or_path).map(DataFrame)
319    }
320
321    #[cfg(feature = "delta")]
322    pub fn read_delta_with_version(
323        &self,
324        name_or_path: &str,
325        version: Option<i64>,
326    ) -> Result<DataFrame, PolarsError> {
327        self.0
328            .read_delta_with_version(name_or_path, version)
329            .map(DataFrame)
330    }
331
332    pub fn stop(&self) {
333        self.0.stop()
334    }
335
336    /// Get the UDF registry. Used internally for thread context management.
337    pub fn udf_registry(&self) -> &robin_sparkless_polars::UdfRegistry {
338        self.0.udf_registry()
339    }
340}
341
342impl DataFrameReader {
343    pub fn option(self, key: impl Into<String>, value: impl Into<String>) -> Self {
344        DataFrameReader(self.0.option(key, value))
345    }
346
347    pub fn options(self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
348        DataFrameReader(self.0.options(opts))
349    }
350
351    pub fn format(self, fmt: impl Into<String>) -> Self {
352        DataFrameReader(self.0.format(fmt))
353    }
354
355    pub fn schema(self, schema: impl Into<String>) -> Self {
356        DataFrameReader(self.0.schema(schema))
357    }
358
359    pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
360        self.0.load(path).map(DataFrame)
361    }
362
363    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
364        self.0.table(name).map(DataFrame)
365    }
366
367    pub fn csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
368        self.0.csv(path).map(DataFrame)
369    }
370
371    pub fn parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
372        self.0.parquet(path).map(DataFrame)
373    }
374
375    pub fn json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
376        self.0.json(path).map(DataFrame)
377    }
378
379    #[cfg(feature = "delta")]
380    pub fn delta(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
381        self.0.delta(path).map(DataFrame)
382    }
383
384    /// JDBC read: load a table from an external database (e.g. PostgreSQL).
385    /// Requires the `jdbc` or `sqlite` feature. Mirror of PySpark's spark.read.jdbc(url, table, properties).
386    #[cfg(any(
387        feature = "jdbc",
388        feature = "jdbc_mysql",
389        feature = "jdbc_mariadb",
390        feature = "jdbc_mssql",
391        feature = "jdbc_oracle",
392        feature = "jdbc_db2",
393        feature = "sqlite"
394    ))]
395    pub fn jdbc(
396        &self,
397        url: &str,
398        table: &str,
399        properties: &HashMap<String, String>,
400    ) -> Result<DataFrame, PolarsError> {
401        self.0
402            .jdbc_with_properties(url, table, properties)
403            .map(DataFrame)
404            .map_err(|e| PolarsError::ComputeError(e.to_string().into()))
405    }
406}