1use crate::EngineError;
4use robin_sparkless_core::SparklessConfig;
5use robin_sparkless_polars::{
6 DataFrameReader as PolarsDataFrameReader, PlDataFrame, PolarsError,
7 SparkSession as PolarsSparkSession, SparkSessionBuilder as PolarsSparkSessionBuilder,
8};
9use std::collections::HashMap;
10use std::path::Path;
11
12use crate::dataframe::DataFrame;
13
14#[derive(Clone)]
16pub struct SparkSession(pub(crate) PolarsSparkSession);
17
18pub struct SparkSessionBuilder(pub(crate) PolarsSparkSessionBuilder);
20
21pub struct DataFrameReader(PolarsDataFrameReader);
23
24impl SparkSessionBuilder {
25 pub fn new() -> Self {
26 SparkSessionBuilder(PolarsSparkSessionBuilder::new())
27 }
28
29 pub fn app_name(self, name: impl Into<String>) -> Self {
30 SparkSessionBuilder(self.0.app_name(name))
31 }
32
33 pub fn master(self, master: impl Into<String>) -> Self {
34 SparkSessionBuilder(self.0.master(master))
35 }
36
37 pub fn config(self, key: impl Into<String>, value: impl Into<String>) -> Self {
38 SparkSessionBuilder(self.0.config(key, value))
39 }
40
41 pub fn get_or_create(self) -> SparkSession {
42 SparkSession(self.0.get_or_create())
43 }
44
45 pub fn with_config(self, config: &SparklessConfig) -> Self {
46 SparkSessionBuilder(self.0.with_config(config))
47 }
48}
49
50impl Default for SparkSessionBuilder {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56impl SparkSession {
57 pub fn builder() -> SparkSessionBuilder {
58 SparkSessionBuilder(PolarsSparkSession::builder())
59 }
60
61 pub fn from_config(config: &SparklessConfig) -> SparkSession {
62 SparkSession(PolarsSparkSession::from_config(config))
63 }
64
65 pub fn read(&self) -> DataFrameReader {
66 DataFrameReader(PolarsDataFrameReader::new(self.0.clone()))
67 }
68
69 pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
70 self.0.create_or_replace_temp_view(name, df.0)
71 }
72
73 pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
74 self.0.create_global_temp_view(name, df.0)
75 }
76
77 pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
78 self.0.create_or_replace_global_temp_view(name, df.0)
79 }
80
81 pub fn drop_temp_view(&self, name: &str) {
82 self.0.drop_temp_view(name)
83 }
84
85 pub fn drop_global_temp_view(&self, name: &str) -> bool {
86 self.0.drop_global_temp_view(name)
87 }
88
89 pub fn register_table(&self, name: &str, df: DataFrame) {
90 self.0.register_table(name, df.0)
91 }
92
93 pub fn register_database(&self, name: &str) {
94 self.0.register_database(name)
95 }
96
97 pub fn list_database_names(&self) -> Vec<String> {
98 self.0.list_database_names()
99 }
100
101 pub fn database_exists(&self, name: &str) -> bool {
102 self.0.database_exists(name)
103 }
104
105 pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
106 self.0.get_saved_table(name).map(DataFrame)
107 }
108
109 pub fn saved_table_exists(&self, name: &str) -> bool {
110 self.0.saved_table_exists(name)
111 }
112
113 pub fn table_exists(&self, name: &str) -> bool {
114 self.0.table_exists(name)
115 }
116
117 pub fn list_global_temp_view_names(&self) -> Vec<String> {
118 self.0.list_global_temp_view_names()
119 }
120
121 pub fn list_temp_view_names(&self) -> Vec<String> {
122 self.0.list_temp_view_names()
123 }
124
125 pub fn list_table_names(&self) -> Vec<String> {
126 self.0.list_table_names()
127 }
128
129 pub fn drop_table(&self, name: &str) -> bool {
130 self.0.drop_table(name)
131 }
132
133 pub fn drop_database(&self, name: &str) -> bool {
134 self.0.drop_database(name)
135 }
136
137 pub fn warehouse_dir(&self) -> Option<&str> {
138 self.0.warehouse_dir()
139 }
140
141 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
142 self.0.table(name).map(DataFrame)
143 }
144
145 pub fn get_config(&self) -> &HashMap<String, String> {
146 self.0.get_config()
147 }
148
149 pub fn is_case_sensitive(&self) -> bool {
150 self.0.is_case_sensitive()
151 }
152
153 pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
154 where
155 F: Fn(
156 &[robin_sparkless_polars::Series],
157 ) -> Result<robin_sparkless_polars::Series, PolarsError>
158 + Send
159 + Sync
160 + 'static,
161 {
162 self.0.register_udf(name, f)
163 }
164
165 pub fn create_dataframe(
166 &self,
167 data: Vec<(i64, i64, String)>,
168 column_names: Vec<&str>,
169 ) -> Result<DataFrame, PolarsError> {
170 self.0.create_dataframe(data, column_names).map(DataFrame)
171 }
172
173 pub fn create_dataframe_engine(
174 &self,
175 data: Vec<(i64, i64, String)>,
176 column_names: Vec<&str>,
177 ) -> Result<DataFrame, EngineError> {
178 self.0
179 .create_dataframe_engine(data, column_names)
180 .map(DataFrame)
181 }
182
183 pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
184 DataFrame(self.0.create_dataframe_from_polars(df))
185 }
186
187 pub fn create_dataframe_from_rows(
188 &self,
189 rows: Vec<Vec<serde_json::Value>>,
190 schema: Vec<(String, String)>,
191 ) -> Result<DataFrame, PolarsError> {
192 self.0
193 .create_dataframe_from_rows(rows, schema)
194 .map(DataFrame)
195 }
196
197 pub fn create_dataframe_from_rows_engine(
198 &self,
199 rows: Vec<Vec<serde_json::Value>>,
200 schema: Vec<(String, String)>,
201 ) -> Result<DataFrame, EngineError> {
202 self.0
203 .create_dataframe_from_rows_engine(rows, schema)
204 .map(DataFrame)
205 }
206
207 pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
208 self.0.range(start, end, step).map(DataFrame)
209 }
210
211 pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
212 self.0.read_csv(path).map(DataFrame)
213 }
214
215 pub fn read_csv_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
216 self.0.read_csv_engine(path).map(DataFrame)
217 }
218
219 pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
220 self.0.read_parquet(path).map(DataFrame)
221 }
222
223 pub fn read_parquet_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
224 self.0.read_parquet_engine(path).map(DataFrame)
225 }
226
227 pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
228 self.0.read_json(path).map(DataFrame)
229 }
230
231 pub fn read_json_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
232 self.0.read_json_engine(path).map(DataFrame)
233 }
234
235 pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
236 self.0.sql(query).map(DataFrame)
237 }
238
239 pub fn table_engine(&self, name: &str) -> Result<DataFrame, EngineError> {
240 self.0.table_engine(name).map(DataFrame)
241 }
242
243 #[cfg(feature = "delta")]
244 pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
245 self.0.read_delta_path(path).map(DataFrame)
246 }
247
248 pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
249 self.0.read_delta_from_path(path).map(DataFrame)
250 }
251
252 #[cfg(feature = "delta")]
253 pub fn read_delta_path_with_version(
254 &self,
255 path: impl AsRef<Path>,
256 version: Option<i64>,
257 ) -> Result<DataFrame, PolarsError> {
258 self.0
259 .read_delta_path_with_version(path, version)
260 .map(DataFrame)
261 }
262
263 #[cfg(feature = "delta")]
264 pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
265 self.0.read_delta(name_or_path).map(DataFrame)
266 }
267
268 #[cfg(feature = "delta")]
269 pub fn read_delta_with_version(
270 &self,
271 name_or_path: &str,
272 version: Option<i64>,
273 ) -> Result<DataFrame, PolarsError> {
274 self.0
275 .read_delta_with_version(name_or_path, version)
276 .map(DataFrame)
277 }
278
279 pub fn stop(&self) {
280 self.0.stop()
281 }
282}
283
284impl DataFrameReader {
285 pub fn option(self, key: impl Into<String>, value: impl Into<String>) -> Self {
286 DataFrameReader(self.0.option(key, value))
287 }
288
289 pub fn options(self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
290 DataFrameReader(self.0.options(opts))
291 }
292
293 pub fn format(self, fmt: impl Into<String>) -> Self {
294 DataFrameReader(self.0.format(fmt))
295 }
296
297 pub fn schema(self, schema: impl Into<String>) -> Self {
298 DataFrameReader(self.0.schema(schema))
299 }
300
301 pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
302 self.0.load(path).map(DataFrame)
303 }
304
305 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
306 self.0.table(name).map(DataFrame)
307 }
308
309 pub fn csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
310 self.0.csv(path).map(DataFrame)
311 }
312
313 pub fn parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
314 self.0.parquet(path).map(DataFrame)
315 }
316
317 pub fn json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
318 self.0.json(path).map(DataFrame)
319 }
320
321 #[cfg(feature = "delta")]
322 pub fn delta(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
323 self.0.delta(path).map(DataFrame)
324 }
325}