Skip to main content

robin_sparkless/
session.rs

1use crate::dataframe::DataFrame;
2use crate::error::EngineError;
3use polars::chunked_array::StructChunked;
4use polars::chunked_array::builder::get_list_builder;
5use polars::prelude::{
6    DataFrame as PlDataFrame, DataType, Field, IntoSeries, NamedFrom, PlSmallStr, PolarsError,
7    Series, TimeUnit,
8};
9use robin_sparkless_expr::UdfRegistry;
10use serde_json::Value as JsonValue;
11use std::cell::RefCell;
12use std::sync::Arc;
13
14/// Parse "array<element_type>" to get inner type string. Returns None if not array<>.
15fn parse_array_element_type(type_str: &str) -> Option<String> {
16    let s = type_str.trim();
17    if !s.to_lowercase().starts_with("array<") || !s.ends_with('>') {
18        return None;
19    }
20    Some(s[6..s.len() - 1].trim().to_string())
21}
22
23/// Parse "struct<field:type,...>" to get field (name, type) pairs. Simple parsing, no nested structs.
24fn parse_struct_fields(type_str: &str) -> Option<Vec<(String, String)>> {
25    let s = type_str.trim();
26    if !s.to_lowercase().starts_with("struct<") || !s.ends_with('>') {
27        return None;
28    }
29    let inner = s[7..s.len() - 1].trim();
30    if inner.is_empty() {
31        return Some(Vec::new());
32    }
33    let mut out = Vec::new();
34    for part in inner.split(',') {
35        let part = part.trim();
36        if let Some(idx) = part.find(':') {
37            let name = part[..idx].trim().to_string();
38            let typ = part[idx + 1..].trim().to_string();
39            out.push((name, typ));
40        }
41    }
42    Some(out)
43}
44
45/// Parse "map<key_type,value_type>" to get (key_type, value_type). Returns None if not map<>.
46/// PySpark: MapType(StringType(), StringType()) -> "map<string,string>".
47fn parse_map_key_value_types(type_str: &str) -> Option<(String, String)> {
48    let s = type_str.trim().to_lowercase();
49    if !s.starts_with("map<") || !s.ends_with('>') {
50        return None;
51    }
52    let inner = s[4..s.len() - 1].trim();
53    let comma = inner.find(',')?;
54    let key_type = inner[..comma].trim().to_string();
55    let value_type = inner[comma + 1..].trim().to_string();
56    Some((key_type, value_type))
57}
58
59/// True if type string is Decimal(precision, scale), e.g. "decimal(10,2)".
60fn is_decimal_type_str(type_str: &str) -> bool {
61    let s = type_str.trim().to_lowercase();
62    s.starts_with("decimal(") && s.contains(')')
63}
64
65/// Map schema type string to Polars DataType (primitives only for nested use).
66/// Decimal(p,s) is mapped to Float64 (Polars dtype-decimal feature not enabled).
67fn json_type_str_to_polars(type_str: &str) -> Option<DataType> {
68    let s = type_str.trim().to_lowercase();
69    if is_decimal_type_str(&s) {
70        return Some(DataType::Float64);
71    }
72    match s.as_str() {
73        "int" | "integer" | "bigint" | "long" => Some(DataType::Int64),
74        "double" | "float" | "double_precision" => Some(DataType::Float64),
75        "string" | "str" | "varchar" => Some(DataType::String),
76        "boolean" | "bool" => Some(DataType::Boolean),
77        _ => None,
78    }
79}
80
81/// Normalize a JSON value to an array for array columns (PySpark parity #625).
82/// Accepts: Array, Object with "0","1",... keys (Python list serialization), String that parses as JSON array.
83/// Returns None for null or when value should be treated as single-element list (#611).
84fn json_value_to_array(v: &JsonValue) -> Option<Vec<JsonValue>> {
85    match v {
86        JsonValue::Null => None,
87        JsonValue::Array(arr) => Some(arr.clone()),
88        JsonValue::Object(obj) => {
89            // Python/serialization sometimes sends list as {"0": x, "1": y}. Build sorted by index.
90            let mut indices: Vec<usize> =
91                obj.keys().filter_map(|k| k.parse::<usize>().ok()).collect();
92            indices.sort_unstable();
93            if indices.is_empty() {
94                return None;
95            }
96            let arr: Vec<JsonValue> = indices
97                .iter()
98                .filter_map(|i| obj.get(&i.to_string()).cloned())
99                .collect();
100            Some(arr)
101        }
102        JsonValue::String(s) => {
103            if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
104                parsed.as_array().cloned()
105            } else {
106                None
107            }
108        }
109        _ => None,
110    }
111}
112
113/// Infer list element type from first non-null array in the column (for schema "list" / "array").
114fn infer_list_element_type(rows: &[Vec<JsonValue>], col_idx: usize) -> Option<(String, DataType)> {
115    for row in rows {
116        let v = row.get(col_idx)?;
117        let arr = json_value_to_array(v)?;
118        let first = arr.first()?;
119        return Some(match first {
120            JsonValue::String(_) => ("string".to_string(), DataType::String),
121            JsonValue::Number(n) => {
122                if n.as_i64().is_some() {
123                    ("bigint".to_string(), DataType::Int64)
124                } else {
125                    ("double".to_string(), DataType::Float64)
126                }
127            }
128            JsonValue::Bool(_) => ("boolean".to_string(), DataType::Boolean),
129            JsonValue::Null => continue,
130            _ => ("string".to_string(), DataType::String),
131        });
132    }
133    None
134}
135
136/// Build a length-N Series from `Vec<Option<JsonValue>>` for a given type (recursive for struct/array).
137fn json_values_to_series(
138    values: &[Option<JsonValue>],
139    type_str: &str,
140    name: &str,
141) -> Result<Series, PolarsError> {
142    use chrono::{NaiveDate, NaiveDateTime};
143    let epoch = crate::date_utils::epoch_naive_date();
144    let type_lower = type_str.trim().to_lowercase();
145
146    if let Some(elem_type) = parse_array_element_type(&type_lower) {
147        let inner_dtype = json_type_str_to_polars(&elem_type).ok_or_else(|| {
148            PolarsError::ComputeError(
149                format!("array element type '{elem_type}' not supported").into(),
150            )
151        })?;
152        let mut builder = get_list_builder(&inner_dtype, 64, values.len(), name.into());
153        for v in values.iter() {
154            if v.as_ref().is_none_or(|x| matches!(x, JsonValue::Null)) {
155                builder.append_null();
156            } else if let Some(arr) = v.as_ref().and_then(json_value_to_array) {
157                // #625: Array, Object with "0","1",..., or string that parses as JSON array (PySpark list parity).
158                let elem_series: Vec<Series> = arr
159                    .iter()
160                    .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
161                    .collect::<Result<Vec<_>, _>>()?;
162                let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
163                let s = Series::from_any_values_and_dtype(
164                    PlSmallStr::EMPTY,
165                    &vals,
166                    &inner_dtype,
167                    false,
168                )
169                .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
170                builder.append_series(&s)?;
171            } else {
172                // #611: PySpark accepts single value as one-element list for array columns.
173                let single_arr = [v.clone().unwrap_or(JsonValue::Null)];
174                let elem_series: Vec<Series> = single_arr
175                    .iter()
176                    .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
177                    .collect::<Result<Vec<_>, _>>()?;
178                let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
179                let arr_series = Series::from_any_values_and_dtype(
180                    PlSmallStr::EMPTY,
181                    &vals,
182                    &inner_dtype,
183                    false,
184                )
185                .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
186                builder.append_series(&arr_series)?;
187            }
188        }
189        return Ok(builder.finish().into_series());
190    }
191
192    if let Some(fields) = parse_struct_fields(&type_lower) {
193        let mut field_series_vec: Vec<Vec<Option<JsonValue>>> = (0..fields.len())
194            .map(|_| Vec::with_capacity(values.len()))
195            .collect();
196        for v in values.iter() {
197            // #610: Accept string that parses as JSON object or array (e.g. Python tuple serialized as "[1, \"y\"]").
198            let effective: Option<JsonValue> = match v.as_ref() {
199                Some(JsonValue::String(s)) => {
200                    if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
201                        if parsed.is_object() || parsed.is_array() {
202                            Some(parsed)
203                        } else {
204                            v.clone()
205                        }
206                    } else {
207                        v.clone()
208                    }
209                }
210                _ => v.clone(),
211            };
212            if effective
213                .as_ref()
214                .is_none_or(|x| matches!(x, JsonValue::Null))
215            {
216                for fc in &mut field_series_vec {
217                    fc.push(None);
218                }
219            } else if let Some(obj) = effective.as_ref().and_then(|x| x.as_object()) {
220                for (fi, (fname, _)) in fields.iter().enumerate() {
221                    field_series_vec[fi].push(obj.get(fname).cloned());
222                }
223            } else if let Some(arr) = effective.as_ref().and_then(|x| x.as_array()) {
224                for (fi, _) in fields.iter().enumerate() {
225                    field_series_vec[fi].push(arr.get(fi).cloned());
226                }
227            } else {
228                return Err(PolarsError::ComputeError(
229                    "struct value must be object (by field name) or array (by position). \
230                     PySpark accepts dict or tuple/list for struct columns."
231                        .into(),
232                ));
233            }
234        }
235        let series_per_field: Vec<Series> = fields
236            .iter()
237            .enumerate()
238            .map(|(fi, (fname, ftype))| json_values_to_series(&field_series_vec[fi], ftype, fname))
239            .collect::<Result<Vec<_>, _>>()?;
240        let field_refs: Vec<&Series> = series_per_field.iter().collect();
241        let st = StructChunked::from_series(name.into(), values.len(), field_refs.iter().copied())
242            .map_err(|e| PolarsError::ComputeError(format!("struct column: {e}").into()))?
243            .into_series();
244        return Ok(st);
245    }
246
247    match type_lower.as_str() {
248        "int" | "bigint" | "long" => {
249            let vals: Vec<Option<i64>> = values
250                .iter()
251                .map(|ov| {
252                    ov.as_ref().and_then(|v| match v {
253                        JsonValue::Number(n) => n.as_i64(),
254                        JsonValue::Null => None,
255                        _ => None,
256                    })
257                })
258                .collect();
259            Ok(Series::new(name.into(), vals))
260        }
261        "double" | "float" => {
262            let vals: Vec<Option<f64>> = values
263                .iter()
264                .map(|ov| {
265                    ov.as_ref().and_then(|v| match v {
266                        JsonValue::Number(n) => n.as_f64(),
267                        JsonValue::Null => None,
268                        _ => None,
269                    })
270                })
271                .collect();
272            Ok(Series::new(name.into(), vals))
273        }
274        "string" | "str" | "varchar" => {
275            let vals: Vec<Option<&str>> = values
276                .iter()
277                .map(|ov| {
278                    ov.as_ref().and_then(|v| match v {
279                        JsonValue::String(s) => Some(s.as_str()),
280                        JsonValue::Null => None,
281                        _ => None,
282                    })
283                })
284                .collect();
285            let owned: Vec<Option<String>> =
286                vals.into_iter().map(|o| o.map(|s| s.to_string())).collect();
287            Ok(Series::new(name.into(), owned))
288        }
289        "boolean" | "bool" => {
290            let vals: Vec<Option<bool>> = values
291                .iter()
292                .map(|ov| {
293                    ov.as_ref().and_then(|v| match v {
294                        JsonValue::Bool(b) => Some(*b),
295                        JsonValue::Null => None,
296                        _ => None,
297                    })
298                })
299                .collect();
300            Ok(Series::new(name.into(), vals))
301        }
302        "date" => {
303            let vals: Vec<Option<i32>> = values
304                .iter()
305                .map(|ov| {
306                    ov.as_ref().and_then(|v| match v {
307                        JsonValue::String(s) => NaiveDate::parse_from_str(s, "%Y-%m-%d")
308                            .ok()
309                            .map(|d| (d - epoch).num_days() as i32),
310                        JsonValue::Null => None,
311                        _ => None,
312                    })
313                })
314                .collect();
315            let s = Series::new(name.into(), vals);
316            s.cast(&DataType::Date)
317                .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))
318        }
319        "timestamp" | "datetime" | "timestamp_ntz" => {
320            let vals: Vec<Option<i64>> = values
321                .iter()
322                .map(|ov| {
323                    ov.as_ref().and_then(|v| match v {
324                        JsonValue::String(s) => {
325                            let parsed = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
326                                .map_err(|e| PolarsError::ComputeError(e.to_string().into()))
327                                .or_else(|_| {
328                                    NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").map_err(
329                                        |e| PolarsError::ComputeError(e.to_string().into()),
330                                    )
331                                })
332                                .or_else(|_| {
333                                    NaiveDate::parse_from_str(s, "%Y-%m-%d")
334                                        .map_err(|e| {
335                                            PolarsError::ComputeError(e.to_string().into())
336                                        })
337                                        .and_then(|d| {
338                                            d.and_hms_opt(0, 0, 0).ok_or_else(|| {
339                                                PolarsError::ComputeError(
340                                                    "date to datetime (0:0:0)".into(),
341                                                )
342                                            })
343                                        })
344                                });
345                            parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
346                        }
347                        JsonValue::Number(n) => n.as_i64(),
348                        JsonValue::Null => None,
349                        _ => None,
350                    })
351                })
352                .collect();
353            let s = Series::new(name.into(), vals);
354            s.cast(&DataType::Datetime(TimeUnit::Microseconds, None))
355                .map_err(|e| PolarsError::ComputeError(format!("datetime cast: {e}").into()))
356        }
357        _ => Err(PolarsError::ComputeError(
358            format!("json_values_to_series: unsupported type '{type_str}'").into(),
359        )),
360    }
361}
362
363/// Build a single Series from a JsonValue for use as list element or struct field.
364fn json_value_to_series_single(
365    value: &JsonValue,
366    type_str: &str,
367    name: &str,
368) -> Result<Series, PolarsError> {
369    use chrono::NaiveDate;
370    let epoch = crate::date_utils::epoch_naive_date();
371    match (value, type_str.trim().to_lowercase().as_str()) {
372        (JsonValue::Null, _) => Ok(Series::new_null(name.into(), 1)),
373        (JsonValue::Number(n), "int" | "bigint" | "long") => {
374            Ok(Series::new(name.into(), vec![n.as_i64()]))
375        }
376        (JsonValue::Number(n), "double" | "float") => {
377            Ok(Series::new(name.into(), vec![n.as_f64()]))
378        }
379        (JsonValue::Number(n), t) if is_decimal_type_str(t) => {
380            Ok(Series::new(name.into(), vec![n.as_f64()]))
381        }
382        (JsonValue::String(s), "string" | "str" | "varchar") => {
383            Ok(Series::new(name.into(), vec![s.as_str()]))
384        }
385        (JsonValue::Bool(b), "boolean" | "bool") => Ok(Series::new(name.into(), vec![*b])),
386        (JsonValue::String(s), "date") => {
387            let d = NaiveDate::parse_from_str(s, "%Y-%m-%d")
388                .map_err(|e| PolarsError::ComputeError(format!("date parse: {e}").into()))?;
389            let days = (d - epoch).num_days() as i32;
390            let s = Series::new(name.into(), vec![days]).cast(&DataType::Date)?;
391            Ok(s)
392        }
393        _ => Err(PolarsError::ComputeError(
394            format!("json_value_to_series: unsupported {type_str} for {value:?}").into(),
395        )),
396    }
397}
398
399/// Build a struct Series from JsonValue::Object or JsonValue::Array (field-order) or Null.
400#[allow(dead_code)]
401fn json_object_or_array_to_struct_series(
402    value: &JsonValue,
403    fields: &[(String, String)],
404    _name: &str,
405) -> Result<Option<Series>, PolarsError> {
406    use polars::prelude::StructChunked;
407    if matches!(value, JsonValue::Null) {
408        return Ok(None);
409    }
410    // #610: Accept string that parses as JSON object or array.
411    let effective = match value {
412        JsonValue::String(s) => {
413            if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
414                if parsed.is_object() || parsed.is_array() {
415                    parsed
416                } else {
417                    value.clone()
418                }
419            } else {
420                value.clone()
421            }
422        }
423        _ => value.clone(),
424    };
425    let mut field_series: Vec<Series> = Vec::with_capacity(fields.len());
426    for (fname, ftype) in fields {
427        let fval = if let Some(obj) = effective.as_object() {
428            obj.get(fname).unwrap_or(&JsonValue::Null)
429        } else if let Some(arr) = effective.as_array() {
430            let idx = field_series.len();
431            arr.get(idx).unwrap_or(&JsonValue::Null)
432        } else {
433            return Err(PolarsError::ComputeError(
434                "struct value must be object (by field name) or array (by position). \
435                 PySpark accepts dict or tuple/list for struct columns."
436                    .into(),
437            ));
438        };
439        let s = json_value_to_series_single(fval, ftype, fname)?;
440        field_series.push(s);
441    }
442    let field_refs: Vec<&Series> = field_series.iter().collect();
443    let st = StructChunked::from_series(PlSmallStr::EMPTY, 1, field_refs.iter().copied())
444        .map_err(|e| PolarsError::ComputeError(format!("struct from value: {e}").into()))?
445        .into_series();
446    Ok(Some(st))
447}
448
449/// Build a single row's map column value as List(Struct{key, value}) element from a JSON object.
450/// PySpark parity #627: create_dataframe_from_rows accepts dict for map columns.
451fn json_object_to_map_struct_series(
452    obj: &serde_json::Map<String, JsonValue>,
453    key_type: &str,
454    value_type: &str,
455    key_dtype: &DataType,
456    value_dtype: &DataType,
457    _name: &str,
458) -> Result<Series, PolarsError> {
459    if obj.is_empty() {
460        let key_series = Series::new("key".into(), Vec::<String>::new());
461        let value_series = Series::new_empty(PlSmallStr::EMPTY, value_dtype);
462        let st = StructChunked::from_series(
463            PlSmallStr::EMPTY,
464            0,
465            [&key_series, &value_series].iter().copied(),
466        )
467        .map_err(|e| PolarsError::ComputeError(format!("map struct empty: {e}").into()))?
468        .into_series();
469        return Ok(st);
470    }
471    let keys: Vec<String> = obj.keys().cloned().collect();
472    let mut value_series = None::<Series>;
473    for v in obj.values() {
474        let s = json_value_to_series_single(v, value_type, "value")?;
475        value_series = Some(match value_series.take() {
476            None => s,
477            Some(mut acc) => {
478                acc.extend(&s).map_err(|e| {
479                    PolarsError::ComputeError(format!("map value extend: {e}").into())
480                })?;
481                acc
482            }
483        });
484    }
485    let value_series =
486        value_series.unwrap_or_else(|| Series::new_empty(PlSmallStr::EMPTY, value_dtype));
487    let key_series = Series::new("key".into(), keys.clone());
488    let key_series = if key_type.trim().to_lowercase().as_str() == "string"
489        || key_type.trim().to_lowercase().as_str() == "str"
490        || key_type.trim().to_lowercase().as_str() == "varchar"
491    {
492        key_series
493    } else {
494        key_series
495            .cast(key_dtype)
496            .map_err(|e| PolarsError::ComputeError(format!("map key cast: {e}").into()))?
497    };
498    let st = StructChunked::from_series(
499        PlSmallStr::EMPTY,
500        key_series.len(),
501        [&key_series, &value_series].iter().copied(),
502    )
503    .map_err(|e| PolarsError::ComputeError(format!("map struct: {e}").into()))?
504    .into_series();
505    Ok(st)
506}
507
508use std::collections::{HashMap, HashSet};
509use std::path::Path;
510use std::sync::{Mutex, OnceLock};
511use std::thread_local;
512
513thread_local! {
514    /// Thread-local SparkSession for UDF resolution in call_udf. Set by get_or_create.
515    static THREAD_UDF_SESSION: RefCell<Option<SparkSession>> = const { RefCell::new(None) };
516}
517
518/// Set the thread-local session for UDF resolution (call_udf). Used by get_or_create.
519pub(crate) fn set_thread_udf_session(session: SparkSession) {
520    robin_sparkless_expr::set_thread_udf_context(
521        Arc::new(session.udf_registry.clone()),
522        session.is_case_sensitive(),
523    );
524    THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = Some(session));
525}
526
527/// Get the thread-local session for UDF resolution. (call_udf uses expr's thread context; this is kept for compatibility.)
528#[allow(dead_code)]
529pub(crate) fn get_thread_udf_session() -> Option<SparkSession> {
530    THREAD_UDF_SESSION.with(|cell| cell.borrow().clone())
531}
532
533/// Clear the thread-local session used for UDF resolution.
534pub(crate) fn clear_thread_udf_session() {
535    THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = None);
536}
537
538/// Catalog of global temporary views (process-scoped). Persists across sessions within the same process.
539/// PySpark: createOrReplaceGlobalTempView / spark.table("global_temp.name").
540static GLOBAL_TEMP_CATALOG: OnceLock<Arc<Mutex<HashMap<String, DataFrame>>>> = OnceLock::new();
541
542fn global_temp_catalog() -> Arc<Mutex<HashMap<String, DataFrame>>> {
543    GLOBAL_TEMP_CATALOG
544        .get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
545        .clone()
546}
547
548/// Builder for creating a SparkSession with configuration options
549#[derive(Clone)]
550pub struct SparkSessionBuilder {
551    app_name: Option<String>,
552    master: Option<String>,
553    config: HashMap<String, String>,
554}
555
556impl Default for SparkSessionBuilder {
557    fn default() -> Self {
558        Self::new()
559    }
560}
561
562impl SparkSessionBuilder {
563    pub fn new() -> Self {
564        SparkSessionBuilder {
565            app_name: None,
566            master: None,
567            config: HashMap::new(),
568        }
569    }
570
571    pub fn app_name(mut self, name: impl Into<String>) -> Self {
572        self.app_name = Some(name.into());
573        self
574    }
575
576    pub fn master(mut self, master: impl Into<String>) -> Self {
577        self.master = Some(master.into());
578        self
579    }
580
581    pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
582        self.config.insert(key.into(), value.into());
583        self
584    }
585
586    pub fn get_or_create(self) -> SparkSession {
587        let session = SparkSession::new(self.app_name, self.master, self.config);
588        set_thread_udf_session(session.clone());
589        session
590    }
591
592    /// Apply configuration from a [`SparklessConfig`](crate::config::SparklessConfig).
593    /// Merges warehouse dir, case sensitivity, and extra keys into the builder config.
594    pub fn with_config(mut self, config: &crate::config::SparklessConfig) -> Self {
595        for (k, v) in config.to_session_config() {
596            self.config.insert(k, v);
597        }
598        self
599    }
600}
601
602/// Catalog of temporary view names to DataFrames (session-scoped). Uses Arc<Mutex<>> for Send+Sync (Python bindings).
603pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
604
605/// Catalog of saved table names to DataFrames (session-scoped). Used by saveAsTable.
606pub type TableCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
607
608/// Names of databases/schemas created via CREATE DATABASE / CREATE SCHEMA (session-scoped). Persisted when SQL DDL runs.
609pub type DatabaseCatalog = Arc<Mutex<HashSet<String>>>;
610
611/// Main entry point for creating DataFrames and executing queries
612/// Similar to PySpark's SparkSession but using Polars as the backend
613#[derive(Clone)]
614pub struct SparkSession {
615    app_name: Option<String>,
616    master: Option<String>,
617    config: HashMap<String, String>,
618    /// Temporary views: name -> DataFrame. Session-scoped; cleared when session is dropped.
619    pub(crate) catalog: TempViewCatalog,
620    /// Saved tables (saveAsTable): name -> DataFrame. Session-scoped; separate namespace from temp views.
621    pub(crate) tables: TableCatalog,
622    /// Databases/schemas created via CREATE DATABASE / CREATE SCHEMA. Session-scoped; used by listDatabases/databaseExists.
623    pub(crate) databases: DatabaseCatalog,
624    /// UDF registry: Rust UDFs. Session-scoped.
625    pub(crate) udf_registry: UdfRegistry,
626}
627
628impl SparkSession {
629    pub fn new(
630        app_name: Option<String>,
631        master: Option<String>,
632        config: HashMap<String, String>,
633    ) -> Self {
634        SparkSession {
635            app_name,
636            master,
637            config,
638            catalog: Arc::new(Mutex::new(HashMap::new())),
639            tables: Arc::new(Mutex::new(HashMap::new())),
640            databases: Arc::new(Mutex::new(HashSet::new())),
641            udf_registry: UdfRegistry::new(),
642        }
643    }
644
645    /// Register a DataFrame as a temporary view (PySpark: createOrReplaceTempView).
646    /// The view is session-scoped and is dropped when the session is dropped.
647    pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
648        let _ = self
649            .catalog
650            .lock()
651            .map(|mut m| m.insert(name.to_string(), df));
652    }
653
654    /// Global temp view (PySpark: createGlobalTempView). Persists across sessions within the same process.
655    pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
656        let _ = global_temp_catalog()
657            .lock()
658            .map(|mut m| m.insert(name.to_string(), df));
659    }
660
661    /// Global temp view (PySpark: createOrReplaceGlobalTempView). Persists across sessions within the same process.
662    pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
663        let _ = global_temp_catalog()
664            .lock()
665            .map(|mut m| m.insert(name.to_string(), df));
666    }
667
668    /// Drop a temporary view by name (PySpark: catalog.dropTempView).
669    /// No error if the view does not exist.
670    pub fn drop_temp_view(&self, name: &str) {
671        let _ = self.catalog.lock().map(|mut m| m.remove(name));
672    }
673
674    /// Drop a global temporary view (PySpark: catalog.dropGlobalTempView). Removes from process-wide catalog.
675    pub fn drop_global_temp_view(&self, name: &str) -> bool {
676        global_temp_catalog()
677            .lock()
678            .map(|mut m| m.remove(name).is_some())
679            .unwrap_or(false)
680    }
681
682    /// Register a DataFrame as a saved table (PySpark: saveAsTable). Inserts into the tables catalog only.
683    pub fn register_table(&self, name: &str, df: DataFrame) {
684        let _ = self
685            .tables
686            .lock()
687            .map(|mut m| m.insert(name.to_string(), df));
688    }
689
690    /// Register a database/schema name (from CREATE DATABASE / CREATE SCHEMA). Persisted in session for listDatabases/databaseExists.
691    pub fn register_database(&self, name: &str) {
692        let _ = self.databases.lock().map(|mut s| {
693            s.insert(name.to_string());
694        });
695    }
696
697    /// List database names: built-in "default", "global_temp", plus any created via CREATE DATABASE / CREATE SCHEMA.
698    pub fn list_database_names(&self) -> Vec<String> {
699        let mut names: Vec<String> = vec!["default".to_string(), "global_temp".to_string()];
700        if let Ok(guard) = self.databases.lock() {
701            let mut created: Vec<String> = guard.iter().cloned().collect();
702            created.sort();
703            names.extend(created);
704        }
705        names
706    }
707
708    /// True if the database name exists (default, global_temp, or created via CREATE DATABASE / CREATE SCHEMA).
709    pub fn database_exists(&self, name: &str) -> bool {
710        if name.eq_ignore_ascii_case("default") || name.eq_ignore_ascii_case("global_temp") {
711            return true;
712        }
713        self.databases
714            .lock()
715            .map(|s| s.iter().any(|n| n.eq_ignore_ascii_case(name)))
716            .unwrap_or(false)
717    }
718
719    /// Get a saved table by name (tables map only). Returns None if not in saved tables (temp views not checked).
720    pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
721        self.tables.lock().ok().and_then(|m| m.get(name).cloned())
722    }
723
724    /// True if the name exists in the saved-tables map (not temp views).
725    pub fn saved_table_exists(&self, name: &str) -> bool {
726        self.tables
727            .lock()
728            .map(|m| m.contains_key(name))
729            .unwrap_or(false)
730    }
731
732    /// Check if a table or temp view exists (PySpark: catalog.tableExists). True if name is in temp views, saved tables, global temp, or warehouse.
733    pub fn table_exists(&self, name: &str) -> bool {
734        // global_temp.xyz
735        if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
736            return global_temp_catalog()
737                .lock()
738                .map(|m| m.contains_key(tbl))
739                .unwrap_or(false);
740        }
741        if self
742            .catalog
743            .lock()
744            .map(|m| m.contains_key(name))
745            .unwrap_or(false)
746        {
747            return true;
748        }
749        if self
750            .tables
751            .lock()
752            .map(|m| m.contains_key(name))
753            .unwrap_or(false)
754        {
755            return true;
756        }
757        // Warehouse fallback
758        if let Some(warehouse) = self.warehouse_dir() {
759            let path = Path::new(warehouse).join(name);
760            if path.is_dir() {
761                return true;
762            }
763        }
764        false
765    }
766
767    /// Return global temp view names (process-scoped). PySpark: catalog.listTables(dbName="global_temp").
768    pub fn list_global_temp_view_names(&self) -> Vec<String> {
769        global_temp_catalog()
770            .lock()
771            .map(|m| m.keys().cloned().collect())
772            .unwrap_or_default()
773    }
774
775    /// Return temporary view names in this session.
776    pub fn list_temp_view_names(&self) -> Vec<String> {
777        self.catalog
778            .lock()
779            .map(|m| m.keys().cloned().collect())
780            .unwrap_or_default()
781    }
782
783    /// Return saved table names in this session (saveAsTable / write_delta_table).
784    pub fn list_table_names(&self) -> Vec<String> {
785        self.tables
786            .lock()
787            .map(|m| m.keys().cloned().collect())
788            .unwrap_or_default()
789    }
790
791    /// Drop a saved table by name (removes from tables catalog only). No-op if not present.
792    pub fn drop_table(&self, name: &str) -> bool {
793        self.tables
794            .lock()
795            .map(|mut m| m.remove(name).is_some())
796            .unwrap_or(false)
797    }
798
799    /// Drop a database/schema by name (from DROP SCHEMA / DROP DATABASE). Removes from registered databases only.
800    /// Does not drop "default" or "global_temp". No-op if not present (or if_exists). Returns true if removed.
801    pub fn drop_database(&self, name: &str) -> bool {
802        if name.eq_ignore_ascii_case("default") || name.eq_ignore_ascii_case("global_temp") {
803            return false;
804        }
805        self.databases
806            .lock()
807            .map(|mut s| s.remove(name))
808            .unwrap_or(false)
809    }
810
811    /// Parse "global_temp.xyz" into ("global_temp", "xyz"). Returns None for plain names.
812    fn parse_global_temp_name(name: &str) -> Option<(&str, &str)> {
813        if let Some(dot) = name.find('.') {
814            let (db, tbl) = name.split_at(dot);
815            if db.eq_ignore_ascii_case("global_temp") {
816                return Some((db, tbl.strip_prefix('.').unwrap_or(tbl)));
817            }
818        }
819        None
820    }
821
822    /// Return spark.sql.warehouse.dir from config if set. Enables disk-backed saveAsTable.
823    pub fn warehouse_dir(&self) -> Option<&str> {
824        self.config
825            .get("spark.sql.warehouse.dir")
826            .map(|s| s.as_str())
827            .filter(|s| !s.is_empty())
828    }
829
830    /// Look up a table or temp view by name (PySpark: table(name)).
831    /// Resolution order: (1) global_temp.xyz from global catalog, (2) temp view, (3) saved table, (4) warehouse.
832    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
833        // global_temp.xyz -> global catalog only
834        if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
835            if let Some(df) = global_temp_catalog()
836                .lock()
837                .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
838                .get(tbl)
839                .cloned()
840            {
841                return Ok(df);
842            }
843            return Err(PolarsError::InvalidOperation(
844                format!(
845                    "Global temp view '{tbl}' not found. Register it with createOrReplaceGlobalTempView."
846                )
847                .into(),
848            ));
849        }
850        // Session: temp view, saved table
851        if let Some(df) = self
852            .catalog
853            .lock()
854            .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
855            .get(name)
856            .cloned()
857        {
858            return Ok(df);
859        }
860        if let Some(df) = self
861            .tables
862            .lock()
863            .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
864            .get(name)
865            .cloned()
866        {
867            return Ok(df);
868        }
869        // Warehouse fallback (disk-backed saveAsTable)
870        if let Some(warehouse) = self.warehouse_dir() {
871            let dir = Path::new(warehouse).join(name);
872            if dir.is_dir() {
873                // Read data.parquet (our convention) or the dir (Polars accepts dirs with parquet files)
874                let data_file = dir.join("data.parquet");
875                let read_path = if data_file.is_file() { data_file } else { dir };
876                return self.read_parquet(&read_path);
877            }
878        }
879        Err(PolarsError::InvalidOperation(
880            format!(
881                "Table or view '{name}' not found. Register it with create_or_replace_temp_view or saveAsTable."
882            )
883            .into(),
884        ))
885    }
886
887    pub fn builder() -> SparkSessionBuilder {
888        SparkSessionBuilder::new()
889    }
890
891    /// Create a session from a [`SparklessConfig`](crate::config::SparklessConfig).
892    /// Equivalent to `SparkSession::builder().with_config(config).get_or_create()`.
893    pub fn from_config(config: &crate::config::SparklessConfig) -> SparkSession {
894        Self::builder().with_config(config).get_or_create()
895    }
896
897    /// Return a reference to the session config (for catalog/conf compatibility).
898    pub fn get_config(&self) -> &HashMap<String, String> {
899        &self.config
900    }
901
902    /// Whether column names are case-sensitive (PySpark: spark.sql.caseSensitive).
903    /// Default is false (case-insensitive matching).
904    pub fn is_case_sensitive(&self) -> bool {
905        self.config
906            .get("spark.sql.caseSensitive")
907            .map(|v| v.eq_ignore_ascii_case("true"))
908            .unwrap_or(false)
909    }
910
911    /// Register a Rust UDF. Session-scoped. Use with call_udf. PySpark: spark.udf.register (Python) or equivalent.
912    pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
913    where
914        F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
915    {
916        self.udf_registry.register_rust_udf(name, f)
917    }
918
919    /// Create a DataFrame from a vector of tuples (i64, i64, String)
920    ///
921    /// # Example
922    /// ```
923    /// use robin_sparkless::session::SparkSession;
924    ///
925    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
926    /// let spark = SparkSession::builder().app_name("test").get_or_create();
927    /// let df = spark.create_dataframe(
928    ///     vec![
929    ///         (1, 25, "Alice".to_string()),
930    ///         (2, 30, "Bob".to_string()),
931    ///     ],
932    ///     vec!["id", "age", "name"],
933    /// )?;
934    /// #     let _ = df;
935    /// #     Ok(())
936    /// # }
937    /// ```
938    pub fn create_dataframe(
939        &self,
940        data: Vec<(i64, i64, String)>,
941        column_names: Vec<&str>,
942    ) -> Result<DataFrame, PolarsError> {
943        if column_names.len() != 3 {
944            return Err(PolarsError::ComputeError(
945                format!(
946                    "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
947                    column_names.len()
948                )
949                .into(),
950            ));
951        }
952
953        let mut cols: Vec<Series> = Vec::with_capacity(3);
954
955        // First column: i64
956        let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
957        cols.push(Series::new(column_names[0].into(), col0));
958
959        // Second column: i64
960        let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
961        cols.push(Series::new(column_names[1].into(), col1));
962
963        // Third column: String
964        let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
965        cols.push(Series::new(column_names[2].into(), col2));
966
967        let pl_df = PlDataFrame::new_infer_height(cols.iter().map(|s| s.clone().into()).collect())?;
968        Ok(DataFrame::from_polars_with_options(
969            pl_df,
970            self.is_case_sensitive(),
971        ))
972    }
973
974    /// Same as [`create_dataframe`](Self::create_dataframe) but returns [`EngineError`]. Use in bindings to avoid Polars.
975    pub fn create_dataframe_engine(
976        &self,
977        data: Vec<(i64, i64, String)>,
978        column_names: Vec<&str>,
979    ) -> Result<DataFrame, EngineError> {
980        self.create_dataframe(data, column_names)
981            .map_err(EngineError::from)
982    }
983
984    /// Create a DataFrame from a Polars DataFrame
985    pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
986        DataFrame::from_polars_with_options(df, self.is_case_sensitive())
987    }
988
989    /// Infer dtype string from a single JSON value (for schema inference). Returns None for Null.
990    fn infer_dtype_from_json_value(v: &JsonValue) -> Option<String> {
991        match v {
992            JsonValue::Null => None,
993            JsonValue::Bool(_) => Some("boolean".to_string()),
994            JsonValue::Number(n) => {
995                if n.is_i64() {
996                    Some("bigint".to_string())
997                } else {
998                    Some("double".to_string())
999                }
1000            }
1001            JsonValue::String(s) => {
1002                if chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d").is_ok() {
1003                    Some("date".to_string())
1004                } else if chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f").is_ok()
1005                    || chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").is_ok()
1006                {
1007                    Some("timestamp".to_string())
1008                } else {
1009                    Some("string".to_string())
1010                }
1011            }
1012            JsonValue::Array(_) => Some("array".to_string()),
1013            JsonValue::Object(_) => Some("string".to_string()), // struct inference not implemented; treat as string for safety
1014        }
1015    }
1016
1017    /// Infer schema (name, dtype_str) from JSON rows by scanning the first non-null value per column.
1018    /// Used by createDataFrame(data, schema=None) when schema is omitted or only column names given.
1019    pub fn infer_schema_from_json_rows(
1020        rows: &[Vec<JsonValue>],
1021        names: &[String],
1022    ) -> Vec<(String, String)> {
1023        if names.is_empty() {
1024            return Vec::new();
1025        }
1026        let mut schema: Vec<(String, String)> = names
1027            .iter()
1028            .map(|n| (n.clone(), "string".to_string()))
1029            .collect();
1030        for (col_idx, (_, dtype_str)) in schema.iter_mut().enumerate() {
1031            for row in rows {
1032                let v = row.get(col_idx).unwrap_or(&JsonValue::Null);
1033                if let Some(dtype) = Self::infer_dtype_from_json_value(v) {
1034                    *dtype_str = dtype;
1035                    break;
1036                }
1037            }
1038        }
1039        schema
1040    }
1041
1042    /// Create a DataFrame from rows and a schema (arbitrary column count and types).
1043    ///
1044    /// `rows`: each inner vec is one row; length must match schema length. Values are JSON-like (i64, f64, string, bool, null, object, array).
1045    /// `schema`: list of (column_name, dtype_string), e.g. `[("id", "bigint"), ("name", "string")]`.
1046    /// Supported dtype strings: bigint, int, long, double, float, string, str, varchar, boolean, bool, date, timestamp, datetime, list, array, array<element_type>, struct<field:type,...>.
1047    /// When `rows` is empty and `schema` is non-empty, returns an empty DataFrame with that schema (issue #519). Use with `write.format("parquet").saveAsTable(...)` then append; PySpark would fail with "can not infer schema from empty dataset".
1048    pub fn create_dataframe_from_rows(
1049        &self,
1050        rows: Vec<Vec<JsonValue>>,
1051        schema: Vec<(String, String)>,
1052    ) -> Result<DataFrame, PolarsError> {
1053        // #624: When schema is empty but rows are not, infer schema from rows (PySpark parity).
1054        let schema = if schema.is_empty() && !rows.is_empty() {
1055            let ncols = rows[0].len();
1056            let names: Vec<String> = (0..ncols).map(|i| format!("c{i}")).collect();
1057            Self::infer_schema_from_json_rows(&rows, &names)
1058        } else {
1059            schema
1060        };
1061
1062        if schema.is_empty() {
1063            if rows.is_empty() {
1064                return Ok(DataFrame::from_polars_with_options(
1065                    PlDataFrame::new(0, vec![])?,
1066                    self.is_case_sensitive(),
1067                ));
1068            }
1069            return Err(PolarsError::InvalidOperation(
1070                "create_dataframe_from_rows: schema must not be empty when rows are not empty"
1071                    .into(),
1072            ));
1073        }
1074        use chrono::{NaiveDate, NaiveDateTime};
1075
1076        let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
1077
1078        for (col_idx, (name, type_str)) in schema.iter().enumerate() {
1079            let type_lower = type_str.trim().to_lowercase();
1080            let s = match type_lower.as_str() {
1081                "int" | "integer" | "bigint" | "long" => {
1082                    let vals: Vec<Option<i64>> = rows
1083                        .iter()
1084                        .map(|row| {
1085                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1086                            match v {
1087                                JsonValue::Number(n) => n.as_i64(),
1088                                JsonValue::Null => None,
1089                                _ => None,
1090                            }
1091                        })
1092                        .collect();
1093                    Series::new(name.as_str().into(), vals)
1094                }
1095                "double" | "float" | "double_precision" => {
1096                    let vals: Vec<Option<f64>> = rows
1097                        .iter()
1098                        .map(|row| {
1099                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1100                            match v {
1101                                JsonValue::Number(n) => n.as_f64(),
1102                                JsonValue::Null => None,
1103                                _ => None,
1104                            }
1105                        })
1106                        .collect();
1107                    Series::new(name.as_str().into(), vals)
1108                }
1109                _ if is_decimal_type_str(&type_lower) => {
1110                    let vals: Vec<Option<f64>> = rows
1111                        .iter()
1112                        .map(|row| {
1113                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1114                            match v {
1115                                JsonValue::Number(n) => n.as_f64(),
1116                                JsonValue::Null => None,
1117                                _ => None,
1118                            }
1119                        })
1120                        .collect();
1121                    Series::new(name.as_str().into(), vals)
1122                }
1123                "string" | "str" | "varchar" => {
1124                    let vals: Vec<Option<String>> = rows
1125                        .iter()
1126                        .map(|row| {
1127                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1128                            match v {
1129                                JsonValue::String(s) => Some(s),
1130                                JsonValue::Null => None,
1131                                other => Some(other.to_string()),
1132                            }
1133                        })
1134                        .collect();
1135                    Series::new(name.as_str().into(), vals)
1136                }
1137                "boolean" | "bool" => {
1138                    let vals: Vec<Option<bool>> = rows
1139                        .iter()
1140                        .map(|row| {
1141                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1142                            match v {
1143                                JsonValue::Bool(b) => Some(b),
1144                                JsonValue::Null => None,
1145                                _ => None,
1146                            }
1147                        })
1148                        .collect();
1149                    Series::new(name.as_str().into(), vals)
1150                }
1151                "date" => {
1152                    let epoch = crate::date_utils::epoch_naive_date();
1153                    let vals: Vec<Option<i32>> = rows
1154                        .iter()
1155                        .map(|row| {
1156                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1157                            match v {
1158                                JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
1159                                    .ok()
1160                                    .map(|d| (d - epoch).num_days() as i32),
1161                                JsonValue::Null => None,
1162                                _ => None,
1163                            }
1164                        })
1165                        .collect();
1166                    let series = Series::new(name.as_str().into(), vals);
1167                    series
1168                        .cast(&DataType::Date)
1169                        .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
1170                }
1171                "timestamp" | "datetime" | "timestamp_ntz" => {
1172                    let vals: Vec<Option<i64>> =
1173                        rows.iter()
1174                            .map(|row| {
1175                                let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1176                                match v {
1177                                    JsonValue::String(s) => {
1178                                        let parsed = NaiveDateTime::parse_from_str(
1179                                            &s,
1180                                            "%Y-%m-%dT%H:%M:%S%.f",
1181                                        )
1182                                        .map_err(|e| {
1183                                            PolarsError::ComputeError(e.to_string().into())
1184                                        })
1185                                        .or_else(|_| {
1186                                            NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
1187                                                .map_err(|e| {
1188                                                    PolarsError::ComputeError(e.to_string().into())
1189                                                })
1190                                        })
1191                                        .or_else(|_| {
1192                                            NaiveDate::parse_from_str(&s, "%Y-%m-%d")
1193                                                .map_err(|e| {
1194                                                    PolarsError::ComputeError(e.to_string().into())
1195                                                })
1196                                                .and_then(|d| {
1197                                                    d.and_hms_opt(0, 0, 0).ok_or_else(|| {
1198                                                        PolarsError::ComputeError(
1199                                                            "date to datetime (0:0:0)".into(),
1200                                                        )
1201                                                    })
1202                                                })
1203                                        });
1204                                        parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
1205                                    }
1206                                    JsonValue::Number(n) => n.as_i64(),
1207                                    JsonValue::Null => None,
1208                                    _ => None,
1209                                }
1210                            })
1211                            .collect();
1212                    let series = Series::new(name.as_str().into(), vals);
1213                    series
1214                        .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
1215                        .map_err(|e| {
1216                            PolarsError::ComputeError(format!("datetime cast: {e}").into())
1217                        })?
1218                }
1219                "list" | "array" => {
1220                    // PySpark parity: ("col", "list") or ("col", "array"); infer element type from first non-null array.
1221                    let (elem_type, inner_dtype) = infer_list_element_type(&rows, col_idx)
1222                        .unwrap_or(("bigint".to_string(), DataType::Int64));
1223                    let n = rows.len();
1224                    let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
1225                    for row in rows.iter() {
1226                        let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1227                        if let JsonValue::Null = &v {
1228                            builder.append_null();
1229                        } else if let Some(arr) = json_value_to_array(&v) {
1230                            // #625: Array, Object with "0","1",..., or string that parses as JSON array (PySpark list parity).
1231                            let elem_series: Vec<Series> = arr
1232                                .iter()
1233                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1234                                .collect::<Result<Vec<_>, _>>()?;
1235                            let vals: Vec<_> =
1236                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1237                            let s = Series::from_any_values_and_dtype(
1238                                PlSmallStr::EMPTY,
1239                                &vals,
1240                                &inner_dtype,
1241                                false,
1242                            )
1243                            .map_err(|e| {
1244                                PolarsError::ComputeError(format!("array elem: {e}").into())
1245                            })?;
1246                            builder.append_series(&s)?;
1247                        } else {
1248                            // #611: PySpark accepts single value as one-element list.
1249                            let single_arr = [v];
1250                            let elem_series: Vec<Series> = single_arr
1251                                .iter()
1252                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1253                                .collect::<Result<Vec<_>, _>>()?;
1254                            let vals: Vec<_> =
1255                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1256                            let s = Series::from_any_values_and_dtype(
1257                                PlSmallStr::EMPTY,
1258                                &vals,
1259                                &inner_dtype,
1260                                false,
1261                            )
1262                            .map_err(|e| {
1263                                PolarsError::ComputeError(format!("array elem: {e}").into())
1264                            })?;
1265                            builder.append_series(&s)?;
1266                        }
1267                    }
1268                    builder.finish().into_series()
1269                }
1270                _ if parse_array_element_type(&type_lower).is_some() => {
1271                    let elem_type = parse_array_element_type(&type_lower).unwrap_or_else(|| {
1272                        unreachable!("guard above ensures parse_array_element_type returned Some")
1273                    });
1274                    let inner_dtype = json_type_str_to_polars(&elem_type)
1275                        .ok_or_else(|| {
1276                            PolarsError::ComputeError(
1277                                format!(
1278                                    "create_dataframe_from_rows: array element type '{elem_type}' not supported"
1279                                )
1280                                .into(),
1281                            )
1282                        })?;
1283                    let n = rows.len();
1284                    let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
1285                    for row in rows.iter() {
1286                        let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1287                        if let JsonValue::Null = &v {
1288                            builder.append_null();
1289                        } else if let Some(arr) = json_value_to_array(&v) {
1290                            // #625: Array, Object with "0","1",..., or string that parses as JSON array (PySpark list parity).
1291                            let elem_series: Vec<Series> = arr
1292                                .iter()
1293                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1294                                .collect::<Result<Vec<_>, _>>()?;
1295                            let vals: Vec<_> =
1296                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1297                            let s = Series::from_any_values_and_dtype(
1298                                PlSmallStr::EMPTY,
1299                                &vals,
1300                                &inner_dtype,
1301                                false,
1302                            )
1303                            .map_err(|e| {
1304                                PolarsError::ComputeError(format!("array elem: {e}").into())
1305                            })?;
1306                            builder.append_series(&s)?;
1307                        } else {
1308                            // #611: PySpark accepts single value as one-element list.
1309                            let single_arr = [v];
1310                            let elem_series: Vec<Series> = single_arr
1311                                .iter()
1312                                .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1313                                .collect::<Result<Vec<_>, _>>()?;
1314                            let vals: Vec<_> =
1315                                elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1316                            let s = Series::from_any_values_and_dtype(
1317                                PlSmallStr::EMPTY,
1318                                &vals,
1319                                &inner_dtype,
1320                                false,
1321                            )
1322                            .map_err(|e| {
1323                                PolarsError::ComputeError(format!("array elem: {e}").into())
1324                            })?;
1325                            builder.append_series(&s)?;
1326                        }
1327                    }
1328                    builder.finish().into_series()
1329                }
1330                _ if parse_map_key_value_types(&type_lower).is_some() => {
1331                    let (key_type, value_type) = parse_map_key_value_types(&type_lower)
1332                        .unwrap_or_else(|| unreachable!("guard ensures Some"));
1333                    let key_dtype = json_type_str_to_polars(&key_type).ok_or_else(|| {
1334                        PolarsError::ComputeError(
1335                            format!(
1336                                "create_dataframe_from_rows: map key type '{key_type}' not supported"
1337                            )
1338                            .into(),
1339                        )
1340                    })?;
1341                    let value_dtype = json_type_str_to_polars(&value_type).ok_or_else(|| {
1342                        PolarsError::ComputeError(
1343                            format!(
1344                                "create_dataframe_from_rows: map value type '{value_type}' not supported"
1345                            )
1346                            .into(),
1347                        )
1348                    })?;
1349                    let struct_dtype = DataType::Struct(vec![
1350                        Field::new("key".into(), key_dtype.clone()),
1351                        Field::new("value".into(), value_dtype.clone()),
1352                    ]);
1353                    let n = rows.len();
1354                    let mut builder = get_list_builder(&struct_dtype, 64, n, name.as_str().into());
1355                    for row in rows.iter() {
1356                        let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1357                        if matches!(v, JsonValue::Null) {
1358                            builder.append_null();
1359                        } else if let Some(obj) = v.as_object() {
1360                            let st = json_object_to_map_struct_series(
1361                                obj,
1362                                &key_type,
1363                                &value_type,
1364                                &key_dtype,
1365                                &value_dtype,
1366                                name,
1367                            )?;
1368                            builder.append_series(&st)?;
1369                        } else {
1370                            return Err(PolarsError::ComputeError(
1371                                format!(
1372                                    "create_dataframe_from_rows: map column '{name}' expects JSON object (dict), got {:?}",
1373                                    v
1374                                )
1375                                .into(),
1376                            ));
1377                        }
1378                    }
1379                    builder.finish().into_series()
1380                }
1381                _ if parse_struct_fields(&type_lower).is_some() => {
1382                    let values: Vec<Option<JsonValue>> =
1383                        rows.iter().map(|row| row.get(col_idx).cloned()).collect();
1384                    json_values_to_series(&values, &type_lower, name)?
1385                }
1386                _ => {
1387                    return Err(PolarsError::ComputeError(
1388                        format!(
1389                            "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
1390                        )
1391                        .into(),
1392                    ));
1393                }
1394            };
1395            cols.push(s);
1396        }
1397
1398        let pl_df = PlDataFrame::new_infer_height(cols.iter().map(|s| s.clone().into()).collect())?;
1399        Ok(DataFrame::from_polars_with_options(
1400            pl_df,
1401            self.is_case_sensitive(),
1402        ))
1403    }
1404
1405    /// Same as [`create_dataframe_from_rows`](Self::create_dataframe_from_rows) but returns [`EngineError`]. Use in bindings to avoid Polars.
1406    pub fn create_dataframe_from_rows_engine(
1407        &self,
1408        rows: Vec<Vec<JsonValue>>,
1409        schema: Vec<(String, String)>,
1410    ) -> Result<DataFrame, EngineError> {
1411        self.create_dataframe_from_rows(rows, schema)
1412            .map_err(EngineError::from)
1413    }
1414
1415    /// Create a DataFrame with a single column `id` (bigint) containing values from start to end (exclusive) with step.
1416    /// PySpark: spark.range(end) or spark.range(start, end, step).
1417    ///
1418    /// - `range(end)` → 0 to end-1, step 1
1419    /// - `range(start, end)` → start to end-1, step 1
1420    /// - `range(start, end, step)` → start, start+step, ... up to but not including end
1421    pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
1422        if step == 0 {
1423            return Err(PolarsError::InvalidOperation(
1424                "range: step must not be 0".into(),
1425            ));
1426        }
1427        let mut vals: Vec<i64> = Vec::new();
1428        let mut v = start;
1429        if step > 0 {
1430            while v < end {
1431                vals.push(v);
1432                v = v.saturating_add(step);
1433            }
1434        } else {
1435            while v > end {
1436                vals.push(v);
1437                v = v.saturating_add(step);
1438            }
1439        }
1440        let col = Series::new("id".into(), vals);
1441        let pl_df = PlDataFrame::new_infer_height(vec![col.into()])?;
1442        Ok(DataFrame::from_polars_with_options(
1443            pl_df,
1444            self.is_case_sensitive(),
1445        ))
1446    }
1447
1448    /// Read a CSV file.
1449    ///
1450    /// Uses Polars' CSV reader with default options:
1451    /// - Header row is inferred (default: true)
1452    /// - Schema is inferred from first 100 rows
1453    ///
1454    /// # Example
1455    /// ```
1456    /// use robin_sparkless::SparkSession;
1457    ///
1458    /// let spark = SparkSession::builder().app_name("test").get_or_create();
1459    /// let df_result = spark.read_csv("data.csv");
1460    /// // Handle the Result as appropriate in your application
1461    /// ```
1462    pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1463        use polars::prelude::*;
1464        let path = path.as_ref();
1465        if !path.exists() {
1466            return Err(PolarsError::ComputeError(
1467                format!("read_csv: file not found: {}", path.display()).into(),
1468            ));
1469        }
1470        let path_display = path.display();
1471        // Use LazyCsvReader - call finish() to get LazyFrame, then collect
1472        let pl_path = PlRefPath::try_from_path(path).map_err(|e| {
1473            PolarsError::ComputeError(format!("read_csv({path_display}): path: {e}").into())
1474        })?;
1475        let lf = LazyCsvReader::new(pl_path)
1476            .with_has_header(true)
1477            .with_infer_schema_length(Some(100))
1478            .finish()
1479            .map_err(|e| {
1480                PolarsError::ComputeError(
1481                    format!(
1482                        "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
1483                    )
1484                    .into(),
1485                )
1486            })?;
1487        Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1488            lf,
1489            self.is_case_sensitive(),
1490        ))
1491    }
1492
1493    /// Same as [`read_csv`](Self::read_csv) but returns [`EngineError`]. Use in bindings to avoid Polars.
1494    pub fn read_csv_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1495        self.read_csv(path).map_err(EngineError::from)
1496    }
1497
1498    /// Read a Parquet file.
1499    ///
1500    /// Uses Polars' Parquet reader. Parquet files have embedded schema, so
1501    /// schema inference is automatic.
1502    ///
1503    /// # Example
1504    /// ```
1505    /// use robin_sparkless::SparkSession;
1506    ///
1507    /// let spark = SparkSession::builder().app_name("test").get_or_create();
1508    /// let df_result = spark.read_parquet("data.parquet");
1509    /// // Handle the Result as appropriate in your application
1510    /// ```
1511    pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1512        use polars::prelude::*;
1513        let path = path.as_ref();
1514        if !path.exists() {
1515            return Err(PolarsError::ComputeError(
1516                format!("read_parquet: file not found: {}", path.display()).into(),
1517            ));
1518        }
1519        // Use LazyFrame::scan_parquet
1520        let pl_path = PlRefPath::try_from_path(path)
1521            .map_err(|e| PolarsError::ComputeError(format!("read_parquet: path: {e}").into()))?;
1522        let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())?;
1523        Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1524            lf,
1525            self.is_case_sensitive(),
1526        ))
1527    }
1528
1529    /// Same as [`read_parquet`](Self::read_parquet) but returns [`EngineError`]. Use in bindings to avoid Polars.
1530    pub fn read_parquet_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1531        self.read_parquet(path).map_err(EngineError::from)
1532    }
1533
1534    /// Read a JSON file (JSONL format - one JSON object per line).
1535    ///
1536    /// Uses Polars' JSONL reader with default options:
1537    /// - Schema is inferred from first 100 rows
1538    ///
1539    /// # Example
1540    /// ```
1541    /// use robin_sparkless::SparkSession;
1542    ///
1543    /// let spark = SparkSession::builder().app_name("test").get_or_create();
1544    /// let df_result = spark.read_json("data.json");
1545    /// // Handle the Result as appropriate in your application
1546    /// ```
1547    pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1548        use polars::prelude::*;
1549        use std::num::NonZeroUsize;
1550        let path = path.as_ref();
1551        if !path.exists() {
1552            return Err(PolarsError::ComputeError(
1553                format!("read_json: file not found: {}", path.display()).into(),
1554            ));
1555        }
1556        // Use LazyJsonLineReader - call finish() to get LazyFrame, then collect
1557        let pl_path = PlRefPath::try_from_path(path)
1558            .map_err(|e| PolarsError::ComputeError(format!("read_json: path: {e}").into()))?;
1559        let lf = LazyJsonLineReader::new(pl_path)
1560            .with_infer_schema_length(NonZeroUsize::new(100))
1561            .finish()?;
1562        Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1563            lf,
1564            self.is_case_sensitive(),
1565        ))
1566    }
1567
1568    /// Same as [`read_json`](Self::read_json) but returns [`EngineError`]. Use in bindings to avoid Polars.
1569    pub fn read_json_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1570        self.read_json(path).map_err(EngineError::from)
1571    }
1572
1573    /// Execute a SQL query (SELECT only). Tables must be registered with `create_or_replace_temp_view`.
1574    /// Requires the `sql` feature. Supports: SELECT (columns or *), FROM (single table or JOIN),
1575    /// WHERE (basic predicates), GROUP BY + aggregates, ORDER BY, LIMIT.
1576    #[cfg(feature = "sql")]
1577    pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
1578        crate::sql::execute_sql(self, query)
1579    }
1580
1581    /// Execute a SQL query (stub when `sql` feature is disabled).
1582    #[cfg(not(feature = "sql"))]
1583    pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
1584        Err(PolarsError::InvalidOperation(
1585            "SQL queries require the 'sql' feature. Build with --features sql.".into(),
1586        ))
1587    }
1588
1589    /// Same as [`table`](Self::table) but returns [`EngineError`]. Use in bindings to avoid Polars.
1590    pub fn table_engine(&self, name: &str) -> Result<DataFrame, EngineError> {
1591        self.table(name).map_err(EngineError::from)
1592    }
1593
1594    /// Returns true if the string looks like a filesystem path (has separators or path exists).
1595    fn looks_like_path(s: &str) -> bool {
1596        s.contains('/') || s.contains('\\') || Path::new(s).exists()
1597    }
1598
1599    /// Read a Delta table from path (latest version). Internal; use read_delta(name_or_path: &str) for dispatch.
1600    #[cfg(feature = "delta")]
1601    pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1602        crate::delta::read_delta(path, self.is_case_sensitive())
1603    }
1604
1605    /// Read Delta table at path, optional version. Internal; use read_delta_str for dispatch.
1606    #[cfg(feature = "delta")]
1607    pub fn read_delta_path_with_version(
1608        &self,
1609        path: impl AsRef<Path>,
1610        version: Option<i64>,
1611    ) -> Result<DataFrame, PolarsError> {
1612        crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
1613    }
1614
1615    /// Read a Delta table or in-memory table by name/path. If name_or_path looks like a path, reads from Delta on disk; else resolves as table name (temp view then saved table).
1616    #[cfg(feature = "delta")]
1617    pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1618        if Self::looks_like_path(name_or_path) {
1619            self.read_delta_path(Path::new(name_or_path))
1620        } else {
1621            self.table(name_or_path)
1622        }
1623    }
1624
1625    #[cfg(feature = "delta")]
1626    pub fn read_delta_with_version(
1627        &self,
1628        name_or_path: &str,
1629        version: Option<i64>,
1630    ) -> Result<DataFrame, PolarsError> {
1631        if Self::looks_like_path(name_or_path) {
1632            self.read_delta_path_with_version(Path::new(name_or_path), version)
1633        } else {
1634            // In-memory tables have no version; ignore version and return table
1635            self.table(name_or_path)
1636        }
1637    }
1638
1639    /// Stub when `delta` feature is disabled. Still supports reading by table name.
1640    #[cfg(not(feature = "delta"))]
1641    pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1642        if Self::looks_like_path(name_or_path) {
1643            Err(PolarsError::InvalidOperation(
1644                "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1645            ))
1646        } else {
1647            self.table(name_or_path)
1648        }
1649    }
1650
1651    #[cfg(not(feature = "delta"))]
1652    pub fn read_delta_with_version(
1653        &self,
1654        name_or_path: &str,
1655        version: Option<i64>,
1656    ) -> Result<DataFrame, PolarsError> {
1657        let _ = version;
1658        self.read_delta(name_or_path)
1659    }
1660
1661    /// Path-only read_delta (for DataFrameReader.load/format delta). Requires delta feature.
1662    #[cfg(feature = "delta")]
1663    pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1664        self.read_delta_path(path)
1665    }
1666
1667    #[cfg(not(feature = "delta"))]
1668    pub fn read_delta_from_path(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1669        Err(PolarsError::InvalidOperation(
1670            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1671        ))
1672    }
1673
1674    /// Stop the session (cleanup resources)
1675    pub fn stop(&self) {
1676        // Best-effort cleanup. This is primarily for PySpark parity so that `spark.stop()`
1677        // exists and can be called in teardown.
1678        let _ = self.catalog.lock().map(|mut m| m.clear());
1679        let _ = self.tables.lock().map(|mut m| m.clear());
1680        let _ = self.databases.lock().map(|mut s| s.clear());
1681        let _ = self.udf_registry.clear();
1682        clear_thread_udf_session();
1683    }
1684}
1685
1686/// DataFrameReader for reading various file formats
1687/// Similar to PySpark's DataFrameReader with option/options/format/load/table
1688pub struct DataFrameReader {
1689    session: SparkSession,
1690    options: HashMap<String, String>,
1691    format: Option<String>,
1692}
1693
1694impl DataFrameReader {
1695    pub fn new(session: SparkSession) -> Self {
1696        DataFrameReader {
1697            session,
1698            options: HashMap::new(),
1699            format: None,
1700        }
1701    }
1702
1703    /// Add a single option (PySpark: option(key, value)). Returns self for chaining.
1704    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1705        self.options.insert(key.into(), value.into());
1706        self
1707    }
1708
1709    /// Add multiple options (PySpark: options(**kwargs)). Returns self for chaining.
1710    pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
1711        for (k, v) in opts {
1712            self.options.insert(k, v);
1713        }
1714        self
1715    }
1716
1717    /// Set the format for load() (PySpark: format("parquet") etc).
1718    pub fn format(mut self, fmt: impl Into<String>) -> Self {
1719        self.format = Some(fmt.into());
1720        self
1721    }
1722
1723    /// Set the schema (PySpark: schema(schema)). Stub: stores but does not apply yet.
1724    pub fn schema(self, _schema: impl Into<String>) -> Self {
1725        self
1726    }
1727
1728    /// Load data from path using format (or infer from extension) and options.
1729    pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1730        let path = path.as_ref();
1731        let fmt = self.format.clone().or_else(|| {
1732            path.extension()
1733                .and_then(|e| e.to_str())
1734                .map(|s| s.to_lowercase())
1735        });
1736        match fmt.as_deref() {
1737            Some("parquet") => self.parquet(path),
1738            Some("csv") => self.csv(path),
1739            Some("json") | Some("jsonl") => self.json(path),
1740            #[cfg(feature = "delta")]
1741            Some("delta") => self.session.read_delta_from_path(path),
1742            _ => Err(PolarsError::ComputeError(
1743                format!(
1744                    "load: could not infer format for path '{}'. Use format('parquet'|'csv'|'json') before load.",
1745                    path.display()
1746                )
1747                .into(),
1748            )),
1749        }
1750    }
1751
1752    /// Return the named table/view (PySpark: table(name)).
1753    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
1754        self.session.table(name)
1755    }
1756
1757    fn apply_csv_options(
1758        &self,
1759        reader: polars::prelude::LazyCsvReader,
1760    ) -> polars::prelude::LazyCsvReader {
1761        use polars::prelude::NullValues;
1762        let mut r = reader;
1763        if let Some(v) = self.options.get("header") {
1764            let has_header = v.eq_ignore_ascii_case("true") || v == "1";
1765            r = r.with_has_header(has_header);
1766        }
1767        if let Some(v) = self.options.get("inferSchema") {
1768            if v.eq_ignore_ascii_case("true") || v == "1" {
1769                let n = self
1770                    .options
1771                    .get("inferSchemaLength")
1772                    .and_then(|s| s.parse::<usize>().ok())
1773                    .unwrap_or(100);
1774                r = r.with_infer_schema_length(Some(n));
1775            } else {
1776                // inferSchema=false: do not infer types (PySpark parity #543)
1777                r = r.with_infer_schema_length(Some(0));
1778            }
1779        } else if let Some(v) = self.options.get("inferSchemaLength") {
1780            if let Ok(n) = v.parse::<usize>() {
1781                r = r.with_infer_schema_length(Some(n));
1782            }
1783        }
1784        if let Some(sep) = self.options.get("sep") {
1785            if let Some(b) = sep.bytes().next() {
1786                r = r.with_separator(b);
1787            }
1788        }
1789        if let Some(null_val) = self.options.get("nullValue") {
1790            r = r.with_null_values(Some(NullValues::AllColumnsSingle(null_val.clone().into())));
1791        }
1792        r
1793    }
1794
1795    fn apply_json_options(
1796        &self,
1797        reader: polars::prelude::LazyJsonLineReader,
1798    ) -> polars::prelude::LazyJsonLineReader {
1799        use std::num::NonZeroUsize;
1800        let mut r = reader;
1801        if let Some(v) = self.options.get("inferSchemaLength") {
1802            if let Ok(n) = v.parse::<usize>() {
1803                r = r.with_infer_schema_length(NonZeroUsize::new(n));
1804            }
1805        }
1806        r
1807    }
1808
1809    pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1810        use polars::prelude::*;
1811        let path = path.as_ref();
1812        let path_display = path.display();
1813        let pl_path = PlRefPath::try_from_path(path).map_err(|e| {
1814            PolarsError::ComputeError(format!("csv({path_display}): path: {e}").into())
1815        })?;
1816        let reader = LazyCsvReader::new(pl_path);
1817        let reader = if self.options.is_empty() {
1818            reader
1819                .with_has_header(true)
1820                .with_infer_schema_length(Some(100))
1821        } else {
1822            self.apply_csv_options(
1823                reader
1824                    .with_has_header(true)
1825                    .with_infer_schema_length(Some(100)),
1826            )
1827        };
1828        let lf = reader.finish().map_err(|e| {
1829            PolarsError::ComputeError(format!("read csv({path_display}): {e}").into())
1830        })?;
1831        let pl_df = lf.collect().map_err(|e| {
1832            PolarsError::ComputeError(
1833                format!("read csv({path_display}): collect failed: {e}").into(),
1834            )
1835        })?;
1836        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1837            pl_df,
1838            self.session.is_case_sensitive(),
1839        ))
1840    }
1841
1842    pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1843        use polars::prelude::*;
1844        let path = path.as_ref();
1845        let pl_path = PlRefPath::try_from_path(path)
1846            .map_err(|e| PolarsError::ComputeError(format!("parquet: path: {e}").into()))?;
1847        let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())?;
1848        let pl_df = lf.collect()?;
1849        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1850            pl_df,
1851            self.session.is_case_sensitive(),
1852        ))
1853    }
1854
1855    pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1856        use polars::prelude::*;
1857        use std::num::NonZeroUsize;
1858        let path = path.as_ref();
1859        let pl_path = PlRefPath::try_from_path(path)
1860            .map_err(|e| PolarsError::ComputeError(format!("json: path: {e}").into()))?;
1861        let reader = LazyJsonLineReader::new(pl_path);
1862        let reader = if self.options.is_empty() {
1863            reader.with_infer_schema_length(NonZeroUsize::new(100))
1864        } else {
1865            self.apply_json_options(reader.with_infer_schema_length(NonZeroUsize::new(100)))
1866        };
1867        let lf = reader.finish()?;
1868        let pl_df = lf.collect()?;
1869        Ok(crate::dataframe::DataFrame::from_polars_with_options(
1870            pl_df,
1871            self.session.is_case_sensitive(),
1872        ))
1873    }
1874
1875    #[cfg(feature = "delta")]
1876    pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1877        self.session.read_delta_from_path(path)
1878    }
1879}
1880
1881impl SparkSession {
1882    /// Get a DataFrameReader for reading files
1883    pub fn read(&self) -> DataFrameReader {
1884        DataFrameReader::new(SparkSession {
1885            app_name: self.app_name.clone(),
1886            master: self.master.clone(),
1887            config: self.config.clone(),
1888            catalog: self.catalog.clone(),
1889            tables: self.tables.clone(),
1890            databases: self.databases.clone(),
1891            udf_registry: self.udf_registry.clone(),
1892        })
1893    }
1894}
1895
1896impl Default for SparkSession {
1897    fn default() -> Self {
1898        Self::builder().get_or_create()
1899    }
1900}
1901
1902#[cfg(test)]
1903mod tests {
1904    use super::*;
1905
1906    #[test]
1907    fn test_spark_session_builder_basic() {
1908        let spark = SparkSession::builder().app_name("test_app").get_or_create();
1909
1910        assert_eq!(spark.app_name, Some("test_app".to_string()));
1911    }
1912
1913    #[test]
1914    fn test_spark_session_builder_with_master() {
1915        let spark = SparkSession::builder()
1916            .app_name("test_app")
1917            .master("local[*]")
1918            .get_or_create();
1919
1920        assert_eq!(spark.app_name, Some("test_app".to_string()));
1921        assert_eq!(spark.master, Some("local[*]".to_string()));
1922    }
1923
1924    #[test]
1925    fn test_spark_session_builder_with_config() {
1926        let spark = SparkSession::builder()
1927            .app_name("test_app")
1928            .config("spark.executor.memory", "4g")
1929            .config("spark.driver.memory", "2g")
1930            .get_or_create();
1931
1932        assert_eq!(
1933            spark.config.get("spark.executor.memory"),
1934            Some(&"4g".to_string())
1935        );
1936        assert_eq!(
1937            spark.config.get("spark.driver.memory"),
1938            Some(&"2g".to_string())
1939        );
1940    }
1941
1942    #[test]
1943    fn test_spark_session_default() {
1944        let spark = SparkSession::default();
1945        assert!(spark.app_name.is_none());
1946        assert!(spark.master.is_none());
1947        assert!(spark.config.is_empty());
1948    }
1949
1950    #[test]
1951    fn test_create_dataframe_success() {
1952        let spark = SparkSession::builder().app_name("test").get_or_create();
1953        let data = vec![
1954            (1i64, 25i64, "Alice".to_string()),
1955            (2i64, 30i64, "Bob".to_string()),
1956        ];
1957
1958        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
1959
1960        assert!(result.is_ok());
1961        let df = result.unwrap();
1962        assert_eq!(df.count().unwrap(), 2);
1963
1964        let columns = df.columns().unwrap();
1965        assert!(columns.contains(&"id".to_string()));
1966        assert!(columns.contains(&"age".to_string()));
1967        assert!(columns.contains(&"name".to_string()));
1968    }
1969
1970    #[test]
1971    fn test_create_dataframe_wrong_column_count() {
1972        let spark = SparkSession::builder().app_name("test").get_or_create();
1973        let data = vec![(1i64, 25i64, "Alice".to_string())];
1974
1975        // Too few columns
1976        let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
1977        assert!(result.is_err());
1978
1979        // Too many columns
1980        let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
1981        assert!(result.is_err());
1982    }
1983
1984    #[test]
1985    fn test_create_dataframe_from_rows_empty_schema_with_rows_returns_error() {
1986        let spark = SparkSession::builder().app_name("test").get_or_create();
1987        let rows: Vec<Vec<JsonValue>> = vec![vec![]];
1988        let schema: Vec<(String, String)> = vec![];
1989        let result = spark.create_dataframe_from_rows(rows, schema);
1990        match &result {
1991            Err(e) => assert!(e.to_string().contains("schema must not be empty")),
1992            Ok(_) => panic!("expected error for empty schema with non-empty rows"),
1993        }
1994    }
1995
1996    #[test]
1997    fn test_create_dataframe_from_rows_empty_data_with_schema() {
1998        let spark = SparkSession::builder().app_name("test").get_or_create();
1999        let rows: Vec<Vec<JsonValue>> = vec![];
2000        let schema = vec![
2001            ("a".to_string(), "int".to_string()),
2002            ("b".to_string(), "string".to_string()),
2003        ];
2004        let result = spark.create_dataframe_from_rows(rows, schema);
2005        let df = result.unwrap();
2006        assert_eq!(df.count().unwrap(), 0);
2007        assert_eq!(df.collect_inner().unwrap().get_column_names(), &["a", "b"]);
2008    }
2009
2010    #[test]
2011    fn test_create_dataframe_from_rows_empty_schema_empty_data() {
2012        let spark = SparkSession::builder().app_name("test").get_or_create();
2013        let rows: Vec<Vec<JsonValue>> = vec![];
2014        let schema: Vec<(String, String)> = vec![];
2015        let result = spark.create_dataframe_from_rows(rows, schema);
2016        let df = result.unwrap();
2017        assert_eq!(df.count().unwrap(), 0);
2018        assert_eq!(df.collect_inner().unwrap().get_column_names().len(), 0);
2019    }
2020
2021    /// create_dataframe_from_rows: struct column as JSON object (by field name). PySpark parity #600.
2022    #[test]
2023    fn test_create_dataframe_from_rows_struct_as_object() {
2024        use serde_json::json;
2025
2026        let spark = SparkSession::builder().app_name("test").get_or_create();
2027        let schema = vec![
2028            ("id".to_string(), "string".to_string()),
2029            (
2030                "nested".to_string(),
2031                "struct<a:bigint,b:string>".to_string(),
2032            ),
2033        ];
2034        let rows: Vec<Vec<JsonValue>> = vec![
2035            vec![json!("x"), json!({"a": 1, "b": "y"})],
2036            vec![json!("z"), json!({"a": 2, "b": "w"})],
2037        ];
2038        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2039        assert_eq!(df.count().unwrap(), 2);
2040        let collected = df.collect_inner().unwrap();
2041        assert_eq!(collected.get_column_names(), &["id", "nested"]);
2042    }
2043
2044    /// create_dataframe_from_rows: struct column as JSON array (by position). PySpark parity #600.
2045    #[test]
2046    fn test_create_dataframe_from_rows_struct_as_array() {
2047        use serde_json::json;
2048
2049        let spark = SparkSession::builder().app_name("test").get_or_create();
2050        let schema = vec![
2051            ("id".to_string(), "string".to_string()),
2052            (
2053                "nested".to_string(),
2054                "struct<a:bigint,b:string>".to_string(),
2055            ),
2056        ];
2057        let rows: Vec<Vec<JsonValue>> = vec![
2058            vec![json!("x"), json!([1, "y"])],
2059            vec![json!("z"), json!([2, "w"])],
2060        ];
2061        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2062        assert_eq!(df.count().unwrap(), 2);
2063        let collected = df.collect_inner().unwrap();
2064        assert_eq!(collected.get_column_names(), &["id", "nested"]);
2065    }
2066
2067    /// #610: create_dataframe_from_rows accepts struct as string that parses to object or array (Sparkless/Python serialization).
2068    #[test]
2069    fn test_issue_610_struct_value_as_string_object_or_array() {
2070        use serde_json::json;
2071
2072        let spark = SparkSession::builder().app_name("test").get_or_create();
2073        let schema = vec![
2074            ("id".to_string(), "string".to_string()),
2075            (
2076                "nested".to_string(),
2077                "struct<a:bigint,b:string>".to_string(),
2078            ),
2079        ];
2080        // Struct as string that parses to JSON object (e.g. Python dict serialized as string).
2081        let rows_object: Vec<Vec<JsonValue>> =
2082            vec![vec![json!("A"), json!(r#"{"a": 1, "b": "x"}"#)]];
2083        let df1 = spark
2084            .create_dataframe_from_rows(rows_object, schema.clone())
2085            .unwrap();
2086        assert_eq!(df1.count().unwrap(), 1);
2087
2088        // Struct as string that parses to JSON array (e.g. Python tuple (1, "y") serialized as "[1, \"y\"]").
2089        let rows_array: Vec<Vec<JsonValue>> = vec![vec![json!("B"), json!(r#"[1, "y"]"#)]];
2090        let df2 = spark
2091            .create_dataframe_from_rows(rows_array, schema)
2092            .unwrap();
2093        assert_eq!(df2.count().unwrap(), 1);
2094    }
2095
2096    /// #611: create_dataframe_from_rows accepts single value as one-element array (PySpark parity).
2097    #[test]
2098    fn test_issue_611_array_column_single_value_as_one_element() {
2099        use serde_json::json;
2100
2101        let spark = SparkSession::builder().app_name("test").get_or_create();
2102        let schema = vec![
2103            ("id".to_string(), "string".to_string()),
2104            ("arr".to_string(), "array<bigint>".to_string()),
2105        ];
2106        // Single number as one-element list (PySpark accepts this).
2107        let rows: Vec<Vec<JsonValue>> = vec![
2108            vec![json!("x"), json!(42)],
2109            vec![json!("y"), json!([1, 2, 3])],
2110        ];
2111        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2112        assert_eq!(df.count().unwrap(), 2);
2113        let collected = df.collect_inner().unwrap();
2114        let arr_col = collected.column("arr").unwrap();
2115        let list = arr_col.list().unwrap();
2116        let row0 = list.get(0).unwrap();
2117        assert_eq!(
2118            row0.len(),
2119            1,
2120            "#611: single value should become one-element list"
2121        );
2122        let row1 = list.get(1).unwrap();
2123        assert_eq!(row1.len(), 3);
2124    }
2125
2126    /// create_dataframe_from_rows: array column with JSON array and null. PySpark parity #601.
2127    #[test]
2128    fn test_create_dataframe_from_rows_array_column() {
2129        use serde_json::json;
2130
2131        let spark = SparkSession::builder().app_name("test").get_or_create();
2132        let schema = vec![
2133            ("id".to_string(), "string".to_string()),
2134            ("arr".to_string(), "array<bigint>".to_string()),
2135        ];
2136        let rows: Vec<Vec<JsonValue>> = vec![
2137            vec![json!("x"), json!([1, 2, 3])],
2138            vec![json!("y"), json!([4, 5])],
2139            vec![json!("z"), json!(null)],
2140        ];
2141        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2142        assert_eq!(df.count().unwrap(), 3);
2143        let collected = df.collect_inner().unwrap();
2144        assert_eq!(collected.get_column_names(), &["id", "arr"]);
2145
2146        // Issue #601: verify array data round-trips correctly (not just no error).
2147        let arr_col = collected.column("arr").unwrap();
2148        let list = arr_col.list().unwrap();
2149        // Row 0: [1, 2, 3]
2150        let row0 = list.get(0).unwrap();
2151        assert_eq!(row0.len(), 3, "row 0 arr should have 3 elements");
2152        // Row 1: [4, 5]
2153        let row1 = list.get(1).unwrap();
2154        assert_eq!(row1.len(), 2);
2155        // Row 2: null list (representation may be None or empty)
2156        let row2 = list.get(2);
2157        assert!(
2158            row2.is_none() || row2.as_ref().map(|a| a.is_empty()).unwrap_or(false),
2159            "row 2 arr should be null or empty"
2160        );
2161    }
2162
2163    /// Issue #601: PySpark createDataFrame([(\"x\", [1,2,3]), (\"y\", [4,5])], schema) with ArrayType.
2164    /// Must not fail with \"array column value must be null or array\" and must produce correct structure.
2165    #[test]
2166    fn test_issue_601_array_column_pyspark_parity() {
2167        use serde_json::json;
2168
2169        let spark = SparkSession::builder().app_name("test").get_or_create();
2170        let schema = vec![
2171            ("id".to_string(), "string".to_string()),
2172            ("arr".to_string(), "array<bigint>".to_string()),
2173        ];
2174        // Exact PySpark example: rows with string id and list of ints.
2175        let rows: Vec<Vec<JsonValue>> = vec![
2176            vec![json!("x"), json!([1, 2, 3])],
2177            vec![json!("y"), json!([4, 5])],
2178        ];
2179        let df = spark
2180            .create_dataframe_from_rows(rows, schema)
2181            .expect("issue #601: create_dataframe_from_rows must accept array column (JSON array)");
2182        let n = df.count().unwrap();
2183        assert_eq!(n, 2, "issue #601: expected 2 rows");
2184        let collected = df.collect_inner().unwrap();
2185        let arr_col = collected.column("arr").unwrap();
2186        let list = arr_col.list().unwrap();
2187        // Verify list lengths match PySpark [1,2,3] and [4,5]
2188        let row0 = list.get(0).unwrap();
2189        assert_eq!(
2190            row0.len(),
2191            3,
2192            "issue #601: first row arr must have 3 elements [1,2,3]"
2193        );
2194        let row1 = list.get(1).unwrap();
2195        assert_eq!(
2196            row1.len(),
2197            2,
2198            "issue #601: second row arr must have 2 elements [4,5]"
2199        );
2200    }
2201
2202    /// #624: When schema is empty but rows are not, infer schema from rows (PySpark parity).
2203    #[test]
2204    fn test_issue_624_empty_schema_inferred_from_rows() {
2205        use serde_json::json;
2206
2207        let spark = SparkSession::builder().app_name("test").get_or_create();
2208        let schema: Vec<(String, String)> = vec![];
2209        let rows: Vec<Vec<JsonValue>> =
2210            vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
2211        let df = spark
2212            .create_dataframe_from_rows(rows, schema)
2213            .expect("#624: empty schema with non-empty rows should infer schema");
2214        assert_eq!(df.count().unwrap(), 2);
2215        let collected = df.collect_inner().unwrap();
2216        assert_eq!(collected.get_column_names(), &["c0", "c1"]);
2217    }
2218
2219    /// #627: create_dataframe_from_rows accepts map column (dict/object). PySpark MapType parity.
2220    #[test]
2221    fn test_create_dataframe_from_rows_map_column() {
2222        use serde_json::json;
2223
2224        let spark = SparkSession::builder().app_name("test").get_or_create();
2225        let schema = vec![
2226            ("id".to_string(), "integer".to_string()),
2227            ("m".to_string(), "map<string,string>".to_string()),
2228        ];
2229        let rows: Vec<Vec<JsonValue>> = vec![
2230            vec![json!(1), json!({"a": "x", "b": "y"})],
2231            vec![json!(2), json!({"c": "z"})],
2232        ];
2233        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2234        assert_eq!(df.count().unwrap(), 2);
2235        let collected = df.collect_inner().unwrap();
2236        assert_eq!(collected.get_column_names(), &["id", "m"]);
2237        let m_col = collected.column("m").unwrap();
2238        let list = m_col.list().unwrap();
2239        let row0 = list.get(0).unwrap();
2240        assert_eq!(row0.len(), 2, "row 0 map should have 2 entries");
2241        let row1 = list.get(1).unwrap();
2242        assert_eq!(row1.len(), 1, "row 1 map should have 1 entry");
2243    }
2244
2245    /// #625: create_dataframe_from_rows accepts array column as JSON array or Object (Python list parity).
2246    #[test]
2247    fn test_issue_625_array_column_list_or_object() {
2248        use serde_json::json;
2249
2250        let spark = SparkSession::builder().app_name("test").get_or_create();
2251        let schema = vec![
2252            ("id".to_string(), "string".to_string()),
2253            ("arr".to_string(), "array<bigint>".to_string()),
2254        ];
2255        // JSON array (Python list) and Object with "0","1","2" keys (some serializations).
2256        let rows: Vec<Vec<JsonValue>> = vec![
2257            vec![json!("x"), json!([1, 2, 3])],
2258            vec![json!("y"), json!({"0": 4, "1": 5})],
2259        ];
2260        let df = spark
2261            .create_dataframe_from_rows(rows, schema)
2262            .expect("#625: array column must accept list/array or object representation");
2263        assert_eq!(df.count().unwrap(), 2);
2264        let collected = df.collect_inner().unwrap();
2265        let list = collected.column("arr").unwrap().list().unwrap();
2266        assert_eq!(list.get(0).unwrap().len(), 3);
2267        assert_eq!(list.get(1).unwrap().len(), 2);
2268    }
2269
2270    #[test]
2271    fn test_create_dataframe_empty() {
2272        let spark = SparkSession::builder().app_name("test").get_or_create();
2273        let data: Vec<(i64, i64, String)> = vec![];
2274
2275        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
2276
2277        assert!(result.is_ok());
2278        let df = result.unwrap();
2279        assert_eq!(df.count().unwrap(), 0);
2280    }
2281
2282    #[test]
2283    fn test_create_dataframe_from_polars() {
2284        use polars::prelude::df;
2285
2286        let spark = SparkSession::builder().app_name("test").get_or_create();
2287        let polars_df = df!(
2288            "x" => &[1, 2, 3],
2289            "y" => &[4, 5, 6]
2290        )
2291        .unwrap();
2292
2293        let df = spark.create_dataframe_from_polars(polars_df);
2294
2295        assert_eq!(df.count().unwrap(), 3);
2296        let columns = df.columns().unwrap();
2297        assert!(columns.contains(&"x".to_string()));
2298        assert!(columns.contains(&"y".to_string()));
2299    }
2300
2301    #[test]
2302    fn test_read_csv_file_not_found() {
2303        let spark = SparkSession::builder().app_name("test").get_or_create();
2304
2305        let result = spark.read_csv("nonexistent_file.csv");
2306
2307        assert!(result.is_err());
2308    }
2309
2310    #[test]
2311    fn test_read_parquet_file_not_found() {
2312        let spark = SparkSession::builder().app_name("test").get_or_create();
2313
2314        let result = spark.read_parquet("nonexistent_file.parquet");
2315
2316        assert!(result.is_err());
2317    }
2318
2319    #[test]
2320    fn test_read_json_file_not_found() {
2321        let spark = SparkSession::builder().app_name("test").get_or_create();
2322
2323        let result = spark.read_json("nonexistent_file.json");
2324
2325        assert!(result.is_err());
2326    }
2327
2328    #[test]
2329    fn test_rust_udf_dataframe() {
2330        use crate::functions::{call_udf, col};
2331        use polars::prelude::DataType;
2332
2333        let spark = SparkSession::builder().app_name("test").get_or_create();
2334        spark
2335            .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
2336            .unwrap();
2337        let df = spark
2338            .create_dataframe(
2339                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2340                vec!["id", "age", "name"],
2341            )
2342            .unwrap();
2343        let col = call_udf("to_str", &[col("id")]).unwrap();
2344        let df2 = df.with_column("id_str", &col).unwrap();
2345        let cols = df2.columns().unwrap();
2346        assert!(cols.contains(&"id_str".to_string()));
2347        let rows = df2.collect_as_json_rows().unwrap();
2348        assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
2349        assert_eq!(rows[1].get("id_str").and_then(|v| v.as_str()), Some("2"));
2350    }
2351
2352    #[test]
2353    fn test_case_insensitive_filter_select() {
2354        use crate::expression::lit_i64;
2355        use crate::functions::col;
2356
2357        let spark = SparkSession::builder().app_name("test").get_or_create();
2358        let df = spark
2359            .create_dataframe(
2360                vec![
2361                    (1, 25, "Alice".to_string()),
2362                    (2, 30, "Bob".to_string()),
2363                    (3, 35, "Charlie".to_string()),
2364                ],
2365                vec!["Id", "Age", "Name"],
2366            )
2367            .unwrap();
2368        // Filter with lowercase column names (PySpark default: case-insensitive)
2369        let filtered = df
2370            .filter(col("age").gt(lit_i64(26)).expr().clone())
2371            .unwrap()
2372            .select(vec!["name"])
2373            .unwrap();
2374        assert_eq!(filtered.count().unwrap(), 2);
2375        let rows = filtered.collect_as_json_rows().unwrap();
2376        let names: Vec<&str> = rows
2377            .iter()
2378            .map(|r| r.get("name").and_then(|v| v.as_str()).unwrap())
2379            .collect();
2380        assert!(names.contains(&"Bob"));
2381        assert!(names.contains(&"Charlie"));
2382    }
2383
2384    #[test]
2385    fn test_sql_returns_error_without_feature_or_unknown_table() {
2386        let spark = SparkSession::builder().app_name("test").get_or_create();
2387
2388        let result = spark.sql("SELECT * FROM table");
2389
2390        assert!(result.is_err());
2391        match result {
2392            Err(PolarsError::InvalidOperation(msg)) => {
2393                let s = msg.to_string();
2394                // Without sql feature: "SQL queries require the 'sql' feature"
2395                // With sql feature but no table: "Table or view 'table' not found" or parse error
2396                assert!(
2397                    s.contains("SQL") || s.contains("Table") || s.contains("feature"),
2398                    "unexpected message: {s}"
2399                );
2400            }
2401            _ => panic!("Expected InvalidOperation error"),
2402        }
2403    }
2404
2405    #[test]
2406    fn test_spark_session_stop() {
2407        let spark = SparkSession::builder().app_name("test").get_or_create();
2408
2409        // stop() should complete without error
2410        spark.stop();
2411    }
2412
2413    #[test]
2414    fn test_dataframe_reader_api() {
2415        let spark = SparkSession::builder().app_name("test").get_or_create();
2416        let reader = spark.read();
2417
2418        // All readers should return errors for non-existent files
2419        assert!(reader.csv("nonexistent.csv").is_err());
2420        assert!(reader.parquet("nonexistent.parquet").is_err());
2421        assert!(reader.json("nonexistent.json").is_err());
2422    }
2423
2424    #[test]
2425    fn test_read_csv_with_valid_file() {
2426        use std::io::Write;
2427        use tempfile::NamedTempFile;
2428
2429        let spark = SparkSession::builder().app_name("test").get_or_create();
2430
2431        // Create a temporary CSV file
2432        let mut temp_file = NamedTempFile::new().unwrap();
2433        writeln!(temp_file, "id,name,age").unwrap();
2434        writeln!(temp_file, "1,Alice,25").unwrap();
2435        writeln!(temp_file, "2,Bob,30").unwrap();
2436        temp_file.flush().unwrap();
2437
2438        let result = spark.read_csv(temp_file.path());
2439
2440        assert!(result.is_ok());
2441        let df = result.unwrap();
2442        assert_eq!(df.count().unwrap(), 2);
2443
2444        let columns = df.columns().unwrap();
2445        assert!(columns.contains(&"id".to_string()));
2446        assert!(columns.contains(&"name".to_string()));
2447        assert!(columns.contains(&"age".to_string()));
2448    }
2449
2450    #[test]
2451    fn test_read_json_with_valid_file() {
2452        use std::io::Write;
2453        use tempfile::NamedTempFile;
2454
2455        let spark = SparkSession::builder().app_name("test").get_or_create();
2456
2457        // Create a temporary JSONL file
2458        let mut temp_file = NamedTempFile::new().unwrap();
2459        writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
2460        writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
2461        temp_file.flush().unwrap();
2462
2463        let result = spark.read_json(temp_file.path());
2464
2465        assert!(result.is_ok());
2466        let df = result.unwrap();
2467        assert_eq!(df.count().unwrap(), 2);
2468    }
2469
2470    #[test]
2471    fn test_read_csv_empty_file() {
2472        use std::io::Write;
2473        use tempfile::NamedTempFile;
2474
2475        let spark = SparkSession::builder().app_name("test").get_or_create();
2476
2477        // Create an empty CSV file (just header)
2478        let mut temp_file = NamedTempFile::new().unwrap();
2479        writeln!(temp_file, "id,name").unwrap();
2480        temp_file.flush().unwrap();
2481
2482        let result = spark.read_csv(temp_file.path());
2483
2484        assert!(result.is_ok());
2485        let df = result.unwrap();
2486        assert_eq!(df.count().unwrap(), 0);
2487    }
2488
2489    #[test]
2490    fn test_write_partitioned_parquet() {
2491        use crate::dataframe::{WriteFormat, WriteMode};
2492        use std::fs;
2493        use tempfile::TempDir;
2494
2495        let spark = SparkSession::builder().app_name("test").get_or_create();
2496        let df = spark
2497            .create_dataframe(
2498                vec![
2499                    (1, 25, "Alice".to_string()),
2500                    (2, 30, "Bob".to_string()),
2501                    (3, 25, "Carol".to_string()),
2502                ],
2503                vec!["id", "age", "name"],
2504            )
2505            .unwrap();
2506        let dir = TempDir::new().unwrap();
2507        let path = dir.path().join("out");
2508        df.write()
2509            .mode(WriteMode::Overwrite)
2510            .format(WriteFormat::Parquet)
2511            .partition_by(["age"])
2512            .save(&path)
2513            .unwrap();
2514        assert!(path.is_dir());
2515        let entries: Vec<_> = fs::read_dir(&path).unwrap().collect();
2516        assert_eq!(
2517            entries.len(),
2518            2,
2519            "expected two partition dirs (age=25, age=30)"
2520        );
2521        let names: Vec<String> = entries
2522            .iter()
2523            .filter_map(|e| e.as_ref().ok())
2524            .map(|e| e.file_name().to_string_lossy().into_owned())
2525            .collect();
2526        assert!(names.iter().any(|n| n.starts_with("age=")));
2527        let df_read = spark.read_parquet(&path).unwrap();
2528        assert_eq!(df_read.count().unwrap(), 3);
2529    }
2530
2531    #[test]
2532    fn test_save_as_table_error_if_exists() {
2533        use crate::dataframe::SaveMode;
2534
2535        let spark = SparkSession::builder().app_name("test").get_or_create();
2536        let df = spark
2537            .create_dataframe(
2538                vec![(1, 25, "Alice".to_string())],
2539                vec!["id", "age", "name"],
2540            )
2541            .unwrap();
2542        // First call succeeds
2543        df.write()
2544            .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
2545            .unwrap();
2546        assert!(spark.table("t1").is_ok());
2547        assert_eq!(spark.table("t1").unwrap().count().unwrap(), 1);
2548        // Second call with ErrorIfExists fails
2549        let err = df
2550            .write()
2551            .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
2552            .unwrap_err();
2553        assert!(err.to_string().contains("already exists"));
2554    }
2555
2556    #[test]
2557    fn test_save_as_table_overwrite() {
2558        use crate::dataframe::SaveMode;
2559
2560        let spark = SparkSession::builder().app_name("test").get_or_create();
2561        let df1 = spark
2562            .create_dataframe(
2563                vec![(1, 25, "Alice".to_string())],
2564                vec!["id", "age", "name"],
2565            )
2566            .unwrap();
2567        let df2 = spark
2568            .create_dataframe(
2569                vec![(2, 30, "Bob".to_string()), (3, 35, "Carol".to_string())],
2570                vec!["id", "age", "name"],
2571            )
2572            .unwrap();
2573        df1.write()
2574            .save_as_table(&spark, "t_over", SaveMode::ErrorIfExists)
2575            .unwrap();
2576        assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 1);
2577        df2.write()
2578            .save_as_table(&spark, "t_over", SaveMode::Overwrite)
2579            .unwrap();
2580        assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 2);
2581    }
2582
2583    #[test]
2584    fn test_save_as_table_append() {
2585        use crate::dataframe::SaveMode;
2586
2587        let spark = SparkSession::builder().app_name("test").get_or_create();
2588        let df1 = spark
2589            .create_dataframe(
2590                vec![(1, 25, "Alice".to_string())],
2591                vec!["id", "age", "name"],
2592            )
2593            .unwrap();
2594        let df2 = spark
2595            .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
2596            .unwrap();
2597        df1.write()
2598            .save_as_table(&spark, "t_append", SaveMode::ErrorIfExists)
2599            .unwrap();
2600        df2.write()
2601            .save_as_table(&spark, "t_append", SaveMode::Append)
2602            .unwrap();
2603        assert_eq!(spark.table("t_append").unwrap().count().unwrap(), 2);
2604    }
2605
2606    /// Empty DataFrame with explicit schema: saveAsTable(Overwrite) then append one row (issue #495).
2607    #[test]
2608    fn test_save_as_table_empty_df_then_append() {
2609        use crate::dataframe::SaveMode;
2610        use serde_json::json;
2611
2612        let spark = SparkSession::builder().app_name("test").get_or_create();
2613        let schema = vec![
2614            ("id".to_string(), "bigint".to_string()),
2615            ("name".to_string(), "string".to_string()),
2616        ];
2617        let empty_df = spark
2618            .create_dataframe_from_rows(vec![], schema.clone())
2619            .unwrap();
2620        assert_eq!(empty_df.count().unwrap(), 0);
2621
2622        empty_df
2623            .write()
2624            .save_as_table(&spark, "t_empty_append", SaveMode::Overwrite)
2625            .unwrap();
2626        let r1 = spark.table("t_empty_append").unwrap();
2627        assert_eq!(r1.count().unwrap(), 0);
2628        let cols = r1.columns().unwrap();
2629        assert!(cols.contains(&"id".to_string()));
2630        assert!(cols.contains(&"name".to_string()));
2631
2632        let one_row = spark
2633            .create_dataframe_from_rows(vec![vec![json!(1), json!("a")]], schema)
2634            .unwrap();
2635        one_row
2636            .write()
2637            .save_as_table(&spark, "t_empty_append", SaveMode::Append)
2638            .unwrap();
2639        let r2 = spark.table("t_empty_append").unwrap();
2640        assert_eq!(r2.count().unwrap(), 1);
2641    }
2642
2643    /// Empty DataFrame with schema: write.format("parquet").save(path) must not fail (issue #519).
2644    /// PySpark fails with "can not infer schema from empty dataset"; robin-sparkless uses explicit schema.
2645    #[test]
2646    fn test_write_parquet_empty_df_with_schema() {
2647        let spark = SparkSession::builder().app_name("test").get_or_create();
2648        let schema = vec![
2649            ("id".to_string(), "bigint".to_string()),
2650            ("name".to_string(), "string".to_string()),
2651        ];
2652        let empty_df = spark.create_dataframe_from_rows(vec![], schema).unwrap();
2653        assert_eq!(empty_df.count().unwrap(), 0);
2654
2655        let dir = tempfile::TempDir::new().unwrap();
2656        let path = dir.path().join("empty.parquet");
2657        empty_df
2658            .write()
2659            .format(crate::dataframe::WriteFormat::Parquet)
2660            .mode(crate::dataframe::WriteMode::Overwrite)
2661            .save(&path)
2662            .unwrap();
2663        assert!(path.is_file());
2664
2665        // Read back and verify schema preserved
2666        let read_df = spark.read().parquet(path.to_str().unwrap()).unwrap();
2667        assert_eq!(read_df.count().unwrap(), 0);
2668        let cols = read_df.columns().unwrap();
2669        assert!(cols.contains(&"id".to_string()));
2670        assert!(cols.contains(&"name".to_string()));
2671    }
2672
2673    /// Empty DataFrame with schema + warehouse: saveAsTable(Overwrite) then append (issue #495 disk path).
2674    #[test]
2675    fn test_save_as_table_empty_df_warehouse_then_append() {
2676        use crate::dataframe::SaveMode;
2677        use serde_json::json;
2678        use std::sync::atomic::{AtomicU64, Ordering};
2679        use tempfile::TempDir;
2680
2681        static COUNTER: AtomicU64 = AtomicU64::new(0);
2682        let n = COUNTER.fetch_add(1, Ordering::SeqCst);
2683        let dir = TempDir::new().unwrap();
2684        let warehouse = dir.path().join(format!("wh_{n}"));
2685        std::fs::create_dir_all(&warehouse).unwrap();
2686        let spark = SparkSession::builder()
2687            .app_name("test")
2688            .config(
2689                "spark.sql.warehouse.dir",
2690                warehouse.as_os_str().to_str().unwrap(),
2691            )
2692            .get_or_create();
2693
2694        let schema = vec![
2695            ("id".to_string(), "bigint".to_string()),
2696            ("name".to_string(), "string".to_string()),
2697        ];
2698        let empty_df = spark
2699            .create_dataframe_from_rows(vec![], schema.clone())
2700            .unwrap();
2701        empty_df
2702            .write()
2703            .save_as_table(&spark, "t_empty_wh", SaveMode::Overwrite)
2704            .unwrap();
2705        let r1 = spark.table("t_empty_wh").unwrap();
2706        assert_eq!(r1.count().unwrap(), 0);
2707
2708        let one_row = spark
2709            .create_dataframe_from_rows(vec![vec![json!(1), json!("a")]], schema)
2710            .unwrap();
2711        one_row
2712            .write()
2713            .save_as_table(&spark, "t_empty_wh", SaveMode::Append)
2714            .unwrap();
2715        let r2 = spark.table("t_empty_wh").unwrap();
2716        assert_eq!(r2.count().unwrap(), 1);
2717    }
2718
2719    #[test]
2720    fn test_save_as_table_ignore() {
2721        use crate::dataframe::SaveMode;
2722
2723        let spark = SparkSession::builder().app_name("test").get_or_create();
2724        let df1 = spark
2725            .create_dataframe(
2726                vec![(1, 25, "Alice".to_string())],
2727                vec!["id", "age", "name"],
2728            )
2729            .unwrap();
2730        let df2 = spark
2731            .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
2732            .unwrap();
2733        df1.write()
2734            .save_as_table(&spark, "t_ignore", SaveMode::ErrorIfExists)
2735            .unwrap();
2736        df2.write()
2737            .save_as_table(&spark, "t_ignore", SaveMode::Ignore)
2738            .unwrap();
2739        // Still 1 row (ignore did not replace)
2740        assert_eq!(spark.table("t_ignore").unwrap().count().unwrap(), 1);
2741    }
2742
2743    #[test]
2744    fn test_table_resolution_temp_view_first() {
2745        use crate::dataframe::SaveMode;
2746
2747        let spark = SparkSession::builder().app_name("test").get_or_create();
2748        let df_saved = spark
2749            .create_dataframe(
2750                vec![(1, 25, "Saved".to_string())],
2751                vec!["id", "age", "name"],
2752            )
2753            .unwrap();
2754        let df_temp = spark
2755            .create_dataframe(vec![(2, 30, "Temp".to_string())], vec!["id", "age", "name"])
2756            .unwrap();
2757        df_saved
2758            .write()
2759            .save_as_table(&spark, "x", SaveMode::ErrorIfExists)
2760            .unwrap();
2761        spark.create_or_replace_temp_view("x", df_temp);
2762        // table("x") must return temp view (PySpark order)
2763        let t = spark.table("x").unwrap();
2764        let rows = t.collect_as_json_rows().unwrap();
2765        assert_eq!(rows.len(), 1);
2766        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Temp"));
2767    }
2768
2769    /// #629: Exact reproduction – createDataFrame, createOrReplaceTempView, then table() must resolve.
2770    #[test]
2771    fn test_issue_629_temp_view_visible_after_create() {
2772        use serde_json::json;
2773
2774        let spark = SparkSession::builder().app_name("repro").get_or_create();
2775        let schema = vec![
2776            ("id".to_string(), "long".to_string()),
2777            ("name".to_string(), "string".to_string()),
2778        ];
2779        let rows: Vec<Vec<JsonValue>> =
2780            vec![vec![json!(1), json!("a")], vec![json!(2), json!("b")]];
2781        let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2782        spark.create_or_replace_temp_view("my_view", df);
2783        let result = spark
2784            .table("my_view")
2785            .unwrap()
2786            .collect_as_json_rows()
2787            .unwrap();
2788        assert_eq!(result.len(), 2);
2789        assert_eq!(result[0].get("id").and_then(|v| v.as_i64()), Some(1));
2790        assert_eq!(result[0].get("name").and_then(|v| v.as_str()), Some("a"));
2791        assert_eq!(result[1].get("id").and_then(|v| v.as_i64()), Some(2));
2792        assert_eq!(result[1].get("name").and_then(|v| v.as_str()), Some("b"));
2793    }
2794
2795    #[test]
2796    fn test_drop_table() {
2797        use crate::dataframe::SaveMode;
2798
2799        let spark = SparkSession::builder().app_name("test").get_or_create();
2800        let df = spark
2801            .create_dataframe(
2802                vec![(1, 25, "Alice".to_string())],
2803                vec!["id", "age", "name"],
2804            )
2805            .unwrap();
2806        df.write()
2807            .save_as_table(&spark, "t_drop", SaveMode::ErrorIfExists)
2808            .unwrap();
2809        assert!(spark.table("t_drop").is_ok());
2810        assert!(spark.drop_table("t_drop"));
2811        assert!(spark.table("t_drop").is_err());
2812        // drop again is no-op, returns false
2813        assert!(!spark.drop_table("t_drop"));
2814    }
2815
2816    #[test]
2817    fn test_global_temp_view_persists_across_sessions() {
2818        // Session 1: create global temp view
2819        let spark1 = SparkSession::builder().app_name("s1").get_or_create();
2820        let df1 = spark1
2821            .create_dataframe(
2822                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2823                vec!["id", "age", "name"],
2824            )
2825            .unwrap();
2826        spark1.create_or_replace_global_temp_view("people", df1);
2827        assert_eq!(
2828            spark1.table("global_temp.people").unwrap().count().unwrap(),
2829            2
2830        );
2831
2832        // Session 2: different session can see global temp view
2833        let spark2 = SparkSession::builder().app_name("s2").get_or_create();
2834        let df2 = spark2.table("global_temp.people").unwrap();
2835        assert_eq!(df2.count().unwrap(), 2);
2836        let rows = df2.collect_as_json_rows().unwrap();
2837        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
2838
2839        // Local temp view in spark2 does not shadow global_temp
2840        let df_local = spark2
2841            .create_dataframe(
2842                vec![(3, 35, "Carol".to_string())],
2843                vec!["id", "age", "name"],
2844            )
2845            .unwrap();
2846        spark2.create_or_replace_temp_view("people", df_local);
2847        // table("people") = local temp view (session resolution)
2848        assert_eq!(spark2.table("people").unwrap().count().unwrap(), 1);
2849        // table("global_temp.people") = global temp view (unchanged)
2850        assert_eq!(
2851            spark2.table("global_temp.people").unwrap().count().unwrap(),
2852            2
2853        );
2854
2855        // Drop global temp view
2856        assert!(spark2.drop_global_temp_view("people"));
2857        assert!(spark2.table("global_temp.people").is_err());
2858    }
2859
2860    #[test]
2861    fn test_warehouse_persistence_between_sessions() {
2862        use crate::dataframe::SaveMode;
2863        use std::fs;
2864        use tempfile::TempDir;
2865
2866        let dir = TempDir::new().unwrap();
2867        let warehouse = dir.path().to_str().unwrap();
2868
2869        // Session 1: save to warehouse
2870        let spark1 = SparkSession::builder()
2871            .app_name("w1")
2872            .config("spark.sql.warehouse.dir", warehouse)
2873            .get_or_create();
2874        let df1 = spark1
2875            .create_dataframe(
2876                vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2877                vec!["id", "age", "name"],
2878            )
2879            .unwrap();
2880        df1.write()
2881            .save_as_table(&spark1, "users", SaveMode::ErrorIfExists)
2882            .unwrap();
2883        assert_eq!(spark1.table("users").unwrap().count().unwrap(), 2);
2884
2885        // Session 2: new session reads from warehouse
2886        let spark2 = SparkSession::builder()
2887            .app_name("w2")
2888            .config("spark.sql.warehouse.dir", warehouse)
2889            .get_or_create();
2890        let df2 = spark2.table("users").unwrap();
2891        assert_eq!(df2.count().unwrap(), 2);
2892        let rows = df2.collect_as_json_rows().unwrap();
2893        assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
2894
2895        // Verify parquet was written
2896        let table_path = dir.path().join("users");
2897        assert!(table_path.is_dir());
2898        let entries: Vec<_> = fs::read_dir(&table_path).unwrap().collect();
2899        assert!(!entries.is_empty());
2900    }
2901}