Skip to main content

robin_sparkless/
session.rs

1use crate::dataframe::DataFrame;
2use polars::prelude::{
3    DataFrame as PlDataFrame, DataType, NamedFrom, PolarsError, Series, TimeUnit,
4};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9
10/// Builder for creating a SparkSession with configuration options
11#[derive(Clone)]
12pub struct SparkSessionBuilder {
13    app_name: Option<String>,
14    master: Option<String>,
15    config: HashMap<String, String>,
16}
17
18impl Default for SparkSessionBuilder {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl SparkSessionBuilder {
25    pub fn new() -> Self {
26        SparkSessionBuilder {
27            app_name: None,
28            master: None,
29            config: HashMap::new(),
30        }
31    }
32
33    pub fn app_name(mut self, name: impl Into<String>) -> Self {
34        self.app_name = Some(name.into());
35        self
36    }
37
38    pub fn master(mut self, master: impl Into<String>) -> Self {
39        self.master = Some(master.into());
40        self
41    }
42
43    pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
44        self.config.insert(key.into(), value.into());
45        self
46    }
47
48    pub fn get_or_create(self) -> SparkSession {
49        SparkSession::new(self.app_name, self.master, self.config)
50    }
51}
52
53/// Catalog of temporary view names to DataFrames (session-scoped). Uses Arc<Mutex<>> for Send+Sync (Python bindings).
54pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
55
56/// Main entry point for creating DataFrames and executing queries
57/// Similar to PySpark's SparkSession but using Polars as the backend
58#[derive(Clone)]
59pub struct SparkSession {
60    app_name: Option<String>,
61    master: Option<String>,
62    config: HashMap<String, String>,
63    /// Temporary views: name -> DataFrame. Session-scoped; cleared when session is dropped.
64    pub(crate) catalog: TempViewCatalog,
65}
66
67impl SparkSession {
68    pub fn new(
69        app_name: Option<String>,
70        master: Option<String>,
71        config: HashMap<String, String>,
72    ) -> Self {
73        SparkSession {
74            app_name,
75            master,
76            config,
77            catalog: Arc::new(Mutex::new(HashMap::new())),
78        }
79    }
80
81    /// Register a DataFrame as a temporary view (PySpark: createOrReplaceTempView).
82    /// The view is session-scoped and is dropped when the session is dropped.
83    pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
84        let _ = self
85            .catalog
86            .lock()
87            .map(|mut m| m.insert(name.to_string(), df));
88    }
89
90    /// Global temp view (PySpark: createGlobalTempView). Stub: uses same catalog as temp view.
91    pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
92        self.create_or_replace_temp_view(name, df);
93    }
94
95    /// Global temp view (PySpark: createOrReplaceGlobalTempView). Stub: uses same catalog as temp view.
96    pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
97        self.create_or_replace_temp_view(name, df);
98    }
99
100    /// Look up a temporary view by name (PySpark: table(name)).
101    /// Returns an error if the view does not exist.
102    pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
103        self.catalog
104            .lock()
105            .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
106            .get(name)
107            .cloned()
108            .ok_or_else(|| {
109                PolarsError::InvalidOperation(
110                    format!(
111                        "Table or view '{name}' not found. Register it with create_or_replace_temp_view."
112                    )
113                    .into(),
114                )
115            })
116    }
117
118    pub fn builder() -> SparkSessionBuilder {
119        SparkSessionBuilder::new()
120    }
121
122    /// Whether column names are case-sensitive (PySpark: spark.sql.caseSensitive).
123    /// Default is false (case-insensitive matching).
124    pub fn is_case_sensitive(&self) -> bool {
125        self.config
126            .get("spark.sql.caseSensitive")
127            .map(|v| v.eq_ignore_ascii_case("true"))
128            .unwrap_or(false)
129    }
130
131    /// Create a DataFrame from a vector of tuples (i64, i64, String)
132    ///
133    /// # Example
134    /// ```
135    /// use robin_sparkless::session::SparkSession;
136    ///
137    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
138    /// let spark = SparkSession::builder().app_name("test").get_or_create();
139    /// let df = spark.create_dataframe(
140    ///     vec![
141    ///         (1, 25, "Alice".to_string()),
142    ///         (2, 30, "Bob".to_string()),
143    ///     ],
144    ///     vec!["id", "age", "name"],
145    /// )?;
146    /// #     let _ = df;
147    /// #     Ok(())
148    /// # }
149    /// ```
150    pub fn create_dataframe(
151        &self,
152        data: Vec<(i64, i64, String)>,
153        column_names: Vec<&str>,
154    ) -> Result<DataFrame, PolarsError> {
155        if column_names.len() != 3 {
156            return Err(PolarsError::ComputeError(
157                format!(
158                    "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
159                    column_names.len()
160                )
161                .into(),
162            ));
163        }
164
165        let mut cols: Vec<Series> = Vec::with_capacity(3);
166
167        // First column: i64
168        let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
169        cols.push(Series::new(column_names[0].into(), col0));
170
171        // Second column: i64
172        let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
173        cols.push(Series::new(column_names[1].into(), col1));
174
175        // Third column: String
176        let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
177        cols.push(Series::new(column_names[2].into(), col2));
178
179        let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
180        Ok(DataFrame::from_polars_with_options(
181            pl_df,
182            self.is_case_sensitive(),
183        ))
184    }
185
186    /// Create a DataFrame from a Polars DataFrame
187    pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
188        DataFrame::from_polars_with_options(df, self.is_case_sensitive())
189    }
190
191    /// Create a DataFrame from rows and a schema (arbitrary column count and types).
192    ///
193    /// `rows`: each inner vec is one row; length must match schema length. Values are JSON-like (i64, f64, string, bool, null).
194    /// `schema`: list of (column_name, dtype_string), e.g. `[("id", "bigint"), ("name", "string")]`.
195    /// Supported dtype strings: bigint, int, long, double, float, string, str, varchar, boolean, bool, date, timestamp, datetime.
196    pub fn create_dataframe_from_rows(
197        &self,
198        rows: Vec<Vec<JsonValue>>,
199        schema: Vec<(String, String)>,
200    ) -> Result<DataFrame, PolarsError> {
201        use chrono::{NaiveDate, NaiveDateTime};
202
203        let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
204
205        for (col_idx, (name, type_str)) in schema.iter().enumerate() {
206            let type_lower = type_str.trim().to_lowercase();
207            let s = match type_lower.as_str() {
208                "int" | "bigint" | "long" => {
209                    let vals: Vec<Option<i64>> = rows
210                        .iter()
211                        .map(|row| {
212                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
213                            match v {
214                                JsonValue::Number(n) => n.as_i64(),
215                                JsonValue::Null => None,
216                                _ => None,
217                            }
218                        })
219                        .collect();
220                    Series::new(name.as_str().into(), vals)
221                }
222                "double" | "float" | "double_precision" => {
223                    let vals: Vec<Option<f64>> = rows
224                        .iter()
225                        .map(|row| {
226                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
227                            match v {
228                                JsonValue::Number(n) => n.as_f64(),
229                                JsonValue::Null => None,
230                                _ => None,
231                            }
232                        })
233                        .collect();
234                    Series::new(name.as_str().into(), vals)
235                }
236                "string" | "str" | "varchar" => {
237                    let vals: Vec<Option<String>> = rows
238                        .iter()
239                        .map(|row| {
240                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
241                            match v {
242                                JsonValue::String(s) => Some(s),
243                                JsonValue::Null => None,
244                                other => Some(other.to_string()),
245                            }
246                        })
247                        .collect();
248                    Series::new(name.as_str().into(), vals)
249                }
250                "boolean" | "bool" => {
251                    let vals: Vec<Option<bool>> = rows
252                        .iter()
253                        .map(|row| {
254                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
255                            match v {
256                                JsonValue::Bool(b) => Some(b),
257                                JsonValue::Null => None,
258                                _ => None,
259                            }
260                        })
261                        .collect();
262                    Series::new(name.as_str().into(), vals)
263                }
264                "date" => {
265                    let epoch = NaiveDate::from_ymd_opt(1970, 1, 1)
266                        .ok_or_else(|| PolarsError::ComputeError("invalid epoch date".into()))?;
267                    let vals: Vec<Option<i32>> = rows
268                        .iter()
269                        .map(|row| {
270                            let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
271                            match v {
272                                JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
273                                    .ok()
274                                    .map(|d| (d - epoch).num_days() as i32),
275                                JsonValue::Null => None,
276                                _ => None,
277                            }
278                        })
279                        .collect();
280                    let series = Series::new(name.as_str().into(), vals);
281                    series
282                        .cast(&DataType::Date)
283                        .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
284                }
285                "timestamp" | "datetime" | "timestamp_ntz" => {
286                    let vals: Vec<Option<i64>> =
287                        rows.iter()
288                            .map(|row| {
289                                let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
290                                match v {
291                                    JsonValue::String(s) => {
292                                        let parsed = NaiveDateTime::parse_from_str(
293                                            &s,
294                                            "%Y-%m-%dT%H:%M:%S%.f",
295                                        )
296                                        .or_else(|_| {
297                                            NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
298                                        })
299                                        .or_else(|_| {
300                                            NaiveDate::parse_from_str(&s, "%Y-%m-%d")
301                                                .map(|d| d.and_hms_opt(0, 0, 0).unwrap())
302                                        });
303                                        parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
304                                    }
305                                    JsonValue::Number(n) => n.as_i64(),
306                                    JsonValue::Null => None,
307                                    _ => None,
308                                }
309                            })
310                            .collect();
311                    let series = Series::new(name.as_str().into(), vals);
312                    series
313                        .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
314                        .map_err(|e| {
315                            PolarsError::ComputeError(format!("datetime cast: {e}").into())
316                        })?
317                }
318                _ => {
319                    return Err(PolarsError::ComputeError(
320                        format!(
321                            "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
322                        )
323                        .into(),
324                    ));
325                }
326            };
327            cols.push(s);
328        }
329
330        let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
331        Ok(DataFrame::from_polars_with_options(
332            pl_df,
333            self.is_case_sensitive(),
334        ))
335    }
336
337    /// Read a CSV file.
338    ///
339    /// Uses Polars' CSV reader with default options:
340    /// - Header row is inferred (default: true)
341    /// - Schema is inferred from first 100 rows
342    ///
343    /// # Example
344    /// ```
345    /// use robin_sparkless::SparkSession;
346    ///
347    /// let spark = SparkSession::builder().app_name("test").get_or_create();
348    /// let df_result = spark.read_csv("data.csv");
349    /// // Handle the Result as appropriate in your application
350    /// ```
351    pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
352        use polars::prelude::*;
353        let path = path.as_ref();
354        let path_display = path.display();
355        // Use LazyCsvReader - call finish() to get LazyFrame, then collect
356        let lf = LazyCsvReader::new(path)
357            .with_has_header(true)
358            .with_infer_schema_length(Some(100))
359            .finish()
360            .map_err(|e| {
361                PolarsError::ComputeError(
362                    format!(
363                        "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
364                    )
365                    .into(),
366                )
367            })?;
368        let pl_df = lf.collect().map_err(|e| {
369            PolarsError::ComputeError(
370                format!("read_csv({path_display}): collect failed: {e}").into(),
371            )
372        })?;
373        Ok(crate::dataframe::DataFrame::from_polars_with_options(
374            pl_df,
375            self.is_case_sensitive(),
376        ))
377    }
378
379    /// Read a Parquet file.
380    ///
381    /// Uses Polars' Parquet reader. Parquet files have embedded schema, so
382    /// schema inference is automatic.
383    ///
384    /// # Example
385    /// ```
386    /// use robin_sparkless::SparkSession;
387    ///
388    /// let spark = SparkSession::builder().app_name("test").get_or_create();
389    /// let df_result = spark.read_parquet("data.parquet");
390    /// // Handle the Result as appropriate in your application
391    /// ```
392    pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
393        use polars::prelude::*;
394        let path = path.as_ref();
395        // Use LazyFrame::scan_parquet
396        let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
397        let pl_df = lf.collect()?;
398        Ok(crate::dataframe::DataFrame::from_polars_with_options(
399            pl_df,
400            self.is_case_sensitive(),
401        ))
402    }
403
404    /// Read a JSON file (JSONL format - one JSON object per line).
405    ///
406    /// Uses Polars' JSONL reader with default options:
407    /// - Schema is inferred from first 100 rows
408    ///
409    /// # Example
410    /// ```
411    /// use robin_sparkless::SparkSession;
412    ///
413    /// let spark = SparkSession::builder().app_name("test").get_or_create();
414    /// let df_result = spark.read_json("data.json");
415    /// // Handle the Result as appropriate in your application
416    /// ```
417    pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
418        use polars::prelude::*;
419        use std::num::NonZeroUsize;
420        let path = path.as_ref();
421        // Use LazyJsonLineReader - call finish() to get LazyFrame, then collect
422        let lf = LazyJsonLineReader::new(path)
423            .with_infer_schema_length(NonZeroUsize::new(100))
424            .finish()?;
425        let pl_df = lf.collect()?;
426        Ok(crate::dataframe::DataFrame::from_polars_with_options(
427            pl_df,
428            self.is_case_sensitive(),
429        ))
430    }
431
432    /// Execute a SQL query (SELECT only). Tables must be registered with `create_or_replace_temp_view`.
433    /// Requires the `sql` feature. Supports: SELECT (columns or *), FROM (single table or JOIN),
434    /// WHERE (basic predicates), GROUP BY + aggregates, ORDER BY, LIMIT.
435    #[cfg(feature = "sql")]
436    pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
437        crate::sql::execute_sql(self, query)
438    }
439
440    /// Execute a SQL query (stub when `sql` feature is disabled).
441    #[cfg(not(feature = "sql"))]
442    pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
443        Err(PolarsError::InvalidOperation(
444            "SQL queries require the 'sql' feature. Build with --features sql.".into(),
445        ))
446    }
447
448    /// Read a Delta table at the given path (latest version).
449    /// Requires the `delta` feature. Path can be local (e.g. `/tmp/table`) or `file:///...`.
450    #[cfg(feature = "delta")]
451    pub fn read_delta(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
452        crate::delta::read_delta(path, self.is_case_sensitive())
453    }
454
455    /// Read a Delta table at the given path, optionally at a specific version (time travel).
456    #[cfg(feature = "delta")]
457    pub fn read_delta_with_version(
458        &self,
459        path: impl AsRef<Path>,
460        version: Option<i64>,
461    ) -> Result<DataFrame, PolarsError> {
462        crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
463    }
464
465    /// Stub when `delta` feature is disabled.
466    #[cfg(not(feature = "delta"))]
467    pub fn read_delta(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
468        Err(PolarsError::InvalidOperation(
469            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
470        ))
471    }
472
473    #[cfg(not(feature = "delta"))]
474    pub fn read_delta_with_version(
475        &self,
476        _path: impl AsRef<Path>,
477        _version: Option<i64>,
478    ) -> Result<DataFrame, PolarsError> {
479        Err(PolarsError::InvalidOperation(
480            "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
481        ))
482    }
483
484    /// Stop the session (cleanup resources)
485    pub fn stop(&self) {
486        // Cleanup if needed
487    }
488}
489
490/// DataFrameReader for reading various file formats
491/// Similar to PySpark's DataFrameReader
492pub struct DataFrameReader {
493    session: SparkSession,
494}
495
496impl DataFrameReader {
497    pub fn new(session: SparkSession) -> Self {
498        DataFrameReader { session }
499    }
500
501    pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
502        self.session.read_csv(path)
503    }
504
505    pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
506        self.session.read_parquet(path)
507    }
508
509    pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
510        self.session.read_json(path)
511    }
512
513    #[cfg(feature = "delta")]
514    pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
515        self.session.read_delta(path)
516    }
517}
518
519impl SparkSession {
520    /// Get a DataFrameReader for reading files
521    pub fn read(&self) -> DataFrameReader {
522        DataFrameReader::new(SparkSession {
523            app_name: self.app_name.clone(),
524            master: self.master.clone(),
525            config: self.config.clone(),
526            catalog: self.catalog.clone(),
527        })
528    }
529}
530
531impl Default for SparkSession {
532    fn default() -> Self {
533        Self::builder().get_or_create()
534    }
535}
536
537#[cfg(test)]
538mod tests {
539    use super::*;
540
541    #[test]
542    fn test_spark_session_builder_basic() {
543        let spark = SparkSession::builder().app_name("test_app").get_or_create();
544
545        assert_eq!(spark.app_name, Some("test_app".to_string()));
546    }
547
548    #[test]
549    fn test_spark_session_builder_with_master() {
550        let spark = SparkSession::builder()
551            .app_name("test_app")
552            .master("local[*]")
553            .get_or_create();
554
555        assert_eq!(spark.app_name, Some("test_app".to_string()));
556        assert_eq!(spark.master, Some("local[*]".to_string()));
557    }
558
559    #[test]
560    fn test_spark_session_builder_with_config() {
561        let spark = SparkSession::builder()
562            .app_name("test_app")
563            .config("spark.executor.memory", "4g")
564            .config("spark.driver.memory", "2g")
565            .get_or_create();
566
567        assert_eq!(
568            spark.config.get("spark.executor.memory"),
569            Some(&"4g".to_string())
570        );
571        assert_eq!(
572            spark.config.get("spark.driver.memory"),
573            Some(&"2g".to_string())
574        );
575    }
576
577    #[test]
578    fn test_spark_session_default() {
579        let spark = SparkSession::default();
580        assert!(spark.app_name.is_none());
581        assert!(spark.master.is_none());
582        assert!(spark.config.is_empty());
583    }
584
585    #[test]
586    fn test_create_dataframe_success() {
587        let spark = SparkSession::builder().app_name("test").get_or_create();
588        let data = vec![
589            (1i64, 25i64, "Alice".to_string()),
590            (2i64, 30i64, "Bob".to_string()),
591        ];
592
593        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
594
595        assert!(result.is_ok());
596        let df = result.unwrap();
597        assert_eq!(df.count().unwrap(), 2);
598
599        let columns = df.columns().unwrap();
600        assert!(columns.contains(&"id".to_string()));
601        assert!(columns.contains(&"age".to_string()));
602        assert!(columns.contains(&"name".to_string()));
603    }
604
605    #[test]
606    fn test_create_dataframe_wrong_column_count() {
607        let spark = SparkSession::builder().app_name("test").get_or_create();
608        let data = vec![(1i64, 25i64, "Alice".to_string())];
609
610        // Too few columns
611        let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
612        assert!(result.is_err());
613
614        // Too many columns
615        let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
616        assert!(result.is_err());
617    }
618
619    #[test]
620    fn test_create_dataframe_empty() {
621        let spark = SparkSession::builder().app_name("test").get_or_create();
622        let data: Vec<(i64, i64, String)> = vec![];
623
624        let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
625
626        assert!(result.is_ok());
627        let df = result.unwrap();
628        assert_eq!(df.count().unwrap(), 0);
629    }
630
631    #[test]
632    fn test_create_dataframe_from_polars() {
633        use polars::prelude::df;
634
635        let spark = SparkSession::builder().app_name("test").get_or_create();
636        let polars_df = df!(
637            "x" => &[1, 2, 3],
638            "y" => &[4, 5, 6]
639        )
640        .unwrap();
641
642        let df = spark.create_dataframe_from_polars(polars_df);
643
644        assert_eq!(df.count().unwrap(), 3);
645        let columns = df.columns().unwrap();
646        assert!(columns.contains(&"x".to_string()));
647        assert!(columns.contains(&"y".to_string()));
648    }
649
650    #[test]
651    fn test_read_csv_file_not_found() {
652        let spark = SparkSession::builder().app_name("test").get_or_create();
653
654        let result = spark.read_csv("nonexistent_file.csv");
655
656        assert!(result.is_err());
657    }
658
659    #[test]
660    fn test_read_parquet_file_not_found() {
661        let spark = SparkSession::builder().app_name("test").get_or_create();
662
663        let result = spark.read_parquet("nonexistent_file.parquet");
664
665        assert!(result.is_err());
666    }
667
668    #[test]
669    fn test_read_json_file_not_found() {
670        let spark = SparkSession::builder().app_name("test").get_or_create();
671
672        let result = spark.read_json("nonexistent_file.json");
673
674        assert!(result.is_err());
675    }
676
677    #[test]
678    fn test_sql_returns_error_without_feature_or_unknown_table() {
679        let spark = SparkSession::builder().app_name("test").get_or_create();
680
681        let result = spark.sql("SELECT * FROM table");
682
683        assert!(result.is_err());
684        match result {
685            Err(PolarsError::InvalidOperation(msg)) => {
686                let s = msg.to_string();
687                // Without sql feature: "SQL queries require the 'sql' feature"
688                // With sql feature but no table: "Table or view 'table' not found" or parse error
689                assert!(
690                    s.contains("SQL") || s.contains("Table") || s.contains("feature"),
691                    "unexpected message: {s}"
692                );
693            }
694            _ => panic!("Expected InvalidOperation error"),
695        }
696    }
697
698    #[test]
699    fn test_spark_session_stop() {
700        let spark = SparkSession::builder().app_name("test").get_or_create();
701
702        // stop() should complete without error
703        spark.stop();
704    }
705
706    #[test]
707    fn test_dataframe_reader_api() {
708        let spark = SparkSession::builder().app_name("test").get_or_create();
709        let reader = spark.read();
710
711        // All readers should return errors for non-existent files
712        assert!(reader.csv("nonexistent.csv").is_err());
713        assert!(reader.parquet("nonexistent.parquet").is_err());
714        assert!(reader.json("nonexistent.json").is_err());
715    }
716
717    #[test]
718    fn test_read_csv_with_valid_file() {
719        use std::io::Write;
720        use tempfile::NamedTempFile;
721
722        let spark = SparkSession::builder().app_name("test").get_or_create();
723
724        // Create a temporary CSV file
725        let mut temp_file = NamedTempFile::new().unwrap();
726        writeln!(temp_file, "id,name,age").unwrap();
727        writeln!(temp_file, "1,Alice,25").unwrap();
728        writeln!(temp_file, "2,Bob,30").unwrap();
729        temp_file.flush().unwrap();
730
731        let result = spark.read_csv(temp_file.path());
732
733        assert!(result.is_ok());
734        let df = result.unwrap();
735        assert_eq!(df.count().unwrap(), 2);
736
737        let columns = df.columns().unwrap();
738        assert!(columns.contains(&"id".to_string()));
739        assert!(columns.contains(&"name".to_string()));
740        assert!(columns.contains(&"age".to_string()));
741    }
742
743    #[test]
744    fn test_read_json_with_valid_file() {
745        use std::io::Write;
746        use tempfile::NamedTempFile;
747
748        let spark = SparkSession::builder().app_name("test").get_or_create();
749
750        // Create a temporary JSONL file
751        let mut temp_file = NamedTempFile::new().unwrap();
752        writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
753        writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
754        temp_file.flush().unwrap();
755
756        let result = spark.read_json(temp_file.path());
757
758        assert!(result.is_ok());
759        let df = result.unwrap();
760        assert_eq!(df.count().unwrap(), 2);
761    }
762
763    #[test]
764    fn test_read_csv_empty_file() {
765        use std::io::Write;
766        use tempfile::NamedTempFile;
767
768        let spark = SparkSession::builder().app_name("test").get_or_create();
769
770        // Create an empty CSV file (just header)
771        let mut temp_file = NamedTempFile::new().unwrap();
772        writeln!(temp_file, "id,name").unwrap();
773        temp_file.flush().unwrap();
774
775        let result = spark.read_csv(temp_file.path());
776
777        assert!(result.is_ok());
778        let df = result.unwrap();
779        assert_eq!(df.count().unwrap(), 0);
780    }
781}