1use crate::dataframe::DataFrame;
2use polars::prelude::{
3 DataFrame as PlDataFrame, DataType, NamedFrom, PolarsError, Series, TimeUnit,
4};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
12pub struct SparkSessionBuilder {
13 app_name: Option<String>,
14 master: Option<String>,
15 config: HashMap<String, String>,
16}
17
18impl Default for SparkSessionBuilder {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl SparkSessionBuilder {
25 pub fn new() -> Self {
26 SparkSessionBuilder {
27 app_name: None,
28 master: None,
29 config: HashMap::new(),
30 }
31 }
32
33 pub fn app_name(mut self, name: impl Into<String>) -> Self {
34 self.app_name = Some(name.into());
35 self
36 }
37
38 pub fn master(mut self, master: impl Into<String>) -> Self {
39 self.master = Some(master.into());
40 self
41 }
42
43 pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
44 self.config.insert(key.into(), value.into());
45 self
46 }
47
48 pub fn get_or_create(self) -> SparkSession {
49 SparkSession::new(self.app_name, self.master, self.config)
50 }
51}
52
53pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
55
56#[derive(Clone)]
59pub struct SparkSession {
60 app_name: Option<String>,
61 master: Option<String>,
62 config: HashMap<String, String>,
63 pub(crate) catalog: TempViewCatalog,
65}
66
67impl SparkSession {
68 pub fn new(
69 app_name: Option<String>,
70 master: Option<String>,
71 config: HashMap<String, String>,
72 ) -> Self {
73 SparkSession {
74 app_name,
75 master,
76 config,
77 catalog: Arc::new(Mutex::new(HashMap::new())),
78 }
79 }
80
81 pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
84 let _ = self
85 .catalog
86 .lock()
87 .map(|mut m| m.insert(name.to_string(), df));
88 }
89
90 pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
92 self.create_or_replace_temp_view(name, df);
93 }
94
95 pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
97 self.create_or_replace_temp_view(name, df);
98 }
99
100 pub fn drop_temp_view(&self, name: &str) {
103 let _ = self.catalog.lock().map(|mut m| m.remove(name));
104 }
105
106 pub fn drop_global_temp_view(&self, name: &str) {
108 self.drop_temp_view(name);
109 }
110
111 pub fn table_exists(&self, name: &str) -> bool {
113 self.catalog
114 .lock()
115 .map(|m| m.contains_key(name))
116 .unwrap_or(false)
117 }
118
119 pub fn list_temp_view_names(&self) -> Vec<String> {
121 self.catalog
122 .lock()
123 .map(|m| m.keys().cloned().collect())
124 .unwrap_or_default()
125 }
126
127 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
130 self.catalog
131 .lock()
132 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
133 .get(name)
134 .cloned()
135 .ok_or_else(|| {
136 PolarsError::InvalidOperation(
137 format!(
138 "Table or view '{name}' not found. Register it with create_or_replace_temp_view."
139 )
140 .into(),
141 )
142 })
143 }
144
145 pub fn builder() -> SparkSessionBuilder {
146 SparkSessionBuilder::new()
147 }
148
149 pub fn get_config(&self) -> &HashMap<String, String> {
151 &self.config
152 }
153
154 pub fn is_case_sensitive(&self) -> bool {
157 self.config
158 .get("spark.sql.caseSensitive")
159 .map(|v| v.eq_ignore_ascii_case("true"))
160 .unwrap_or(false)
161 }
162
163 pub fn create_dataframe(
183 &self,
184 data: Vec<(i64, i64, String)>,
185 column_names: Vec<&str>,
186 ) -> Result<DataFrame, PolarsError> {
187 if column_names.len() != 3 {
188 return Err(PolarsError::ComputeError(
189 format!(
190 "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
191 column_names.len()
192 )
193 .into(),
194 ));
195 }
196
197 let mut cols: Vec<Series> = Vec::with_capacity(3);
198
199 let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
201 cols.push(Series::new(column_names[0].into(), col0));
202
203 let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
205 cols.push(Series::new(column_names[1].into(), col1));
206
207 let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
209 cols.push(Series::new(column_names[2].into(), col2));
210
211 let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
212 Ok(DataFrame::from_polars_with_options(
213 pl_df,
214 self.is_case_sensitive(),
215 ))
216 }
217
218 pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
220 DataFrame::from_polars_with_options(df, self.is_case_sensitive())
221 }
222
223 pub fn create_dataframe_from_rows(
229 &self,
230 rows: Vec<Vec<JsonValue>>,
231 schema: Vec<(String, String)>,
232 ) -> Result<DataFrame, PolarsError> {
233 if schema.is_empty() {
234 return Err(PolarsError::InvalidOperation(
235 "create_dataframe_from_rows: schema must not be empty".into(),
236 ));
237 }
238 use chrono::{NaiveDate, NaiveDateTime};
239
240 let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
241
242 for (col_idx, (name, type_str)) in schema.iter().enumerate() {
243 let type_lower = type_str.trim().to_lowercase();
244 let s = match type_lower.as_str() {
245 "int" | "bigint" | "long" => {
246 let vals: Vec<Option<i64>> = rows
247 .iter()
248 .map(|row| {
249 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
250 match v {
251 JsonValue::Number(n) => n.as_i64(),
252 JsonValue::Null => None,
253 _ => None,
254 }
255 })
256 .collect();
257 Series::new(name.as_str().into(), vals)
258 }
259 "double" | "float" | "double_precision" => {
260 let vals: Vec<Option<f64>> = rows
261 .iter()
262 .map(|row| {
263 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
264 match v {
265 JsonValue::Number(n) => n.as_f64(),
266 JsonValue::Null => None,
267 _ => None,
268 }
269 })
270 .collect();
271 Series::new(name.as_str().into(), vals)
272 }
273 "string" | "str" | "varchar" => {
274 let vals: Vec<Option<String>> = rows
275 .iter()
276 .map(|row| {
277 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
278 match v {
279 JsonValue::String(s) => Some(s),
280 JsonValue::Null => None,
281 other => Some(other.to_string()),
282 }
283 })
284 .collect();
285 Series::new(name.as_str().into(), vals)
286 }
287 "boolean" | "bool" => {
288 let vals: Vec<Option<bool>> = rows
289 .iter()
290 .map(|row| {
291 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
292 match v {
293 JsonValue::Bool(b) => Some(b),
294 JsonValue::Null => None,
295 _ => None,
296 }
297 })
298 .collect();
299 Series::new(name.as_str().into(), vals)
300 }
301 "date" => {
302 let epoch = crate::date_utils::epoch_naive_date();
303 let vals: Vec<Option<i32>> = rows
304 .iter()
305 .map(|row| {
306 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
307 match v {
308 JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
309 .ok()
310 .map(|d| (d - epoch).num_days() as i32),
311 JsonValue::Null => None,
312 _ => None,
313 }
314 })
315 .collect();
316 let series = Series::new(name.as_str().into(), vals);
317 series
318 .cast(&DataType::Date)
319 .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
320 }
321 "timestamp" | "datetime" | "timestamp_ntz" => {
322 let vals: Vec<Option<i64>> =
323 rows.iter()
324 .map(|row| {
325 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
326 match v {
327 JsonValue::String(s) => {
328 let parsed = NaiveDateTime::parse_from_str(
329 &s,
330 "%Y-%m-%dT%H:%M:%S%.f",
331 )
332 .or_else(|_| {
333 NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
334 })
335 .or_else(|_| {
336 NaiveDate::parse_from_str(&s, "%Y-%m-%d")
337 .map(|d| d.and_hms_opt(0, 0, 0).unwrap())
338 });
339 parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
340 }
341 JsonValue::Number(n) => n.as_i64(),
342 JsonValue::Null => None,
343 _ => None,
344 }
345 })
346 .collect();
347 let series = Series::new(name.as_str().into(), vals);
348 series
349 .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
350 .map_err(|e| {
351 PolarsError::ComputeError(format!("datetime cast: {e}").into())
352 })?
353 }
354 _ => {
355 return Err(PolarsError::ComputeError(
356 format!(
357 "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
358 )
359 .into(),
360 ));
361 }
362 };
363 cols.push(s);
364 }
365
366 let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
367 Ok(DataFrame::from_polars_with_options(
368 pl_df,
369 self.is_case_sensitive(),
370 ))
371 }
372
373 pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
380 if step == 0 {
381 return Err(PolarsError::InvalidOperation(
382 "range: step must not be 0".into(),
383 ));
384 }
385 let mut vals: Vec<i64> = Vec::new();
386 let mut v = start;
387 if step > 0 {
388 while v < end {
389 vals.push(v);
390 v = v.saturating_add(step);
391 }
392 } else {
393 while v > end {
394 vals.push(v);
395 v = v.saturating_add(step);
396 }
397 }
398 let col = Series::new("id".into(), vals);
399 let pl_df = PlDataFrame::new(vec![col.into()])?;
400 Ok(DataFrame::from_polars_with_options(
401 pl_df,
402 self.is_case_sensitive(),
403 ))
404 }
405
406 pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
421 use polars::prelude::*;
422 let path = path.as_ref();
423 let path_display = path.display();
424 let lf = LazyCsvReader::new(path)
426 .with_has_header(true)
427 .with_infer_schema_length(Some(100))
428 .finish()
429 .map_err(|e| {
430 PolarsError::ComputeError(
431 format!(
432 "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
433 )
434 .into(),
435 )
436 })?;
437 let pl_df = lf.collect().map_err(|e| {
438 PolarsError::ComputeError(
439 format!("read_csv({path_display}): collect failed: {e}").into(),
440 )
441 })?;
442 Ok(crate::dataframe::DataFrame::from_polars_with_options(
443 pl_df,
444 self.is_case_sensitive(),
445 ))
446 }
447
448 pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
462 use polars::prelude::*;
463 let path = path.as_ref();
464 let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
466 let pl_df = lf.collect()?;
467 Ok(crate::dataframe::DataFrame::from_polars_with_options(
468 pl_df,
469 self.is_case_sensitive(),
470 ))
471 }
472
473 pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
487 use polars::prelude::*;
488 use std::num::NonZeroUsize;
489 let path = path.as_ref();
490 let lf = LazyJsonLineReader::new(path)
492 .with_infer_schema_length(NonZeroUsize::new(100))
493 .finish()?;
494 let pl_df = lf.collect()?;
495 Ok(crate::dataframe::DataFrame::from_polars_with_options(
496 pl_df,
497 self.is_case_sensitive(),
498 ))
499 }
500
501 #[cfg(feature = "sql")]
505 pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
506 crate::sql::execute_sql(self, query)
507 }
508
509 #[cfg(not(feature = "sql"))]
511 pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
512 Err(PolarsError::InvalidOperation(
513 "SQL queries require the 'sql' feature. Build with --features sql.".into(),
514 ))
515 }
516
517 #[cfg(feature = "delta")]
520 pub fn read_delta(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
521 crate::delta::read_delta(path, self.is_case_sensitive())
522 }
523
524 #[cfg(feature = "delta")]
526 pub fn read_delta_with_version(
527 &self,
528 path: impl AsRef<Path>,
529 version: Option<i64>,
530 ) -> Result<DataFrame, PolarsError> {
531 crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
532 }
533
534 #[cfg(not(feature = "delta"))]
536 pub fn read_delta(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
537 Err(PolarsError::InvalidOperation(
538 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
539 ))
540 }
541
542 #[cfg(not(feature = "delta"))]
543 pub fn read_delta_with_version(
544 &self,
545 _path: impl AsRef<Path>,
546 _version: Option<i64>,
547 ) -> Result<DataFrame, PolarsError> {
548 Err(PolarsError::InvalidOperation(
549 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
550 ))
551 }
552
553 pub fn stop(&self) {
555 }
557}
558
559pub struct DataFrameReader {
562 session: SparkSession,
563 options: HashMap<String, String>,
564 format: Option<String>,
565}
566
567impl DataFrameReader {
568 pub fn new(session: SparkSession) -> Self {
569 DataFrameReader {
570 session,
571 options: HashMap::new(),
572 format: None,
573 }
574 }
575
576 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
578 self.options.insert(key.into(), value.into());
579 self
580 }
581
582 pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
584 for (k, v) in opts {
585 self.options.insert(k, v);
586 }
587 self
588 }
589
590 pub fn format(mut self, fmt: impl Into<String>) -> Self {
592 self.format = Some(fmt.into());
593 self
594 }
595
596 pub fn schema(self, _schema: impl Into<String>) -> Self {
598 self
599 }
600
601 pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
603 let path = path.as_ref();
604 let fmt = self.format.clone().or_else(|| {
605 path.extension()
606 .and_then(|e| e.to_str())
607 .map(|s| s.to_lowercase())
608 });
609 match fmt.as_deref() {
610 Some("parquet") => self.parquet(path),
611 Some("csv") => self.csv(path),
612 Some("json") | Some("jsonl") => self.json(path),
613 #[cfg(feature = "delta")]
614 Some("delta") => self.session.read_delta(path),
615 _ => Err(PolarsError::ComputeError(
616 format!(
617 "load: could not infer format for path '{}'. Use format('parquet'|'csv'|'json') before load.",
618 path.display()
619 )
620 .into(),
621 )),
622 }
623 }
624
625 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
627 self.session.table(name)
628 }
629
630 fn apply_csv_options(
631 &self,
632 reader: polars::prelude::LazyCsvReader,
633 ) -> polars::prelude::LazyCsvReader {
634 use polars::prelude::NullValues;
635 let mut r = reader;
636 if let Some(v) = self.options.get("header") {
637 let has_header = v.eq_ignore_ascii_case("true") || v == "1";
638 r = r.with_has_header(has_header);
639 }
640 if let Some(v) = self.options.get("inferSchema") {
641 if v.eq_ignore_ascii_case("true") || v == "1" {
642 let n = self
643 .options
644 .get("inferSchemaLength")
645 .and_then(|s| s.parse::<usize>().ok())
646 .unwrap_or(100);
647 r = r.with_infer_schema_length(Some(n));
648 }
649 } else if let Some(v) = self.options.get("inferSchemaLength") {
650 if let Ok(n) = v.parse::<usize>() {
651 r = r.with_infer_schema_length(Some(n));
652 }
653 }
654 if let Some(sep) = self.options.get("sep") {
655 if let Some(b) = sep.bytes().next() {
656 r = r.with_separator(b);
657 }
658 }
659 if let Some(null_val) = self.options.get("nullValue") {
660 r = r.with_null_values(Some(NullValues::AllColumnsSingle(null_val.clone().into())));
661 }
662 r
663 }
664
665 fn apply_json_options(
666 &self,
667 reader: polars::prelude::LazyJsonLineReader,
668 ) -> polars::prelude::LazyJsonLineReader {
669 use std::num::NonZeroUsize;
670 let mut r = reader;
671 if let Some(v) = self.options.get("inferSchemaLength") {
672 if let Ok(n) = v.parse::<usize>() {
673 r = r.with_infer_schema_length(NonZeroUsize::new(n));
674 }
675 }
676 r
677 }
678
679 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
680 use polars::prelude::*;
681 let path = path.as_ref();
682 let path_display = path.display();
683 let reader = LazyCsvReader::new(path);
684 let reader = if self.options.is_empty() {
685 reader
686 .with_has_header(true)
687 .with_infer_schema_length(Some(100))
688 } else {
689 self.apply_csv_options(
690 reader
691 .with_has_header(true)
692 .with_infer_schema_length(Some(100)),
693 )
694 };
695 let lf = reader.finish().map_err(|e| {
696 PolarsError::ComputeError(format!("read csv({path_display}): {e}").into())
697 })?;
698 let pl_df = lf.collect().map_err(|e| {
699 PolarsError::ComputeError(
700 format!("read csv({path_display}): collect failed: {e}").into(),
701 )
702 })?;
703 Ok(crate::dataframe::DataFrame::from_polars_with_options(
704 pl_df,
705 self.session.is_case_sensitive(),
706 ))
707 }
708
709 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
710 use polars::prelude::*;
711 let path = path.as_ref();
712 let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
713 let pl_df = lf.collect()?;
714 Ok(crate::dataframe::DataFrame::from_polars_with_options(
715 pl_df,
716 self.session.is_case_sensitive(),
717 ))
718 }
719
720 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
721 use polars::prelude::*;
722 use std::num::NonZeroUsize;
723 let path = path.as_ref();
724 let reader = LazyJsonLineReader::new(path);
725 let reader = if self.options.is_empty() {
726 reader.with_infer_schema_length(NonZeroUsize::new(100))
727 } else {
728 self.apply_json_options(reader.with_infer_schema_length(NonZeroUsize::new(100)))
729 };
730 let lf = reader.finish()?;
731 let pl_df = lf.collect()?;
732 Ok(crate::dataframe::DataFrame::from_polars_with_options(
733 pl_df,
734 self.session.is_case_sensitive(),
735 ))
736 }
737
738 #[cfg(feature = "delta")]
739 pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
740 self.session.read_delta(path)
741 }
742}
743
744impl SparkSession {
745 pub fn read(&self) -> DataFrameReader {
747 DataFrameReader::new(SparkSession {
748 app_name: self.app_name.clone(),
749 master: self.master.clone(),
750 config: self.config.clone(),
751 catalog: self.catalog.clone(),
752 })
753 }
754}
755
756impl Default for SparkSession {
757 fn default() -> Self {
758 Self::builder().get_or_create()
759 }
760}
761
762#[cfg(test)]
763mod tests {
764 use super::*;
765
766 #[test]
767 fn test_spark_session_builder_basic() {
768 let spark = SparkSession::builder().app_name("test_app").get_or_create();
769
770 assert_eq!(spark.app_name, Some("test_app".to_string()));
771 }
772
773 #[test]
774 fn test_spark_session_builder_with_master() {
775 let spark = SparkSession::builder()
776 .app_name("test_app")
777 .master("local[*]")
778 .get_or_create();
779
780 assert_eq!(spark.app_name, Some("test_app".to_string()));
781 assert_eq!(spark.master, Some("local[*]".to_string()));
782 }
783
784 #[test]
785 fn test_spark_session_builder_with_config() {
786 let spark = SparkSession::builder()
787 .app_name("test_app")
788 .config("spark.executor.memory", "4g")
789 .config("spark.driver.memory", "2g")
790 .get_or_create();
791
792 assert_eq!(
793 spark.config.get("spark.executor.memory"),
794 Some(&"4g".to_string())
795 );
796 assert_eq!(
797 spark.config.get("spark.driver.memory"),
798 Some(&"2g".to_string())
799 );
800 }
801
802 #[test]
803 fn test_spark_session_default() {
804 let spark = SparkSession::default();
805 assert!(spark.app_name.is_none());
806 assert!(spark.master.is_none());
807 assert!(spark.config.is_empty());
808 }
809
810 #[test]
811 fn test_create_dataframe_success() {
812 let spark = SparkSession::builder().app_name("test").get_or_create();
813 let data = vec![
814 (1i64, 25i64, "Alice".to_string()),
815 (2i64, 30i64, "Bob".to_string()),
816 ];
817
818 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
819
820 assert!(result.is_ok());
821 let df = result.unwrap();
822 assert_eq!(df.count().unwrap(), 2);
823
824 let columns = df.columns().unwrap();
825 assert!(columns.contains(&"id".to_string()));
826 assert!(columns.contains(&"age".to_string()));
827 assert!(columns.contains(&"name".to_string()));
828 }
829
830 #[test]
831 fn test_create_dataframe_wrong_column_count() {
832 let spark = SparkSession::builder().app_name("test").get_or_create();
833 let data = vec![(1i64, 25i64, "Alice".to_string())];
834
835 let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
837 assert!(result.is_err());
838
839 let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
841 assert!(result.is_err());
842 }
843
844 #[test]
845 fn test_create_dataframe_empty() {
846 let spark = SparkSession::builder().app_name("test").get_or_create();
847 let data: Vec<(i64, i64, String)> = vec![];
848
849 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
850
851 assert!(result.is_ok());
852 let df = result.unwrap();
853 assert_eq!(df.count().unwrap(), 0);
854 }
855
856 #[test]
857 fn test_create_dataframe_from_polars() {
858 use polars::prelude::df;
859
860 let spark = SparkSession::builder().app_name("test").get_or_create();
861 let polars_df = df!(
862 "x" => &[1, 2, 3],
863 "y" => &[4, 5, 6]
864 )
865 .unwrap();
866
867 let df = spark.create_dataframe_from_polars(polars_df);
868
869 assert_eq!(df.count().unwrap(), 3);
870 let columns = df.columns().unwrap();
871 assert!(columns.contains(&"x".to_string()));
872 assert!(columns.contains(&"y".to_string()));
873 }
874
875 #[test]
876 fn test_read_csv_file_not_found() {
877 let spark = SparkSession::builder().app_name("test").get_or_create();
878
879 let result = spark.read_csv("nonexistent_file.csv");
880
881 assert!(result.is_err());
882 }
883
884 #[test]
885 fn test_read_parquet_file_not_found() {
886 let spark = SparkSession::builder().app_name("test").get_or_create();
887
888 let result = spark.read_parquet("nonexistent_file.parquet");
889
890 assert!(result.is_err());
891 }
892
893 #[test]
894 fn test_read_json_file_not_found() {
895 let spark = SparkSession::builder().app_name("test").get_or_create();
896
897 let result = spark.read_json("nonexistent_file.json");
898
899 assert!(result.is_err());
900 }
901
902 #[test]
903 fn test_sql_returns_error_without_feature_or_unknown_table() {
904 let spark = SparkSession::builder().app_name("test").get_or_create();
905
906 let result = spark.sql("SELECT * FROM table");
907
908 assert!(result.is_err());
909 match result {
910 Err(PolarsError::InvalidOperation(msg)) => {
911 let s = msg.to_string();
912 assert!(
915 s.contains("SQL") || s.contains("Table") || s.contains("feature"),
916 "unexpected message: {s}"
917 );
918 }
919 _ => panic!("Expected InvalidOperation error"),
920 }
921 }
922
923 #[test]
924 fn test_spark_session_stop() {
925 let spark = SparkSession::builder().app_name("test").get_or_create();
926
927 spark.stop();
929 }
930
931 #[test]
932 fn test_dataframe_reader_api() {
933 let spark = SparkSession::builder().app_name("test").get_or_create();
934 let reader = spark.read();
935
936 assert!(reader.csv("nonexistent.csv").is_err());
938 assert!(reader.parquet("nonexistent.parquet").is_err());
939 assert!(reader.json("nonexistent.json").is_err());
940 }
941
942 #[test]
943 fn test_read_csv_with_valid_file() {
944 use std::io::Write;
945 use tempfile::NamedTempFile;
946
947 let spark = SparkSession::builder().app_name("test").get_or_create();
948
949 let mut temp_file = NamedTempFile::new().unwrap();
951 writeln!(temp_file, "id,name,age").unwrap();
952 writeln!(temp_file, "1,Alice,25").unwrap();
953 writeln!(temp_file, "2,Bob,30").unwrap();
954 temp_file.flush().unwrap();
955
956 let result = spark.read_csv(temp_file.path());
957
958 assert!(result.is_ok());
959 let df = result.unwrap();
960 assert_eq!(df.count().unwrap(), 2);
961
962 let columns = df.columns().unwrap();
963 assert!(columns.contains(&"id".to_string()));
964 assert!(columns.contains(&"name".to_string()));
965 assert!(columns.contains(&"age".to_string()));
966 }
967
968 #[test]
969 fn test_read_json_with_valid_file() {
970 use std::io::Write;
971 use tempfile::NamedTempFile;
972
973 let spark = SparkSession::builder().app_name("test").get_or_create();
974
975 let mut temp_file = NamedTempFile::new().unwrap();
977 writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
978 writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
979 temp_file.flush().unwrap();
980
981 let result = spark.read_json(temp_file.path());
982
983 assert!(result.is_ok());
984 let df = result.unwrap();
985 assert_eq!(df.count().unwrap(), 2);
986 }
987
988 #[test]
989 fn test_read_csv_empty_file() {
990 use std::io::Write;
991 use tempfile::NamedTempFile;
992
993 let spark = SparkSession::builder().app_name("test").get_or_create();
994
995 let mut temp_file = NamedTempFile::new().unwrap();
997 writeln!(temp_file, "id,name").unwrap();
998 temp_file.flush().unwrap();
999
1000 let result = spark.read_csv(temp_file.path());
1001
1002 assert!(result.is_ok());
1003 let df = result.unwrap();
1004 assert_eq!(df.count().unwrap(), 0);
1005 }
1006
1007 #[test]
1008 fn test_write_partitioned_parquet() {
1009 use crate::dataframe::{WriteFormat, WriteMode};
1010 use std::fs;
1011 use tempfile::TempDir;
1012
1013 let spark = SparkSession::builder().app_name("test").get_or_create();
1014 let df = spark
1015 .create_dataframe(
1016 vec![
1017 (1, 25, "Alice".to_string()),
1018 (2, 30, "Bob".to_string()),
1019 (3, 25, "Carol".to_string()),
1020 ],
1021 vec!["id", "age", "name"],
1022 )
1023 .unwrap();
1024 let dir = TempDir::new().unwrap();
1025 let path = dir.path().join("out");
1026 df.write()
1027 .mode(WriteMode::Overwrite)
1028 .format(WriteFormat::Parquet)
1029 .partition_by(["age"])
1030 .save(&path)
1031 .unwrap();
1032 assert!(path.is_dir());
1033 let entries: Vec<_> = fs::read_dir(&path).unwrap().collect();
1034 assert_eq!(
1035 entries.len(),
1036 2,
1037 "expected two partition dirs (age=25, age=30)"
1038 );
1039 let names: Vec<String> = entries
1040 .iter()
1041 .filter_map(|e| e.as_ref().ok())
1042 .map(|e| e.file_name().to_string_lossy().into_owned())
1043 .collect();
1044 assert!(names.iter().any(|n| n.starts_with("age=")));
1045 let df_read = spark.read_parquet(&path).unwrap();
1046 assert_eq!(df_read.count().unwrap(), 3);
1047 }
1048}