mod conversion;
mod schema;
use crate::conversion::FromSqlite;
pub use crate::schema::*;
use anyhow::{Context, Result};
use fallible_streaming_iterator::FallibleStreamingIterator;
use parquet::file::writer::SerializedFileWriter;
use rusqlite::Connection;
use std::io::Write;
use std::sync::Arc;
fn mk_writer<W: Write + Send>(
table_name: &str,
cols: &[Column],
out: W,
) -> Result<SerializedFileWriter<W>> {
let fields = cols
.iter()
.map(|col| Arc::new(col.as_parquet().unwrap()))
.collect::<Vec<_>>();
let schema = parquet::schema::types::Type::group_type_builder(table_name)
.with_fields(fields)
.build()?;
let mut bldr = parquet::file::properties::WriterProperties::builder()
.set_compression(parquet::basic::Compression::ZSTD(Default::default()));
for col in cols {
let path = parquet::schema::types::ColumnPath::new(vec![col.name.clone()]);
if let Some(enc) = col.encoding() {
bldr = bldr.set_column_encoding(path.clone(), enc)
}
bldr = bldr.set_column_dictionary_enabled(path, col.dictionary);
}
let props = bldr.build();
Ok(SerializedFileWriter::new(
out,
Arc::new(schema),
Arc::new(props),
)?)
}
pub fn write_table(
conn: &Connection,
table_name: &str,
cols: &[Column],
out: impl Write + Send,
group_size: usize,
) -> Result<parquet::format::FileMetaData> {
write_table_with_progress(conn, table_name, cols, out, group_size, |_| Ok(()))
}
#[derive(Default, Debug, Copy, Clone)]
pub struct Progress {
pub n_cols: u64,
pub n_rows: u64,
pub n_groups: u64,
}
pub fn write_table_with_progress(
conn: &Connection,
table_name: &str,
cols: &[Column],
out: impl Write + Send,
group_size: usize,
mut progress_cb: impl FnMut(Progress) -> Result<()>,
) -> Result<parquet::format::FileMetaData> {
let mut wtr = mk_writer(table_name, cols, out)?;
let mut stmnts = cols
.iter()
.map(|col| conn.prepare(&col.query).unwrap())
.collect::<Vec<_>>();
let mut selects = stmnts
.iter_mut()
.map(|x| x.query([]).unwrap())
.collect::<Vec<rusqlite::Rows>>();
for s in &mut selects {
s.advance()?;
}
let mut progress = Progress::default();
while selects[0].get().is_some() {
write_group(&mut wtr, &mut selects, group_size, |n_cols| {
progress_cb(Progress { n_cols, ..progress })
})
.context(format!("Group {}", progress.n_groups))?;
progress.n_rows += group_size as u64;
progress.n_groups += 1;
}
let metadata = wtr.close()?;
Ok(metadata)
}
fn write_group<W: Write + Send>(
wtr: &mut SerializedFileWriter<W>,
selects: &mut [rusqlite::Rows],
group_size: usize,
mut progress_cb: impl FnMut(u64) -> Result<()>,
) -> Result<Arc<parquet::file::metadata::RowGroupMetaData>> {
let mut group_wtr = wtr.next_row_group()?;
let mut selects_iter = selects.iter_mut();
let mut n_cols_written = 0;
while let Some(mut col_wtr) = group_wtr.next_column()? {
progress_cb(n_cols_written)?;
let select = selects_iter.next().unwrap();
use parquet::column::writer::ColumnWriter::*;
let x = match col_wtr.untyped() {
BoolColumnWriter(wtr) => write_col(select, group_size, wtr),
Int32ColumnWriter(wtr) => write_col(select, group_size, wtr),
Int64ColumnWriter(wtr) => write_col(select, group_size, wtr),
Int96ColumnWriter(wtr) => write_col(select, group_size, wtr),
FloatColumnWriter(wtr) => write_col(select, group_size, wtr),
DoubleColumnWriter(wtr) => write_col(select, group_size, wtr),
ByteArrayColumnWriter(wtr) => write_col(select, group_size, wtr),
FixedLenByteArrayColumnWriter(wtr) => write_col(select, group_size, wtr),
};
x.context(format!("Column {}", n_cols_written))?;
col_wtr
.close()
.context(format!("Column {}", n_cols_written))?;
n_cols_written += 1;
}
Ok(group_wtr.close()?)
}
fn write_col<T>(
iter: &mut rusqlite::Rows,
group_size: usize,
wtr: &mut parquet::column::writer::ColumnWriterImpl<T>,
) -> Result<()>
where
T: parquet::data_type::DataType,
T::T: FromSqlite,
{
let mut defs = vec![];
let mut vals = vec![];
for _ in 0..group_size {
let x = match iter.get() {
Some(x) => x,
None => break,
};
let x = x.get_ref(0)?;
if x == rusqlite::types::ValueRef::Null {
defs.push(0);
} else {
defs.push(1);
vals.push(T::T::from_sqlite(x)?);
}
iter.advance()?;
}
wtr.write_batch(&vals, Some(&defs), None).unwrap();
Ok(())
}