use std::fs::File;
use std::io::BufWriter;
use std::path::Path;
use std::sync::Arc;
use arrow_array::builder::{
BooleanBuilder, Float64Builder, Int64Builder, StringBuilder, TimestampMicrosecondBuilder,
};
use arrow_array::{ArrayRef, RecordBatch};
use arrow_schema::{DataType, Field, Schema, TimeUnit};
use chrono::{DateTime, Utc};
use narwhal_core::{ColumnHeader, Row, Value};
use parquet::arrow::ArrowWriter;
use parquet::basic::{Compression, ZstdLevel};
use parquet::file::properties::WriterProperties;
use super::error::ExportError;
use super::format::ParquetCompression;
const SCHEMA_SAMPLE: usize = 100;
pub(super) fn write_parquet(
columns: &[ColumnHeader],
rows: &[Row],
path: &Path,
compression: ParquetCompression,
) -> Result<(), ExportError> {
if columns.is_empty() {
return Err(ExportError::Serialise(
"parquet export needs at least one column — run a query first".to_owned(),
));
}
let logical_types: Vec<LogicalType> = columns
.iter()
.enumerate()
.map(|(idx, col)| infer_column_type(idx, col, rows))
.collect();
let fields: Vec<Field> = columns
.iter()
.zip(logical_types.iter())
.map(|(col, ty)| Field::new(&col.name, ty.arrow_data_type(), true))
.collect();
let schema = Arc::new(Schema::new(fields));
let mut builders: Vec<ColumnBuilder> = logical_types
.iter()
.zip(columns.iter())
.map(|(ty, col)| ColumnBuilder::new(*ty, &col.name))
.collect();
for row in rows {
for (idx, value) in row.0.iter().enumerate() {
if let Some(builder) = builders.get_mut(idx) {
builder.append_value(value);
}
}
for builder in builders.iter_mut().skip(row.0.len()) {
builder.append_null();
}
}
let arrays: Vec<ArrayRef> = builders
.into_iter()
.map(ColumnBuilder::finish)
.collect::<Result<Vec<_>, _>>()?;
let batch = RecordBatch::try_new(Arc::clone(&schema), arrays)
.map_err(|e| ExportError::Serialise(format!("parquet record batch: {e}")))?;
let props = WriterProperties::builder()
.set_compression(compression_codec(compression))
.build();
let staging = staging_path(path);
if let Some(parent) = staging.parent() {
if !parent.as_os_str().is_empty() {
std::fs::create_dir_all(parent)?;
}
}
let file = File::create(&staging)?;
let mut writer = ArrowWriter::try_new(BufWriter::new(file), schema, Some(props))
.map_err(|e| ExportError::Serialise(format!("parquet writer init: {e}")))?;
if let Err(error) = writer.write(&batch) {
let _ = std::fs::remove_file(&staging);
return Err(ExportError::Serialise(format!(
"parquet write batch: {error}"
)));
}
if let Err(error) = writer.close() {
let _ = std::fs::remove_file(&staging);
return Err(ExportError::Serialise(format!(
"parquet writer close: {error}"
)));
}
if let Err(error) = std::fs::rename(&staging, path) {
let _ = std::fs::remove_file(&staging);
return Err(ExportError::Io(error));
}
Ok(())
}
fn staging_path(target: &Path) -> std::path::PathBuf {
let mut staging = target.to_path_buf();
let stem = staging.file_name().map_or_else(
|| "narwhal-export".to_owned(),
|n| n.to_string_lossy().into_owned(),
);
staging.set_file_name(format!(".{stem}.tmp"));
staging
}
fn compression_codec(compression: ParquetCompression) -> Compression {
match compression {
ParquetCompression::Snappy => Compression::SNAPPY,
ParquetCompression::Zstd => Compression::ZSTD(ZstdLevel::default()),
ParquetCompression::None => Compression::UNCOMPRESSED,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum LogicalType {
Bool,
Int64,
Float64,
Utf8,
Timestamp,
}
impl LogicalType {
fn arrow_data_type(self) -> DataType {
match self {
Self::Bool => DataType::Boolean,
Self::Int64 => DataType::Int64,
Self::Float64 => DataType::Float64,
Self::Utf8 => DataType::Utf8,
Self::Timestamp => DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())),
}
}
fn widen(self, other: Self) -> Self {
use LogicalType::{Bool, Float64, Int64, Timestamp, Utf8};
match (self, other) {
(a, b) if a == b => a,
(Int64, Float64) | (Float64, Int64) => Float64,
(Bool, Int64) | (Int64, Bool) => Int64,
(Bool, Float64) | (Float64, Bool) => Float64,
(Timestamp, _) | (_, Timestamp) => Utf8,
_ => Utf8,
}
}
}
fn infer_column_type(idx: usize, header: &ColumnHeader, rows: &[Row]) -> LogicalType {
let mut inferred: Option<LogicalType> = type_hint_from_header(&header.data_type);
for row in rows.iter().take(SCHEMA_SAMPLE) {
let Some(value) = row.0.get(idx) else {
continue;
};
let Some(observed) = type_from_value(value) else {
continue;
};
inferred = Some(match inferred {
Some(existing) => existing.widen(observed),
None => observed,
});
if matches!(inferred, Some(LogicalType::Utf8)) {
break;
}
}
inferred.unwrap_or(LogicalType::Utf8)
}
fn type_hint_from_header(data_type: &str) -> Option<LogicalType> {
let lower = data_type.to_ascii_lowercase();
if lower.contains("bool") {
Some(LogicalType::Bool)
} else if lower.contains("int") || lower.contains("serial") {
Some(LogicalType::Int64)
} else if ["real", "float", "double", "decimal", "numeric", "money"]
.iter()
.any(|hint| lower.contains(hint))
{
Some(LogicalType::Float64)
} else if lower.contains("timestamp") || lower.contains("date") || lower.contains("time") {
Some(LogicalType::Timestamp)
} else {
None
}
}
const fn type_from_value(value: &Value) -> Option<LogicalType> {
match value {
Value::Null => None,
Value::Bool(_) => Some(LogicalType::Bool),
Value::Int(_) => Some(LogicalType::Int64),
Value::Float(_) => Some(LogicalType::Float64),
Value::Date(_) | Value::DateTime(_) | Value::Timestamp(_) => Some(LogicalType::Timestamp),
_ => Some(LogicalType::Utf8),
}
}
enum ColumnBuilder {
Bool {
inner: BooleanBuilder,
col: String,
},
Int64 {
inner: Int64Builder,
col: String,
},
Float64 {
inner: Float64Builder,
col: String,
},
Utf8 {
inner: StringBuilder,
col: String,
},
Timestamp {
inner: TimestampMicrosecondBuilder,
col: String,
},
}
impl ColumnBuilder {
fn new(logical: LogicalType, col: &str) -> Self {
let col = col.to_owned();
match logical {
LogicalType::Bool => Self::Bool {
inner: BooleanBuilder::new(),
col,
},
LogicalType::Int64 => Self::Int64 {
inner: Int64Builder::new(),
col,
},
LogicalType::Float64 => Self::Float64 {
inner: Float64Builder::new(),
col,
},
LogicalType::Utf8 => Self::Utf8 {
inner: StringBuilder::new(),
col,
},
LogicalType::Timestamp => Self::Timestamp {
inner: TimestampMicrosecondBuilder::new().with_timezone(Arc::from("UTC")),
col,
},
}
}
fn append_null(&mut self) {
match self {
Self::Bool { inner, .. } => inner.append_null(),
Self::Int64 { inner, .. } => inner.append_null(),
Self::Float64 { inner, .. } => inner.append_null(),
Self::Utf8 { inner, .. } => inner.append_null(),
Self::Timestamp { inner, .. } => inner.append_null(),
}
}
fn append_value(&mut self, value: &Value) {
if matches!(value, Value::Null) {
self.append_null();
return;
}
match (self, value) {
(Self::Bool { inner, .. }, Value::Bool(v)) => inner.append_value(*v),
(Self::Bool { inner, .. }, Value::Int(n)) => inner.append_value(*n != 0),
(Self::Int64 { inner, .. }, Value::Int(n)) => inner.append_value(*n),
(Self::Int64 { inner, .. }, Value::Bool(v)) => inner.append_value(i64::from(*v)),
(Self::Float64 { inner, .. }, Value::Float(n)) => inner.append_value(*n),
(Self::Float64 { inner, .. }, Value::Int(n)) => {
#[allow(clippy::cast_precision_loss)]
inner.append_value(*n as f64);
}
(Self::Float64 { inner, .. }, Value::Bool(v)) => {
inner.append_value(f64::from(i32::from(*v)));
}
(Self::Utf8 { inner, .. }, other) => inner.append_value(other.render()),
(Self::Timestamp { inner, .. }, Value::Timestamp(ts)) => {
inner.append_value(ts.timestamp_micros());
}
(Self::Timestamp { inner, .. }, Value::DateTime(dt)) => {
let utc: DateTime<Utc> = DateTime::from_naive_utc_and_offset(*dt, Utc);
inner.append_value(utc.timestamp_micros());
}
(Self::Timestamp { inner, .. }, Value::Date(d)) => {
if let Some(dt) = d.and_hms_opt(0, 0, 0) {
let utc: DateTime<Utc> = DateTime::from_naive_utc_and_offset(dt, Utc);
inner.append_value(utc.timestamp_micros());
} else {
inner.append_null();
}
}
(typed, _) => {
tracing::warn!(
target: "narwhal::export::parquet",
column = %typed.column_name(),
expected = ?typed.logical_type(),
got = ?type_from_value(value),
"parquet: dropped value due to type inference mismatch"
);
typed.append_null();
}
}
}
fn finish(mut self) -> Result<ArrayRef, ExportError> {
let array: ArrayRef = match &mut self {
Self::Bool { inner, .. } => Arc::new(inner.finish()),
Self::Int64 { inner, .. } => Arc::new(inner.finish()),
Self::Float64 { inner, .. } => Arc::new(inner.finish()),
Self::Utf8 { inner, .. } => Arc::new(inner.finish()),
Self::Timestamp { inner, .. } => Arc::new(inner.finish()),
};
Ok(array)
}
fn column_name(&self) -> &str {
match self {
Self::Bool { col, .. }
| Self::Int64 { col, .. }
| Self::Float64 { col, .. }
| Self::Utf8 { col, .. }
| Self::Timestamp { col, .. } => col,
}
}
const fn logical_type(&self) -> LogicalType {
match self {
Self::Bool { .. } => LogicalType::Bool,
Self::Int64 { .. } => LogicalType::Int64,
Self::Float64 { .. } => LogicalType::Float64,
Self::Utf8 { .. } => LogicalType::Utf8,
Self::Timestamp { .. } => LogicalType::Timestamp,
}
}
}