use crate::{
Client, Error, FileMetaStream, Result,
telemetry::{
self, FILE_SOURCE_BYTES_DOWNLOADED, FILE_SOURCE_FILES_READ, FILE_SOURCE_READ_DURATION_MS,
FILE_SOURCE_READ_ERRORS, FILE_SOURCE_ROWS_READ, telemetry_labels,
},
};
use arrow::array::RecordBatch;
use futures::{
StreamExt, TryStreamExt,
stream::{self, BoxStream},
};
use parquet::arrow::async_reader::ParquetRecordBatchStreamBuilder;
use std::{io::Cursor, path::Path, sync::Arc, time::Instant};
use tokio::fs::File;
pub const DEFAULT_BATCH_SIZE: usize = 8192;
pub type RecordBatchStream = BoxStream<'static, Result<RecordBatch>>;
pub fn source<I, P>(paths: I, batch_size: Option<usize>, label: Option<String>) -> RecordBatchStream
where
I: IntoIterator<Item = P>,
P: AsRef<Path>,
{
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
let paths: Vec<_> = paths
.into_iter()
.map(|p| p.as_ref().to_path_buf())
.collect();
let label_clone = label.clone();
stream::iter(paths)
.map(move |path| {
let name_label = label_clone.clone();
async move {
let start = Instant::now();
let file = File::open(&path).await?;
let builder = ParquetRecordBatchStreamBuilder::new(file).await?;
let stream = builder.with_batch_size(batch_size).build()?;
let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
let label = telemetry_labels!("source_name" => name_label.as_ref());
telemetry::increment_counter(FILE_SOURCE_FILES_READ, 1, label);
telemetry::record_histogram(FILE_SOURCE_READ_DURATION_MS, duration_ms, label);
Ok(stream)
}
})
.buffered(2) .flat_map(move |result| {
let name_label = label.clone();
match result {
Ok(stream) => stream
.inspect(move |result| match result {
Ok(batch) => {
let label = telemetry_labels!("source_name" => name_label.as_ref());
telemetry::increment_counter(
FILE_SOURCE_ROWS_READ,
batch.num_rows() as u64,
label,
)
}
Err(_) => {
let labels = telemetry_labels!("source_name" => name_label.as_ref(), "error_type" => "read");
telemetry::increment_counter(FILE_SOURCE_READ_ERRORS, 1, labels)
}
})
.map_err(Error::from)
.boxed(),
Err(err) => {
let labels = telemetry_labels!("source_name" => name_label.as_ref(), "error_type" => "open");
telemetry::increment_counter(FILE_SOURCE_READ_ERRORS, 1, labels);
stream::once(async { Err(err) }).boxed()
}
}
})
.boxed()
}
pub async fn source_s3_file(
client: &Client,
bucket: impl Into<String>,
key: impl Into<String>,
batch_size: Option<usize>,
label: Option<String>,
) -> Result<RecordBatchStream> {
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
let key = key.into();
let bucket = bucket.into();
let download_start = Instant::now();
let bytes = crate::get_file(client, bucket.as_str(), &key).await?;
let download_duration_ms = download_start.elapsed().as_secs_f64() * 1000.0;
let bytes_len = bytes.len();
telemetry::record_histogram(
FILE_SOURCE_READ_DURATION_MS,
download_duration_ms,
telemetry_labels!("bucket" => bucket.as_str(), "source_name" => label.as_ref()),
);
telemetry::record_histogram(
FILE_SOURCE_BYTES_DOWNLOADED,
bytes_len as f64,
telemetry_labels!("bucket" => bucket.as_str(), "source_name" => label.as_ref()),
);
const MIN_PARQUET_SIZE: usize = 12;
if bytes.is_empty() {
telemetry::increment_counter(
FILE_SOURCE_READ_ERRORS,
1,
telemetry_labels!("bucket" => bucket.as_str(), "error_type" => "empty_file", "source_name" => label.as_ref()),
);
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Empty parquet file: {key}"),
)));
}
if bytes.len() < MIN_PARQUET_SIZE {
telemetry::increment_counter(
FILE_SOURCE_READ_ERRORS,
1,
telemetry_labels!("bucket" => bucket.as_str(), "error_type" => "invalid_size", "source_name" => label.as_ref()),
);
return Err(Error::Io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"Parquet file missing too small ({} bytes) minimum fixed-length values: {key}",
bytes.len()
),
)));
}
let cursor = Cursor::new(bytes);
let builder = ParquetRecordBatchStreamBuilder::new(cursor).await?;
let stream = builder.with_batch_size(batch_size).build()?;
telemetry::increment_counter(
FILE_SOURCE_FILES_READ,
1,
telemetry_labels!("bucket" => bucket.as_str(), "source_name" => label.as_ref()),
);
Ok(stream
.inspect(move |result| match result {
Ok(batch) => {
telemetry::increment_counter(
FILE_SOURCE_ROWS_READ,
batch.num_rows() as u64,
telemetry_labels!("bucket" => bucket.as_str(), "source_name" => label.as_ref()),
);
}
Err(_) => {
telemetry::increment_counter(
FILE_SOURCE_READ_ERRORS,
1,
telemetry_labels!("bucket" => bucket.as_str(), "error_type" => "read", "source_name" => label.as_ref()),
);
}
})
.map_err(Error::from)
.boxed())
}
pub fn source_s3_files(
client: &Client,
bucket: impl Into<String>,
metas: FileMetaStream,
batch_size: Option<usize>,
label: Option<String>,
) -> RecordBatchStream {
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
let bucket = bucket.into();
let client = Arc::new(client.clone());
metas
.map(move |meta_result| {
let client = Arc::clone(&client);
let bucket = bucket.clone();
let label = label.clone();
async move {
let meta = meta_result?;
source_s3_file(&client, bucket, meta.key, Some(batch_size), label).await
}
})
.buffered(1) .flat_map(|result| match result {
Ok(stream) => stream.boxed(),
Err(err) => stream::once(async { Err(err) }).boxed(),
})
.boxed()
}
pub fn source_s3_files_unordered(
client: &Client,
bucket: impl Into<String>,
workers: usize,
metas: FileMetaStream,
batch_size: Option<usize>,
label: Option<String>,
) -> RecordBatchStream {
let batch_size = batch_size.unwrap_or(DEFAULT_BATCH_SIZE);
let bucket = bucket.into();
let client = Arc::new(client.clone());
metas
.map(move |meta_result| {
let client = Arc::clone(&client);
let bucket = bucket.clone();
let label = label.clone();
async move {
let meta = meta_result?;
source_s3_file(&client, bucket, meta.key, Some(batch_size), label).await
}
})
.buffer_unordered(workers) .flat_map(|result| match result {
Ok(stream) => stream.boxed(),
Err(err) => stream::once(async { Err(err) }).boxed(),
})
.boxed()
}
pub async fn deserialize_to_vec<T>(mut stream: RecordBatchStream) -> Result<Vec<T>>
where
T: for<'de> serde::Deserialize<'de>,
{
let mut stream_results = Vec::new();
while let Some(batch) = stream.next().await.transpose()? {
let records: Vec<T> =
serde_arrow::from_record_batch(&batch).map_err(|e| Error::SerdeArrow(e.to_string()))?;
stream_results.extend(records);
}
Ok(stream_results)
}
pub fn deserialize_stream<T>(stream: RecordBatchStream) -> BoxStream<'static, Result<T>>
where
T: for<'de> serde::Deserialize<'de> + Send + 'static,
{
stream
.flat_map(|batch_result| {
match batch_result {
Ok(batch) => {
match serde_arrow::from_record_batch::<Vec<T>>(&batch) {
Ok(records) => {
stream::iter(records).map(Ok).boxed()
}
Err(e) => {
stream::once(async move { Err(Error::SerdeArrow(e.to_string())) })
.boxed()
}
}
}
Err(e) => {
stream::once(async move { Err(e) }).boxed()
}
}
})
.boxed()
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::{Float64Array, Int64Array, RecordBatch, StringArray};
use arrow::datatypes::{DataType, Field, Schema};
use parquet::arrow::ArrowWriter;
use std::sync::Arc;
use tempfile::{NamedTempFile, TempDir};
async fn create_test_parquet_file() -> NamedTempFile {
let temp_file = NamedTempFile::new().unwrap();
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("name", DataType::Utf8, false),
Field::new("value", DataType::Float64, false),
]));
let id_array = Int64Array::from(vec![1, 2, 3, 4, 5]);
let name_array = StringArray::from(vec!["Alice", "Bob", "Charlie", "Diana", "Eve"]);
let value_array = Float64Array::from(vec![1.5, 2.5, 3.5, 4.5, 5.5]);
let batch = RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(id_array),
Arc::new(name_array),
Arc::new(value_array),
],
)
.unwrap();
let file = std::fs::File::create(temp_file.path()).unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
temp_file
}
async fn create_test_parquet_files(count: usize) -> (TempDir, Vec<std::path::PathBuf>) {
let temp_dir = TempDir::new().unwrap();
let mut paths = Vec::new();
for i in 0..count {
let file_path = temp_dir.path().join(format!("test_{}.parquet", i));
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("batch", DataType::Int64, false),
]));
let id_array =
Int64Array::from(vec![i as i64 * 10, i as i64 * 10 + 1, i as i64 * 10 + 2]);
let batch_array = Int64Array::from(vec![i as i64, i as i64, i as i64]);
let batch = RecordBatch::try_new(
schema.clone(),
vec![Arc::new(id_array), Arc::new(batch_array)],
)
.unwrap();
let file = std::fs::File::create(&file_path).unwrap();
let mut writer = ArrowWriter::try_new(file, schema, None).unwrap();
writer.write(&batch).unwrap();
writer.close().unwrap();
paths.push(file_path);
}
(temp_dir, paths)
}
#[tokio::test]
async fn test_source_local_single_file() {
let temp_file = create_test_parquet_file().await;
let paths = vec![temp_file.path()];
let mut stream = source(paths, None, None);
let mut total_rows = 0;
while let Some(batch_result) = stream.next().await {
let batch = batch_result.unwrap();
total_rows += batch.num_rows();
}
assert_eq!(total_rows, 5);
}
#[tokio::test]
async fn test_source_local_multiple_files() {
let (_temp_dir, paths) = create_test_parquet_files(3).await;
let mut stream = source(paths, None, None);
let mut total_rows = 0;
while let Some(batch_result) = stream.next().await {
let batch = batch_result.unwrap();
total_rows += batch.num_rows();
}
assert_eq!(total_rows, 9);
}
#[tokio::test]
async fn test_source_local_custom_batch_size() {
let temp_file = create_test_parquet_file().await;
let paths = vec![temp_file.path()];
let mut stream = source(paths, Some(2), None);
let mut batch_count = 0;
let mut total_rows = 0;
while let Some(batch_result) = stream.next().await {
let batch = batch_result.unwrap();
batch_count += 1;
total_rows += batch.num_rows();
}
assert_eq!(total_rows, 5);
assert!(batch_count >= 2);
}
#[tokio::test]
async fn test_source_local_empty_paths() {
let paths: Vec<&str> = vec![];
let mut stream = source(paths, None, None);
let result = stream.next().await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_source_local_nonexistent_file() {
let paths = vec!["/tmp/nonexistent_file_12345.parquet"];
let mut stream = source(paths, None, None);
let result = stream.next().await;
assert!(result.is_some());
assert!(result.unwrap().is_err());
}
#[tokio::test]
async fn test_source_local_verify_data_integrity() {
let temp_file = create_test_parquet_file().await;
let paths = vec![temp_file.path()];
let mut stream = source(paths, None, None);
let batch_result = stream.next().await;
assert!(batch_result.is_some());
let batch = batch_result.unwrap().unwrap();
assert_eq!(batch.num_rows(), 5);
assert_eq!(batch.num_columns(), 3);
let schema = batch.schema();
assert_eq!(schema.field(0).name(), "id");
assert_eq!(schema.field(1).name(), "name");
assert_eq!(schema.field(2).name(), "value");
let id_col = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
let name_col = batch
.column(1)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let value_col = batch
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.unwrap();
assert_eq!(id_col.value(0), 1);
assert_eq!(name_col.value(0), "Alice");
assert_eq!(value_col.value(0), 1.5);
}
#[tokio::test]
async fn test_default_batch_size_constant() {
assert_eq!(DEFAULT_BATCH_SIZE, 8192);
}
#[tokio::test]
async fn test_source_with_metrics() {
let temp_file = create_test_parquet_file().await;
let paths = vec![temp_file.path()];
let mut stream = source(paths, None, Some("test_read_metric".to_string()));
let mut total_rows = 0;
while let Some(batch_result) = stream.next().await {
let batch = batch_result.unwrap();
total_rows += batch.num_rows();
}
assert_eq!(total_rows, 5);
}
}