use std::fs::File;
use std::path::{Path, PathBuf};
use polars::frame::DataFrame;
use polars::prelude::ParquetReader as PlParquetReader;
use polars::prelude::*;
use super::PlotlarsError;
#[derive(Clone)]
pub struct ParquetReader {
path: PathBuf,
columns: Option<Vec<String>>,
n_rows: Option<usize>,
}
impl ParquetReader {
pub fn new(path: impl AsRef<Path>) -> Self {
Self {
path: path.as_ref().to_path_buf(),
columns: None,
n_rows: None,
}
}
pub fn columns(mut self, columns: Vec<&str>) -> Self {
self.columns = Some(columns.into_iter().map(|s| s.to_string()).collect());
self
}
pub fn n_rows(mut self, n_rows: usize) -> Self {
self.n_rows = Some(n_rows);
self
}
pub fn finish(self) -> Result<DataFrame, PlotlarsError> {
let path_str = self.path.display().to_string();
let file = File::open(&self.path).map_err(|e| PlotlarsError::ParquetParse {
path: path_str.clone(),
source: Box::new(e),
})?;
let mut reader = PlParquetReader::new(file);
if let Some(n) = self.n_rows {
reader = reader.with_slice(Some((0, n)));
}
if let Some(cols) = self.columns {
reader = reader.with_columns(Some(cols));
}
reader.finish().map_err(|e| PlotlarsError::ParquetParse {
path: path_str,
source: Box::new(e),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_parquet() -> PathBuf {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../../target/test_data.parquet");
let mut df = df!(
"a" => [1, 2, 3],
"b" => ["x", "y", "z"]
)
.unwrap();
let file = std::fs::File::create(&path).unwrap();
ParquetWriter::new(file).finish(&mut df).unwrap();
path
}
#[test]
fn read_parquet_default() {
let path = create_test_parquet();
let df = ParquetReader::new(&path).finish().unwrap();
assert_eq!(df.height(), 3);
assert_eq!(df.width(), 2);
}
#[test]
fn read_parquet_select_columns() {
let path = create_test_parquet();
let df = ParquetReader::new(&path)
.columns(vec!["a"])
.finish()
.unwrap();
assert_eq!(df.width(), 1);
}
#[test]
fn read_parquet_n_rows() {
let path = create_test_parquet();
let df = ParquetReader::new(&path).n_rows(2).finish().unwrap();
assert_eq!(df.height(), 2);
}
#[test]
fn read_parquet_file_not_found() {
let result = ParquetReader::new("nonexistent.parquet").finish();
assert!(result.is_err());
}
}