use crate::{
io::{Data, Read},
print_error, Entry, UniqueIdentifier, Update, Who,
};
use arrow::{
array::{Array, ArrayData, BufferBuilder, ListArray, PrimitiveArray},
buffer::Buffer,
datatypes::{ArrowNativeType, ArrowPrimitiveType, DataType, Field, Schema, ToByteSlice},
record_batch::RecordBatch,
};
use parquet::{
arrow::{arrow_writer::ArrowWriter, ArrowReader, ParquetFileArrowReader},
file::{properties::WriterProperties, reader::SerializedFileReader},
};
use std::{
any::Any, collections::HashMap, env, fmt::Display, fs::File, marker::PhantomData, mem::size_of,
path::Path, sync::Arc,
};
#[derive(Debug, thiserror::Error)]
pub enum ArrowError {
#[error("cannot open a parquet file")]
ArrowToFile(#[from] std::io::Error),
#[error("cannot build Arrow data")]
ArrowError(#[from] arrow::error::ArrowError),
#[error("cannot save data to Parquet")]
ParquetError(#[from] parquet::errors::ParquetError),
#[error("no record available")]
NoRecord,
#[error("Field {0} not found")]
FieldNotFound(String),
#[error("Parsing field {0} failed")]
ParseField(String),
#[cfg(feature = "matio-rs")]
#[error("failed to save data to mat file")]
MatFile(#[from] matio_rs::MatioError),
}
type Result<T> = std::result::Result<T, ArrowError>;
const MAX_CAPACITY_BYTE: usize = 2 << 29;
pub enum FileFormat {
Parquet,
Matlab(MatFormat),
}
impl Default for FileFormat {
fn default() -> Self {
Self::Parquet
}
}
pub enum MatFormat {
SampleBased,
TimeBased(f64),
}
impl Default for MatFormat {
fn default() -> Self {
Self::SampleBased
}
}
trait BufferObject: Send + Sync {
fn who(&self) -> String;
fn as_any(&self) -> &dyn Any;
fn as_mut_any(&mut self) -> &mut dyn Any;
fn into_list(&mut self, n_step: usize, n: usize, data_type: DataType) -> Result<ListArray>;
}
struct ArrowBuffer<U: UniqueIdentifier>(PhantomData<U>);
impl<T: ArrowNativeType, U: UniqueIdentifier<Data = Vec<T>>> UniqueIdentifier for ArrowBuffer<U> {
type Data = BufferBuilder<T>;
}
impl<T, U> BufferObject for Data<ArrowBuffer<U>>
where
T: ArrowNativeType,
U: 'static + Send + Sync + UniqueIdentifier<Data = Vec<T>>,
{
fn who(&self) -> String {
Who::who(self)
}
fn as_any(&self) -> &dyn Any {
self
}
fn as_mut_any(&mut self) -> &mut dyn Any {
self
}
fn into_list(&mut self, n_step: usize, n: usize, data_type: DataType) -> Result<ListArray> {
let buffer = &mut *self;
let data = ArrayData::builder(data_type.clone())
.len(buffer.len())
.add_buffer(buffer.finish())
.build()?;
let offsets = (0..).step_by(n).take(n_step + 1).collect::<Vec<i32>>();
let list = ArrayData::builder(DataType::List(Box::new(Field::new(
"values", data_type, false,
))))
.len(n_step)
.add_buffer(Buffer::from(&offsets.to_byte_slice()))
.add_child_data(data)
.build()?;
Ok(ListArray::from(list))
}
}
#[doc(hidden)]
pub trait BufferDataType {
type ArrayType;
fn buffer_data_type() -> DataType;
}
use paste::paste;
macro_rules! impl_buffer_types {
( $( ($rs:ty,$arw:expr) ),+ ) => {
$(
paste! {
impl BufferDataType for $rs {
type ArrayType = arrow::datatypes::[<$arw Type>];
fn buffer_data_type() -> DataType {
arrow::datatypes::DataType::$arw
}
}
}
)+
};
}
impl_buffer_types! {
(f64,Float64),
(f32,Float32),
(i64,Int64),
(i32,Int32),
(i16,Int16),
(i8 ,Int8),
(u64,UInt64),
(u32,UInt32),
(u16,UInt16),
(u8 ,UInt8)
}
enum DropOption {
Save(Option<String>),
NoSave,
}
pub struct ArrowBuilder {
n_step: usize,
capacities: Vec<usize>,
buffers: Vec<(Box<dyn BufferObject>, DataType)>,
metadata: Option<HashMap<String, String>>,
n_entry: usize,
drop_option: DropOption,
decimation: usize,
file_format: FileFormat,
}
impl ArrowBuilder {
pub fn new(n_step: usize) -> Self {
Self {
n_step,
capacities: Vec::new(),
buffers: Vec::new(),
metadata: None,
n_entry: 0,
drop_option: DropOption::Save(None),
decimation: 1,
file_format: Default::default(),
}
}
#[deprecated = "replaced by the log method of the InputLogs trait"]
pub fn entry<T: BufferDataType, U>(self, size: usize) -> Self
where
T: 'static + ArrowNativeType + Send + Sync,
U: 'static + Send + Sync + UniqueIdentifier<Data = Vec<T>>,
{
let mut buffers = self.buffers;
let mut capacity = size * (1 + self.n_step / self.decimation);
if capacity * size_of::<T>() > MAX_CAPACITY_BYTE {
capacity = MAX_CAPACITY_BYTE / size_of::<T>();
log::info!("Capacity limit of 1GB exceeded, reduced to : {}", capacity);
}
let buffer: Data<ArrowBuffer<U>> = Data::new(BufferBuilder::<T>::new(capacity));
buffers.push((Box::new(buffer), T::buffer_data_type()));
let mut capacities = self.capacities;
capacities.push(size);
Self {
buffers,
capacities,
n_entry: self.n_entry + 1,
..self
}
}
pub fn filename<S: Into<String>>(self, filename: S) -> Self {
Self {
drop_option: DropOption::Save(Some(filename.into())),
..self
}
}
pub fn no_save(self) -> Self {
Self {
drop_option: DropOption::NoSave,
..self
}
}
pub fn file_format(self, file_format: FileFormat) -> Self {
Self {
file_format,
..self
}
}
pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata = Some(metadata);
self
}
pub fn decimation(self, decimation: usize) -> Self {
Self { decimation, ..self }
}
pub fn build(self) -> Arrow {
Arrow {
n_step: self.n_step,
capacities: self.capacities,
buffers: self.buffers,
metadata: self.metadata,
step: 0,
n_entry: self.n_entry,
record: None,
drop_option: self.drop_option,
decimation: self.decimation,
count: 0,
file_format: self.file_format,
}
}
}
pub struct Arrow {
n_step: usize,
capacities: Vec<usize>,
buffers: Vec<(Box<dyn BufferObject>, DataType)>,
metadata: Option<HashMap<String, String>>,
step: usize,
n_entry: usize,
record: Option<RecordBatch>,
drop_option: DropOption,
decimation: usize,
count: usize,
file_format: FileFormat,
}
impl Default for Arrow {
fn default() -> Self {
Arrow {
n_step: 0,
capacities: Vec::new(),
buffers: Vec::new(),
metadata: None,
step: 0,
n_entry: 0,
record: None,
drop_option: DropOption::NoSave,
decimation: 1,
count: 0,
file_format: Default::default(),
}
}
}
impl Arrow {
pub fn builder(n_step: usize) -> ArrowBuilder {
ArrowBuilder::new(n_step)
}
fn data<T, U>(&mut self) -> Option<&mut Data<ArrowBuffer<U>>>
where
T: 'static + ArrowNativeType,
U: 'static + UniqueIdentifier<Data = Vec<T>>,
{
self.buffers
.iter_mut()
.find_map(|(b, _)| b.as_mut_any().downcast_mut::<Data<ArrowBuffer<U>>>())
}
pub fn pct_complete(&self) -> usize {
self.step / self.n_step / self.n_entry
}
pub fn size(&self) -> usize {
self.step / self.n_entry
}
}
impl<T, U> Entry<U> for Arrow
where
T: 'static + BufferDataType + ArrowNativeType + Send + Sync,
U: 'static + Send + Sync + UniqueIdentifier<Data = Vec<T>>,
{
fn entry(&mut self, size: usize) {
let mut capacity = size * (1 + self.n_step / self.decimation);
if capacity * size_of::<T>() > MAX_CAPACITY_BYTE {
capacity = MAX_CAPACITY_BYTE / size_of::<T>();
log::info!("Capacity limit of 1GB exceeded, reduced to : {}", capacity);
}
let buffer: Data<ArrowBuffer<U>> = Data::new(BufferBuilder::<T>::new(capacity));
self.buffers.push((Box::new(buffer), T::buffer_data_type()));
self.capacities.push(size);
self.n_entry += 1;
}
}
impl Display for Arrow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.n_entry > 0 {
writeln!(f, "Arrow logger:")?;
writeln!(f, " - data:")?;
for ((buffer, _), capacity) in self.buffers.iter().zip(self.capacities.iter()) {
writeln!(f, " - {:>8}:{:>4}", buffer.who(), capacity)?;
}
write!(
f,
" - steps #: {}/{}/{}",
self.n_step,
self.step / self.n_entry,
self.count / self.n_entry
)?;
}
Ok(())
}
}
impl Drop for Arrow {
fn drop(&mut self) {
log::info!("{self}");
match self.drop_option {
DropOption::Save(ref filename) => {
let file_name = filename
.as_ref()
.cloned()
.unwrap_or_else(|| "data".to_string());
match self.file_format {
FileFormat::Parquet => {
if let Err(e) = self.to_parquet(file_name) {
print_error("Arrow error", &e);
}
}
FileFormat::Matlab(_) =>
{
#[cfg(feature = "matio-rs")]
if let Err(e) = self.to_mat(file_name) {
print_error("Arrow error", &e);
}
}
}
}
DropOption::NoSave => {
log::info!("Dropping Arrow logger without saving.");
}
}
}
}
impl Arrow {
pub fn record(&mut self) -> Result<&RecordBatch> {
if self.record.is_none() {
let mut lists: Vec<Arc<dyn Array>> = vec![];
for ((buffer, buffer_data_type), n) in self.buffers.iter_mut().zip(&self.capacities) {
let list =
buffer.into_list(self.count / self.n_entry, *n, buffer_data_type.clone())?;
lists.push(Arc::new(list));
}
let fields: Vec<_> = self
.buffers
.iter()
.map(|(buffer, data_type)| {
Field::new(
&buffer
.who()
.split("::")
.last()
.unwrap_or("no name")
.replace(">", ""),
DataType::List(Box::new(Field::new("values", data_type.clone(), false))),
false,
)
})
.collect();
let schema = Arc::new(if let Some(metadata) = self.metadata.as_ref() {
Schema::new_with_metadata(fields, metadata.clone())
} else {
Schema::new(fields)
});
self.record = Some(RecordBatch::try_new(Arc::clone(&schema), lists)?);
}
self.record.as_ref().ok_or(ArrowError::NoRecord)
}
pub fn to_parquet<P: AsRef<Path> + std::fmt::Debug>(&mut self, path: P) -> Result<()> {
let batch = self.record()?;
let root_env = env::var("DATA_REPO").unwrap_or_else(|_| ".".to_string());
let root = Path::new(&root_env).join(&path).with_extension("parquet");
let file = File::create(&root)?;
let props = WriterProperties::builder().build();
let mut writer = ArrowWriter::try_new(file, Arc::clone(&batch.schema()), Some(props))?;
writer.write(&batch)?;
writer.close()?;
log::info!("Arrow data saved to {root:?}");
Ok(())
}
pub fn from_parquet<P>(path: P) -> Result<Self>
where
P: AsRef<Path>,
{
let root_env = env::var("DATA_REPO").unwrap_or_else(|_| ".".to_string());
let root = Path::new(&root_env);
let filename = root.join(&path).with_extension("parquet");
let file = File::open(&filename)?;
log::info!("Loading {:?}", filename);
let file_reader = SerializedFileReader::new(file)?;
let mut arrow_reader = ParquetFileArrowReader::new(Arc::new(file_reader));
let records = arrow_reader
.get_record_reader(2048)
.unwrap()
.collect::<std::result::Result<Vec<RecordBatch>, arrow::error::ArrowError>>()?;
let schema = records.get(0).unwrap().schema();
Ok(Arrow {
n_step: 0,
capacities: Vec::new(),
buffers: Vec::new(),
metadata: None,
step: 0,
n_entry: 0,
record: Some(RecordBatch::concat(&schema, &records)?),
drop_option: DropOption::NoSave,
decimation: 1,
count: 0,
file_format: FileFormat::Parquet,
})
}
#[cfg(feature = "matio-rs")]
pub fn to_mat<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
use matio_rs::{MatFile, MatVar, Save};
let batch = self.record()?;
let root_env = env::var("DATA_REPO").unwrap_or_else(|_| ".".to_string());
let root = Path::new(&root_env).join(&path).with_extension("mat");
let mat_file = MatFile::save(&root)?;
let mut n_sample = 0;
for field in batch.schema().fields() {
let name = field.name();
let data: Vec<Vec<f64>> = self.get(name)?;
n_sample = data.len();
let n_data = data[0].len();
mat_file.write(MatVar::<Vec<f64>>::array(
name,
data.into_iter()
.flatten()
.collect::<Vec<f64>>()
.as_mut_slice(),
(n_data, n_sample),
)?);
}
if let FileFormat::Matlab(MatFormat::TimeBased(sampling_frequency)) = self.file_format {
let tau = sampling_frequency.recip();
let time: Vec<f64> = (0..n_sample).map(|i| i as f64 * tau).collect();
mat_file.write(MatVar::<Vec<f64>>::new("time", time.as_slice())?);
}
log::info!("Arrow data saved to {root:?}");
Ok(())
}
}
pub trait Get<T>
where
T: BufferDataType,
<T as BufferDataType>::ArrayType: ArrowPrimitiveType,
Vec<T>: FromIterator<<<T as BufferDataType>::ArrayType as ArrowPrimitiveType>::Native>,
{
fn get<S>(&mut self, field_name: S) -> Result<Vec<Vec<T>>>
where
S: AsRef<str>,
String: From<S>;
fn get_skip_take<S>(
&mut self,
field_name: S,
skip: usize,
take: Option<usize>,
) -> Result<Vec<Vec<T>>>
where
S: AsRef<str>,
String: From<S>;
fn get_skip<S>(&mut self, field_name: S, skip: usize) -> Result<Vec<Vec<T>>>
where
S: AsRef<str>,
String: From<S>,
{
self.get_skip_take(field_name, skip, None)
}
fn get_take<S>(&mut self, field_name: S, take: usize) -> Result<Vec<Vec<T>>>
where
S: AsRef<str>,
String: From<S>,
{
self.get_skip_take(field_name, 0, Some(take))
}
}
impl<'a, T> Get<T> for Arrow
where
T: BufferDataType,
<T as BufferDataType>::ArrayType: ArrowPrimitiveType,
Vec<T>: FromIterator<<<T as BufferDataType>::ArrayType as ArrowPrimitiveType>::Native>,
{
fn get<S>(&mut self, field_name: S) -> Result<Vec<Vec<T>>>
where
S: AsRef<str>,
String: From<S>,
{
match self.record() {
Ok(record) => match record.schema().column_with_name(field_name.as_ref()) {
Some((idx, _)) => record
.column(idx)
.as_any()
.downcast_ref::<ListArray>()
.map(|data| {
data.iter()
.map(|data| {
data.map(|data| {
data.as_any()
.downcast_ref::<PrimitiveArray<<T as BufferDataType>::ArrayType>>()
.and_then(|data| data.iter().collect::<Option<Vec<T>>>())
})
.flatten()
})
.collect::<Option<Vec<Vec<T>>>>()
})
.flatten()
.ok_or_else(|| ArrowError::ParseField(field_name.into())),
None => Err(ArrowError::FieldNotFound(field_name.into())),
},
Err(e) => Err(e),
}
}
fn get_skip_take<S>(
&mut self,
field_name: S,
skip: usize,
take: Option<usize>,
) -> Result<Vec<Vec<T>>>
where
S: AsRef<str>,
String: From<S>,
{
match self.record() {
Ok(record) => match record.schema().column_with_name(field_name.as_ref()) {
Some((idx, _)) => record
.column(idx)
.as_any()
.downcast_ref::<ListArray>()
.map(|data| {
data.iter()
.skip(skip)
.take(take.unwrap_or(usize::MAX))
.map(|data| {
data.map(|data| {
data.as_any()
.downcast_ref::<PrimitiveArray<<T as BufferDataType>::ArrayType>>()
.and_then(|data| data.iter().collect::<Option<Vec<T>>>())
})
.flatten()
})
.collect::<Option<Vec<Vec<T>>>>()
})
.flatten()
.ok_or_else(|| ArrowError::ParseField(field_name.into())),
None => Err(ArrowError::FieldNotFound(field_name.into())),
},
Err(e) => Err(e),
}
}
}
impl Update for Arrow {}
impl<T, U> Read<U> for Arrow
where
T: ArrowNativeType,
U: 'static + UniqueIdentifier<Data = Vec<T>>,
{
fn read(&mut self, data: Arc<Data<U>>) {
let r = 1 + (self.step as f64 / self.n_entry as f64).floor() as usize;
self.step += 1;
if r % self.decimation > 0 {
return;
}
if let Some(buffer) = self.data::<T, U>() {
buffer.append_slice(data.as_slice());
self.count += 1;
}
}
}