1use crate::dataframe::DataFrame;
2use crate::error::EngineError;
3use polars::chunked_array::StructChunked;
4use polars::chunked_array::builder::get_list_builder;
5use polars::prelude::{
6 DataFrame as PlDataFrame, DataType, Field, IntoSeries, NamedFrom, PlSmallStr, PolarsError,
7 Series, TimeUnit,
8};
9use robin_sparkless_expr::UdfRegistry;
10use serde_json::Value as JsonValue;
11use std::cell::RefCell;
12use std::sync::Arc;
13
14fn parse_array_element_type(type_str: &str) -> Option<String> {
16 let s = type_str.trim();
17 if !s.to_lowercase().starts_with("array<") || !s.ends_with('>') {
18 return None;
19 }
20 Some(s[6..s.len() - 1].trim().to_string())
21}
22
23fn parse_struct_fields(type_str: &str) -> Option<Vec<(String, String)>> {
25 let s = type_str.trim();
26 if !s.to_lowercase().starts_with("struct<") || !s.ends_with('>') {
27 return None;
28 }
29 let inner = s[7..s.len() - 1].trim();
30 if inner.is_empty() {
31 return Some(Vec::new());
32 }
33 let mut out = Vec::new();
34 for part in inner.split(',') {
35 let part = part.trim();
36 if let Some(idx) = part.find(':') {
37 let name = part[..idx].trim().to_string();
38 let typ = part[idx + 1..].trim().to_string();
39 out.push((name, typ));
40 }
41 }
42 Some(out)
43}
44
45fn parse_map_key_value_types(type_str: &str) -> Option<(String, String)> {
48 let s = type_str.trim().to_lowercase();
49 if !s.starts_with("map<") || !s.ends_with('>') {
50 return None;
51 }
52 let inner = s[4..s.len() - 1].trim();
53 let comma = inner.find(',')?;
54 let key_type = inner[..comma].trim().to_string();
55 let value_type = inner[comma + 1..].trim().to_string();
56 Some((key_type, value_type))
57}
58
59fn is_decimal_type_str(type_str: &str) -> bool {
61 let s = type_str.trim().to_lowercase();
62 s.starts_with("decimal(") && s.contains(')')
63}
64
65fn json_type_str_to_polars(type_str: &str) -> Option<DataType> {
68 let s = type_str.trim().to_lowercase();
69 if is_decimal_type_str(&s) {
70 return Some(DataType::Float64);
71 }
72 match s.as_str() {
73 "int" | "integer" | "bigint" | "long" => Some(DataType::Int64),
74 "double" | "float" | "double_precision" => Some(DataType::Float64),
75 "string" | "str" | "varchar" => Some(DataType::String),
76 "boolean" | "bool" => Some(DataType::Boolean),
77 _ => None,
78 }
79}
80
81fn json_value_to_array(v: &JsonValue) -> Option<Vec<JsonValue>> {
85 match v {
86 JsonValue::Null => None,
87 JsonValue::Array(arr) => Some(arr.clone()),
88 JsonValue::Object(obj) => {
89 let mut indices: Vec<usize> =
91 obj.keys().filter_map(|k| k.parse::<usize>().ok()).collect();
92 indices.sort_unstable();
93 if indices.is_empty() {
94 return None;
95 }
96 let arr: Vec<JsonValue> = indices
97 .iter()
98 .filter_map(|i| obj.get(&i.to_string()).cloned())
99 .collect();
100 Some(arr)
101 }
102 JsonValue::String(s) => {
103 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
104 parsed.as_array().cloned()
105 } else {
106 None
107 }
108 }
109 _ => None,
110 }
111}
112
113fn infer_list_element_type(rows: &[Vec<JsonValue>], col_idx: usize) -> Option<(String, DataType)> {
115 for row in rows {
116 let v = row.get(col_idx)?;
117 let arr = json_value_to_array(v)?;
118 let first = arr.first()?;
119 return Some(match first {
120 JsonValue::String(_) => ("string".to_string(), DataType::String),
121 JsonValue::Number(n) => {
122 if n.as_i64().is_some() {
123 ("bigint".to_string(), DataType::Int64)
124 } else {
125 ("double".to_string(), DataType::Float64)
126 }
127 }
128 JsonValue::Bool(_) => ("boolean".to_string(), DataType::Boolean),
129 JsonValue::Null => continue,
130 _ => ("string".to_string(), DataType::String),
131 });
132 }
133 None
134}
135
136fn json_values_to_series(
138 values: &[Option<JsonValue>],
139 type_str: &str,
140 name: &str,
141) -> Result<Series, PolarsError> {
142 use chrono::{NaiveDate, NaiveDateTime};
143 let epoch = crate::date_utils::epoch_naive_date();
144 let type_lower = type_str.trim().to_lowercase();
145
146 if let Some(elem_type) = parse_array_element_type(&type_lower) {
147 let inner_dtype = json_type_str_to_polars(&elem_type).ok_or_else(|| {
148 PolarsError::ComputeError(
149 format!("array element type '{elem_type}' not supported").into(),
150 )
151 })?;
152 let mut builder = get_list_builder(&inner_dtype, 64, values.len(), name.into());
153 for v in values.iter() {
154 if v.as_ref().is_none_or(|x| matches!(x, JsonValue::Null)) {
155 builder.append_null();
156 } else if let Some(arr) = v.as_ref().and_then(json_value_to_array) {
157 let elem_series: Vec<Series> = arr
159 .iter()
160 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
161 .collect::<Result<Vec<_>, _>>()?;
162 let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
163 let s = Series::from_any_values_and_dtype(
164 PlSmallStr::EMPTY,
165 &vals,
166 &inner_dtype,
167 false,
168 )
169 .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
170 builder.append_series(&s)?;
171 } else {
172 let single_arr = [v.clone().unwrap_or(JsonValue::Null)];
174 let elem_series: Vec<Series> = single_arr
175 .iter()
176 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
177 .collect::<Result<Vec<_>, _>>()?;
178 let vals: Vec<_> = elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
179 let arr_series = Series::from_any_values_and_dtype(
180 PlSmallStr::EMPTY,
181 &vals,
182 &inner_dtype,
183 false,
184 )
185 .map_err(|e| PolarsError::ComputeError(format!("array elem: {e}").into()))?;
186 builder.append_series(&arr_series)?;
187 }
188 }
189 return Ok(builder.finish().into_series());
190 }
191
192 if let Some(fields) = parse_struct_fields(&type_lower) {
193 let mut field_series_vec: Vec<Vec<Option<JsonValue>>> = (0..fields.len())
194 .map(|_| Vec::with_capacity(values.len()))
195 .collect();
196 for v in values.iter() {
197 let effective: Option<JsonValue> = match v.as_ref() {
199 Some(JsonValue::String(s)) => {
200 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
201 if parsed.is_object() || parsed.is_array() {
202 Some(parsed)
203 } else {
204 v.clone()
205 }
206 } else {
207 v.clone()
208 }
209 }
210 _ => v.clone(),
211 };
212 if effective
213 .as_ref()
214 .is_none_or(|x| matches!(x, JsonValue::Null))
215 {
216 for fc in &mut field_series_vec {
217 fc.push(None);
218 }
219 } else if let Some(obj) = effective.as_ref().and_then(|x| x.as_object()) {
220 for (fi, (fname, _)) in fields.iter().enumerate() {
221 field_series_vec[fi].push(obj.get(fname).cloned());
222 }
223 } else if let Some(arr) = effective.as_ref().and_then(|x| x.as_array()) {
224 for (fi, _) in fields.iter().enumerate() {
225 field_series_vec[fi].push(arr.get(fi).cloned());
226 }
227 } else {
228 return Err(PolarsError::ComputeError(
229 "struct value must be object (by field name) or array (by position). \
230 PySpark accepts dict or tuple/list for struct columns."
231 .into(),
232 ));
233 }
234 }
235 let series_per_field: Vec<Series> = fields
236 .iter()
237 .enumerate()
238 .map(|(fi, (fname, ftype))| json_values_to_series(&field_series_vec[fi], ftype, fname))
239 .collect::<Result<Vec<_>, _>>()?;
240 let field_refs: Vec<&Series> = series_per_field.iter().collect();
241 let st = StructChunked::from_series(name.into(), values.len(), field_refs.iter().copied())
242 .map_err(|e| PolarsError::ComputeError(format!("struct column: {e}").into()))?
243 .into_series();
244 return Ok(st);
245 }
246
247 match type_lower.as_str() {
248 "int" | "bigint" | "long" => {
249 let vals: Vec<Option<i64>> = values
250 .iter()
251 .map(|ov| {
252 ov.as_ref().and_then(|v| match v {
253 JsonValue::Number(n) => n.as_i64(),
254 JsonValue::Null => None,
255 _ => None,
256 })
257 })
258 .collect();
259 Ok(Series::new(name.into(), vals))
260 }
261 "double" | "float" => {
262 let vals: Vec<Option<f64>> = values
263 .iter()
264 .map(|ov| {
265 ov.as_ref().and_then(|v| match v {
266 JsonValue::Number(n) => n.as_f64(),
267 JsonValue::Null => None,
268 _ => None,
269 })
270 })
271 .collect();
272 Ok(Series::new(name.into(), vals))
273 }
274 "string" | "str" | "varchar" => {
275 let vals: Vec<Option<&str>> = values
276 .iter()
277 .map(|ov| {
278 ov.as_ref().and_then(|v| match v {
279 JsonValue::String(s) => Some(s.as_str()),
280 JsonValue::Null => None,
281 _ => None,
282 })
283 })
284 .collect();
285 let owned: Vec<Option<String>> =
286 vals.into_iter().map(|o| o.map(|s| s.to_string())).collect();
287 Ok(Series::new(name.into(), owned))
288 }
289 "boolean" | "bool" => {
290 let vals: Vec<Option<bool>> = values
291 .iter()
292 .map(|ov| {
293 ov.as_ref().and_then(|v| match v {
294 JsonValue::Bool(b) => Some(*b),
295 JsonValue::Null => None,
296 _ => None,
297 })
298 })
299 .collect();
300 Ok(Series::new(name.into(), vals))
301 }
302 "date" => {
303 let vals: Vec<Option<i32>> = values
304 .iter()
305 .map(|ov| {
306 ov.as_ref().and_then(|v| match v {
307 JsonValue::String(s) => NaiveDate::parse_from_str(s, "%Y-%m-%d")
308 .ok()
309 .map(|d| (d - epoch).num_days() as i32),
310 JsonValue::Null => None,
311 _ => None,
312 })
313 })
314 .collect();
315 let s = Series::new(name.into(), vals);
316 s.cast(&DataType::Date)
317 .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))
318 }
319 "timestamp" | "datetime" | "timestamp_ntz" => {
320 let vals: Vec<Option<i64>> = values
321 .iter()
322 .map(|ov| {
323 ov.as_ref().and_then(|v| match v {
324 JsonValue::String(s) => {
325 let parsed = NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f")
326 .map_err(|e| PolarsError::ComputeError(e.to_string().into()))
327 .or_else(|_| {
328 NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").map_err(
329 |e| PolarsError::ComputeError(e.to_string().into()),
330 )
331 })
332 .or_else(|_| {
333 NaiveDate::parse_from_str(s, "%Y-%m-%d")
334 .map_err(|e| {
335 PolarsError::ComputeError(e.to_string().into())
336 })
337 .and_then(|d| {
338 d.and_hms_opt(0, 0, 0).ok_or_else(|| {
339 PolarsError::ComputeError(
340 "date to datetime (0:0:0)".into(),
341 )
342 })
343 })
344 });
345 parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
346 }
347 JsonValue::Number(n) => n.as_i64(),
348 JsonValue::Null => None,
349 _ => None,
350 })
351 })
352 .collect();
353 let s = Series::new(name.into(), vals);
354 s.cast(&DataType::Datetime(TimeUnit::Microseconds, None))
355 .map_err(|e| PolarsError::ComputeError(format!("datetime cast: {e}").into()))
356 }
357 _ => Err(PolarsError::ComputeError(
358 format!("json_values_to_series: unsupported type '{type_str}'").into(),
359 )),
360 }
361}
362
363fn json_value_to_series_single(
365 value: &JsonValue,
366 type_str: &str,
367 name: &str,
368) -> Result<Series, PolarsError> {
369 use chrono::NaiveDate;
370 let epoch = crate::date_utils::epoch_naive_date();
371 match (value, type_str.trim().to_lowercase().as_str()) {
372 (JsonValue::Null, _) => Ok(Series::new_null(name.into(), 1)),
373 (JsonValue::Number(n), "int" | "bigint" | "long") => {
374 Ok(Series::new(name.into(), vec![n.as_i64()]))
375 }
376 (JsonValue::Number(n), "double" | "float") => {
377 Ok(Series::new(name.into(), vec![n.as_f64()]))
378 }
379 (JsonValue::Number(n), t) if is_decimal_type_str(t) => {
380 Ok(Series::new(name.into(), vec![n.as_f64()]))
381 }
382 (JsonValue::String(s), "string" | "str" | "varchar") => {
383 Ok(Series::new(name.into(), vec![s.as_str()]))
384 }
385 (JsonValue::Bool(b), "boolean" | "bool") => Ok(Series::new(name.into(), vec![*b])),
386 (JsonValue::String(s), "date") => {
387 let d = NaiveDate::parse_from_str(s, "%Y-%m-%d")
388 .map_err(|e| PolarsError::ComputeError(format!("date parse: {e}").into()))?;
389 let days = (d - epoch).num_days() as i32;
390 let s = Series::new(name.into(), vec![days]).cast(&DataType::Date)?;
391 Ok(s)
392 }
393 _ => Err(PolarsError::ComputeError(
394 format!("json_value_to_series: unsupported {type_str} for {value:?}").into(),
395 )),
396 }
397}
398
399#[allow(dead_code)]
401fn json_object_or_array_to_struct_series(
402 value: &JsonValue,
403 fields: &[(String, String)],
404 _name: &str,
405) -> Result<Option<Series>, PolarsError> {
406 use polars::prelude::StructChunked;
407 if matches!(value, JsonValue::Null) {
408 return Ok(None);
409 }
410 let effective = match value {
412 JsonValue::String(s) => {
413 if let Ok(parsed) = serde_json::from_str::<JsonValue>(s) {
414 if parsed.is_object() || parsed.is_array() {
415 parsed
416 } else {
417 value.clone()
418 }
419 } else {
420 value.clone()
421 }
422 }
423 _ => value.clone(),
424 };
425 let mut field_series: Vec<Series> = Vec::with_capacity(fields.len());
426 for (fname, ftype) in fields {
427 let fval = if let Some(obj) = effective.as_object() {
428 obj.get(fname).unwrap_or(&JsonValue::Null)
429 } else if let Some(arr) = effective.as_array() {
430 let idx = field_series.len();
431 arr.get(idx).unwrap_or(&JsonValue::Null)
432 } else {
433 return Err(PolarsError::ComputeError(
434 "struct value must be object (by field name) or array (by position). \
435 PySpark accepts dict or tuple/list for struct columns."
436 .into(),
437 ));
438 };
439 let s = json_value_to_series_single(fval, ftype, fname)?;
440 field_series.push(s);
441 }
442 let field_refs: Vec<&Series> = field_series.iter().collect();
443 let st = StructChunked::from_series(PlSmallStr::EMPTY, 1, field_refs.iter().copied())
444 .map_err(|e| PolarsError::ComputeError(format!("struct from value: {e}").into()))?
445 .into_series();
446 Ok(Some(st))
447}
448
449fn json_object_to_map_struct_series(
452 obj: &serde_json::Map<String, JsonValue>,
453 key_type: &str,
454 value_type: &str,
455 key_dtype: &DataType,
456 value_dtype: &DataType,
457 _name: &str,
458) -> Result<Series, PolarsError> {
459 if obj.is_empty() {
460 let key_series = Series::new("key".into(), Vec::<String>::new());
461 let value_series = Series::new_empty(PlSmallStr::EMPTY, value_dtype);
462 let st = StructChunked::from_series(
463 PlSmallStr::EMPTY,
464 0,
465 [&key_series, &value_series].iter().copied(),
466 )
467 .map_err(|e| PolarsError::ComputeError(format!("map struct empty: {e}").into()))?
468 .into_series();
469 return Ok(st);
470 }
471 let keys: Vec<String> = obj.keys().cloned().collect();
472 let mut value_series = None::<Series>;
473 for v in obj.values() {
474 let s = json_value_to_series_single(v, value_type, "value")?;
475 value_series = Some(match value_series.take() {
476 None => s,
477 Some(mut acc) => {
478 acc.extend(&s).map_err(|e| {
479 PolarsError::ComputeError(format!("map value extend: {e}").into())
480 })?;
481 acc
482 }
483 });
484 }
485 let value_series =
486 value_series.unwrap_or_else(|| Series::new_empty(PlSmallStr::EMPTY, value_dtype));
487 let key_series = Series::new("key".into(), keys.clone());
488 let key_series = if key_type.trim().to_lowercase().as_str() == "string"
489 || key_type.trim().to_lowercase().as_str() == "str"
490 || key_type.trim().to_lowercase().as_str() == "varchar"
491 {
492 key_series
493 } else {
494 key_series
495 .cast(key_dtype)
496 .map_err(|e| PolarsError::ComputeError(format!("map key cast: {e}").into()))?
497 };
498 let st = StructChunked::from_series(
499 PlSmallStr::EMPTY,
500 key_series.len(),
501 [&key_series, &value_series].iter().copied(),
502 )
503 .map_err(|e| PolarsError::ComputeError(format!("map struct: {e}").into()))?
504 .into_series();
505 Ok(st)
506}
507
508use std::collections::{HashMap, HashSet};
509use std::path::Path;
510use std::sync::{Mutex, OnceLock};
511use std::thread_local;
512
513thread_local! {
514 static THREAD_UDF_SESSION: RefCell<Option<SparkSession>> = const { RefCell::new(None) };
516}
517
518pub(crate) fn set_thread_udf_session(session: SparkSession) {
520 robin_sparkless_expr::set_thread_udf_context(
521 Arc::new(session.udf_registry.clone()),
522 session.is_case_sensitive(),
523 );
524 THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = Some(session));
525}
526
527#[allow(dead_code)]
529pub(crate) fn get_thread_udf_session() -> Option<SparkSession> {
530 THREAD_UDF_SESSION.with(|cell| cell.borrow().clone())
531}
532
533pub(crate) fn clear_thread_udf_session() {
535 THREAD_UDF_SESSION.with(|cell| *cell.borrow_mut() = None);
536}
537
538static GLOBAL_TEMP_CATALOG: OnceLock<Arc<Mutex<HashMap<String, DataFrame>>>> = OnceLock::new();
541
542fn global_temp_catalog() -> Arc<Mutex<HashMap<String, DataFrame>>> {
543 GLOBAL_TEMP_CATALOG
544 .get_or_init(|| Arc::new(Mutex::new(HashMap::new())))
545 .clone()
546}
547
548#[derive(Clone)]
550pub struct SparkSessionBuilder {
551 app_name: Option<String>,
552 master: Option<String>,
553 config: HashMap<String, String>,
554}
555
556impl Default for SparkSessionBuilder {
557 fn default() -> Self {
558 Self::new()
559 }
560}
561
562impl SparkSessionBuilder {
563 pub fn new() -> Self {
564 SparkSessionBuilder {
565 app_name: None,
566 master: None,
567 config: HashMap::new(),
568 }
569 }
570
571 pub fn app_name(mut self, name: impl Into<String>) -> Self {
572 self.app_name = Some(name.into());
573 self
574 }
575
576 pub fn master(mut self, master: impl Into<String>) -> Self {
577 self.master = Some(master.into());
578 self
579 }
580
581 pub fn config(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
582 self.config.insert(key.into(), value.into());
583 self
584 }
585
586 pub fn get_or_create(self) -> SparkSession {
587 let session = SparkSession::new(self.app_name, self.master, self.config);
588 set_thread_udf_session(session.clone());
589 session
590 }
591
592 pub fn with_config(mut self, config: &crate::config::SparklessConfig) -> Self {
595 for (k, v) in config.to_session_config() {
596 self.config.insert(k, v);
597 }
598 self
599 }
600}
601
602pub type TempViewCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
604
605pub type TableCatalog = Arc<Mutex<HashMap<String, DataFrame>>>;
607
608pub type DatabaseCatalog = Arc<Mutex<HashSet<String>>>;
610
611#[derive(Clone)]
614pub struct SparkSession {
615 app_name: Option<String>,
616 master: Option<String>,
617 config: HashMap<String, String>,
618 pub(crate) catalog: TempViewCatalog,
620 pub(crate) tables: TableCatalog,
622 pub(crate) databases: DatabaseCatalog,
624 pub(crate) udf_registry: UdfRegistry,
626}
627
628impl SparkSession {
629 pub fn new(
630 app_name: Option<String>,
631 master: Option<String>,
632 config: HashMap<String, String>,
633 ) -> Self {
634 SparkSession {
635 app_name,
636 master,
637 config,
638 catalog: Arc::new(Mutex::new(HashMap::new())),
639 tables: Arc::new(Mutex::new(HashMap::new())),
640 databases: Arc::new(Mutex::new(HashSet::new())),
641 udf_registry: UdfRegistry::new(),
642 }
643 }
644
645 pub fn create_or_replace_temp_view(&self, name: &str, df: DataFrame) {
648 let _ = self
649 .catalog
650 .lock()
651 .map(|mut m| m.insert(name.to_string(), df));
652 }
653
654 pub fn create_global_temp_view(&self, name: &str, df: DataFrame) {
656 let _ = global_temp_catalog()
657 .lock()
658 .map(|mut m| m.insert(name.to_string(), df));
659 }
660
661 pub fn create_or_replace_global_temp_view(&self, name: &str, df: DataFrame) {
663 let _ = global_temp_catalog()
664 .lock()
665 .map(|mut m| m.insert(name.to_string(), df));
666 }
667
668 pub fn drop_temp_view(&self, name: &str) {
671 let _ = self.catalog.lock().map(|mut m| m.remove(name));
672 }
673
674 pub fn drop_global_temp_view(&self, name: &str) -> bool {
676 global_temp_catalog()
677 .lock()
678 .map(|mut m| m.remove(name).is_some())
679 .unwrap_or(false)
680 }
681
682 pub fn register_table(&self, name: &str, df: DataFrame) {
684 let _ = self
685 .tables
686 .lock()
687 .map(|mut m| m.insert(name.to_string(), df));
688 }
689
690 pub fn register_database(&self, name: &str) {
692 let _ = self.databases.lock().map(|mut s| {
693 s.insert(name.to_string());
694 });
695 }
696
697 pub fn list_database_names(&self) -> Vec<String> {
699 let mut names: Vec<String> = vec!["default".to_string(), "global_temp".to_string()];
700 if let Ok(guard) = self.databases.lock() {
701 let mut created: Vec<String> = guard.iter().cloned().collect();
702 created.sort();
703 names.extend(created);
704 }
705 names
706 }
707
708 pub fn database_exists(&self, name: &str) -> bool {
710 if name.eq_ignore_ascii_case("default") || name.eq_ignore_ascii_case("global_temp") {
711 return true;
712 }
713 self.databases
714 .lock()
715 .map(|s| s.iter().any(|n| n.eq_ignore_ascii_case(name)))
716 .unwrap_or(false)
717 }
718
719 pub fn get_saved_table(&self, name: &str) -> Option<DataFrame> {
721 self.tables.lock().ok().and_then(|m| m.get(name).cloned())
722 }
723
724 pub fn saved_table_exists(&self, name: &str) -> bool {
726 self.tables
727 .lock()
728 .map(|m| m.contains_key(name))
729 .unwrap_or(false)
730 }
731
732 pub fn table_exists(&self, name: &str) -> bool {
734 if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
736 return global_temp_catalog()
737 .lock()
738 .map(|m| m.contains_key(tbl))
739 .unwrap_or(false);
740 }
741 if self
742 .catalog
743 .lock()
744 .map(|m| m.contains_key(name))
745 .unwrap_or(false)
746 {
747 return true;
748 }
749 if self
750 .tables
751 .lock()
752 .map(|m| m.contains_key(name))
753 .unwrap_or(false)
754 {
755 return true;
756 }
757 if let Some(warehouse) = self.warehouse_dir() {
759 let path = Path::new(warehouse).join(name);
760 if path.is_dir() {
761 return true;
762 }
763 }
764 false
765 }
766
767 pub fn list_global_temp_view_names(&self) -> Vec<String> {
769 global_temp_catalog()
770 .lock()
771 .map(|m| m.keys().cloned().collect())
772 .unwrap_or_default()
773 }
774
775 pub fn list_temp_view_names(&self) -> Vec<String> {
777 self.catalog
778 .lock()
779 .map(|m| m.keys().cloned().collect())
780 .unwrap_or_default()
781 }
782
783 pub fn list_table_names(&self) -> Vec<String> {
785 self.tables
786 .lock()
787 .map(|m| m.keys().cloned().collect())
788 .unwrap_or_default()
789 }
790
791 pub fn drop_table(&self, name: &str) -> bool {
793 self.tables
794 .lock()
795 .map(|mut m| m.remove(name).is_some())
796 .unwrap_or(false)
797 }
798
799 pub fn drop_database(&self, name: &str) -> bool {
802 if name.eq_ignore_ascii_case("default") || name.eq_ignore_ascii_case("global_temp") {
803 return false;
804 }
805 self.databases
806 .lock()
807 .map(|mut s| s.remove(name))
808 .unwrap_or(false)
809 }
810
811 fn parse_global_temp_name(name: &str) -> Option<(&str, &str)> {
813 if let Some(dot) = name.find('.') {
814 let (db, tbl) = name.split_at(dot);
815 if db.eq_ignore_ascii_case("global_temp") {
816 return Some((db, tbl.strip_prefix('.').unwrap_or(tbl)));
817 }
818 }
819 None
820 }
821
822 pub fn warehouse_dir(&self) -> Option<&str> {
824 self.config
825 .get("spark.sql.warehouse.dir")
826 .map(|s| s.as_str())
827 .filter(|s| !s.is_empty())
828 }
829
830 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
833 if let Some((_db, tbl)) = Self::parse_global_temp_name(name) {
835 if let Some(df) = global_temp_catalog()
836 .lock()
837 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
838 .get(tbl)
839 .cloned()
840 {
841 return Ok(df);
842 }
843 return Err(PolarsError::InvalidOperation(
844 format!(
845 "Global temp view '{tbl}' not found. Register it with createOrReplaceGlobalTempView."
846 )
847 .into(),
848 ));
849 }
850 if let Some(df) = self
852 .catalog
853 .lock()
854 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
855 .get(name)
856 .cloned()
857 {
858 return Ok(df);
859 }
860 if let Some(df) = self
861 .tables
862 .lock()
863 .map_err(|_| PolarsError::InvalidOperation("catalog lock poisoned".into()))?
864 .get(name)
865 .cloned()
866 {
867 return Ok(df);
868 }
869 if let Some(warehouse) = self.warehouse_dir() {
871 let dir = Path::new(warehouse).join(name);
872 if dir.is_dir() {
873 let data_file = dir.join("data.parquet");
875 let read_path = if data_file.is_file() { data_file } else { dir };
876 return self.read_parquet(&read_path);
877 }
878 }
879 Err(PolarsError::InvalidOperation(
880 format!(
881 "Table or view '{name}' not found. Register it with create_or_replace_temp_view or saveAsTable."
882 )
883 .into(),
884 ))
885 }
886
887 pub fn builder() -> SparkSessionBuilder {
888 SparkSessionBuilder::new()
889 }
890
891 pub fn from_config(config: &crate::config::SparklessConfig) -> SparkSession {
894 Self::builder().with_config(config).get_or_create()
895 }
896
897 pub fn get_config(&self) -> &HashMap<String, String> {
899 &self.config
900 }
901
902 pub fn is_case_sensitive(&self) -> bool {
905 self.config
906 .get("spark.sql.caseSensitive")
907 .map(|v| v.eq_ignore_ascii_case("true"))
908 .unwrap_or(false)
909 }
910
911 pub fn register_udf<F>(&self, name: &str, f: F) -> Result<(), PolarsError>
913 where
914 F: Fn(&[Series]) -> Result<Series, PolarsError> + Send + Sync + 'static,
915 {
916 self.udf_registry.register_rust_udf(name, f)
917 }
918
919 pub fn create_dataframe(
939 &self,
940 data: Vec<(i64, i64, String)>,
941 column_names: Vec<&str>,
942 ) -> Result<DataFrame, PolarsError> {
943 if column_names.len() != 3 {
944 return Err(PolarsError::ComputeError(
945 format!(
946 "create_dataframe: expected 3 column names for (i64, i64, String) tuples, got {}. Hint: provide exactly 3 names, e.g. [\"id\", \"age\", \"name\"].",
947 column_names.len()
948 )
949 .into(),
950 ));
951 }
952
953 let mut cols: Vec<Series> = Vec::with_capacity(3);
954
955 let col0: Vec<i64> = data.iter().map(|t| t.0).collect();
957 cols.push(Series::new(column_names[0].into(), col0));
958
959 let col1: Vec<i64> = data.iter().map(|t| t.1).collect();
961 cols.push(Series::new(column_names[1].into(), col1));
962
963 let col2: Vec<String> = data.iter().map(|t| t.2.clone()).collect();
965 cols.push(Series::new(column_names[2].into(), col2));
966
967 let pl_df = PlDataFrame::new_infer_height(cols.iter().map(|s| s.clone().into()).collect())?;
968 Ok(DataFrame::from_polars_with_options(
969 pl_df,
970 self.is_case_sensitive(),
971 ))
972 }
973
974 pub fn create_dataframe_engine(
976 &self,
977 data: Vec<(i64, i64, String)>,
978 column_names: Vec<&str>,
979 ) -> Result<DataFrame, EngineError> {
980 self.create_dataframe(data, column_names)
981 .map_err(EngineError::from)
982 }
983
984 pub fn create_dataframe_from_polars(&self, df: PlDataFrame) -> DataFrame {
986 DataFrame::from_polars_with_options(df, self.is_case_sensitive())
987 }
988
989 fn infer_dtype_from_json_value(v: &JsonValue) -> Option<String> {
991 match v {
992 JsonValue::Null => None,
993 JsonValue::Bool(_) => Some("boolean".to_string()),
994 JsonValue::Number(n) => {
995 if n.is_i64() {
996 Some("bigint".to_string())
997 } else {
998 Some("double".to_string())
999 }
1000 }
1001 JsonValue::String(s) => {
1002 if chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d").is_ok() {
1003 Some("date".to_string())
1004 } else if chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f").is_ok()
1005 || chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S").is_ok()
1006 {
1007 Some("timestamp".to_string())
1008 } else {
1009 Some("string".to_string())
1010 }
1011 }
1012 JsonValue::Array(_) => Some("array".to_string()),
1013 JsonValue::Object(_) => Some("string".to_string()), }
1015 }
1016
1017 pub fn infer_schema_from_json_rows(
1020 rows: &[Vec<JsonValue>],
1021 names: &[String],
1022 ) -> Vec<(String, String)> {
1023 if names.is_empty() {
1024 return Vec::new();
1025 }
1026 let mut schema: Vec<(String, String)> = names
1027 .iter()
1028 .map(|n| (n.clone(), "string".to_string()))
1029 .collect();
1030 for (col_idx, (_, dtype_str)) in schema.iter_mut().enumerate() {
1031 for row in rows {
1032 let v = row.get(col_idx).unwrap_or(&JsonValue::Null);
1033 if let Some(dtype) = Self::infer_dtype_from_json_value(v) {
1034 *dtype_str = dtype;
1035 break;
1036 }
1037 }
1038 }
1039 schema
1040 }
1041
1042 pub fn create_dataframe_from_rows(
1049 &self,
1050 rows: Vec<Vec<JsonValue>>,
1051 schema: Vec<(String, String)>,
1052 ) -> Result<DataFrame, PolarsError> {
1053 let schema = if schema.is_empty() && !rows.is_empty() {
1055 let ncols = rows[0].len();
1056 let names: Vec<String> = (0..ncols).map(|i| format!("c{i}")).collect();
1057 Self::infer_schema_from_json_rows(&rows, &names)
1058 } else {
1059 schema
1060 };
1061
1062 if schema.is_empty() {
1063 if rows.is_empty() {
1064 return Ok(DataFrame::from_polars_with_options(
1065 PlDataFrame::new(0, vec![])?,
1066 self.is_case_sensitive(),
1067 ));
1068 }
1069 return Err(PolarsError::InvalidOperation(
1070 "create_dataframe_from_rows: schema must not be empty when rows are not empty"
1071 .into(),
1072 ));
1073 }
1074 use chrono::{NaiveDate, NaiveDateTime};
1075
1076 let mut cols: Vec<Series> = Vec::with_capacity(schema.len());
1077
1078 for (col_idx, (name, type_str)) in schema.iter().enumerate() {
1079 let type_lower = type_str.trim().to_lowercase();
1080 let s = match type_lower.as_str() {
1081 "int" | "integer" | "bigint" | "long" => {
1082 let vals: Vec<Option<i64>> = rows
1083 .iter()
1084 .map(|row| {
1085 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1086 match v {
1087 JsonValue::Number(n) => n.as_i64(),
1088 JsonValue::Null => None,
1089 _ => None,
1090 }
1091 })
1092 .collect();
1093 Series::new(name.as_str().into(), vals)
1094 }
1095 "double" | "float" | "double_precision" => {
1096 let vals: Vec<Option<f64>> = rows
1097 .iter()
1098 .map(|row| {
1099 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1100 match v {
1101 JsonValue::Number(n) => n.as_f64(),
1102 JsonValue::Null => None,
1103 _ => None,
1104 }
1105 })
1106 .collect();
1107 Series::new(name.as_str().into(), vals)
1108 }
1109 _ if is_decimal_type_str(&type_lower) => {
1110 let vals: Vec<Option<f64>> = rows
1111 .iter()
1112 .map(|row| {
1113 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1114 match v {
1115 JsonValue::Number(n) => n.as_f64(),
1116 JsonValue::Null => None,
1117 _ => None,
1118 }
1119 })
1120 .collect();
1121 Series::new(name.as_str().into(), vals)
1122 }
1123 "string" | "str" | "varchar" => {
1124 let vals: Vec<Option<String>> = rows
1125 .iter()
1126 .map(|row| {
1127 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1128 match v {
1129 JsonValue::String(s) => Some(s),
1130 JsonValue::Null => None,
1131 other => Some(other.to_string()),
1132 }
1133 })
1134 .collect();
1135 Series::new(name.as_str().into(), vals)
1136 }
1137 "boolean" | "bool" => {
1138 let vals: Vec<Option<bool>> = rows
1139 .iter()
1140 .map(|row| {
1141 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1142 match v {
1143 JsonValue::Bool(b) => Some(b),
1144 JsonValue::Null => None,
1145 _ => None,
1146 }
1147 })
1148 .collect();
1149 Series::new(name.as_str().into(), vals)
1150 }
1151 "date" => {
1152 let epoch = crate::date_utils::epoch_naive_date();
1153 let vals: Vec<Option<i32>> = rows
1154 .iter()
1155 .map(|row| {
1156 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1157 match v {
1158 JsonValue::String(s) => NaiveDate::parse_from_str(&s, "%Y-%m-%d")
1159 .ok()
1160 .map(|d| (d - epoch).num_days() as i32),
1161 JsonValue::Null => None,
1162 _ => None,
1163 }
1164 })
1165 .collect();
1166 let series = Series::new(name.as_str().into(), vals);
1167 series
1168 .cast(&DataType::Date)
1169 .map_err(|e| PolarsError::ComputeError(format!("date cast: {e}").into()))?
1170 }
1171 "timestamp" | "datetime" | "timestamp_ntz" => {
1172 let vals: Vec<Option<i64>> =
1173 rows.iter()
1174 .map(|row| {
1175 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1176 match v {
1177 JsonValue::String(s) => {
1178 let parsed = NaiveDateTime::parse_from_str(
1179 &s,
1180 "%Y-%m-%dT%H:%M:%S%.f",
1181 )
1182 .map_err(|e| {
1183 PolarsError::ComputeError(e.to_string().into())
1184 })
1185 .or_else(|_| {
1186 NaiveDateTime::parse_from_str(&s, "%Y-%m-%dT%H:%M:%S")
1187 .map_err(|e| {
1188 PolarsError::ComputeError(e.to_string().into())
1189 })
1190 })
1191 .or_else(|_| {
1192 NaiveDate::parse_from_str(&s, "%Y-%m-%d")
1193 .map_err(|e| {
1194 PolarsError::ComputeError(e.to_string().into())
1195 })
1196 .and_then(|d| {
1197 d.and_hms_opt(0, 0, 0).ok_or_else(|| {
1198 PolarsError::ComputeError(
1199 "date to datetime (0:0:0)".into(),
1200 )
1201 })
1202 })
1203 });
1204 parsed.ok().map(|dt| dt.and_utc().timestamp_micros())
1205 }
1206 JsonValue::Number(n) => n.as_i64(),
1207 JsonValue::Null => None,
1208 _ => None,
1209 }
1210 })
1211 .collect();
1212 let series = Series::new(name.as_str().into(), vals);
1213 series
1214 .cast(&DataType::Datetime(TimeUnit::Microseconds, None))
1215 .map_err(|e| {
1216 PolarsError::ComputeError(format!("datetime cast: {e}").into())
1217 })?
1218 }
1219 "list" | "array" => {
1220 let (elem_type, inner_dtype) = infer_list_element_type(&rows, col_idx)
1222 .unwrap_or(("bigint".to_string(), DataType::Int64));
1223 let n = rows.len();
1224 let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
1225 for row in rows.iter() {
1226 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1227 if let JsonValue::Null = &v {
1228 builder.append_null();
1229 } else if let Some(arr) = json_value_to_array(&v) {
1230 let elem_series: Vec<Series> = arr
1232 .iter()
1233 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1234 .collect::<Result<Vec<_>, _>>()?;
1235 let vals: Vec<_> =
1236 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1237 let s = Series::from_any_values_and_dtype(
1238 PlSmallStr::EMPTY,
1239 &vals,
1240 &inner_dtype,
1241 false,
1242 )
1243 .map_err(|e| {
1244 PolarsError::ComputeError(format!("array elem: {e}").into())
1245 })?;
1246 builder.append_series(&s)?;
1247 } else {
1248 let single_arr = [v];
1250 let elem_series: Vec<Series> = single_arr
1251 .iter()
1252 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1253 .collect::<Result<Vec<_>, _>>()?;
1254 let vals: Vec<_> =
1255 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1256 let s = Series::from_any_values_and_dtype(
1257 PlSmallStr::EMPTY,
1258 &vals,
1259 &inner_dtype,
1260 false,
1261 )
1262 .map_err(|e| {
1263 PolarsError::ComputeError(format!("array elem: {e}").into())
1264 })?;
1265 builder.append_series(&s)?;
1266 }
1267 }
1268 builder.finish().into_series()
1269 }
1270 _ if parse_array_element_type(&type_lower).is_some() => {
1271 let elem_type = parse_array_element_type(&type_lower).unwrap_or_else(|| {
1272 unreachable!("guard above ensures parse_array_element_type returned Some")
1273 });
1274 let inner_dtype = json_type_str_to_polars(&elem_type)
1275 .ok_or_else(|| {
1276 PolarsError::ComputeError(
1277 format!(
1278 "create_dataframe_from_rows: array element type '{elem_type}' not supported"
1279 )
1280 .into(),
1281 )
1282 })?;
1283 let n = rows.len();
1284 let mut builder = get_list_builder(&inner_dtype, 64, n, name.as_str().into());
1285 for row in rows.iter() {
1286 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1287 if let JsonValue::Null = &v {
1288 builder.append_null();
1289 } else if let Some(arr) = json_value_to_array(&v) {
1290 let elem_series: Vec<Series> = arr
1292 .iter()
1293 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1294 .collect::<Result<Vec<_>, _>>()?;
1295 let vals: Vec<_> =
1296 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1297 let s = Series::from_any_values_and_dtype(
1298 PlSmallStr::EMPTY,
1299 &vals,
1300 &inner_dtype,
1301 false,
1302 )
1303 .map_err(|e| {
1304 PolarsError::ComputeError(format!("array elem: {e}").into())
1305 })?;
1306 builder.append_series(&s)?;
1307 } else {
1308 let single_arr = [v];
1310 let elem_series: Vec<Series> = single_arr
1311 .iter()
1312 .map(|e| json_value_to_series_single(e, &elem_type, "elem"))
1313 .collect::<Result<Vec<_>, _>>()?;
1314 let vals: Vec<_> =
1315 elem_series.iter().filter_map(|s| s.get(0).ok()).collect();
1316 let s = Series::from_any_values_and_dtype(
1317 PlSmallStr::EMPTY,
1318 &vals,
1319 &inner_dtype,
1320 false,
1321 )
1322 .map_err(|e| {
1323 PolarsError::ComputeError(format!("array elem: {e}").into())
1324 })?;
1325 builder.append_series(&s)?;
1326 }
1327 }
1328 builder.finish().into_series()
1329 }
1330 _ if parse_map_key_value_types(&type_lower).is_some() => {
1331 let (key_type, value_type) = parse_map_key_value_types(&type_lower)
1332 .unwrap_or_else(|| unreachable!("guard ensures Some"));
1333 let key_dtype = json_type_str_to_polars(&key_type).ok_or_else(|| {
1334 PolarsError::ComputeError(
1335 format!(
1336 "create_dataframe_from_rows: map key type '{key_type}' not supported"
1337 )
1338 .into(),
1339 )
1340 })?;
1341 let value_dtype = json_type_str_to_polars(&value_type).ok_or_else(|| {
1342 PolarsError::ComputeError(
1343 format!(
1344 "create_dataframe_from_rows: map value type '{value_type}' not supported"
1345 )
1346 .into(),
1347 )
1348 })?;
1349 let struct_dtype = DataType::Struct(vec![
1350 Field::new("key".into(), key_dtype.clone()),
1351 Field::new("value".into(), value_dtype.clone()),
1352 ]);
1353 let n = rows.len();
1354 let mut builder = get_list_builder(&struct_dtype, 64, n, name.as_str().into());
1355 for row in rows.iter() {
1356 let v = row.get(col_idx).cloned().unwrap_or(JsonValue::Null);
1357 if matches!(v, JsonValue::Null) {
1358 builder.append_null();
1359 } else if let Some(obj) = v.as_object() {
1360 let st = json_object_to_map_struct_series(
1361 obj,
1362 &key_type,
1363 &value_type,
1364 &key_dtype,
1365 &value_dtype,
1366 name,
1367 )?;
1368 builder.append_series(&st)?;
1369 } else {
1370 return Err(PolarsError::ComputeError(
1371 format!(
1372 "create_dataframe_from_rows: map column '{name}' expects JSON object (dict), got {:?}",
1373 v
1374 )
1375 .into(),
1376 ));
1377 }
1378 }
1379 builder.finish().into_series()
1380 }
1381 _ if parse_struct_fields(&type_lower).is_some() => {
1382 let values: Vec<Option<JsonValue>> =
1383 rows.iter().map(|row| row.get(col_idx).cloned()).collect();
1384 json_values_to_series(&values, &type_lower, name)?
1385 }
1386 _ => {
1387 return Err(PolarsError::ComputeError(
1388 format!(
1389 "create_dataframe_from_rows: unsupported type '{type_str}' for column '{name}'"
1390 )
1391 .into(),
1392 ));
1393 }
1394 };
1395 cols.push(s);
1396 }
1397
1398 let pl_df = PlDataFrame::new_infer_height(cols.iter().map(|s| s.clone().into()).collect())?;
1399 Ok(DataFrame::from_polars_with_options(
1400 pl_df,
1401 self.is_case_sensitive(),
1402 ))
1403 }
1404
1405 pub fn create_dataframe_from_rows_engine(
1407 &self,
1408 rows: Vec<Vec<JsonValue>>,
1409 schema: Vec<(String, String)>,
1410 ) -> Result<DataFrame, EngineError> {
1411 self.create_dataframe_from_rows(rows, schema)
1412 .map_err(EngineError::from)
1413 }
1414
1415 pub fn range(&self, start: i64, end: i64, step: i64) -> Result<DataFrame, PolarsError> {
1422 if step == 0 {
1423 return Err(PolarsError::InvalidOperation(
1424 "range: step must not be 0".into(),
1425 ));
1426 }
1427 let mut vals: Vec<i64> = Vec::new();
1428 let mut v = start;
1429 if step > 0 {
1430 while v < end {
1431 vals.push(v);
1432 v = v.saturating_add(step);
1433 }
1434 } else {
1435 while v > end {
1436 vals.push(v);
1437 v = v.saturating_add(step);
1438 }
1439 }
1440 let col = Series::new("id".into(), vals);
1441 let pl_df = PlDataFrame::new_infer_height(vec![col.into()])?;
1442 Ok(DataFrame::from_polars_with_options(
1443 pl_df,
1444 self.is_case_sensitive(),
1445 ))
1446 }
1447
1448 pub fn read_csv(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1463 use polars::prelude::*;
1464 let path = path.as_ref();
1465 if !path.exists() {
1466 return Err(PolarsError::ComputeError(
1467 format!("read_csv: file not found: {}", path.display()).into(),
1468 ));
1469 }
1470 let path_display = path.display();
1471 let pl_path = PlRefPath::try_from_path(path).map_err(|e| {
1473 PolarsError::ComputeError(format!("read_csv({path_display}): path: {e}").into())
1474 })?;
1475 let lf = LazyCsvReader::new(pl_path)
1476 .with_has_header(true)
1477 .with_infer_schema_length(Some(100))
1478 .finish()
1479 .map_err(|e| {
1480 PolarsError::ComputeError(
1481 format!(
1482 "read_csv({path_display}): {e} Hint: check that the file exists and is valid CSV."
1483 )
1484 .into(),
1485 )
1486 })?;
1487 Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1488 lf,
1489 self.is_case_sensitive(),
1490 ))
1491 }
1492
1493 pub fn read_csv_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1495 self.read_csv(path).map_err(EngineError::from)
1496 }
1497
1498 pub fn read_parquet(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1512 use polars::prelude::*;
1513 let path = path.as_ref();
1514 if !path.exists() {
1515 return Err(PolarsError::ComputeError(
1516 format!("read_parquet: file not found: {}", path.display()).into(),
1517 ));
1518 }
1519 let pl_path = PlRefPath::try_from_path(path)
1521 .map_err(|e| PolarsError::ComputeError(format!("read_parquet: path: {e}").into()))?;
1522 let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())?;
1523 Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1524 lf,
1525 self.is_case_sensitive(),
1526 ))
1527 }
1528
1529 pub fn read_parquet_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1531 self.read_parquet(path).map_err(EngineError::from)
1532 }
1533
1534 pub fn read_json(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1548 use polars::prelude::*;
1549 use std::num::NonZeroUsize;
1550 let path = path.as_ref();
1551 if !path.exists() {
1552 return Err(PolarsError::ComputeError(
1553 format!("read_json: file not found: {}", path.display()).into(),
1554 ));
1555 }
1556 let pl_path = PlRefPath::try_from_path(path)
1558 .map_err(|e| PolarsError::ComputeError(format!("read_json: path: {e}").into()))?;
1559 let lf = LazyJsonLineReader::new(pl_path)
1560 .with_infer_schema_length(NonZeroUsize::new(100))
1561 .finish()?;
1562 Ok(crate::dataframe::DataFrame::from_lazy_with_options(
1563 lf,
1564 self.is_case_sensitive(),
1565 ))
1566 }
1567
1568 pub fn read_json_engine(&self, path: impl AsRef<Path>) -> Result<DataFrame, EngineError> {
1570 self.read_json(path).map_err(EngineError::from)
1571 }
1572
1573 #[cfg(feature = "sql")]
1577 pub fn sql(&self, query: &str) -> Result<DataFrame, PolarsError> {
1578 crate::sql::execute_sql(self, query)
1579 }
1580
1581 #[cfg(not(feature = "sql"))]
1583 pub fn sql(&self, _query: &str) -> Result<DataFrame, PolarsError> {
1584 Err(PolarsError::InvalidOperation(
1585 "SQL queries require the 'sql' feature. Build with --features sql.".into(),
1586 ))
1587 }
1588
1589 pub fn table_engine(&self, name: &str) -> Result<DataFrame, EngineError> {
1591 self.table(name).map_err(EngineError::from)
1592 }
1593
1594 fn looks_like_path(s: &str) -> bool {
1596 s.contains('/') || s.contains('\\') || Path::new(s).exists()
1597 }
1598
1599 #[cfg(feature = "delta")]
1601 pub fn read_delta_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1602 crate::delta::read_delta(path, self.is_case_sensitive())
1603 }
1604
1605 #[cfg(feature = "delta")]
1607 pub fn read_delta_path_with_version(
1608 &self,
1609 path: impl AsRef<Path>,
1610 version: Option<i64>,
1611 ) -> Result<DataFrame, PolarsError> {
1612 crate::delta::read_delta_with_version(path, version, self.is_case_sensitive())
1613 }
1614
1615 #[cfg(feature = "delta")]
1617 pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1618 if Self::looks_like_path(name_or_path) {
1619 self.read_delta_path(Path::new(name_or_path))
1620 } else {
1621 self.table(name_or_path)
1622 }
1623 }
1624
1625 #[cfg(feature = "delta")]
1626 pub fn read_delta_with_version(
1627 &self,
1628 name_or_path: &str,
1629 version: Option<i64>,
1630 ) -> Result<DataFrame, PolarsError> {
1631 if Self::looks_like_path(name_or_path) {
1632 self.read_delta_path_with_version(Path::new(name_or_path), version)
1633 } else {
1634 self.table(name_or_path)
1636 }
1637 }
1638
1639 #[cfg(not(feature = "delta"))]
1641 pub fn read_delta(&self, name_or_path: &str) -> Result<DataFrame, PolarsError> {
1642 if Self::looks_like_path(name_or_path) {
1643 Err(PolarsError::InvalidOperation(
1644 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1645 ))
1646 } else {
1647 self.table(name_or_path)
1648 }
1649 }
1650
1651 #[cfg(not(feature = "delta"))]
1652 pub fn read_delta_with_version(
1653 &self,
1654 name_or_path: &str,
1655 version: Option<i64>,
1656 ) -> Result<DataFrame, PolarsError> {
1657 let _ = version;
1658 self.read_delta(name_or_path)
1659 }
1660
1661 #[cfg(feature = "delta")]
1663 pub fn read_delta_from_path(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1664 self.read_delta_path(path)
1665 }
1666
1667 #[cfg(not(feature = "delta"))]
1668 pub fn read_delta_from_path(&self, _path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1669 Err(PolarsError::InvalidOperation(
1670 "Delta Lake requires the 'delta' feature. Build with --features delta.".into(),
1671 ))
1672 }
1673
1674 pub fn stop(&self) {
1676 let _ = self.catalog.lock().map(|mut m| m.clear());
1679 let _ = self.tables.lock().map(|mut m| m.clear());
1680 let _ = self.databases.lock().map(|mut s| s.clear());
1681 let _ = self.udf_registry.clear();
1682 clear_thread_udf_session();
1683 }
1684}
1685
1686pub struct DataFrameReader {
1689 session: SparkSession,
1690 options: HashMap<String, String>,
1691 format: Option<String>,
1692}
1693
1694impl DataFrameReader {
1695 pub fn new(session: SparkSession) -> Self {
1696 DataFrameReader {
1697 session,
1698 options: HashMap::new(),
1699 format: None,
1700 }
1701 }
1702
1703 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1705 self.options.insert(key.into(), value.into());
1706 self
1707 }
1708
1709 pub fn options(mut self, opts: impl IntoIterator<Item = (String, String)>) -> Self {
1711 for (k, v) in opts {
1712 self.options.insert(k, v);
1713 }
1714 self
1715 }
1716
1717 pub fn format(mut self, fmt: impl Into<String>) -> Self {
1719 self.format = Some(fmt.into());
1720 self
1721 }
1722
1723 pub fn schema(self, _schema: impl Into<String>) -> Self {
1725 self
1726 }
1727
1728 pub fn load(&self, path: impl AsRef<Path>) -> Result<DataFrame, PolarsError> {
1730 let path = path.as_ref();
1731 let fmt = self.format.clone().or_else(|| {
1732 path.extension()
1733 .and_then(|e| e.to_str())
1734 .map(|s| s.to_lowercase())
1735 });
1736 match fmt.as_deref() {
1737 Some("parquet") => self.parquet(path),
1738 Some("csv") => self.csv(path),
1739 Some("json") | Some("jsonl") => self.json(path),
1740 #[cfg(feature = "delta")]
1741 Some("delta") => self.session.read_delta_from_path(path),
1742 _ => Err(PolarsError::ComputeError(
1743 format!(
1744 "load: could not infer format for path '{}'. Use format('parquet'|'csv'|'json') before load.",
1745 path.display()
1746 )
1747 .into(),
1748 )),
1749 }
1750 }
1751
1752 pub fn table(&self, name: &str) -> Result<DataFrame, PolarsError> {
1754 self.session.table(name)
1755 }
1756
1757 fn apply_csv_options(
1758 &self,
1759 reader: polars::prelude::LazyCsvReader,
1760 ) -> polars::prelude::LazyCsvReader {
1761 use polars::prelude::NullValues;
1762 let mut r = reader;
1763 if let Some(v) = self.options.get("header") {
1764 let has_header = v.eq_ignore_ascii_case("true") || v == "1";
1765 r = r.with_has_header(has_header);
1766 }
1767 if let Some(v) = self.options.get("inferSchema") {
1768 if v.eq_ignore_ascii_case("true") || v == "1" {
1769 let n = self
1770 .options
1771 .get("inferSchemaLength")
1772 .and_then(|s| s.parse::<usize>().ok())
1773 .unwrap_or(100);
1774 r = r.with_infer_schema_length(Some(n));
1775 } else {
1776 r = r.with_infer_schema_length(Some(0));
1778 }
1779 } else if let Some(v) = self.options.get("inferSchemaLength") {
1780 if let Ok(n) = v.parse::<usize>() {
1781 r = r.with_infer_schema_length(Some(n));
1782 }
1783 }
1784 if let Some(sep) = self.options.get("sep") {
1785 if let Some(b) = sep.bytes().next() {
1786 r = r.with_separator(b);
1787 }
1788 }
1789 if let Some(null_val) = self.options.get("nullValue") {
1790 r = r.with_null_values(Some(NullValues::AllColumnsSingle(null_val.clone().into())));
1791 }
1792 r
1793 }
1794
1795 fn apply_json_options(
1796 &self,
1797 reader: polars::prelude::LazyJsonLineReader,
1798 ) -> polars::prelude::LazyJsonLineReader {
1799 use std::num::NonZeroUsize;
1800 let mut r = reader;
1801 if let Some(v) = self.options.get("inferSchemaLength") {
1802 if let Ok(n) = v.parse::<usize>() {
1803 r = r.with_infer_schema_length(NonZeroUsize::new(n));
1804 }
1805 }
1806 r
1807 }
1808
1809 pub fn csv(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1810 use polars::prelude::*;
1811 let path = path.as_ref();
1812 let path_display = path.display();
1813 let pl_path = PlRefPath::try_from_path(path).map_err(|e| {
1814 PolarsError::ComputeError(format!("csv({path_display}): path: {e}").into())
1815 })?;
1816 let reader = LazyCsvReader::new(pl_path);
1817 let reader = if self.options.is_empty() {
1818 reader
1819 .with_has_header(true)
1820 .with_infer_schema_length(Some(100))
1821 } else {
1822 self.apply_csv_options(
1823 reader
1824 .with_has_header(true)
1825 .with_infer_schema_length(Some(100)),
1826 )
1827 };
1828 let lf = reader.finish().map_err(|e| {
1829 PolarsError::ComputeError(format!("read csv({path_display}): {e}").into())
1830 })?;
1831 let pl_df = lf.collect().map_err(|e| {
1832 PolarsError::ComputeError(
1833 format!("read csv({path_display}): collect failed: {e}").into(),
1834 )
1835 })?;
1836 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1837 pl_df,
1838 self.session.is_case_sensitive(),
1839 ))
1840 }
1841
1842 pub fn parquet(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1843 use polars::prelude::*;
1844 let path = path.as_ref();
1845 let pl_path = PlRefPath::try_from_path(path)
1846 .map_err(|e| PolarsError::ComputeError(format!("parquet: path: {e}").into()))?;
1847 let lf = LazyFrame::scan_parquet(pl_path, ScanArgsParquet::default())?;
1848 let pl_df = lf.collect()?;
1849 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1850 pl_df,
1851 self.session.is_case_sensitive(),
1852 ))
1853 }
1854
1855 pub fn json(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1856 use polars::prelude::*;
1857 use std::num::NonZeroUsize;
1858 let path = path.as_ref();
1859 let pl_path = PlRefPath::try_from_path(path)
1860 .map_err(|e| PolarsError::ComputeError(format!("json: path: {e}").into()))?;
1861 let reader = LazyJsonLineReader::new(pl_path);
1862 let reader = if self.options.is_empty() {
1863 reader.with_infer_schema_length(NonZeroUsize::new(100))
1864 } else {
1865 self.apply_json_options(reader.with_infer_schema_length(NonZeroUsize::new(100)))
1866 };
1867 let lf = reader.finish()?;
1868 let pl_df = lf.collect()?;
1869 Ok(crate::dataframe::DataFrame::from_polars_with_options(
1870 pl_df,
1871 self.session.is_case_sensitive(),
1872 ))
1873 }
1874
1875 #[cfg(feature = "delta")]
1876 pub fn delta(&self, path: impl AsRef<std::path::Path>) -> Result<DataFrame, PolarsError> {
1877 self.session.read_delta_from_path(path)
1878 }
1879}
1880
1881impl SparkSession {
1882 pub fn read(&self) -> DataFrameReader {
1884 DataFrameReader::new(SparkSession {
1885 app_name: self.app_name.clone(),
1886 master: self.master.clone(),
1887 config: self.config.clone(),
1888 catalog: self.catalog.clone(),
1889 tables: self.tables.clone(),
1890 databases: self.databases.clone(),
1891 udf_registry: self.udf_registry.clone(),
1892 })
1893 }
1894}
1895
1896impl Default for SparkSession {
1897 fn default() -> Self {
1898 Self::builder().get_or_create()
1899 }
1900}
1901
1902#[cfg(test)]
1903mod tests {
1904 use super::*;
1905
1906 #[test]
1907 fn test_spark_session_builder_basic() {
1908 let spark = SparkSession::builder().app_name("test_app").get_or_create();
1909
1910 assert_eq!(spark.app_name, Some("test_app".to_string()));
1911 }
1912
1913 #[test]
1914 fn test_spark_session_builder_with_master() {
1915 let spark = SparkSession::builder()
1916 .app_name("test_app")
1917 .master("local[*]")
1918 .get_or_create();
1919
1920 assert_eq!(spark.app_name, Some("test_app".to_string()));
1921 assert_eq!(spark.master, Some("local[*]".to_string()));
1922 }
1923
1924 #[test]
1925 fn test_spark_session_builder_with_config() {
1926 let spark = SparkSession::builder()
1927 .app_name("test_app")
1928 .config("spark.executor.memory", "4g")
1929 .config("spark.driver.memory", "2g")
1930 .get_or_create();
1931
1932 assert_eq!(
1933 spark.config.get("spark.executor.memory"),
1934 Some(&"4g".to_string())
1935 );
1936 assert_eq!(
1937 spark.config.get("spark.driver.memory"),
1938 Some(&"2g".to_string())
1939 );
1940 }
1941
1942 #[test]
1943 fn test_spark_session_default() {
1944 let spark = SparkSession::default();
1945 assert!(spark.app_name.is_none());
1946 assert!(spark.master.is_none());
1947 assert!(spark.config.is_empty());
1948 }
1949
1950 #[test]
1951 fn test_create_dataframe_success() {
1952 let spark = SparkSession::builder().app_name("test").get_or_create();
1953 let data = vec![
1954 (1i64, 25i64, "Alice".to_string()),
1955 (2i64, 30i64, "Bob".to_string()),
1956 ];
1957
1958 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
1959
1960 assert!(result.is_ok());
1961 let df = result.unwrap();
1962 assert_eq!(df.count().unwrap(), 2);
1963
1964 let columns = df.columns().unwrap();
1965 assert!(columns.contains(&"id".to_string()));
1966 assert!(columns.contains(&"age".to_string()));
1967 assert!(columns.contains(&"name".to_string()));
1968 }
1969
1970 #[test]
1971 fn test_create_dataframe_wrong_column_count() {
1972 let spark = SparkSession::builder().app_name("test").get_or_create();
1973 let data = vec![(1i64, 25i64, "Alice".to_string())];
1974
1975 let result = spark.create_dataframe(data.clone(), vec!["id", "age"]);
1977 assert!(result.is_err());
1978
1979 let result = spark.create_dataframe(data, vec!["id", "age", "name", "extra"]);
1981 assert!(result.is_err());
1982 }
1983
1984 #[test]
1985 fn test_create_dataframe_from_rows_empty_schema_with_rows_returns_error() {
1986 let spark = SparkSession::builder().app_name("test").get_or_create();
1987 let rows: Vec<Vec<JsonValue>> = vec![vec![]];
1988 let schema: Vec<(String, String)> = vec![];
1989 let result = spark.create_dataframe_from_rows(rows, schema);
1990 match &result {
1991 Err(e) => assert!(e.to_string().contains("schema must not be empty")),
1992 Ok(_) => panic!("expected error for empty schema with non-empty rows"),
1993 }
1994 }
1995
1996 #[test]
1997 fn test_create_dataframe_from_rows_empty_data_with_schema() {
1998 let spark = SparkSession::builder().app_name("test").get_or_create();
1999 let rows: Vec<Vec<JsonValue>> = vec![];
2000 let schema = vec![
2001 ("a".to_string(), "int".to_string()),
2002 ("b".to_string(), "string".to_string()),
2003 ];
2004 let result = spark.create_dataframe_from_rows(rows, schema);
2005 let df = result.unwrap();
2006 assert_eq!(df.count().unwrap(), 0);
2007 assert_eq!(df.collect_inner().unwrap().get_column_names(), &["a", "b"]);
2008 }
2009
2010 #[test]
2011 fn test_create_dataframe_from_rows_empty_schema_empty_data() {
2012 let spark = SparkSession::builder().app_name("test").get_or_create();
2013 let rows: Vec<Vec<JsonValue>> = vec![];
2014 let schema: Vec<(String, String)> = vec![];
2015 let result = spark.create_dataframe_from_rows(rows, schema);
2016 let df = result.unwrap();
2017 assert_eq!(df.count().unwrap(), 0);
2018 assert_eq!(df.collect_inner().unwrap().get_column_names().len(), 0);
2019 }
2020
2021 #[test]
2023 fn test_create_dataframe_from_rows_struct_as_object() {
2024 use serde_json::json;
2025
2026 let spark = SparkSession::builder().app_name("test").get_or_create();
2027 let schema = vec![
2028 ("id".to_string(), "string".to_string()),
2029 (
2030 "nested".to_string(),
2031 "struct<a:bigint,b:string>".to_string(),
2032 ),
2033 ];
2034 let rows: Vec<Vec<JsonValue>> = vec![
2035 vec![json!("x"), json!({"a": 1, "b": "y"})],
2036 vec![json!("z"), json!({"a": 2, "b": "w"})],
2037 ];
2038 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2039 assert_eq!(df.count().unwrap(), 2);
2040 let collected = df.collect_inner().unwrap();
2041 assert_eq!(collected.get_column_names(), &["id", "nested"]);
2042 }
2043
2044 #[test]
2046 fn test_create_dataframe_from_rows_struct_as_array() {
2047 use serde_json::json;
2048
2049 let spark = SparkSession::builder().app_name("test").get_or_create();
2050 let schema = vec![
2051 ("id".to_string(), "string".to_string()),
2052 (
2053 "nested".to_string(),
2054 "struct<a:bigint,b:string>".to_string(),
2055 ),
2056 ];
2057 let rows: Vec<Vec<JsonValue>> = vec![
2058 vec![json!("x"), json!([1, "y"])],
2059 vec![json!("z"), json!([2, "w"])],
2060 ];
2061 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2062 assert_eq!(df.count().unwrap(), 2);
2063 let collected = df.collect_inner().unwrap();
2064 assert_eq!(collected.get_column_names(), &["id", "nested"]);
2065 }
2066
2067 #[test]
2069 fn test_issue_610_struct_value_as_string_object_or_array() {
2070 use serde_json::json;
2071
2072 let spark = SparkSession::builder().app_name("test").get_or_create();
2073 let schema = vec![
2074 ("id".to_string(), "string".to_string()),
2075 (
2076 "nested".to_string(),
2077 "struct<a:bigint,b:string>".to_string(),
2078 ),
2079 ];
2080 let rows_object: Vec<Vec<JsonValue>> =
2082 vec![vec![json!("A"), json!(r#"{"a": 1, "b": "x"}"#)]];
2083 let df1 = spark
2084 .create_dataframe_from_rows(rows_object, schema.clone())
2085 .unwrap();
2086 assert_eq!(df1.count().unwrap(), 1);
2087
2088 let rows_array: Vec<Vec<JsonValue>> = vec![vec![json!("B"), json!(r#"[1, "y"]"#)]];
2090 let df2 = spark
2091 .create_dataframe_from_rows(rows_array, schema)
2092 .unwrap();
2093 assert_eq!(df2.count().unwrap(), 1);
2094 }
2095
2096 #[test]
2098 fn test_issue_611_array_column_single_value_as_one_element() {
2099 use serde_json::json;
2100
2101 let spark = SparkSession::builder().app_name("test").get_or_create();
2102 let schema = vec![
2103 ("id".to_string(), "string".to_string()),
2104 ("arr".to_string(), "array<bigint>".to_string()),
2105 ];
2106 let rows: Vec<Vec<JsonValue>> = vec![
2108 vec![json!("x"), json!(42)],
2109 vec![json!("y"), json!([1, 2, 3])],
2110 ];
2111 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2112 assert_eq!(df.count().unwrap(), 2);
2113 let collected = df.collect_inner().unwrap();
2114 let arr_col = collected.column("arr").unwrap();
2115 let list = arr_col.list().unwrap();
2116 let row0 = list.get(0).unwrap();
2117 assert_eq!(
2118 row0.len(),
2119 1,
2120 "#611: single value should become one-element list"
2121 );
2122 let row1 = list.get(1).unwrap();
2123 assert_eq!(row1.len(), 3);
2124 }
2125
2126 #[test]
2128 fn test_create_dataframe_from_rows_array_column() {
2129 use serde_json::json;
2130
2131 let spark = SparkSession::builder().app_name("test").get_or_create();
2132 let schema = vec![
2133 ("id".to_string(), "string".to_string()),
2134 ("arr".to_string(), "array<bigint>".to_string()),
2135 ];
2136 let rows: Vec<Vec<JsonValue>> = vec![
2137 vec![json!("x"), json!([1, 2, 3])],
2138 vec![json!("y"), json!([4, 5])],
2139 vec![json!("z"), json!(null)],
2140 ];
2141 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2142 assert_eq!(df.count().unwrap(), 3);
2143 let collected = df.collect_inner().unwrap();
2144 assert_eq!(collected.get_column_names(), &["id", "arr"]);
2145
2146 let arr_col = collected.column("arr").unwrap();
2148 let list = arr_col.list().unwrap();
2149 let row0 = list.get(0).unwrap();
2151 assert_eq!(row0.len(), 3, "row 0 arr should have 3 elements");
2152 let row1 = list.get(1).unwrap();
2154 assert_eq!(row1.len(), 2);
2155 let row2 = list.get(2);
2157 assert!(
2158 row2.is_none() || row2.as_ref().map(|a| a.is_empty()).unwrap_or(false),
2159 "row 2 arr should be null or empty"
2160 );
2161 }
2162
2163 #[test]
2166 fn test_issue_601_array_column_pyspark_parity() {
2167 use serde_json::json;
2168
2169 let spark = SparkSession::builder().app_name("test").get_or_create();
2170 let schema = vec![
2171 ("id".to_string(), "string".to_string()),
2172 ("arr".to_string(), "array<bigint>".to_string()),
2173 ];
2174 let rows: Vec<Vec<JsonValue>> = vec![
2176 vec![json!("x"), json!([1, 2, 3])],
2177 vec![json!("y"), json!([4, 5])],
2178 ];
2179 let df = spark
2180 .create_dataframe_from_rows(rows, schema)
2181 .expect("issue #601: create_dataframe_from_rows must accept array column (JSON array)");
2182 let n = df.count().unwrap();
2183 assert_eq!(n, 2, "issue #601: expected 2 rows");
2184 let collected = df.collect_inner().unwrap();
2185 let arr_col = collected.column("arr").unwrap();
2186 let list = arr_col.list().unwrap();
2187 let row0 = list.get(0).unwrap();
2189 assert_eq!(
2190 row0.len(),
2191 3,
2192 "issue #601: first row arr must have 3 elements [1,2,3]"
2193 );
2194 let row1 = list.get(1).unwrap();
2195 assert_eq!(
2196 row1.len(),
2197 2,
2198 "issue #601: second row arr must have 2 elements [4,5]"
2199 );
2200 }
2201
2202 #[test]
2204 fn test_issue_624_empty_schema_inferred_from_rows() {
2205 use serde_json::json;
2206
2207 let spark = SparkSession::builder().app_name("test").get_or_create();
2208 let schema: Vec<(String, String)> = vec![];
2209 let rows: Vec<Vec<JsonValue>> =
2210 vec![vec![json!("a"), json!(1)], vec![json!("b"), json!(2)]];
2211 let df = spark
2212 .create_dataframe_from_rows(rows, schema)
2213 .expect("#624: empty schema with non-empty rows should infer schema");
2214 assert_eq!(df.count().unwrap(), 2);
2215 let collected = df.collect_inner().unwrap();
2216 assert_eq!(collected.get_column_names(), &["c0", "c1"]);
2217 }
2218
2219 #[test]
2221 fn test_create_dataframe_from_rows_map_column() {
2222 use serde_json::json;
2223
2224 let spark = SparkSession::builder().app_name("test").get_or_create();
2225 let schema = vec![
2226 ("id".to_string(), "integer".to_string()),
2227 ("m".to_string(), "map<string,string>".to_string()),
2228 ];
2229 let rows: Vec<Vec<JsonValue>> = vec![
2230 vec![json!(1), json!({"a": "x", "b": "y"})],
2231 vec![json!(2), json!({"c": "z"})],
2232 ];
2233 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2234 assert_eq!(df.count().unwrap(), 2);
2235 let collected = df.collect_inner().unwrap();
2236 assert_eq!(collected.get_column_names(), &["id", "m"]);
2237 let m_col = collected.column("m").unwrap();
2238 let list = m_col.list().unwrap();
2239 let row0 = list.get(0).unwrap();
2240 assert_eq!(row0.len(), 2, "row 0 map should have 2 entries");
2241 let row1 = list.get(1).unwrap();
2242 assert_eq!(row1.len(), 1, "row 1 map should have 1 entry");
2243 }
2244
2245 #[test]
2247 fn test_issue_625_array_column_list_or_object() {
2248 use serde_json::json;
2249
2250 let spark = SparkSession::builder().app_name("test").get_or_create();
2251 let schema = vec![
2252 ("id".to_string(), "string".to_string()),
2253 ("arr".to_string(), "array<bigint>".to_string()),
2254 ];
2255 let rows: Vec<Vec<JsonValue>> = vec![
2257 vec![json!("x"), json!([1, 2, 3])],
2258 vec![json!("y"), json!({"0": 4, "1": 5})],
2259 ];
2260 let df = spark
2261 .create_dataframe_from_rows(rows, schema)
2262 .expect("#625: array column must accept list/array or object representation");
2263 assert_eq!(df.count().unwrap(), 2);
2264 let collected = df.collect_inner().unwrap();
2265 let list = collected.column("arr").unwrap().list().unwrap();
2266 assert_eq!(list.get(0).unwrap().len(), 3);
2267 assert_eq!(list.get(1).unwrap().len(), 2);
2268 }
2269
2270 #[test]
2271 fn test_create_dataframe_empty() {
2272 let spark = SparkSession::builder().app_name("test").get_or_create();
2273 let data: Vec<(i64, i64, String)> = vec![];
2274
2275 let result = spark.create_dataframe(data, vec!["id", "age", "name"]);
2276
2277 assert!(result.is_ok());
2278 let df = result.unwrap();
2279 assert_eq!(df.count().unwrap(), 0);
2280 }
2281
2282 #[test]
2283 fn test_create_dataframe_from_polars() {
2284 use polars::prelude::df;
2285
2286 let spark = SparkSession::builder().app_name("test").get_or_create();
2287 let polars_df = df!(
2288 "x" => &[1, 2, 3],
2289 "y" => &[4, 5, 6]
2290 )
2291 .unwrap();
2292
2293 let df = spark.create_dataframe_from_polars(polars_df);
2294
2295 assert_eq!(df.count().unwrap(), 3);
2296 let columns = df.columns().unwrap();
2297 assert!(columns.contains(&"x".to_string()));
2298 assert!(columns.contains(&"y".to_string()));
2299 }
2300
2301 #[test]
2302 fn test_read_csv_file_not_found() {
2303 let spark = SparkSession::builder().app_name("test").get_or_create();
2304
2305 let result = spark.read_csv("nonexistent_file.csv");
2306
2307 assert!(result.is_err());
2308 }
2309
2310 #[test]
2311 fn test_read_parquet_file_not_found() {
2312 let spark = SparkSession::builder().app_name("test").get_or_create();
2313
2314 let result = spark.read_parquet("nonexistent_file.parquet");
2315
2316 assert!(result.is_err());
2317 }
2318
2319 #[test]
2320 fn test_read_json_file_not_found() {
2321 let spark = SparkSession::builder().app_name("test").get_or_create();
2322
2323 let result = spark.read_json("nonexistent_file.json");
2324
2325 assert!(result.is_err());
2326 }
2327
2328 #[test]
2329 fn test_rust_udf_dataframe() {
2330 use crate::functions::{call_udf, col};
2331 use polars::prelude::DataType;
2332
2333 let spark = SparkSession::builder().app_name("test").get_or_create();
2334 spark
2335 .register_udf("to_str", |cols| cols[0].cast(&DataType::String))
2336 .unwrap();
2337 let df = spark
2338 .create_dataframe(
2339 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2340 vec!["id", "age", "name"],
2341 )
2342 .unwrap();
2343 let col = call_udf("to_str", &[col("id")]).unwrap();
2344 let df2 = df.with_column("id_str", &col).unwrap();
2345 let cols = df2.columns().unwrap();
2346 assert!(cols.contains(&"id_str".to_string()));
2347 let rows = df2.collect_as_json_rows().unwrap();
2348 assert_eq!(rows[0].get("id_str").and_then(|v| v.as_str()), Some("1"));
2349 assert_eq!(rows[1].get("id_str").and_then(|v| v.as_str()), Some("2"));
2350 }
2351
2352 #[test]
2353 fn test_case_insensitive_filter_select() {
2354 use crate::expression::lit_i64;
2355 use crate::functions::col;
2356
2357 let spark = SparkSession::builder().app_name("test").get_or_create();
2358 let df = spark
2359 .create_dataframe(
2360 vec![
2361 (1, 25, "Alice".to_string()),
2362 (2, 30, "Bob".to_string()),
2363 (3, 35, "Charlie".to_string()),
2364 ],
2365 vec!["Id", "Age", "Name"],
2366 )
2367 .unwrap();
2368 let filtered = df
2370 .filter(col("age").gt(lit_i64(26)).expr().clone())
2371 .unwrap()
2372 .select(vec!["name"])
2373 .unwrap();
2374 assert_eq!(filtered.count().unwrap(), 2);
2375 let rows = filtered.collect_as_json_rows().unwrap();
2376 let names: Vec<&str> = rows
2377 .iter()
2378 .map(|r| r.get("name").and_then(|v| v.as_str()).unwrap())
2379 .collect();
2380 assert!(names.contains(&"Bob"));
2381 assert!(names.contains(&"Charlie"));
2382 }
2383
2384 #[test]
2385 fn test_sql_returns_error_without_feature_or_unknown_table() {
2386 let spark = SparkSession::builder().app_name("test").get_or_create();
2387
2388 let result = spark.sql("SELECT * FROM table");
2389
2390 assert!(result.is_err());
2391 match result {
2392 Err(PolarsError::InvalidOperation(msg)) => {
2393 let s = msg.to_string();
2394 assert!(
2397 s.contains("SQL") || s.contains("Table") || s.contains("feature"),
2398 "unexpected message: {s}"
2399 );
2400 }
2401 _ => panic!("Expected InvalidOperation error"),
2402 }
2403 }
2404
2405 #[test]
2406 fn test_spark_session_stop() {
2407 let spark = SparkSession::builder().app_name("test").get_or_create();
2408
2409 spark.stop();
2411 }
2412
2413 #[test]
2414 fn test_dataframe_reader_api() {
2415 let spark = SparkSession::builder().app_name("test").get_or_create();
2416 let reader = spark.read();
2417
2418 assert!(reader.csv("nonexistent.csv").is_err());
2420 assert!(reader.parquet("nonexistent.parquet").is_err());
2421 assert!(reader.json("nonexistent.json").is_err());
2422 }
2423
2424 #[test]
2425 fn test_read_csv_with_valid_file() {
2426 use std::io::Write;
2427 use tempfile::NamedTempFile;
2428
2429 let spark = SparkSession::builder().app_name("test").get_or_create();
2430
2431 let mut temp_file = NamedTempFile::new().unwrap();
2433 writeln!(temp_file, "id,name,age").unwrap();
2434 writeln!(temp_file, "1,Alice,25").unwrap();
2435 writeln!(temp_file, "2,Bob,30").unwrap();
2436 temp_file.flush().unwrap();
2437
2438 let result = spark.read_csv(temp_file.path());
2439
2440 assert!(result.is_ok());
2441 let df = result.unwrap();
2442 assert_eq!(df.count().unwrap(), 2);
2443
2444 let columns = df.columns().unwrap();
2445 assert!(columns.contains(&"id".to_string()));
2446 assert!(columns.contains(&"name".to_string()));
2447 assert!(columns.contains(&"age".to_string()));
2448 }
2449
2450 #[test]
2451 fn test_read_json_with_valid_file() {
2452 use std::io::Write;
2453 use tempfile::NamedTempFile;
2454
2455 let spark = SparkSession::builder().app_name("test").get_or_create();
2456
2457 let mut temp_file = NamedTempFile::new().unwrap();
2459 writeln!(temp_file, r#"{{"id":1,"name":"Alice"}}"#).unwrap();
2460 writeln!(temp_file, r#"{{"id":2,"name":"Bob"}}"#).unwrap();
2461 temp_file.flush().unwrap();
2462
2463 let result = spark.read_json(temp_file.path());
2464
2465 assert!(result.is_ok());
2466 let df = result.unwrap();
2467 assert_eq!(df.count().unwrap(), 2);
2468 }
2469
2470 #[test]
2471 fn test_read_csv_empty_file() {
2472 use std::io::Write;
2473 use tempfile::NamedTempFile;
2474
2475 let spark = SparkSession::builder().app_name("test").get_or_create();
2476
2477 let mut temp_file = NamedTempFile::new().unwrap();
2479 writeln!(temp_file, "id,name").unwrap();
2480 temp_file.flush().unwrap();
2481
2482 let result = spark.read_csv(temp_file.path());
2483
2484 assert!(result.is_ok());
2485 let df = result.unwrap();
2486 assert_eq!(df.count().unwrap(), 0);
2487 }
2488
2489 #[test]
2490 fn test_write_partitioned_parquet() {
2491 use crate::dataframe::{WriteFormat, WriteMode};
2492 use std::fs;
2493 use tempfile::TempDir;
2494
2495 let spark = SparkSession::builder().app_name("test").get_or_create();
2496 let df = spark
2497 .create_dataframe(
2498 vec![
2499 (1, 25, "Alice".to_string()),
2500 (2, 30, "Bob".to_string()),
2501 (3, 25, "Carol".to_string()),
2502 ],
2503 vec!["id", "age", "name"],
2504 )
2505 .unwrap();
2506 let dir = TempDir::new().unwrap();
2507 let path = dir.path().join("out");
2508 df.write()
2509 .mode(WriteMode::Overwrite)
2510 .format(WriteFormat::Parquet)
2511 .partition_by(["age"])
2512 .save(&path)
2513 .unwrap();
2514 assert!(path.is_dir());
2515 let entries: Vec<_> = fs::read_dir(&path).unwrap().collect();
2516 assert_eq!(
2517 entries.len(),
2518 2,
2519 "expected two partition dirs (age=25, age=30)"
2520 );
2521 let names: Vec<String> = entries
2522 .iter()
2523 .filter_map(|e| e.as_ref().ok())
2524 .map(|e| e.file_name().to_string_lossy().into_owned())
2525 .collect();
2526 assert!(names.iter().any(|n| n.starts_with("age=")));
2527 let df_read = spark.read_parquet(&path).unwrap();
2528 assert_eq!(df_read.count().unwrap(), 3);
2529 }
2530
2531 #[test]
2532 fn test_save_as_table_error_if_exists() {
2533 use crate::dataframe::SaveMode;
2534
2535 let spark = SparkSession::builder().app_name("test").get_or_create();
2536 let df = spark
2537 .create_dataframe(
2538 vec![(1, 25, "Alice".to_string())],
2539 vec!["id", "age", "name"],
2540 )
2541 .unwrap();
2542 df.write()
2544 .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
2545 .unwrap();
2546 assert!(spark.table("t1").is_ok());
2547 assert_eq!(spark.table("t1").unwrap().count().unwrap(), 1);
2548 let err = df
2550 .write()
2551 .save_as_table(&spark, "t1", SaveMode::ErrorIfExists)
2552 .unwrap_err();
2553 assert!(err.to_string().contains("already exists"));
2554 }
2555
2556 #[test]
2557 fn test_save_as_table_overwrite() {
2558 use crate::dataframe::SaveMode;
2559
2560 let spark = SparkSession::builder().app_name("test").get_or_create();
2561 let df1 = spark
2562 .create_dataframe(
2563 vec![(1, 25, "Alice".to_string())],
2564 vec!["id", "age", "name"],
2565 )
2566 .unwrap();
2567 let df2 = spark
2568 .create_dataframe(
2569 vec![(2, 30, "Bob".to_string()), (3, 35, "Carol".to_string())],
2570 vec!["id", "age", "name"],
2571 )
2572 .unwrap();
2573 df1.write()
2574 .save_as_table(&spark, "t_over", SaveMode::ErrorIfExists)
2575 .unwrap();
2576 assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 1);
2577 df2.write()
2578 .save_as_table(&spark, "t_over", SaveMode::Overwrite)
2579 .unwrap();
2580 assert_eq!(spark.table("t_over").unwrap().count().unwrap(), 2);
2581 }
2582
2583 #[test]
2584 fn test_save_as_table_append() {
2585 use crate::dataframe::SaveMode;
2586
2587 let spark = SparkSession::builder().app_name("test").get_or_create();
2588 let df1 = spark
2589 .create_dataframe(
2590 vec![(1, 25, "Alice".to_string())],
2591 vec!["id", "age", "name"],
2592 )
2593 .unwrap();
2594 let df2 = spark
2595 .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
2596 .unwrap();
2597 df1.write()
2598 .save_as_table(&spark, "t_append", SaveMode::ErrorIfExists)
2599 .unwrap();
2600 df2.write()
2601 .save_as_table(&spark, "t_append", SaveMode::Append)
2602 .unwrap();
2603 assert_eq!(spark.table("t_append").unwrap().count().unwrap(), 2);
2604 }
2605
2606 #[test]
2608 fn test_save_as_table_empty_df_then_append() {
2609 use crate::dataframe::SaveMode;
2610 use serde_json::json;
2611
2612 let spark = SparkSession::builder().app_name("test").get_or_create();
2613 let schema = vec![
2614 ("id".to_string(), "bigint".to_string()),
2615 ("name".to_string(), "string".to_string()),
2616 ];
2617 let empty_df = spark
2618 .create_dataframe_from_rows(vec![], schema.clone())
2619 .unwrap();
2620 assert_eq!(empty_df.count().unwrap(), 0);
2621
2622 empty_df
2623 .write()
2624 .save_as_table(&spark, "t_empty_append", SaveMode::Overwrite)
2625 .unwrap();
2626 let r1 = spark.table("t_empty_append").unwrap();
2627 assert_eq!(r1.count().unwrap(), 0);
2628 let cols = r1.columns().unwrap();
2629 assert!(cols.contains(&"id".to_string()));
2630 assert!(cols.contains(&"name".to_string()));
2631
2632 let one_row = spark
2633 .create_dataframe_from_rows(vec![vec![json!(1), json!("a")]], schema)
2634 .unwrap();
2635 one_row
2636 .write()
2637 .save_as_table(&spark, "t_empty_append", SaveMode::Append)
2638 .unwrap();
2639 let r2 = spark.table("t_empty_append").unwrap();
2640 assert_eq!(r2.count().unwrap(), 1);
2641 }
2642
2643 #[test]
2646 fn test_write_parquet_empty_df_with_schema() {
2647 let spark = SparkSession::builder().app_name("test").get_or_create();
2648 let schema = vec![
2649 ("id".to_string(), "bigint".to_string()),
2650 ("name".to_string(), "string".to_string()),
2651 ];
2652 let empty_df = spark.create_dataframe_from_rows(vec![], schema).unwrap();
2653 assert_eq!(empty_df.count().unwrap(), 0);
2654
2655 let dir = tempfile::TempDir::new().unwrap();
2656 let path = dir.path().join("empty.parquet");
2657 empty_df
2658 .write()
2659 .format(crate::dataframe::WriteFormat::Parquet)
2660 .mode(crate::dataframe::WriteMode::Overwrite)
2661 .save(&path)
2662 .unwrap();
2663 assert!(path.is_file());
2664
2665 let read_df = spark.read().parquet(path.to_str().unwrap()).unwrap();
2667 assert_eq!(read_df.count().unwrap(), 0);
2668 let cols = read_df.columns().unwrap();
2669 assert!(cols.contains(&"id".to_string()));
2670 assert!(cols.contains(&"name".to_string()));
2671 }
2672
2673 #[test]
2675 fn test_save_as_table_empty_df_warehouse_then_append() {
2676 use crate::dataframe::SaveMode;
2677 use serde_json::json;
2678 use std::sync::atomic::{AtomicU64, Ordering};
2679 use tempfile::TempDir;
2680
2681 static COUNTER: AtomicU64 = AtomicU64::new(0);
2682 let n = COUNTER.fetch_add(1, Ordering::SeqCst);
2683 let dir = TempDir::new().unwrap();
2684 let warehouse = dir.path().join(format!("wh_{n}"));
2685 std::fs::create_dir_all(&warehouse).unwrap();
2686 let spark = SparkSession::builder()
2687 .app_name("test")
2688 .config(
2689 "spark.sql.warehouse.dir",
2690 warehouse.as_os_str().to_str().unwrap(),
2691 )
2692 .get_or_create();
2693
2694 let schema = vec![
2695 ("id".to_string(), "bigint".to_string()),
2696 ("name".to_string(), "string".to_string()),
2697 ];
2698 let empty_df = spark
2699 .create_dataframe_from_rows(vec![], schema.clone())
2700 .unwrap();
2701 empty_df
2702 .write()
2703 .save_as_table(&spark, "t_empty_wh", SaveMode::Overwrite)
2704 .unwrap();
2705 let r1 = spark.table("t_empty_wh").unwrap();
2706 assert_eq!(r1.count().unwrap(), 0);
2707
2708 let one_row = spark
2709 .create_dataframe_from_rows(vec![vec![json!(1), json!("a")]], schema)
2710 .unwrap();
2711 one_row
2712 .write()
2713 .save_as_table(&spark, "t_empty_wh", SaveMode::Append)
2714 .unwrap();
2715 let r2 = spark.table("t_empty_wh").unwrap();
2716 assert_eq!(r2.count().unwrap(), 1);
2717 }
2718
2719 #[test]
2720 fn test_save_as_table_ignore() {
2721 use crate::dataframe::SaveMode;
2722
2723 let spark = SparkSession::builder().app_name("test").get_or_create();
2724 let df1 = spark
2725 .create_dataframe(
2726 vec![(1, 25, "Alice".to_string())],
2727 vec!["id", "age", "name"],
2728 )
2729 .unwrap();
2730 let df2 = spark
2731 .create_dataframe(vec![(2, 30, "Bob".to_string())], vec!["id", "age", "name"])
2732 .unwrap();
2733 df1.write()
2734 .save_as_table(&spark, "t_ignore", SaveMode::ErrorIfExists)
2735 .unwrap();
2736 df2.write()
2737 .save_as_table(&spark, "t_ignore", SaveMode::Ignore)
2738 .unwrap();
2739 assert_eq!(spark.table("t_ignore").unwrap().count().unwrap(), 1);
2741 }
2742
2743 #[test]
2744 fn test_table_resolution_temp_view_first() {
2745 use crate::dataframe::SaveMode;
2746
2747 let spark = SparkSession::builder().app_name("test").get_or_create();
2748 let df_saved = spark
2749 .create_dataframe(
2750 vec![(1, 25, "Saved".to_string())],
2751 vec!["id", "age", "name"],
2752 )
2753 .unwrap();
2754 let df_temp = spark
2755 .create_dataframe(vec![(2, 30, "Temp".to_string())], vec!["id", "age", "name"])
2756 .unwrap();
2757 df_saved
2758 .write()
2759 .save_as_table(&spark, "x", SaveMode::ErrorIfExists)
2760 .unwrap();
2761 spark.create_or_replace_temp_view("x", df_temp);
2762 let t = spark.table("x").unwrap();
2764 let rows = t.collect_as_json_rows().unwrap();
2765 assert_eq!(rows.len(), 1);
2766 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Temp"));
2767 }
2768
2769 #[test]
2771 fn test_issue_629_temp_view_visible_after_create() {
2772 use serde_json::json;
2773
2774 let spark = SparkSession::builder().app_name("repro").get_or_create();
2775 let schema = vec![
2776 ("id".to_string(), "long".to_string()),
2777 ("name".to_string(), "string".to_string()),
2778 ];
2779 let rows: Vec<Vec<JsonValue>> =
2780 vec![vec![json!(1), json!("a")], vec![json!(2), json!("b")]];
2781 let df = spark.create_dataframe_from_rows(rows, schema).unwrap();
2782 spark.create_or_replace_temp_view("my_view", df);
2783 let result = spark
2784 .table("my_view")
2785 .unwrap()
2786 .collect_as_json_rows()
2787 .unwrap();
2788 assert_eq!(result.len(), 2);
2789 assert_eq!(result[0].get("id").and_then(|v| v.as_i64()), Some(1));
2790 assert_eq!(result[0].get("name").and_then(|v| v.as_str()), Some("a"));
2791 assert_eq!(result[1].get("id").and_then(|v| v.as_i64()), Some(2));
2792 assert_eq!(result[1].get("name").and_then(|v| v.as_str()), Some("b"));
2793 }
2794
2795 #[test]
2796 fn test_drop_table() {
2797 use crate::dataframe::SaveMode;
2798
2799 let spark = SparkSession::builder().app_name("test").get_or_create();
2800 let df = spark
2801 .create_dataframe(
2802 vec![(1, 25, "Alice".to_string())],
2803 vec!["id", "age", "name"],
2804 )
2805 .unwrap();
2806 df.write()
2807 .save_as_table(&spark, "t_drop", SaveMode::ErrorIfExists)
2808 .unwrap();
2809 assert!(spark.table("t_drop").is_ok());
2810 assert!(spark.drop_table("t_drop"));
2811 assert!(spark.table("t_drop").is_err());
2812 assert!(!spark.drop_table("t_drop"));
2814 }
2815
2816 #[test]
2817 fn test_global_temp_view_persists_across_sessions() {
2818 let spark1 = SparkSession::builder().app_name("s1").get_or_create();
2820 let df1 = spark1
2821 .create_dataframe(
2822 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2823 vec!["id", "age", "name"],
2824 )
2825 .unwrap();
2826 spark1.create_or_replace_global_temp_view("people", df1);
2827 assert_eq!(
2828 spark1.table("global_temp.people").unwrap().count().unwrap(),
2829 2
2830 );
2831
2832 let spark2 = SparkSession::builder().app_name("s2").get_or_create();
2834 let df2 = spark2.table("global_temp.people").unwrap();
2835 assert_eq!(df2.count().unwrap(), 2);
2836 let rows = df2.collect_as_json_rows().unwrap();
2837 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
2838
2839 let df_local = spark2
2841 .create_dataframe(
2842 vec![(3, 35, "Carol".to_string())],
2843 vec!["id", "age", "name"],
2844 )
2845 .unwrap();
2846 spark2.create_or_replace_temp_view("people", df_local);
2847 assert_eq!(spark2.table("people").unwrap().count().unwrap(), 1);
2849 assert_eq!(
2851 spark2.table("global_temp.people").unwrap().count().unwrap(),
2852 2
2853 );
2854
2855 assert!(spark2.drop_global_temp_view("people"));
2857 assert!(spark2.table("global_temp.people").is_err());
2858 }
2859
2860 #[test]
2861 fn test_warehouse_persistence_between_sessions() {
2862 use crate::dataframe::SaveMode;
2863 use std::fs;
2864 use tempfile::TempDir;
2865
2866 let dir = TempDir::new().unwrap();
2867 let warehouse = dir.path().to_str().unwrap();
2868
2869 let spark1 = SparkSession::builder()
2871 .app_name("w1")
2872 .config("spark.sql.warehouse.dir", warehouse)
2873 .get_or_create();
2874 let df1 = spark1
2875 .create_dataframe(
2876 vec![(1, 25, "Alice".to_string()), (2, 30, "Bob".to_string())],
2877 vec!["id", "age", "name"],
2878 )
2879 .unwrap();
2880 df1.write()
2881 .save_as_table(&spark1, "users", SaveMode::ErrorIfExists)
2882 .unwrap();
2883 assert_eq!(spark1.table("users").unwrap().count().unwrap(), 2);
2884
2885 let spark2 = SparkSession::builder()
2887 .app_name("w2")
2888 .config("spark.sql.warehouse.dir", warehouse)
2889 .get_or_create();
2890 let df2 = spark2.table("users").unwrap();
2891 assert_eq!(df2.count().unwrap(), 2);
2892 let rows = df2.collect_as_json_rows().unwrap();
2893 assert_eq!(rows[0].get("name").and_then(|v| v.as_str()), Some("Alice"));
2894
2895 let table_path = dir.path().join("users");
2897 assert!(table_path.is_dir());
2898 let entries: Vec<_> = fs::read_dir(&table_path).unwrap().collect();
2899 assert!(!entries.is_empty());
2900 }
2901}