Skip to main content

robin_sparkless/
session.rs

1use crate::dataframe::DataFrame;
2use crate::udf_registry::UdfRegistry;
3use polars::chunked_array::builder::get_list_builder;
4use polars::chunked_array::StructChunked;
5use polars::prelude::{
6    DataFrame as PlDataFrame, DataType, IntoSeries, NamedFrom, PlSmallStr, PolarsError, Series,
7    TimeUnit,
8};
9use serde_json::Value as JsonValue;
10use std::cell::RefCell;
11
12/// Parse "array<element_type>" to get inner type string. Returns None if not array<>.
13fn parse_array_element_type(type_str: &str) -> Option<String> {
14    let s = type_str.trim();
15    if !s.to_lowercase().starts_with("array<") || !s.ends_with('>') {
16        return None;
17    }
18    Some(s[6..s.len() - 1].trim().to_string())
19}
20
21/// Parse "struct<field:type,...>" to get field (name, type) pairs. Simple parsing, no nested structs.
22fn parse_struct_fields(type_str: &str) -> Option<Vec<(String, String)>> {
23    let s = type_str.trim();
24    if !s.to_lowercase().starts_with("struct<") || !s.ends_with('>') {
25        return None;
26    }
27    let inner = s[7..s.len() - 1].trim();
28    if inner.is_empty() {
29        return Some(Vec::new());
30    }
31    let mut out = Vec::new();
32    for part in inner.split(',') {
33        let part = part.trim();
34        if let Some(idx) = part.find(':') {
35            let name = part[..idx].trim().to_string();
36            let typ = part[idx + 1..].trim().to_string();
37            out.push((name, typ));
38        }
39    }
40    Some(out)
41}
42
43/// Map schema type string to Polars DataType (primitives only for nested use).
44fn json_type_str_to_polars(type_str: &str) -> Option<DataType> {
45    match type_str.trim().to_lowercase().as_str() {
46        "int" | "bigint" | "long" => Some(DataType::Int64),
47        "double" | "float" | "double_precision" => Some(DataType::Float64),
48        "string" | "str" | "varchar" => Some(DataType::String),
49        "boolean" | "bool" => Some(DataType::Boolean),
50        _ => None,
51    }
52}
53
54/// Build a length-N Series from Vec<Option<JsonValue>> for a given type (recursive for struct/array).
55fn json_values_to_series(
56    values: &[Option<JsonValue>],
57    type_str: &str,
58    name: &str,
59) -> Result<Series, PolarsError> {
60    use chrono::{NaiveDate, NaiveDateTime};
61    let epoch = crate::date_utils::epoch_naive_date();
62    let type_lower = type_str.trim().to_lowercase();
63
64    if let Some(elem_type) = parse_array_element_type(&type_lower) {
65        let inner_dtype = json_type_str_to_polars(&elem_type).ok_or_else(|| {
66            PolarsError::ComputeError(
67                format!("array element type '{elem_type}' not supported").into(),
68            )
69        })?;
70        let mut builder = get_list_builder(&inner_dtype, 64, values.len(), name.into());
71        for v in values.iter() {
72            if v.as_ref().is_none_or(|x| matches!(x, JsonValue::Null)) {
73                builder.append_null();
74            } else if let Some(arr) = v.as_ref().and_then(|x| x.as_array()) {
75                let elem_series: Vec<Series> = arr
76                    .iter()
77                    .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
78                    .collect::<Result<Vec<_>, _>>()?;
79                let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
80                let s = Series::from_any_values_and_dtype(
81                    PlSmallStr::EMPTY,
82                    &vals,
83                    &inner_dtype,
84                    false,
85                )
86                .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
87                builder.append_series(&s)?;
88            } else {
89                return Err(PolarsError::ComputeError(
90                    "array column value must be null or array".into(),
91                ));
92            }
93        }
94        return Ok(builder.finish().into_series());
95    }
96
97    if let Some(fields) = parse_struct_fields(&type_lower) {
98        let mut field_series_vec: Vec<Vec<Option<JsonValue>>> = (0..fields.len())
99            .map(|_| Vec::with_capacity(values.len()))
100            .collect();
101        for v in values.iter() {
102            if v.as_ref().is_none_or(|x| matches!(x, JsonValue::Null)) {
103                for fc in &mut field_series_vec {
104                    fc.push(None);
105                }
106            } else if let Some(obj) = v.as_ref().and_then(|x| x.as_object()) {
107                for (fi, (fname, _)) in fields.iter().enumerate() {
108                    field_series_vec[fi].push(obj.get(fname).cloned());
109                }
110            } else if let Some(arr) = v.as_ref().and_then(|x| x.as_array()) {
111                for (fi, _) in fields.iter().enumerate() {
112                    field_series_vec[fi].push(arr.get(fi).cloned());
113                }
114            } else {
115                return Err(PolarsError::ComputeError(
116                    "struct value must be object or array".into(),
117                ));
118            }
119        }
120        let series_per_field: Vec<Series> = fields
121            .iter()
122            .enumerate()
123            .map(|(fi, (fname, ftype))| json_values_to_series(&field_series_vec[fi], ftype, fname))
124            .collect::<Result<Vec<_>, _>>()?;
125        let field_refs: Vec<&Series> = series_per_field.iter().collect();
126        let st = StructChunked::from_series(name.into(), values.len(), field_refs.iter().copied())
127            .map_err(|e| PolarsError::ComputeError(format!("struct column: {e}").into()))?
128            .into_series();
129        return Ok(st);
130    }
131
132    match type_lower.as_str() {
133        "int" | "bigint" | "long" => {
134            let vals: Vec<Option<i64>> = values
135                .iter()
136                .map(|ov| {
137                    ov.as_ref().and_then(|v| match v {
138                        JsonValue::Number(n) => n.as_i64(),
139                        JsonValue::Null => None,
140                        _ => None,
141                    })
142                })
143                .collect();
144            Ok(Series::new(name.into(), vals))
145        }
146        "double" | "float" => {
147            let vals: Vec<Option<f64>> = values
148                .iter()
149                .map(|ov| {
150                    ov.as_ref().and_then(|v| match v {
151                        JsonValue::Number(n) => n.as_f64(),
152                        JsonValue::Null => None,
153                        _ => None,
154                    })
155                })
156                .collect();
157            Ok(Series::new(name.into(), vals))
158        }
159        "string" | "str" | "varchar" => {
160            let vals: Vec<Option<&str>> = values
161                .iter()
162                .map(|ov| {
163                    ov.as_ref().and_then(|v| match v {
164                        JsonValue::String(s) => Some(s.as_str()),
165                        JsonValue::Null => None,
166                        _ => None,
167                    })
168                })
169                .collect();
170            let owned: Vec<Option<String>> =
171                vals.into_iter().map(|o| o.map(|s| s.to_string())).collect();
172            Ok(Series::new(name.into(), owned))
173        }
174        "boolean" | "bool" => {
175            let vals: Vec<Option<bool>> = values
176                .iter()
177                .map(|ov| {
178                    ov.as_ref().and_then(|v| match v {
179                        JsonValue::Bool(b) => Some(*b),
180                        JsonValue::Null => None,
181                        _ => None,
182                    })
183                })
184                .collect();
185            Ok(Series::new(name.into(), vals))
186        }
187        "date" => {
188            let vals: Vec<Option<i32>> = values
189                .iter()
190                .map(|ov| {
191                    ov.as_ref().and_then(|v| match v {
192                        JsonValue::String(s) => NaiveDate::parse_from_str(s, "%Y-%m-%d")
193                            .ok()
194                            .map(|d| (d - epoch).num_days() as i32),
195                        JsonValue::Null => None,
196                        _ => None,
197                    })
198                })
199                .collect();
200            let s = Series::new(name.into(), vals);
201            s.cast(&DataType::Date)
202                .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))
203        }
204        "timestamp" | "datetime" | "timestamp_ntz" => {
205            let vals: Vec<Option<i64>> = values
206                .iter()
207                .map(|ov| {
208                    ov.as_ref().and_then(|v| match v {
209                        JsonValue::String(s) => {
210                            let parsed = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
211                                .or_else(|_| NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S"))
212                                .or_else(|_| {
213                                    NaiveDate::parse_from_str(s, "%Y-%m-%d")
214                                        .map(|d| d.and_hms_opt(0, 0, 0).unwrap())
215                                });
216                            parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
217                        }
218                        JsonValue::Number(n) => n.as_i64(),
219                        JsonValue::Null => None,
220                        _ => None,
221                    })
222                })
223                .collect();
224            let s = Series::new(name.into(), vals);
225            s.cast(&DataType::Datetime(TimeUnit::Microseconds, None))
226                .map_err(|e| PolarsError::ComputeError(format!("datetime cast: {e}").into()))
227        }
228        _ => Err(PolarsError::ComputeError(
229            format!("json_values_to_series: unsupported type '{type_str}'").into(),
230        )),
231    }
232}
233
234/// Build a single Series from a JsonValue for use as list element or struct field.
235fn json_value_to_series_single(
236    value: &JsonValue,
237    type_str: &str,
238    name: &str,
239) -> Result<Series, PolarsError> {
240    use chrono::NaiveDate;
241    let epoch = crate::date_utils::epoch_naive_date();
242    match (value, type_str.trim().to_lowercase().as_str()) {
243        (JsonValue::Null, _) => Ok(Series::new_null(name.into(), 1)),
244        (JsonValue::Number(n), "int" | "bigint" | "long") => {
245            Ok(Series::new(name.into(), vec![n.as_i64()]))
246        }
247        (JsonValue::Number(n), "double" | "float") => {
248            Ok(Series::new(name.into(), vec![n.as_f64()]))
249        }
250        (JsonValue::String(s), "string" | "str" | "varchar") => {
251            Ok(Series::new(name.into(), vec![s.as_str()]))
252        }
253        (JsonValue::Bool(b), "boolean" | "bool") => Ok(Series::new(name.into(), vec![*b])),
254        (JsonValue::String(s), "date") => {
255            let d = NaiveDate::parse_from_str(s, "%Y-%m-%d")
256                .map_err(|e| PolarsError::ComputeError(format!("date parse: {e}").into()))?;
257            let days = (d - epoch).num_days() as i32;
258            let s = Series::new(name.into(), vec![days]).cast(&DataType::Date)?;
259            Ok(s)
260        }
261        _ => Err(PolarsError::ComputeError(
262            format!("json_value_to_series: unsupported {type_str} for {value:?}").into(),
263        )),
264    }
265}
266
267/// Build a struct Series from JsonValue::Object or JsonValue::Array (field-order) or Null.
268#[allow(dead_code)]
269fn json_object_or_array_to_struct_series(
270    value: &JsonValue,
271    fields: &[(String, String)],
272    _name: &str,
273) -> Result<Option<Series>, PolarsError> {
274    use polars::prelude::StructChunked;
275    if matches!(value, JsonValue::Null) {
276        return Ok(None);
277    }
278    let mut field_series: Vec<Series> = Vec::with_capacity(fields.len());
279    for (fname, ftype) in fields {
280        let fval = if let Some(obj) = value.as_object() {
281            obj.get(fname).unwrap_or(&JsonValue::Null)
282        } else if let Some(arr) = value.as_array() {
283            let idx = field_series.len();
284            arr.get(idx).unwrap_or(&JsonValue::Null)
285        } else {
286            return Err(PolarsError::ComputeError(
287                "struct value must be object or array".into(),
288            ));
289        };
290        let s = json_value_to_series_single(fval, ftype, fname)?;
291        field_series.push(s);
292    }
293    let field_refs: Vec<&Series> = field_series.iter().collect();
294    let st = StructChunked::from_series(PlSmallStr::EMPTY, 1, field_refs.iter().copied())
295        .map_err(|e| PolarsError::ComputeError(format!("struct from value: {e}").into()))?
296        .into_series();
297    Ok(Some(st))
298}
299
300use std::collections::HashMap;
301use std::path::Path;
302use std::sync::{Arc, Mutex, OnceLock};
303use std::thread_local;
304
305thread_local! {
306    /// Thread-local SparkSession for UDF resolution in call_udf. Set by get_or_create.
307    static THREAD_UDF_SESSION: RefCell<Option<SparkSession>> = const { RefCell::new(None) };
308}
309
310/// Set the thread-local session for UDF resolution (call_udf). Used by get_or_create.
311pub(crate) fn set_thread_udf_session(session: SparkSession) {
312    THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = Some(session));
313}
314
315/// Get the thread-local session for UDF resolution. Used by call_udf.
316pub(crate) fn get_thread_udf_session() -> Option<SparkSession> {
317    THREAD_UDF_SESSION.with(|cell| cell.borrow().clone())
318}
319
320/// Catalog of global temporary views (process-scoped). Persists across sessions within the same process.
321/// PySpark: createOrReplaceGlobalTempView / spark.table("global_temp.name").
322static GLOBAL_TEMP_CATALOG: OnceLock<Arc<Mutex<HashMap<String, DataFrame>>>> = OnceLock::new();
323
324fn global_temp_catalog() -> Arc<Mutex<HashMap<String, DataFrame>>> {
325    GLOBAL_TEMP_CATALOG
326        .get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
327        .clone()
328}
329
330/// Builder for creating a SparkSession with configuration options
331#[derive(Clone)]
332pub struct SparkSessionBuilder {
333    app_name: Option<String>,
334    master: Option<String>,
335    config: HashMap<String, String>,
336}
337
338impl Default for SparkSessionBuilder {
339    fn default() -> Self {
340        Self::new()
341    }
342}
343
344impl SparkSessionBuilder {
345    pub fn new() -> Self {
346        SparkSessionBuilder {
347            app_name: None,
348            master: None,
349            config: HashMap::new(),
350        }
351    }
352
353    pub fn app_name(mut self, name: impl Into<String>) -> Self {
354        self.app_name = Some(name.into());
355        self
356    }
357
358    pub fn master(mut self, master: impl Into<String>) -> Self {
359        self.master = Some(master.into());
360        self
361    }
362
363    pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
364        self.config.insert(key.into(), value.into());
365        self
366    }
367
368    pub fn get_or_create(self) -> SparkSession {
369        let session = SparkSession::new(self.app_name, self.master, self.config);
370        set_thread_udf_session(session.clone());
371        session
372    }
373}
374
375/// Catalog of temporary view names to DataFrames (session-scoped). Uses Arc<Mutex<>> for Send+Sync (Python bindings).
376pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
377
378/// Catalog of saved table names to DataFrames (session-scoped). Used by saveAsTable.
379pub type TableCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
380
381/// Main entry point for creating DataFrames and executing queries
382/// Similar to PySpark's SparkSession but using Polars as the backend
383#[derive(Clone)]
384pub struct SparkSession {
385    app_name: Option<String>,
386    master: Option<String>,
387    config: HashMap<String, String>,
388    /// Temporary views: name -> DataFrame. Session-scoped; cleared when session is dropped.
389    pub(crate) catalog: TempViewCatalog,
390    /// Saved tables (saveAsTable): name -> DataFrame. Session-scoped; separate namespace from temp views.
391    pub(crate) tables: TableCatalog,
392    /// UDF registry: Rust and Python UDFs. Session-scoped.
393    pub(crate) udf_registry: UdfRegistry,
394    /// Python UDF execution batch size for vectorized UDFs (non-grouped). usize::MAX = no chunking.
395    #[cfg(feature = "pyo3")]
396    pub(crate) python_udf_batch_size: usize,
397    /// Maximum concurrent Python UDF batches/groups to execute. 1 = serial.
398    #[cfg(feature = "pyo3")]
399    pub(crate) python_udf_max_concurrent_batches: usize,
400}
401
402impl SparkSession {
403    pub fn new(
404        app_name: Option<String>,
405        master: Option<String>,
406        config: HashMap<String, String>,
407    ) -> Self {
408        #[cfg(feature = "pyo3")]
409        let batch_size = config
410            .get("spark.robin.pythonUdf.batchSize")
411            .and_then(|s| s.parse::<usize>().ok())
412            .unwrap_or(usize::MAX);
413        #[cfg(feature = "pyo3")]
414        let max_concurrent = config
415            .get("spark.robin.pythonUdf.maxConcurrentBatches")
416            .and_then(|s| s.parse::<usize>().ok())
417            .unwrap_or(1);
418
419        SparkSession {
420            app_name,
421            master,
422            config,
423            catalog: Arc::new(Mutex::new(HashMap::new())),
424            tables: Arc::new(Mutex::new(HashMap::new())),
425            udf_registry: UdfRegistry::new(),
426            #[cfg(feature = "pyo3")]
427            python_udf_batch_size: batch_size,
428            #[cfg(feature = "pyo3")]
429            python_udf_max_concurrent_batches: max_concurrent,
430        }
431    }
432
433    /// Register a DataFrame as a temporary view (PySpark: createOrReplaceTempView).
434    /// The view is session-scoped and is dropped when the session is dropped.
435    pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
436        let _ = self
437            .catalog
438            .lock()
439            .map(|mut m| m.insert(name.to_string(), df));
440    }
441
442    /// Global temp view (PySpark: createGlobalTempView). Persists across sessions within the same process.
443    pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
444        let _ = global_temp_catalog()
445            .lock()
446            .map(|mut m| m.insert(name.to_string(), df));
447    }
448
449    /// Global temp view (PySpark: createOrReplaceGlobalTempView). Persists across sessions within the same process.
450    pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
451        let _ = global_temp_catalog()
452            .lock()
453            .map(|mut m| m.insert(name.to_string(), df));
454    }
455
456    /// Drop a temporary view by name (PySpark: catalog.dropTempView).
457    /// No error if the view does not exist.
458    pub fn drop_temp_view(&self, name: &str) {
459        let _ = self.catalog.lock().map(|mut m| m.remove(name));
460    }
461
462    /// Drop a global temporary view (PySpark: catalog.dropGlobalTempView). Removes from process-wide catalog.
463    pub fn drop_global_temp_view(&self, name: &str) -> bool {
464        global_temp_catalog()
465            .lock()
466            .map(|mut m| m.remove(name).is_some())
467            .unwrap_or(false)
468    }
469
470    /// Register a DataFrame as a saved table (PySpark: saveAsTable). Inserts into the tables catalog only.
471    pub fn register_table(&self, name: &str, df: DataFrame) {
472        let _ = self
473            .tables
474            .lock()
475            .map(|mut m| m.insert(name.to_string(), df));
476    }
477
478    /// Get a saved table by name (tables map only). Returns None if not in saved tables (temp views not checked).
479    pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
480        self.tables.lock().ok().and_then(|m| m.get(name).cloned())
481    }
482
483    /// True if the name exists in the saved-tables map (not temp views).
484    pub fn saved_table_exists(&self, name: &str) -> bool {
485        self.tables
486            .lock()
487            .map(|m| m.contains_key(name))
488            .unwrap_or(false)
489    }
490
491    /// Check if a table or temp view exists (PySpark: catalog.tableExists). True if name is in temp views, saved tables, global temp, or warehouse.
492    pub fn table_exists(&self, name: &str) -> bool {
493        // global_temp.xyz
494        if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
495            return global_temp_catalog()
496                .lock()
497                .map(|m| m.contains_key(tbl))
498                .unwrap_or(false);
499        }
500        if self
501            .catalog
502            .lock()
503            .map(|m| m.contains_key(name))
504            .unwrap_or(false)
505        {
506            return true;
507        }
508        if self
509            .tables
510            .lock()
511            .map(|m| m.contains_key(name))
512            .unwrap_or(false)
513        {
514            return true;
515        }
516        // Warehouse fallback
517        if let Some(warehouse) = self.warehouse_dir() {
518            let path = Path::new(warehouse).join(name);
519            if path.is_dir() {
520                return true;
521            }
522        }
523        false
524    }
525
526    /// Return global temp view names (process-scoped). PySpark: catalog.listTables(dbName="global_temp").
527    pub fn list_global_temp_view_names(&self) -> Vec<String> {
528        global_temp_catalog()
529            .lock()
530            .map(|m| m.keys().cloned().collect())
531            .unwrap_or_default()
532    }
533
534    /// Return temporary view names in this session.
535    pub fn list_temp_view_names(&self) -> Vec<String> {
536        self.catalog
537            .lock()
538            .map(|m| m.keys().cloned().collect())
539            .unwrap_or_default()
540    }
541
542    /// Return saved table names in this session (saveAsTable / write_delta_table).
543    pub fn list_table_names(&self) -> Vec<String> {
544        self.tables
545            .lock()
546            .map(|m| m.keys().cloned().collect())
547            .unwrap_or_default()
548    }
549
550    /// Drop a saved table by name (removes from tables catalog only). No-op if not present.
551    pub fn drop_table(&self, name: &str) -> bool {
552        self.tables
553            .lock()
554            .map(|mut m| m.remove(name).is_some())
555            .unwrap_or(false)
556    }
557
558    /// Parse "global_temp.xyz" into ("global_temp", "xyz"). Returns None for plain names.
559    fn parse_global_temp_name(name: &str) -> Option<(&str, &str)> {
560        if let Some(dot) = name.find('.') {
561            let (db, tbl) = name.split_at(dot);
562            if db.eq_ignore_ascii_case("global_temp") {
563                return Some((db, tbl.strip_prefix('.').unwrap_or(tbl)));
564            }
565        }
566        None
567    }
568
569    /// Return spark.sql.warehouse.dir from config if set. Enables disk-backed saveAsTable.
570    pub fn warehouse_dir(&self) -> Option<&str> {
571        self.config
572            .get("spark.sql.warehouse.dir")
573            .map(|s| s.as_str())
574            .filter(|s| !s.is_empty())
575    }
576
577    /// Look up a table or temp view by name (PySpark: table(name)).
578    /// Resolution order: (1) global_temp.xyz from global catalog, (2) temp view, (3) saved table, (4) warehouse.
579    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
580        // global_temp.xyz -> global catalog only
581        if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
582            if let Some(df) = global_temp_catalog()
583                .lock()
584                .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
585                .get(tbl)
586                .cloned()
587            {
588                return Ok(df);
589            }
590            return Err(PolarsError::InvalidOperation(
591                format!(
592                    "Global temp view '{tbl}' not found. Register it with createOrReplaceGlobalTempView."
593                )
594                .into(),
595            ));
596        }
597        // Session: temp view, saved table
598        if let Some(df) = self
599            .catalog
600            .lock()
601            .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
602            .get(name)
603            .cloned()
604        {
605            return Ok(df);
606        }
607        if let Some(df) = self
608            .tables
609            .lock()
610            .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
611            .get(name)
612            .cloned()
613        {
614            return Ok(df);
615        }
616        // Warehouse fallback (disk-backed saveAsTable)
617        if let Some(warehouse) = self.warehouse_dir() {
618            let dir = Path::new(warehouse).join(name);
619            if dir.is_dir() {
620                // Read data.parquet (our convention) or the dir (Polars accepts dirs with parquet files)
621                let data_file = dir.join("data.parquet");
622                let read_path = if data_file.is_file() { data_file } else { dir };
623                return self.read_parquet(&read_path);
624            }
625        }
626        Err(PolarsError::InvalidOperation(
627            format!(
628                "Table or view '{name}' not found. Register it with create_or_replace_temp_view or saveAsTable."
629            )
630            .into(),
631        ))
632    }
633
634    pub fn builder() -> SparkSessionBuilder {
635        SparkSessionBuilder::new()
636    }
637
638    /// Return a reference to the session config (for catalog/conf compatibility).
639    pub fn get_config(&self) -> &HashMap<String, String> {
640        &self.config
641    }
642
643    /// Whether column names are case-sensitive (PySpark: spark.sql.caseSensitive).
644    /// Default is false (case-insensitive matching).
645    pub fn is_case_sensitive(&self) -> bool {
646        self.config
647            .get("spark.sql.caseSensitive")
648            .map(|v| v.eq_ignore_ascii_case("true"))
649            .unwrap_or(false)
650    }
651
652    /// Register a Rust UDF. Session-scoped. Use with call_udf. PySpark: spark.udf.register (Python) or equivalent.
653    pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
654    where
655        F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
656    {
657        self.udf_registry.register_rust_udf(name, f)
658    }
659
660    /// Create a DataFrame from a vector of tuples (i64, i64, String)
661    ///
662    /// # Example
663    /// ```
664    /// use robin_sparkless::session::SparkSession;
665    ///
666    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
667    /// let spark = SparkSession::builder().app_name("test").get_or_create();
668    /// let df = spark.create_dataframe(
669    ///     vec![
670    ///         (1, 25, "Alice".to_string()),
671    ///         (2, 30, "Bob".to_string()),
672    ///     ],
673    ///     vec!["id", "age", "name"],
674    /// )?;
675    /// #     let _ = df;
676    /// #     Ok(())
677    /// # }
678    /// ```
679    pub fn create_dataframe(
680        &self,
681        data: Vec<(i64, i64, String)>,
682        column_names: Vec<&str>,
683    ) -> Result<DataFrame, PolarsError> {
684        if column_names.len() != 3 {
685            return Err(PolarsError::ComputeError(
686                format!(
687                    "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
688                    column_names.len()
689                )
690                .into(),
691            ));
692        }
693
694        let mut cols: Vec<Series> = Vec::with_capacity(3);
695
696        // First column: i64
697        let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
698        cols.push(Series::new(column_names[0].into(), col0));
699
700        // Second column: i64
701        let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
702        cols.push(Series::new(column_names[1].into(), col1));
703
704        // Third column: String
705        let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
706        cols.push(Series::new(column_names[2].into(), col2));
707
708        let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
709        Ok(DataFrame::from_polars_with_options(
710            pl_df,
711            self.is_case_sensitive(),
712        ))
713    }
714
715    /// Create a DataFrame from a Polars DataFrame
716    pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
717        DataFrame::from_polars_with_options(df, self.is_case_sensitive())
718    }
719
720    /// Create a DataFrame from rows and a schema (arbitrary column count and types).
721    ///
722    /// `rows`: each inner vec is one row; length must match schema length. Values are JSON-like (i64, f64, string, bool, null, object, array).
723    /// `schema`: list of (column_name, dtype_string), e.g. `[("id", "bigint"), ("name", "string")]`.
724    /// Supported dtype strings: bigint, int, long, double, float, string, str, varchar, boolean, bool, date, timestamp, datetime, array<element_type>, struct<field:type,...>.
725    pub fn create_dataframe_from_rows(
726        &self,
727        rows: Vec<Vec<JsonValue>>,
728        schema: Vec<(String, String)>,
729    ) -> Result<DataFrame, PolarsError> {
730        if schema.is_empty() {
731            return Err(PolarsError::InvalidOperation(
732                "create_dataframe_from_rows: schema must not be empty".into(),
733            ));
734        }
735        use chrono::{NaiveDate, NaiveDateTime};
736
737        let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
738
739        for (col_idx, (name, type_str)) in schema.iter().enumerate() {
740            let type_lower = type_str.trim().to_lowercase();
741            let s = match type_lower.as_str() {
742                "int" | "bigint" | "long" => {
743                    let vals: Vec<Option<i64>> = rows
744                        .iter()
745                        .map(|row| {
746                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
747                            match v {
748                                JsonValue::Number(n) => n.as_i64(),
749                                JsonValue::Null => None,
750                                _ => None,
751                            }
752                        })
753                        .collect();
754                    Series::new(name.as_str().into(), vals)
755                }
756                "double" | "float" | "double_precision" => {
757                    let vals: Vec<Option<f64>> = rows
758                        .iter()
759                        .map(|row| {
760                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
761                            match v {
762                                JsonValue::Number(n) => n.as_f64(),
763                                JsonValue::Null => None,
764                                _ => None,
765                            }
766                        })
767                        .collect();
768                    Series::new(name.as_str().into(), vals)
769                }
770                "string" | "str" | "varchar" => {
771                    let vals: Vec<Option<String>> = rows
772                        .iter()
773                        .map(|row| {
774                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
775                            match v {
776                                JsonValue::String(s) => Some(s),
777                                JsonValue::Null => None,
778                                other => Some(other.to_string()),
779                            }
780                        })
781                        .collect();
782                    Series::new(name.as_str().into(), vals)
783                }
784                "boolean" | "bool" => {
785                    let vals: Vec<Option<bool>> = rows
786                        .iter()
787                        .map(|row| {
788                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
789                            match v {
790                                JsonValue::Bool(b) => Some(b),
791                                JsonValue::Null => None,
792                                _ => None,
793                            }
794                        })
795                        .collect();
796                    Series::new(name.as_str().into(), vals)
797                }
798                "date" => {
799                    let epoch = crate::date_utils::epoch_naive_date();
800                    let vals: Vec<Option<i32>> = rows
801                        .iter()
802                        .map(|row| {
803                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
804                            match v {
805                                JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
806                                    .ok()
807                                    .map(|d| (d - epoch).num_days() as i32),
808                                JsonValue::Null => None,
809                                _ => None,
810                            }
811                        })
812                        .collect();
813                    let series = Series::new(name.as_str().into(), vals);
814                    series
815                        .cast(&DataType::Date)
816                        .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
817                }
818                "timestamp" | "datetime" | "timestamp_ntz" => {
819                    let vals: Vec<Option<i64>> =
820                        rows.iter()
821                            .map(|row| {
822                                let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
823                                match v {
824                                    JsonValue::String(s) => {
825                                        let parsed = NaiveDateTime::parse_from_str(
826                                            &s,
827                                            "%Y-%m-%dT%H:%M:%S%.f",
828                                        )
829                                        .or_else(|_| {
830                                            NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
831                                        })
832                                        .or_else(|_| {
833                                            NaiveDate::parse_from_str(&s, "%Y-%m-%d")
834                                                .map(|d| d.and_hms_opt(0, 0, 0).unwrap())
835                                        });
836                                        parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
837                                    }
838                                    JsonValue::Number(n) => n.as_i64(),
839                                    JsonValue::Null => None,
840                                    _ => None,
841                                }
842                            })
843                            .collect();
844                    let series = Series::new(name.as_str().into(), vals);
845                    series
846                        .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
847                        .map_err(|e| {
848                            PolarsError::ComputeError(format!("datetime cast: {e}").into())
849                        })?
850                }
851                _ if parse_array_element_type(&type_lower).is_some() => {
852                    let elem_type = parse_array_element_type(&type_lower).unwrap();
853                    let inner_dtype = json_type_str_to_polars(&elem_type)
854                        .ok_or_else(|| {
855                            PolarsError::ComputeError(
856                                format!(
857                                    "create_dataframe_from_rows: array element type '{elem_type}' not supported"
858                                )
859                                .into(),
860                            )
861                        })?;
862                    let n = rows.len();
863                    let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
864                    for row in rows.iter() {
865                        let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
866                        if let JsonValue::Null = v {
867                            builder.append_null();
868                        } else if let Some(arr) = v.as_array() {
869                            let elem_series: Vec<Series> = arr
870                                .iter()
871                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
872                                .collect::<Result<Vec<_>, _>>()?;
873                            let vals: Vec<_> =
874                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
875                            let s = Series::from_any_values_and_dtype(
876                                PlSmallStr::EMPTY,
877                                &vals,
878                                &inner_dtype,
879                                false,
880                            )
881                            .map_err(|e| {
882                                PolarsError::ComputeError(format!("array elem: {e}").into())
883                            })?;
884                            builder.append_series(&s)?;
885                        } else {
886                            return Err(PolarsError::ComputeError(
887                                "array column value must be null or array".into(),
888                            ));
889                        }
890                    }
891                    builder.finish().into_series()
892                }
893                _ if parse_struct_fields(&type_lower).is_some() => {
894                    let values: Vec<Option<JsonValue>> =
895                        rows.iter().map(|row| row.get(col_idx).cloned()).collect();
896                    json_values_to_series(&values, &type_lower, name)?
897                }
898                _ => {
899                    return Err(PolarsError::ComputeError(
900                        format!(
901                            "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
902                        )
903                        .into(),
904                    ));
905                }
906            };
907            cols.push(s);
908        }
909
910        let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
911        Ok(DataFrame::from_polars_with_options(
912            pl_df,
913            self.is_case_sensitive(),
914        ))
915    }
916
917    /// Create a DataFrame with a single column `id` (bigint) containing values from start to end (exclusive) with step.
918    /// PySpark: spark.range(end) or spark.range(start, end, step).
919    ///
920    /// - `range(end)` → 0 to end-1, step 1
921    /// - `range(start, end)` → start to end-1, step 1
922    /// - `range(start, end, step)` → start, start+step, ... up to but not including end
923    pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
924        if step == 0 {
925            return Err(PolarsError::InvalidOperation(
926                "range: step must not be 0".into(),
927            ));
928        }
929        let mut vals: Vec<i64> = Vec::new();
930        let mut v = start;
931        if step > 0 {
932            while v < end {
933                vals.push(v);
934                v = v.saturating_add(step);
935            }
936        } else {
937            while v > end {
938                vals.push(v);
939                v = v.saturating_add(step);
940            }
941        }
942        let col = Series::new("id".into(), vals);
943        let pl_df = PlDataFrame::new(vec![col.into()])?;
944        Ok(DataFrame::from_polars_with_options(
945            pl_df,
946            self.is_case_sensitive(),
947        ))
948    }
949
950    /// Read a CSV file.
951    ///
952    /// Uses Polars' CSV reader with default options:
953    /// - Header row is inferred (default: true)
954    /// - Schema is inferred from first 100 rows
955    ///
956    /// # Example
957    /// ```
958    /// use robin_sparkless::SparkSession;
959    ///
960    /// let spark = SparkSession::builder().app_name("test").get_or_create();
961    /// let df_result = spark.read_csv("data.csv");
962    /// // Handle the Result as appropriate in your application
963    /// ```
964    pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
965        use polars::prelude::*;
966        let path = path.as_ref();
967        let path_display = path.display();
968        // Use LazyCsvReader - call finish() to get LazyFrame, then collect
969        let lf = LazyCsvReader::new(path)
970            .with_has_header(true)
971            .with_infer_schema_length(Some(100))
972            .finish()
973            .map_err(|e| {
974                PolarsError::ComputeError(
975                    format!(
976                        "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
977                    )
978                    .into(),
979                )
980            })?;
981        let pl_df = lf.collect().map_err(|e| {
982            PolarsError::ComputeError(
983                format!("read_csv({path_display}): collect failed: {e}").into(),
984            )
985        })?;
986        Ok(crate::dataframe::DataFrame::from_polars_with_options(
987            pl_df,
988            self.is_case_sensitive(),
989        ))
990    }
991
992    /// Read a Parquet file.
993    ///
994    /// Uses Polars' Parquet reader. Parquet files have embedded schema, so
995    /// schema inference is automatic.
996    ///
997    /// # Example
998    /// ```
999    /// use robin_sparkless::SparkSession;
1000    ///
1001    /// let spark = SparkSession::builder().app_name("test").get_or_create();
1002    /// let df_result = spark.read_parquet("data.parquet");
1003    /// // Handle the Result as appropriate in your application
1004    /// ```
1005    pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1006        use polars::prelude::*;
1007        let path = path.as_ref();
1008        // Use LazyFrame::scan_parquet
1009        let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
1010        let pl_df = lf.collect()?;
1011        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1012            pl_df,
1013            self.is_case_sensitive(),
1014        ))
1015    }
1016
1017    /// Read a JSON file (JSONL format - one JSON object per line).
1018    ///
1019    /// Uses Polars' JSONL reader with default options:
1020    /// - Schema is inferred from first 100 rows
1021    ///
1022    /// # Example
1023    /// ```
1024    /// use robin_sparkless::SparkSession;
1025    ///
1026    /// let spark = SparkSession::builder().app_name("test").get_or_create();
1027    /// let df_result = spark.read_json("data.json");
1028    /// // Handle the Result as appropriate in your application
1029    /// ```
1030    pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1031        use polars::prelude::*;
1032        use std::num::NonZeroUsize;
1033        let path = path.as_ref();
1034        // Use LazyJsonLineReader - call finish() to get LazyFrame, then collect
1035        let lf = LazyJsonLineReader::new(path)
1036            .with_infer_schema_length(NonZeroUsize::new(100))
1037            .finish()?;
1038        let pl_df = lf.collect()?;
1039        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1040            pl_df,
1041            self.is_case_sensitive(),
1042        ))
1043    }
1044
1045    /// Execute a SQL query (SELECT only). Tables must be registered with `create_or_replace_temp_view`.
1046    /// Requires the `sql` feature. Supports: SELECT (columns or *), FROM (single table or JOIN),
1047    /// WHERE (basic predicates), GROUP BY + aggregates, ORDER BY, LIMIT.
1048    #[cfg(feature = "sql")]
1049    pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
1050        crate::sql::execute_sql(self, query)
1051    }
1052
1053    /// Execute a SQL query (stub when `sql` feature is disabled).
1054    #[cfg(not(feature = "sql"))]
1055    pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
1056        Err(PolarsError::InvalidOperation(
1057            "SQL queries require the 'sql' feature. Build with --features sql.".into(),
1058        ))
1059    }
1060
1061    /// Returns true if the string looks like a filesystem path (has separators or path exists).
1062    fn looks_like_path(s: &str) -> bool {
1063        s.contains('/') || s.contains('\\') || Path::new(s).exists()
1064    }
1065
1066    /// Read a Delta table from path (latest version). Internal; use read_delta(name_or_path: &str) for dispatch.
1067    #[cfg(feature = "delta")]
1068    pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1069        crate::delta::read_delta(path, self.is_case_sensitive())
1070    }
1071
1072    /// Read Delta table at path, optional version. Internal; use read_delta_str for dispatch.
1073    #[cfg(feature = "delta")]
1074    pub fn read_delta_path_with_version(
1075        &self,
1076        path: impl AsRef<Path>,
1077        version: Option<i64>,
1078    ) -> Result<DataFrame, PolarsError> {
1079        crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
1080    }
1081
1082    /// Read a Delta table or in-memory table by name/path. If name_or_path looks like a path, reads from Delta on disk; else resolves as table name (temp view then saved table).
1083    #[cfg(feature = "delta")]
1084    pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1085        if Self::looks_like_path(name_or_path) {
1086            self.read_delta_path(Path::new(name_or_path))
1087        } else {
1088            self.table(name_or_path)
1089        }
1090    }
1091
1092    #[cfg(feature = "delta")]
1093    pub fn read_delta_with_version(
1094        &self,
1095        name_or_path: &str,
1096        version: Option<i64>,
1097    ) -> Result<DataFrame, PolarsError> {
1098        if Self::looks_like_path(name_or_path) {
1099            self.read_delta_path_with_version(Path::new(name_or_path), version)
1100        } else {
1101            // In-memory tables have no version; ignore version and return table
1102            self.table(name_or_path)
1103        }
1104    }
1105
1106    /// Stub when `delta` feature is disabled. Still supports reading by table name.
1107    #[cfg(not(feature = "delta"))]
1108    pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1109        if Self::looks_like_path(name_or_path) {
1110            Err(PolarsError::InvalidOperation(
1111                "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1112            ))
1113        } else {
1114            self.table(name_or_path)
1115        }
1116    }
1117
1118    #[cfg(not(feature = "delta"))]
1119    pub fn read_delta_with_version(
1120        &self,
1121        name_or_path: &str,
1122        version: Option<i64>,
1123    ) -> Result<DataFrame, PolarsError> {
1124        let _ = version;
1125        self.read_delta(name_or_path)
1126    }
1127
1128    /// Path-only read_delta (for DataFrameReader.load/format delta). Requires delta feature.
1129    #[cfg(feature = "delta")]
1130    pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1131        self.read_delta_path(path)
1132    }
1133
1134    #[cfg(not(feature = "delta"))]
1135    pub fn read_delta_from_path(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1136        Err(PolarsError::InvalidOperation(
1137            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1138        ))
1139    }
1140
1141    /// Stop the session (cleanup resources)
1142    pub fn stop(&self) {
1143        // Cleanup if needed
1144    }
1145}
1146
1147/// DataFrameReader for reading various file formats
1148/// Similar to PySpark's DataFrameReader with option/options/format/load/table
1149pub struct DataFrameReader {
1150    session: SparkSession,
1151    options: HashMap<String, String>,
1152    format: Option<String>,
1153}
1154
1155impl DataFrameReader {
1156    pub fn new(session: SparkSession) -> Self {
1157        DataFrameReader {
1158            session,
1159            options: HashMap::new(),
1160            format: None,
1161        }
1162    }
1163
1164    /// Add a single option (PySpark: option(key, value)). Returns self for chaining.
1165    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1166        self.options.insert(key.into(), value.into());
1167        self
1168    }
1169
1170    /// Add multiple options (PySpark: options(**kwargs)). Returns self for chaining.
1171    pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
1172        for (k, v) in opts {
1173            self.options.insert(k, v);
1174        }
1175        self
1176    }
1177
1178    /// Set the format for load() (PySpark: format("parquet") etc).
1179    pub fn format(mut self, fmt: impl Into<String>) -> Self {
1180        self.format = Some(fmt.into());
1181        self
1182    }
1183
1184    /// Set the schema (PySpark: schema(schema)). Stub: stores but does not apply yet.
1185    pub fn schema(self, _schema: impl Into<String>) -> Self {
1186        self
1187    }
1188
1189    /// Load data from path using format (or infer from extension) and options.
1190    pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1191        let path = path.as_ref();
1192        let fmt = self.format.clone().or_else(|| {
1193            path.extension()
1194                .and_then(|e| e.to_str())
1195                .map(|s| s.to_lowercase())
1196        });
1197        match fmt.as_deref() {
1198            Some("parquet") => self.parquet(path),
1199            Some("csv") => self.csv(path),
1200            Some("json") | Some("jsonl") => self.json(path),
1201            #[cfg(feature = "delta")]
1202            Some("delta") => self.session.read_delta_from_path(path),
1203            _ => Err(PolarsError::ComputeError(
1204                format!(
1205                    "load: could not infer format for path '{}'. Use format('parquet'|'csv'|'json') before load.",
1206                    path.display()
1207                )
1208                .into(),
1209            )),
1210        }
1211    }
1212
1213    /// Return the named table/view (PySpark: table(name)).
1214    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
1215        self.session.table(name)
1216    }
1217
1218    fn apply_csv_options(
1219        &self,
1220        reader: polars::prelude::LazyCsvReader,
1221    ) -> polars::prelude::LazyCsvReader {
1222        use polars::prelude::NullValues;
1223        let mut r = reader;
1224        if let Some(v) = self.options.get("header") {
1225            let has_header = v.eq_ignore_ascii_case("true") || v == "1";
1226            r = r.with_has_header(has_header);
1227        }
1228        if let Some(v) = self.options.get("inferSchema") {
1229            if v.eq_ignore_ascii_case("true") || v == "1" {
1230                let n = self
1231                    .options
1232                    .get("inferSchemaLength")
1233                    .and_then(|s| s.parse::<usize>().ok())
1234                    .unwrap_or(100);
1235                r = r.with_infer_schema_length(Some(n));
1236            }
1237        } else if let Some(v) = self.options.get("inferSchemaLength") {
1238            if let Ok(n) = v.parse::<usize>() {
1239                r = r.with_infer_schema_length(Some(n));
1240            }
1241        }
1242        if let Some(sep) = self.options.get("sep") {
1243            if let Some(b) = sep.bytes().next() {
1244                r = r.with_separator(b);
1245            }
1246        }
1247        if let Some(null_val) = self.options.get("nullValue") {
1248            r = r.with_null_values(Some(NullValues::AllColumnsSingle(null_val.clone().into())));
1249        }
1250        r
1251    }
1252
1253    fn apply_json_options(
1254        &self,
1255        reader: polars::prelude::LazyJsonLineReader,
1256    ) -> polars::prelude::LazyJsonLineReader {
1257        use std::num::NonZeroUsize;
1258        let mut r = reader;
1259        if let Some(v) = self.options.get("inferSchemaLength") {
1260            if let Ok(n) = v.parse::<usize>() {
1261                r = r.with_infer_schema_length(NonZeroUsize::new(n));
1262            }
1263        }
1264        r
1265    }
1266
1267    pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1268        use polars::prelude::*;
1269        let path = path.as_ref();
1270        let path_display = path.display();
1271        let reader = LazyCsvReader::new(path);
1272        let reader = if self.options.is_empty() {
1273            reader
1274                .with_has_header(true)
1275                .with_infer_schema_length(Some(100))
1276        } else {
1277            self.apply_csv_options(
1278                reader
1279                    .with_has_header(true)
1280                    .with_infer_schema_length(Some(100)),
1281            )
1282        };
1283        let lf = reader.finish().map_err(|e| {
1284            PolarsError::ComputeError(format!("read csv({path_display}): {e}").into())
1285        })?;
1286        let pl_df = lf.collect().map_err(|e| {
1287            PolarsError::ComputeError(
1288                format!("read csv({path_display}): collect failed: {e}").into(),
1289            )
1290        })?;
1291        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1292            pl_df,
1293            self.session.is_case_sensitive(),
1294        ))
1295    }
1296
1297    pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1298        use polars::prelude::*;
1299        let path = path.as_ref();
1300        let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
1301        let pl_df = lf.collect()?;
1302        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1303            pl_df,
1304            self.session.is_case_sensitive(),
1305        ))
1306    }
1307
1308    pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1309        use polars::prelude::*;
1310        use std::num::NonZeroUsize;
1311        let path = path.as_ref();
1312        let reader = LazyJsonLineReader::new(path);
1313        let reader = if self.options.is_empty() {
1314            reader.with_infer_schema_length(NonZeroUsize::new(100))
1315        } else {
1316            self.apply_json_options(reader.with_infer_schema_length(NonZeroUsize::new(100)))
1317        };
1318        let lf = reader.finish()?;
1319        let pl_df = lf.collect()?;
1320        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1321            pl_df,
1322            self.session.is_case_sensitive(),
1323        ))
1324    }
1325
1326    #[cfg(feature = "delta")]
1327    pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1328        self.session.read_delta_from_path(path)
1329    }
1330}
1331
1332impl SparkSession {
1333    /// Get a DataFrameReader for reading files
1334    pub fn read(&self) -> DataFrameReader {
1335        DataFrameReader::new(SparkSession {
1336            app_name: self.app_name.clone(),
1337            master: self.master.clone(),
1338            config: self.config.clone(),
1339            catalog: self.catalog.clone(),
1340            tables: self.tables.clone(),
1341            udf_registry: self.udf_registry.clone(),
1342            #[cfg(feature = "pyo3")]
1343            python_udf_batch_size: self.python_udf_batch_size,
1344            #[cfg(feature = "pyo3")]
1345            python_udf_max_concurrent_batches: self.python_udf_max_concurrent_batches,
1346        })
1347    }
1348}
1349
1350impl Default for SparkSession {
1351    fn default() -> Self {
1352        Self::builder().get_or_create()
1353    }
1354}
1355
1356#[cfg(test)]
1357mod tests {
1358    use super::*;
1359
1360    #[test]
1361    fn test_spark_session_builder_basic() {
1362        let spark = SparkSession::builder().app_name("test_app").get_or_create();
1363
1364        assert_eq!(spark.app_name, Some("test_app".to_string()));
1365    }
1366
1367    #[test]
1368    fn test_spark_session_builder_with_master() {
1369        let spark = SparkSession::builder()
1370            .app_name("test_app")
1371            .master("local[*]")
1372            .get_or_create();
1373
1374        assert_eq!(spark.app_name, Some("test_app".to_string()));
1375        assert_eq!(spark.master, Some("local[*]".to_string()));
1376    }
1377
1378    #[test]
1379    fn test_spark_session_builder_with_config() {
1380        let spark = SparkSession::builder()
1381            .app_name("test_app")
1382            .config("spark.executor.memory", "4g")
1383            .config("spark.driver.memory", "2g")
1384            .get_or_create();
1385
1386        assert_eq!(
1387            spark.config.get("spark.executor.memory"),
1388            Some(&"4g".to_string())
1389        );
1390        assert_eq!(
1391            spark.config.get("spark.driver.memory"),
1392            Some(&"2g".to_string())
1393        );
1394    }
1395
1396    #[test]
1397    fn test_spark_session_default() {
1398        let spark = SparkSession::default();
1399        assert!(spark.app_name.is_none());
1400        assert!(spark.master.is_none());
1401        assert!(spark.config.is_empty());
1402    }
1403
1404    #[test]
1405    fn test_create_dataframe_success() {
1406        let spark = SparkSession::builder().app_name("test").get_or_create();
1407        let data = vec![
1408            (1i64, 25i64, "Alice".to_string()),
1409            (2i64, 30i64, "Bob".to_string()),
1410        ];
1411
1412        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
1413
1414        assert!(result.is_ok());
1415        let df = result.unwrap();
1416        assert_eq!(df.count().unwrap(), 2);
1417
1418        let columns = df.columns().unwrap();
1419        assert!(columns.contains(&"id".to_string()));
1420        assert!(columns.contains(&"age".to_string()));
1421        assert!(columns.contains(&"name".to_string()));
1422    }
1423
1424    #[test]
1425    fn test_create_dataframe_wrong_column_count() {
1426        let spark = SparkSession::builder().app_name("test").get_or_create();
1427        let data = vec![(1i64, 25i64, "Alice".to_string())];
1428
1429        // Too few columns
1430        let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
1431        assert!(result.is_err());
1432
1433        // Too many columns
1434        let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
1435        assert!(result.is_err());
1436    }
1437
1438    #[test]
1439    fn test_create_dataframe_from_rows_empty_schema_returns_error() {
1440        let spark = SparkSession::builder().app_name("test").get_or_create();
1441        let rows: Vec<Vec<JsonValue>> = vec![vec![]];
1442        let schema: Vec<(String, String)> = vec![];
1443        let result = spark.create_dataframe_from_rows(rows, schema);
1444        match &result {
1445            Err(e) => assert!(e.to_string().contains("schema must not be empty")),
1446            Ok(_) => panic!("expected error for empty schema"),
1447        }
1448    }
1449
1450    #[test]
1451    fn test_create_dataframe_empty() {
1452        let spark = SparkSession::builder().app_name("test").get_or_create();
1453        let data: Vec<(i64, i64, String)> = vec![];
1454
1455        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
1456
1457        assert!(result.is_ok());
1458        let df = result.unwrap();
1459        assert_eq!(df.count().unwrap(), 0);
1460    }
1461
1462    #[test]
1463    fn test_create_dataframe_from_polars() {
1464        use polars::prelude::df;
1465
1466        let spark = SparkSession::builder().app_name("test").get_or_create();
1467        let polars_df = df!(
1468            "x" => &[1, 2, 3],
1469            "y" => &[4, 5, 6]
1470        )
1471        .unwrap();
1472
1473        let df = spark.create_dataframe_from_polars(polars_df);
1474
1475        assert_eq!(df.count().unwrap(), 3);
1476        let columns = df.columns().unwrap();
1477        assert!(columns.contains(&"x".to_string()));
1478        assert!(columns.contains(&"y".to_string()));
1479    }
1480
1481    #[test]
1482    fn test_read_csv_file_not_found() {
1483        let spark = SparkSession::builder().app_name("test").get_or_create();
1484
1485        let result = spark.read_csv("nonexistent_file.csv");
1486
1487        assert!(result.is_err());
1488    }
1489
1490    #[test]
1491    fn test_read_parquet_file_not_found() {
1492        let spark = SparkSession::builder().app_name("test").get_or_create();
1493
1494        let result = spark.read_parquet("nonexistent_file.parquet");
1495
1496        assert!(result.is_err());
1497    }
1498
1499    #[test]
1500    fn test_read_json_file_not_found() {
1501        let spark = SparkSession::builder().app_name("test").get_or_create();
1502
1503        let result = spark.read_json("nonexistent_file.json");
1504
1505        assert!(result.is_err());
1506    }
1507
1508    #[test]
1509    fn test_rust_udf_dataframe() {
1510        use crate::functions::{call_udf, col};
1511        use polars::prelude::DataType;
1512
1513        let spark = SparkSession::builder().app_name("test").get_or_create();
1514        spark
1515            .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
1516            .unwrap();
1517        let df = spark
1518            .create_dataframe(
1519                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
1520                vec!["id", "age", "name"],
1521            )
1522            .unwrap();
1523        let col = call_udf("to_str", &[col("id")]).unwrap();
1524        let df2 = df.with_column("id_str", &col).unwrap();
1525        let cols = df2.columns().unwrap();
1526        assert!(cols.contains(&"id_str".to_string()));
1527        let rows = df2.collect_as_json_rows().unwrap();
1528        assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
1529        assert_eq!(rows[1].get("id_str").and_then(|v| v.as_str()), Some("2"));
1530    }
1531
1532    #[test]
1533    fn test_case_insensitive_filter_select() {
1534        use crate::expression::lit_i64;
1535        use crate::functions::col;
1536
1537        let spark = SparkSession::builder().app_name("test").get_or_create();
1538        let df = spark
1539            .create_dataframe(
1540                vec![
1541                    (1, 25, "Alice".to_string()),
1542                    (2, 30, "Bob".to_string()),
1543                    (3, 35, "Charlie".to_string()),
1544                ],
1545                vec!["Id", "Age", "Name"],
1546            )
1547            .unwrap();
1548        // Filter with lowercase column names (PySpark default: case-insensitive)
1549        let filtered = df
1550            .filter(col("age").gt(lit_i64(26)).expr().clone())
1551            .unwrap()
1552            .select(vec!["name"])
1553            .unwrap();
1554        assert_eq!(filtered.count().unwrap(), 2);
1555        let rows = filtered.collect_as_json_rows().unwrap();
1556        let names: Vec<&str> = rows
1557            .iter()
1558            .map(|r| r.get("name").and_then(|v| v.as_str()).unwrap())
1559            .collect();
1560        assert!(names.contains(&"Bob"));
1561        assert!(names.contains(&"Charlie"));
1562    }
1563
1564    #[test]
1565    fn test_sql_returns_error_without_feature_or_unknown_table() {
1566        let spark = SparkSession::builder().app_name("test").get_or_create();
1567
1568        let result = spark.sql("SELECT * FROM table");
1569
1570        assert!(result.is_err());
1571        match result {
1572            Err(PolarsError::InvalidOperation(msg)) => {
1573                let s = msg.to_string();
1574                // Without sql feature: "SQL queries require the 'sql' feature"
1575                // With sql feature but no table: "Table or view 'table' not found" or parse error
1576                assert!(
1577                    s.contains("SQL") || s.contains("Table") || s.contains("feature"),
1578                    "unexpected message: {s}"
1579                );
1580            }
1581            _ => panic!("Expected InvalidOperation error"),
1582        }
1583    }
1584
1585    #[test]
1586    fn test_spark_session_stop() {
1587        let spark = SparkSession::builder().app_name("test").get_or_create();
1588
1589        // stop() should complete without error
1590        spark.stop();
1591    }
1592
1593    #[test]
1594    fn test_dataframe_reader_api() {
1595        let spark = SparkSession::builder().app_name("test").get_or_create();
1596        let reader = spark.read();
1597
1598        // All readers should return errors for non-existent files
1599        assert!(reader.csv("nonexistent.csv").is_err());
1600        assert!(reader.parquet("nonexistent.parquet").is_err());
1601        assert!(reader.json("nonexistent.json").is_err());
1602    }
1603
1604    #[test]
1605    fn test_read_csv_with_valid_file() {
1606        use std::io::Write;
1607        use tempfile::NamedTempFile;
1608
1609        let spark = SparkSession::builder().app_name("test").get_or_create();
1610
1611        // Create a temporary CSV file
1612        let mut temp_file = NamedTempFile::new().unwrap();
1613        writeln!(temp_file, "id,name,age").unwrap();
1614        writeln!(temp_file, "1,Alice,25").unwrap();
1615        writeln!(temp_file, "2,Bob,30").unwrap();
1616        temp_file.flush().unwrap();
1617
1618        let result = spark.read_csv(temp_file.path());
1619
1620        assert!(result.is_ok());
1621        let df = result.unwrap();
1622        assert_eq!(df.count().unwrap(), 2);
1623
1624        let columns = df.columns().unwrap();
1625        assert!(columns.contains(&"id".to_string()));
1626        assert!(columns.contains(&"name".to_string()));
1627        assert!(columns.contains(&"age".to_string()));
1628    }
1629
1630    #[test]
1631    fn test_read_json_with_valid_file() {
1632        use std::io::Write;
1633        use tempfile::NamedTempFile;
1634
1635        let spark = SparkSession::builder().app_name("test").get_or_create();
1636
1637        // Create a temporary JSONL file
1638        let mut temp_file = NamedTempFile::new().unwrap();
1639        writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
1640        writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
1641        temp_file.flush().unwrap();
1642
1643        let result = spark.read_json(temp_file.path());
1644
1645        assert!(result.is_ok());
1646        let df = result.unwrap();
1647        assert_eq!(df.count().unwrap(), 2);
1648    }
1649
1650    #[test]
1651    fn test_read_csv_empty_file() {
1652        use std::io::Write;
1653        use tempfile::NamedTempFile;
1654
1655        let spark = SparkSession::builder().app_name("test").get_or_create();
1656
1657        // Create an empty CSV file (just header)
1658        let mut temp_file = NamedTempFile::new().unwrap();
1659        writeln!(temp_file, "id,name").unwrap();
1660        temp_file.flush().unwrap();
1661
1662        let result = spark.read_csv(temp_file.path());
1663
1664        assert!(result.is_ok());
1665        let df = result.unwrap();
1666        assert_eq!(df.count().unwrap(), 0);
1667    }
1668
1669    #[test]
1670    fn test_write_partitioned_parquet() {
1671        use crate::dataframe::{WriteFormat, WriteMode};
1672        use std::fs;
1673        use tempfile::TempDir;
1674
1675        let spark = SparkSession::builder().app_name("test").get_or_create();
1676        let df = spark
1677            .create_dataframe(
1678                vec![
1679                    (1, 25, "Alice".to_string()),
1680                    (2, 30, "Bob".to_string()),
1681                    (3, 25, "Carol".to_string()),
1682                ],
1683                vec!["id", "age", "name"],
1684            )
1685            .unwrap();
1686        let dir = TempDir::new().unwrap();
1687        let path = dir.path().join("out");
1688        df.write()
1689            .mode(WriteMode::Overwrite)
1690            .format(WriteFormat::Parquet)
1691            .partition_by(["age"])
1692            .save(&path)
1693            .unwrap();
1694        assert!(path.is_dir());
1695        let entries: Vec<_> = fs::read_dir(&path).unwrap().collect();
1696        assert_eq!(
1697            entries.len(),
1698            2,
1699            "expected two partition dirs (age=25, age=30)"
1700        );
1701        let names: Vec<String> = entries
1702            .iter()
1703            .filter_map(|e| e.as_ref().ok())
1704            .map(|e| e.file_name().to_string_lossy().into_owned())
1705            .collect();
1706        assert!(names.iter().any(|n| n.starts_with("age=")));
1707        let df_read = spark.read_parquet(&path).unwrap();
1708        assert_eq!(df_read.count().unwrap(), 3);
1709    }
1710
1711    #[test]
1712    fn test_save_as_table_error_if_exists() {
1713        use crate::dataframe::SaveMode;
1714
1715        let spark = SparkSession::builder().app_name("test").get_or_create();
1716        let df = spark
1717            .create_dataframe(
1718                vec![(1, 25, "Alice".to_string())],
1719                vec!["id", "age", "name"],
1720            )
1721            .unwrap();
1722        // First call succeeds
1723        df.write()
1724            .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
1725            .unwrap();
1726        assert!(spark.table("t1").is_ok());
1727        assert_eq!(spark.table("t1").unwrap().count().unwrap(), 1);
1728        // Second call with ErrorIfExists fails
1729        let err = df
1730            .write()
1731            .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
1732            .unwrap_err();
1733        assert!(err.to_string().contains("already exists"));
1734    }
1735
1736    #[test]
1737    fn test_save_as_table_overwrite() {
1738        use crate::dataframe::SaveMode;
1739
1740        let spark = SparkSession::builder().app_name("test").get_or_create();
1741        let df1 = spark
1742            .create_dataframe(
1743                vec![(1, 25, "Alice".to_string())],
1744                vec!["id", "age", "name"],
1745            )
1746            .unwrap();
1747        let df2 = spark
1748            .create_dataframe(
1749                vec![(2, 30, "Bob".to_string()), (3, 35, "Carol".to_string())],
1750                vec!["id", "age", "name"],
1751            )
1752            .unwrap();
1753        df1.write()
1754            .save_as_table(&spark, "t_over", SaveMode::ErrorIfExists)
1755            .unwrap();
1756        assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 1);
1757        df2.write()
1758            .save_as_table(&spark, "t_over", SaveMode::Overwrite)
1759            .unwrap();
1760        assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 2);
1761    }
1762
1763    #[test]
1764    fn test_save_as_table_append() {
1765        use crate::dataframe::SaveMode;
1766
1767        let spark = SparkSession::builder().app_name("test").get_or_create();
1768        let df1 = spark
1769            .create_dataframe(
1770                vec![(1, 25, "Alice".to_string())],
1771                vec!["id", "age", "name"],
1772            )
1773            .unwrap();
1774        let df2 = spark
1775            .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
1776            .unwrap();
1777        df1.write()
1778            .save_as_table(&spark, "t_append", SaveMode::ErrorIfExists)
1779            .unwrap();
1780        df2.write()
1781            .save_as_table(&spark, "t_append", SaveMode::Append)
1782            .unwrap();
1783        assert_eq!(spark.table("t_append").unwrap().count().unwrap(), 2);
1784    }
1785
1786    #[test]
1787    fn test_save_as_table_ignore() {
1788        use crate::dataframe::SaveMode;
1789
1790        let spark = SparkSession::builder().app_name("test").get_or_create();
1791        let df1 = spark
1792            .create_dataframe(
1793                vec![(1, 25, "Alice".to_string())],
1794                vec!["id", "age", "name"],
1795            )
1796            .unwrap();
1797        let df2 = spark
1798            .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
1799            .unwrap();
1800        df1.write()
1801            .save_as_table(&spark, "t_ignore", SaveMode::ErrorIfExists)
1802            .unwrap();
1803        df2.write()
1804            .save_as_table(&spark, "t_ignore", SaveMode::Ignore)
1805            .unwrap();
1806        // Still 1 row (ignore did not replace)
1807        assert_eq!(spark.table("t_ignore").unwrap().count().unwrap(), 1);
1808    }
1809
1810    #[test]
1811    fn test_table_resolution_temp_view_first() {
1812        use crate::dataframe::SaveMode;
1813
1814        let spark = SparkSession::builder().app_name("test").get_or_create();
1815        let df_saved = spark
1816            .create_dataframe(
1817                vec![(1, 25, "Saved".to_string())],
1818                vec!["id", "age", "name"],
1819            )
1820            .unwrap();
1821        let df_temp = spark
1822            .create_dataframe(vec![(2, 30, "Temp".to_string())], vec!["id", "age", "name"])
1823            .unwrap();
1824        df_saved
1825            .write()
1826            .save_as_table(&spark, "x", SaveMode::ErrorIfExists)
1827            .unwrap();
1828        spark.create_or_replace_temp_view("x", df_temp);
1829        // table("x") must return temp view (PySpark order)
1830        let t = spark.table("x").unwrap();
1831        let rows = t.collect_as_json_rows().unwrap();
1832        assert_eq!(rows.len(), 1);
1833        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Temp"));
1834    }
1835
1836    #[test]
1837    fn test_drop_table() {
1838        use crate::dataframe::SaveMode;
1839
1840        let spark = SparkSession::builder().app_name("test").get_or_create();
1841        let df = spark
1842            .create_dataframe(
1843                vec![(1, 25, "Alice".to_string())],
1844                vec!["id", "age", "name"],
1845            )
1846            .unwrap();
1847        df.write()
1848            .save_as_table(&spark, "t_drop", SaveMode::ErrorIfExists)
1849            .unwrap();
1850        assert!(spark.table("t_drop").is_ok());
1851        assert!(spark.drop_table("t_drop"));
1852        assert!(spark.table("t_drop").is_err());
1853        // drop again is no-op, returns false
1854        assert!(!spark.drop_table("t_drop"));
1855    }
1856
1857    #[test]
1858    fn test_global_temp_view_persists_across_sessions() {
1859        // Session 1: create global temp view
1860        let spark1 = SparkSession::builder().app_name("s1").get_or_create();
1861        let df1 = spark1
1862            .create_dataframe(
1863                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
1864                vec!["id", "age", "name"],
1865            )
1866            .unwrap();
1867        spark1.create_or_replace_global_temp_view("people", df1);
1868        assert_eq!(
1869            spark1.table("global_temp.people").unwrap().count().unwrap(),
1870            2
1871        );
1872
1873        // Session 2: different session can see global temp view
1874        let spark2 = SparkSession::builder().app_name("s2").get_or_create();
1875        let df2 = spark2.table("global_temp.people").unwrap();
1876        assert_eq!(df2.count().unwrap(), 2);
1877        let rows = df2.collect_as_json_rows().unwrap();
1878        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
1879
1880        // Local temp view in spark2 does not shadow global_temp
1881        let df_local = spark2
1882            .create_dataframe(
1883                vec![(3, 35, "Carol".to_string())],
1884                vec!["id", "age", "name"],
1885            )
1886            .unwrap();
1887        spark2.create_or_replace_temp_view("people", df_local);
1888        // table("people") = local temp view (session resolution)
1889        assert_eq!(spark2.table("people").unwrap().count().unwrap(), 1);
1890        // table("global_temp.people") = global temp view (unchanged)
1891        assert_eq!(
1892            spark2.table("global_temp.people").unwrap().count().unwrap(),
1893            2
1894        );
1895
1896        // Drop global temp view
1897        assert!(spark2.drop_global_temp_view("people"));
1898        assert!(spark2.table("global_temp.people").is_err());
1899    }
1900
1901    #[test]
1902    fn test_warehouse_persistence_between_sessions() {
1903        use crate::dataframe::SaveMode;
1904        use std::fs;
1905        use tempfile::TempDir;
1906
1907        let dir = TempDir::new().unwrap();
1908        let warehouse = dir.path().to_str().unwrap();
1909
1910        // Session 1: save to warehouse
1911        let spark1 = SparkSession::builder()
1912            .app_name("w1")
1913            .config("spark.sql.warehouse.dir", warehouse)
1914            .get_or_create();
1915        let df1 = spark1
1916            .create_dataframe(
1917                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
1918                vec!["id", "age", "name"],
1919            )
1920            .unwrap();
1921        df1.write()
1922            .save_as_table(&spark1, "users", SaveMode::ErrorIfExists)
1923            .unwrap();
1924        assert_eq!(spark1.table("users").unwrap().count().unwrap(), 2);
1925
1926        // Session 2: new session reads from warehouse
1927        let spark2 = SparkSession::builder()
1928            .app_name("w2")
1929            .config("spark.sql.warehouse.dir", warehouse)
1930            .get_or_create();
1931        let df2 = spark2.table("users").unwrap();
1932        assert_eq!(df2.count().unwrap(), 2);
1933        let rows = df2.collect_as_json_rows().unwrap();
1934        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
1935
1936        // Verify parquet was written
1937        let table_path = dir.path().join("users");
1938        assert!(table_path.is_dir());
1939        let entries: Vec<_> = fs::read_dir(&table_path).unwrap().collect();
1940        assert!(!entries.is_empty());
1941    }
1942}