1use std::{error::Error, path::PathBuf};
21
22#[macro_export]
54macro_rules! assert_batches_eq {
55 ($EXPECTED_LINES: expr, $CHUNKS: expr) => {
56 let expected_lines: Vec<String> =
57 $EXPECTED_LINES.iter().map(|&s| s.into()).collect();
58
59 let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options(
60 $CHUNKS,
61 &$crate::format::DEFAULT_FORMAT_OPTIONS,
62 )
63 .unwrap()
64 .to_string();
65
66 let actual_lines: Vec<&str> = formatted.trim().lines().collect();
67
68 assert_eq!(
69 expected_lines, actual_lines,
70 "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
71 expected_lines, actual_lines
72 );
73 };
74}
75
76#[macro_export]
86macro_rules! assert_batches_sorted_eq {
87 ($EXPECTED_LINES: expr, $CHUNKS: expr) => {
88 let mut expected_lines: Vec<String> =
89 $EXPECTED_LINES.iter().map(|&s| s.into()).collect();
90
91 let num_lines = expected_lines.len();
93 if num_lines > 3 {
94 expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable()
95 }
96
97 let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options(
98 $CHUNKS,
99 &$crate::format::DEFAULT_FORMAT_OPTIONS,
100 )
101 .unwrap()
102 .to_string();
103 let mut actual_lines: Vec<&str> = formatted.trim().lines().collect();
106
107 let num_lines = actual_lines.len();
109 if num_lines > 3 {
110 actual_lines.as_mut_slice()[2..num_lines - 1].sort_unstable()
111 }
112
113 assert_eq!(
114 expected_lines, actual_lines,
115 "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
116 expected_lines, actual_lines
117 );
118 };
119}
120
121#[macro_export]
131macro_rules! assert_contains {
132 ($ACTUAL: expr, $EXPECTED: expr) => {
133 let actual_value: String = $ACTUAL.into();
134 let expected_value: String = $EXPECTED.into();
135 assert!(
136 actual_value.contains(&expected_value),
137 "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}",
138 expected_value,
139 actual_value
140 );
141 };
142}
143
144#[macro_export]
154macro_rules! assert_not_contains {
155 ($ACTUAL: expr, $UNEXPECTED: expr) => {
156 let actual_value: String = $ACTUAL.into();
157 let unexpected_value: String = $UNEXPECTED.into();
158 assert!(
159 !actual_value.contains(&unexpected_value),
160 "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}",
161 unexpected_value,
162 actual_value
163 );
164 };
165}
166
167pub fn datafusion_test_data() -> String {
181 match get_data_dir("DATAFUSION_TEST_DATA", "../../datafusion/core/tests/data") {
182 Ok(pb) => pb.display().to_string(),
183 Err(err) => panic!("failed to get arrow data dir: {err}"),
184 }
185}
186
187pub fn arrow_test_data() -> String {
202 match get_data_dir("ARROW_TEST_DATA", "../../testing/data") {
203 Ok(pb) => pb.display().to_string(),
204 Err(err) => panic!("failed to get arrow data dir: {err}"),
205 }
206}
207
208#[cfg(feature = "parquet")]
224pub fn parquet_test_data() -> String {
225 match get_data_dir("PARQUET_TEST_DATA", "../../parquet-testing/data") {
226 Ok(pb) => pb.display().to_string(),
227 Err(err) => panic!("failed to get parquet data dir: {err}"),
228 }
229}
230
231pub fn get_data_dir(
241 udf_env: &str,
242 submodule_data: &str,
243) -> Result<PathBuf, Box<dyn Error>> {
244 if let Ok(dir) = std::env::var(udf_env) {
246 let trimmed = dir.trim().to_string();
247 if !trimmed.is_empty() {
248 let pb = PathBuf::from(trimmed);
249 if pb.is_dir() {
250 return Ok(pb);
251 } else {
252 return Err(format!(
253 "the data dir `{}` defined by env {} not found",
254 pb.display(),
255 udf_env
256 )
257 .into());
258 }
259 }
260 }
261
262 let dir = env!("CARGO_MANIFEST_DIR");
268
269 let pb = PathBuf::from(dir).join(submodule_data);
270 if pb.is_dir() {
271 Ok(pb)
272 } else {
273 Err(format!(
274 "env `{}` is undefined or has empty value, and the pre-defined data dir `{}` not found\n\
275 HINT: try running `git submodule update --init`",
276 udf_env,
277 pb.display(),
278 ).into())
279 }
280}
281
282#[macro_export]
283macro_rules! create_array {
284 (Boolean, $values: expr) => {
285 std::sync::Arc::new(arrow::array::BooleanArray::from($values))
286 };
287 (Int8, $values: expr) => {
288 std::sync::Arc::new(arrow::array::Int8Array::from($values))
289 };
290 (Int16, $values: expr) => {
291 std::sync::Arc::new(arrow::array::Int16Array::from($values))
292 };
293 (Int32, $values: expr) => {
294 std::sync::Arc::new(arrow::array::Int32Array::from($values))
295 };
296 (Int64, $values: expr) => {
297 std::sync::Arc::new(arrow::array::Int64Array::from($values))
298 };
299 (UInt8, $values: expr) => {
300 std::sync::Arc::new(arrow::array::UInt8Array::from($values))
301 };
302 (UInt16, $values: expr) => {
303 std::sync::Arc::new(arrow::array::UInt16Array::from($values))
304 };
305 (UInt32, $values: expr) => {
306 std::sync::Arc::new(arrow::array::UInt32Array::from($values))
307 };
308 (UInt64, $values: expr) => {
309 std::sync::Arc::new(arrow::array::UInt64Array::from($values))
310 };
311 (Float16, $values: expr) => {
312 std::sync::Arc::new(arrow::array::Float16Array::from($values))
313 };
314 (Float32, $values: expr) => {
315 std::sync::Arc::new(arrow::array::Float32Array::from($values))
316 };
317 (Float64, $values: expr) => {
318 std::sync::Arc::new(arrow::array::Float64Array::from($values))
319 };
320 (Utf8, $values: expr) => {
321 std::sync::Arc::new(arrow::array::StringArray::from($values))
322 };
323}
324
325#[macro_export]
338macro_rules! record_batch {
339 ($(($name: expr, $type: ident, $values: expr)),*) => {
340 {
341 let schema = std::sync::Arc::new(arrow::datatypes::Schema::new(vec![
342 $(
343 arrow::datatypes::Field::new($name, arrow::datatypes::DataType::$type, true),
344 )*
345 ]));
346
347 let batch = arrow::array::RecordBatch::try_new(
348 schema,
349 vec![$(
350 $crate::create_array!($type, $values),
351 )*]
352 );
353
354 batch
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use crate::cast::{as_float64_array, as_int32_array, as_string_array};
362 use crate::error::Result;
363
364 use super::*;
365 use std::env;
366
367 #[test]
368 fn test_data_dir() {
369 let udf_env = "get_data_dir";
370 let cwd = env::current_dir().unwrap();
371
372 let existing_pb = cwd.join("..");
373 let existing = existing_pb.display().to_string();
374 let existing_str = existing.as_str();
375
376 let non_existing = cwd.join("non-existing-dir").display().to_string();
377 let non_existing_str = non_existing.as_str();
378
379 env::set_var(udf_env, non_existing_str);
380 let res = get_data_dir(udf_env, existing_str);
381 assert!(res.is_err());
382
383 env::set_var(udf_env, "");
384 let res = get_data_dir(udf_env, existing_str);
385 assert!(res.is_ok());
386 assert_eq!(res.unwrap(), existing_pb);
387
388 env::set_var(udf_env, " ");
389 let res = get_data_dir(udf_env, existing_str);
390 assert!(res.is_ok());
391 assert_eq!(res.unwrap(), existing_pb);
392
393 env::set_var(udf_env, existing_str);
394 let res = get_data_dir(udf_env, existing_str);
395 assert!(res.is_ok());
396 assert_eq!(res.unwrap(), existing_pb);
397
398 env::remove_var(udf_env);
399 let res = get_data_dir(udf_env, non_existing_str);
400 assert!(res.is_err());
401
402 let res = get_data_dir(udf_env, existing_str);
403 assert!(res.is_ok());
404 assert_eq!(res.unwrap(), existing_pb);
405 }
406
407 #[test]
408 #[cfg(feature = "parquet")]
409 fn test_happy() {
410 let res = arrow_test_data();
411 assert!(PathBuf::from(res).is_dir());
412
413 let res = parquet_test_data();
414 assert!(PathBuf::from(res).is_dir());
415 }
416
417 #[test]
418 fn test_create_record_batch() -> Result<()> {
419 use arrow::array::Array;
420
421 let batch = record_batch!(
422 ("a", Int32, vec![1, 2, 3, 4]),
423 ("b", Float64, vec![Some(4.0), None, Some(5.0), None]),
424 ("c", Utf8, vec!["alpha", "beta", "gamma", "delta"])
425 )?;
426
427 assert_eq!(3, batch.num_columns());
428 assert_eq!(4, batch.num_rows());
429
430 let values: Vec<_> = as_int32_array(batch.column(0))?
431 .values()
432 .iter()
433 .map(|v| v.to_owned())
434 .collect();
435 assert_eq!(values, vec![1, 2, 3, 4]);
436
437 let values: Vec<_> = as_float64_array(batch.column(1))?
438 .values()
439 .iter()
440 .map(|v| v.to_owned())
441 .collect();
442 assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]);
443
444 let nulls: Vec<_> = as_float64_array(batch.column(1))?
445 .nulls()
446 .unwrap()
447 .iter()
448 .collect();
449 assert_eq!(nulls, vec![true, false, true, false]);
450
451 let values: Vec<_> = as_string_array(batch.column(2))?.iter().flatten().collect();
452 assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]);
453
454 Ok(())
455 }
456}