1use crate::dataframe::DataFrame;
2use crate::error::EngineError;
3use crate::udf_registry::UdfRegistry;
4use polars::chunked_array::StructChunked;
5use polars::chunked_array::builder::get_list_builder;
6use polars::prelude::{
7 DataFrame as PlDataFrame, DataType, Field, IntoSeries, NamedFrom, PlSmallStr, PolarsError,
8 Series, TimeUnit,
9};
10use serde_json::Value as JsonValue;
11use std::cell::RefCell;
12
13fn parse_array_element_type(type_str: &str) -> Option<String> {
15 let s = type_str.trim();
16 if !s.to_lowercase().starts_with("array<") || !s.ends_with('>') {
17 return None;
18 }
19 Some(s[6..s.len() - 1].trim().to_string())
20}
21
22fn parse_struct_fields(type_str: &str) -> Option<Vec<(String, String)>> {
24 let s = type_str.trim();
25 if !s.to_lowercase().starts_with("struct<") || !s.ends_with('>') {
26 return None;
27 }
28 let inner = s[7..s.len() - 1].trim();
29 if inner.is_empty() {
30 return Some(Vec::new());
31 }
32 let mut out = Vec::new();
33 for part in inner.split(',') {
34 let part = part.trim();
35 if let Some(idx) = part.find(':') {
36 let name = part[..idx].trim().to_string();
37 let typ = part[idx + 1..].trim().to_string();
38 out.push((name, typ));
39 }
40 }
41 Some(out)
42}
43
44fn parse_map_key_value_types(type_str: &str) -> Option<(String, String)> {
47 let s = type_str.trim().to_lowercase();
48 if !s.starts_with("map<") || !s.ends_with('>') {
49 return None;
50 }
51 let inner = s[4..s.len() - 1].trim();
52 let comma = inner.find(',')?;
53 let key_type = inner[..comma].trim().to_string();
54 let value_type = inner[comma + 1..].trim().to_string();
55 Some((key_type, value_type))
56}
57
58fn is_decimal_type_str(type_str: &str) -> bool {
60 let s = type_str.trim().to_lowercase();
61 s.starts_with("decimal(") && s.contains(')')
62}
63
64fn json_type_str_to_polars(type_str: &str) -> Option<DataType> {
67 let s = type_str.trim().to_lowercase();
68 if is_decimal_type_str(&s) {
69 return Some(DataType::Float64);
70 }
71 match s.as_str() {
72 "int" | "integer" | "bigint" | "long" => Some(DataType::Int64),
73 "double" | "float" | "double_precision" => Some(DataType::Float64),
74 "string" | "str" | "varchar" => Some(DataType::String),
75 "boolean" | "bool" => Some(DataType::Boolean),
76 _ => None,
77 }
78}
79
80fn json_value_to_array(v: &JsonValue) -> Option<Vec<JsonValue>> {
84 match v {
85 JsonValue::Null => None,
86 JsonValue::Array(arr) => Some(arr.clone()),
87 JsonValue::Object(obj) => {
88 let mut indices: Vec<usize> =
90 obj.keys().filter_map(|k| k.parse::<usize>().ok()).collect();
91 indices.sort_unstable();
92 if indices.is_empty() {
93 return None;
94 }
95 let arr: Vec<JsonValue> = indices
96 .iter()
97 .filter_map(|i| obj.get(&i.to_string()).cloned())
98 .collect();
99 Some(arr)
100 }
101 JsonValue::String(s) => {
102 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
103 parsed.as_array().cloned()
104 } else {
105 None
106 }
107 }
108 _ => None,
109 }
110}
111
112fn infer_list_element_type(rows: &[Vec<JsonValue>], col_idx: usize) -> Option<(String, DataType)> {
114 for row in rows {
115 let v = row.get(col_idx)?;
116 let arr = json_value_to_array(v)?;
117 let first = arr.first()?;
118 return Some(match first {
119 JsonValue::String(_) => ("string".to_string(), DataType::String),
120 JsonValue::Number(n) => {
121 if n.as_i64().is_some() {
122 ("bigint".to_string(), DataType::Int64)
123 } else {
124 ("double".to_string(), DataType::Float64)
125 }
126 }
127 JsonValue::Bool(_) => ("boolean".to_string(), DataType::Boolean),
128 JsonValue::Null => continue,
129 _ => ("string".to_string(), DataType::String),
130 });
131 }
132 None
133}
134
135fn json_values_to_series(
137 values: &[Option<JsonValue>],
138 type_str: &str,
139 name: &str,
140) -> Result<Series, PolarsError> {
141 use chrono::{NaiveDate, NaiveDateTime};
142 let epoch = crate::date_utils::epoch_naive_date();
143 let type_lower = type_str.trim().to_lowercase();
144
145 if let Some(elem_type) = parse_array_element_type(&type_lower) {
146 let inner_dtype = json_type_str_to_polars(&elem_type).ok_or_else(|| {
147 PolarsError::ComputeError(
148 format!("array element type '{elem_type}' not supported").into(),
149 )
150 })?;
151 let mut builder = get_list_builder(&inner_dtype, 64, values.len(), name.into());
152 for v in values.iter() {
153 if v.as_ref().is_none_or(|x| matches!(x, JsonValue::Null)) {
154 builder.append_null();
155 } else if let Some(arr) = v.as_ref().and_then(json_value_to_array) {
156 let elem_series: Vec<Series> = arr
158 .iter()
159 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
160 .collect::<Result<Vec<_>, _>>()?;
161 let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
162 let s = Series::from_any_values_and_dtype(
163 PlSmallStr::EMPTY,
164 &vals,
165 &inner_dtype,
166 false,
167 )
168 .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
169 builder.append_series(&s)?;
170 } else {
171 let single_arr = [v.clone().unwrap_or(JsonValue::Null)];
173 let elem_series: Vec<Series> = single_arr
174 .iter()
175 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
176 .collect::<Result<Vec<_>, _>>()?;
177 let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
178 let arr_series = Series::from_any_values_and_dtype(
179 PlSmallStr::EMPTY,
180 &vals,
181 &inner_dtype,
182 false,
183 )
184 .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
185 builder.append_series(&arr_series)?;
186 }
187 }
188 return Ok(builder.finish().into_series());
189 }
190
191 if let Some(fields) = parse_struct_fields(&type_lower) {
192 let mut field_series_vec: Vec<Vec<Option<JsonValue>>> = (0..fields.len())
193 .map(|_| Vec::with_capacity(values.len()))
194 .collect();
195 for v in values.iter() {
196 let effective: Option<JsonValue> = match v.as_ref() {
198 Some(JsonValue::String(s)) => {
199 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
200 if parsed.is_object() || parsed.is_array() {
201 Some(parsed)
202 } else {
203 v.clone()
204 }
205 } else {
206 v.clone()
207 }
208 }
209 _ => v.clone(),
210 };
211 if effective
212 .as_ref()
213 .is_none_or(|x| matches!(x, JsonValue::Null))
214 {
215 for fc in &mut field_series_vec {
216 fc.push(None);
217 }
218 } else if let Some(obj) = effective.as_ref().and_then(|x| x.as_object()) {
219 for (fi, (fname, _)) in fields.iter().enumerate() {
220 field_series_vec[fi].push(obj.get(fname).cloned());
221 }
222 } else if let Some(arr) = effective.as_ref().and_then(|x| x.as_array()) {
223 for (fi, _) in fields.iter().enumerate() {
224 field_series_vec[fi].push(arr.get(fi).cloned());
225 }
226 } else {
227 return Err(PolarsError::ComputeError(
228 "struct value must be object (by field name) or array (by position). \
229 PySpark accepts dict or tuple/list for struct columns."
230 .into(),
231 ));
232 }
233 }
234 let series_per_field: Vec<Series> = fields
235 .iter()
236 .enumerate()
237 .map(|(fi, (fname, ftype))| json_values_to_series(&field_series_vec[fi], ftype, fname))
238 .collect::<Result<Vec<_>, _>>()?;
239 let field_refs: Vec<&Series> = series_per_field.iter().collect();
240 let st = StructChunked::from_series(name.into(), values.len(), field_refs.iter().copied())
241 .map_err(|e| PolarsError::ComputeError(format!("struct column: {e}").into()))?
242 .into_series();
243 return Ok(st);
244 }
245
246 match type_lower.as_str() {
247 "int" | "bigint" | "long" => {
248 let vals: Vec<Option<i64>> = values
249 .iter()
250 .map(|ov| {
251 ov.as_ref().and_then(|v| match v {
252 JsonValue::Number(n) => n.as_i64(),
253 JsonValue::Null => None,
254 _ => None,
255 })
256 })
257 .collect();
258 Ok(Series::new(name.into(), vals))
259 }
260 "double" | "float" => {
261 let vals: Vec<Option<f64>> = values
262 .iter()
263 .map(|ov| {
264 ov.as_ref().and_then(|v| match v {
265 JsonValue::Number(n) => n.as_f64(),
266 JsonValue::Null => None,
267 _ => None,
268 })
269 })
270 .collect();
271 Ok(Series::new(name.into(), vals))
272 }
273 "string" | "str" | "varchar" => {
274 let vals: Vec<Option<&str>> = values
275 .iter()
276 .map(|ov| {
277 ov.as_ref().and_then(|v| match v {
278 JsonValue::String(s) => Some(s.as_str()),
279 JsonValue::Null => None,
280 _ => None,
281 })
282 })
283 .collect();
284 let owned: Vec<Option<String>> =
285 vals.into_iter().map(|o| o.map(|s| s.to_string())).collect();
286 Ok(Series::new(name.into(), owned))
287 }
288 "boolean" | "bool" => {
289 let vals: Vec<Option<bool>> = values
290 .iter()
291 .map(|ov| {
292 ov.as_ref().and_then(|v| match v {
293 JsonValue::Bool(b) => Some(*b),
294 JsonValue::Null => None,
295 _ => None,
296 })
297 })
298 .collect();
299 Ok(Series::new(name.into(), vals))
300 }
301 "date" => {
302 let vals: Vec<Option<i32>> = values
303 .iter()
304 .map(|ov| {
305 ov.as_ref().and_then(|v| match v {
306 JsonValue::String(s) => NaiveDate::parse_from_str(s, "%Y-%m-%d")
307 .ok()
308 .map(|d| (d - epoch).num_days() as i32),
309 JsonValue::Null => None,
310 _ => None,
311 })
312 })
313 .collect();
314 let s = Series::new(name.into(), vals);
315 s.cast(&DataType::Date)
316 .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))
317 }
318 "timestamp" | "datetime" | "timestamp_ntz" => {
319 let vals: Vec<Option<i64>> = values
320 .iter()
321 .map(|ov| {
322 ov.as_ref().and_then(|v| match v {
323 JsonValue::String(s) => {
324 let parsed = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
325 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))
326 .or_else(|_| {
327 NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").map_err(
328 |e| PolarsError::ComputeError(e.to_string().into()),
329 )
330 })
331 .or_else(|_| {
332 NaiveDate::parse_from_str(s, "%Y-%m-%d")
333 .map_err(|e| {
334 PolarsError::ComputeError(e.to_string().into())
335 })
336 .and_then(|d| {
337 d.and_hms_opt(0, 0, 0).ok_or_else(|| {
338 PolarsError::ComputeError(
339 "date to datetime (0:0:0)".into(),
340 )
341 })
342 })
343 });
344 parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
345 }
346 JsonValue::Number(n) => n.as_i64(),
347 JsonValue::Null => None,
348 _ => None,
349 })
350 })
351 .collect();
352 let s = Series::new(name.into(), vals);
353 s.cast(&DataType::Datetime(TimeUnit::Microseconds, None))
354 .map_err(|e| PolarsError::ComputeError(format!("datetime cast: {e}").into()))
355 }
356 _ => Err(PolarsError::ComputeError(
357 format!("json_values_to_series: unsupported type '{type_str}'").into(),
358 )),
359 }
360}
361
362fn json_value_to_series_single(
364 value: &JsonValue,
365 type_str: &str,
366 name: &str,
367) -> Result<Series, PolarsError> {
368 use chrono::NaiveDate;
369 let epoch = crate::date_utils::epoch_naive_date();
370 match (value, type_str.trim().to_lowercase().as_str()) {
371 (JsonValue::Null, _) => Ok(Series::new_null(name.into(), 1)),
372 (JsonValue::Number(n), "int" | "bigint" | "long") => {
373 Ok(Series::new(name.into(), vec![n.as_i64()]))
374 }
375 (JsonValue::Number(n), "double" | "float") => {
376 Ok(Series::new(name.into(), vec![n.as_f64()]))
377 }
378 (JsonValue::Number(n), t) if is_decimal_type_str(t) => {
379 Ok(Series::new(name.into(), vec![n.as_f64()]))
380 }
381 (JsonValue::String(s), "string" | "str" | "varchar") => {
382 Ok(Series::new(name.into(), vec![s.as_str()]))
383 }
384 (JsonValue::Bool(b), "boolean" | "bool") => Ok(Series::new(name.into(), vec![*b])),
385 (JsonValue::String(s), "date") => {
386 let d = NaiveDate::parse_from_str(s, "%Y-%m-%d")
387 .map_err(|e| PolarsError::ComputeError(format!("date parse: {e}").into()))?;
388 let days = (d - epoch).num_days() as i32;
389 let s = Series::new(name.into(), vec![days]).cast(&DataType::Date)?;
390 Ok(s)
391 }
392 _ => Err(PolarsError::ComputeError(
393 format!("json_value_to_series: unsupported {type_str} for {value:?}").into(),
394 )),
395 }
396}
397
398#[allow(dead_code)]
400fn json_object_or_array_to_struct_series(
401 value: &JsonValue,
402 fields: &[(String, String)],
403 _name: &str,
404) -> Result<Option<Series>, PolarsError> {
405 use polars::prelude::StructChunked;
406 if matches!(value, JsonValue::Null) {
407 return Ok(None);
408 }
409 let effective = match value {
411 JsonValue::String(s) => {
412 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
413 if parsed.is_object() || parsed.is_array() {
414 parsed
415 } else {
416 value.clone()
417 }
418 } else {
419 value.clone()
420 }
421 }
422 _ => value.clone(),
423 };
424 let mut field_series: Vec<Series> = Vec::with_capacity(fields.len());
425 for (fname, ftype) in fields {
426 let fval = if let Some(obj) = effective.as_object() {
427 obj.get(fname).unwrap_or(&JsonValue::Null)
428 } else if let Some(arr) = effective.as_array() {
429 let idx = field_series.len();
430 arr.get(idx).unwrap_or(&JsonValue::Null)
431 } else {
432 return Err(PolarsError::ComputeError(
433 "struct value must be object (by field name) or array (by position). \
434 PySpark accepts dict or tuple/list for struct columns."
435 .into(),
436 ));
437 };
438 let s = json_value_to_series_single(fval, ftype, fname)?;
439 field_series.push(s);
440 }
441 let field_refs: Vec<&Series> = field_series.iter().collect();
442 let st = StructChunked::from_series(PlSmallStr::EMPTY, 1, field_refs.iter().copied())
443 .map_err(|e| PolarsError::ComputeError(format!("struct from value: {e}").into()))?
444 .into_series();
445 Ok(Some(st))
446}
447
448fn json_object_to_map_struct_series(
451 obj: &serde_json::Map<String, JsonValue>,
452 key_type: &str,
453 value_type: &str,
454 key_dtype: &DataType,
455 value_dtype: &DataType,
456 _name: &str,
457) -> Result<Series, PolarsError> {
458 if obj.is_empty() {
459 let key_series = Series::new("key".into(), Vec::<String>::new());
460 let value_series = Series::new_empty(PlSmallStr::EMPTY, value_dtype);
461 let st = StructChunked::from_series(
462 PlSmallStr::EMPTY,
463 0,
464 [&key_series, &value_series].iter().copied(),
465 )
466 .map_err(|e| PolarsError::ComputeError(format!("map struct empty: {e}").into()))?
467 .into_series();
468 return Ok(st);
469 }
470 let keys: Vec<String> = obj.keys().cloned().collect();
471 let mut value_series = None::<Series>;
472 for v in obj.values() {
473 let s = json_value_to_series_single(v, value_type, "value")?;
474 value_series = Some(match value_series.take() {
475 None => s,
476 Some(mut acc) => {
477 acc.extend(&s).map_err(|e| {
478 PolarsError::ComputeError(format!("map value extend: {e}").into())
479 })?;
480 acc
481 }
482 });
483 }
484 let value_series =
485 value_series.unwrap_or_else(|| Series::new_empty(PlSmallStr::EMPTY, value_dtype));
486 let key_series = Series::new("key".into(), keys.clone());
487 let key_series = if key_type.trim().to_lowercase().as_str() == "string"
488 || key_type.trim().to_lowercase().as_str() == "str"
489 || key_type.trim().to_lowercase().as_str() == "varchar"
490 {
491 key_series
492 } else {
493 key_series
494 .cast(key_dtype)
495 .map_err(|e| PolarsError::ComputeError(format!("map key cast: {e}").into()))?
496 };
497 let st = StructChunked::from_series(
498 PlSmallStr::EMPTY,
499 key_series.len(),
500 [&key_series, &value_series].iter().copied(),
501 )
502 .map_err(|e| PolarsError::ComputeError(format!("map struct: {e}").into()))?
503 .into_series();
504 Ok(st)
505}
506
507use std::collections::{HashMap, HashSet};
508use std::path::Path;
509use std::sync::{Arc, Mutex, OnceLock};
510use std::thread_local;
511
512thread_local! {
513 static THREAD_UDF_SESSION: RefCell<Option<SparkSession>> = const { RefCell::new(None) };
515}
516
517pub(crate) fn set_thread_udf_session(session: SparkSession) {
519 THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = Some(session));
520}
521
522pub(crate) fn get_thread_udf_session() -> Option<SparkSession> {
524 THREAD_UDF_SESSION.with(|cell| cell.borrow().clone())
525}
526
527pub(crate) fn clear_thread_udf_session() {
529 THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = None);
530}
531
532static GLOBAL_TEMP_CATALOG: OnceLock<Arc<Mutex<HashMap<String, DataFrame>>>> = OnceLock::new();
535
536fn global_temp_catalog() -> Arc<Mutex<HashMap<String, DataFrame>>> {
537 GLOBAL_TEMP_CATALOG
538 .get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
539 .clone()
540}
541
542#[derive(Clone)]
544pub struct SparkSessionBuilder {
545 app_name: Option<String>,
546 master: Option<String>,
547 config: HashMap<String, String>,
548}
549
550impl Default for SparkSessionBuilder {
551 fn default() -> Self {
552 Self::new()
553 }
554}
555
556impl SparkSessionBuilder {
557 pub fn new() -> Self {
558 SparkSessionBuilder {
559 app_name: None,
560 master: None,
561 config: HashMap::new(),
562 }
563 }
564
565 pub fn app_name(mut self, name: impl Into<String>) -> Self {
566 self.app_name = Some(name.into());
567 self
568 }
569
570 pub fn master(mut self, master: impl Into<String>) -> Self {
571 self.master = Some(master.into());
572 self
573 }
574
575 pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
576 self.config.insert(key.into(), value.into());
577 self
578 }
579
580 pub fn get_or_create(self) -> SparkSession {
581 let session = SparkSession::new(self.app_name, self.master, self.config);
582 set_thread_udf_session(session.clone());
583 session
584 }
585
586 pub fn with_config(mut self, config: &crate::config::SparklessConfig) -> Self {
589 for (k, v) in config.to_session_config() {
590 self.config.insert(k, v);
591 }
592 self
593 }
594}
595
596pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
598
599pub type TableCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
601
602pub type DatabaseCatalog = Arc<Mutex<HashSet<String>>>;
604
605#[derive(Clone)]
608pub struct SparkSession {
609 app_name: Option<String>,
610 master: Option<String>,
611 config: HashMap<String, String>,
612 pub(crate) catalog: TempViewCatalog,
614 pub(crate) tables: TableCatalog,
616 pub(crate) databases: DatabaseCatalog,
618 pub(crate) udf_registry: UdfRegistry,
620}
621
622impl SparkSession {
623 pub fn new(
624 app_name: Option<String>,
625 master: Option<String>,
626 config: HashMap<String, String>,
627 ) -> Self {
628 SparkSession {
629 app_name,
630 master,
631 config,
632 catalog: Arc::new(Mutex::new(HashMap::new())),
633 tables: Arc::new(Mutex::new(HashMap::new())),
634 databases: Arc::new(Mutex::new(HashSet::new())),
635 udf_registry: UdfRegistry::new(),
636 }
637 }
638
639 pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
642 let _ = self
643 .catalog
644 .lock()
645 .map(|mut m| m.insert(name.to_string(), df));
646 }
647
648 pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
650 let _ = global_temp_catalog()
651 .lock()
652 .map(|mut m| m.insert(name.to_string(), df));
653 }
654
655 pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
657 let _ = global_temp_catalog()
658 .lock()
659 .map(|mut m| m.insert(name.to_string(), df));
660 }
661
662 pub fn drop_temp_view(&self, name: &str) {
665 let _ = self.catalog.lock().map(|mut m| m.remove(name));
666 }
667
668 pub fn drop_global_temp_view(&self, name: &str) -> bool {
670 global_temp_catalog()
671 .lock()
672 .map(|mut m| m.remove(name).is_some())
673 .unwrap_or(false)
674 }
675
676 pub fn register_table(&self, name: &str, df: DataFrame) {
678 let _ = self
679 .tables
680 .lock()
681 .map(|mut m| m.insert(name.to_string(), df));
682 }
683
684 pub fn register_database(&self, name: &str) {
686 let _ = self.databases.lock().map(|mut s| {
687 s.insert(name.to_string());
688 });
689 }
690
691 pub fn list_database_names(&self) -> Vec<String> {
693 let mut names: Vec<String> = vec!["default".to_string(), "global_temp".to_string()];
694 if let Ok(guard) = self.databases.lock() {
695 let mut created: Vec<String> = guard.iter().cloned().collect();
696 created.sort();
697 names.extend(created);
698 }
699 names
700 }
701
702 pub fn database_exists(&self, name: &str) -> bool {
704 if name.eq_ignore_ascii_case("default") || name.eq_ignore_ascii_case("global_temp") {
705 return true;
706 }
707 self.databases
708 .lock()
709 .map(|s| s.iter().any(|n| n.eq_ignore_ascii_case(name)))
710 .unwrap_or(false)
711 }
712
713 pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
715 self.tables.lock().ok().and_then(|m| m.get(name).cloned())
716 }
717
718 pub fn saved_table_exists(&self, name: &str) -> bool {
720 self.tables
721 .lock()
722 .map(|m| m.contains_key(name))
723 .unwrap_or(false)
724 }
725
726 pub fn table_exists(&self, name: &str) -> bool {
728 if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
730 return global_temp_catalog()
731 .lock()
732 .map(|m| m.contains_key(tbl))
733 .unwrap_or(false);
734 }
735 if self
736 .catalog
737 .lock()
738 .map(|m| m.contains_key(name))
739 .unwrap_or(false)
740 {
741 return true;
742 }
743 if self
744 .tables
745 .lock()
746 .map(|m| m.contains_key(name))
747 .unwrap_or(false)
748 {
749 return true;
750 }
751 if let Some(warehouse) = self.warehouse_dir() {
753 let path = Path::new(warehouse).join(name);
754 if path.is_dir() {
755 return true;
756 }
757 }
758 false
759 }
760
761 pub fn list_global_temp_view_names(&self) -> Vec<String> {
763 global_temp_catalog()
764 .lock()
765 .map(|m| m.keys().cloned().collect())
766 .unwrap_or_default()
767 }
768
769 pub fn list_temp_view_names(&self) -> Vec<String> {
771 self.catalog
772 .lock()
773 .map(|m| m.keys().cloned().collect())
774 .unwrap_or_default()
775 }
776
777 pub fn list_table_names(&self) -> Vec<String> {
779 self.tables
780 .lock()
781 .map(|m| m.keys().cloned().collect())
782 .unwrap_or_default()
783 }
784
785 pub fn drop_table(&self, name: &str) -> bool {
787 self.tables
788 .lock()
789 .map(|mut m| m.remove(name).is_some())
790 .unwrap_or(false)
791 }
792
793 pub fn drop_database(&self, name: &str) -> bool {
796 if name.eq_ignore_ascii_case("default") || name.eq_ignore_ascii_case("global_temp") {
797 return false;
798 }
799 self.databases
800 .lock()
801 .map(|mut s| s.remove(name))
802 .unwrap_or(false)
803 }
804
805 fn parse_global_temp_name(name: &str) -> Option<(&str, &str)> {
807 if let Some(dot) = name.find('.') {
808 let (db, tbl) = name.split_at(dot);
809 if db.eq_ignore_ascii_case("global_temp") {
810 return Some((db, tbl.strip_prefix('.').unwrap_or(tbl)));
811 }
812 }
813 None
814 }
815
816 pub fn warehouse_dir(&self) -> Option<&str> {
818 self.config
819 .get("spark.sql.warehouse.dir")
820 .map(|s| s.as_str())
821 .filter(|s| !s.is_empty())
822 }
823
824 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
827 if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
829 if let Some(df) = global_temp_catalog()
830 .lock()
831 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
832 .get(tbl)
833 .cloned()
834 {
835 return Ok(df);
836 }
837 return Err(PolarsError::InvalidOperation(
838 format!(
839 "Global temp view '{tbl}' not found. Register it with createOrReplaceGlobalTempView."
840 )
841 .into(),
842 ));
843 }
844 if let Some(df) = self
846 .catalog
847 .lock()
848 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
849 .get(name)
850 .cloned()
851 {
852 return Ok(df);
853 }
854 if let Some(df) = self
855 .tables
856 .lock()
857 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
858 .get(name)
859 .cloned()
860 {
861 return Ok(df);
862 }
863 if let Some(warehouse) = self.warehouse_dir() {
865 let dir = Path::new(warehouse).join(name);
866 if dir.is_dir() {
867 let data_file = dir.join("data.parquet");
869 let read_path = if data_file.is_file() { data_file } else { dir };
870 return self.read_parquet(&read_path);
871 }
872 }
873 Err(PolarsError::InvalidOperation(
874 format!(
875 "Table or view '{name}' not found. Register it with create_or_replace_temp_view or saveAsTable."
876 )
877 .into(),
878 ))
879 }
880
881 pub fn builder() -> SparkSessionBuilder {
882 SparkSessionBuilder::new()
883 }
884
885 pub fn from_config(config: &crate::config::SparklessConfig) -> SparkSession {
888 Self::builder().with_config(config).get_or_create()
889 }
890
891 pub fn get_config(&self) -> &HashMap<String, String> {
893 &self.config
894 }
895
896 pub fn is_case_sensitive(&self) -> bool {
899 self.config
900 .get("spark.sql.caseSensitive")
901 .map(|v| v.eq_ignore_ascii_case("true"))
902 .unwrap_or(false)
903 }
904
905 pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
907 where
908 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
909 {
910 self.udf_registry.register_rust_udf(name, f)
911 }
912
913 pub fn create_dataframe(
933 &self,
934 data: Vec<(i64, i64, String)>,
935 column_names: Vec<&str>,
936 ) -> Result<DataFrame, PolarsError> {
937 if column_names.len() != 3 {
938 return Err(PolarsError::ComputeError(
939 format!(
940 "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
941 column_names.len()
942 )
943 .into(),
944 ));
945 }
946
947 let mut cols: Vec<Series> = Vec::with_capacity(3);
948
949 let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
951 cols.push(Series::new(column_names[0].into(), col0));
952
953 let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
955 cols.push(Series::new(column_names[1].into(), col1));
956
957 let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
959 cols.push(Series::new(column_names[2].into(), col2));
960
961 let pl_df = PlDataFrame::new_infer_height(cols.iter().map(|s| s.clone().into()).collect())?;
962 Ok(DataFrame::from_polars_with_options(
963 pl_df,
964 self.is_case_sensitive(),
965 ))
966 }
967
968 pub fn create_dataframe_engine(
970 &self,
971 data: Vec<(i64, i64, String)>,
972 column_names: Vec<&str>,
973 ) -> Result<DataFrame, EngineError> {
974 self.create_dataframe(data, column_names)
975 .map_err(EngineError::from)
976 }
977
978 pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
980 DataFrame::from_polars_with_options(df, self.is_case_sensitive())
981 }
982
983 fn infer_dtype_from_json_value(v: &JsonValue) -> Option<String> {
985 match v {
986 JsonValue::Null => None,
987 JsonValue::Bool(_) => Some("boolean".to_string()),
988 JsonValue::Number(n) => {
989 if n.is_i64() {
990 Some("bigint".to_string())
991 } else {
992 Some("double".to_string())
993 }
994 }
995 JsonValue::String(s) => {
996 if chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d").is_ok() {
997 Some("date".to_string())
998 } else if chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f").is_ok()
999 || chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").is_ok()
1000 {
1001 Some("timestamp".to_string())
1002 } else {
1003 Some("string".to_string())
1004 }
1005 }
1006 JsonValue::Array(_) => Some("array".to_string()),
1007 JsonValue::Object(_) => Some("string".to_string()), }
1009 }
1010
1011 pub fn infer_schema_from_json_rows(
1014 rows: &[Vec<JsonValue>],
1015 names: &[String],
1016 ) -> Vec<(String, String)> {
1017 if names.is_empty() {
1018 return Vec::new();
1019 }
1020 let mut schema: Vec<(String, String)> = names
1021 .iter()
1022 .map(|n| (n.clone(), "string".to_string()))
1023 .collect();
1024 for (col_idx, (_, dtype_str)) in schema.iter_mut().enumerate() {
1025 for row in rows {
1026 let v = row.get(col_idx).unwrap_or(&JsonValue::Null);
1027 if let Some(dtype) = Self::infer_dtype_from_json_value(v) {
1028 *dtype_str = dtype;
1029 break;
1030 }
1031 }
1032 }
1033 schema
1034 }
1035
1036 pub fn create_dataframe_from_rows(
1043 &self,
1044 rows: Vec<Vec<JsonValue>>,
1045 schema: Vec<(String, String)>,
1046 ) -> Result<DataFrame, PolarsError> {
1047 let schema = if schema.is_empty() && !rows.is_empty() {
1049 let ncols = rows[0].len();
1050 let names: Vec<String> = (0..ncols).map(|i| format!("c{i}")).collect();
1051 Self::infer_schema_from_json_rows(&rows, &names)
1052 } else {
1053 schema
1054 };
1055
1056 if schema.is_empty() {
1057 if rows.is_empty() {
1058 return Ok(DataFrame::from_polars_with_options(
1059 PlDataFrame::new(0, vec![])?,
1060 self.is_case_sensitive(),
1061 ));
1062 }
1063 return Err(PolarsError::InvalidOperation(
1064 "create_dataframe_from_rows: schema must not be empty when rows are not empty"
1065 .into(),
1066 ));
1067 }
1068 use chrono::{NaiveDate, NaiveDateTime};
1069
1070 let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
1071
1072 for (col_idx, (name, type_str)) in schema.iter().enumerate() {
1073 let type_lower = type_str.trim().to_lowercase();
1074 let s = match type_lower.as_str() {
1075 "int" | "integer" | "bigint" | "long" => {
1076 let vals: Vec<Option<i64>> = rows
1077 .iter()
1078 .map(|row| {
1079 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1080 match v {
1081 JsonValue::Number(n) => n.as_i64(),
1082 JsonValue::Null => None,
1083 _ => None,
1084 }
1085 })
1086 .collect();
1087 Series::new(name.as_str().into(), vals)
1088 }
1089 "double" | "float" | "double_precision" => {
1090 let vals: Vec<Option<f64>> = rows
1091 .iter()
1092 .map(|row| {
1093 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1094 match v {
1095 JsonValue::Number(n) => n.as_f64(),
1096 JsonValue::Null => None,
1097 _ => None,
1098 }
1099 })
1100 .collect();
1101 Series::new(name.as_str().into(), vals)
1102 }
1103 _ if is_decimal_type_str(&type_lower) => {
1104 let vals: Vec<Option<f64>> = rows
1105 .iter()
1106 .map(|row| {
1107 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1108 match v {
1109 JsonValue::Number(n) => n.as_f64(),
1110 JsonValue::Null => None,
1111 _ => None,
1112 }
1113 })
1114 .collect();
1115 Series::new(name.as_str().into(), vals)
1116 }
1117 "string" | "str" | "varchar" => {
1118 let vals: Vec<Option<String>> = rows
1119 .iter()
1120 .map(|row| {
1121 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1122 match v {
1123 JsonValue::String(s) => Some(s),
1124 JsonValue::Null => None,
1125 other => Some(other.to_string()),
1126 }
1127 })
1128 .collect();
1129 Series::new(name.as_str().into(), vals)
1130 }
1131 "boolean" | "bool" => {
1132 let vals: Vec<Option<bool>> = rows
1133 .iter()
1134 .map(|row| {
1135 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1136 match v {
1137 JsonValue::Bool(b) => Some(b),
1138 JsonValue::Null => None,
1139 _ => None,
1140 }
1141 })
1142 .collect();
1143 Series::new(name.as_str().into(), vals)
1144 }
1145 "date" => {
1146 let epoch = crate::date_utils::epoch_naive_date();
1147 let vals: Vec<Option<i32>> = rows
1148 .iter()
1149 .map(|row| {
1150 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1151 match v {
1152 JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
1153 .ok()
1154 .map(|d| (d - epoch).num_days() as i32),
1155 JsonValue::Null => None,
1156 _ => None,
1157 }
1158 })
1159 .collect();
1160 let series = Series::new(name.as_str().into(), vals);
1161 series
1162 .cast(&DataType::Date)
1163 .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
1164 }
1165 "timestamp" | "datetime" | "timestamp_ntz" => {
1166 let vals: Vec<Option<i64>> =
1167 rows.iter()
1168 .map(|row| {
1169 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1170 match v {
1171 JsonValue::String(s) => {
1172 let parsed = NaiveDateTime::parse_from_str(
1173 &s,
1174 "%Y-%m-%dT%H:%M:%S%.f",
1175 )
1176 .map_err(|e| {
1177 PolarsError::ComputeError(e.to_string().into())
1178 })
1179 .or_else(|_| {
1180 NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
1181 .map_err(|e| {
1182 PolarsError::ComputeError(e.to_string().into())
1183 })
1184 })
1185 .or_else(|_| {
1186 NaiveDate::parse_from_str(&s, "%Y-%m-%d")
1187 .map_err(|e| {
1188 PolarsError::ComputeError(e.to_string().into())
1189 })
1190 .and_then(|d| {
1191 d.and_hms_opt(0, 0, 0).ok_or_else(|| {
1192 PolarsError::ComputeError(
1193 "date to datetime (0:0:0)".into(),
1194 )
1195 })
1196 })
1197 });
1198 parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
1199 }
1200 JsonValue::Number(n) => n.as_i64(),
1201 JsonValue::Null => None,
1202 _ => None,
1203 }
1204 })
1205 .collect();
1206 let series = Series::new(name.as_str().into(), vals);
1207 series
1208 .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
1209 .map_err(|e| {
1210 PolarsError::ComputeError(format!("datetime cast: {e}").into())
1211 })?
1212 }
1213 "list" | "array" => {
1214 let (elem_type, inner_dtype) = infer_list_element_type(&rows, col_idx)
1216 .unwrap_or(("bigint".to_string(), DataType::Int64));
1217 let n = rows.len();
1218 let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
1219 for row in rows.iter() {
1220 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1221 if let JsonValue::Null = &v {
1222 builder.append_null();
1223 } else if let Some(arr) = json_value_to_array(&v) {
1224 let elem_series: Vec<Series> = arr
1226 .iter()
1227 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1228 .collect::<Result<Vec<_>, _>>()?;
1229 let vals: Vec<_> =
1230 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1231 let s = Series::from_any_values_and_dtype(
1232 PlSmallStr::EMPTY,
1233 &vals,
1234 &inner_dtype,
1235 false,
1236 )
1237 .map_err(|e| {
1238 PolarsError::ComputeError(format!("array elem: {e}").into())
1239 })?;
1240 builder.append_series(&s)?;
1241 } else {
1242 let single_arr = [v];
1244 let elem_series: Vec<Series> = single_arr
1245 .iter()
1246 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1247 .collect::<Result<Vec<_>, _>>()?;
1248 let vals: Vec<_> =
1249 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1250 let s = Series::from_any_values_and_dtype(
1251 PlSmallStr::EMPTY,
1252 &vals,
1253 &inner_dtype,
1254 false,
1255 )
1256 .map_err(|e| {
1257 PolarsError::ComputeError(format!("array elem: {e}").into())
1258 })?;
1259 builder.append_series(&s)?;
1260 }
1261 }
1262 builder.finish().into_series()
1263 }
1264 _ if parse_array_element_type(&type_lower).is_some() => {
1265 let elem_type = parse_array_element_type(&type_lower).unwrap_or_else(|| {
1266 unreachable!("guard above ensures parse_array_element_type returned Some")
1267 });
1268 let inner_dtype = json_type_str_to_polars(&elem_type)
1269 .ok_or_else(|| {
1270 PolarsError::ComputeError(
1271 format!(
1272 "create_dataframe_from_rows: array element type '{elem_type}' not supported"
1273 )
1274 .into(),
1275 )
1276 })?;
1277 let n = rows.len();
1278 let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
1279 for row in rows.iter() {
1280 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1281 if let JsonValue::Null = &v {
1282 builder.append_null();
1283 } else if let Some(arr) = json_value_to_array(&v) {
1284 let elem_series: Vec<Series> = arr
1286 .iter()
1287 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1288 .collect::<Result<Vec<_>, _>>()?;
1289 let vals: Vec<_> =
1290 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1291 let s = Series::from_any_values_and_dtype(
1292 PlSmallStr::EMPTY,
1293 &vals,
1294 &inner_dtype,
1295 false,
1296 )
1297 .map_err(|e| {
1298 PolarsError::ComputeError(format!("array elem: {e}").into())
1299 })?;
1300 builder.append_series(&s)?;
1301 } else {
1302 let single_arr = [v];
1304 let elem_series: Vec<Series> = single_arr
1305 .iter()
1306 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1307 .collect::<Result<Vec<_>, _>>()?;
1308 let vals: Vec<_> =
1309 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1310 let s = Series::from_any_values_and_dtype(
1311 PlSmallStr::EMPTY,
1312 &vals,
1313 &inner_dtype,
1314 false,
1315 )
1316 .map_err(|e| {
1317 PolarsError::ComputeError(format!("array elem: {e}").into())
1318 })?;
1319 builder.append_series(&s)?;
1320 }
1321 }
1322 builder.finish().into_series()
1323 }
1324 _ if parse_map_key_value_types(&type_lower).is_some() => {
1325 let (key_type, value_type) = parse_map_key_value_types(&type_lower)
1326 .unwrap_or_else(|| unreachable!("guard ensures Some"));
1327 let key_dtype = json_type_str_to_polars(&key_type).ok_or_else(|| {
1328 PolarsError::ComputeError(
1329 format!(
1330 "create_dataframe_from_rows: map key type '{key_type}' not supported"
1331 )
1332 .into(),
1333 )
1334 })?;
1335 let value_dtype = json_type_str_to_polars(&value_type).ok_or_else(|| {
1336 PolarsError::ComputeError(
1337 format!(
1338 "create_dataframe_from_rows: map value type '{value_type}' not supported"
1339 )
1340 .into(),
1341 )
1342 })?;
1343 let struct_dtype = DataType::Struct(vec![
1344 Field::new("key".into(), key_dtype.clone()),
1345 Field::new("value".into(), value_dtype.clone()),
1346 ]);
1347 let n = rows.len();
1348 let mut builder = get_list_builder(&struct_dtype, 64, n, name.as_str().into());
1349 for row in rows.iter() {
1350 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1351 if matches!(v, JsonValue::Null) {
1352 builder.append_null();
1353 } else if let Some(obj) = v.as_object() {
1354 let st = json_object_to_map_struct_series(
1355 obj,
1356 &key_type,
1357 &value_type,
1358 &key_dtype,
1359 &value_dtype,
1360 name,
1361 )?;
1362 builder.append_series(&st)?;
1363 } else {
1364 return Err(PolarsError::ComputeError(
1365 format!(
1366 "create_dataframe_from_rows: map column '{name}' expects JSON object (dict), got {:?}",
1367 v
1368 )
1369 .into(),
1370 ));
1371 }
1372 }
1373 builder.finish().into_series()
1374 }
1375 _ if parse_struct_fields(&type_lower).is_some() => {
1376 let values: Vec<Option<JsonValue>> =
1377 rows.iter().map(|row| row.get(col_idx).cloned()).collect();
1378 json_values_to_series(&values, &type_lower, name)?
1379 }
1380 _ => {
1381 return Err(PolarsError::ComputeError(
1382 format!(
1383 "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
1384 )
1385 .into(),
1386 ));
1387 }
1388 };
1389 cols.push(s);
1390 }
1391
1392 let pl_df = PlDataFrame::new_infer_height(cols.iter().map(|s| s.clone().into()).collect())?;
1393 Ok(DataFrame::from_polars_with_options(
1394 pl_df,
1395 self.is_case_sensitive(),
1396 ))
1397 }
1398
1399 pub fn create_dataframe_from_rows_engine(
1401 &self,
1402 rows: Vec<Vec<JsonValue>>,
1403 schema: Vec<(String, String)>,
1404 ) -> Result<DataFrame, EngineError> {
1405 self.create_dataframe_from_rows(rows, schema)
1406 .map_err(EngineError::from)
1407 }
1408
1409 pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
1416 if step == 0 {
1417 return Err(PolarsError::InvalidOperation(
1418 "range: step must not be 0".into(),
1419 ));
1420 }
1421 let mut vals: Vec<i64> = Vec::new();
1422 let mut v = start;
1423 if step > 0 {
1424 while v < end {
1425 vals.push(v);
1426 v = v.saturating_add(step);
1427 }
1428 } else {
1429 while v > end {
1430 vals.push(v);
1431 v = v.saturating_add(step);
1432 }
1433 }
1434 let col = Series::new("id".into(), vals);
1435 let pl_df = PlDataFrame::new_infer_height(vec![col.into()])?;
1436 Ok(DataFrame::from_polars_with_options(
1437 pl_df,
1438 self.is_case_sensitive(),
1439 ))
1440 }
1441
1442 pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1457 use polars::prelude::*;
1458 let path = path.as_ref();
1459 if !path.exists() {
1460 return Err(PolarsError::ComputeError(
1461 format!("read_csv: file not found: {}", path.display()).into(),
1462 ));
1463 }
1464 let path_display = path.display();
1465 let pl_path = PlRefPath::try_from_path(path).map_err(|e| {
1467 PolarsError::ComputeError(format!("read_csv({path_display}): path: {e}").into())
1468 })?;
1469 let lf = LazyCsvReader::new(pl_path)
1470 .with_has_header(true)
1471 .with_infer_schema_length(Some(100))
1472 .finish()
1473 .map_err(|e| {
1474 PolarsError::ComputeError(
1475 format!(
1476 "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
1477 )
1478 .into(),
1479 )
1480 })?;
1481 Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1482 lf,
1483 self.is_case_sensitive(),
1484 ))
1485 }
1486
1487 pub fn read_csv_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1489 self.read_csv(path).map_err(EngineError::from)
1490 }
1491
1492 pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1506 use polars::prelude::*;
1507 let path = path.as_ref();
1508 if !path.exists() {
1509 return Err(PolarsError::ComputeError(
1510 format!("read_parquet: file not found: {}", path.display()).into(),
1511 ));
1512 }
1513 let pl_path = PlRefPath::try_from_path(path)
1515 .map_err(|e| PolarsError::ComputeError(format!("read_parquet: path: {e}").into()))?;
1516 let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())?;
1517 Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1518 lf,
1519 self.is_case_sensitive(),
1520 ))
1521 }
1522
1523 pub fn read_parquet_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1525 self.read_parquet(path).map_err(EngineError::from)
1526 }
1527
1528 pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1542 use polars::prelude::*;
1543 use std::num::NonZeroUsize;
1544 let path = path.as_ref();
1545 if !path.exists() {
1546 return Err(PolarsError::ComputeError(
1547 format!("read_json: file not found: {}", path.display()).into(),
1548 ));
1549 }
1550 let pl_path = PlRefPath::try_from_path(path)
1552 .map_err(|e| PolarsError::ComputeError(format!("read_json: path: {e}").into()))?;
1553 let lf = LazyJsonLineReader::new(pl_path)
1554 .with_infer_schema_length(NonZeroUsize::new(100))
1555 .finish()?;
1556 Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1557 lf,
1558 self.is_case_sensitive(),
1559 ))
1560 }
1561
1562 pub fn read_json_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1564 self.read_json(path).map_err(EngineError::from)
1565 }
1566
1567 #[cfg(feature = "sql")]
1571 pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
1572 crate::sql::execute_sql(self, query)
1573 }
1574
1575 #[cfg(not(feature = "sql"))]
1577 pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
1578 Err(PolarsError::InvalidOperation(
1579 "SQL queries require the 'sql' feature. Build with --features sql.".into(),
1580 ))
1581 }
1582
1583 pub fn table_engine(&self, name: &str) -> Result<DataFrame, EngineError> {
1585 self.table(name).map_err(EngineError::from)
1586 }
1587
1588 fn looks_like_path(s: &str) -> bool {
1590 s.contains('/') || s.contains('\\') || Path::new(s).exists()
1591 }
1592
1593 #[cfg(feature = "delta")]
1595 pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1596 crate::delta::read_delta(path, self.is_case_sensitive())
1597 }
1598
1599 #[cfg(feature = "delta")]
1601 pub fn read_delta_path_with_version(
1602 &self,
1603 path: impl AsRef<Path>,
1604 version: Option<i64>,
1605 ) -> Result<DataFrame, PolarsError> {
1606 crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
1607 }
1608
1609 #[cfg(feature = "delta")]
1611 pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1612 if Self::looks_like_path(name_or_path) {
1613 self.read_delta_path(Path::new(name_or_path))
1614 } else {
1615 self.table(name_or_path)
1616 }
1617 }
1618
1619 #[cfg(feature = "delta")]
1620 pub fn read_delta_with_version(
1621 &self,
1622 name_or_path: &str,
1623 version: Option<i64>,
1624 ) -> Result<DataFrame, PolarsError> {
1625 if Self::looks_like_path(name_or_path) {
1626 self.read_delta_path_with_version(Path::new(name_or_path), version)
1627 } else {
1628 self.table(name_or_path)
1630 }
1631 }
1632
1633 #[cfg(not(feature = "delta"))]
1635 pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1636 if Self::looks_like_path(name_or_path) {
1637 Err(PolarsError::InvalidOperation(
1638 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1639 ))
1640 } else {
1641 self.table(name_or_path)
1642 }
1643 }
1644
1645 #[cfg(not(feature = "delta"))]
1646 pub fn read_delta_with_version(
1647 &self,
1648 name_or_path: &str,
1649 version: Option<i64>,
1650 ) -> Result<DataFrame, PolarsError> {
1651 let _ = version;
1652 self.read_delta(name_or_path)
1653 }
1654
1655 #[cfg(feature = "delta")]
1657 pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1658 self.read_delta_path(path)
1659 }
1660
1661 #[cfg(not(feature = "delta"))]
1662 pub fn read_delta_from_path(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1663 Err(PolarsError::InvalidOperation(
1664 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1665 ))
1666 }
1667
1668 pub fn stop(&self) {
1670 let _ = self.catalog.lock().map(|mut m| m.clear());
1673 let _ = self.tables.lock().map(|mut m| m.clear());
1674 let _ = self.databases.lock().map(|mut s| s.clear());
1675 let _ = self.udf_registry.clear();
1676 clear_thread_udf_session();
1677 }
1678}
1679
1680pub struct DataFrameReader {
1683 session: SparkSession,
1684 options: HashMap<String, String>,
1685 format: Option<String>,
1686}
1687
1688impl DataFrameReader {
1689 pub fn new(session: SparkSession) -> Self {
1690 DataFrameReader {
1691 session,
1692 options: HashMap::new(),
1693 format: None,
1694 }
1695 }
1696
1697 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1699 self.options.insert(key.into(), value.into());
1700 self
1701 }
1702
1703 pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
1705 for (k, v) in opts {
1706 self.options.insert(k, v);
1707 }
1708 self
1709 }
1710
1711 pub fn format(mut self, fmt: impl Into<String>) -> Self {
1713 self.format = Some(fmt.into());
1714 self
1715 }
1716
1717 pub fn schema(self, _schema: impl Into<String>) -> Self {
1719 self
1720 }
1721
1722 pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1724 let path = path.as_ref();
1725 let fmt = self.format.clone().or_else(|| {
1726 path.extension()
1727 .and_then(|e| e.to_str())
1728 .map(|s| s.to_lowercase())
1729 });
1730 match fmt.as_deref() {
1731 Some("parquet") => self.parquet(path),
1732 Some("csv") => self.csv(path),
1733 Some("json") | Some("jsonl") => self.json(path),
1734 #[cfg(feature = "delta")]
1735 Some("delta") => self.session.read_delta_from_path(path),
1736 _ => Err(PolarsError::ComputeError(
1737 format!(
1738 "load: could not infer format for path '{}'. Use format('parquet'|'csv'|'json') before load.",
1739 path.display()
1740 )
1741 .into(),
1742 )),
1743 }
1744 }
1745
1746 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
1748 self.session.table(name)
1749 }
1750
1751 fn apply_csv_options(
1752 &self,
1753 reader: polars::prelude::LazyCsvReader,
1754 ) -> polars::prelude::LazyCsvReader {
1755 use polars::prelude::NullValues;
1756 let mut r = reader;
1757 if let Some(v) = self.options.get("header") {
1758 let has_header = v.eq_ignore_ascii_case("true") || v == "1";
1759 r = r.with_has_header(has_header);
1760 }
1761 if let Some(v) = self.options.get("inferSchema") {
1762 if v.eq_ignore_ascii_case("true") || v == "1" {
1763 let n = self
1764 .options
1765 .get("inferSchemaLength")
1766 .and_then(|s| s.parse::<usize>().ok())
1767 .unwrap_or(100);
1768 r = r.with_infer_schema_length(Some(n));
1769 } else {
1770 r = r.with_infer_schema_length(Some(0));
1772 }
1773 } else if let Some(v) = self.options.get("inferSchemaLength") {
1774 if let Ok(n) = v.parse::<usize>() {
1775 r = r.with_infer_schema_length(Some(n));
1776 }
1777 }
1778 if let Some(sep) = self.options.get("sep") {
1779 if let Some(b) = sep.bytes().next() {
1780 r = r.with_separator(b);
1781 }
1782 }
1783 if let Some(null_val) = self.options.get("nullValue") {
1784 r = r.with_null_values(Some(NullValues::AllColumnsSingle(null_val.clone().into())));
1785 }
1786 r
1787 }
1788
1789 fn apply_json_options(
1790 &self,
1791 reader: polars::prelude::LazyJsonLineReader,
1792 ) -> polars::prelude::LazyJsonLineReader {
1793 use std::num::NonZeroUsize;
1794 let mut r = reader;
1795 if let Some(v) = self.options.get("inferSchemaLength") {
1796 if let Ok(n) = v.parse::<usize>() {
1797 r = r.with_infer_schema_length(NonZeroUsize::new(n));
1798 }
1799 }
1800 r
1801 }
1802
1803 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1804 use polars::prelude::*;
1805 let path = path.as_ref();
1806 let path_display = path.display();
1807 let pl_path = PlRefPath::try_from_path(path).map_err(|e| {
1808 PolarsError::ComputeError(format!("csv({path_display}): path: {e}").into())
1809 })?;
1810 let reader = LazyCsvReader::new(pl_path);
1811 let reader = if self.options.is_empty() {
1812 reader
1813 .with_has_header(true)
1814 .with_infer_schema_length(Some(100))
1815 } else {
1816 self.apply_csv_options(
1817 reader
1818 .with_has_header(true)
1819 .with_infer_schema_length(Some(100)),
1820 )
1821 };
1822 let lf = reader.finish().map_err(|e| {
1823 PolarsError::ComputeError(format!("read csv({path_display}): {e}").into())
1824 })?;
1825 let pl_df = lf.collect().map_err(|e| {
1826 PolarsError::ComputeError(
1827 format!("read csv({path_display}): collect failed: {e}").into(),
1828 )
1829 })?;
1830 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1831 pl_df,
1832 self.session.is_case_sensitive(),
1833 ))
1834 }
1835
1836 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1837 use polars::prelude::*;
1838 let path = path.as_ref();
1839 let pl_path = PlRefPath::try_from_path(path)
1840 .map_err(|e| PolarsError::ComputeError(format!("parquet: path: {e}").into()))?;
1841 let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())?;
1842 let pl_df = lf.collect()?;
1843 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1844 pl_df,
1845 self.session.is_case_sensitive(),
1846 ))
1847 }
1848
1849 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1850 use polars::prelude::*;
1851 use std::num::NonZeroUsize;
1852 let path = path.as_ref();
1853 let pl_path = PlRefPath::try_from_path(path)
1854 .map_err(|e| PolarsError::ComputeError(format!("json: path: {e}").into()))?;
1855 let reader = LazyJsonLineReader::new(pl_path);
1856 let reader = if self.options.is_empty() {
1857 reader.with_infer_schema_length(NonZeroUsize::new(100))
1858 } else {
1859 self.apply_json_options(reader.with_infer_schema_length(NonZeroUsize::new(100)))
1860 };
1861 let lf = reader.finish()?;
1862 let pl_df = lf.collect()?;
1863 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1864 pl_df,
1865 self.session.is_case_sensitive(),
1866 ))
1867 }
1868
1869 #[cfg(feature = "delta")]
1870 pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1871 self.session.read_delta_from_path(path)
1872 }
1873}
1874
1875impl SparkSession {
1876 pub fn read(&self) -> DataFrameReader {
1878 DataFrameReader::new(SparkSession {
1879 app_name: self.app_name.clone(),
1880 master: self.master.clone(),
1881 config: self.config.clone(),
1882 catalog: self.catalog.clone(),
1883 tables: self.tables.clone(),
1884 databases: self.databases.clone(),
1885 udf_registry: self.udf_registry.clone(),
1886 })
1887 }
1888}
1889
1890impl Default for SparkSession {
1891 fn default() -> Self {
1892 Self::builder().get_or_create()
1893 }
1894}
1895
1896#[cfg(test)]
1897mod tests {
1898 use super::*;
1899
1900 #[test]
1901 fn test_spark_session_builder_basic() {
1902 let spark = SparkSession::builder().app_name("test_app").get_or_create();
1903
1904 assert_eq!(spark.app_name, Some("test_app".to_string()));
1905 }
1906
1907 #[test]
1908 fn test_spark_session_builder_with_master() {
1909 let spark = SparkSession::builder()
1910 .app_name("test_app")
1911 .master("local[*]")
1912 .get_or_create();
1913
1914 assert_eq!(spark.app_name, Some("test_app".to_string()));
1915 assert_eq!(spark.master, Some("local[*]".to_string()));
1916 }
1917
1918 #[test]
1919 fn test_spark_session_builder_with_config() {
1920 let spark = SparkSession::builder()
1921 .app_name("test_app")
1922 .config("spark.executor.memory", "4g")
1923 .config("spark.driver.memory", "2g")
1924 .get_or_create();
1925
1926 assert_eq!(
1927 spark.config.get("spark.executor.memory"),
1928 Some(&"4g".to_string())
1929 );
1930 assert_eq!(
1931 spark.config.get("spark.driver.memory"),
1932 Some(&"2g".to_string())
1933 );
1934 }
1935
1936 #[test]
1937 fn test_spark_session_default() {
1938 let spark = SparkSession::default();
1939 assert!(spark.app_name.is_none());
1940 assert!(spark.master.is_none());
1941 assert!(spark.config.is_empty());
1942 }
1943
1944 #[test]
1945 fn test_create_dataframe_success() {
1946 let spark = SparkSession::builder().app_name("test").get_or_create();
1947 let data = vec![
1948 (1i64, 25i64, "Alice".to_string()),
1949 (2i64, 30i64, "Bob".to_string()),
1950 ];
1951
1952 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
1953
1954 assert!(result.is_ok());
1955 let df = result.unwrap();
1956 assert_eq!(df.count().unwrap(), 2);
1957
1958 let columns = df.columns().unwrap();
1959 assert!(columns.contains(&"id".to_string()));
1960 assert!(columns.contains(&"age".to_string()));
1961 assert!(columns.contains(&"name".to_string()));
1962 }
1963
1964 #[test]
1965 fn test_create_dataframe_wrong_column_count() {
1966 let spark = SparkSession::builder().app_name("test").get_or_create();
1967 let data = vec![(1i64, 25i64, "Alice".to_string())];
1968
1969 let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
1971 assert!(result.is_err());
1972
1973 let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
1975 assert!(result.is_err());
1976 }
1977
1978 #[test]
1979 fn test_create_dataframe_from_rows_empty_schema_with_rows_returns_error() {
1980 let spark = SparkSession::builder().app_name("test").get_or_create();
1981 let rows: Vec<Vec<JsonValue>> = vec![vec![]];
1982 let schema: Vec<(String, String)> = vec![];
1983 let result = spark.create_dataframe_from_rows(rows, schema);
1984 match &result {
1985 Err(e) => assert!(e.to_string().contains("schema must not be empty")),
1986 Ok(_) => panic!("expected error for empty schema with non-empty rows"),
1987 }
1988 }
1989
1990 #[test]
1991 fn test_create_dataframe_from_rows_empty_data_with_schema() {
1992 let spark = SparkSession::builder().app_name("test").get_or_create();
1993 let rows: Vec<Vec<JsonValue>> = vec![];
1994 let schema = vec![
1995 ("a".to_string(), "int".to_string()),
1996 ("b".to_string(), "string".to_string()),
1997 ];
1998 let result = spark.create_dataframe_from_rows(rows, schema);
1999 let df = result.unwrap();
2000 assert_eq!(df.count().unwrap(), 0);
2001 assert_eq!(df.collect_inner().unwrap().get_column_names(), &["a", "b"]);
2002 }
2003
2004 #[test]
2005 fn test_create_dataframe_from_rows_empty_schema_empty_data() {
2006 let spark = SparkSession::builder().app_name("test").get_or_create();
2007 let rows: Vec<Vec<JsonValue>> = vec![];
2008 let schema: Vec<(String, String)> = vec![];
2009 let result = spark.create_dataframe_from_rows(rows, schema);
2010 let df = result.unwrap();
2011 assert_eq!(df.count().unwrap(), 0);
2012 assert_eq!(df.collect_inner().unwrap().get_column_names().len(), 0);
2013 }
2014
2015 #[test]
2017 fn test_create_dataframe_from_rows_struct_as_object() {
2018 use serde_json::json;
2019
2020 let spark = SparkSession::builder().app_name("test").get_or_create();
2021 let schema = vec![
2022 ("id".to_string(), "string".to_string()),
2023 (
2024 "nested".to_string(),
2025 "struct<a:bigint,b:string>".to_string(),
2026 ),
2027 ];
2028 let rows: Vec<Vec<JsonValue>> = vec![
2029 vec![json!("x"), json!({"a": 1, "b": "y"})],
2030 vec![json!("z"), json!({"a": 2, "b": "w"})],
2031 ];
2032 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2033 assert_eq!(df.count().unwrap(), 2);
2034 let collected = df.collect_inner().unwrap();
2035 assert_eq!(collected.get_column_names(), &["id", "nested"]);
2036 }
2037
2038 #[test]
2040 fn test_create_dataframe_from_rows_struct_as_array() {
2041 use serde_json::json;
2042
2043 let spark = SparkSession::builder().app_name("test").get_or_create();
2044 let schema = vec![
2045 ("id".to_string(), "string".to_string()),
2046 (
2047 "nested".to_string(),
2048 "struct<a:bigint,b:string>".to_string(),
2049 ),
2050 ];
2051 let rows: Vec<Vec<JsonValue>> = vec![
2052 vec![json!("x"), json!([1, "y"])],
2053 vec![json!("z"), json!([2, "w"])],
2054 ];
2055 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2056 assert_eq!(df.count().unwrap(), 2);
2057 let collected = df.collect_inner().unwrap();
2058 assert_eq!(collected.get_column_names(), &["id", "nested"]);
2059 }
2060
2061 #[test]
2063 fn test_issue_610_struct_value_as_string_object_or_array() {
2064 use serde_json::json;
2065
2066 let spark = SparkSession::builder().app_name("test").get_or_create();
2067 let schema = vec![
2068 ("id".to_string(), "string".to_string()),
2069 (
2070 "nested".to_string(),
2071 "struct<a:bigint,b:string>".to_string(),
2072 ),
2073 ];
2074 let rows_object: Vec<Vec<JsonValue>> =
2076 vec![vec![json!("A"), json!(r#"{"a": 1, "b": "x"}"#)]];
2077 let df1 = spark
2078 .create_dataframe_from_rows(rows_object, schema.clone())
2079 .unwrap();
2080 assert_eq!(df1.count().unwrap(), 1);
2081
2082 let rows_array: Vec<Vec<JsonValue>> = vec![vec![json!("B"), json!(r#"[1, "y"]"#)]];
2084 let df2 = spark
2085 .create_dataframe_from_rows(rows_array, schema)
2086 .unwrap();
2087 assert_eq!(df2.count().unwrap(), 1);
2088 }
2089
2090 #[test]
2092 fn test_issue_611_array_column_single_value_as_one_element() {
2093 use serde_json::json;
2094
2095 let spark = SparkSession::builder().app_name("test").get_or_create();
2096 let schema = vec![
2097 ("id".to_string(), "string".to_string()),
2098 ("arr".to_string(), "array<bigint>".to_string()),
2099 ];
2100 let rows: Vec<Vec<JsonValue>> = vec![
2102 vec![json!("x"), json!(42)],
2103 vec![json!("y"), json!([1, 2, 3])],
2104 ];
2105 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2106 assert_eq!(df.count().unwrap(), 2);
2107 let collected = df.collect_inner().unwrap();
2108 let arr_col = collected.column("arr").unwrap();
2109 let list = arr_col.list().unwrap();
2110 let row0 = list.get(0).unwrap();
2111 assert_eq!(
2112 row0.len(),
2113 1,
2114 "#611: single value should become one-element list"
2115 );
2116 let row1 = list.get(1).unwrap();
2117 assert_eq!(row1.len(), 3);
2118 }
2119
2120 #[test]
2122 fn test_create_dataframe_from_rows_array_column() {
2123 use serde_json::json;
2124
2125 let spark = SparkSession::builder().app_name("test").get_or_create();
2126 let schema = vec![
2127 ("id".to_string(), "string".to_string()),
2128 ("arr".to_string(), "array<bigint>".to_string()),
2129 ];
2130 let rows: Vec<Vec<JsonValue>> = vec![
2131 vec![json!("x"), json!([1, 2, 3])],
2132 vec![json!("y"), json!([4, 5])],
2133 vec![json!("z"), json!(null)],
2134 ];
2135 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2136 assert_eq!(df.count().unwrap(), 3);
2137 let collected = df.collect_inner().unwrap();
2138 assert_eq!(collected.get_column_names(), &["id", "arr"]);
2139
2140 let arr_col = collected.column("arr").unwrap();
2142 let list = arr_col.list().unwrap();
2143 let row0 = list.get(0).unwrap();
2145 assert_eq!(row0.len(), 3, "row 0 arr should have 3 elements");
2146 let row1 = list.get(1).unwrap();
2148 assert_eq!(row1.len(), 2);
2149 let row2 = list.get(2);
2151 assert!(
2152 row2.is_none() || row2.as_ref().map(|a| a.is_empty()).unwrap_or(false),
2153 "row 2 arr should be null or empty"
2154 );
2155 }
2156
2157 #[test]
2160 fn test_issue_601_array_column_pyspark_parity() {
2161 use serde_json::json;
2162
2163 let spark = SparkSession::builder().app_name("test").get_or_create();
2164 let schema = vec![
2165 ("id".to_string(), "string".to_string()),
2166 ("arr".to_string(), "array<bigint>".to_string()),
2167 ];
2168 let rows: Vec<Vec<JsonValue>> = vec![
2170 vec![json!("x"), json!([1, 2, 3])],
2171 vec![json!("y"), json!([4, 5])],
2172 ];
2173 let df = spark
2174 .create_dataframe_from_rows(rows, schema)
2175 .expect("issue #601: create_dataframe_from_rows must accept array column (JSON array)");
2176 let n = df.count().unwrap();
2177 assert_eq!(n, 2, "issue #601: expected 2 rows");
2178 let collected = df.collect_inner().unwrap();
2179 let arr_col = collected.column("arr").unwrap();
2180 let list = arr_col.list().unwrap();
2181 let row0 = list.get(0).unwrap();
2183 assert_eq!(
2184 row0.len(),
2185 3,
2186 "issue #601: first row arr must have 3 elements [1,2,3]"
2187 );
2188 let row1 = list.get(1).unwrap();
2189 assert_eq!(
2190 row1.len(),
2191 2,
2192 "issue #601: second row arr must have 2 elements [4,5]"
2193 );
2194 }
2195
2196 #[test]
2198 fn test_issue_624_empty_schema_inferred_from_rows() {
2199 use serde_json::json;
2200
2201 let spark = SparkSession::builder().app_name("test").get_or_create();
2202 let schema: Vec<(String, String)> = vec![];
2203 let rows: Vec<Vec<JsonValue>> =
2204 vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
2205 let df = spark
2206 .create_dataframe_from_rows(rows, schema)
2207 .expect("#624: empty schema with non-empty rows should infer schema");
2208 assert_eq!(df.count().unwrap(), 2);
2209 let collected = df.collect_inner().unwrap();
2210 assert_eq!(collected.get_column_names(), &["c0", "c1"]);
2211 }
2212
2213 #[test]
2215 fn test_create_dataframe_from_rows_map_column() {
2216 use serde_json::json;
2217
2218 let spark = SparkSession::builder().app_name("test").get_or_create();
2219 let schema = vec![
2220 ("id".to_string(), "integer".to_string()),
2221 ("m".to_string(), "map<string,string>".to_string()),
2222 ];
2223 let rows: Vec<Vec<JsonValue>> = vec![
2224 vec![json!(1), json!({"a": "x", "b": "y"})],
2225 vec![json!(2), json!({"c": "z"})],
2226 ];
2227 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2228 assert_eq!(df.count().unwrap(), 2);
2229 let collected = df.collect_inner().unwrap();
2230 assert_eq!(collected.get_column_names(), &["id", "m"]);
2231 let m_col = collected.column("m").unwrap();
2232 let list = m_col.list().unwrap();
2233 let row0 = list.get(0).unwrap();
2234 assert_eq!(row0.len(), 2, "row 0 map should have 2 entries");
2235 let row1 = list.get(1).unwrap();
2236 assert_eq!(row1.len(), 1, "row 1 map should have 1 entry");
2237 }
2238
2239 #[test]
2241 fn test_issue_625_array_column_list_or_object() {
2242 use serde_json::json;
2243
2244 let spark = SparkSession::builder().app_name("test").get_or_create();
2245 let schema = vec![
2246 ("id".to_string(), "string".to_string()),
2247 ("arr".to_string(), "array<bigint>".to_string()),
2248 ];
2249 let rows: Vec<Vec<JsonValue>> = vec![
2251 vec![json!("x"), json!([1, 2, 3])],
2252 vec![json!("y"), json!({"0": 4, "1": 5})],
2253 ];
2254 let df = spark
2255 .create_dataframe_from_rows(rows, schema)
2256 .expect("#625: array column must accept list/array or object representation");
2257 assert_eq!(df.count().unwrap(), 2);
2258 let collected = df.collect_inner().unwrap();
2259 let list = collected.column("arr").unwrap().list().unwrap();
2260 assert_eq!(list.get(0).unwrap().len(), 3);
2261 assert_eq!(list.get(1).unwrap().len(), 2);
2262 }
2263
2264 #[test]
2265 fn test_create_dataframe_empty() {
2266 let spark = SparkSession::builder().app_name("test").get_or_create();
2267 let data: Vec<(i64, i64, String)> = vec![];
2268
2269 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
2270
2271 assert!(result.is_ok());
2272 let df = result.unwrap();
2273 assert_eq!(df.count().unwrap(), 0);
2274 }
2275
2276 #[test]
2277 fn test_create_dataframe_from_polars() {
2278 use polars::prelude::df;
2279
2280 let spark = SparkSession::builder().app_name("test").get_or_create();
2281 let polars_df = df!(
2282 "x" => &[1, 2, 3],
2283 "y" => &[4, 5, 6]
2284 )
2285 .unwrap();
2286
2287 let df = spark.create_dataframe_from_polars(polars_df);
2288
2289 assert_eq!(df.count().unwrap(), 3);
2290 let columns = df.columns().unwrap();
2291 assert!(columns.contains(&"x".to_string()));
2292 assert!(columns.contains(&"y".to_string()));
2293 }
2294
2295 #[test]
2296 fn test_read_csv_file_not_found() {
2297 let spark = SparkSession::builder().app_name("test").get_or_create();
2298
2299 let result = spark.read_csv("nonexistent_file.csv");
2300
2301 assert!(result.is_err());
2302 }
2303
2304 #[test]
2305 fn test_read_parquet_file_not_found() {
2306 let spark = SparkSession::builder().app_name("test").get_or_create();
2307
2308 let result = spark.read_parquet("nonexistent_file.parquet");
2309
2310 assert!(result.is_err());
2311 }
2312
2313 #[test]
2314 fn test_read_json_file_not_found() {
2315 let spark = SparkSession::builder().app_name("test").get_or_create();
2316
2317 let result = spark.read_json("nonexistent_file.json");
2318
2319 assert!(result.is_err());
2320 }
2321
2322 #[test]
2323 fn test_rust_udf_dataframe() {
2324 use crate::functions::{call_udf, col};
2325 use polars::prelude::DataType;
2326
2327 let spark = SparkSession::builder().app_name("test").get_or_create();
2328 spark
2329 .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
2330 .unwrap();
2331 let df = spark
2332 .create_dataframe(
2333 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2334 vec!["id", "age", "name"],
2335 )
2336 .unwrap();
2337 let col = call_udf("to_str", &[col("id")]).unwrap();
2338 let df2 = df.with_column("id_str", &col).unwrap();
2339 let cols = df2.columns().unwrap();
2340 assert!(cols.contains(&"id_str".to_string()));
2341 let rows = df2.collect_as_json_rows().unwrap();
2342 assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
2343 assert_eq!(rows[1].get("id_str").and_then(|v| v.as_str()), Some("2"));
2344 }
2345
2346 #[test]
2347 fn test_case_insensitive_filter_select() {
2348 use crate::expression::lit_i64;
2349 use crate::functions::col;
2350
2351 let spark = SparkSession::builder().app_name("test").get_or_create();
2352 let df = spark
2353 .create_dataframe(
2354 vec![
2355 (1, 25, "Alice".to_string()),
2356 (2, 30, "Bob".to_string()),
2357 (3, 35, "Charlie".to_string()),
2358 ],
2359 vec!["Id", "Age", "Name"],
2360 )
2361 .unwrap();
2362 let filtered = df
2364 .filter(col("age").gt(lit_i64(26)).expr().clone())
2365 .unwrap()
2366 .select(vec!["name"])
2367 .unwrap();
2368 assert_eq!(filtered.count().unwrap(), 2);
2369 let rows = filtered.collect_as_json_rows().unwrap();
2370 let names: Vec<&str> = rows
2371 .iter()
2372 .map(|r| r.get("name").and_then(|v| v.as_str()).unwrap())
2373 .collect();
2374 assert!(names.contains(&"Bob"));
2375 assert!(names.contains(&"Charlie"));
2376 }
2377
2378 #[test]
2379 fn test_sql_returns_error_without_feature_or_unknown_table() {
2380 let spark = SparkSession::builder().app_name("test").get_or_create();
2381
2382 let result = spark.sql("SELECT * FROM table");
2383
2384 assert!(result.is_err());
2385 match result {
2386 Err(PolarsError::InvalidOperation(msg)) => {
2387 let s = msg.to_string();
2388 assert!(
2391 s.contains("SQL") || s.contains("Table") || s.contains("feature"),
2392 "unexpected message: {s}"
2393 );
2394 }
2395 _ => panic!("Expected InvalidOperation error"),
2396 }
2397 }
2398
2399 #[test]
2400 fn test_spark_session_stop() {
2401 let spark = SparkSession::builder().app_name("test").get_or_create();
2402
2403 spark.stop();
2405 }
2406
2407 #[test]
2408 fn test_dataframe_reader_api() {
2409 let spark = SparkSession::builder().app_name("test").get_or_create();
2410 let reader = spark.read();
2411
2412 assert!(reader.csv("nonexistent.csv").is_err());
2414 assert!(reader.parquet("nonexistent.parquet").is_err());
2415 assert!(reader.json("nonexistent.json").is_err());
2416 }
2417
2418 #[test]
2419 fn test_read_csv_with_valid_file() {
2420 use std::io::Write;
2421 use tempfile::NamedTempFile;
2422
2423 let spark = SparkSession::builder().app_name("test").get_or_create();
2424
2425 let mut temp_file = NamedTempFile::new().unwrap();
2427 writeln!(temp_file, "id,name,age").unwrap();
2428 writeln!(temp_file, "1,Alice,25").unwrap();
2429 writeln!(temp_file, "2,Bob,30").unwrap();
2430 temp_file.flush().unwrap();
2431
2432 let result = spark.read_csv(temp_file.path());
2433
2434 assert!(result.is_ok());
2435 let df = result.unwrap();
2436 assert_eq!(df.count().unwrap(), 2);
2437
2438 let columns = df.columns().unwrap();
2439 assert!(columns.contains(&"id".to_string()));
2440 assert!(columns.contains(&"name".to_string()));
2441 assert!(columns.contains(&"age".to_string()));
2442 }
2443
2444 #[test]
2445 fn test_read_json_with_valid_file() {
2446 use std::io::Write;
2447 use tempfile::NamedTempFile;
2448
2449 let spark = SparkSession::builder().app_name("test").get_or_create();
2450
2451 let mut temp_file = NamedTempFile::new().unwrap();
2453 writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
2454 writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
2455 temp_file.flush().unwrap();
2456
2457 let result = spark.read_json(temp_file.path());
2458
2459 assert!(result.is_ok());
2460 let df = result.unwrap();
2461 assert_eq!(df.count().unwrap(), 2);
2462 }
2463
2464 #[test]
2465 fn test_read_csv_empty_file() {
2466 use std::io::Write;
2467 use tempfile::NamedTempFile;
2468
2469 let spark = SparkSession::builder().app_name("test").get_or_create();
2470
2471 let mut temp_file = NamedTempFile::new().unwrap();
2473 writeln!(temp_file, "id,name").unwrap();
2474 temp_file.flush().unwrap();
2475
2476 let result = spark.read_csv(temp_file.path());
2477
2478 assert!(result.is_ok());
2479 let df = result.unwrap();
2480 assert_eq!(df.count().unwrap(), 0);
2481 }
2482
2483 #[test]
2484 fn test_write_partitioned_parquet() {
2485 use crate::dataframe::{WriteFormat, WriteMode};
2486 use std::fs;
2487 use tempfile::TempDir;
2488
2489 let spark = SparkSession::builder().app_name("test").get_or_create();
2490 let df = spark
2491 .create_dataframe(
2492 vec![
2493 (1, 25, "Alice".to_string()),
2494 (2, 30, "Bob".to_string()),
2495 (3, 25, "Carol".to_string()),
2496 ],
2497 vec!["id", "age", "name"],
2498 )
2499 .unwrap();
2500 let dir = TempDir::new().unwrap();
2501 let path = dir.path().join("out");
2502 df.write()
2503 .mode(WriteMode::Overwrite)
2504 .format(WriteFormat::Parquet)
2505 .partition_by(["age"])
2506 .save(&path)
2507 .unwrap();
2508 assert!(path.is_dir());
2509 let entries: Vec<_> = fs::read_dir(&path).unwrap().collect();
2510 assert_eq!(
2511 entries.len(),
2512 2,
2513 "expected two partition dirs (age=25, age=30)"
2514 );
2515 let names: Vec<String> = entries
2516 .iter()
2517 .filter_map(|e| e.as_ref().ok())
2518 .map(|e| e.file_name().to_string_lossy().into_owned())
2519 .collect();
2520 assert!(names.iter().any(|n| n.starts_with("age=")));
2521 let df_read = spark.read_parquet(&path).unwrap();
2522 assert_eq!(df_read.count().unwrap(), 3);
2523 }
2524
2525 #[test]
2526 fn test_save_as_table_error_if_exists() {
2527 use crate::dataframe::SaveMode;
2528
2529 let spark = SparkSession::builder().app_name("test").get_or_create();
2530 let df = spark
2531 .create_dataframe(
2532 vec![(1, 25, "Alice".to_string())],
2533 vec!["id", "age", "name"],
2534 )
2535 .unwrap();
2536 df.write()
2538 .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
2539 .unwrap();
2540 assert!(spark.table("t1").is_ok());
2541 assert_eq!(spark.table("t1").unwrap().count().unwrap(), 1);
2542 let err = df
2544 .write()
2545 .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
2546 .unwrap_err();
2547 assert!(err.to_string().contains("already exists"));
2548 }
2549
2550 #[test]
2551 fn test_save_as_table_overwrite() {
2552 use crate::dataframe::SaveMode;
2553
2554 let spark = SparkSession::builder().app_name("test").get_or_create();
2555 let df1 = spark
2556 .create_dataframe(
2557 vec![(1, 25, "Alice".to_string())],
2558 vec!["id", "age", "name"],
2559 )
2560 .unwrap();
2561 let df2 = spark
2562 .create_dataframe(
2563 vec![(2, 30, "Bob".to_string()), (3, 35, "Carol".to_string())],
2564 vec!["id", "age", "name"],
2565 )
2566 .unwrap();
2567 df1.write()
2568 .save_as_table(&spark, "t_over", SaveMode::ErrorIfExists)
2569 .unwrap();
2570 assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 1);
2571 df2.write()
2572 .save_as_table(&spark, "t_over", SaveMode::Overwrite)
2573 .unwrap();
2574 assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 2);
2575 }
2576
2577 #[test]
2578 fn test_save_as_table_append() {
2579 use crate::dataframe::SaveMode;
2580
2581 let spark = SparkSession::builder().app_name("test").get_or_create();
2582 let df1 = spark
2583 .create_dataframe(
2584 vec![(1, 25, "Alice".to_string())],
2585 vec!["id", "age", "name"],
2586 )
2587 .unwrap();
2588 let df2 = spark
2589 .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
2590 .unwrap();
2591 df1.write()
2592 .save_as_table(&spark, "t_append", SaveMode::ErrorIfExists)
2593 .unwrap();
2594 df2.write()
2595 .save_as_table(&spark, "t_append", SaveMode::Append)
2596 .unwrap();
2597 assert_eq!(spark.table("t_append").unwrap().count().unwrap(), 2);
2598 }
2599
2600 #[test]
2602 fn test_save_as_table_empty_df_then_append() {
2603 use crate::dataframe::SaveMode;
2604 use serde_json::json;
2605
2606 let spark = SparkSession::builder().app_name("test").get_or_create();
2607 let schema = vec![
2608 ("id".to_string(), "bigint".to_string()),
2609 ("name".to_string(), "string".to_string()),
2610 ];
2611 let empty_df = spark
2612 .create_dataframe_from_rows(vec![], schema.clone())
2613 .unwrap();
2614 assert_eq!(empty_df.count().unwrap(), 0);
2615
2616 empty_df
2617 .write()
2618 .save_as_table(&spark, "t_empty_append", SaveMode::Overwrite)
2619 .unwrap();
2620 let r1 = spark.table("t_empty_append").unwrap();
2621 assert_eq!(r1.count().unwrap(), 0);
2622 let cols = r1.columns().unwrap();
2623 assert!(cols.contains(&"id".to_string()));
2624 assert!(cols.contains(&"name".to_string()));
2625
2626 let one_row = spark
2627 .create_dataframe_from_rows(vec![vec![json!(1), json!("a")]], schema)
2628 .unwrap();
2629 one_row
2630 .write()
2631 .save_as_table(&spark, "t_empty_append", SaveMode::Append)
2632 .unwrap();
2633 let r2 = spark.table("t_empty_append").unwrap();
2634 assert_eq!(r2.count().unwrap(), 1);
2635 }
2636
2637 #[test]
2640 fn test_write_parquet_empty_df_with_schema() {
2641 let spark = SparkSession::builder().app_name("test").get_or_create();
2642 let schema = vec![
2643 ("id".to_string(), "bigint".to_string()),
2644 ("name".to_string(), "string".to_string()),
2645 ];
2646 let empty_df = spark.create_dataframe_from_rows(vec![], schema).unwrap();
2647 assert_eq!(empty_df.count().unwrap(), 0);
2648
2649 let dir = tempfile::TempDir::new().unwrap();
2650 let path = dir.path().join("empty.parquet");
2651 empty_df
2652 .write()
2653 .format(crate::dataframe::WriteFormat::Parquet)
2654 .mode(crate::dataframe::WriteMode::Overwrite)
2655 .save(&path)
2656 .unwrap();
2657 assert!(path.is_file());
2658
2659 let read_df = spark.read().parquet(path.to_str().unwrap()).unwrap();
2661 assert_eq!(read_df.count().unwrap(), 0);
2662 let cols = read_df.columns().unwrap();
2663 assert!(cols.contains(&"id".to_string()));
2664 assert!(cols.contains(&"name".to_string()));
2665 }
2666
2667 #[test]
2669 fn test_save_as_table_empty_df_warehouse_then_append() {
2670 use crate::dataframe::SaveMode;
2671 use serde_json::json;
2672 use std::sync::atomic::{AtomicU64, Ordering};
2673 use tempfile::TempDir;
2674
2675 static COUNTER: AtomicU64 = AtomicU64::new(0);
2676 let n = COUNTER.fetch_add(1, Ordering::SeqCst);
2677 let dir = TempDir::new().unwrap();
2678 let warehouse = dir.path().join(format!("wh_{n}"));
2679 std::fs::create_dir_all(&warehouse).unwrap();
2680 let spark = SparkSession::builder()
2681 .app_name("test")
2682 .config(
2683 "spark.sql.warehouse.dir",
2684 warehouse.as_os_str().to_str().unwrap(),
2685 )
2686 .get_or_create();
2687
2688 let schema = vec![
2689 ("id".to_string(), "bigint".to_string()),
2690 ("name".to_string(), "string".to_string()),
2691 ];
2692 let empty_df = spark
2693 .create_dataframe_from_rows(vec![], schema.clone())
2694 .unwrap();
2695 empty_df
2696 .write()
2697 .save_as_table(&spark, "t_empty_wh", SaveMode::Overwrite)
2698 .unwrap();
2699 let r1 = spark.table("t_empty_wh").unwrap();
2700 assert_eq!(r1.count().unwrap(), 0);
2701
2702 let one_row = spark
2703 .create_dataframe_from_rows(vec![vec![json!(1), json!("a")]], schema)
2704 .unwrap();
2705 one_row
2706 .write()
2707 .save_as_table(&spark, "t_empty_wh", SaveMode::Append)
2708 .unwrap();
2709 let r2 = spark.table("t_empty_wh").unwrap();
2710 assert_eq!(r2.count().unwrap(), 1);
2711 }
2712
2713 #[test]
2714 fn test_save_as_table_ignore() {
2715 use crate::dataframe::SaveMode;
2716
2717 let spark = SparkSession::builder().app_name("test").get_or_create();
2718 let df1 = spark
2719 .create_dataframe(
2720 vec![(1, 25, "Alice".to_string())],
2721 vec!["id", "age", "name"],
2722 )
2723 .unwrap();
2724 let df2 = spark
2725 .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
2726 .unwrap();
2727 df1.write()
2728 .save_as_table(&spark, "t_ignore", SaveMode::ErrorIfExists)
2729 .unwrap();
2730 df2.write()
2731 .save_as_table(&spark, "t_ignore", SaveMode::Ignore)
2732 .unwrap();
2733 assert_eq!(spark.table("t_ignore").unwrap().count().unwrap(), 1);
2735 }
2736
2737 #[test]
2738 fn test_table_resolution_temp_view_first() {
2739 use crate::dataframe::SaveMode;
2740
2741 let spark = SparkSession::builder().app_name("test").get_or_create();
2742 let df_saved = spark
2743 .create_dataframe(
2744 vec![(1, 25, "Saved".to_string())],
2745 vec!["id", "age", "name"],
2746 )
2747 .unwrap();
2748 let df_temp = spark
2749 .create_dataframe(vec![(2, 30, "Temp".to_string())], vec!["id", "age", "name"])
2750 .unwrap();
2751 df_saved
2752 .write()
2753 .save_as_table(&spark, "x", SaveMode::ErrorIfExists)
2754 .unwrap();
2755 spark.create_or_replace_temp_view("x", df_temp);
2756 let t = spark.table("x").unwrap();
2758 let rows = t.collect_as_json_rows().unwrap();
2759 assert_eq!(rows.len(), 1);
2760 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Temp"));
2761 }
2762
2763 #[test]
2765 fn test_issue_629_temp_view_visible_after_create() {
2766 use serde_json::json;
2767
2768 let spark = SparkSession::builder().app_name("repro").get_or_create();
2769 let schema = vec![
2770 ("id".to_string(), "long".to_string()),
2771 ("name".to_string(), "string".to_string()),
2772 ];
2773 let rows: Vec<Vec<JsonValue>> =
2774 vec![vec![json!(1), json!("a")], vec![json!(2), json!("b")]];
2775 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2776 spark.create_or_replace_temp_view("my_view", df);
2777 let result = spark
2778 .table("my_view")
2779 .unwrap()
2780 .collect_as_json_rows()
2781 .unwrap();
2782 assert_eq!(result.len(), 2);
2783 assert_eq!(result[0].get("id").and_then(|v| v.as_i64()), Some(1));
2784 assert_eq!(result[0].get("name").and_then(|v| v.as_str()), Some("a"));
2785 assert_eq!(result[1].get("id").and_then(|v| v.as_i64()), Some(2));
2786 assert_eq!(result[1].get("name").and_then(|v| v.as_str()), Some("b"));
2787 }
2788
2789 #[test]
2790 fn test_drop_table() {
2791 use crate::dataframe::SaveMode;
2792
2793 let spark = SparkSession::builder().app_name("test").get_or_create();
2794 let df = spark
2795 .create_dataframe(
2796 vec![(1, 25, "Alice".to_string())],
2797 vec!["id", "age", "name"],
2798 )
2799 .unwrap();
2800 df.write()
2801 .save_as_table(&spark, "t_drop", SaveMode::ErrorIfExists)
2802 .unwrap();
2803 assert!(spark.table("t_drop").is_ok());
2804 assert!(spark.drop_table("t_drop"));
2805 assert!(spark.table("t_drop").is_err());
2806 assert!(!spark.drop_table("t_drop"));
2808 }
2809
2810 #[test]
2811 fn test_global_temp_view_persists_across_sessions() {
2812 let spark1 = SparkSession::builder().app_name("s1").get_or_create();
2814 let df1 = spark1
2815 .create_dataframe(
2816 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2817 vec!["id", "age", "name"],
2818 )
2819 .unwrap();
2820 spark1.create_or_replace_global_temp_view("people", df1);
2821 assert_eq!(
2822 spark1.table("global_temp.people").unwrap().count().unwrap(),
2823 2
2824 );
2825
2826 let spark2 = SparkSession::builder().app_name("s2").get_or_create();
2828 let df2 = spark2.table("global_temp.people").unwrap();
2829 assert_eq!(df2.count().unwrap(), 2);
2830 let rows = df2.collect_as_json_rows().unwrap();
2831 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
2832
2833 let df_local = spark2
2835 .create_dataframe(
2836 vec![(3, 35, "Carol".to_string())],
2837 vec!["id", "age", "name"],
2838 )
2839 .unwrap();
2840 spark2.create_or_replace_temp_view("people", df_local);
2841 assert_eq!(spark2.table("people").unwrap().count().unwrap(), 1);
2843 assert_eq!(
2845 spark2.table("global_temp.people").unwrap().count().unwrap(),
2846 2
2847 );
2848
2849 assert!(spark2.drop_global_temp_view("people"));
2851 assert!(spark2.table("global_temp.people").is_err());
2852 }
2853
2854 #[test]
2855 fn test_warehouse_persistence_between_sessions() {
2856 use crate::dataframe::SaveMode;
2857 use std::fs;
2858 use tempfile::TempDir;
2859
2860 let dir = TempDir::new().unwrap();
2861 let warehouse = dir.path().to_str().unwrap();
2862
2863 let spark1 = SparkSession::builder()
2865 .app_name("w1")
2866 .config("spark.sql.warehouse.dir", warehouse)
2867 .get_or_create();
2868 let df1 = spark1
2869 .create_dataframe(
2870 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2871 vec!["id", "age", "name"],
2872 )
2873 .unwrap();
2874 df1.write()
2875 .save_as_table(&spark1, "users", SaveMode::ErrorIfExists)
2876 .unwrap();
2877 assert_eq!(spark1.table("users").unwrap().count().unwrap(), 2);
2878
2879 let spark2 = SparkSession::builder()
2881 .app_name("w2")
2882 .config("spark.sql.warehouse.dir", warehouse)
2883 .get_or_create();
2884 let df2 = spark2.table("users").unwrap();
2885 assert_eq!(df2.count().unwrap(), 2);
2886 let rows = df2.collect_as_json_rows().unwrap();
2887 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
2888
2889 let table_path = dir.path().join("users");
2891 assert!(table_path.is_dir());
2892 let entries: Vec<_> = fs::read_dir(&table_path).unwrap().collect();
2893 assert!(!entries.is_empty());
2894 }
2895}