datafusion_common/
test_util.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utility functions to make testing DataFusion based crates easier
19
20use std::{error::Error, path::PathBuf};
21
22/// Compares formatted output of a record batch with an expected
23/// vector of strings, with the result of pretty formatting record
24/// batches. This is a macro so errors appear on the correct line
25///
26/// Designed so that failure output can be directly copy/pasted
27/// into the test code as expected results.
28///
29/// Expects to be called about like this:
30///
31/// `assert_batches_eq!(expected_lines: &[&str], batches: &[RecordBatch])`
32///
33/// # Example
34/// ```
35/// # use std::sync::Arc;
36/// # use arrow::record_batch::RecordBatch;
37/// # use arrow::array::{ArrayRef, Int32Array};
38/// # use datafusion_common::assert_batches_eq;
39/// let col: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
40///  let batch = RecordBatch::try_from_iter([("column", col)]).unwrap();
41/// // Expected output is a vec of strings
42/// let expected = vec![
43///     "+--------+",
44///     "| column |",
45///     "+--------+",
46///     "| 1      |",
47///     "| 2      |",
48///     "+--------+",
49/// ];
50/// // compare the formatted output of the record batch with the expected output
51/// assert_batches_eq!(expected, &[batch]);
52/// ```
53#[macro_export]
54macro_rules! assert_batches_eq {
55    ($EXPECTED_LINES: expr, $CHUNKS: expr) => {
56        let expected_lines: Vec<String> =
57            $EXPECTED_LINES.iter().map(|&s| s.into()).collect();
58
59        let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options(
60            $CHUNKS,
61            &$crate::format::DEFAULT_FORMAT_OPTIONS,
62        )
63        .unwrap()
64        .to_string();
65
66        let actual_lines: Vec<&str> = formatted.trim().lines().collect();
67
68        assert_eq!(
69            expected_lines, actual_lines,
70            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
71            expected_lines, actual_lines
72        );
73    };
74}
75
76/// Compares formatted output of a record batch with an expected
77/// vector of strings in a way that order does not matter.
78/// This is a macro so errors appear on the correct line
79///
80/// See [`assert_batches_eq`] for more details and example.
81///
82/// Expects to be called about like this:
83///
84/// `assert_batch_sorted_eq!(expected_lines: &[&str], batches: &[RecordBatch])`
85#[macro_export]
86macro_rules! assert_batches_sorted_eq {
87    ($EXPECTED_LINES: expr, $CHUNKS: expr) => {
88        let mut expected_lines: Vec<String> =
89            $EXPECTED_LINES.iter().map(|&s| s.into()).collect();
90
91        // sort except for header + footer
92        let num_lines = expected_lines.len();
93        if num_lines > 3 {
94            expected_lines.as_mut_slice()[2..num_lines - 1].sort_unstable()
95        }
96
97        let formatted = $crate::arrow::util::pretty::pretty_format_batches_with_options(
98            $CHUNKS,
99            &$crate::format::DEFAULT_FORMAT_OPTIONS,
100        )
101        .unwrap()
102        .to_string();
103        // fix for windows: \r\n -->
104
105        let mut actual_lines: Vec<&str> = formatted.trim().lines().collect();
106
107        // sort except for header + footer
108        let num_lines = actual_lines.len();
109        if num_lines > 3 {
110            actual_lines.as_mut_slice()[2..num_lines - 1].sort_unstable()
111        }
112
113        assert_eq!(
114            expected_lines, actual_lines,
115            "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
116            expected_lines, actual_lines
117        );
118    };
119}
120
121/// A macro to assert that one string is contained within another with
122/// a nice error message if they are not.
123///
124/// Usage: `assert_contains!(actual, expected)`
125///
126/// Is a macro so test error
127/// messages are on the same line as the failure;
128///
129/// Both arguments must be convertable into Strings ([`Into`]<[`String`]>)
130#[macro_export]
131macro_rules! assert_contains {
132    ($ACTUAL: expr, $EXPECTED: expr) => {
133        let actual_value: String = $ACTUAL.into();
134        let expected_value: String = $EXPECTED.into();
135        assert!(
136            actual_value.contains(&expected_value),
137            "Can not find expected in actual.\n\nExpected:\n{}\n\nActual:\n{}",
138            expected_value,
139            actual_value
140        );
141    };
142}
143
144/// A macro to assert that one string is NOT contained within another with
145/// a nice error message if they are are.
146///
147/// Usage: `assert_not_contains!(actual, unexpected)`
148///
149/// Is a macro so test error
150/// messages are on the same line as the failure;
151///
152/// Both arguments must be convertable into Strings ([`Into`]<[`String`]>)
153#[macro_export]
154macro_rules! assert_not_contains {
155    ($ACTUAL: expr, $UNEXPECTED: expr) => {
156        let actual_value: String = $ACTUAL.into();
157        let unexpected_value: String = $UNEXPECTED.into();
158        assert!(
159            !actual_value.contains(&unexpected_value),
160            "Found unexpected in actual.\n\nUnexpected:\n{}\n\nActual:\n{}",
161            unexpected_value,
162            actual_value
163        );
164    };
165}
166
167/// Returns the datafusion test data directory, which is by default rooted at `datafusion/core/tests/data`.
168///
169/// The default can be overridden by the optional environment
170/// variable `DATAFUSION_TEST_DATA`
171///
172/// panics when the directory can not be found.
173///
174/// Example:
175/// ```
176/// let testdata = datafusion_common::test_util::datafusion_test_data();
177/// let csvdata = format!("{}/window_1.csv", testdata);
178/// assert!(std::path::PathBuf::from(csvdata).exists());
179/// ```
180pub fn datafusion_test_data() -> String {
181    match get_data_dir("DATAFUSION_TEST_DATA", "../../datafusion/core/tests/data") {
182        Ok(pb) => pb.display().to_string(),
183        Err(err) => panic!("failed to get arrow data dir: {err}"),
184    }
185}
186
187/// Returns the arrow test data directory, which is by default stored
188/// in a git submodule rooted at `testing/data`.
189///
190/// The default can be overridden by the optional environment
191/// variable `ARROW_TEST_DATA`
192///
193/// panics when the directory can not be found.
194///
195/// Example:
196/// ```
197/// let testdata = datafusion_common::test_util::arrow_test_data();
198/// let csvdata = format!("{}/csv/aggregate_test_100.csv", testdata);
199/// assert!(std::path::PathBuf::from(csvdata).exists());
200/// ```
201pub fn arrow_test_data() -> String {
202    match get_data_dir("ARROW_TEST_DATA", "../../testing/data") {
203        Ok(pb) => pb.display().to_string(),
204        Err(err) => panic!("failed to get arrow data dir: {err}"),
205    }
206}
207
208/// Returns the parquet test data directory, which is by default
209/// stored in a git submodule rooted at
210/// `parquet-testing/data`.
211///
212/// The default can be overridden by the optional environment variable
213/// `PARQUET_TEST_DATA`
214///
215/// panics when the directory can not be found.
216///
217/// Example:
218/// ```
219/// let testdata = datafusion_common::test_util::parquet_test_data();
220/// let filename = format!("{}/binary.parquet", testdata);
221/// assert!(std::path::PathBuf::from(filename).exists());
222/// ```
223#[cfg(feature = "parquet")]
224pub fn parquet_test_data() -> String {
225    match get_data_dir("PARQUET_TEST_DATA", "../../parquet-testing/data") {
226        Ok(pb) => pb.display().to_string(),
227        Err(err) => panic!("failed to get parquet data dir: {err}"),
228    }
229}
230
231/// Returns a directory path for finding test data.
232///
233/// udf_env: name of an environment variable
234///
235/// submodule_dir: fallback path (relative to CARGO_MANIFEST_DIR)
236///
237///  Returns either:
238/// The path referred to in `udf_env` if that variable is set and refers to a directory
239/// The submodule_data directory relative to CARGO_MANIFEST_PATH
240pub fn get_data_dir(
241    udf_env: &str,
242    submodule_data: &str,
243) -> Result<PathBuf, Box<dyn Error>> {
244    // Try user defined env.
245    if let Ok(dir) = std::env::var(udf_env) {
246        let trimmed = dir.trim().to_string();
247        if !trimmed.is_empty() {
248            let pb = PathBuf::from(trimmed);
249            if pb.is_dir() {
250                return Ok(pb);
251            } else {
252                return Err(format!(
253                    "the data dir `{}` defined by env {} not found",
254                    pb.display(),
255                    udf_env
256                )
257                .into());
258            }
259        }
260    }
261
262    // The env is undefined or its value is trimmed to empty, let's try default dir.
263
264    // env "CARGO_MANIFEST_DIR" is "the directory containing the manifest of your package",
265    // set by `cargo run` or `cargo test`, see:
266    // https://doc.rust-lang.org/cargo/reference/environment-variables.html
267    let dir = env!("CARGO_MANIFEST_DIR");
268
269    let pb = PathBuf::from(dir).join(submodule_data);
270    if pb.is_dir() {
271        Ok(pb)
272    } else {
273        Err(format!(
274            "env `{}` is undefined or has empty value, and the pre-defined data dir `{}` not found\n\
275             HINT: try running `git submodule update --init`",
276            udf_env,
277            pb.display(),
278        ).into())
279    }
280}
281
282#[macro_export]
283macro_rules! create_array {
284    (Boolean, $values: expr) => {
285        std::sync::Arc::new(arrow::array::BooleanArray::from($values))
286    };
287    (Int8, $values: expr) => {
288        std::sync::Arc::new(arrow::array::Int8Array::from($values))
289    };
290    (Int16, $values: expr) => {
291        std::sync::Arc::new(arrow::array::Int16Array::from($values))
292    };
293    (Int32, $values: expr) => {
294        std::sync::Arc::new(arrow::array::Int32Array::from($values))
295    };
296    (Int64, $values: expr) => {
297        std::sync::Arc::new(arrow::array::Int64Array::from($values))
298    };
299    (UInt8, $values: expr) => {
300        std::sync::Arc::new(arrow::array::UInt8Array::from($values))
301    };
302    (UInt16, $values: expr) => {
303        std::sync::Arc::new(arrow::array::UInt16Array::from($values))
304    };
305    (UInt32, $values: expr) => {
306        std::sync::Arc::new(arrow::array::UInt32Array::from($values))
307    };
308    (UInt64, $values: expr) => {
309        std::sync::Arc::new(arrow::array::UInt64Array::from($values))
310    };
311    (Float16, $values: expr) => {
312        std::sync::Arc::new(arrow::array::Float16Array::from($values))
313    };
314    (Float32, $values: expr) => {
315        std::sync::Arc::new(arrow::array::Float32Array::from($values))
316    };
317    (Float64, $values: expr) => {
318        std::sync::Arc::new(arrow::array::Float64Array::from($values))
319    };
320    (Utf8, $values: expr) => {
321        std::sync::Arc::new(arrow::array::StringArray::from($values))
322    };
323}
324
325/// Creates a record batch from literal slice of values, suitable for rapid
326/// testing and development.
327///
328/// Example:
329/// ```
330/// use datafusion_common::{record_batch, create_array};
331/// let batch = record_batch!(
332///     ("a", Int32, vec![1, 2, 3]),
333///     ("b", Float64, vec![Some(4.0), None, Some(5.0)]),
334///     ("c", Utf8, vec!["alpha", "beta", "gamma"])
335/// );
336/// ```
337#[macro_export]
338macro_rules! record_batch {
339    ($(($name: expr, $type: ident, $values: expr)),*) => {
340        {
341            let schema = std::sync::Arc::new(arrow::datatypes::Schema::new(vec![
342                $(
343                    arrow::datatypes::Field::new($name, arrow::datatypes::DataType::$type, true),
344                )*
345            ]));
346
347            let batch = arrow::array::RecordBatch::try_new(
348                schema,
349                vec![$(
350                    $crate::create_array!($type, $values),
351                )*]
352            );
353
354            batch
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use crate::cast::{as_float64_array, as_int32_array, as_string_array};
362    use crate::error::Result;
363
364    use super::*;
365    use std::env;
366
367    #[test]
368    fn test_data_dir() {
369        let udf_env = "get_data_dir";
370        let cwd = env::current_dir().unwrap();
371
372        let existing_pb = cwd.join("..");
373        let existing = existing_pb.display().to_string();
374        let existing_str = existing.as_str();
375
376        let non_existing = cwd.join("non-existing-dir").display().to_string();
377        let non_existing_str = non_existing.as_str();
378
379        env::set_var(udf_env, non_existing_str);
380        let res = get_data_dir(udf_env, existing_str);
381        assert!(res.is_err());
382
383        env::set_var(udf_env, "");
384        let res = get_data_dir(udf_env, existing_str);
385        assert!(res.is_ok());
386        assert_eq!(res.unwrap(), existing_pb);
387
388        env::set_var(udf_env, " ");
389        let res = get_data_dir(udf_env, existing_str);
390        assert!(res.is_ok());
391        assert_eq!(res.unwrap(), existing_pb);
392
393        env::set_var(udf_env, existing_str);
394        let res = get_data_dir(udf_env, existing_str);
395        assert!(res.is_ok());
396        assert_eq!(res.unwrap(), existing_pb);
397
398        env::remove_var(udf_env);
399        let res = get_data_dir(udf_env, non_existing_str);
400        assert!(res.is_err());
401
402        let res = get_data_dir(udf_env, existing_str);
403        assert!(res.is_ok());
404        assert_eq!(res.unwrap(), existing_pb);
405    }
406
407    #[test]
408    #[cfg(feature = "parquet")]
409    fn test_happy() {
410        let res = arrow_test_data();
411        assert!(PathBuf::from(res).is_dir());
412
413        let res = parquet_test_data();
414        assert!(PathBuf::from(res).is_dir());
415    }
416
417    #[test]
418    fn test_create_record_batch() -> Result<()> {
419        use arrow::array::Array;
420
421        let batch = record_batch!(
422            ("a", Int32, vec![1, 2, 3, 4]),
423            ("b", Float64, vec![Some(4.0), None, Some(5.0), None]),
424            ("c", Utf8, vec!["alpha", "beta", "gamma", "delta"])
425        )?;
426
427        assert_eq!(3, batch.num_columns());
428        assert_eq!(4, batch.num_rows());
429
430        let values: Vec<_> = as_int32_array(batch.column(0))?
431            .values()
432            .iter()
433            .map(|v| v.to_owned())
434            .collect();
435        assert_eq!(values, vec![1, 2, 3, 4]);
436
437        let values: Vec<_> = as_float64_array(batch.column(1))?
438            .values()
439            .iter()
440            .map(|v| v.to_owned())
441            .collect();
442        assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]);
443
444        let nulls: Vec<_> = as_float64_array(batch.column(1))?
445            .nulls()
446            .unwrap()
447            .iter()
448            .collect();
449        assert_eq!(nulls, vec![true, false, true, false]);
450
451        let values: Vec<_> = as_string_array(batch.column(2))?.iter().flatten().collect();
452        assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]);
453
454        Ok(())
455    }
456}