1use crate::dataframe::DataFrame;
2use polars::prelude::{
3 DataFrame as PlDataFrame, DataType, NamedFrom, PolarsError, Series, TimeUnit,
4};
5use serde_json::Value as JsonValue;
6use std::collections::HashMap;
7use std::path::Path;
8use std::sync::{Arc, Mutex};
9
10#[derive(Clone)]
12pub struct SparkSessionBuilder {
13 app_name: Option<String>,
14 master: Option<String>,
15 config: HashMap<String, String>,
16}
17
18impl Default for SparkSessionBuilder {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl SparkSessionBuilder {
25 pub fn new() -> Self {
26 SparkSessionBuilder {
27 app_name: None,
28 master: None,
29 config: HashMap::new(),
30 }
31 }
32
33 pub fn app_name(mut self, name: impl Into<String>) -> Self {
34 self.app_name = Some(name.into());
35 self
36 }
37
38 pub fn master(mut self, master: impl Into<String>) -> Self {
39 self.master = Some(master.into());
40 self
41 }
42
43 pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
44 self.config.insert(key.into(), value.into());
45 self
46 }
47
48 pub fn get_or_create(self) -> SparkSession {
49 SparkSession::new(self.app_name, self.master, self.config)
50 }
51}
52
53pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
55
56#[derive(Clone)]
59pub struct SparkSession {
60 app_name: Option<String>,
61 master: Option<String>,
62 config: HashMap<String, String>,
63 pub(crate) catalog: TempViewCatalog,
65}
66
67impl SparkSession {
68 pub fn new(
69 app_name: Option<String>,
70 master: Option<String>,
71 config: HashMap<String, String>,
72 ) -> Self {
73 SparkSession {
74 app_name,
75 master,
76 config,
77 catalog: Arc::new(Mutex::new(HashMap::new())),
78 }
79 }
80
81 pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
84 let _ = self
85 .catalog
86 .lock()
87 .map(|mut m| m.insert(name.to_string(), df));
88 }
89
90 pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
92 self.create_or_replace_temp_view(name, df);
93 }
94
95 pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
97 self.create_or_replace_temp_view(name, df);
98 }
99
100 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
103 self.catalog
104 .lock()
105 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
106 .get(name)
107 .cloned()
108 .ok_or_else(|| {
109 PolarsError::InvalidOperation(
110 format!(
111 "Table or view '{name}' not found. Register it with create_or_replace_temp_view."
112 )
113 .into(),
114 )
115 })
116 }
117
118 pub fn builder() -> SparkSessionBuilder {
119 SparkSessionBuilder::new()
120 }
121
122 pub fn is_case_sensitive(&self) -> bool {
125 self.config
126 .get("spark.sql.caseSensitive")
127 .map(|v| v.eq_ignore_ascii_case("true"))
128 .unwrap_or(false)
129 }
130
131 pub fn create_dataframe(
151 &self,
152 data: Vec<(i64, i64, String)>,
153 column_names: Vec<&str>,
154 ) -> Result<DataFrame, PolarsError> {
155 if column_names.len() != 3 {
156 return Err(PolarsError::ComputeError(
157 format!(
158 "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
159 column_names.len()
160 )
161 .into(),
162 ));
163 }
164
165 let mut cols: Vec<Series> = Vec::with_capacity(3);
166
167 let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
169 cols.push(Series::new(column_names[0].into(), col0));
170
171 let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
173 cols.push(Series::new(column_names[1].into(), col1));
174
175 let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
177 cols.push(Series::new(column_names[2].into(), col2));
178
179 let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
180 Ok(DataFrame::from_polars_with_options(
181 pl_df,
182 self.is_case_sensitive(),
183 ))
184 }
185
186 pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
188 DataFrame::from_polars_with_options(df, self.is_case_sensitive())
189 }
190
191 pub fn create_dataframe_from_rows(
197 &self,
198 rows: Vec<Vec<JsonValue>>,
199 schema: Vec<(String, String)>,
200 ) -> Result<DataFrame, PolarsError> {
201 use chrono::{NaiveDate, NaiveDateTime};
202
203 let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
204
205 for (col_idx, (name, type_str)) in schema.iter().enumerate() {
206 let type_lower = type_str.trim().to_lowercase();
207 let s = match type_lower.as_str() {
208 "int" | "bigint" | "long" => {
209 let vals: Vec<Option<i64>> = rows
210 .iter()
211 .map(|row| {
212 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
213 match v {
214 JsonValue::Number(n) => n.as_i64(),
215 JsonValue::Null => None,
216 _ => None,
217 }
218 })
219 .collect();
220 Series::new(name.as_str().into(), vals)
221 }
222 "double" | "float" | "double_precision" => {
223 let vals: Vec<Option<f64>> = rows
224 .iter()
225 .map(|row| {
226 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
227 match v {
228 JsonValue::Number(n) => n.as_f64(),
229 JsonValue::Null => None,
230 _ => None,
231 }
232 })
233 .collect();
234 Series::new(name.as_str().into(), vals)
235 }
236 "string" | "str" | "varchar" => {
237 let vals: Vec<Option<String>> = rows
238 .iter()
239 .map(|row| {
240 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
241 match v {
242 JsonValue::String(s) => Some(s),
243 JsonValue::Null => None,
244 other => Some(other.to_string()),
245 }
246 })
247 .collect();
248 Series::new(name.as_str().into(), vals)
249 }
250 "boolean" | "bool" => {
251 let vals: Vec<Option<bool>> = rows
252 .iter()
253 .map(|row| {
254 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
255 match v {
256 JsonValue::Bool(b) => Some(b),
257 JsonValue::Null => None,
258 _ => None,
259 }
260 })
261 .collect();
262 Series::new(name.as_str().into(), vals)
263 }
264 "date" => {
265 let epoch = NaiveDate::from_ymd_opt(1970, 1, 1)
266 .ok_or_else(|| PolarsError::ComputeError("invalid epoch date".into()))?;
267 let vals: Vec<Option<i32>> = rows
268 .iter()
269 .map(|row| {
270 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
271 match v {
272 JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
273 .ok()
274 .map(|d| (d - epoch).num_days() as i32),
275 JsonValue::Null => None,
276 _ => None,
277 }
278 })
279 .collect();
280 let series = Series::new(name.as_str().into(), vals);
281 series
282 .cast(&DataType::Date)
283 .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
284 }
285 "timestamp" | "datetime" | "timestamp_ntz" => {
286 let vals: Vec<Option<i64>> =
287 rows.iter()
288 .map(|row| {
289 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
290 match v {
291 JsonValue::String(s) => {
292 let parsed = NaiveDateTime::parse_from_str(
293 &s,
294 "%Y-%m-%dT%H:%M:%S%.f",
295 )
296 .or_else(|_| {
297 NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
298 })
299 .or_else(|_| {
300 NaiveDate::parse_from_str(&s, "%Y-%m-%d")
301 .map(|d| d.and_hms_opt(0, 0, 0).unwrap())
302 });
303 parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
304 }
305 JsonValue::Number(n) => n.as_i64(),
306 JsonValue::Null => None,
307 _ => None,
308 }
309 })
310 .collect();
311 let series = Series::new(name.as_str().into(), vals);
312 series
313 .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
314 .map_err(|e| {
315 PolarsError::ComputeError(format!("datetime cast: {e}").into())
316 })?
317 }
318 _ => {
319 return Err(PolarsError::ComputeError(
320 format!(
321 "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
322 )
323 .into(),
324 ));
325 }
326 };
327 cols.push(s);
328 }
329
330 let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
331 Ok(DataFrame::from_polars_with_options(
332 pl_df,
333 self.is_case_sensitive(),
334 ))
335 }
336
337 pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
352 use polars::prelude::*;
353 let path = path.as_ref();
354 let path_display = path.display();
355 let lf = LazyCsvReader::new(path)
357 .with_has_header(true)
358 .with_infer_schema_length(Some(100))
359 .finish()
360 .map_err(|e| {
361 PolarsError::ComputeError(
362 format!(
363 "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
364 )
365 .into(),
366 )
367 })?;
368 let pl_df = lf.collect().map_err(|e| {
369 PolarsError::ComputeError(
370 format!("read_csv({path_display}): collect failed: {e}").into(),
371 )
372 })?;
373 Ok(crate::dataframe::DataFrame::from_polars_with_options(
374 pl_df,
375 self.is_case_sensitive(),
376 ))
377 }
378
379 pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
393 use polars::prelude::*;
394 let path = path.as_ref();
395 let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
397 let pl_df = lf.collect()?;
398 Ok(crate::dataframe::DataFrame::from_polars_with_options(
399 pl_df,
400 self.is_case_sensitive(),
401 ))
402 }
403
404 pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
418 use polars::prelude::*;
419 use std::num::NonZeroUsize;
420 let path = path.as_ref();
421 let lf = LazyJsonLineReader::new(path)
423 .with_infer_schema_length(NonZeroUsize::new(100))
424 .finish()?;
425 let pl_df = lf.collect()?;
426 Ok(crate::dataframe::DataFrame::from_polars_with_options(
427 pl_df,
428 self.is_case_sensitive(),
429 ))
430 }
431
432 #[cfg(feature = "sql")]
436 pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
437 crate::sql::execute_sql(self, query)
438 }
439
440 #[cfg(not(feature = "sql"))]
442 pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
443 Err(PolarsError::InvalidOperation(
444 "SQL queries require the 'sql' feature. Build with --features sql.".into(),
445 ))
446 }
447
448 #[cfg(feature = "delta")]
451 pub fn read_delta(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
452 crate::delta::read_delta(path, self.is_case_sensitive())
453 }
454
455 #[cfg(feature = "delta")]
457 pub fn read_delta_with_version(
458 &self,
459 path: impl AsRef<Path>,
460 version: Option<i64>,
461 ) -> Result<DataFrame, PolarsError> {
462 crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
463 }
464
465 #[cfg(not(feature = "delta"))]
467 pub fn read_delta(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
468 Err(PolarsError::InvalidOperation(
469 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
470 ))
471 }
472
473 #[cfg(not(feature = "delta"))]
474 pub fn read_delta_with_version(
475 &self,
476 _path: impl AsRef<Path>,
477 _version: Option<i64>,
478 ) -> Result<DataFrame, PolarsError> {
479 Err(PolarsError::InvalidOperation(
480 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
481 ))
482 }
483
484 pub fn stop(&self) {
486 }
488}
489
490pub struct DataFrameReader {
493 session: SparkSession,
494}
495
496impl DataFrameReader {
497 pub fn new(session: SparkSession) -> Self {
498 DataFrameReader { session }
499 }
500
501 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
502 self.session.read_csv(path)
503 }
504
505 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
506 self.session.read_parquet(path)
507 }
508
509 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
510 self.session.read_json(path)
511 }
512
513 #[cfg(feature = "delta")]
514 pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
515 self.session.read_delta(path)
516 }
517}
518
519impl SparkSession {
520 pub fn read(&self) -> DataFrameReader {
522 DataFrameReader::new(SparkSession {
523 app_name: self.app_name.clone(),
524 master: self.master.clone(),
525 config: self.config.clone(),
526 catalog: self.catalog.clone(),
527 })
528 }
529}
530
531impl Default for SparkSession {
532 fn default() -> Self {
533 Self::builder().get_or_create()
534 }
535}
536
537#[cfg(test)]
538mod tests {
539 use super::*;
540
541 #[test]
542 fn test_spark_session_builder_basic() {
543 let spark = SparkSession::builder().app_name("test_app").get_or_create();
544
545 assert_eq!(spark.app_name, Some("test_app".to_string()));
546 }
547
548 #[test]
549 fn test_spark_session_builder_with_master() {
550 let spark = SparkSession::builder()
551 .app_name("test_app")
552 .master("local[*]")
553 .get_or_create();
554
555 assert_eq!(spark.app_name, Some("test_app".to_string()));
556 assert_eq!(spark.master, Some("local[*]".to_string()));
557 }
558
559 #[test]
560 fn test_spark_session_builder_with_config() {
561 let spark = SparkSession::builder()
562 .app_name("test_app")
563 .config("spark.executor.memory", "4g")
564 .config("spark.driver.memory", "2g")
565 .get_or_create();
566
567 assert_eq!(
568 spark.config.get("spark.executor.memory"),
569 Some(&"4g".to_string())
570 );
571 assert_eq!(
572 spark.config.get("spark.driver.memory"),
573 Some(&"2g".to_string())
574 );
575 }
576
577 #[test]
578 fn test_spark_session_default() {
579 let spark = SparkSession::default();
580 assert!(spark.app_name.is_none());
581 assert!(spark.master.is_none());
582 assert!(spark.config.is_empty());
583 }
584
585 #[test]
586 fn test_create_dataframe_success() {
587 let spark = SparkSession::builder().app_name("test").get_or_create();
588 let data = vec![
589 (1i64, 25i64, "Alice".to_string()),
590 (2i64, 30i64, "Bob".to_string()),
591 ];
592
593 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
594
595 assert!(result.is_ok());
596 let df = result.unwrap();
597 assert_eq!(df.count().unwrap(), 2);
598
599 let columns = df.columns().unwrap();
600 assert!(columns.contains(&"id".to_string()));
601 assert!(columns.contains(&"age".to_string()));
602 assert!(columns.contains(&"name".to_string()));
603 }
604
605 #[test]
606 fn test_create_dataframe_wrong_column_count() {
607 let spark = SparkSession::builder().app_name("test").get_or_create();
608 let data = vec![(1i64, 25i64, "Alice".to_string())];
609
610 let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
612 assert!(result.is_err());
613
614 let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
616 assert!(result.is_err());
617 }
618
619 #[test]
620 fn test_create_dataframe_empty() {
621 let spark = SparkSession::builder().app_name("test").get_or_create();
622 let data: Vec<(i64, i64, String)> = vec![];
623
624 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
625
626 assert!(result.is_ok());
627 let df = result.unwrap();
628 assert_eq!(df.count().unwrap(), 0);
629 }
630
631 #[test]
632 fn test_create_dataframe_from_polars() {
633 use polars::prelude::df;
634
635 let spark = SparkSession::builder().app_name("test").get_or_create();
636 let polars_df = df!(
637 "x" => &[1, 2, 3],
638 "y" => &[4, 5, 6]
639 )
640 .unwrap();
641
642 let df = spark.create_dataframe_from_polars(polars_df);
643
644 assert_eq!(df.count().unwrap(), 3);
645 let columns = df.columns().unwrap();
646 assert!(columns.contains(&"x".to_string()));
647 assert!(columns.contains(&"y".to_string()));
648 }
649
650 #[test]
651 fn test_read_csv_file_not_found() {
652 let spark = SparkSession::builder().app_name("test").get_or_create();
653
654 let result = spark.read_csv("nonexistent_file.csv");
655
656 assert!(result.is_err());
657 }
658
659 #[test]
660 fn test_read_parquet_file_not_found() {
661 let spark = SparkSession::builder().app_name("test").get_or_create();
662
663 let result = spark.read_parquet("nonexistent_file.parquet");
664
665 assert!(result.is_err());
666 }
667
668 #[test]
669 fn test_read_json_file_not_found() {
670 let spark = SparkSession::builder().app_name("test").get_or_create();
671
672 let result = spark.read_json("nonexistent_file.json");
673
674 assert!(result.is_err());
675 }
676
677 #[test]
678 fn test_sql_returns_error_without_feature_or_unknown_table() {
679 let spark = SparkSession::builder().app_name("test").get_or_create();
680
681 let result = spark.sql("SELECT * FROM table");
682
683 assert!(result.is_err());
684 match result {
685 Err(PolarsError::InvalidOperation(msg)) => {
686 let s = msg.to_string();
687 assert!(
690 s.contains("SQL") || s.contains("Table") || s.contains("feature"),
691 "unexpected message: {s}"
692 );
693 }
694 _ => panic!("Expected InvalidOperation error"),
695 }
696 }
697
698 #[test]
699 fn test_spark_session_stop() {
700 let spark = SparkSession::builder().app_name("test").get_or_create();
701
702 spark.stop();
704 }
705
706 #[test]
707 fn test_dataframe_reader_api() {
708 let spark = SparkSession::builder().app_name("test").get_or_create();
709 let reader = spark.read();
710
711 assert!(reader.csv("nonexistent.csv").is_err());
713 assert!(reader.parquet("nonexistent.parquet").is_err());
714 assert!(reader.json("nonexistent.json").is_err());
715 }
716
717 #[test]
718 fn test_read_csv_with_valid_file() {
719 use std::io::Write;
720 use tempfile::NamedTempFile;
721
722 let spark = SparkSession::builder().app_name("test").get_or_create();
723
724 let mut temp_file = NamedTempFile::new().unwrap();
726 writeln!(temp_file, "id,name,age").unwrap();
727 writeln!(temp_file, "1,Alice,25").unwrap();
728 writeln!(temp_file, "2,Bob,30").unwrap();
729 temp_file.flush().unwrap();
730
731 let result = spark.read_csv(temp_file.path());
732
733 assert!(result.is_ok());
734 let df = result.unwrap();
735 assert_eq!(df.count().unwrap(), 2);
736
737 let columns = df.columns().unwrap();
738 assert!(columns.contains(&"id".to_string()));
739 assert!(columns.contains(&"name".to_string()));
740 assert!(columns.contains(&"age".to_string()));
741 }
742
743 #[test]
744 fn test_read_json_with_valid_file() {
745 use std::io::Write;
746 use tempfile::NamedTempFile;
747
748 let spark = SparkSession::builder().app_name("test").get_or_create();
749
750 let mut temp_file = NamedTempFile::new().unwrap();
752 writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
753 writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
754 temp_file.flush().unwrap();
755
756 let result = spark.read_json(temp_file.path());
757
758 assert!(result.is_ok());
759 let df = result.unwrap();
760 assert_eq!(df.count().unwrap(), 2);
761 }
762
763 #[test]
764 fn test_read_csv_empty_file() {
765 use std::io::Write;
766 use tempfile::NamedTempFile;
767
768 let spark = SparkSession::builder().app_name("test").get_or_create();
769
770 let mut temp_file = NamedTempFile::new().unwrap();
772 writeln!(temp_file, "id,name").unwrap();
773 temp_file.flush().unwrap();
774
775 let result = spark.read_csv(temp_file.path());
776
777 assert!(result.is_ok());
778 let df = result.unwrap();
779 assert_eq!(df.count().unwrap(), 0);
780 }
781}