Skip to main content

robin_sparkless/
session.rs

1use crate::dataframe::DataFrame;
2use crate::error::EngineError;
3use crate::udf_registry::UdfRegistry;
4use polars::chunked_array::StructChunked;
5use polars::chunked_array::builder::get_list_builder;
6use polars::prelude::{
7    DataFrame as PlDataFrame, DataType, Field, IntoSeries, NamedFrom, PlSmallStr, PolarsError,
8    Series, TimeUnit,
9};
10use serde_json::Value as JsonValue;
11use std::cell::RefCell;
12
13/// Parse "array<element_type>" to get inner type string. Returns None if not array<>.
14fn parse_array_element_type(type_str: &str) -> Option<String> {
15    let s = type_str.trim();
16    if !s.to_lowercase().starts_with("array<") || !s.ends_with('>') {
17        return None;
18    }
19    Some(s[6..s.len() - 1].trim().to_string())
20}
21
22/// Parse "struct<field:type,...>" to get field (name, type) pairs. Simple parsing, no nested structs.
23fn parse_struct_fields(type_str: &str) -> Option<Vec<(String, String)>> {
24    let s = type_str.trim();
25    if !s.to_lowercase().starts_with("struct<") || !s.ends_with('>') {
26        return None;
27    }
28    let inner = s[7..s.len() - 1].trim();
29    if inner.is_empty() {
30        return Some(Vec::new());
31    }
32    let mut out = Vec::new();
33    for part in inner.split(',') {
34        let part = part.trim();
35        if let Some(idx) = part.find(':') {
36            let name = part[..idx].trim().to_string();
37            let typ = part[idx + 1..].trim().to_string();
38            out.push((name, typ));
39        }
40    }
41    Some(out)
42}
43
44/// Parse "map<key_type,value_type>" to get (key_type, value_type). Returns None if not map<>.
45/// PySpark: MapType(StringType(), StringType()) -> "map<string,string>".
46fn parse_map_key_value_types(type_str: &str) -> Option<(String, String)> {
47    let s = type_str.trim().to_lowercase();
48    if !s.starts_with("map<") || !s.ends_with('>') {
49        return None;
50    }
51    let inner = s[4..s.len() - 1].trim();
52    let comma = inner.find(',')?;
53    let key_type = inner[..comma].trim().to_string();
54    let value_type = inner[comma + 1..].trim().to_string();
55    Some((key_type, value_type))
56}
57
58/// True if type string is Decimal(precision, scale), e.g. "decimal(10,2)".
59fn is_decimal_type_str(type_str: &str) -> bool {
60    let s = type_str.trim().to_lowercase();
61    s.starts_with("decimal(") && s.contains(')')
62}
63
64/// Map schema type string to Polars DataType (primitives only for nested use).
65/// Decimal(p,s) is mapped to Float64 (Polars dtype-decimal feature not enabled).
66fn json_type_str_to_polars(type_str: &str) -> Option<DataType> {
67    let s = type_str.trim().to_lowercase();
68    if is_decimal_type_str(&s) {
69        return Some(DataType::Float64);
70    }
71    match s.as_str() {
72        "int" | "integer" | "bigint" | "long" => Some(DataType::Int64),
73        "double" | "float" | "double_precision" => Some(DataType::Float64),
74        "string" | "str" | "varchar" => Some(DataType::String),
75        "boolean" | "bool" => Some(DataType::Boolean),
76        _ => None,
77    }
78}
79
80/// Normalize a JSON value to an array for array columns (PySpark parity #625).
81/// Accepts: Array, Object with "0","1",... keys (Python list serialization), String that parses as JSON array.
82/// Returns None for null or when value should be treated as single-element list (#611).
83fn json_value_to_array(v: &JsonValue) -> Option<Vec<JsonValue>> {
84    match v {
85        JsonValue::Null => None,
86        JsonValue::Array(arr) => Some(arr.clone()),
87        JsonValue::Object(obj) => {
88            // Python/serialization sometimes sends list as {"0": x, "1": y}. Build sorted by index.
89            let mut indices: Vec<usize> =
90                obj.keys().filter_map(|k| k.parse::<usize>().ok()).collect();
91            indices.sort_unstable();
92            if indices.is_empty() {
93                return None;
94            }
95            let arr: Vec<JsonValue> = indices
96                .iter()
97                .filter_map(|i| obj.get(&i.to_string()).cloned())
98                .collect();
99            Some(arr)
100        }
101        JsonValue::String(s) => {
102            if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
103                parsed.as_array().cloned()
104            } else {
105                None
106            }
107        }
108        _ => None,
109    }
110}
111
112/// Infer list element type from first non-null array in the column (for schema "list" / "array").
113fn infer_list_element_type(rows: &[Vec<JsonValue>], col_idx: usize) -> Option<(String, DataType)> {
114    for row in rows {
115        let v = row.get(col_idx)?;
116        let arr = json_value_to_array(v)?;
117        let first = arr.first()?;
118        return Some(match first {
119            JsonValue::String(_) => ("string".to_string(), DataType::String),
120            JsonValue::Number(n) => {
121                if n.as_i64().is_some() {
122                    ("bigint".to_string(), DataType::Int64)
123                } else {
124                    ("double".to_string(), DataType::Float64)
125                }
126            }
127            JsonValue::Bool(_) => ("boolean".to_string(), DataType::Boolean),
128            JsonValue::Null => continue,
129            _ => ("string".to_string(), DataType::String),
130        });
131    }
132    None
133}
134
135/// Build a length-N Series from `Vec<Option<JsonValue>>` for a given type (recursive for struct/array).
136fn json_values_to_series(
137    values: &[Option<JsonValue>],
138    type_str: &str,
139    name: &str,
140) -> Result<Series, PolarsError> {
141    use chrono::{NaiveDate, NaiveDateTime};
142    let epoch = crate::date_utils::epoch_naive_date();
143    let type_lower = type_str.trim().to_lowercase();
144
145    if let Some(elem_type) = parse_array_element_type(&type_lower) {
146        let inner_dtype = json_type_str_to_polars(&elem_type).ok_or_else(|| {
147            PolarsError::ComputeError(
148                format!("array element type '{elem_type}' not supported").into(),
149            )
150        })?;
151        let mut builder = get_list_builder(&inner_dtype, 64, values.len(), name.into());
152        for v in values.iter() {
153            if v.as_ref().is_none_or(|x| matches!(x, JsonValue::Null)) {
154                builder.append_null();
155            } else if let Some(arr) = v.as_ref().and_then(json_value_to_array) {
156                // #625: Array, Object with "0","1",..., or string that parses as JSON array (PySpark list parity).
157                let elem_series: Vec<Series> = arr
158                    .iter()
159                    .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
160                    .collect::<Result<Vec<_>, _>>()?;
161                let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
162                let s = Series::from_any_values_and_dtype(
163                    PlSmallStr::EMPTY,
164                    &vals,
165                    &inner_dtype,
166                    false,
167                )
168                .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
169                builder.append_series(&s)?;
170            } else {
171                // #611: PySpark accepts single value as one-element list for array columns.
172                let single_arr = [v.clone().unwrap_or(JsonValue::Null)];
173                let elem_series: Vec<Series> = single_arr
174                    .iter()
175                    .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
176                    .collect::<Result<Vec<_>, _>>()?;
177                let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
178                let arr_series = Series::from_any_values_and_dtype(
179                    PlSmallStr::EMPTY,
180                    &vals,
181                    &inner_dtype,
182                    false,
183                )
184                .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
185                builder.append_series(&arr_series)?;
186            }
187        }
188        return Ok(builder.finish().into_series());
189    }
190
191    if let Some(fields) = parse_struct_fields(&type_lower) {
192        let mut field_series_vec: Vec<Vec<Option<JsonValue>>> = (0..fields.len())
193            .map(|_| Vec::with_capacity(values.len()))
194            .collect();
195        for v in values.iter() {
196            // #610: Accept string that parses as JSON object or array (e.g. Python tuple serialized as "[1, \"y\"]").
197            let effective: Option<JsonValue> = match v.as_ref() {
198                Some(JsonValue::String(s)) => {
199                    if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
200                        if parsed.is_object() || parsed.is_array() {
201                            Some(parsed)
202                        } else {
203                            v.clone()
204                        }
205                    } else {
206                        v.clone()
207                    }
208                }
209                _ => v.clone(),
210            };
211            if effective
212                .as_ref()
213                .is_none_or(|x| matches!(x, JsonValue::Null))
214            {
215                for fc in &mut field_series_vec {
216                    fc.push(None);
217                }
218            } else if let Some(obj) = effective.as_ref().and_then(|x| x.as_object()) {
219                for (fi, (fname, _)) in fields.iter().enumerate() {
220                    field_series_vec[fi].push(obj.get(fname).cloned());
221                }
222            } else if let Some(arr) = effective.as_ref().and_then(|x| x.as_array()) {
223                for (fi, _) in fields.iter().enumerate() {
224                    field_series_vec[fi].push(arr.get(fi).cloned());
225                }
226            } else {
227                return Err(PolarsError::ComputeError(
228                    "struct value must be object (by field name) or array (by position). \
229                     PySpark accepts dict or tuple/list for struct columns."
230                        .into(),
231                ));
232            }
233        }
234        let series_per_field: Vec<Series> = fields
235            .iter()
236            .enumerate()
237            .map(|(fi, (fname, ftype))| json_values_to_series(&field_series_vec[fi], ftype, fname))
238            .collect::<Result<Vec<_>, _>>()?;
239        let field_refs: Vec<&Series> = series_per_field.iter().collect();
240        let st = StructChunked::from_series(name.into(), values.len(), field_refs.iter().copied())
241            .map_err(|e| PolarsError::ComputeError(format!("struct column: {e}").into()))?
242            .into_series();
243        return Ok(st);
244    }
245
246    match type_lower.as_str() {
247        "int" | "bigint" | "long" => {
248            let vals: Vec<Option<i64>> = values
249                .iter()
250                .map(|ov| {
251                    ov.as_ref().and_then(|v| match v {
252                        JsonValue::Number(n) => n.as_i64(),
253                        JsonValue::Null => None,
254                        _ => None,
255                    })
256                })
257                .collect();
258            Ok(Series::new(name.into(), vals))
259        }
260        "double" | "float" => {
261            let vals: Vec<Option<f64>> = values
262                .iter()
263                .map(|ov| {
264                    ov.as_ref().and_then(|v| match v {
265                        JsonValue::Number(n) => n.as_f64(),
266                        JsonValue::Null => None,
267                        _ => None,
268                    })
269                })
270                .collect();
271            Ok(Series::new(name.into(), vals))
272        }
273        "string" | "str" | "varchar" => {
274            let vals: Vec<Option<&str>> = values
275                .iter()
276                .map(|ov| {
277                    ov.as_ref().and_then(|v| match v {
278                        JsonValue::String(s) => Some(s.as_str()),
279                        JsonValue::Null => None,
280                        _ => None,
281                    })
282                })
283                .collect();
284            let owned: Vec<Option<String>> =
285                vals.into_iter().map(|o| o.map(|s| s.to_string())).collect();
286            Ok(Series::new(name.into(), owned))
287        }
288        "boolean" | "bool" => {
289            let vals: Vec<Option<bool>> = values
290                .iter()
291                .map(|ov| {
292                    ov.as_ref().and_then(|v| match v {
293                        JsonValue::Bool(b) => Some(*b),
294                        JsonValue::Null => None,
295                        _ => None,
296                    })
297                })
298                .collect();
299            Ok(Series::new(name.into(), vals))
300        }
301        "date" => {
302            let vals: Vec<Option<i32>> = values
303                .iter()
304                .map(|ov| {
305                    ov.as_ref().and_then(|v| match v {
306                        JsonValue::String(s) => NaiveDate::parse_from_str(s, "%Y-%m-%d")
307                            .ok()
308                            .map(|d| (d - epoch).num_days() as i32),
309                        JsonValue::Null => None,
310                        _ => None,
311                    })
312                })
313                .collect();
314            let s = Series::new(name.into(), vals);
315            s.cast(&DataType::Date)
316                .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))
317        }
318        "timestamp" | "datetime" | "timestamp_ntz" => {
319            let vals: Vec<Option<i64>> = values
320                .iter()
321                .map(|ov| {
322                    ov.as_ref().and_then(|v| match v {
323                        JsonValue::String(s) => {
324                            let parsed = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
325                                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))
326                                .or_else(|_| {
327                                    NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").map_err(
328                                        |e| PolarsError::ComputeError(e.to_string().into()),
329                                    )
330                                })
331                                .or_else(|_| {
332                                    NaiveDate::parse_from_str(s, "%Y-%m-%d")
333                                        .map_err(|e| {
334                                            PolarsError::ComputeError(e.to_string().into())
335                                        })
336                                        .and_then(|d| {
337                                            d.and_hms_opt(0, 0, 0).ok_or_else(|| {
338                                                PolarsError::ComputeError(
339                                                    "date to datetime (0:0:0)".into(),
340                                                )
341                                            })
342                                        })
343                                });
344                            parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
345                        }
346                        JsonValue::Number(n) => n.as_i64(),
347                        JsonValue::Null => None,
348                        _ => None,
349                    })
350                })
351                .collect();
352            let s = Series::new(name.into(), vals);
353            s.cast(&DataType::Datetime(TimeUnit::Microseconds, None))
354                .map_err(|e| PolarsError::ComputeError(format!("datetime cast: {e}").into()))
355        }
356        _ => Err(PolarsError::ComputeError(
357            format!("json_values_to_series: unsupported type '{type_str}'").into(),
358        )),
359    }
360}
361
362/// Build a single Series from a JsonValue for use as list element or struct field.
363fn json_value_to_series_single(
364    value: &JsonValue,
365    type_str: &str,
366    name: &str,
367) -> Result<Series, PolarsError> {
368    use chrono::NaiveDate;
369    let epoch = crate::date_utils::epoch_naive_date();
370    match (value, type_str.trim().to_lowercase().as_str()) {
371        (JsonValue::Null, _) => Ok(Series::new_null(name.into(), 1)),
372        (JsonValue::Number(n), "int" | "bigint" | "long") => {
373            Ok(Series::new(name.into(), vec![n.as_i64()]))
374        }
375        (JsonValue::Number(n), "double" | "float") => {
376            Ok(Series::new(name.into(), vec![n.as_f64()]))
377        }
378        (JsonValue::Number(n), t) if is_decimal_type_str(t) => {
379            Ok(Series::new(name.into(), vec![n.as_f64()]))
380        }
381        (JsonValue::String(s), "string" | "str" | "varchar") => {
382            Ok(Series::new(name.into(), vec![s.as_str()]))
383        }
384        (JsonValue::Bool(b), "boolean" | "bool") => Ok(Series::new(name.into(), vec![*b])),
385        (JsonValue::String(s), "date") => {
386            let d = NaiveDate::parse_from_str(s, "%Y-%m-%d")
387                .map_err(|e| PolarsError::ComputeError(format!("date parse: {e}").into()))?;
388            let days = (d - epoch).num_days() as i32;
389            let s = Series::new(name.into(), vec![days]).cast(&DataType::Date)?;
390            Ok(s)
391        }
392        _ => Err(PolarsError::ComputeError(
393            format!("json_value_to_series: unsupported {type_str} for {value:?}").into(),
394        )),
395    }
396}
397
398/// Build a struct Series from JsonValue::Object or JsonValue::Array (field-order) or Null.
399#[allow(dead_code)]
400fn json_object_or_array_to_struct_series(
401    value: &JsonValue,
402    fields: &[(String, String)],
403    _name: &str,
404) -> Result<Option<Series>, PolarsError> {
405    use polars::prelude::StructChunked;
406    if matches!(value, JsonValue::Null) {
407        return Ok(None);
408    }
409    // #610: Accept string that parses as JSON object or array.
410    let effective = match value {
411        JsonValue::String(s) => {
412            if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
413                if parsed.is_object() || parsed.is_array() {
414                    parsed
415                } else {
416                    value.clone()
417                }
418            } else {
419                value.clone()
420            }
421        }
422        _ => value.clone(),
423    };
424    let mut field_series: Vec<Series> = Vec::with_capacity(fields.len());
425    for (fname, ftype) in fields {
426        let fval = if let Some(obj) = effective.as_object() {
427            obj.get(fname).unwrap_or(&JsonValue::Null)
428        } else if let Some(arr) = effective.as_array() {
429            let idx = field_series.len();
430            arr.get(idx).unwrap_or(&JsonValue::Null)
431        } else {
432            return Err(PolarsError::ComputeError(
433                "struct value must be object (by field name) or array (by position). \
434                 PySpark accepts dict or tuple/list for struct columns."
435                    .into(),
436            ));
437        };
438        let s = json_value_to_series_single(fval, ftype, fname)?;
439        field_series.push(s);
440    }
441    let field_refs: Vec<&Series> = field_series.iter().collect();
442    let st = StructChunked::from_series(PlSmallStr::EMPTY, 1, field_refs.iter().copied())
443        .map_err(|e| PolarsError::ComputeError(format!("struct from value: {e}").into()))?
444        .into_series();
445    Ok(Some(st))
446}
447
448/// Build a single row's map column value as List(Struct{key, value}) element from a JSON object.
449/// PySpark parity #627: create_dataframe_from_rows accepts dict for map columns.
450fn json_object_to_map_struct_series(
451    obj: &serde_json::Map<String, JsonValue>,
452    key_type: &str,
453    value_type: &str,
454    key_dtype: &DataType,
455    value_dtype: &DataType,
456    _name: &str,
457) -> Result<Series, PolarsError> {
458    if obj.is_empty() {
459        let key_series = Series::new("key".into(), Vec::<String>::new());
460        let value_series = Series::new_empty(PlSmallStr::EMPTY, value_dtype);
461        let st = StructChunked::from_series(
462            PlSmallStr::EMPTY,
463            0,
464            [&key_series, &value_series].iter().copied(),
465        )
466        .map_err(|e| PolarsError::ComputeError(format!("map struct empty: {e}").into()))?
467        .into_series();
468        return Ok(st);
469    }
470    let keys: Vec<String> = obj.keys().cloned().collect();
471    let mut value_series = None::<Series>;
472    for v in obj.values() {
473        let s = json_value_to_series_single(v, value_type, "value")?;
474        value_series = Some(match value_series.take() {
475            None => s,
476            Some(mut acc) => {
477                acc.extend(&s).map_err(|e| {
478                    PolarsError::ComputeError(format!("map value extend: {e}").into())
479                })?;
480                acc
481            }
482        });
483    }
484    let value_series =
485        value_series.unwrap_or_else(|| Series::new_empty(PlSmallStr::EMPTY, value_dtype));
486    let key_series = Series::new("key".into(), keys.clone());
487    let key_series = if key_type.trim().to_lowercase().as_str() == "string"
488        || key_type.trim().to_lowercase().as_str() == "str"
489        || key_type.trim().to_lowercase().as_str() == "varchar"
490    {
491        key_series
492    } else {
493        key_series
494            .cast(key_dtype)
495            .map_err(|e| PolarsError::ComputeError(format!("map key cast: {e}").into()))?
496    };
497    let st = StructChunked::from_series(
498        PlSmallStr::EMPTY,
499        key_series.len(),
500        [&key_series, &value_series].iter().copied(),
501    )
502    .map_err(|e| PolarsError::ComputeError(format!("map struct: {e}").into()))?
503    .into_series();
504    Ok(st)
505}
506
507use std::collections::{HashMap, HashSet};
508use std::path::Path;
509use std::sync::{Arc, Mutex, OnceLock};
510use std::thread_local;
511
512thread_local! {
513    /// Thread-local SparkSession for UDF resolution in call_udf. Set by get_or_create.
514    static THREAD_UDF_SESSION: RefCell<Option<SparkSession>> = const { RefCell::new(None) };
515}
516
517/// Set the thread-local session for UDF resolution (call_udf). Used by get_or_create.
518pub(crate) fn set_thread_udf_session(session: SparkSession) {
519    THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = Some(session));
520}
521
522/// Get the thread-local session for UDF resolution. Used by call_udf.
523pub(crate) fn get_thread_udf_session() -> Option<SparkSession> {
524    THREAD_UDF_SESSION.with(|cell| cell.borrow().clone())
525}
526
527/// Clear the thread-local session used for UDF resolution.
528pub(crate) fn clear_thread_udf_session() {
529    THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = None);
530}
531
532/// Catalog of global temporary views (process-scoped). Persists across sessions within the same process.
533/// PySpark: createOrReplaceGlobalTempView / spark.table("global_temp.name").
534static GLOBAL_TEMP_CATALOG: OnceLock<Arc<Mutex<HashMap<String, DataFrame>>>> = OnceLock::new();
535
536fn global_temp_catalog() -> Arc<Mutex<HashMap<String, DataFrame>>> {
537    GLOBAL_TEMP_CATALOG
538        .get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
539        .clone()
540}
541
542/// Builder for creating a SparkSession with configuration options
543#[derive(Clone)]
544pub struct SparkSessionBuilder {
545    app_name: Option<String>,
546    master: Option<String>,
547    config: HashMap<String, String>,
548}
549
550impl Default for SparkSessionBuilder {
551    fn default() -> Self {
552        Self::new()
553    }
554}
555
556impl SparkSessionBuilder {
557    pub fn new() -> Self {
558        SparkSessionBuilder {
559            app_name: None,
560            master: None,
561            config: HashMap::new(),
562        }
563    }
564
565    pub fn app_name(mut self, name: impl Into<String>) -> Self {
566        self.app_name = Some(name.into());
567        self
568    }
569
570    pub fn master(mut self, master: impl Into<String>) -> Self {
571        self.master = Some(master.into());
572        self
573    }
574
575    pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
576        self.config.insert(key.into(), value.into());
577        self
578    }
579
580    pub fn get_or_create(self) -> SparkSession {
581        let session = SparkSession::new(self.app_name, self.master, self.config);
582        set_thread_udf_session(session.clone());
583        session
584    }
585
586    /// Apply configuration from a [`SparklessConfig`](crate::config::SparklessConfig).
587    /// Merges warehouse dir, case sensitivity, and extra keys into the builder config.
588    pub fn with_config(mut self, config: &crate::config::SparklessConfig) -> Self {
589        for (k, v) in config.to_session_config() {
590            self.config.insert(k, v);
591        }
592        self
593    }
594}
595
596/// Catalog of temporary view names to DataFrames (session-scoped). Uses Arc<Mutex<>> for Send+Sync (Python bindings).
597pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
598
599/// Catalog of saved table names to DataFrames (session-scoped). Used by saveAsTable.
600pub type TableCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
601
602/// Names of databases/schemas created via CREATE DATABASE / CREATE SCHEMA (session-scoped). Persisted when SQL DDL runs.
603pub type DatabaseCatalog = Arc<Mutex<HashSet<String>>>;
604
605/// Main entry point for creating DataFrames and executing queries
606/// Similar to PySpark's SparkSession but using Polars as the backend
607#[derive(Clone)]
608pub struct SparkSession {
609    app_name: Option<String>,
610    master: Option<String>,
611    config: HashMap<String, String>,
612    /// Temporary views: name -> DataFrame. Session-scoped; cleared when session is dropped.
613    pub(crate) catalog: TempViewCatalog,
614    /// Saved tables (saveAsTable): name -> DataFrame. Session-scoped; separate namespace from temp views.
615    pub(crate) tables: TableCatalog,
616    /// Databases/schemas created via CREATE DATABASE / CREATE SCHEMA. Session-scoped; used by listDatabases/databaseExists.
617    pub(crate) databases: DatabaseCatalog,
618    /// UDF registry: Rust UDFs. Session-scoped.
619    pub(crate) udf_registry: UdfRegistry,
620}
621
622impl SparkSession {
623    pub fn new(
624        app_name: Option<String>,
625        master: Option<String>,
626        config: HashMap<String, String>,
627    ) -> Self {
628        SparkSession {
629            app_name,
630            master,
631            config,
632            catalog: Arc::new(Mutex::new(HashMap::new())),
633            tables: Arc::new(Mutex::new(HashMap::new())),
634            databases: Arc::new(Mutex::new(HashSet::new())),
635            udf_registry: UdfRegistry::new(),
636        }
637    }
638
639    /// Register a DataFrame as a temporary view (PySpark: createOrReplaceTempView).
640    /// The view is session-scoped and is dropped when the session is dropped.
641    pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
642        let _ = self
643            .catalog
644            .lock()
645            .map(|mut m| m.insert(name.to_string(), df));
646    }
647
648    /// Global temp view (PySpark: createGlobalTempView). Persists across sessions within the same process.
649    pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
650        let _ = global_temp_catalog()
651            .lock()
652            .map(|mut m| m.insert(name.to_string(), df));
653    }
654
655    /// Global temp view (PySpark: createOrReplaceGlobalTempView). Persists across sessions within the same process.
656    pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
657        let _ = global_temp_catalog()
658            .lock()
659            .map(|mut m| m.insert(name.to_string(), df));
660    }
661
662    /// Drop a temporary view by name (PySpark: catalog.dropTempView).
663    /// No error if the view does not exist.
664    pub fn drop_temp_view(&self, name: &str) {
665        let _ = self.catalog.lock().map(|mut m| m.remove(name));
666    }
667
668    /// Drop a global temporary view (PySpark: catalog.dropGlobalTempView). Removes from process-wide catalog.
669    pub fn drop_global_temp_view(&self, name: &str) -> bool {
670        global_temp_catalog()
671            .lock()
672            .map(|mut m| m.remove(name).is_some())
673            .unwrap_or(false)
674    }
675
676    /// Register a DataFrame as a saved table (PySpark: saveAsTable). Inserts into the tables catalog only.
677    pub fn register_table(&self, name: &str, df: DataFrame) {
678        let _ = self
679            .tables
680            .lock()
681            .map(|mut m| m.insert(name.to_string(), df));
682    }
683
684    /// Register a database/schema name (from CREATE DATABASE / CREATE SCHEMA). Persisted in session for listDatabases/databaseExists.
685    pub fn register_database(&self, name: &str) {
686        let _ = self.databases.lock().map(|mut s| {
687            s.insert(name.to_string());
688        });
689    }
690
691    /// List database names: built-in "default", "global_temp", plus any created via CREATE DATABASE / CREATE SCHEMA.
692    pub fn list_database_names(&self) -> Vec<String> {
693        let mut names: Vec<String> = vec!["default".to_string(), "global_temp".to_string()];
694        if let Ok(guard) = self.databases.lock() {
695            let mut created: Vec<String> = guard.iter().cloned().collect();
696            created.sort();
697            names.extend(created);
698        }
699        names
700    }
701
702    /// True if the database name exists (default, global_temp, or created via CREATE DATABASE / CREATE SCHEMA).
703    pub fn database_exists(&self, name: &str) -> bool {
704        if name.eq_ignore_ascii_case("default") || name.eq_ignore_ascii_case("global_temp") {
705            return true;
706        }
707        self.databases
708            .lock()
709            .map(|s| s.iter().any(|n| n.eq_ignore_ascii_case(name)))
710            .unwrap_or(false)
711    }
712
713    /// Get a saved table by name (tables map only). Returns None if not in saved tables (temp views not checked).
714    pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
715        self.tables.lock().ok().and_then(|m| m.get(name).cloned())
716    }
717
718    /// True if the name exists in the saved-tables map (not temp views).
719    pub fn saved_table_exists(&self, name: &str) -> bool {
720        self.tables
721            .lock()
722            .map(|m| m.contains_key(name))
723            .unwrap_or(false)
724    }
725
726    /// Check if a table or temp view exists (PySpark: catalog.tableExists). True if name is in temp views, saved tables, global temp, or warehouse.
727    pub fn table_exists(&self, name: &str) -> bool {
728        // global_temp.xyz
729        if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
730            return global_temp_catalog()
731                .lock()
732                .map(|m| m.contains_key(tbl))
733                .unwrap_or(false);
734        }
735        if self
736            .catalog
737            .lock()
738            .map(|m| m.contains_key(name))
739            .unwrap_or(false)
740        {
741            return true;
742        }
743        if self
744            .tables
745            .lock()
746            .map(|m| m.contains_key(name))
747            .unwrap_or(false)
748        {
749            return true;
750        }
751        // Warehouse fallback
752        if let Some(warehouse) = self.warehouse_dir() {
753            let path = Path::new(warehouse).join(name);
754            if path.is_dir() {
755                return true;
756            }
757        }
758        false
759    }
760
761    /// Return global temp view names (process-scoped). PySpark: catalog.listTables(dbName="global_temp").
762    pub fn list_global_temp_view_names(&self) -> Vec<String> {
763        global_temp_catalog()
764            .lock()
765            .map(|m| m.keys().cloned().collect())
766            .unwrap_or_default()
767    }
768
769    /// Return temporary view names in this session.
770    pub fn list_temp_view_names(&self) -> Vec<String> {
771        self.catalog
772            .lock()
773            .map(|m| m.keys().cloned().collect())
774            .unwrap_or_default()
775    }
776
777    /// Return saved table names in this session (saveAsTable / write_delta_table).
778    pub fn list_table_names(&self) -> Vec<String> {
779        self.tables
780            .lock()
781            .map(|m| m.keys().cloned().collect())
782            .unwrap_or_default()
783    }
784
785    /// Drop a saved table by name (removes from tables catalog only). No-op if not present.
786    pub fn drop_table(&self, name: &str) -> bool {
787        self.tables
788            .lock()
789            .map(|mut m| m.remove(name).is_some())
790            .unwrap_or(false)
791    }
792
793    /// Drop a database/schema by name (from DROP SCHEMA / DROP DATABASE). Removes from registered databases only.
794    /// Does not drop "default" or "global_temp". No-op if not present (or if_exists). Returns true if removed.
795    pub fn drop_database(&self, name: &str) -> bool {
796        if name.eq_ignore_ascii_case("default") || name.eq_ignore_ascii_case("global_temp") {
797            return false;
798        }
799        self.databases
800            .lock()
801            .map(|mut s| s.remove(name))
802            .unwrap_or(false)
803    }
804
805    /// Parse "global_temp.xyz" into ("global_temp", "xyz"). Returns None for plain names.
806    fn parse_global_temp_name(name: &str) -> Option<(&str, &str)> {
807        if let Some(dot) = name.find('.') {
808            let (db, tbl) = name.split_at(dot);
809            if db.eq_ignore_ascii_case("global_temp") {
810                return Some((db, tbl.strip_prefix('.').unwrap_or(tbl)));
811            }
812        }
813        None
814    }
815
816    /// Return spark.sql.warehouse.dir from config if set. Enables disk-backed saveAsTable.
817    pub fn warehouse_dir(&self) -> Option<&str> {
818        self.config
819            .get("spark.sql.warehouse.dir")
820            .map(|s| s.as_str())
821            .filter(|s| !s.is_empty())
822    }
823
824    /// Look up a table or temp view by name (PySpark: table(name)).
825    /// Resolution order: (1) global_temp.xyz from global catalog, (2) temp view, (3) saved table, (4) warehouse.
826    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
827        // global_temp.xyz -> global catalog only
828        if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
829            if let Some(df) = global_temp_catalog()
830                .lock()
831                .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
832                .get(tbl)
833                .cloned()
834            {
835                return Ok(df);
836            }
837            return Err(PolarsError::InvalidOperation(
838                format!(
839                    "Global temp view '{tbl}' not found. Register it with createOrReplaceGlobalTempView."
840                )
841                .into(),
842            ));
843        }
844        // Session: temp view, saved table
845        if let Some(df) = self
846            .catalog
847            .lock()
848            .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
849            .get(name)
850            .cloned()
851        {
852            return Ok(df);
853        }
854        if let Some(df) = self
855            .tables
856            .lock()
857            .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
858            .get(name)
859            .cloned()
860        {
861            return Ok(df);
862        }
863        // Warehouse fallback (disk-backed saveAsTable)
864        if let Some(warehouse) = self.warehouse_dir() {
865            let dir = Path::new(warehouse).join(name);
866            if dir.is_dir() {
867                // Read data.parquet (our convention) or the dir (Polars accepts dirs with parquet files)
868                let data_file = dir.join("data.parquet");
869                let read_path = if data_file.is_file() { data_file } else { dir };
870                return self.read_parquet(&read_path);
871            }
872        }
873        Err(PolarsError::InvalidOperation(
874            format!(
875                "Table or view '{name}' not found. Register it with create_or_replace_temp_view or saveAsTable."
876            )
877            .into(),
878        ))
879    }
880
881    pub fn builder() -> SparkSessionBuilder {
882        SparkSessionBuilder::new()
883    }
884
885    /// Create a session from a [`SparklessConfig`](crate::config::SparklessConfig).
886    /// Equivalent to `SparkSession::builder().with_config(config).get_or_create()`.
887    pub fn from_config(config: &crate::config::SparklessConfig) -> SparkSession {
888        Self::builder().with_config(config).get_or_create()
889    }
890
891    /// Return a reference to the session config (for catalog/conf compatibility).
892    pub fn get_config(&self) -> &HashMap<String, String> {
893        &self.config
894    }
895
896    /// Whether column names are case-sensitive (PySpark: spark.sql.caseSensitive).
897    /// Default is false (case-insensitive matching).
898    pub fn is_case_sensitive(&self) -> bool {
899        self.config
900            .get("spark.sql.caseSensitive")
901            .map(|v| v.eq_ignore_ascii_case("true"))
902            .unwrap_or(false)
903    }
904
905    /// Register a Rust UDF. Session-scoped. Use with call_udf. PySpark: spark.udf.register (Python) or equivalent.
906    pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
907    where
908        F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
909    {
910        self.udf_registry.register_rust_udf(name, f)
911    }
912
913    /// Create a DataFrame from a vector of tuples (i64, i64, String)
914    ///
915    /// # Example
916    /// ```
917    /// use robin_sparkless::session::SparkSession;
918    ///
919    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
920    /// let spark = SparkSession::builder().app_name("test").get_or_create();
921    /// let df = spark.create_dataframe(
922    ///     vec![
923    ///         (1, 25, "Alice".to_string()),
924    ///         (2, 30, "Bob".to_string()),
925    ///     ],
926    ///     vec!["id", "age", "name"],
927    /// )?;
928    /// #     let _ = df;
929    /// #     Ok(())
930    /// # }
931    /// ```
932    pub fn create_dataframe(
933        &self,
934        data: Vec<(i64, i64, String)>,
935        column_names: Vec<&str>,
936    ) -> Result<DataFrame, PolarsError> {
937        if column_names.len() != 3 {
938            return Err(PolarsError::ComputeError(
939                format!(
940                    "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
941                    column_names.len()
942                )
943                .into(),
944            ));
945        }
946
947        let mut cols: Vec<Series> = Vec::with_capacity(3);
948
949        // First column: i64
950        let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
951        cols.push(Series::new(column_names[0].into(), col0));
952
953        // Second column: i64
954        let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
955        cols.push(Series::new(column_names[1].into(), col1));
956
957        // Third column: String
958        let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
959        cols.push(Series::new(column_names[2].into(), col2));
960
961        let pl_df = PlDataFrame::new_infer_height(cols.iter().map(|s| s.clone().into()).collect())?;
962        Ok(DataFrame::from_polars_with_options(
963            pl_df,
964            self.is_case_sensitive(),
965        ))
966    }
967
968    /// Same as [`create_dataframe`](Self::create_dataframe) but returns [`EngineError`]. Use in bindings to avoid Polars.
969    pub fn create_dataframe_engine(
970        &self,
971        data: Vec<(i64, i64, String)>,
972        column_names: Vec<&str>,
973    ) -> Result<DataFrame, EngineError> {
974        self.create_dataframe(data, column_names)
975            .map_err(EngineError::from)
976    }
977
978    /// Create a DataFrame from a Polars DataFrame
979    pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
980        DataFrame::from_polars_with_options(df, self.is_case_sensitive())
981    }
982
983    /// Infer dtype string from a single JSON value (for schema inference). Returns None for Null.
984    fn infer_dtype_from_json_value(v: &JsonValue) -> Option<String> {
985        match v {
986            JsonValue::Null => None,
987            JsonValue::Bool(_) => Some("boolean".to_string()),
988            JsonValue::Number(n) => {
989                if n.is_i64() {
990                    Some("bigint".to_string())
991                } else {
992                    Some("double".to_string())
993                }
994            }
995            JsonValue::String(s) => {
996                if chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d").is_ok() {
997                    Some("date".to_string())
998                } else if chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f").is_ok()
999                    || chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").is_ok()
1000                {
1001                    Some("timestamp".to_string())
1002                } else {
1003                    Some("string".to_string())
1004                }
1005            }
1006            JsonValue::Array(_) => Some("array".to_string()),
1007            JsonValue::Object(_) => Some("string".to_string()), // struct inference not implemented; treat as string for safety
1008        }
1009    }
1010
1011    /// Infer schema (name, dtype_str) from JSON rows by scanning the first non-null value per column.
1012    /// Used by createDataFrame(data, schema=None) when schema is omitted or only column names given.
1013    pub fn infer_schema_from_json_rows(
1014        rows: &[Vec<JsonValue>],
1015        names: &[String],
1016    ) -> Vec<(String, String)> {
1017        if names.is_empty() {
1018            return Vec::new();
1019        }
1020        let mut schema: Vec<(String, String)> = names
1021            .iter()
1022            .map(|n| (n.clone(), "string".to_string()))
1023            .collect();
1024        for (col_idx, (_, dtype_str)) in schema.iter_mut().enumerate() {
1025            for row in rows {
1026                let v = row.get(col_idx).unwrap_or(&JsonValue::Null);
1027                if let Some(dtype) = Self::infer_dtype_from_json_value(v) {
1028                    *dtype_str = dtype;
1029                    break;
1030                }
1031            }
1032        }
1033        schema
1034    }
1035
1036    /// Create a DataFrame from rows and a schema (arbitrary column count and types).
1037    ///
1038    /// `rows`: each inner vec is one row; length must match schema length. Values are JSON-like (i64, f64, string, bool, null, object, array).
1039    /// `schema`: list of (column_name, dtype_string), e.g. `[("id", "bigint"), ("name", "string")]`.
1040    /// Supported dtype strings: bigint, int, long, double, float, string, str, varchar, boolean, bool, date, timestamp, datetime, list, array, array<element_type>, struct<field:type,...>.
1041    /// When `rows` is empty and `schema` is non-empty, returns an empty DataFrame with that schema (issue #519). Use with `write.format("parquet").saveAsTable(...)` then append; PySpark would fail with "can not infer schema from empty dataset".
1042    pub fn create_dataframe_from_rows(
1043        &self,
1044        rows: Vec<Vec<JsonValue>>,
1045        schema: Vec<(String, String)>,
1046    ) -> Result<DataFrame, PolarsError> {
1047        // #624: When schema is empty but rows are not, infer schema from rows (PySpark parity).
1048        let schema = if schema.is_empty() && !rows.is_empty() {
1049            let ncols = rows[0].len();
1050            let names: Vec<String> = (0..ncols).map(|i| format!("c{i}")).collect();
1051            Self::infer_schema_from_json_rows(&rows, &names)
1052        } else {
1053            schema
1054        };
1055
1056        if schema.is_empty() {
1057            if rows.is_empty() {
1058                return Ok(DataFrame::from_polars_with_options(
1059                    PlDataFrame::new(0, vec![])?,
1060                    self.is_case_sensitive(),
1061                ));
1062            }
1063            return Err(PolarsError::InvalidOperation(
1064                "create_dataframe_from_rows: schema must not be empty when rows are not empty"
1065                    .into(),
1066            ));
1067        }
1068        use chrono::{NaiveDate, NaiveDateTime};
1069
1070        let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
1071
1072        for (col_idx, (name, type_str)) in schema.iter().enumerate() {
1073            let type_lower = type_str.trim().to_lowercase();
1074            let s = match type_lower.as_str() {
1075                "int" | "integer" | "bigint" | "long" => {
1076                    let vals: Vec<Option<i64>> = rows
1077                        .iter()
1078                        .map(|row| {
1079                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1080                            match v {
1081                                JsonValue::Number(n) => n.as_i64(),
1082                                JsonValue::Null => None,
1083                                _ => None,
1084                            }
1085                        })
1086                        .collect();
1087                    Series::new(name.as_str().into(), vals)
1088                }
1089                "double" | "float" | "double_precision" => {
1090                    let vals: Vec<Option<f64>> = rows
1091                        .iter()
1092                        .map(|row| {
1093                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1094                            match v {
1095                                JsonValue::Number(n) => n.as_f64(),
1096                                JsonValue::Null => None,
1097                                _ => None,
1098                            }
1099                        })
1100                        .collect();
1101                    Series::new(name.as_str().into(), vals)
1102                }
1103                _ if is_decimal_type_str(&type_lower) => {
1104                    let vals: Vec<Option<f64>> = rows
1105                        .iter()
1106                        .map(|row| {
1107                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1108                            match v {
1109                                JsonValue::Number(n) => n.as_f64(),
1110                                JsonValue::Null => None,
1111                                _ => None,
1112                            }
1113                        })
1114                        .collect();
1115                    Series::new(name.as_str().into(), vals)
1116                }
1117                "string" | "str" | "varchar" => {
1118                    let vals: Vec<Option<String>> = rows
1119                        .iter()
1120                        .map(|row| {
1121                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1122                            match v {
1123                                JsonValue::String(s) => Some(s),
1124                                JsonValue::Null => None,
1125                                other => Some(other.to_string()),
1126                            }
1127                        })
1128                        .collect();
1129                    Series::new(name.as_str().into(), vals)
1130                }
1131                "boolean" | "bool" => {
1132                    let vals: Vec<Option<bool>> = rows
1133                        .iter()
1134                        .map(|row| {
1135                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1136                            match v {
1137                                JsonValue::Bool(b) => Some(b),
1138                                JsonValue::Null => None,
1139                                _ => None,
1140                            }
1141                        })
1142                        .collect();
1143                    Series::new(name.as_str().into(), vals)
1144                }
1145                "date" => {
1146                    let epoch = crate::date_utils::epoch_naive_date();
1147                    let vals: Vec<Option<i32>> = rows
1148                        .iter()
1149                        .map(|row| {
1150                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1151                            match v {
1152                                JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
1153                                    .ok()
1154                                    .map(|d| (d - epoch).num_days() as i32),
1155                                JsonValue::Null => None,
1156                                _ => None,
1157                            }
1158                        })
1159                        .collect();
1160                    let series = Series::new(name.as_str().into(), vals);
1161                    series
1162                        .cast(&DataType::Date)
1163                        .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
1164                }
1165                "timestamp" | "datetime" | "timestamp_ntz" => {
1166                    let vals: Vec<Option<i64>> =
1167                        rows.iter()
1168                            .map(|row| {
1169                                let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1170                                match v {
1171                                    JsonValue::String(s) => {
1172                                        let parsed = NaiveDateTime::parse_from_str(
1173                                            &s,
1174                                            "%Y-%m-%dT%H:%M:%S%.f",
1175                                        )
1176                                        .map_err(|e| {
1177                                            PolarsError::ComputeError(e.to_string().into())
1178                                        })
1179                                        .or_else(|_| {
1180                                            NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
1181                                                .map_err(|e| {
1182                                                    PolarsError::ComputeError(e.to_string().into())
1183                                                })
1184                                        })
1185                                        .or_else(|_| {
1186                                            NaiveDate::parse_from_str(&s, "%Y-%m-%d")
1187                                                .map_err(|e| {
1188                                                    PolarsError::ComputeError(e.to_string().into())
1189                                                })
1190                                                .and_then(|d| {
1191                                                    d.and_hms_opt(0, 0, 0).ok_or_else(|| {
1192                                                        PolarsError::ComputeError(
1193                                                            "date to datetime (0:0:0)".into(),
1194                                                        )
1195                                                    })
1196                                                })
1197                                        });
1198                                        parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
1199                                    }
1200                                    JsonValue::Number(n) => n.as_i64(),
1201                                    JsonValue::Null => None,
1202                                    _ => None,
1203                                }
1204                            })
1205                            .collect();
1206                    let series = Series::new(name.as_str().into(), vals);
1207                    series
1208                        .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
1209                        .map_err(|e| {
1210                            PolarsError::ComputeError(format!("datetime cast: {e}").into())
1211                        })?
1212                }
1213                "list" | "array" => {
1214                    // PySpark parity: ("col", "list") or ("col", "array"); infer element type from first non-null array.
1215                    let (elem_type, inner_dtype) = infer_list_element_type(&rows, col_idx)
1216                        .unwrap_or(("bigint".to_string(), DataType::Int64));
1217                    let n = rows.len();
1218                    let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
1219                    for row in rows.iter() {
1220                        let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1221                        if let JsonValue::Null = &v {
1222                            builder.append_null();
1223                        } else if let Some(arr) = json_value_to_array(&v) {
1224                            // #625: Array, Object with "0","1",..., or string that parses as JSON array (PySpark list parity).
1225                            let elem_series: Vec<Series> = arr
1226                                .iter()
1227                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1228                                .collect::<Result<Vec<_>, _>>()?;
1229                            let vals: Vec<_> =
1230                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1231                            let s = Series::from_any_values_and_dtype(
1232                                PlSmallStr::EMPTY,
1233                                &vals,
1234                                &inner_dtype,
1235                                false,
1236                            )
1237                            .map_err(|e| {
1238                                PolarsError::ComputeError(format!("array elem: {e}").into())
1239                            })?;
1240                            builder.append_series(&s)?;
1241                        } else {
1242                            // #611: PySpark accepts single value as one-element list.
1243                            let single_arr = [v];
1244                            let elem_series: Vec<Series> = single_arr
1245                                .iter()
1246                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1247                                .collect::<Result<Vec<_>, _>>()?;
1248                            let vals: Vec<_> =
1249                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1250                            let s = Series::from_any_values_and_dtype(
1251                                PlSmallStr::EMPTY,
1252                                &vals,
1253                                &inner_dtype,
1254                                false,
1255                            )
1256                            .map_err(|e| {
1257                                PolarsError::ComputeError(format!("array elem: {e}").into())
1258                            })?;
1259                            builder.append_series(&s)?;
1260                        }
1261                    }
1262                    builder.finish().into_series()
1263                }
1264                _ if parse_array_element_type(&type_lower).is_some() => {
1265                    let elem_type = parse_array_element_type(&type_lower).unwrap_or_else(|| {
1266                        unreachable!("guard above ensures parse_array_element_type returned Some")
1267                    });
1268                    let inner_dtype = json_type_str_to_polars(&elem_type)
1269                        .ok_or_else(|| {
1270                            PolarsError::ComputeError(
1271                                format!(
1272                                    "create_dataframe_from_rows: array element type '{elem_type}' not supported"
1273                                )
1274                                .into(),
1275                            )
1276                        })?;
1277                    let n = rows.len();
1278                    let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
1279                    for row in rows.iter() {
1280                        let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1281                        if let JsonValue::Null = &v {
1282                            builder.append_null();
1283                        } else if let Some(arr) = json_value_to_array(&v) {
1284                            // #625: Array, Object with "0","1",..., or string that parses as JSON array (PySpark list parity).
1285                            let elem_series: Vec<Series> = arr
1286                                .iter()
1287                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1288                                .collect::<Result<Vec<_>, _>>()?;
1289                            let vals: Vec<_> =
1290                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1291                            let s = Series::from_any_values_and_dtype(
1292                                PlSmallStr::EMPTY,
1293                                &vals,
1294                                &inner_dtype,
1295                                false,
1296                            )
1297                            .map_err(|e| {
1298                                PolarsError::ComputeError(format!("array elem: {e}").into())
1299                            })?;
1300                            builder.append_series(&s)?;
1301                        } else {
1302                            // #611: PySpark accepts single value as one-element list.
1303                            let single_arr = [v];
1304                            let elem_series: Vec<Series> = single_arr
1305                                .iter()
1306                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1307                                .collect::<Result<Vec<_>, _>>()?;
1308                            let vals: Vec<_> =
1309                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1310                            let s = Series::from_any_values_and_dtype(
1311                                PlSmallStr::EMPTY,
1312                                &vals,
1313                                &inner_dtype,
1314                                false,
1315                            )
1316                            .map_err(|e| {
1317                                PolarsError::ComputeError(format!("array elem: {e}").into())
1318                            })?;
1319                            builder.append_series(&s)?;
1320                        }
1321                    }
1322                    builder.finish().into_series()
1323                }
1324                _ if parse_map_key_value_types(&type_lower).is_some() => {
1325                    let (key_type, value_type) = parse_map_key_value_types(&type_lower)
1326                        .unwrap_or_else(|| unreachable!("guard ensures Some"));
1327                    let key_dtype = json_type_str_to_polars(&key_type).ok_or_else(|| {
1328                        PolarsError::ComputeError(
1329                            format!(
1330                                "create_dataframe_from_rows: map key type '{key_type}' not supported"
1331                            )
1332                            .into(),
1333                        )
1334                    })?;
1335                    let value_dtype = json_type_str_to_polars(&value_type).ok_or_else(|| {
1336                        PolarsError::ComputeError(
1337                            format!(
1338                                "create_dataframe_from_rows: map value type '{value_type}' not supported"
1339                            )
1340                            .into(),
1341                        )
1342                    })?;
1343                    let struct_dtype = DataType::Struct(vec![
1344                        Field::new("key".into(), key_dtype.clone()),
1345                        Field::new("value".into(), value_dtype.clone()),
1346                    ]);
1347                    let n = rows.len();
1348                    let mut builder = get_list_builder(&struct_dtype, 64, n, name.as_str().into());
1349                    for row in rows.iter() {
1350                        let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1351                        if matches!(v, JsonValue::Null) {
1352                            builder.append_null();
1353                        } else if let Some(obj) = v.as_object() {
1354                            let st = json_object_to_map_struct_series(
1355                                obj,
1356                                &key_type,
1357                                &value_type,
1358                                &key_dtype,
1359                                &value_dtype,
1360                                name,
1361                            )?;
1362                            builder.append_series(&st)?;
1363                        } else {
1364                            return Err(PolarsError::ComputeError(
1365                                format!(
1366                                    "create_dataframe_from_rows: map column '{name}' expects JSON object (dict), got {:?}",
1367                                    v
1368                                )
1369                                .into(),
1370                            ));
1371                        }
1372                    }
1373                    builder.finish().into_series()
1374                }
1375                _ if parse_struct_fields(&type_lower).is_some() => {
1376                    let values: Vec<Option<JsonValue>> =
1377                        rows.iter().map(|row| row.get(col_idx).cloned()).collect();
1378                    json_values_to_series(&values, &type_lower, name)?
1379                }
1380                _ => {
1381                    return Err(PolarsError::ComputeError(
1382                        format!(
1383                            "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
1384                        )
1385                        .into(),
1386                    ));
1387                }
1388            };
1389            cols.push(s);
1390        }
1391
1392        let pl_df = PlDataFrame::new_infer_height(cols.iter().map(|s| s.clone().into()).collect())?;
1393        Ok(DataFrame::from_polars_with_options(
1394            pl_df,
1395            self.is_case_sensitive(),
1396        ))
1397    }
1398
1399    /// Same as [`create_dataframe_from_rows`](Self::create_dataframe_from_rows) but returns [`EngineError`]. Use in bindings to avoid Polars.
1400    pub fn create_dataframe_from_rows_engine(
1401        &self,
1402        rows: Vec<Vec<JsonValue>>,
1403        schema: Vec<(String, String)>,
1404    ) -> Result<DataFrame, EngineError> {
1405        self.create_dataframe_from_rows(rows, schema)
1406            .map_err(EngineError::from)
1407    }
1408
1409    /// Create a DataFrame with a single column `id` (bigint) containing values from start to end (exclusive) with step.
1410    /// PySpark: spark.range(end) or spark.range(start, end, step).
1411    ///
1412    /// - `range(end)` → 0 to end-1, step 1
1413    /// - `range(start, end)` → start to end-1, step 1
1414    /// - `range(start, end, step)` → start, start+step, ... up to but not including end
1415    pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
1416        if step == 0 {
1417            return Err(PolarsError::InvalidOperation(
1418                "range: step must not be 0".into(),
1419            ));
1420        }
1421        let mut vals: Vec<i64> = Vec::new();
1422        let mut v = start;
1423        if step > 0 {
1424            while v < end {
1425                vals.push(v);
1426                v = v.saturating_add(step);
1427            }
1428        } else {
1429            while v > end {
1430                vals.push(v);
1431                v = v.saturating_add(step);
1432            }
1433        }
1434        let col = Series::new("id".into(), vals);
1435        let pl_df = PlDataFrame::new_infer_height(vec![col.into()])?;
1436        Ok(DataFrame::from_polars_with_options(
1437            pl_df,
1438            self.is_case_sensitive(),
1439        ))
1440    }
1441
1442    /// Read a CSV file.
1443    ///
1444    /// Uses Polars' CSV reader with default options:
1445    /// - Header row is inferred (default: true)
1446    /// - Schema is inferred from first 100 rows
1447    ///
1448    /// # Example
1449    /// ```
1450    /// use robin_sparkless::SparkSession;
1451    ///
1452    /// let spark = SparkSession::builder().app_name("test").get_or_create();
1453    /// let df_result = spark.read_csv("data.csv");
1454    /// // Handle the Result as appropriate in your application
1455    /// ```
1456    pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1457        use polars::prelude::*;
1458        let path = path.as_ref();
1459        if !path.exists() {
1460            return Err(PolarsError::ComputeError(
1461                format!("read_csv: file not found: {}", path.display()).into(),
1462            ));
1463        }
1464        let path_display = path.display();
1465        // Use LazyCsvReader - call finish() to get LazyFrame, then collect
1466        let pl_path = PlRefPath::try_from_path(path).map_err(|e| {
1467            PolarsError::ComputeError(format!("read_csv({path_display}): path: {e}").into())
1468        })?;
1469        let lf = LazyCsvReader::new(pl_path)
1470            .with_has_header(true)
1471            .with_infer_schema_length(Some(100))
1472            .finish()
1473            .map_err(|e| {
1474                PolarsError::ComputeError(
1475                    format!(
1476                        "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
1477                    )
1478                    .into(),
1479                )
1480            })?;
1481        Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1482            lf,
1483            self.is_case_sensitive(),
1484        ))
1485    }
1486
1487    /// Same as [`read_csv`](Self::read_csv) but returns [`EngineError`]. Use in bindings to avoid Polars.
1488    pub fn read_csv_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1489        self.read_csv(path).map_err(EngineError::from)
1490    }
1491
1492    /// Read a Parquet file.
1493    ///
1494    /// Uses Polars' Parquet reader. Parquet files have embedded schema, so
1495    /// schema inference is automatic.
1496    ///
1497    /// # Example
1498    /// ```
1499    /// use robin_sparkless::SparkSession;
1500    ///
1501    /// let spark = SparkSession::builder().app_name("test").get_or_create();
1502    /// let df_result = spark.read_parquet("data.parquet");
1503    /// // Handle the Result as appropriate in your application
1504    /// ```
1505    pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1506        use polars::prelude::*;
1507        let path = path.as_ref();
1508        if !path.exists() {
1509            return Err(PolarsError::ComputeError(
1510                format!("read_parquet: file not found: {}", path.display()).into(),
1511            ));
1512        }
1513        // Use LazyFrame::scan_parquet
1514        let pl_path = PlRefPath::try_from_path(path)
1515            .map_err(|e| PolarsError::ComputeError(format!("read_parquet: path: {e}").into()))?;
1516        let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())?;
1517        Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1518            lf,
1519            self.is_case_sensitive(),
1520        ))
1521    }
1522
1523    /// Same as [`read_parquet`](Self::read_parquet) but returns [`EngineError`]. Use in bindings to avoid Polars.
1524    pub fn read_parquet_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1525        self.read_parquet(path).map_err(EngineError::from)
1526    }
1527
1528    /// Read a JSON file (JSONL format - one JSON object per line).
1529    ///
1530    /// Uses Polars' JSONL reader with default options:
1531    /// - Schema is inferred from first 100 rows
1532    ///
1533    /// # Example
1534    /// ```
1535    /// use robin_sparkless::SparkSession;
1536    ///
1537    /// let spark = SparkSession::builder().app_name("test").get_or_create();
1538    /// let df_result = spark.read_json("data.json");
1539    /// // Handle the Result as appropriate in your application
1540    /// ```
1541    pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1542        use polars::prelude::*;
1543        use std::num::NonZeroUsize;
1544        let path = path.as_ref();
1545        if !path.exists() {
1546            return Err(PolarsError::ComputeError(
1547                format!("read_json: file not found: {}", path.display()).into(),
1548            ));
1549        }
1550        // Use LazyJsonLineReader - call finish() to get LazyFrame, then collect
1551        let pl_path = PlRefPath::try_from_path(path)
1552            .map_err(|e| PolarsError::ComputeError(format!("read_json: path: {e}").into()))?;
1553        let lf = LazyJsonLineReader::new(pl_path)
1554            .with_infer_schema_length(NonZeroUsize::new(100))
1555            .finish()?;
1556        Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1557            lf,
1558            self.is_case_sensitive(),
1559        ))
1560    }
1561
1562    /// Same as [`read_json`](Self::read_json) but returns [`EngineError`]. Use in bindings to avoid Polars.
1563    pub fn read_json_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1564        self.read_json(path).map_err(EngineError::from)
1565    }
1566
1567    /// Execute a SQL query (SELECT only). Tables must be registered with `create_or_replace_temp_view`.
1568    /// Requires the `sql` feature. Supports: SELECT (columns or *), FROM (single table or JOIN),
1569    /// WHERE (basic predicates), GROUP BY + aggregates, ORDER BY, LIMIT.
1570    #[cfg(feature = "sql")]
1571    pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
1572        crate::sql::execute_sql(self, query)
1573    }
1574
1575    /// Execute a SQL query (stub when `sql` feature is disabled).
1576    #[cfg(not(feature = "sql"))]
1577    pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
1578        Err(PolarsError::InvalidOperation(
1579            "SQL queries require the 'sql' feature. Build with --features sql.".into(),
1580        ))
1581    }
1582
1583    /// Same as [`table`](Self::table) but returns [`EngineError`]. Use in bindings to avoid Polars.
1584    pub fn table_engine(&self, name: &str) -> Result<DataFrame, EngineError> {
1585        self.table(name).map_err(EngineError::from)
1586    }
1587
1588    /// Returns true if the string looks like a filesystem path (has separators or path exists).
1589    fn looks_like_path(s: &str) -> bool {
1590        s.contains('/') || s.contains('\\') || Path::new(s).exists()
1591    }
1592
1593    /// Read a Delta table from path (latest version). Internal; use read_delta(name_or_path: &str) for dispatch.
1594    #[cfg(feature = "delta")]
1595    pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1596        crate::delta::read_delta(path, self.is_case_sensitive())
1597    }
1598
1599    /// Read Delta table at path, optional version. Internal; use read_delta_str for dispatch.
1600    #[cfg(feature = "delta")]
1601    pub fn read_delta_path_with_version(
1602        &self,
1603        path: impl AsRef<Path>,
1604        version: Option<i64>,
1605    ) -> Result<DataFrame, PolarsError> {
1606        crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
1607    }
1608
1609    /// Read a Delta table or in-memory table by name/path. If name_or_path looks like a path, reads from Delta on disk; else resolves as table name (temp view then saved table).
1610    #[cfg(feature = "delta")]
1611    pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1612        if Self::looks_like_path(name_or_path) {
1613            self.read_delta_path(Path::new(name_or_path))
1614        } else {
1615            self.table(name_or_path)
1616        }
1617    }
1618
1619    #[cfg(feature = "delta")]
1620    pub fn read_delta_with_version(
1621        &self,
1622        name_or_path: &str,
1623        version: Option<i64>,
1624    ) -> Result<DataFrame, PolarsError> {
1625        if Self::looks_like_path(name_or_path) {
1626            self.read_delta_path_with_version(Path::new(name_or_path), version)
1627        } else {
1628            // In-memory tables have no version; ignore version and return table
1629            self.table(name_or_path)
1630        }
1631    }
1632
1633    /// Stub when `delta` feature is disabled. Still supports reading by table name.
1634    #[cfg(not(feature = "delta"))]
1635    pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1636        if Self::looks_like_path(name_or_path) {
1637            Err(PolarsError::InvalidOperation(
1638                "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1639            ))
1640        } else {
1641            self.table(name_or_path)
1642        }
1643    }
1644
1645    #[cfg(not(feature = "delta"))]
1646    pub fn read_delta_with_version(
1647        &self,
1648        name_or_path: &str,
1649        version: Option<i64>,
1650    ) -> Result<DataFrame, PolarsError> {
1651        let _ = version;
1652        self.read_delta(name_or_path)
1653    }
1654
1655    /// Path-only read_delta (for DataFrameReader.load/format delta). Requires delta feature.
1656    #[cfg(feature = "delta")]
1657    pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1658        self.read_delta_path(path)
1659    }
1660
1661    #[cfg(not(feature = "delta"))]
1662    pub fn read_delta_from_path(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1663        Err(PolarsError::InvalidOperation(
1664            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1665        ))
1666    }
1667
1668    /// Stop the session (cleanup resources)
1669    pub fn stop(&self) {
1670        // Best-effort cleanup. This is primarily for PySpark parity so that `spark.stop()`
1671        // exists and can be called in teardown.
1672        let _ = self.catalog.lock().map(|mut m| m.clear());
1673        let _ = self.tables.lock().map(|mut m| m.clear());
1674        let _ = self.databases.lock().map(|mut s| s.clear());
1675        let _ = self.udf_registry.clear();
1676        clear_thread_udf_session();
1677    }
1678}
1679
1680/// DataFrameReader for reading various file formats
1681/// Similar to PySpark's DataFrameReader with option/options/format/load/table
1682pub struct DataFrameReader {
1683    session: SparkSession,
1684    options: HashMap<String, String>,
1685    format: Option<String>,
1686}
1687
1688impl DataFrameReader {
1689    pub fn new(session: SparkSession) -> Self {
1690        DataFrameReader {
1691            session,
1692            options: HashMap::new(),
1693            format: None,
1694        }
1695    }
1696
1697    /// Add a single option (PySpark: option(key, value)). Returns self for chaining.
1698    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1699        self.options.insert(key.into(), value.into());
1700        self
1701    }
1702
1703    /// Add multiple options (PySpark: options(**kwargs)). Returns self for chaining.
1704    pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
1705        for (k, v) in opts {
1706            self.options.insert(k, v);
1707        }
1708        self
1709    }
1710
1711    /// Set the format for load() (PySpark: format("parquet") etc).
1712    pub fn format(mut self, fmt: impl Into<String>) -> Self {
1713        self.format = Some(fmt.into());
1714        self
1715    }
1716
1717    /// Set the schema (PySpark: schema(schema)). Stub: stores but does not apply yet.
1718    pub fn schema(self, _schema: impl Into<String>) -> Self {
1719        self
1720    }
1721
1722    /// Load data from path using format (or infer from extension) and options.
1723    pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1724        let path = path.as_ref();
1725        let fmt = self.format.clone().or_else(|| {
1726            path.extension()
1727                .and_then(|e| e.to_str())
1728                .map(|s| s.to_lowercase())
1729        });
1730        match fmt.as_deref() {
1731            Some("parquet") => self.parquet(path),
1732            Some("csv") => self.csv(path),
1733            Some("json") | Some("jsonl") => self.json(path),
1734            #[cfg(feature = "delta")]
1735            Some("delta") => self.session.read_delta_from_path(path),
1736            _ => Err(PolarsError::ComputeError(
1737                format!(
1738                    "load: could not infer format for path '{}'. Use format('parquet'|'csv'|'json') before load.",
1739                    path.display()
1740                )
1741                .into(),
1742            )),
1743        }
1744    }
1745
1746    /// Return the named table/view (PySpark: table(name)).
1747    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
1748        self.session.table(name)
1749    }
1750
1751    fn apply_csv_options(
1752        &self,
1753        reader: polars::prelude::LazyCsvReader,
1754    ) -> polars::prelude::LazyCsvReader {
1755        use polars::prelude::NullValues;
1756        let mut r = reader;
1757        if let Some(v) = self.options.get("header") {
1758            let has_header = v.eq_ignore_ascii_case("true") || v == "1";
1759            r = r.with_has_header(has_header);
1760        }
1761        if let Some(v) = self.options.get("inferSchema") {
1762            if v.eq_ignore_ascii_case("true") || v == "1" {
1763                let n = self
1764                    .options
1765                    .get("inferSchemaLength")
1766                    .and_then(|s| s.parse::<usize>().ok())
1767                    .unwrap_or(100);
1768                r = r.with_infer_schema_length(Some(n));
1769            } else {
1770                // inferSchema=false: do not infer types (PySpark parity #543)
1771                r = r.with_infer_schema_length(Some(0));
1772            }
1773        } else if let Some(v) = self.options.get("inferSchemaLength") {
1774            if let Ok(n) = v.parse::<usize>() {
1775                r = r.with_infer_schema_length(Some(n));
1776            }
1777        }
1778        if let Some(sep) = self.options.get("sep") {
1779            if let Some(b) = sep.bytes().next() {
1780                r = r.with_separator(b);
1781            }
1782        }
1783        if let Some(null_val) = self.options.get("nullValue") {
1784            r = r.with_null_values(Some(NullValues::AllColumnsSingle(null_val.clone().into())));
1785        }
1786        r
1787    }
1788
1789    fn apply_json_options(
1790        &self,
1791        reader: polars::prelude::LazyJsonLineReader,
1792    ) -> polars::prelude::LazyJsonLineReader {
1793        use std::num::NonZeroUsize;
1794        let mut r = reader;
1795        if let Some(v) = self.options.get("inferSchemaLength") {
1796            if let Ok(n) = v.parse::<usize>() {
1797                r = r.with_infer_schema_length(NonZeroUsize::new(n));
1798            }
1799        }
1800        r
1801    }
1802
1803    pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1804        use polars::prelude::*;
1805        let path = path.as_ref();
1806        let path_display = path.display();
1807        let pl_path = PlRefPath::try_from_path(path).map_err(|e| {
1808            PolarsError::ComputeError(format!("csv({path_display}): path: {e}").into())
1809        })?;
1810        let reader = LazyCsvReader::new(pl_path);
1811        let reader = if self.options.is_empty() {
1812            reader
1813                .with_has_header(true)
1814                .with_infer_schema_length(Some(100))
1815        } else {
1816            self.apply_csv_options(
1817                reader
1818                    .with_has_header(true)
1819                    .with_infer_schema_length(Some(100)),
1820            )
1821        };
1822        let lf = reader.finish().map_err(|e| {
1823            PolarsError::ComputeError(format!("read csv({path_display}): {e}").into())
1824        })?;
1825        let pl_df = lf.collect().map_err(|e| {
1826            PolarsError::ComputeError(
1827                format!("read csv({path_display}): collect failed: {e}").into(),
1828            )
1829        })?;
1830        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1831            pl_df,
1832            self.session.is_case_sensitive(),
1833        ))
1834    }
1835
1836    pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1837        use polars::prelude::*;
1838        let path = path.as_ref();
1839        let pl_path = PlRefPath::try_from_path(path)
1840            .map_err(|e| PolarsError::ComputeError(format!("parquet: path: {e}").into()))?;
1841        let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())?;
1842        let pl_df = lf.collect()?;
1843        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1844            pl_df,
1845            self.session.is_case_sensitive(),
1846        ))
1847    }
1848
1849    pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1850        use polars::prelude::*;
1851        use std::num::NonZeroUsize;
1852        let path = path.as_ref();
1853        let pl_path = PlRefPath::try_from_path(path)
1854            .map_err(|e| PolarsError::ComputeError(format!("json: path: {e}").into()))?;
1855        let reader = LazyJsonLineReader::new(pl_path);
1856        let reader = if self.options.is_empty() {
1857            reader.with_infer_schema_length(NonZeroUsize::new(100))
1858        } else {
1859            self.apply_json_options(reader.with_infer_schema_length(NonZeroUsize::new(100)))
1860        };
1861        let lf = reader.finish()?;
1862        let pl_df = lf.collect()?;
1863        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1864            pl_df,
1865            self.session.is_case_sensitive(),
1866        ))
1867    }
1868
1869    #[cfg(feature = "delta")]
1870    pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1871        self.session.read_delta_from_path(path)
1872    }
1873}
1874
1875impl SparkSession {
1876    /// Get a DataFrameReader for reading files
1877    pub fn read(&self) -> DataFrameReader {
1878        DataFrameReader::new(SparkSession {
1879            app_name: self.app_name.clone(),
1880            master: self.master.clone(),
1881            config: self.config.clone(),
1882            catalog: self.catalog.clone(),
1883            tables: self.tables.clone(),
1884            databases: self.databases.clone(),
1885            udf_registry: self.udf_registry.clone(),
1886        })
1887    }
1888}
1889
1890impl Default for SparkSession {
1891    fn default() -> Self {
1892        Self::builder().get_or_create()
1893    }
1894}
1895
1896#[cfg(test)]
1897mod tests {
1898    use super::*;
1899
1900    #[test]
1901    fn test_spark_session_builder_basic() {
1902        let spark = SparkSession::builder().app_name("test_app").get_or_create();
1903
1904        assert_eq!(spark.app_name, Some("test_app".to_string()));
1905    }
1906
1907    #[test]
1908    fn test_spark_session_builder_with_master() {
1909        let spark = SparkSession::builder()
1910            .app_name("test_app")
1911            .master("local[*]")
1912            .get_or_create();
1913
1914        assert_eq!(spark.app_name, Some("test_app".to_string()));
1915        assert_eq!(spark.master, Some("local[*]".to_string()));
1916    }
1917
1918    #[test]
1919    fn test_spark_session_builder_with_config() {
1920        let spark = SparkSession::builder()
1921            .app_name("test_app")
1922            .config("spark.executor.memory", "4g")
1923            .config("spark.driver.memory", "2g")
1924            .get_or_create();
1925
1926        assert_eq!(
1927            spark.config.get("spark.executor.memory"),
1928            Some(&"4g".to_string())
1929        );
1930        assert_eq!(
1931            spark.config.get("spark.driver.memory"),
1932            Some(&"2g".to_string())
1933        );
1934    }
1935
1936    #[test]
1937    fn test_spark_session_default() {
1938        let spark = SparkSession::default();
1939        assert!(spark.app_name.is_none());
1940        assert!(spark.master.is_none());
1941        assert!(spark.config.is_empty());
1942    }
1943
1944    #[test]
1945    fn test_create_dataframe_success() {
1946        let spark = SparkSession::builder().app_name("test").get_or_create();
1947        let data = vec![
1948            (1i64, 25i64, "Alice".to_string()),
1949            (2i64, 30i64, "Bob".to_string()),
1950        ];
1951
1952        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
1953
1954        assert!(result.is_ok());
1955        let df = result.unwrap();
1956        assert_eq!(df.count().unwrap(), 2);
1957
1958        let columns = df.columns().unwrap();
1959        assert!(columns.contains(&"id".to_string()));
1960        assert!(columns.contains(&"age".to_string()));
1961        assert!(columns.contains(&"name".to_string()));
1962    }
1963
1964    #[test]
1965    fn test_create_dataframe_wrong_column_count() {
1966        let spark = SparkSession::builder().app_name("test").get_or_create();
1967        let data = vec![(1i64, 25i64, "Alice".to_string())];
1968
1969        // Too few columns
1970        let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
1971        assert!(result.is_err());
1972
1973        // Too many columns
1974        let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
1975        assert!(result.is_err());
1976    }
1977
1978    #[test]
1979    fn test_create_dataframe_from_rows_empty_schema_with_rows_returns_error() {
1980        let spark = SparkSession::builder().app_name("test").get_or_create();
1981        let rows: Vec<Vec<JsonValue>> = vec![vec![]];
1982        let schema: Vec<(String, String)> = vec![];
1983        let result = spark.create_dataframe_from_rows(rows, schema);
1984        match &result {
1985            Err(e) => assert!(e.to_string().contains("schema must not be empty")),
1986            Ok(_) => panic!("expected error for empty schema with non-empty rows"),
1987        }
1988    }
1989
1990    #[test]
1991    fn test_create_dataframe_from_rows_empty_data_with_schema() {
1992        let spark = SparkSession::builder().app_name("test").get_or_create();
1993        let rows: Vec<Vec<JsonValue>> = vec![];
1994        let schema = vec![
1995            ("a".to_string(), "int".to_string()),
1996            ("b".to_string(), "string".to_string()),
1997        ];
1998        let result = spark.create_dataframe_from_rows(rows, schema);
1999        let df = result.unwrap();
2000        assert_eq!(df.count().unwrap(), 0);
2001        assert_eq!(df.collect_inner().unwrap().get_column_names(), &["a", "b"]);
2002    }
2003
2004    #[test]
2005    fn test_create_dataframe_from_rows_empty_schema_empty_data() {
2006        let spark = SparkSession::builder().app_name("test").get_or_create();
2007        let rows: Vec<Vec<JsonValue>> = vec![];
2008        let schema: Vec<(String, String)> = vec![];
2009        let result = spark.create_dataframe_from_rows(rows, schema);
2010        let df = result.unwrap();
2011        assert_eq!(df.count().unwrap(), 0);
2012        assert_eq!(df.collect_inner().unwrap().get_column_names().len(), 0);
2013    }
2014
2015    /// create_dataframe_from_rows: struct column as JSON object (by field name). PySpark parity #600.
2016    #[test]
2017    fn test_create_dataframe_from_rows_struct_as_object() {
2018        use serde_json::json;
2019
2020        let spark = SparkSession::builder().app_name("test").get_or_create();
2021        let schema = vec![
2022            ("id".to_string(), "string".to_string()),
2023            (
2024                "nested".to_string(),
2025                "struct<a:bigint,b:string>".to_string(),
2026            ),
2027        ];
2028        let rows: Vec<Vec<JsonValue>> = vec![
2029            vec![json!("x"), json!({"a": 1, "b": "y"})],
2030            vec![json!("z"), json!({"a": 2, "b": "w"})],
2031        ];
2032        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2033        assert_eq!(df.count().unwrap(), 2);
2034        let collected = df.collect_inner().unwrap();
2035        assert_eq!(collected.get_column_names(), &["id", "nested"]);
2036    }
2037
2038    /// create_dataframe_from_rows: struct column as JSON array (by position). PySpark parity #600.
2039    #[test]
2040    fn test_create_dataframe_from_rows_struct_as_array() {
2041        use serde_json::json;
2042
2043        let spark = SparkSession::builder().app_name("test").get_or_create();
2044        let schema = vec![
2045            ("id".to_string(), "string".to_string()),
2046            (
2047                "nested".to_string(),
2048                "struct<a:bigint,b:string>".to_string(),
2049            ),
2050        ];
2051        let rows: Vec<Vec<JsonValue>> = vec![
2052            vec![json!("x"), json!([1, "y"])],
2053            vec![json!("z"), json!([2, "w"])],
2054        ];
2055        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2056        assert_eq!(df.count().unwrap(), 2);
2057        let collected = df.collect_inner().unwrap();
2058        assert_eq!(collected.get_column_names(), &["id", "nested"]);
2059    }
2060
2061    /// #610: create_dataframe_from_rows accepts struct as string that parses to object or array (Sparkless/Python serialization).
2062    #[test]
2063    fn test_issue_610_struct_value_as_string_object_or_array() {
2064        use serde_json::json;
2065
2066        let spark = SparkSession::builder().app_name("test").get_or_create();
2067        let schema = vec![
2068            ("id".to_string(), "string".to_string()),
2069            (
2070                "nested".to_string(),
2071                "struct<a:bigint,b:string>".to_string(),
2072            ),
2073        ];
2074        // Struct as string that parses to JSON object (e.g. Python dict serialized as string).
2075        let rows_object: Vec<Vec<JsonValue>> =
2076            vec![vec![json!("A"), json!(r#"{"a": 1, "b": "x"}"#)]];
2077        let df1 = spark
2078            .create_dataframe_from_rows(rows_object, schema.clone())
2079            .unwrap();
2080        assert_eq!(df1.count().unwrap(), 1);
2081
2082        // Struct as string that parses to JSON array (e.g. Python tuple (1, "y") serialized as "[1, \"y\"]").
2083        let rows_array: Vec<Vec<JsonValue>> = vec![vec![json!("B"), json!(r#"[1, "y"]"#)]];
2084        let df2 = spark
2085            .create_dataframe_from_rows(rows_array, schema)
2086            .unwrap();
2087        assert_eq!(df2.count().unwrap(), 1);
2088    }
2089
2090    /// #611: create_dataframe_from_rows accepts single value as one-element array (PySpark parity).
2091    #[test]
2092    fn test_issue_611_array_column_single_value_as_one_element() {
2093        use serde_json::json;
2094
2095        let spark = SparkSession::builder().app_name("test").get_or_create();
2096        let schema = vec![
2097            ("id".to_string(), "string".to_string()),
2098            ("arr".to_string(), "array<bigint>".to_string()),
2099        ];
2100        // Single number as one-element list (PySpark accepts this).
2101        let rows: Vec<Vec<JsonValue>> = vec![
2102            vec![json!("x"), json!(42)],
2103            vec![json!("y"), json!([1, 2, 3])],
2104        ];
2105        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2106        assert_eq!(df.count().unwrap(), 2);
2107        let collected = df.collect_inner().unwrap();
2108        let arr_col = collected.column("arr").unwrap();
2109        let list = arr_col.list().unwrap();
2110        let row0 = list.get(0).unwrap();
2111        assert_eq!(
2112            row0.len(),
2113            1,
2114            "#611: single value should become one-element list"
2115        );
2116        let row1 = list.get(1).unwrap();
2117        assert_eq!(row1.len(), 3);
2118    }
2119
2120    /// create_dataframe_from_rows: array column with JSON array and null. PySpark parity #601.
2121    #[test]
2122    fn test_create_dataframe_from_rows_array_column() {
2123        use serde_json::json;
2124
2125        let spark = SparkSession::builder().app_name("test").get_or_create();
2126        let schema = vec![
2127            ("id".to_string(), "string".to_string()),
2128            ("arr".to_string(), "array<bigint>".to_string()),
2129        ];
2130        let rows: Vec<Vec<JsonValue>> = vec![
2131            vec![json!("x"), json!([1, 2, 3])],
2132            vec![json!("y"), json!([4, 5])],
2133            vec![json!("z"), json!(null)],
2134        ];
2135        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2136        assert_eq!(df.count().unwrap(), 3);
2137        let collected = df.collect_inner().unwrap();
2138        assert_eq!(collected.get_column_names(), &["id", "arr"]);
2139
2140        // Issue #601: verify array data round-trips correctly (not just no error).
2141        let arr_col = collected.column("arr").unwrap();
2142        let list = arr_col.list().unwrap();
2143        // Row 0: [1, 2, 3]
2144        let row0 = list.get(0).unwrap();
2145        assert_eq!(row0.len(), 3, "row 0 arr should have 3 elements");
2146        // Row 1: [4, 5]
2147        let row1 = list.get(1).unwrap();
2148        assert_eq!(row1.len(), 2);
2149        // Row 2: null list (representation may be None or empty)
2150        let row2 = list.get(2);
2151        assert!(
2152            row2.is_none() || row2.as_ref().map(|a| a.is_empty()).unwrap_or(false),
2153            "row 2 arr should be null or empty"
2154        );
2155    }
2156
2157    /// Issue #601: PySpark createDataFrame([(\"x\", [1,2,3]), (\"y\", [4,5])], schema) with ArrayType.
2158    /// Must not fail with \"array column value must be null or array\" and must produce correct structure.
2159    #[test]
2160    fn test_issue_601_array_column_pyspark_parity() {
2161        use serde_json::json;
2162
2163        let spark = SparkSession::builder().app_name("test").get_or_create();
2164        let schema = vec![
2165            ("id".to_string(), "string".to_string()),
2166            ("arr".to_string(), "array<bigint>".to_string()),
2167        ];
2168        // Exact PySpark example: rows with string id and list of ints.
2169        let rows: Vec<Vec<JsonValue>> = vec![
2170            vec![json!("x"), json!([1, 2, 3])],
2171            vec![json!("y"), json!([4, 5])],
2172        ];
2173        let df = spark
2174            .create_dataframe_from_rows(rows, schema)
2175            .expect("issue #601: create_dataframe_from_rows must accept array column (JSON array)");
2176        let n = df.count().unwrap();
2177        assert_eq!(n, 2, "issue #601: expected 2 rows");
2178        let collected = df.collect_inner().unwrap();
2179        let arr_col = collected.column("arr").unwrap();
2180        let list = arr_col.list().unwrap();
2181        // Verify list lengths match PySpark [1,2,3] and [4,5]
2182        let row0 = list.get(0).unwrap();
2183        assert_eq!(
2184            row0.len(),
2185            3,
2186            "issue #601: first row arr must have 3 elements [1,2,3]"
2187        );
2188        let row1 = list.get(1).unwrap();
2189        assert_eq!(
2190            row1.len(),
2191            2,
2192            "issue #601: second row arr must have 2 elements [4,5]"
2193        );
2194    }
2195
2196    /// #624: When schema is empty but rows are not, infer schema from rows (PySpark parity).
2197    #[test]
2198    fn test_issue_624_empty_schema_inferred_from_rows() {
2199        use serde_json::json;
2200
2201        let spark = SparkSession::builder().app_name("test").get_or_create();
2202        let schema: Vec<(String, String)> = vec![];
2203        let rows: Vec<Vec<JsonValue>> =
2204            vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
2205        let df = spark
2206            .create_dataframe_from_rows(rows, schema)
2207            .expect("#624: empty schema with non-empty rows should infer schema");
2208        assert_eq!(df.count().unwrap(), 2);
2209        let collected = df.collect_inner().unwrap();
2210        assert_eq!(collected.get_column_names(), &["c0", "c1"]);
2211    }
2212
2213    /// #627: create_dataframe_from_rows accepts map column (dict/object). PySpark MapType parity.
2214    #[test]
2215    fn test_create_dataframe_from_rows_map_column() {
2216        use serde_json::json;
2217
2218        let spark = SparkSession::builder().app_name("test").get_or_create();
2219        let schema = vec![
2220            ("id".to_string(), "integer".to_string()),
2221            ("m".to_string(), "map<string,string>".to_string()),
2222        ];
2223        let rows: Vec<Vec<JsonValue>> = vec![
2224            vec![json!(1), json!({"a": "x", "b": "y"})],
2225            vec![json!(2), json!({"c": "z"})],
2226        ];
2227        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2228        assert_eq!(df.count().unwrap(), 2);
2229        let collected = df.collect_inner().unwrap();
2230        assert_eq!(collected.get_column_names(), &["id", "m"]);
2231        let m_col = collected.column("m").unwrap();
2232        let list = m_col.list().unwrap();
2233        let row0 = list.get(0).unwrap();
2234        assert_eq!(row0.len(), 2, "row 0 map should have 2 entries");
2235        let row1 = list.get(1).unwrap();
2236        assert_eq!(row1.len(), 1, "row 1 map should have 1 entry");
2237    }
2238
2239    /// #625: create_dataframe_from_rows accepts array column as JSON array or Object (Python list parity).
2240    #[test]
2241    fn test_issue_625_array_column_list_or_object() {
2242        use serde_json::json;
2243
2244        let spark = SparkSession::builder().app_name("test").get_or_create();
2245        let schema = vec![
2246            ("id".to_string(), "string".to_string()),
2247            ("arr".to_string(), "array<bigint>".to_string()),
2248        ];
2249        // JSON array (Python list) and Object with "0","1","2" keys (some serializations).
2250        let rows: Vec<Vec<JsonValue>> = vec![
2251            vec![json!("x"), json!([1, 2, 3])],
2252            vec![json!("y"), json!({"0": 4, "1": 5})],
2253        ];
2254        let df = spark
2255            .create_dataframe_from_rows(rows, schema)
2256            .expect("#625: array column must accept list/array or object representation");
2257        assert_eq!(df.count().unwrap(), 2);
2258        let collected = df.collect_inner().unwrap();
2259        let list = collected.column("arr").unwrap().list().unwrap();
2260        assert_eq!(list.get(0).unwrap().len(), 3);
2261        assert_eq!(list.get(1).unwrap().len(), 2);
2262    }
2263
2264    #[test]
2265    fn test_create_dataframe_empty() {
2266        let spark = SparkSession::builder().app_name("test").get_or_create();
2267        let data: Vec<(i64, i64, String)> = vec![];
2268
2269        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
2270
2271        assert!(result.is_ok());
2272        let df = result.unwrap();
2273        assert_eq!(df.count().unwrap(), 0);
2274    }
2275
2276    #[test]
2277    fn test_create_dataframe_from_polars() {
2278        use polars::prelude::df;
2279
2280        let spark = SparkSession::builder().app_name("test").get_or_create();
2281        let polars_df = df!(
2282            "x" => &[1, 2, 3],
2283            "y" => &[4, 5, 6]
2284        )
2285        .unwrap();
2286
2287        let df = spark.create_dataframe_from_polars(polars_df);
2288
2289        assert_eq!(df.count().unwrap(), 3);
2290        let columns = df.columns().unwrap();
2291        assert!(columns.contains(&"x".to_string()));
2292        assert!(columns.contains(&"y".to_string()));
2293    }
2294
2295    #[test]
2296    fn test_read_csv_file_not_found() {
2297        let spark = SparkSession::builder().app_name("test").get_or_create();
2298
2299        let result = spark.read_csv("nonexistent_file.csv");
2300
2301        assert!(result.is_err());
2302    }
2303
2304    #[test]
2305    fn test_read_parquet_file_not_found() {
2306        let spark = SparkSession::builder().app_name("test").get_or_create();
2307
2308        let result = spark.read_parquet("nonexistent_file.parquet");
2309
2310        assert!(result.is_err());
2311    }
2312
2313    #[test]
2314    fn test_read_json_file_not_found() {
2315        let spark = SparkSession::builder().app_name("test").get_or_create();
2316
2317        let result = spark.read_json("nonexistent_file.json");
2318
2319        assert!(result.is_err());
2320    }
2321
2322    #[test]
2323    fn test_rust_udf_dataframe() {
2324        use crate::functions::{call_udf, col};
2325        use polars::prelude::DataType;
2326
2327        let spark = SparkSession::builder().app_name("test").get_or_create();
2328        spark
2329            .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
2330            .unwrap();
2331        let df = spark
2332            .create_dataframe(
2333                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2334                vec!["id", "age", "name"],
2335            )
2336            .unwrap();
2337        let col = call_udf("to_str", &[col("id")]).unwrap();
2338        let df2 = df.with_column("id_str", &col).unwrap();
2339        let cols = df2.columns().unwrap();
2340        assert!(cols.contains(&"id_str".to_string()));
2341        let rows = df2.collect_as_json_rows().unwrap();
2342        assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
2343        assert_eq!(rows[1].get("id_str").and_then(|v| v.as_str()), Some("2"));
2344    }
2345
2346    #[test]
2347    fn test_case_insensitive_filter_select() {
2348        use crate::expression::lit_i64;
2349        use crate::functions::col;
2350
2351        let spark = SparkSession::builder().app_name("test").get_or_create();
2352        let df = spark
2353            .create_dataframe(
2354                vec![
2355                    (1, 25, "Alice".to_string()),
2356                    (2, 30, "Bob".to_string()),
2357                    (3, 35, "Charlie".to_string()),
2358                ],
2359                vec!["Id", "Age", "Name"],
2360            )
2361            .unwrap();
2362        // Filter with lowercase column names (PySpark default: case-insensitive)
2363        let filtered = df
2364            .filter(col("age").gt(lit_i64(26)).expr().clone())
2365            .unwrap()
2366            .select(vec!["name"])
2367            .unwrap();
2368        assert_eq!(filtered.count().unwrap(), 2);
2369        let rows = filtered.collect_as_json_rows().unwrap();
2370        let names: Vec<&str> = rows
2371            .iter()
2372            .map(|r| r.get("name").and_then(|v| v.as_str()).unwrap())
2373            .collect();
2374        assert!(names.contains(&"Bob"));
2375        assert!(names.contains(&"Charlie"));
2376    }
2377
2378    #[test]
2379    fn test_sql_returns_error_without_feature_or_unknown_table() {
2380        let spark = SparkSession::builder().app_name("test").get_or_create();
2381
2382        let result = spark.sql("SELECT * FROM table");
2383
2384        assert!(result.is_err());
2385        match result {
2386            Err(PolarsError::InvalidOperation(msg)) => {
2387                let s = msg.to_string();
2388                // Without sql feature: "SQL queries require the 'sql' feature"
2389                // With sql feature but no table: "Table or view 'table' not found" or parse error
2390                assert!(
2391                    s.contains("SQL") || s.contains("Table") || s.contains("feature"),
2392                    "unexpected message: {s}"
2393                );
2394            }
2395            _ => panic!("Expected InvalidOperation error"),
2396        }
2397    }
2398
2399    #[test]
2400    fn test_spark_session_stop() {
2401        let spark = SparkSession::builder().app_name("test").get_or_create();
2402
2403        // stop() should complete without error
2404        spark.stop();
2405    }
2406
2407    #[test]
2408    fn test_dataframe_reader_api() {
2409        let spark = SparkSession::builder().app_name("test").get_or_create();
2410        let reader = spark.read();
2411
2412        // All readers should return errors for non-existent files
2413        assert!(reader.csv("nonexistent.csv").is_err());
2414        assert!(reader.parquet("nonexistent.parquet").is_err());
2415        assert!(reader.json("nonexistent.json").is_err());
2416    }
2417
2418    #[test]
2419    fn test_read_csv_with_valid_file() {
2420        use std::io::Write;
2421        use tempfile::NamedTempFile;
2422
2423        let spark = SparkSession::builder().app_name("test").get_or_create();
2424
2425        // Create a temporary CSV file
2426        let mut temp_file = NamedTempFile::new().unwrap();
2427        writeln!(temp_file, "id,name,age").unwrap();
2428        writeln!(temp_file, "1,Alice,25").unwrap();
2429        writeln!(temp_file, "2,Bob,30").unwrap();
2430        temp_file.flush().unwrap();
2431
2432        let result = spark.read_csv(temp_file.path());
2433
2434        assert!(result.is_ok());
2435        let df = result.unwrap();
2436        assert_eq!(df.count().unwrap(), 2);
2437
2438        let columns = df.columns().unwrap();
2439        assert!(columns.contains(&"id".to_string()));
2440        assert!(columns.contains(&"name".to_string()));
2441        assert!(columns.contains(&"age".to_string()));
2442    }
2443
2444    #[test]
2445    fn test_read_json_with_valid_file() {
2446        use std::io::Write;
2447        use tempfile::NamedTempFile;
2448
2449        let spark = SparkSession::builder().app_name("test").get_or_create();
2450
2451        // Create a temporary JSONL file
2452        let mut temp_file = NamedTempFile::new().unwrap();
2453        writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
2454        writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
2455        temp_file.flush().unwrap();
2456
2457        let result = spark.read_json(temp_file.path());
2458
2459        assert!(result.is_ok());
2460        let df = result.unwrap();
2461        assert_eq!(df.count().unwrap(), 2);
2462    }
2463
2464    #[test]
2465    fn test_read_csv_empty_file() {
2466        use std::io::Write;
2467        use tempfile::NamedTempFile;
2468
2469        let spark = SparkSession::builder().app_name("test").get_or_create();
2470
2471        // Create an empty CSV file (just header)
2472        let mut temp_file = NamedTempFile::new().unwrap();
2473        writeln!(temp_file, "id,name").unwrap();
2474        temp_file.flush().unwrap();
2475
2476        let result = spark.read_csv(temp_file.path());
2477
2478        assert!(result.is_ok());
2479        let df = result.unwrap();
2480        assert_eq!(df.count().unwrap(), 0);
2481    }
2482
2483    #[test]
2484    fn test_write_partitioned_parquet() {
2485        use crate::dataframe::{WriteFormat, WriteMode};
2486        use std::fs;
2487        use tempfile::TempDir;
2488
2489        let spark = SparkSession::builder().app_name("test").get_or_create();
2490        let df = spark
2491            .create_dataframe(
2492                vec![
2493                    (1, 25, "Alice".to_string()),
2494                    (2, 30, "Bob".to_string()),
2495                    (3, 25, "Carol".to_string()),
2496                ],
2497                vec!["id", "age", "name"],
2498            )
2499            .unwrap();
2500        let dir = TempDir::new().unwrap();
2501        let path = dir.path().join("out");
2502        df.write()
2503            .mode(WriteMode::Overwrite)
2504            .format(WriteFormat::Parquet)
2505            .partition_by(["age"])
2506            .save(&path)
2507            .unwrap();
2508        assert!(path.is_dir());
2509        let entries: Vec<_> = fs::read_dir(&path).unwrap().collect();
2510        assert_eq!(
2511            entries.len(),
2512            2,
2513            "expected two partition dirs (age=25, age=30)"
2514        );
2515        let names: Vec<String> = entries
2516            .iter()
2517            .filter_map(|e| e.as_ref().ok())
2518            .map(|e| e.file_name().to_string_lossy().into_owned())
2519            .collect();
2520        assert!(names.iter().any(|n| n.starts_with("age=")));
2521        let df_read = spark.read_parquet(&path).unwrap();
2522        assert_eq!(df_read.count().unwrap(), 3);
2523    }
2524
2525    #[test]
2526    fn test_save_as_table_error_if_exists() {
2527        use crate::dataframe::SaveMode;
2528
2529        let spark = SparkSession::builder().app_name("test").get_or_create();
2530        let df = spark
2531            .create_dataframe(
2532                vec![(1, 25, "Alice".to_string())],
2533                vec!["id", "age", "name"],
2534            )
2535            .unwrap();
2536        // First call succeeds
2537        df.write()
2538            .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
2539            .unwrap();
2540        assert!(spark.table("t1").is_ok());
2541        assert_eq!(spark.table("t1").unwrap().count().unwrap(), 1);
2542        // Second call with ErrorIfExists fails
2543        let err = df
2544            .write()
2545            .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
2546            .unwrap_err();
2547        assert!(err.to_string().contains("already exists"));
2548    }
2549
2550    #[test]
2551    fn test_save_as_table_overwrite() {
2552        use crate::dataframe::SaveMode;
2553
2554        let spark = SparkSession::builder().app_name("test").get_or_create();
2555        let df1 = spark
2556            .create_dataframe(
2557                vec![(1, 25, "Alice".to_string())],
2558                vec!["id", "age", "name"],
2559            )
2560            .unwrap();
2561        let df2 = spark
2562            .create_dataframe(
2563                vec![(2, 30, "Bob".to_string()), (3, 35, "Carol".to_string())],
2564                vec!["id", "age", "name"],
2565            )
2566            .unwrap();
2567        df1.write()
2568            .save_as_table(&spark, "t_over", SaveMode::ErrorIfExists)
2569            .unwrap();
2570        assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 1);
2571        df2.write()
2572            .save_as_table(&spark, "t_over", SaveMode::Overwrite)
2573            .unwrap();
2574        assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 2);
2575    }
2576
2577    #[test]
2578    fn test_save_as_table_append() {
2579        use crate::dataframe::SaveMode;
2580
2581        let spark = SparkSession::builder().app_name("test").get_or_create();
2582        let df1 = spark
2583            .create_dataframe(
2584                vec![(1, 25, "Alice".to_string())],
2585                vec!["id", "age", "name"],
2586            )
2587            .unwrap();
2588        let df2 = spark
2589            .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
2590            .unwrap();
2591        df1.write()
2592            .save_as_table(&spark, "t_append", SaveMode::ErrorIfExists)
2593            .unwrap();
2594        df2.write()
2595            .save_as_table(&spark, "t_append", SaveMode::Append)
2596            .unwrap();
2597        assert_eq!(spark.table("t_append").unwrap().count().unwrap(), 2);
2598    }
2599
2600    /// Empty DataFrame with explicit schema: saveAsTable(Overwrite) then append one row (issue #495).
2601    #[test]
2602    fn test_save_as_table_empty_df_then_append() {
2603        use crate::dataframe::SaveMode;
2604        use serde_json::json;
2605
2606        let spark = SparkSession::builder().app_name("test").get_or_create();
2607        let schema = vec![
2608            ("id".to_string(), "bigint".to_string()),
2609            ("name".to_string(), "string".to_string()),
2610        ];
2611        let empty_df = spark
2612            .create_dataframe_from_rows(vec![], schema.clone())
2613            .unwrap();
2614        assert_eq!(empty_df.count().unwrap(), 0);
2615
2616        empty_df
2617            .write()
2618            .save_as_table(&spark, "t_empty_append", SaveMode::Overwrite)
2619            .unwrap();
2620        let r1 = spark.table("t_empty_append").unwrap();
2621        assert_eq!(r1.count().unwrap(), 0);
2622        let cols = r1.columns().unwrap();
2623        assert!(cols.contains(&"id".to_string()));
2624        assert!(cols.contains(&"name".to_string()));
2625
2626        let one_row = spark
2627            .create_dataframe_from_rows(vec![vec![json!(1), json!("a")]], schema)
2628            .unwrap();
2629        one_row
2630            .write()
2631            .save_as_table(&spark, "t_empty_append", SaveMode::Append)
2632            .unwrap();
2633        let r2 = spark.table("t_empty_append").unwrap();
2634        assert_eq!(r2.count().unwrap(), 1);
2635    }
2636
2637    /// Empty DataFrame with schema: write.format("parquet").save(path) must not fail (issue #519).
2638    /// PySpark fails with "can not infer schema from empty dataset"; robin-sparkless uses explicit schema.
2639    #[test]
2640    fn test_write_parquet_empty_df_with_schema() {
2641        let spark = SparkSession::builder().app_name("test").get_or_create();
2642        let schema = vec![
2643            ("id".to_string(), "bigint".to_string()),
2644            ("name".to_string(), "string".to_string()),
2645        ];
2646        let empty_df = spark.create_dataframe_from_rows(vec![], schema).unwrap();
2647        assert_eq!(empty_df.count().unwrap(), 0);
2648
2649        let dir = tempfile::TempDir::new().unwrap();
2650        let path = dir.path().join("empty.parquet");
2651        empty_df
2652            .write()
2653            .format(crate::dataframe::WriteFormat::Parquet)
2654            .mode(crate::dataframe::WriteMode::Overwrite)
2655            .save(&path)
2656            .unwrap();
2657        assert!(path.is_file());
2658
2659        // Read back and verify schema preserved
2660        let read_df = spark.read().parquet(path.to_str().unwrap()).unwrap();
2661        assert_eq!(read_df.count().unwrap(), 0);
2662        let cols = read_df.columns().unwrap();
2663        assert!(cols.contains(&"id".to_string()));
2664        assert!(cols.contains(&"name".to_string()));
2665    }
2666
2667    /// Empty DataFrame with schema + warehouse: saveAsTable(Overwrite) then append (issue #495 disk path).
2668    #[test]
2669    fn test_save_as_table_empty_df_warehouse_then_append() {
2670        use crate::dataframe::SaveMode;
2671        use serde_json::json;
2672        use std::sync::atomic::{AtomicU64, Ordering};
2673        use tempfile::TempDir;
2674
2675        static COUNTER: AtomicU64 = AtomicU64::new(0);
2676        let n = COUNTER.fetch_add(1, Ordering::SeqCst);
2677        let dir = TempDir::new().unwrap();
2678        let warehouse = dir.path().join(format!("wh_{n}"));
2679        std::fs::create_dir_all(&warehouse).unwrap();
2680        let spark = SparkSession::builder()
2681            .app_name("test")
2682            .config(
2683                "spark.sql.warehouse.dir",
2684                warehouse.as_os_str().to_str().unwrap(),
2685            )
2686            .get_or_create();
2687
2688        let schema = vec![
2689            ("id".to_string(), "bigint".to_string()),
2690            ("name".to_string(), "string".to_string()),
2691        ];
2692        let empty_df = spark
2693            .create_dataframe_from_rows(vec![], schema.clone())
2694            .unwrap();
2695        empty_df
2696            .write()
2697            .save_as_table(&spark, "t_empty_wh", SaveMode::Overwrite)
2698            .unwrap();
2699        let r1 = spark.table("t_empty_wh").unwrap();
2700        assert_eq!(r1.count().unwrap(), 0);
2701
2702        let one_row = spark
2703            .create_dataframe_from_rows(vec![vec![json!(1), json!("a")]], schema)
2704            .unwrap();
2705        one_row
2706            .write()
2707            .save_as_table(&spark, "t_empty_wh", SaveMode::Append)
2708            .unwrap();
2709        let r2 = spark.table("t_empty_wh").unwrap();
2710        assert_eq!(r2.count().unwrap(), 1);
2711    }
2712
2713    #[test]
2714    fn test_save_as_table_ignore() {
2715        use crate::dataframe::SaveMode;
2716
2717        let spark = SparkSession::builder().app_name("test").get_or_create();
2718        let df1 = spark
2719            .create_dataframe(
2720                vec![(1, 25, "Alice".to_string())],
2721                vec!["id", "age", "name"],
2722            )
2723            .unwrap();
2724        let df2 = spark
2725            .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
2726            .unwrap();
2727        df1.write()
2728            .save_as_table(&spark, "t_ignore", SaveMode::ErrorIfExists)
2729            .unwrap();
2730        df2.write()
2731            .save_as_table(&spark, "t_ignore", SaveMode::Ignore)
2732            .unwrap();
2733        // Still 1 row (ignore did not replace)
2734        assert_eq!(spark.table("t_ignore").unwrap().count().unwrap(), 1);
2735    }
2736
2737    #[test]
2738    fn test_table_resolution_temp_view_first() {
2739        use crate::dataframe::SaveMode;
2740
2741        let spark = SparkSession::builder().app_name("test").get_or_create();
2742        let df_saved = spark
2743            .create_dataframe(
2744                vec![(1, 25, "Saved".to_string())],
2745                vec!["id", "age", "name"],
2746            )
2747            .unwrap();
2748        let df_temp = spark
2749            .create_dataframe(vec![(2, 30, "Temp".to_string())], vec!["id", "age", "name"])
2750            .unwrap();
2751        df_saved
2752            .write()
2753            .save_as_table(&spark, "x", SaveMode::ErrorIfExists)
2754            .unwrap();
2755        spark.create_or_replace_temp_view("x", df_temp);
2756        // table("x") must return temp view (PySpark order)
2757        let t = spark.table("x").unwrap();
2758        let rows = t.collect_as_json_rows().unwrap();
2759        assert_eq!(rows.len(), 1);
2760        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Temp"));
2761    }
2762
2763    /// #629: Exact reproduction – createDataFrame, createOrReplaceTempView, then table() must resolve.
2764    #[test]
2765    fn test_issue_629_temp_view_visible_after_create() {
2766        use serde_json::json;
2767
2768        let spark = SparkSession::builder().app_name("repro").get_or_create();
2769        let schema = vec![
2770            ("id".to_string(), "long".to_string()),
2771            ("name".to_string(), "string".to_string()),
2772        ];
2773        let rows: Vec<Vec<JsonValue>> =
2774            vec![vec![json!(1), json!("a")], vec![json!(2), json!("b")]];
2775        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2776        spark.create_or_replace_temp_view("my_view", df);
2777        let result = spark
2778            .table("my_view")
2779            .unwrap()
2780            .collect_as_json_rows()
2781            .unwrap();
2782        assert_eq!(result.len(), 2);
2783        assert_eq!(result[0].get("id").and_then(|v| v.as_i64()), Some(1));
2784        assert_eq!(result[0].get("name").and_then(|v| v.as_str()), Some("a"));
2785        assert_eq!(result[1].get("id").and_then(|v| v.as_i64()), Some(2));
2786        assert_eq!(result[1].get("name").and_then(|v| v.as_str()), Some("b"));
2787    }
2788
2789    #[test]
2790    fn test_drop_table() {
2791        use crate::dataframe::SaveMode;
2792
2793        let spark = SparkSession::builder().app_name("test").get_or_create();
2794        let df = spark
2795            .create_dataframe(
2796                vec![(1, 25, "Alice".to_string())],
2797                vec!["id", "age", "name"],
2798            )
2799            .unwrap();
2800        df.write()
2801            .save_as_table(&spark, "t_drop", SaveMode::ErrorIfExists)
2802            .unwrap();
2803        assert!(spark.table("t_drop").is_ok());
2804        assert!(spark.drop_table("t_drop"));
2805        assert!(spark.table("t_drop").is_err());
2806        // drop again is no-op, returns false
2807        assert!(!spark.drop_table("t_drop"));
2808    }
2809
2810    #[test]
2811    fn test_global_temp_view_persists_across_sessions() {
2812        // Session 1: create global temp view
2813        let spark1 = SparkSession::builder().app_name("s1").get_or_create();
2814        let df1 = spark1
2815            .create_dataframe(
2816                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2817                vec!["id", "age", "name"],
2818            )
2819            .unwrap();
2820        spark1.create_or_replace_global_temp_view("people", df1);
2821        assert_eq!(
2822            spark1.table("global_temp.people").unwrap().count().unwrap(),
2823            2
2824        );
2825
2826        // Session 2: different session can see global temp view
2827        let spark2 = SparkSession::builder().app_name("s2").get_or_create();
2828        let df2 = spark2.table("global_temp.people").unwrap();
2829        assert_eq!(df2.count().unwrap(), 2);
2830        let rows = df2.collect_as_json_rows().unwrap();
2831        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
2832
2833        // Local temp view in spark2 does not shadow global_temp
2834        let df_local = spark2
2835            .create_dataframe(
2836                vec![(3, 35, "Carol".to_string())],
2837                vec!["id", "age", "name"],
2838            )
2839            .unwrap();
2840        spark2.create_or_replace_temp_view("people", df_local);
2841        // table("people") = local temp view (session resolution)
2842        assert_eq!(spark2.table("people").unwrap().count().unwrap(), 1);
2843        // table("global_temp.people") = global temp view (unchanged)
2844        assert_eq!(
2845            spark2.table("global_temp.people").unwrap().count().unwrap(),
2846            2
2847        );
2848
2849        // Drop global temp view
2850        assert!(spark2.drop_global_temp_view("people"));
2851        assert!(spark2.table("global_temp.people").is_err());
2852    }
2853
2854    #[test]
2855    fn test_warehouse_persistence_between_sessions() {
2856        use crate::dataframe::SaveMode;
2857        use std::fs;
2858        use tempfile::TempDir;
2859
2860        let dir = TempDir::new().unwrap();
2861        let warehouse = dir.path().to_str().unwrap();
2862
2863        // Session 1: save to warehouse
2864        let spark1 = SparkSession::builder()
2865            .app_name("w1")
2866            .config("spark.sql.warehouse.dir", warehouse)
2867            .get_or_create();
2868        let df1 = spark1
2869            .create_dataframe(
2870                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2871                vec!["id", "age", "name"],
2872            )
2873            .unwrap();
2874        df1.write()
2875            .save_as_table(&spark1, "users", SaveMode::ErrorIfExists)
2876            .unwrap();
2877        assert_eq!(spark1.table("users").unwrap().count().unwrap(), 2);
2878
2879        // Session 2: new session reads from warehouse
2880        let spark2 = SparkSession::builder()
2881            .app_name("w2")
2882            .config("spark.sql.warehouse.dir", warehouse)
2883            .get_or_create();
2884        let df2 = spark2.table("users").unwrap();
2885        assert_eq!(df2.count().unwrap(), 2);
2886        let rows = df2.collect_as_json_rows().unwrap();
2887        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
2888
2889        // Verify parquet was written
2890        let table_path = dir.path().join("users");
2891        assert!(table_path.is_dir());
2892        let entries: Vec<_> = fs::read_dir(&table_path).unwrap().collect();
2893        assert!(!entries.is_empty());
2894    }
2895}