use std::sync::Arc;
use crate::datasource::file_format::{
format_as_file_type, parquet::ParquetFormatFactory,
};
use super::{
DataFrame, DataFrameWriteOptions, DataFusionError, LogicalPlanBuilder, RecordBatch,
};
use datafusion_common::config::TableParquetOptions;
use datafusion_common::not_impl_err;
use datafusion_expr::dml::InsertOp;
impl DataFrame {
pub async fn write_parquet(
self,
path: &str,
options: DataFrameWriteOptions,
writer_options: Option<TableParquetOptions>,
) -> Result<Vec<RecordBatch>, DataFusionError> {
if options.insert_op != InsertOp::Append {
return not_impl_err!(
"{} is not implemented for DataFrame::write_parquet.",
options.insert_op
);
}
let format = if let Some(parquet_opts) = writer_options {
Arc::new(ParquetFormatFactory::new_with_options(parquet_opts))
} else {
Arc::new(ParquetFormatFactory::new())
};
let file_type = format_as_file_type(format);
let copy_options = options.build_sink_options();
let plan = if options.sort_by.is_empty() {
self.plan
} else {
LogicalPlanBuilder::from(self.plan)
.sort(options.sort_by)?
.build()?
};
let plan = LogicalPlanBuilder::copy_to(
plan,
path.into(),
file_type,
copy_options,
options.partition_by,
)?
.build()?;
DataFrame {
session_state: self.session_state,
plan,
projection_requires_validation: self.projection_requires_validation,
}
.collect()
.await
}
}
#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::sync::Arc;
use super::super::Result;
use super::*;
use crate::arrow::util::pretty;
use crate::execution::context::SessionContext;
use crate::execution::options::ParquetReadOptions;
use crate::test_util::{self, register_aggregate_csv};
use datafusion_common::file_options::parquet_writer::parse_compression_string;
use datafusion_execution::config::SessionConfig;
use datafusion_expr::{col, lit};
#[cfg(feature = "parquet_encryption")]
use datafusion_common::config::ConfigFileEncryptionProperties;
use object_store::local::LocalFileSystem;
use parquet::file::reader::FileReader;
use tempfile::TempDir;
use url::Url;
#[tokio::test]
async fn filter_pushdown_dataframe() -> Result<()> {
let ctx = SessionContext::new();
ctx.register_parquet(
"test",
&format!(
"{}/alltypes_plain.snappy.parquet",
test_util::parquet_test_data()
),
ParquetReadOptions::default(),
)
.await?;
ctx.register_table("t1", ctx.table("test").await?.into_view())?;
let df = ctx
.table("t1")
.await?
.filter(col("id").eq(lit(1)))?
.select_columns(&["bool_col", "int_col"])?;
let plan = df.explain(false, false)?.collect().await?;
let formatted = pretty::pretty_format_batches(&plan)?.to_string();
assert!(formatted.contains("FilterExec: id@0 = 1"), "{formatted}");
Ok(())
}
#[tokio::test]
async fn write_parquet_with_compression() -> Result<()> {
let test_df = test_util::test_table().await?;
let output_path = "file://local/test.parquet";
let test_compressions = vec![
"snappy",
"brotli(1)",
"lz4",
"lz4_raw",
"gzip(6)",
"zstd(1)",
];
for compression in test_compressions.into_iter() {
let df = test_df.clone();
let tmp_dir = TempDir::new()?;
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
let local_url = Url::parse("file://local").unwrap();
let ctx = &test_df.session_state;
ctx.runtime_env().register_object_store(&local_url, local);
let mut options = TableParquetOptions::default();
options.global.compression = Some(compression.to_string());
df.write_parquet(
output_path,
DataFrameWriteOptions::new().with_single_file_output(true),
Some(options),
)
.await?;
let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?;
let reader =
parquet::file::serialized_reader::SerializedFileReader::new(file)
.unwrap();
let parquet_metadata = reader.metadata();
let written_compression =
parquet_metadata.row_group(0).column(0).compression();
assert_eq!(written_compression, parse_compression_string(compression)?);
}
Ok(())
}
#[tokio::test]
async fn write_parquet_with_small_rg_size() -> Result<()> {
let ctx = SessionContext::new_with_config(SessionConfig::from_string_hash_map(
&HashMap::from_iter(
[("datafusion.execution.batch_size", "10")]
.iter()
.map(|(s1, s2)| ((*s1).to_string(), (*s2).to_string())),
),
)?);
register_aggregate_csv(&ctx, "aggregate_test_100").await?;
let test_df = ctx.table("aggregate_test_100").await?;
let output_path = "file://local/test.parquet";
for rg_size in 1..10 {
let df = test_df.clone();
let tmp_dir = TempDir::new()?;
let local = Arc::new(LocalFileSystem::new_with_prefix(&tmp_dir)?);
let local_url = Url::parse("file://local").unwrap();
let ctx = &test_df.session_state;
ctx.runtime_env().register_object_store(&local_url, local);
let mut options = TableParquetOptions::default();
options.global.max_row_group_size = rg_size;
options.global.allow_single_file_parallelism = true;
df.write_parquet(
output_path,
DataFrameWriteOptions::new().with_single_file_output(true),
Some(options),
)
.await?;
let file = std::fs::File::open(tmp_dir.path().join("test.parquet"))?;
let reader =
parquet::file::serialized_reader::SerializedFileReader::new(file)
.unwrap();
let parquet_metadata = reader.metadata();
let written_rows = parquet_metadata.row_group(0).num_rows();
assert_eq!(written_rows as usize, rg_size);
}
Ok(())
}
#[rstest::rstest]
#[cfg(feature = "parquet_encryption")]
#[tokio::test]
async fn roundtrip_parquet_with_encryption(
#[values(false, true)] allow_single_file_parallelism: bool,
) -> Result<()> {
use parquet::encryption::decrypt::FileDecryptionProperties;
use parquet::encryption::encrypt::FileEncryptionProperties;
let test_df = test_util::test_table().await?;
let schema = test_df.schema();
let footer_key = b"0123456789012345".to_vec(); let column_key = b"1234567890123450".to_vec();
let mut encrypt = FileEncryptionProperties::builder(footer_key.clone());
let mut decrypt = FileDecryptionProperties::builder(footer_key.clone());
for field in schema.fields().iter() {
encrypt = encrypt.with_column_key(field.name().as_str(), column_key.clone());
decrypt = decrypt.with_column_key(field.name().as_str(), column_key.clone());
}
let encrypt = encrypt.build()?;
let decrypt = decrypt.build()?;
let df = test_df.clone();
let tmp_dir = TempDir::new()?;
let tempfile = tmp_dir.path().join("roundtrip.parquet");
let tempfile_str = tempfile.into_os_string().into_string().unwrap();
let mut options = TableParquetOptions::default();
options.crypto.file_encryption =
Some(ConfigFileEncryptionProperties::from(&encrypt));
options.global.allow_single_file_parallelism = allow_single_file_parallelism;
df.write_parquet(
tempfile_str.as_str(),
DataFrameWriteOptions::new().with_single_file_output(true),
Some(options),
)
.await?;
let num_rows_written = test_df.count().await?;
let ctx: SessionContext = SessionContext::new();
let read_options =
ParquetReadOptions::default().file_decryption_properties((&decrypt).into());
ctx.register_parquet("roundtrip_parquet", &tempfile_str, read_options.clone())
.await?;
let df_enc = ctx.sql("SELECT * FROM roundtrip_parquet").await?;
let num_rows_read = df_enc.count().await?;
assert_eq!(num_rows_read, num_rows_written);
let encrypted_parquet_df = ctx.read_parquet(tempfile_str, read_options).await?;
let selected = encrypted_parquet_df
.clone()
.select_columns(&["c1", "c2", "c3"])?
.filter(col("c2").gt(lit(4)))?;
let num_rows_selected = selected.count().await?;
assert_eq!(num_rows_selected, 14);
Ok(())
}
#[tokio::test]
async fn test_file_output_mode_single_file() -> Result<()> {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
let ctx = SessionContext::new();
let tmp_dir = TempDir::new()?;
let output_path = tmp_dir.path().join("data_no_ext");
let output_path_str = output_path.to_str().unwrap();
let df = ctx.read_batch(RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?)?;
df.write_parquet(
output_path_str,
DataFrameWriteOptions::new().with_single_file_output(true),
None,
)
.await?;
assert!(
output_path.is_file(),
"Expected single file at {:?}, but got is_file={}, is_dir={}",
output_path,
output_path.is_file(),
output_path.is_dir()
);
let file = std::fs::File::open(&output_path)?;
let reader = parquet::file::reader::SerializedFileReader::new(file)?;
let metadata = reader.metadata();
assert_eq!(metadata.num_row_groups(), 1);
assert_eq!(metadata.file_metadata().num_rows(), 3);
Ok(())
}
#[tokio::test]
async fn test_file_output_mode_automatic() -> Result<()> {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
let ctx = SessionContext::new();
let tmp_dir = TempDir::new()?;
let schema =
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)]));
let batch = RecordBatch::try_new(
schema,
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?;
let output_with_ext = tmp_dir.path().join("data.parquet");
let df = ctx.read_batch(batch.clone())?;
df.write_parquet(
output_with_ext.to_str().unwrap(),
DataFrameWriteOptions::new(), None,
)
.await?;
assert!(
output_with_ext.is_file(),
"Path with extension should be a single file, got is_file={}, is_dir={}",
output_with_ext.is_file(),
output_with_ext.is_dir()
);
let output_no_ext = tmp_dir.path().join("data_dir");
let df = ctx.read_batch(batch)?;
df.write_parquet(
output_no_ext.to_str().unwrap(),
DataFrameWriteOptions::new(), None,
)
.await?;
assert!(
output_no_ext.is_dir(),
"Path without extension should be a directory, got is_file={}, is_dir={}",
output_no_ext.is_file(),
output_no_ext.is_dir()
);
Ok(())
}
#[tokio::test]
async fn test_file_output_mode_directory() -> Result<()> {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;
let ctx = SessionContext::new();
let tmp_dir = TempDir::new()?;
let output_path = tmp_dir.path().join("output.parquet");
let output_path_str = output_path.to_str().unwrap();
let df = ctx.read_batch(RecordBatch::try_new(
Arc::new(Schema::new(vec![Field::new("id", DataType::Int32, false)])),
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
)?)?;
df.write_parquet(
output_path_str,
DataFrameWriteOptions::new().with_single_file_output(false),
None,
)
.await?;
assert!(
output_path.is_dir(),
"Expected directory at {:?}, but got is_file={}, is_dir={}",
output_path,
output_path.is_file(),
output_path.is_dir()
);
let entries: Vec<_> = std::fs::read_dir(&output_path)?
.filter_map(|e| e.ok())
.collect();
assert!(
!entries.is_empty(),
"Directory should contain at least one file"
);
Ok(())
}
}