use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use parquet::basic::Type as PhysicalType;
use parquet::errors::ParquetError;
use parquet::file::properties::WriterProperties;
use parquet::file::reader::{FileReader, SerializedFileReader};
use parquet::file::writer::SerializedFileWriter;
use parquet::schema::parser::parse_message_type;
use parquet::schema::types::Type;
use std::fs::File;
use std::path::Path;
use std::sync::Arc;
const SHAPE_METADATA_KEY: &str = "numrs2_shape";
const DTYPE_METADATA_KEY: &str = "numrs2_dtype";
pub fn write_parquet<T, P>(array: &Array<T>, path: P, props: Option<WriterProperties>) -> Result<()>
where
T: Clone + ParquetWritable,
P: AsRef<Path>,
{
T::write_to_parquet(array, path.as_ref(), props)
}
pub fn read_parquet<T, P>(path: P) -> Result<Array<T>>
where
T: Clone + ParquetReadable,
P: AsRef<Path>,
{
T::read_from_parquet(path.as_ref())
}
pub trait ParquetWritable: Clone {
fn write_to_parquet(
array: &Array<Self>,
path: &Path,
props: Option<WriterProperties>,
) -> Result<()>;
}
pub trait ParquetReadable: Clone {
fn read_from_parquet(path: &Path) -> Result<Array<Self>>;
}
fn parquet_err_to_numrs2(e: ParquetError) -> NumRs2Error {
NumRs2Error::IOError(format!("Parquet error: {}", e))
}
macro_rules! impl_parquet_io {
($type:ty, $physical_type:expr, $type_name:expr) => {
impl ParquetWritable for $type {
fn write_to_parquet(
array: &Array<Self>,
path: &Path,
props: Option<WriterProperties>,
) -> Result<()> {
let file = File::create(path)
.map_err(|e| NumRs2Error::IOError(format!("Failed to create file: {}", e)))?;
let schema_str = format!(
"message numrs2_array {{
REQUIRED {} values;
}}",
$physical_type
);
let schema = Arc::new(
parse_message_type(&schema_str)
.map_err(parquet_err_to_numrs2)?
);
let props = props.unwrap_or_else(|| {
WriterProperties::builder()
.set_compression(parquet::basic::Compression::SNAPPY)
.build()
});
let mut writer = SerializedFileWriter::new(file, schema, Arc::new(props))
.map_err(parquet_err_to_numrs2)?;
let data = array.to_vec();
let shape = array.shape();
let mut row_group_writer = writer.next_row_group()
.map_err(parquet_err_to_numrs2)?;
if let Some(col_writer) = row_group_writer.next_column()
.map_err(parquet_err_to_numrs2)?
{
col_writer.close()
.map_err(parquet_err_to_numrs2)?;
}
row_group_writer.close()
.map_err(parquet_err_to_numrs2)?;
writer.close()
.map_err(parquet_err_to_numrs2)?;
let metadata_path = path.with_extension("parquet.meta");
let metadata = serde_json::json!({
"shape": shape,
"dtype": $type_name,
});
std::fs::write(&metadata_path, metadata.to_string())
.map_err(|e| NumRs2Error::IOError(format!("Failed to write metadata: {}", e)))?;
Ok(())
}
}
impl ParquetReadable for $type {
fn read_from_parquet(path: &Path) -> Result<Array<Self>> {
let metadata_path = path.with_extension("parquet.meta");
let metadata_str = std::fs::read_to_string(&metadata_path)
.map_err(|e| NumRs2Error::IOError(format!("Failed to read metadata: {}", e)))?;
let metadata: serde_json::Value = serde_json::from_str(&metadata_str)
.map_err(|e| NumRs2Error::DeserializationError(format!("Invalid metadata: {}", e)))?;
let shape: Vec<usize> = metadata["shape"]
.as_array()
.ok_or_else(|| NumRs2Error::DeserializationError("Missing shape in metadata".to_string()))?
.iter()
.map(|v| v.as_u64().ok_or_else(|| NumRs2Error::DeserializationError("Invalid shape value".to_string())).map(|x| x as usize))
.collect::<Result<Vec<_>>>()?;
let file = File::open(path)
.map_err(|e| NumRs2Error::IOError(format!("Failed to open file: {}", e)))?;
let reader = SerializedFileReader::new(file)
.map_err(parquet_err_to_numrs2)?;
let data = Vec::<$type>::new();
return Err(NumRs2Error::IOError(
"Parquet reading not fully implemented yet - use Arrow format instead".to_string()
));
}
}
};
}
impl_parquet_io!(f64, "DOUBLE", "f64");
impl_parquet_io!(f32, "FLOAT", "f32");
impl_parquet_io!(i32, "INT32", "i32");
impl_parquet_io!(i64, "INT64", "i64");
impl_parquet_io!(u32, "INT32", "u32");
impl_parquet_io!(u64, "INT64", "u64");
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_parquet_metadata() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let path = temp_dir.path().join("test.parquet");
let array = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0]).reshape(&[2, 2]);
let result = write_parquet(&array, &path, None);
}
}