1use crate::dataframe::DataFrame;
2use crate::udf_registry::UdfRegistry;
3use polars::chunked_array::builder::get_list_builder;
4use polars::chunked_array::StructChunked;
5use polars::prelude::{
6 DataFrame as PlDataFrame, DataType, IntoSeries, NamedFrom, PlSmallStr, PolarsError, Series,
7 TimeUnit,
8};
9use serde_json::Value as JsonValue;
10use std::cell::RefCell;
11
12fn parse_array_element_type(type_str: &str) -> Option<String> {
14 let s = type_str.trim();
15 if !s.to_lowercase().starts_with("array<") || !s.ends_with('>') {
16 return None;
17 }
18 Some(s[6..s.len() - 1].trim().to_string())
19}
20
21fn parse_struct_fields(type_str: &str) -> Option<Vec<(String, String)>> {
23 let s = type_str.trim();
24 if !s.to_lowercase().starts_with("struct<") || !s.ends_with('>') {
25 return None;
26 }
27 let inner = s[7..s.len() - 1].trim();
28 if inner.is_empty() {
29 return Some(Vec::new());
30 }
31 let mut out = Vec::new();
32 for part in inner.split(',') {
33 let part = part.trim();
34 if let Some(idx) = part.find(':') {
35 let name = part[..idx].trim().to_string();
36 let typ = part[idx + 1..].trim().to_string();
37 out.push((name, typ));
38 }
39 }
40 Some(out)
41}
42
43fn json_type_str_to_polars(type_str: &str) -> Option<DataType> {
45 match type_str.trim().to_lowercase().as_str() {
46 "int" | "bigint" | "long" => Some(DataType::Int64),
47 "double" | "float" | "double_precision" => Some(DataType::Float64),
48 "string" | "str" | "varchar" => Some(DataType::String),
49 "boolean" | "bool" => Some(DataType::Boolean),
50 _ => None,
51 }
52}
53
54fn json_values_to_series(
56 values: &[Option<JsonValue>],
57 type_str: &str,
58 name: &str,
59) -> Result<Series, PolarsError> {
60 use chrono::{NaiveDate, NaiveDateTime};
61 let epoch = crate::date_utils::epoch_naive_date();
62 let type_lower = type_str.trim().to_lowercase();
63
64 if let Some(elem_type) = parse_array_element_type(&type_lower) {
65 let inner_dtype = json_type_str_to_polars(&elem_type).ok_or_else(|| {
66 PolarsError::ComputeError(
67 format!("array element type '{elem_type}' not supported").into(),
68 )
69 })?;
70 let mut builder = get_list_builder(&inner_dtype, 64, values.len(), name.into());
71 for v in values.iter() {
72 if v.as_ref().is_none_or(|x| matches!(x, JsonValue::Null)) {
73 builder.append_null();
74 } else if let Some(arr) = v.as_ref().and_then(|x| x.as_array()) {
75 let elem_series: Vec<Series> = arr
76 .iter()
77 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
78 .collect::<Result<Vec<_>, _>>()?;
79 let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
80 let s = Series::from_any_values_and_dtype(
81 PlSmallStr::EMPTY,
82 &vals,
83 &inner_dtype,
84 false,
85 )
86 .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
87 builder.append_series(&s)?;
88 } else {
89 return Err(PolarsError::ComputeError(
90 "array column value must be null or array".into(),
91 ));
92 }
93 }
94 return Ok(builder.finish().into_series());
95 }
96
97 if let Some(fields) = parse_struct_fields(&type_lower) {
98 let mut field_series_vec: Vec<Vec<Option<JsonValue>>> = (0..fields.len())
99 .map(|_| Vec::with_capacity(values.len()))
100 .collect();
101 for v in values.iter() {
102 if v.as_ref().is_none_or(|x| matches!(x, JsonValue::Null)) {
103 for fc in &mut field_series_vec {
104 fc.push(None);
105 }
106 } else if let Some(obj) = v.as_ref().and_then(|x| x.as_object()) {
107 for (fi, (fname, _)) in fields.iter().enumerate() {
108 field_series_vec[fi].push(obj.get(fname).cloned());
109 }
110 } else if let Some(arr) = v.as_ref().and_then(|x| x.as_array()) {
111 for (fi, _) in fields.iter().enumerate() {
112 field_series_vec[fi].push(arr.get(fi).cloned());
113 }
114 } else {
115 return Err(PolarsError::ComputeError(
116 "struct value must be object or array".into(),
117 ));
118 }
119 }
120 let series_per_field: Vec<Series> = fields
121 .iter()
122 .enumerate()
123 .map(|(fi, (fname, ftype))| json_values_to_series(&field_series_vec[fi], ftype, fname))
124 .collect::<Result<Vec<_>, _>>()?;
125 let field_refs: Vec<&Series> = series_per_field.iter().collect();
126 let st = StructChunked::from_series(name.into(), values.len(), field_refs.iter().copied())
127 .map_err(|e| PolarsError::ComputeError(format!("struct column: {e}").into()))?
128 .into_series();
129 return Ok(st);
130 }
131
132 match type_lower.as_str() {
133 "int" | "bigint" | "long" => {
134 let vals: Vec<Option<i64>> = values
135 .iter()
136 .map(|ov| {
137 ov.as_ref().and_then(|v| match v {
138 JsonValue::Number(n) => n.as_i64(),
139 JsonValue::Null => None,
140 _ => None,
141 })
142 })
143 .collect();
144 Ok(Series::new(name.into(), vals))
145 }
146 "double" | "float" => {
147 let vals: Vec<Option<f64>> = values
148 .iter()
149 .map(|ov| {
150 ov.as_ref().and_then(|v| match v {
151 JsonValue::Number(n) => n.as_f64(),
152 JsonValue::Null => None,
153 _ => None,
154 })
155 })
156 .collect();
157 Ok(Series::new(name.into(), vals))
158 }
159 "string" | "str" | "varchar" => {
160 let vals: Vec<Option<&str>> = values
161 .iter()
162 .map(|ov| {
163 ov.as_ref().and_then(|v| match v {
164 JsonValue::String(s) => Some(s.as_str()),
165 JsonValue::Null => None,
166 _ => None,
167 })
168 })
169 .collect();
170 let owned: Vec<Option<String>> =
171 vals.into_iter().map(|o| o.map(|s| s.to_string())).collect();
172 Ok(Series::new(name.into(), owned))
173 }
174 "boolean" | "bool" => {
175 let vals: Vec<Option<bool>> = values
176 .iter()
177 .map(|ov| {
178 ov.as_ref().and_then(|v| match v {
179 JsonValue::Bool(b) => Some(*b),
180 JsonValue::Null => None,
181 _ => None,
182 })
183 })
184 .collect();
185 Ok(Series::new(name.into(), vals))
186 }
187 "date" => {
188 let vals: Vec<Option<i32>> = values
189 .iter()
190 .map(|ov| {
191 ov.as_ref().and_then(|v| match v {
192 JsonValue::String(s) => NaiveDate::parse_from_str(s, "%Y-%m-%d")
193 .ok()
194 .map(|d| (d - epoch).num_days() as i32),
195 JsonValue::Null => None,
196 _ => None,
197 })
198 })
199 .collect();
200 let s = Series::new(name.into(), vals);
201 s.cast(&DataType::Date)
202 .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))
203 }
204 "timestamp" | "datetime" | "timestamp_ntz" => {
205 let vals: Vec<Option<i64>> = values
206 .iter()
207 .map(|ov| {
208 ov.as_ref().and_then(|v| match v {
209 JsonValue::String(s) => {
210 let parsed = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
211 .or_else(|_| NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S"))
212 .or_else(|_| {
213 NaiveDate::parse_from_str(s, "%Y-%m-%d")
214 .map(|d| d.and_hms_opt(0, 0, 0).unwrap())
215 });
216 parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
217 }
218 JsonValue::Number(n) => n.as_i64(),
219 JsonValue::Null => None,
220 _ => None,
221 })
222 })
223 .collect();
224 let s = Series::new(name.into(), vals);
225 s.cast(&DataType::Datetime(TimeUnit::Microseconds, None))
226 .map_err(|e| PolarsError::ComputeError(format!("datetime cast: {e}").into()))
227 }
228 _ => Err(PolarsError::ComputeError(
229 format!("json_values_to_series: unsupported type '{type_str}'").into(),
230 )),
231 }
232}
233
234fn json_value_to_series_single(
236 value: &JsonValue,
237 type_str: &str,
238 name: &str,
239) -> Result<Series, PolarsError> {
240 use chrono::NaiveDate;
241 let epoch = crate::date_utils::epoch_naive_date();
242 match (value, type_str.trim().to_lowercase().as_str()) {
243 (JsonValue::Null, _) => Ok(Series::new_null(name.into(), 1)),
244 (JsonValue::Number(n), "int" | "bigint" | "long") => {
245 Ok(Series::new(name.into(), vec![n.as_i64()]))
246 }
247 (JsonValue::Number(n), "double" | "float") => {
248 Ok(Series::new(name.into(), vec![n.as_f64()]))
249 }
250 (JsonValue::String(s), "string" | "str" | "varchar") => {
251 Ok(Series::new(name.into(), vec![s.as_str()]))
252 }
253 (JsonValue::Bool(b), "boolean" | "bool") => Ok(Series::new(name.into(), vec![*b])),
254 (JsonValue::String(s), "date") => {
255 let d = NaiveDate::parse_from_str(s, "%Y-%m-%d")
256 .map_err(|e| PolarsError::ComputeError(format!("date parse: {e}").into()))?;
257 let days = (d - epoch).num_days() as i32;
258 let s = Series::new(name.into(), vec![days]).cast(&DataType::Date)?;
259 Ok(s)
260 }
261 _ => Err(PolarsError::ComputeError(
262 format!("json_value_to_series: unsupported {type_str} for {value:?}").into(),
263 )),
264 }
265}
266
267#[allow(dead_code)]
269fn json_object_or_array_to_struct_series(
270 value: &JsonValue,
271 fields: &[(String, String)],
272 _name: &str,
273) -> Result<Option<Series>, PolarsError> {
274 use polars::prelude::StructChunked;
275 if matches!(value, JsonValue::Null) {
276 return Ok(None);
277 }
278 let mut field_series: Vec<Series> = Vec::with_capacity(fields.len());
279 for (fname, ftype) in fields {
280 let fval = if let Some(obj) = value.as_object() {
281 obj.get(fname).unwrap_or(&JsonValue::Null)
282 } else if let Some(arr) = value.as_array() {
283 let idx = field_series.len();
284 arr.get(idx).unwrap_or(&JsonValue::Null)
285 } else {
286 return Err(PolarsError::ComputeError(
287 "struct value must be object or array".into(),
288 ));
289 };
290 let s = json_value_to_series_single(fval, ftype, fname)?;
291 field_series.push(s);
292 }
293 let field_refs: Vec<&Series> = field_series.iter().collect();
294 let st = StructChunked::from_series(PlSmallStr::EMPTY, 1, field_refs.iter().copied())
295 .map_err(|e| PolarsError::ComputeError(format!("struct from value: {e}").into()))?
296 .into_series();
297 Ok(Some(st))
298}
299
300use std::collections::HashMap;
301use std::path::Path;
302use std::sync::{Arc, Mutex, OnceLock};
303use std::thread_local;
304
305thread_local! {
306 static THREAD_UDF_SESSION: RefCell<Option<SparkSession>> = const { RefCell::new(None) };
308}
309
310pub(crate) fn set_thread_udf_session(session: SparkSession) {
312 THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = Some(session));
313}
314
315pub(crate) fn get_thread_udf_session() -> Option<SparkSession> {
317 THREAD_UDF_SESSION.with(|cell| cell.borrow().clone())
318}
319
320static GLOBAL_TEMP_CATALOG: OnceLock<Arc<Mutex<HashMap<String, DataFrame>>>> = OnceLock::new();
323
324fn global_temp_catalog() -> Arc<Mutex<HashMap<String, DataFrame>>> {
325 GLOBAL_TEMP_CATALOG
326 .get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
327 .clone()
328}
329
330#[derive(Clone)]
332pub struct SparkSessionBuilder {
333 app_name: Option<String>,
334 master: Option<String>,
335 config: HashMap<String, String>,
336}
337
338impl Default for SparkSessionBuilder {
339 fn default() -> Self {
340 Self::new()
341 }
342}
343
344impl SparkSessionBuilder {
345 pub fn new() -> Self {
346 SparkSessionBuilder {
347 app_name: None,
348 master: None,
349 config: HashMap::new(),
350 }
351 }
352
353 pub fn app_name(mut self, name: impl Into<String>) -> Self {
354 self.app_name = Some(name.into());
355 self
356 }
357
358 pub fn master(mut self, master: impl Into<String>) -> Self {
359 self.master = Some(master.into());
360 self
361 }
362
363 pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
364 self.config.insert(key.into(), value.into());
365 self
366 }
367
368 pub fn get_or_create(self) -> SparkSession {
369 let session = SparkSession::new(self.app_name, self.master, self.config);
370 set_thread_udf_session(session.clone());
371 session
372 }
373}
374
375pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
377
378pub type TableCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
380
381#[derive(Clone)]
384pub struct SparkSession {
385 app_name: Option<String>,
386 master: Option<String>,
387 config: HashMap<String, String>,
388 pub(crate) catalog: TempViewCatalog,
390 pub(crate) tables: TableCatalog,
392 pub(crate) udf_registry: UdfRegistry,
394 #[cfg(feature = "pyo3")]
396 pub(crate) python_udf_batch_size: usize,
397 #[cfg(feature = "pyo3")]
399 pub(crate) python_udf_max_concurrent_batches: usize,
400}
401
402impl SparkSession {
403 pub fn new(
404 app_name: Option<String>,
405 master: Option<String>,
406 config: HashMap<String, String>,
407 ) -> Self {
408 #[cfg(feature = "pyo3")]
409 let batch_size = config
410 .get("spark.robin.pythonUdf.batchSize")
411 .and_then(|s| s.parse::<usize>().ok())
412 .unwrap_or(usize::MAX);
413 #[cfg(feature = "pyo3")]
414 let max_concurrent = config
415 .get("spark.robin.pythonUdf.maxConcurrentBatches")
416 .and_then(|s| s.parse::<usize>().ok())
417 .unwrap_or(1);
418
419 SparkSession {
420 app_name,
421 master,
422 config,
423 catalog: Arc::new(Mutex::new(HashMap::new())),
424 tables: Arc::new(Mutex::new(HashMap::new())),
425 udf_registry: UdfRegistry::new(),
426 #[cfg(feature = "pyo3")]
427 python_udf_batch_size: batch_size,
428 #[cfg(feature = "pyo3")]
429 python_udf_max_concurrent_batches: max_concurrent,
430 }
431 }
432
433 pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
436 let _ = self
437 .catalog
438 .lock()
439 .map(|mut m| m.insert(name.to_string(), df));
440 }
441
442 pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
444 let _ = global_temp_catalog()
445 .lock()
446 .map(|mut m| m.insert(name.to_string(), df));
447 }
448
449 pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
451 let _ = global_temp_catalog()
452 .lock()
453 .map(|mut m| m.insert(name.to_string(), df));
454 }
455
456 pub fn drop_temp_view(&self, name: &str) {
459 let _ = self.catalog.lock().map(|mut m| m.remove(name));
460 }
461
462 pub fn drop_global_temp_view(&self, name: &str) -> bool {
464 global_temp_catalog()
465 .lock()
466 .map(|mut m| m.remove(name).is_some())
467 .unwrap_or(false)
468 }
469
470 pub fn register_table(&self, name: &str, df: DataFrame) {
472 let _ = self
473 .tables
474 .lock()
475 .map(|mut m| m.insert(name.to_string(), df));
476 }
477
478 pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
480 self.tables.lock().ok().and_then(|m| m.get(name).cloned())
481 }
482
483 pub fn saved_table_exists(&self, name: &str) -> bool {
485 self.tables
486 .lock()
487 .map(|m| m.contains_key(name))
488 .unwrap_or(false)
489 }
490
491 pub fn table_exists(&self, name: &str) -> bool {
493 if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
495 return global_temp_catalog()
496 .lock()
497 .map(|m| m.contains_key(tbl))
498 .unwrap_or(false);
499 }
500 if self
501 .catalog
502 .lock()
503 .map(|m| m.contains_key(name))
504 .unwrap_or(false)
505 {
506 return true;
507 }
508 if self
509 .tables
510 .lock()
511 .map(|m| m.contains_key(name))
512 .unwrap_or(false)
513 {
514 return true;
515 }
516 if let Some(warehouse) = self.warehouse_dir() {
518 let path = Path::new(warehouse).join(name);
519 if path.is_dir() {
520 return true;
521 }
522 }
523 false
524 }
525
526 pub fn list_global_temp_view_names(&self) -> Vec<String> {
528 global_temp_catalog()
529 .lock()
530 .map(|m| m.keys().cloned().collect())
531 .unwrap_or_default()
532 }
533
534 pub fn list_temp_view_names(&self) -> Vec<String> {
536 self.catalog
537 .lock()
538 .map(|m| m.keys().cloned().collect())
539 .unwrap_or_default()
540 }
541
542 pub fn list_table_names(&self) -> Vec<String> {
544 self.tables
545 .lock()
546 .map(|m| m.keys().cloned().collect())
547 .unwrap_or_default()
548 }
549
550 pub fn drop_table(&self, name: &str) -> bool {
552 self.tables
553 .lock()
554 .map(|mut m| m.remove(name).is_some())
555 .unwrap_or(false)
556 }
557
558 fn parse_global_temp_name(name: &str) -> Option<(&str, &str)> {
560 if let Some(dot) = name.find('.') {
561 let (db, tbl) = name.split_at(dot);
562 if db.eq_ignore_ascii_case("global_temp") {
563 return Some((db, tbl.strip_prefix('.').unwrap_or(tbl)));
564 }
565 }
566 None
567 }
568
569 pub fn warehouse_dir(&self) -> Option<&str> {
571 self.config
572 .get("spark.sql.warehouse.dir")
573 .map(|s| s.as_str())
574 .filter(|s| !s.is_empty())
575 }
576
577 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
580 if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
582 if let Some(df) = global_temp_catalog()
583 .lock()
584 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
585 .get(tbl)
586 .cloned()
587 {
588 return Ok(df);
589 }
590 return Err(PolarsError::InvalidOperation(
591 format!(
592 "Global temp view '{tbl}' not found. Register it with createOrReplaceGlobalTempView."
593 )
594 .into(),
595 ));
596 }
597 if let Some(df) = self
599 .catalog
600 .lock()
601 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
602 .get(name)
603 .cloned()
604 {
605 return Ok(df);
606 }
607 if let Some(df) = self
608 .tables
609 .lock()
610 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
611 .get(name)
612 .cloned()
613 {
614 return Ok(df);
615 }
616 if let Some(warehouse) = self.warehouse_dir() {
618 let dir = Path::new(warehouse).join(name);
619 if dir.is_dir() {
620 let data_file = dir.join("data.parquet");
622 let read_path = if data_file.is_file() { data_file } else { dir };
623 return self.read_parquet(&read_path);
624 }
625 }
626 Err(PolarsError::InvalidOperation(
627 format!(
628 "Table or view '{name}' not found. Register it with create_or_replace_temp_view or saveAsTable."
629 )
630 .into(),
631 ))
632 }
633
634 pub fn builder() -> SparkSessionBuilder {
635 SparkSessionBuilder::new()
636 }
637
638 pub fn get_config(&self) -> &HashMap<String, String> {
640 &self.config
641 }
642
643 pub fn is_case_sensitive(&self) -> bool {
646 self.config
647 .get("spark.sql.caseSensitive")
648 .map(|v| v.eq_ignore_ascii_case("true"))
649 .unwrap_or(false)
650 }
651
652 pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
654 where
655 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
656 {
657 self.udf_registry.register_rust_udf(name, f)
658 }
659
660 pub fn create_dataframe(
680 &self,
681 data: Vec<(i64, i64, String)>,
682 column_names: Vec<&str>,
683 ) -> Result<DataFrame, PolarsError> {
684 if column_names.len() != 3 {
685 return Err(PolarsError::ComputeError(
686 format!(
687 "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
688 column_names.len()
689 )
690 .into(),
691 ));
692 }
693
694 let mut cols: Vec<Series> = Vec::with_capacity(3);
695
696 let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
698 cols.push(Series::new(column_names[0].into(), col0));
699
700 let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
702 cols.push(Series::new(column_names[1].into(), col1));
703
704 let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
706 cols.push(Series::new(column_names[2].into(), col2));
707
708 let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
709 Ok(DataFrame::from_polars_with_options(
710 pl_df,
711 self.is_case_sensitive(),
712 ))
713 }
714
715 pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
717 DataFrame::from_polars_with_options(df, self.is_case_sensitive())
718 }
719
720 pub fn create_dataframe_from_rows(
726 &self,
727 rows: Vec<Vec<JsonValue>>,
728 schema: Vec<(String, String)>,
729 ) -> Result<DataFrame, PolarsError> {
730 if schema.is_empty() {
731 return Err(PolarsError::InvalidOperation(
732 "create_dataframe_from_rows: schema must not be empty".into(),
733 ));
734 }
735 use chrono::{NaiveDate, NaiveDateTime};
736
737 let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
738
739 for (col_idx, (name, type_str)) in schema.iter().enumerate() {
740 let type_lower = type_str.trim().to_lowercase();
741 let s = match type_lower.as_str() {
742 "int" | "bigint" | "long" => {
743 let vals: Vec<Option<i64>> = rows
744 .iter()
745 .map(|row| {
746 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
747 match v {
748 JsonValue::Number(n) => n.as_i64(),
749 JsonValue::Null => None,
750 _ => None,
751 }
752 })
753 .collect();
754 Series::new(name.as_str().into(), vals)
755 }
756 "double" | "float" | "double_precision" => {
757 let vals: Vec<Option<f64>> = rows
758 .iter()
759 .map(|row| {
760 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
761 match v {
762 JsonValue::Number(n) => n.as_f64(),
763 JsonValue::Null => None,
764 _ => None,
765 }
766 })
767 .collect();
768 Series::new(name.as_str().into(), vals)
769 }
770 "string" | "str" | "varchar" => {
771 let vals: Vec<Option<String>> = rows
772 .iter()
773 .map(|row| {
774 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
775 match v {
776 JsonValue::String(s) => Some(s),
777 JsonValue::Null => None,
778 other => Some(other.to_string()),
779 }
780 })
781 .collect();
782 Series::new(name.as_str().into(), vals)
783 }
784 "boolean" | "bool" => {
785 let vals: Vec<Option<bool>> = rows
786 .iter()
787 .map(|row| {
788 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
789 match v {
790 JsonValue::Bool(b) => Some(b),
791 JsonValue::Null => None,
792 _ => None,
793 }
794 })
795 .collect();
796 Series::new(name.as_str().into(), vals)
797 }
798 "date" => {
799 let epoch = crate::date_utils::epoch_naive_date();
800 let vals: Vec<Option<i32>> = rows
801 .iter()
802 .map(|row| {
803 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
804 match v {
805 JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
806 .ok()
807 .map(|d| (d - epoch).num_days() as i32),
808 JsonValue::Null => None,
809 _ => None,
810 }
811 })
812 .collect();
813 let series = Series::new(name.as_str().into(), vals);
814 series
815 .cast(&DataType::Date)
816 .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
817 }
818 "timestamp" | "datetime" | "timestamp_ntz" => {
819 let vals: Vec<Option<i64>> =
820 rows.iter()
821 .map(|row| {
822 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
823 match v {
824 JsonValue::String(s) => {
825 let parsed = NaiveDateTime::parse_from_str(
826 &s,
827 "%Y-%m-%dT%H:%M:%S%.f",
828 )
829 .or_else(|_| {
830 NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
831 })
832 .or_else(|_| {
833 NaiveDate::parse_from_str(&s, "%Y-%m-%d")
834 .map(|d| d.and_hms_opt(0, 0, 0).unwrap())
835 });
836 parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
837 }
838 JsonValue::Number(n) => n.as_i64(),
839 JsonValue::Null => None,
840 _ => None,
841 }
842 })
843 .collect();
844 let series = Series::new(name.as_str().into(), vals);
845 series
846 .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
847 .map_err(|e| {
848 PolarsError::ComputeError(format!("datetime cast: {e}").into())
849 })?
850 }
851 _ if parse_array_element_type(&type_lower).is_some() => {
852 let elem_type = parse_array_element_type(&type_lower).unwrap();
853 let inner_dtype = json_type_str_to_polars(&elem_type)
854 .ok_or_else(|| {
855 PolarsError::ComputeError(
856 format!(
857 "create_dataframe_from_rows: array element type '{elem_type}' not supported"
858 )
859 .into(),
860 )
861 })?;
862 let n = rows.len();
863 let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
864 for row in rows.iter() {
865 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
866 if let JsonValue::Null = v {
867 builder.append_null();
868 } else if let Some(arr) = v.as_array() {
869 let elem_series: Vec<Series> = arr
870 .iter()
871 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
872 .collect::<Result<Vec<_>, _>>()?;
873 let vals: Vec<_> =
874 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
875 let s = Series::from_any_values_and_dtype(
876 PlSmallStr::EMPTY,
877 &vals,
878 &inner_dtype,
879 false,
880 )
881 .map_err(|e| {
882 PolarsError::ComputeError(format!("array elem: {e}").into())
883 })?;
884 builder.append_series(&s)?;
885 } else {
886 return Err(PolarsError::ComputeError(
887 "array column value must be null or array".into(),
888 ));
889 }
890 }
891 builder.finish().into_series()
892 }
893 _ if parse_struct_fields(&type_lower).is_some() => {
894 let values: Vec<Option<JsonValue>> =
895 rows.iter().map(|row| row.get(col_idx).cloned()).collect();
896 json_values_to_series(&values, &type_lower, name)?
897 }
898 _ => {
899 return Err(PolarsError::ComputeError(
900 format!(
901 "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
902 )
903 .into(),
904 ));
905 }
906 };
907 cols.push(s);
908 }
909
910 let pl_df = PlDataFrame::new(cols.iter().map(|s| s.clone().into()).collect())?;
911 Ok(DataFrame::from_polars_with_options(
912 pl_df,
913 self.is_case_sensitive(),
914 ))
915 }
916
917 pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
924 if step == 0 {
925 return Err(PolarsError::InvalidOperation(
926 "range: step must not be 0".into(),
927 ));
928 }
929 let mut vals: Vec<i64> = Vec::new();
930 let mut v = start;
931 if step > 0 {
932 while v < end {
933 vals.push(v);
934 v = v.saturating_add(step);
935 }
936 } else {
937 while v > end {
938 vals.push(v);
939 v = v.saturating_add(step);
940 }
941 }
942 let col = Series::new("id".into(), vals);
943 let pl_df = PlDataFrame::new(vec![col.into()])?;
944 Ok(DataFrame::from_polars_with_options(
945 pl_df,
946 self.is_case_sensitive(),
947 ))
948 }
949
950 pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
965 use polars::prelude::*;
966 let path = path.as_ref();
967 let path_display = path.display();
968 let lf = LazyCsvReader::new(path)
970 .with_has_header(true)
971 .with_infer_schema_length(Some(100))
972 .finish()
973 .map_err(|e| {
974 PolarsError::ComputeError(
975 format!(
976 "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
977 )
978 .into(),
979 )
980 })?;
981 let pl_df = lf.collect().map_err(|e| {
982 PolarsError::ComputeError(
983 format!("read_csv({path_display}): collect failed: {e}").into(),
984 )
985 })?;
986 Ok(crate::dataframe::DataFrame::from_polars_with_options(
987 pl_df,
988 self.is_case_sensitive(),
989 ))
990 }
991
992 pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1006 use polars::prelude::*;
1007 let path = path.as_ref();
1008 let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
1010 let pl_df = lf.collect()?;
1011 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1012 pl_df,
1013 self.is_case_sensitive(),
1014 ))
1015 }
1016
1017 pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1031 use polars::prelude::*;
1032 use std::num::NonZeroUsize;
1033 let path = path.as_ref();
1034 let lf = LazyJsonLineReader::new(path)
1036 .with_infer_schema_length(NonZeroUsize::new(100))
1037 .finish()?;
1038 let pl_df = lf.collect()?;
1039 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1040 pl_df,
1041 self.is_case_sensitive(),
1042 ))
1043 }
1044
1045 #[cfg(feature = "sql")]
1049 pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
1050 crate::sql::execute_sql(self, query)
1051 }
1052
1053 #[cfg(not(feature = "sql"))]
1055 pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
1056 Err(PolarsError::InvalidOperation(
1057 "SQL queries require the 'sql' feature. Build with --features sql.".into(),
1058 ))
1059 }
1060
1061 fn looks_like_path(s: &str) -> bool {
1063 s.contains('/') || s.contains('\\') || Path::new(s).exists()
1064 }
1065
1066 #[cfg(feature = "delta")]
1068 pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1069 crate::delta::read_delta(path, self.is_case_sensitive())
1070 }
1071
1072 #[cfg(feature = "delta")]
1074 pub fn read_delta_path_with_version(
1075 &self,
1076 path: impl AsRef<Path>,
1077 version: Option<i64>,
1078 ) -> Result<DataFrame, PolarsError> {
1079 crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
1080 }
1081
1082 #[cfg(feature = "delta")]
1084 pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1085 if Self::looks_like_path(name_or_path) {
1086 self.read_delta_path(Path::new(name_or_path))
1087 } else {
1088 self.table(name_or_path)
1089 }
1090 }
1091
1092 #[cfg(feature = "delta")]
1093 pub fn read_delta_with_version(
1094 &self,
1095 name_or_path: &str,
1096 version: Option<i64>,
1097 ) -> Result<DataFrame, PolarsError> {
1098 if Self::looks_like_path(name_or_path) {
1099 self.read_delta_path_with_version(Path::new(name_or_path), version)
1100 } else {
1101 self.table(name_or_path)
1103 }
1104 }
1105
1106 #[cfg(not(feature = "delta"))]
1108 pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1109 if Self::looks_like_path(name_or_path) {
1110 Err(PolarsError::InvalidOperation(
1111 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1112 ))
1113 } else {
1114 self.table(name_or_path)
1115 }
1116 }
1117
1118 #[cfg(not(feature = "delta"))]
1119 pub fn read_delta_with_version(
1120 &self,
1121 name_or_path: &str,
1122 version: Option<i64>,
1123 ) -> Result<DataFrame, PolarsError> {
1124 let _ = version;
1125 self.read_delta(name_or_path)
1126 }
1127
1128 #[cfg(feature = "delta")]
1130 pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1131 self.read_delta_path(path)
1132 }
1133
1134 #[cfg(not(feature = "delta"))]
1135 pub fn read_delta_from_path(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1136 Err(PolarsError::InvalidOperation(
1137 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1138 ))
1139 }
1140
1141 pub fn stop(&self) {
1143 }
1145}
1146
1147pub struct DataFrameReader {
1150 session: SparkSession,
1151 options: HashMap<String, String>,
1152 format: Option<String>,
1153}
1154
1155impl DataFrameReader {
1156 pub fn new(session: SparkSession) -> Self {
1157 DataFrameReader {
1158 session,
1159 options: HashMap::new(),
1160 format: None,
1161 }
1162 }
1163
1164 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1166 self.options.insert(key.into(), value.into());
1167 self
1168 }
1169
1170 pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
1172 for (k, v) in opts {
1173 self.options.insert(k, v);
1174 }
1175 self
1176 }
1177
1178 pub fn format(mut self, fmt: impl Into<String>) -> Self {
1180 self.format = Some(fmt.into());
1181 self
1182 }
1183
1184 pub fn schema(self, _schema: impl Into<String>) -> Self {
1186 self
1187 }
1188
1189 pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1191 let path = path.as_ref();
1192 let fmt = self.format.clone().or_else(|| {
1193 path.extension()
1194 .and_then(|e| e.to_str())
1195 .map(|s| s.to_lowercase())
1196 });
1197 match fmt.as_deref() {
1198 Some("parquet") => self.parquet(path),
1199 Some("csv") => self.csv(path),
1200 Some("json") | Some("jsonl") => self.json(path),
1201 #[cfg(feature = "delta")]
1202 Some("delta") => self.session.read_delta_from_path(path),
1203 _ => Err(PolarsError::ComputeError(
1204 format!(
1205 "load: could not infer format for path '{}'. Use format('parquet'|'csv'|'json') before load.",
1206 path.display()
1207 )
1208 .into(),
1209 )),
1210 }
1211 }
1212
1213 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
1215 self.session.table(name)
1216 }
1217
1218 fn apply_csv_options(
1219 &self,
1220 reader: polars::prelude::LazyCsvReader,
1221 ) -> polars::prelude::LazyCsvReader {
1222 use polars::prelude::NullValues;
1223 let mut r = reader;
1224 if let Some(v) = self.options.get("header") {
1225 let has_header = v.eq_ignore_ascii_case("true") || v == "1";
1226 r = r.with_has_header(has_header);
1227 }
1228 if let Some(v) = self.options.get("inferSchema") {
1229 if v.eq_ignore_ascii_case("true") || v == "1" {
1230 let n = self
1231 .options
1232 .get("inferSchemaLength")
1233 .and_then(|s| s.parse::<usize>().ok())
1234 .unwrap_or(100);
1235 r = r.with_infer_schema_length(Some(n));
1236 }
1237 } else if let Some(v) = self.options.get("inferSchemaLength") {
1238 if let Ok(n) = v.parse::<usize>() {
1239 r = r.with_infer_schema_length(Some(n));
1240 }
1241 }
1242 if let Some(sep) = self.options.get("sep") {
1243 if let Some(b) = sep.bytes().next() {
1244 r = r.with_separator(b);
1245 }
1246 }
1247 if let Some(null_val) = self.options.get("nullValue") {
1248 r = r.with_null_values(Some(NullValues::AllColumnsSingle(null_val.clone().into())));
1249 }
1250 r
1251 }
1252
1253 fn apply_json_options(
1254 &self,
1255 reader: polars::prelude::LazyJsonLineReader,
1256 ) -> polars::prelude::LazyJsonLineReader {
1257 use std::num::NonZeroUsize;
1258 let mut r = reader;
1259 if let Some(v) = self.options.get("inferSchemaLength") {
1260 if let Ok(n) = v.parse::<usize>() {
1261 r = r.with_infer_schema_length(NonZeroUsize::new(n));
1262 }
1263 }
1264 r
1265 }
1266
1267 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1268 use polars::prelude::*;
1269 let path = path.as_ref();
1270 let path_display = path.display();
1271 let reader = LazyCsvReader::new(path);
1272 let reader = if self.options.is_empty() {
1273 reader
1274 .with_has_header(true)
1275 .with_infer_schema_length(Some(100))
1276 } else {
1277 self.apply_csv_options(
1278 reader
1279 .with_has_header(true)
1280 .with_infer_schema_length(Some(100)),
1281 )
1282 };
1283 let lf = reader.finish().map_err(|e| {
1284 PolarsError::ComputeError(format!("read csv({path_display}): {e}").into())
1285 })?;
1286 let pl_df = lf.collect().map_err(|e| {
1287 PolarsError::ComputeError(
1288 format!("read csv({path_display}): collect failed: {e}").into(),
1289 )
1290 })?;
1291 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1292 pl_df,
1293 self.session.is_case_sensitive(),
1294 ))
1295 }
1296
1297 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1298 use polars::prelude::*;
1299 let path = path.as_ref();
1300 let lf = LazyFrame::scan_parquet(path, ScanArgsParquet::default())?;
1301 let pl_df = lf.collect()?;
1302 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1303 pl_df,
1304 self.session.is_case_sensitive(),
1305 ))
1306 }
1307
1308 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1309 use polars::prelude::*;
1310 use std::num::NonZeroUsize;
1311 let path = path.as_ref();
1312 let reader = LazyJsonLineReader::new(path);
1313 let reader = if self.options.is_empty() {
1314 reader.with_infer_schema_length(NonZeroUsize::new(100))
1315 } else {
1316 self.apply_json_options(reader.with_infer_schema_length(NonZeroUsize::new(100)))
1317 };
1318 let lf = reader.finish()?;
1319 let pl_df = lf.collect()?;
1320 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1321 pl_df,
1322 self.session.is_case_sensitive(),
1323 ))
1324 }
1325
1326 #[cfg(feature = "delta")]
1327 pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1328 self.session.read_delta_from_path(path)
1329 }
1330}
1331
1332impl SparkSession {
1333 pub fn read(&self) -> DataFrameReader {
1335 DataFrameReader::new(SparkSession {
1336 app_name: self.app_name.clone(),
1337 master: self.master.clone(),
1338 config: self.config.clone(),
1339 catalog: self.catalog.clone(),
1340 tables: self.tables.clone(),
1341 udf_registry: self.udf_registry.clone(),
1342 #[cfg(feature = "pyo3")]
1343 python_udf_batch_size: self.python_udf_batch_size,
1344 #[cfg(feature = "pyo3")]
1345 python_udf_max_concurrent_batches: self.python_udf_max_concurrent_batches,
1346 })
1347 }
1348}
1349
1350impl Default for SparkSession {
1351 fn default() -> Self {
1352 Self::builder().get_or_create()
1353 }
1354}
1355
1356#[cfg(test)]
1357mod tests {
1358 use super::*;
1359
1360 #[test]
1361 fn test_spark_session_builder_basic() {
1362 let spark = SparkSession::builder().app_name("test_app").get_or_create();
1363
1364 assert_eq!(spark.app_name, Some("test_app".to_string()));
1365 }
1366
1367 #[test]
1368 fn test_spark_session_builder_with_master() {
1369 let spark = SparkSession::builder()
1370 .app_name("test_app")
1371 .master("local[*]")
1372 .get_or_create();
1373
1374 assert_eq!(spark.app_name, Some("test_app".to_string()));
1375 assert_eq!(spark.master, Some("local[*]".to_string()));
1376 }
1377
1378 #[test]
1379 fn test_spark_session_builder_with_config() {
1380 let spark = SparkSession::builder()
1381 .app_name("test_app")
1382 .config("spark.executor.memory", "4g")
1383 .config("spark.driver.memory", "2g")
1384 .get_or_create();
1385
1386 assert_eq!(
1387 spark.config.get("spark.executor.memory"),
1388 Some(&"4g".to_string())
1389 );
1390 assert_eq!(
1391 spark.config.get("spark.driver.memory"),
1392 Some(&"2g".to_string())
1393 );
1394 }
1395
1396 #[test]
1397 fn test_spark_session_default() {
1398 let spark = SparkSession::default();
1399 assert!(spark.app_name.is_none());
1400 assert!(spark.master.is_none());
1401 assert!(spark.config.is_empty());
1402 }
1403
1404 #[test]
1405 fn test_create_dataframe_success() {
1406 let spark = SparkSession::builder().app_name("test").get_or_create();
1407 let data = vec![
1408 (1i64, 25i64, "Alice".to_string()),
1409 (2i64, 30i64, "Bob".to_string()),
1410 ];
1411
1412 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
1413
1414 assert!(result.is_ok());
1415 let df = result.unwrap();
1416 assert_eq!(df.count().unwrap(), 2);
1417
1418 let columns = df.columns().unwrap();
1419 assert!(columns.contains(&"id".to_string()));
1420 assert!(columns.contains(&"age".to_string()));
1421 assert!(columns.contains(&"name".to_string()));
1422 }
1423
1424 #[test]
1425 fn test_create_dataframe_wrong_column_count() {
1426 let spark = SparkSession::builder().app_name("test").get_or_create();
1427 let data = vec![(1i64, 25i64, "Alice".to_string())];
1428
1429 let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
1431 assert!(result.is_err());
1432
1433 let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
1435 assert!(result.is_err());
1436 }
1437
1438 #[test]
1439 fn test_create_dataframe_from_rows_empty_schema_returns_error() {
1440 let spark = SparkSession::builder().app_name("test").get_or_create();
1441 let rows: Vec<Vec<JsonValue>> = vec![vec![]];
1442 let schema: Vec<(String, String)> = vec![];
1443 let result = spark.create_dataframe_from_rows(rows, schema);
1444 match &result {
1445 Err(e) => assert!(e.to_string().contains("schema must not be empty")),
1446 Ok(_) => panic!("expected error for empty schema"),
1447 }
1448 }
1449
1450 #[test]
1451 fn test_create_dataframe_empty() {
1452 let spark = SparkSession::builder().app_name("test").get_or_create();
1453 let data: Vec<(i64, i64, String)> = vec![];
1454
1455 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
1456
1457 assert!(result.is_ok());
1458 let df = result.unwrap();
1459 assert_eq!(df.count().unwrap(), 0);
1460 }
1461
1462 #[test]
1463 fn test_create_dataframe_from_polars() {
1464 use polars::prelude::df;
1465
1466 let spark = SparkSession::builder().app_name("test").get_or_create();
1467 let polars_df = df!(
1468 "x" => &[1, 2, 3],
1469 "y" => &[4, 5, 6]
1470 )
1471 .unwrap();
1472
1473 let df = spark.create_dataframe_from_polars(polars_df);
1474
1475 assert_eq!(df.count().unwrap(), 3);
1476 let columns = df.columns().unwrap();
1477 assert!(columns.contains(&"x".to_string()));
1478 assert!(columns.contains(&"y".to_string()));
1479 }
1480
1481 #[test]
1482 fn test_read_csv_file_not_found() {
1483 let spark = SparkSession::builder().app_name("test").get_or_create();
1484
1485 let result = spark.read_csv("nonexistent_file.csv");
1486
1487 assert!(result.is_err());
1488 }
1489
1490 #[test]
1491 fn test_read_parquet_file_not_found() {
1492 let spark = SparkSession::builder().app_name("test").get_or_create();
1493
1494 let result = spark.read_parquet("nonexistent_file.parquet");
1495
1496 assert!(result.is_err());
1497 }
1498
1499 #[test]
1500 fn test_read_json_file_not_found() {
1501 let spark = SparkSession::builder().app_name("test").get_or_create();
1502
1503 let result = spark.read_json("nonexistent_file.json");
1504
1505 assert!(result.is_err());
1506 }
1507
1508 #[test]
1509 fn test_rust_udf_dataframe() {
1510 use crate::functions::{call_udf, col};
1511 use polars::prelude::DataType;
1512
1513 let spark = SparkSession::builder().app_name("test").get_or_create();
1514 spark
1515 .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
1516 .unwrap();
1517 let df = spark
1518 .create_dataframe(
1519 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
1520 vec!["id", "age", "name"],
1521 )
1522 .unwrap();
1523 let col = call_udf("to_str", &[col("id")]).unwrap();
1524 let df2 = df.with_column("id_str", &col).unwrap();
1525 let cols = df2.columns().unwrap();
1526 assert!(cols.contains(&"id_str".to_string()));
1527 let rows = df2.collect_as_json_rows().unwrap();
1528 assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
1529 assert_eq!(rows[1].get("id_str").and_then(|v| v.as_str()), Some("2"));
1530 }
1531
1532 #[test]
1533 fn test_case_insensitive_filter_select() {
1534 use crate::expression::lit_i64;
1535 use crate::functions::col;
1536
1537 let spark = SparkSession::builder().app_name("test").get_or_create();
1538 let df = spark
1539 .create_dataframe(
1540 vec![
1541 (1, 25, "Alice".to_string()),
1542 (2, 30, "Bob".to_string()),
1543 (3, 35, "Charlie".to_string()),
1544 ],
1545 vec!["Id", "Age", "Name"],
1546 )
1547 .unwrap();
1548 let filtered = df
1550 .filter(col("age").gt(lit_i64(26)).expr().clone())
1551 .unwrap()
1552 .select(vec!["name"])
1553 .unwrap();
1554 assert_eq!(filtered.count().unwrap(), 2);
1555 let rows = filtered.collect_as_json_rows().unwrap();
1556 let names: Vec<&str> = rows
1557 .iter()
1558 .map(|r| r.get("name").and_then(|v| v.as_str()).unwrap())
1559 .collect();
1560 assert!(names.contains(&"Bob"));
1561 assert!(names.contains(&"Charlie"));
1562 }
1563
1564 #[test]
1565 fn test_sql_returns_error_without_feature_or_unknown_table() {
1566 let spark = SparkSession::builder().app_name("test").get_or_create();
1567
1568 let result = spark.sql("SELECT * FROM table");
1569
1570 assert!(result.is_err());
1571 match result {
1572 Err(PolarsError::InvalidOperation(msg)) => {
1573 let s = msg.to_string();
1574 assert!(
1577 s.contains("SQL") || s.contains("Table") || s.contains("feature"),
1578 "unexpected message: {s}"
1579 );
1580 }
1581 _ => panic!("Expected InvalidOperation error"),
1582 }
1583 }
1584
1585 #[test]
1586 fn test_spark_session_stop() {
1587 let spark = SparkSession::builder().app_name("test").get_or_create();
1588
1589 spark.stop();
1591 }
1592
1593 #[test]
1594 fn test_dataframe_reader_api() {
1595 let spark = SparkSession::builder().app_name("test").get_or_create();
1596 let reader = spark.read();
1597
1598 assert!(reader.csv("nonexistent.csv").is_err());
1600 assert!(reader.parquet("nonexistent.parquet").is_err());
1601 assert!(reader.json("nonexistent.json").is_err());
1602 }
1603
1604 #[test]
1605 fn test_read_csv_with_valid_file() {
1606 use std::io::Write;
1607 use tempfile::NamedTempFile;
1608
1609 let spark = SparkSession::builder().app_name("test").get_or_create();
1610
1611 let mut temp_file = NamedTempFile::new().unwrap();
1613 writeln!(temp_file, "id,name,age").unwrap();
1614 writeln!(temp_file, "1,Alice,25").unwrap();
1615 writeln!(temp_file, "2,Bob,30").unwrap();
1616 temp_file.flush().unwrap();
1617
1618 let result = spark.read_csv(temp_file.path());
1619
1620 assert!(result.is_ok());
1621 let df = result.unwrap();
1622 assert_eq!(df.count().unwrap(), 2);
1623
1624 let columns = df.columns().unwrap();
1625 assert!(columns.contains(&"id".to_string()));
1626 assert!(columns.contains(&"name".to_string()));
1627 assert!(columns.contains(&"age".to_string()));
1628 }
1629
1630 #[test]
1631 fn test_read_json_with_valid_file() {
1632 use std::io::Write;
1633 use tempfile::NamedTempFile;
1634
1635 let spark = SparkSession::builder().app_name("test").get_or_create();
1636
1637 let mut temp_file = NamedTempFile::new().unwrap();
1639 writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
1640 writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
1641 temp_file.flush().unwrap();
1642
1643 let result = spark.read_json(temp_file.path());
1644
1645 assert!(result.is_ok());
1646 let df = result.unwrap();
1647 assert_eq!(df.count().unwrap(), 2);
1648 }
1649
1650 #[test]
1651 fn test_read_csv_empty_file() {
1652 use std::io::Write;
1653 use tempfile::NamedTempFile;
1654
1655 let spark = SparkSession::builder().app_name("test").get_or_create();
1656
1657 let mut temp_file = NamedTempFile::new().unwrap();
1659 writeln!(temp_file, "id,name").unwrap();
1660 temp_file.flush().unwrap();
1661
1662 let result = spark.read_csv(temp_file.path());
1663
1664 assert!(result.is_ok());
1665 let df = result.unwrap();
1666 assert_eq!(df.count().unwrap(), 0);
1667 }
1668
1669 #[test]
1670 fn test_write_partitioned_parquet() {
1671 use crate::dataframe::{WriteFormat, WriteMode};
1672 use std::fs;
1673 use tempfile::TempDir;
1674
1675 let spark = SparkSession::builder().app_name("test").get_or_create();
1676 let df = spark
1677 .create_dataframe(
1678 vec![
1679 (1, 25, "Alice".to_string()),
1680 (2, 30, "Bob".to_string()),
1681 (3, 25, "Carol".to_string()),
1682 ],
1683 vec!["id", "age", "name"],
1684 )
1685 .unwrap();
1686 let dir = TempDir::new().unwrap();
1687 let path = dir.path().join("out");
1688 df.write()
1689 .mode(WriteMode::Overwrite)
1690 .format(WriteFormat::Parquet)
1691 .partition_by(["age"])
1692 .save(&path)
1693 .unwrap();
1694 assert!(path.is_dir());
1695 let entries: Vec<_> = fs::read_dir(&path).unwrap().collect();
1696 assert_eq!(
1697 entries.len(),
1698 2,
1699 "expected two partition dirs (age=25, age=30)"
1700 );
1701 let names: Vec<String> = entries
1702 .iter()
1703 .filter_map(|e| e.as_ref().ok())
1704 .map(|e| e.file_name().to_string_lossy().into_owned())
1705 .collect();
1706 assert!(names.iter().any(|n| n.starts_with("age=")));
1707 let df_read = spark.read_parquet(&path).unwrap();
1708 assert_eq!(df_read.count().unwrap(), 3);
1709 }
1710
1711 #[test]
1712 fn test_save_as_table_error_if_exists() {
1713 use crate::dataframe::SaveMode;
1714
1715 let spark = SparkSession::builder().app_name("test").get_or_create();
1716 let df = spark
1717 .create_dataframe(
1718 vec![(1, 25, "Alice".to_string())],
1719 vec!["id", "age", "name"],
1720 )
1721 .unwrap();
1722 df.write()
1724 .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
1725 .unwrap();
1726 assert!(spark.table("t1").is_ok());
1727 assert_eq!(spark.table("t1").unwrap().count().unwrap(), 1);
1728 let err = df
1730 .write()
1731 .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
1732 .unwrap_err();
1733 assert!(err.to_string().contains("already exists"));
1734 }
1735
1736 #[test]
1737 fn test_save_as_table_overwrite() {
1738 use crate::dataframe::SaveMode;
1739
1740 let spark = SparkSession::builder().app_name("test").get_or_create();
1741 let df1 = spark
1742 .create_dataframe(
1743 vec![(1, 25, "Alice".to_string())],
1744 vec!["id", "age", "name"],
1745 )
1746 .unwrap();
1747 let df2 = spark
1748 .create_dataframe(
1749 vec![(2, 30, "Bob".to_string()), (3, 35, "Carol".to_string())],
1750 vec!["id", "age", "name"],
1751 )
1752 .unwrap();
1753 df1.write()
1754 .save_as_table(&spark, "t_over", SaveMode::ErrorIfExists)
1755 .unwrap();
1756 assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 1);
1757 df2.write()
1758 .save_as_table(&spark, "t_over", SaveMode::Overwrite)
1759 .unwrap();
1760 assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 2);
1761 }
1762
1763 #[test]
1764 fn test_save_as_table_append() {
1765 use crate::dataframe::SaveMode;
1766
1767 let spark = SparkSession::builder().app_name("test").get_or_create();
1768 let df1 = spark
1769 .create_dataframe(
1770 vec![(1, 25, "Alice".to_string())],
1771 vec!["id", "age", "name"],
1772 )
1773 .unwrap();
1774 let df2 = spark
1775 .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
1776 .unwrap();
1777 df1.write()
1778 .save_as_table(&spark, "t_append", SaveMode::ErrorIfExists)
1779 .unwrap();
1780 df2.write()
1781 .save_as_table(&spark, "t_append", SaveMode::Append)
1782 .unwrap();
1783 assert_eq!(spark.table("t_append").unwrap().count().unwrap(), 2);
1784 }
1785
1786 #[test]
1787 fn test_save_as_table_ignore() {
1788 use crate::dataframe::SaveMode;
1789
1790 let spark = SparkSession::builder().app_name("test").get_or_create();
1791 let df1 = spark
1792 .create_dataframe(
1793 vec![(1, 25, "Alice".to_string())],
1794 vec!["id", "age", "name"],
1795 )
1796 .unwrap();
1797 let df2 = spark
1798 .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
1799 .unwrap();
1800 df1.write()
1801 .save_as_table(&spark, "t_ignore", SaveMode::ErrorIfExists)
1802 .unwrap();
1803 df2.write()
1804 .save_as_table(&spark, "t_ignore", SaveMode::Ignore)
1805 .unwrap();
1806 assert_eq!(spark.table("t_ignore").unwrap().count().unwrap(), 1);
1808 }
1809
1810 #[test]
1811 fn test_table_resolution_temp_view_first() {
1812 use crate::dataframe::SaveMode;
1813
1814 let spark = SparkSession::builder().app_name("test").get_or_create();
1815 let df_saved = spark
1816 .create_dataframe(
1817 vec![(1, 25, "Saved".to_string())],
1818 vec!["id", "age", "name"],
1819 )
1820 .unwrap();
1821 let df_temp = spark
1822 .create_dataframe(vec![(2, 30, "Temp".to_string())], vec!["id", "age", "name"])
1823 .unwrap();
1824 df_saved
1825 .write()
1826 .save_as_table(&spark, "x", SaveMode::ErrorIfExists)
1827 .unwrap();
1828 spark.create_or_replace_temp_view("x", df_temp);
1829 let t = spark.table("x").unwrap();
1831 let rows = t.collect_as_json_rows().unwrap();
1832 assert_eq!(rows.len(), 1);
1833 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Temp"));
1834 }
1835
1836 #[test]
1837 fn test_drop_table() {
1838 use crate::dataframe::SaveMode;
1839
1840 let spark = SparkSession::builder().app_name("test").get_or_create();
1841 let df = spark
1842 .create_dataframe(
1843 vec![(1, 25, "Alice".to_string())],
1844 vec!["id", "age", "name"],
1845 )
1846 .unwrap();
1847 df.write()
1848 .save_as_table(&spark, "t_drop", SaveMode::ErrorIfExists)
1849 .unwrap();
1850 assert!(spark.table("t_drop").is_ok());
1851 assert!(spark.drop_table("t_drop"));
1852 assert!(spark.table("t_drop").is_err());
1853 assert!(!spark.drop_table("t_drop"));
1855 }
1856
1857 #[test]
1858 fn test_global_temp_view_persists_across_sessions() {
1859 let spark1 = SparkSession::builder().app_name("s1").get_or_create();
1861 let df1 = spark1
1862 .create_dataframe(
1863 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
1864 vec!["id", "age", "name"],
1865 )
1866 .unwrap();
1867 spark1.create_or_replace_global_temp_view("people", df1);
1868 assert_eq!(
1869 spark1.table("global_temp.people").unwrap().count().unwrap(),
1870 2
1871 );
1872
1873 let spark2 = SparkSession::builder().app_name("s2").get_or_create();
1875 let df2 = spark2.table("global_temp.people").unwrap();
1876 assert_eq!(df2.count().unwrap(), 2);
1877 let rows = df2.collect_as_json_rows().unwrap();
1878 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
1879
1880 let df_local = spark2
1882 .create_dataframe(
1883 vec![(3, 35, "Carol".to_string())],
1884 vec!["id", "age", "name"],
1885 )
1886 .unwrap();
1887 spark2.create_or_replace_temp_view("people", df_local);
1888 assert_eq!(spark2.table("people").unwrap().count().unwrap(), 1);
1890 assert_eq!(
1892 spark2.table("global_temp.people").unwrap().count().unwrap(),
1893 2
1894 );
1895
1896 assert!(spark2.drop_global_temp_view("people"));
1898 assert!(spark2.table("global_temp.people").is_err());
1899 }
1900
1901 #[test]
1902 fn test_warehouse_persistence_between_sessions() {
1903 use crate::dataframe::SaveMode;
1904 use std::fs;
1905 use tempfile::TempDir;
1906
1907 let dir = TempDir::new().unwrap();
1908 let warehouse = dir.path().to_str().unwrap();
1909
1910 let spark1 = SparkSession::builder()
1912 .app_name("w1")
1913 .config("spark.sql.warehouse.dir", warehouse)
1914 .get_or_create();
1915 let df1 = spark1
1916 .create_dataframe(
1917 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
1918 vec!["id", "age", "name"],
1919 )
1920 .unwrap();
1921 df1.write()
1922 .save_as_table(&spark1, "users", SaveMode::ErrorIfExists)
1923 .unwrap();
1924 assert_eq!(spark1.table("users").unwrap().count().unwrap(), 2);
1925
1926 let spark2 = SparkSession::builder()
1928 .app_name("w2")
1929 .config("spark.sql.warehouse.dir", warehouse)
1930 .get_or_create();
1931 let df2 = spark2.table("users").unwrap();
1932 assert_eq!(df2.count().unwrap(), 2);
1933 let rows = df2.collect_as_json_rows().unwrap();
1934 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
1935
1936 let table_path = dir.path().join("users");
1938 assert!(table_path.is_dir());
1939 let entries: Vec<_> = fs::read_dir(&table_path).unwrap().collect();
1940 assert!(!entries.is_empty());
1941 }
1942}