Skip to main content

robin_sparkless/
session.rs

1use crate::dataframe::DataFrame;
2use polars::prelude::{
3    DataFrame as PlDataFrame, DataType, NamedFrom, PolarsError, Series, TimeUnit,
4};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9
10/// Builder for creating a SparkSession with configuration options
11#[derive(Clone)]
12pub struct SparkSessionBuilder {
13    app_name: Option<String>,
14    master: Option<String>,
15    config: HashMap<String, String>,
16}
17
18impl Default for SparkSessionBuilder {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl SparkSessionBuilder {
25    pub fn new() -> Self {
26        SparkSessionBuilder {
27            app_name: None,
28            master: None,
29            config: HashMap::new(),
30        }
31    }
32
33    pub fn app_name(mut self, name: impl Into<String>) -> Self {
34        self.app_name = Some(name.into());
35        self
36    }
37
38    pub fn master(mut self, master: impl Into<String>) -> Self {
39        self.master = Some(master.into());
40        self
41    }
42
43    pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
44        self.config.insert(key.into(), value.into());
45        self
46    }
47
48    pub fn get_or_create(self) -> SparkSession {
49        SparkSession::new(self.app_name, self.master, self.config)
50    }
51}
52
53/// Catalog of temporary view names to DataFrames (session-scoped). Uses Arc<Mutex<>> for Send+Sync (Python bindings).
54pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
55
56/// Main entry point for creating DataFrames and executing queries
57/// Similar to PySpark's SparkSession but using Polars as the backend
58#[derive(Clone)]
59pub struct SparkSession {
60    app_name: Option<String>,
61    master: Option<String>,
62    config: HashMap<String, String>,
63    /// Temporary views: name -> DataFrame. Session-scoped; cleared when session is dropped.
64    pub(crate) catalog: TempViewCatalog,
65}
66
67impl SparkSession {
68    pub fn new(
69        app_name: Option<String>,
70        master: Option<String>,
71        config: HashMap<String, String>,
72    ) -> Self {
73        SparkSession {
74            app_name,
75            master,
76            config,
77            catalog: Arc::new(Mutex::new(HashMap::new())),
78        }
79    }
80
81    /// Register a DataFrame as a temporary view (PySpark: createOrReplaceTempView).
82    /// The view is session-scoped and is dropped when the session is dropped.
83    pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
84        let _ = self
85            .catalog
86            .lock()
87            .map(|mut m| m.insert(name.to_string(), df));
88    }
89
90    /// Global temp view (PySpark: createGlobalTempView). Stub: uses same catalog as temp view.
91    pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
92        self.create_or_replace_temp_view(name, df);
93    }
94
95    /// Global temp view (PySpark: createOrReplaceGlobalTempView). Stub: uses same catalog as temp view.
96    pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
97        self.create_or_replace_temp_view(name, df);
98    }
99
100    /// Drop a temporary view by name (PySpark: catalog.dropTempView).
101    /// No error if the view does not exist.
102    pub fn drop_temp_view(&self, name: &str) {
103        let _ = self.catalog.lock().map(|mut m| m.remove(name));
104    }
105
106    /// Drop a global temporary view (PySpark: catalog.dropGlobalTempView). Stub: same catalog as temp view.
107    pub fn drop_global_temp_view(&self, name: &str) {
108        self.drop_temp_view(name);
109    }
110
111    /// Check if a temporary view exists.
112    pub fn table_exists(&self, name: &str) -> bool {
113        self.catalog
114            .lock()
115            .map(|m| m.contains_key(name))
116            .unwrap_or(false)
117    }
118
119    /// Return temporary view names in this session.
120    pub fn list_temp_view_names(&self) -> Vec<String> {
121        self.catalog
122            .lock()
123            .map(|m| m.keys().cloned().collect())
124            .unwrap_or_default()
125    }
126
127    /// Look up a temporary view by name (PySpark: table(name)).
128    /// Returns an error if the view does not exist.
129    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
130        self.catalog
131            .lock()
132            .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
133            .get(name)
134            .cloned()
135            .ok_or_else(|| {
136                PolarsError::InvalidOperation(
137                    format!(
138                        "Table or view '{name}' not found. Register it with create_or_replace_temp_view."
139                    )
140                    .into(),
141                )
142            })
143    }
144
145    pub fn builder() -> SparkSessionBuilder {
146        SparkSessionBuilder::new()
147    }
148
149    /// Return a reference to the session config (for catalog/conf compatibility).
150    pub fn get_config(&self) -> &HashMap<String, String> {
151        &self.config
152    }
153
154    /// Whether column names are case-sensitive (PySpark: spark.sql.caseSensitive).
155    /// Default is false (case-insensitive matching).
156    pub fn is_case_sensitive(&self) -> bool {
157        self.config
158            .get("spark.sql.caseSensitive")
159            .map(|v| v.eq_ignore_ascii_case("true"))
160            .unwrap_or(false)
161    }
162
163    /// Create a DataFrame from a vector of tuples (i64, i64, String)
164    ///
165    /// # Example
166    /// ```
167    /// use robin_sparkless::session::SparkSession;
168    ///
169    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
170    /// let spark = SparkSession::builder().app_name("test").get_or_create();
171    /// let df = spark.create_dataframe(
172    ///     vec![
173    ///         (1, 25, "Alice".to_string()),
174    ///         (2, 30, "Bob".to_string()),
175    ///     ],
176    ///     vec!["id", "age", "name"],
177    /// )?;
178    /// #     let _ = df;
179    /// #     Ok(())
180    /// # }
181    /// ```
182    pub fn create_dataframe(
183        &self,
184        data: Vec<(i64, i64, String)>,
185        column_names: Vec<&str>,
186    ) -> Result<DataFrame, PolarsError> {
187        if column_names.len() != 3 {
188            return Err(PolarsError::ComputeError(
189                format!(
190                    "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
191                    column_names.len()
192                )
193                .into(),
194            ));
195        }
196
197        let mut cols: Vec<Series> = Vec::with_capacity(3);
198
199        // First column: i64
200        let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
201        cols.push(Series::new(column_names[0].into(), col0));
202
203        // Second column: i64
204        let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
205        cols.push(Series::new(column_names[1].into(), col1));
206
207        // Third column: String
208        let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
209        cols.push(Series::new(column_names[2].into(), col2));
210
211        let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
212        Ok(DataFrame::from_polars_with_options(
213            pl_df,
214            self.is_case_sensitive(),
215        ))
216    }
217
218    /// Create a DataFrame from a Polars DataFrame
219    pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
220        DataFrame::from_polars_with_options(df, self.is_case_sensitive())
221    }
222
223    /// Create a DataFrame from rows and a schema (arbitrary column count and types).
224    ///
225    /// `rows`: each inner vec is one row; length must match schema length. Values are JSON-like (i64, f64, string, bool, null).
226    /// `schema`: list of (column_name, dtype_string), e.g. `[("id", "bigint"), ("name", "string")]`.
227    /// Supported dtype strings: bigint, int, long, double, float, string, str, varchar, boolean, bool, date, timestamp, datetime.
228    pub fn create_dataframe_from_rows(
229        &self,
230        rows: Vec<Vec<JsonValue>>,
231        schema: Vec<(String, String)>,
232    ) -> Result<DataFrame, PolarsError> {
233        if schema.is_empty() {
234            return Err(PolarsError::InvalidOperation(
235                "create_dataframe_from_rows: schema must not be empty".into(),
236            ));
237        }
238        use chrono::{NaiveDate, NaiveDateTime};
239
240        let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
241
242        for (col_idx, (name, type_str)) in schema.iter().enumerate() {
243            let type_lower = type_str.trim().to_lowercase();
244            let s = match type_lower.as_str() {
245                "int" | "bigint" | "long" => {
246                    let vals: Vec<Option<i64>> = rows
247                        .iter()
248                        .map(|row| {
249                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
250                            match v {
251                                JsonValue::Number(n) => n.as_i64(),
252                                JsonValue::Null => None,
253                                _ => None,
254                            }
255                        })
256                        .collect();
257                    Series::new(name.as_str().into(), vals)
258                }
259                "double" | "float" | "double_precision" => {
260                    let vals: Vec<Option<f64>> = rows
261                        .iter()
262                        .map(|row| {
263                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
264                            match v {
265                                JsonValue::Number(n) => n.as_f64(),
266                                JsonValue::Null => None,
267                                _ => None,
268                            }
269                        })
270                        .collect();
271                    Series::new(name.as_str().into(), vals)
272                }
273                "string" | "str" | "varchar" => {
274                    let vals: Vec<Option<String>> = rows
275                        .iter()
276                        .map(|row| {
277                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
278                            match v {
279                                JsonValue::String(s) => Some(s),
280                                JsonValue::Null => None,
281                                other => Some(other.to_string()),
282                            }
283                        })
284                        .collect();
285                    Series::new(name.as_str().into(), vals)
286                }
287                "boolean" | "bool" => {
288                    let vals: Vec<Option<bool>> = rows
289                        .iter()
290                        .map(|row| {
291                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
292                            match v {
293                                JsonValue::Bool(b) => Some(b),
294                                JsonValue::Null => None,
295                                _ => None,
296                            }
297                        })
298                        .collect();
299                    Series::new(name.as_str().into(), vals)
300                }
301                "date" => {
302                    let epoch = crate::date_utils::epoch_naive_date();
303                    let vals: Vec<Option<i32>> = rows
304                        .iter()
305                        .map(|row| {
306                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
307                            match v {
308                                JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
309                                    .ok()
310                                    .map(|d| (d - epoch).num_days() as i32),
311                                JsonValue::Null => None,
312                                _ => None,
313                            }
314                        })
315                        .collect();
316                    let series = Series::new(name.as_str().into(), vals);
317                    series
318                        .cast(&DataType::Date)
319                        .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
320                }
321                "timestamp" | "datetime" | "timestamp_ntz" => {
322                    let vals: Vec<Option<i64>> =
323                        rows.iter()
324                            .map(|row| {
325                                let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
326                                match v {
327                                    JsonValue::String(s) => {
328                                        let parsed = NaiveDateTime::parse_from_str(
329                                            &s,
330                                            "%Y-%m-%dT%H:%M:%S%.f",
331                                        )
332                                        .or_else(|_| {
333                                            NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
334                                        })
335                                        .or_else(|_| {
336                                            NaiveDate::parse_from_str(&s, "%Y-%m-%d")
337                                                .map(|d| d.and_hms_opt(0, 0, 0).unwrap())
338                                        });
339                                        parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
340                                    }
341                                    JsonValue::Number(n) => n.as_i64(),
342                                    JsonValue::Null => None,
343                                    _ => None,
344                                }
345                            })
346                            .collect();
347                    let series = Series::new(name.as_str().into(), vals);
348                    series
349                        .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
350                        .map_err(|e| {
351                            PolarsError::ComputeError(format!("datetime cast: {e}").into())
352                        })?
353                }
354                _ => {
355                    return Err(PolarsError::ComputeError(
356                        format!(
357                            "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
358                        )
359                        .into(),
360                    ));
361                }
362            };
363            cols.push(s);
364        }
365
366        let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
367        Ok(DataFrame::from_polars_with_options(
368            pl_df,
369            self.is_case_sensitive(),
370        ))
371    }
372
373    /// Create a DataFrame with a single column `id` (bigint) containing values from start to end (exclusive) with step.
374    /// PySpark: spark.range(end) or spark.range(start, end, step).
375    ///
376    /// - `range(end)` → 0 to end-1, step 1
377    /// - `range(start, end)` → start to end-1, step 1
378    /// - `range(start, end, step)` → start, start+step, ... up to but not including end
379    pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
380        if step == 0 {
381            return Err(PolarsError::InvalidOperation(
382                "range: step must not be 0".into(),
383            ));
384        }
385        let mut vals: Vec<i64> = Vec::new();
386        let mut v = start;
387        if step > 0 {
388            while v < end {
389                vals.push(v);
390                v = v.saturating_add(step);
391            }
392        } else {
393            while v > end {
394                vals.push(v);
395                v = v.saturating_add(step);
396            }
397        }
398        let col = Series::new("id".into(), vals);
399        let pl_df = PlDataFrame::new(vec![col.into()])?;
400        Ok(DataFrame::from_polars_with_options(
401            pl_df,
402            self.is_case_sensitive(),
403        ))
404    }
405
406    /// Read a CSV file.
407    ///
408    /// Uses Polars' CSV reader with default options:
409    /// - Header row is inferred (default: true)
410    /// - Schema is inferred from first 100 rows
411    ///
412    /// # Example
413    /// ```
414    /// use robin_sparkless::SparkSession;
415    ///
416    /// let spark = SparkSession::builder().app_name("test").get_or_create();
417    /// let df_result = spark.read_csv("data.csv");
418    /// // Handle the Result as appropriate in your application
419    /// ```
420    pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
421        use polars::prelude::*;
422        let path = path.as_ref();
423        let path_display = path.display();
424        // Use LazyCsvReader - call finish() to get LazyFrame, then collect
425        let lf = LazyCsvReader::new(path)
426            .with_has_header(true)
427            .with_infer_schema_length(Some(100))
428            .finish()
429            .map_err(|e| {
430                PolarsError::ComputeError(
431                    format!(
432                        "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
433                    )
434                    .into(),
435                )
436            })?;
437        let pl_df = lf.collect().map_err(|e| {
438            PolarsError::ComputeError(
439                format!("read_csv({path_display}): collect failed: {e}").into(),
440            )
441        })?;
442        Ok(crate::dataframe::DataFrame::from_polars_with_options(
443            pl_df,
444            self.is_case_sensitive(),
445        ))
446    }
447
448    /// Read a Parquet file.
449    ///
450    /// Uses Polars' Parquet reader. Parquet files have embedded schema, so
451    /// schema inference is automatic.
452    ///
453    /// # Example
454    /// ```
455    /// use robin_sparkless::SparkSession;
456    ///
457    /// let spark = SparkSession::builder().app_name("test").get_or_create();
458    /// let df_result = spark.read_parquet("data.parquet");
459    /// // Handle the Result as appropriate in your application
460    /// ```
461    pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
462        use polars::prelude::*;
463        let path = path.as_ref();
464        // Use LazyFrame::scan_parquet
465        let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
466        let pl_df = lf.collect()?;
467        Ok(crate::dataframe::DataFrame::from_polars_with_options(
468            pl_df,
469            self.is_case_sensitive(),
470        ))
471    }
472
473    /// Read a JSON file (JSONL format - one JSON object per line).
474    ///
475    /// Uses Polars' JSONL reader with default options:
476    /// - Schema is inferred from first 100 rows
477    ///
478    /// # Example
479    /// ```
480    /// use robin_sparkless::SparkSession;
481    ///
482    /// let spark = SparkSession::builder().app_name("test").get_or_create();
483    /// let df_result = spark.read_json("data.json");
484    /// // Handle the Result as appropriate in your application
485    /// ```
486    pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
487        use polars::prelude::*;
488        use std::num::NonZeroUsize;
489        let path = path.as_ref();
490        // Use LazyJsonLineReader - call finish() to get LazyFrame, then collect
491        let lf = LazyJsonLineReader::new(path)
492            .with_infer_schema_length(NonZeroUsize::new(100))
493            .finish()?;
494        let pl_df = lf.collect()?;
495        Ok(crate::dataframe::DataFrame::from_polars_with_options(
496            pl_df,
497            self.is_case_sensitive(),
498        ))
499    }
500
501    /// Execute a SQL query (SELECT only). Tables must be registered with `create_or_replace_temp_view`.
502    /// Requires the `sql` feature. Supports: SELECT (columns or *), FROM (single table or JOIN),
503    /// WHERE (basic predicates), GROUP BY + aggregates, ORDER BY, LIMIT.
504    #[cfg(feature = "sql")]
505    pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
506        crate::sql::execute_sql(self, query)
507    }
508
509    /// Execute a SQL query (stub when `sql` feature is disabled).
510    #[cfg(not(feature = "sql"))]
511    pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
512        Err(PolarsError::InvalidOperation(
513            "SQL queries require the 'sql' feature. Build with --features sql.".into(),
514        ))
515    }
516
517    /// Read a Delta table at the given path (latest version).
518    /// Requires the `delta` feature. Path can be local (e.g. `/tmp/table`) or `file:///...`.
519    #[cfg(feature = "delta")]
520    pub fn read_delta(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
521        crate::delta::read_delta(path, self.is_case_sensitive())
522    }
523
524    /// Read a Delta table at the given path, optionally at a specific version (time travel).
525    #[cfg(feature = "delta")]
526    pub fn read_delta_with_version(
527        &self,
528        path: impl AsRef<Path>,
529        version: Option<i64>,
530    ) -> Result<DataFrame, PolarsError> {
531        crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
532    }
533
534    /// Stub when `delta` feature is disabled.
535    #[cfg(not(feature = "delta"))]
536    pub fn read_delta(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
537        Err(PolarsError::InvalidOperation(
538            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
539        ))
540    }
541
542    #[cfg(not(feature = "delta"))]
543    pub fn read_delta_with_version(
544        &self,
545        _path: impl AsRef<Path>,
546        _version: Option<i64>,
547    ) -> Result<DataFrame, PolarsError> {
548        Err(PolarsError::InvalidOperation(
549            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
550        ))
551    }
552
553    /// Stop the session (cleanup resources)
554    pub fn stop(&self) {
555        // Cleanup if needed
556    }
557}
558
559/// DataFrameReader for reading various file formats
560/// Similar to PySpark's DataFrameReader with option/options/format/load/table
561pub struct DataFrameReader {
562    session: SparkSession,
563    options: HashMap<String, String>,
564    format: Option<String>,
565}
566
567impl DataFrameReader {
568    pub fn new(session: SparkSession) -> Self {
569        DataFrameReader {
570            session,
571            options: HashMap::new(),
572            format: None,
573        }
574    }
575
576    /// Add a single option (PySpark: option(key, value)). Returns self for chaining.
577    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
578        self.options.insert(key.into(), value.into());
579        self
580    }
581
582    /// Add multiple options (PySpark: options(**kwargs)). Returns self for chaining.
583    pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
584        for (k, v) in opts {
585            self.options.insert(k, v);
586        }
587        self
588    }
589
590    /// Set the format for load() (PySpark: format("parquet") etc).
591    pub fn format(mut self, fmt: impl Into<String>) -> Self {
592        self.format = Some(fmt.into());
593        self
594    }
595
596    /// Set the schema (PySpark: schema(schema)). Stub: stores but does not apply yet.
597    pub fn schema(self, _schema: impl Into<String>) -> Self {
598        self
599    }
600
601    /// Load data from path using format (or infer from extension) and options.
602    pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
603        let path = path.as_ref();
604        let fmt = self.format.clone().or_else(|| {
605            path.extension()
606                .and_then(|e| e.to_str())
607                .map(|s| s.to_lowercase())
608        });
609        match fmt.as_deref() {
610            Some("parquet") => self.parquet(path),
611            Some("csv") => self.csv(path),
612            Some("json") | Some("jsonl") => self.json(path),
613            #[cfg(feature = "delta")]
614            Some("delta") => self.session.read_delta(path),
615            _ => Err(PolarsError::ComputeError(
616                format!(
617                    "load: could not infer format for path '{}'. Use format('parquet'|'csv'|'json') before load.",
618                    path.display()
619                )
620                .into(),
621            )),
622        }
623    }
624
625    /// Return the named table/view (PySpark: table(name)).
626    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
627        self.session.table(name)
628    }
629
630    fn apply_csv_options(
631        &self,
632        reader: polars::prelude::LazyCsvReader,
633    ) -> polars::prelude::LazyCsvReader {
634        use polars::prelude::NullValues;
635        let mut r = reader;
636        if let Some(v) = self.options.get("header") {
637            let has_header = v.eq_ignore_ascii_case("true") || v == "1";
638            r = r.with_has_header(has_header);
639        }
640        if let Some(v) = self.options.get("inferSchema") {
641            if v.eq_ignore_ascii_case("true") || v == "1" {
642                let n = self
643                    .options
644                    .get("inferSchemaLength")
645                    .and_then(|s| s.parse::<usize>().ok())
646                    .unwrap_or(100);
647                r = r.with_infer_schema_length(Some(n));
648            }
649        } else if let Some(v) = self.options.get("inferSchemaLength") {
650            if let Ok(n) = v.parse::<usize>() {
651                r = r.with_infer_schema_length(Some(n));
652            }
653        }
654        if let Some(sep) = self.options.get("sep") {
655            if let Some(b) = sep.bytes().next() {
656                r = r.with_separator(b);
657            }
658        }
659        if let Some(null_val) = self.options.get("nullValue") {
660            r = r.with_null_values(Some(NullValues::AllColumnsSingle(null_val.clone().into())));
661        }
662        r
663    }
664
665    fn apply_json_options(
666        &self,
667        reader: polars::prelude::LazyJsonLineReader,
668    ) -> polars::prelude::LazyJsonLineReader {
669        use std::num::NonZeroUsize;
670        let mut r = reader;
671        if let Some(v) = self.options.get("inferSchemaLength") {
672            if let Ok(n) = v.parse::<usize>() {
673                r = r.with_infer_schema_length(NonZeroUsize::new(n));
674            }
675        }
676        r
677    }
678
679    pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
680        use polars::prelude::*;
681        let path = path.as_ref();
682        let path_display = path.display();
683        let reader = LazyCsvReader::new(path);
684        let reader = if self.options.is_empty() {
685            reader
686                .with_has_header(true)
687                .with_infer_schema_length(Some(100))
688        } else {
689            self.apply_csv_options(
690                reader
691                    .with_has_header(true)
692                    .with_infer_schema_length(Some(100)),
693            )
694        };
695        let lf = reader.finish().map_err(|e| {
696            PolarsError::ComputeError(format!("read csv({path_display}): {e}").into())
697        })?;
698        let pl_df = lf.collect().map_err(|e| {
699            PolarsError::ComputeError(
700                format!("read csv({path_display}): collect failed: {e}").into(),
701            )
702        })?;
703        Ok(crate::dataframe::DataFrame::from_polars_with_options(
704            pl_df,
705            self.session.is_case_sensitive(),
706        ))
707    }
708
709    pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
710        use polars::prelude::*;
711        let path = path.as_ref();
712        let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
713        let pl_df = lf.collect()?;
714        Ok(crate::dataframe::DataFrame::from_polars_with_options(
715            pl_df,
716            self.session.is_case_sensitive(),
717        ))
718    }
719
720    pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
721        use polars::prelude::*;
722        use std::num::NonZeroUsize;
723        let path = path.as_ref();
724        let reader = LazyJsonLineReader::new(path);
725        let reader = if self.options.is_empty() {
726            reader.with_infer_schema_length(NonZeroUsize::new(100))
727        } else {
728            self.apply_json_options(reader.with_infer_schema_length(NonZeroUsize::new(100)))
729        };
730        let lf = reader.finish()?;
731        let pl_df = lf.collect()?;
732        Ok(crate::dataframe::DataFrame::from_polars_with_options(
733            pl_df,
734            self.session.is_case_sensitive(),
735        ))
736    }
737
738    #[cfg(feature = "delta")]
739    pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
740        self.session.read_delta(path)
741    }
742}
743
744impl SparkSession {
745    /// Get a DataFrameReader for reading files
746    pub fn read(&self) -> DataFrameReader {
747        DataFrameReader::new(SparkSession {
748            app_name: self.app_name.clone(),
749            master: self.master.clone(),
750            config: self.config.clone(),
751            catalog: self.catalog.clone(),
752        })
753    }
754}
755
756impl Default for SparkSession {
757    fn default() -> Self {
758        Self::builder().get_or_create()
759    }
760}
761
762#[cfg(test)]
763mod tests {
764    use super::*;
765
766    #[test]
767    fn test_spark_session_builder_basic() {
768        let spark = SparkSession::builder().app_name("test_app").get_or_create();
769
770        assert_eq!(spark.app_name, Some("test_app".to_string()));
771    }
772
773    #[test]
774    fn test_spark_session_builder_with_master() {
775        let spark = SparkSession::builder()
776            .app_name("test_app")
777            .master("local[*]")
778            .get_or_create();
779
780        assert_eq!(spark.app_name, Some("test_app".to_string()));
781        assert_eq!(spark.master, Some("local[*]".to_string()));
782    }
783
784    #[test]
785    fn test_spark_session_builder_with_config() {
786        let spark = SparkSession::builder()
787            .app_name("test_app")
788            .config("spark.executor.memory", "4g")
789            .config("spark.driver.memory", "2g")
790            .get_or_create();
791
792        assert_eq!(
793            spark.config.get("spark.executor.memory"),
794            Some(&"4g".to_string())
795        );
796        assert_eq!(
797            spark.config.get("spark.driver.memory"),
798            Some(&"2g".to_string())
799        );
800    }
801
802    #[test]
803    fn test_spark_session_default() {
804        let spark = SparkSession::default();
805        assert!(spark.app_name.is_none());
806        assert!(spark.master.is_none());
807        assert!(spark.config.is_empty());
808    }
809
810    #[test]
811    fn test_create_dataframe_success() {
812        let spark = SparkSession::builder().app_name("test").get_or_create();
813        let data = vec![
814            (1i64, 25i64, "Alice".to_string()),
815            (2i64, 30i64, "Bob".to_string()),
816        ];
817
818        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
819
820        assert!(result.is_ok());
821        let df = result.unwrap();
822        assert_eq!(df.count().unwrap(), 2);
823
824        let columns = df.columns().unwrap();
825        assert!(columns.contains(&"id".to_string()));
826        assert!(columns.contains(&"age".to_string()));
827        assert!(columns.contains(&"name".to_string()));
828    }
829
830    #[test]
831    fn test_create_dataframe_wrong_column_count() {
832        let spark = SparkSession::builder().app_name("test").get_or_create();
833        let data = vec![(1i64, 25i64, "Alice".to_string())];
834
835        // Too few columns
836        let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
837        assert!(result.is_err());
838
839        // Too many columns
840        let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
841        assert!(result.is_err());
842    }
843
844    #[test]
845    fn test_create_dataframe_empty() {
846        let spark = SparkSession::builder().app_name("test").get_or_create();
847        let data: Vec<(i64, i64, String)> = vec![];
848
849        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
850
851        assert!(result.is_ok());
852        let df = result.unwrap();
853        assert_eq!(df.count().unwrap(), 0);
854    }
855
856    #[test]
857    fn test_create_dataframe_from_polars() {
858        use polars::prelude::df;
859
860        let spark = SparkSession::builder().app_name("test").get_or_create();
861        let polars_df = df!(
862            "x" => &[1, 2, 3],
863            "y" => &[4, 5, 6]
864        )
865        .unwrap();
866
867        let df = spark.create_dataframe_from_polars(polars_df);
868
869        assert_eq!(df.count().unwrap(), 3);
870        let columns = df.columns().unwrap();
871        assert!(columns.contains(&"x".to_string()));
872        assert!(columns.contains(&"y".to_string()));
873    }
874
875    #[test]
876    fn test_read_csv_file_not_found() {
877        let spark = SparkSession::builder().app_name("test").get_or_create();
878
879        let result = spark.read_csv("nonexistent_file.csv");
880
881        assert!(result.is_err());
882    }
883
884    #[test]
885    fn test_read_parquet_file_not_found() {
886        let spark = SparkSession::builder().app_name("test").get_or_create();
887
888        let result = spark.read_parquet("nonexistent_file.parquet");
889
890        assert!(result.is_err());
891    }
892
893    #[test]
894    fn test_read_json_file_not_found() {
895        let spark = SparkSession::builder().app_name("test").get_or_create();
896
897        let result = spark.read_json("nonexistent_file.json");
898
899        assert!(result.is_err());
900    }
901
902    #[test]
903    fn test_sql_returns_error_without_feature_or_unknown_table() {
904        let spark = SparkSession::builder().app_name("test").get_or_create();
905
906        let result = spark.sql("SELECT * FROM table");
907
908        assert!(result.is_err());
909        match result {
910            Err(PolarsError::InvalidOperation(msg)) => {
911                let s = msg.to_string();
912                // Without sql feature: "SQL queries require the 'sql' feature"
913                // With sql feature but no table: "Table or view 'table' not found" or parse error
914                assert!(
915                    s.contains("SQL") || s.contains("Table") || s.contains("feature"),
916                    "unexpected message: {s}"
917                );
918            }
919            _ => panic!("Expected InvalidOperation error"),
920        }
921    }
922
923    #[test]
924    fn test_spark_session_stop() {
925        let spark = SparkSession::builder().app_name("test").get_or_create();
926
927        // stop() should complete without error
928        spark.stop();
929    }
930
931    #[test]
932    fn test_dataframe_reader_api() {
933        let spark = SparkSession::builder().app_name("test").get_or_create();
934        let reader = spark.read();
935
936        // All readers should return errors for non-existent files
937        assert!(reader.csv("nonexistent.csv").is_err());
938        assert!(reader.parquet("nonexistent.parquet").is_err());
939        assert!(reader.json("nonexistent.json").is_err());
940    }
941
942    #[test]
943    fn test_read_csv_with_valid_file() {
944        use std::io::Write;
945        use tempfile::NamedTempFile;
946
947        let spark = SparkSession::builder().app_name("test").get_or_create();
948
949        // Create a temporary CSV file
950        let mut temp_file = NamedTempFile::new().unwrap();
951        writeln!(temp_file, "id,name,age").unwrap();
952        writeln!(temp_file, "1,Alice,25").unwrap();
953        writeln!(temp_file, "2,Bob,30").unwrap();
954        temp_file.flush().unwrap();
955
956        let result = spark.read_csv(temp_file.path());
957
958        assert!(result.is_ok());
959        let df = result.unwrap();
960        assert_eq!(df.count().unwrap(), 2);
961
962        let columns = df.columns().unwrap();
963        assert!(columns.contains(&"id".to_string()));
964        assert!(columns.contains(&"name".to_string()));
965        assert!(columns.contains(&"age".to_string()));
966    }
967
968    #[test]
969    fn test_read_json_with_valid_file() {
970        use std::io::Write;
971        use tempfile::NamedTempFile;
972
973        let spark = SparkSession::builder().app_name("test").get_or_create();
974
975        // Create a temporary JSONL file
976        let mut temp_file = NamedTempFile::new().unwrap();
977        writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
978        writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
979        temp_file.flush().unwrap();
980
981        let result = spark.read_json(temp_file.path());
982
983        assert!(result.is_ok());
984        let df = result.unwrap();
985        assert_eq!(df.count().unwrap(), 2);
986    }
987
988    #[test]
989    fn test_read_csv_empty_file() {
990        use std::io::Write;
991        use tempfile::NamedTempFile;
992
993        let spark = SparkSession::builder().app_name("test").get_or_create();
994
995        // Create an empty CSV file (just header)
996        let mut temp_file = NamedTempFile::new().unwrap();
997        writeln!(temp_file, "id,name").unwrap();
998        temp_file.flush().unwrap();
999
1000        let result = spark.read_csv(temp_file.path());
1001
1002        assert!(result.is_ok());
1003        let df = result.unwrap();
1004        assert_eq!(df.count().unwrap(), 0);
1005    }
1006}