datafusion/test_util/
mod.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18//! Utility functions to make testing DataFusion based crates easier
19
20#[cfg(feature = "parquet")]
21pub mod parquet;
22
23pub mod csv;
24
25use futures::Stream;
26use std::any::Any;
27use std::collections::HashMap;
28use std::fs::File;
29use std::io::Write;
30use std::path::Path;
31use std::sync::Arc;
32use std::task::{Context, Poll};
33
34use crate::catalog::{TableProvider, TableProviderFactory};
35use crate::dataframe::DataFrame;
36use crate::datasource::stream::{FileStreamProvider, StreamConfig, StreamTable};
37use crate::datasource::{empty::EmptyTable, provider_as_source};
38use crate::error::Result;
39use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
40use crate::physical_plan::ExecutionPlan;
41use crate::prelude::{CsvReadOptions, SessionContext};
42
43use crate::execution::SendableRecordBatchStream;
44use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
45use arrow::record_batch::RecordBatch;
46use datafusion_catalog::Session;
47use datafusion_common::TableReference;
48use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType};
49use std::pin::Pin;
50
51use async_trait::async_trait;
52
53use tempfile::TempDir;
54// backwards compatibility
55#[cfg(feature = "parquet")]
56pub use datafusion_common::test_util::parquet_test_data;
57pub use datafusion_common::test_util::{arrow_test_data, get_data_dir};
58
59use crate::execution::RecordBatchStream;
60
61/// Scan an empty data source, mainly used in tests
62pub fn scan_empty(
63    name: Option<&str>,
64    table_schema: &Schema,
65    projection: Option<Vec<usize>>,
66) -> Result<LogicalPlanBuilder> {
67    let table_schema = Arc::new(table_schema.clone());
68    let provider = Arc::new(EmptyTable::new(table_schema));
69    let name = TableReference::bare(name.unwrap_or(UNNAMED_TABLE));
70    LogicalPlanBuilder::scan(name, provider_as_source(provider), projection)
71}
72
73/// Scan an empty data source with configured partition, mainly used in tests.
74pub fn scan_empty_with_partitions(
75    name: Option<&str>,
76    table_schema: &Schema,
77    projection: Option<Vec<usize>>,
78    partitions: usize,
79) -> Result<LogicalPlanBuilder> {
80    let table_schema = Arc::new(table_schema.clone());
81    let provider = Arc::new(EmptyTable::new(table_schema).with_partitions(partitions));
82    let name = TableReference::bare(name.unwrap_or(UNNAMED_TABLE));
83    LogicalPlanBuilder::scan(name, provider_as_source(provider), projection)
84}
85
86/// Get the schema for the aggregate_test_* csv files
87pub fn aggr_test_schema() -> SchemaRef {
88    let mut f1 = Field::new("c1", DataType::Utf8, false);
89    f1.set_metadata(HashMap::from_iter(vec![("testing".into(), "test".into())]));
90    let schema = Schema::new(vec![
91        f1,
92        Field::new("c2", DataType::UInt32, false),
93        Field::new("c3", DataType::Int8, false),
94        Field::new("c4", DataType::Int16, false),
95        Field::new("c5", DataType::Int32, false),
96        Field::new("c6", DataType::Int64, false),
97        Field::new("c7", DataType::UInt8, false),
98        Field::new("c8", DataType::UInt16, false),
99        Field::new("c9", DataType::UInt32, false),
100        Field::new("c10", DataType::UInt64, false),
101        Field::new("c11", DataType::Float32, false),
102        Field::new("c12", DataType::Float64, false),
103        Field::new("c13", DataType::Utf8, false),
104    ]);
105
106    Arc::new(schema)
107}
108
109/// Register session context for the aggregate_test_100.csv file
110pub async fn register_aggregate_csv(
111    ctx: &SessionContext,
112    table_name: &str,
113) -> Result<()> {
114    let schema = aggr_test_schema();
115    let testdata = arrow_test_data();
116    ctx.register_csv(
117        table_name,
118        &format!("{testdata}/csv/aggregate_test_100.csv"),
119        CsvReadOptions::new().schema(schema.as_ref()),
120    )
121    .await?;
122    Ok(())
123}
124
125/// Create a table from the aggregate_test_100.csv file with the specified name
126pub async fn test_table_with_name(name: &str) -> Result<DataFrame> {
127    let ctx = SessionContext::new();
128    register_aggregate_csv(&ctx, name).await?;
129    ctx.table(name).await
130}
131
132/// Create a table from the aggregate_test_100.csv file with the name "aggregate_test_100"
133pub async fn test_table() -> Result<DataFrame> {
134    test_table_with_name("aggregate_test_100").await
135}
136
137/// Execute SQL and return results
138pub async fn plan_and_collect(
139    ctx: &SessionContext,
140    sql: &str,
141) -> Result<Vec<RecordBatch>> {
142    ctx.sql(sql).await?.collect().await
143}
144
145/// Generate CSV partitions within the supplied directory
146pub fn populate_csv_partitions(
147    tmp_dir: &TempDir,
148    partition_count: usize,
149    file_extension: &str,
150) -> Result<SchemaRef> {
151    // define schema for data source (csv file)
152    let schema = Arc::new(Schema::new(vec![
153        Field::new("c1", DataType::UInt32, false),
154        Field::new("c2", DataType::UInt64, false),
155        Field::new("c3", DataType::Boolean, false),
156    ]));
157
158    // generate a partitioned file
159    for partition in 0..partition_count {
160        let filename = format!("partition-{partition}.{file_extension}");
161        let file_path = tmp_dir.path().join(filename);
162        let mut file = File::create(file_path)?;
163
164        // generate some data
165        for i in 0..=10 {
166            let data = format!("{},{},{}\n", partition, i, i % 2 == 0);
167            file.write_all(data.as_bytes())?;
168        }
169    }
170
171    Ok(schema)
172}
173
174/// TableFactory for tests
175#[derive(Default, Debug)]
176pub struct TestTableFactory {}
177
178#[async_trait]
179impl TableProviderFactory for TestTableFactory {
180    async fn create(
181        &self,
182        _: &dyn Session,
183        cmd: &CreateExternalTable,
184    ) -> Result<Arc<dyn TableProvider>> {
185        Ok(Arc::new(TestTableProvider {
186            url: cmd.location.to_string(),
187            schema: Arc::new(cmd.schema.as_ref().into()),
188        }))
189    }
190}
191
192/// TableProvider for testing purposes
193#[derive(Debug)]
194pub struct TestTableProvider {
195    /// URL of table files or folder
196    pub url: String,
197    /// test table schema
198    pub schema: SchemaRef,
199}
200
201impl TestTableProvider {}
202
203#[async_trait]
204impl TableProvider for TestTableProvider {
205    fn as_any(&self) -> &dyn Any {
206        self
207    }
208
209    fn schema(&self) -> SchemaRef {
210        Arc::clone(&self.schema)
211    }
212
213    fn table_type(&self) -> TableType {
214        unimplemented!("TestTableProvider is a stub for testing.")
215    }
216
217    async fn scan(
218        &self,
219        _state: &dyn Session,
220        _projection: Option<&Vec<usize>>,
221        _filters: &[Expr],
222        _limit: Option<usize>,
223    ) -> Result<Arc<dyn ExecutionPlan>> {
224        unimplemented!("TestTableProvider is a stub for testing.")
225    }
226}
227
228/// This function creates an unbounded sorted file for testing purposes.
229pub fn register_unbounded_file_with_ordering(
230    ctx: &SessionContext,
231    schema: SchemaRef,
232    file_path: &Path,
233    table_name: &str,
234    file_sort_order: Vec<Vec<SortExpr>>,
235) -> Result<()> {
236    let source = FileStreamProvider::new_file(schema, file_path.into());
237    let config = StreamConfig::new(Arc::new(source)).with_order(file_sort_order);
238
239    // Register table:
240    ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?;
241    Ok(())
242}
243
244/// Creates a bounded stream that emits the same record batch a specified number of times.
245/// This is useful for testing purposes.
246pub fn bounded_stream(
247    record_batch: RecordBatch,
248    limit: usize,
249) -> SendableRecordBatchStream {
250    Box::pin(BoundedStream {
251        record_batch,
252        count: 0,
253        limit,
254    })
255}
256
257struct BoundedStream {
258    record_batch: RecordBatch,
259    count: usize,
260    limit: usize,
261}
262
263impl Stream for BoundedStream {
264    type Item = Result<RecordBatch, crate::error::DataFusionError>;
265
266    fn poll_next(
267        mut self: Pin<&mut Self>,
268        _cx: &mut Context<'_>,
269    ) -> Poll<Option<Self::Item>> {
270        if self.count >= self.limit {
271            Poll::Ready(None)
272        } else {
273            self.count += 1;
274            Poll::Ready(Some(Ok(self.record_batch.clone())))
275        }
276    }
277}
278
279impl RecordBatchStream for BoundedStream {
280    fn schema(&self) -> SchemaRef {
281        self.record_batch.schema()
282    }
283}