use std::{
collections::HashMap,
env,
fmt::Display,
fs::{DirBuilder, File},
mem::size_of,
path::Path,
sync::Arc,
};
use apache_arrow::{
array::{Array, BufferBuilder},
compute::concat_batches,
datatypes::{ArrowNativeType, DataType, Field, Schema},
record_batch::{RecordBatch, RecordBatchReader},
};
use interface::{print_info, Entry, UniqueIdentifier};
use parquet::{
arrow::{arrow_reader::ParquetRecordBatchReaderBuilder, ArrowWriter},
file::properties::WriterProperties,
};
use crate::{
ArrowBuffer, ArrowError, BufferDataType, BufferObject, DropOption, FileFormat, LogData, Result,
MAX_CAPACITY_BYTE,
};
pub struct ArrowBuilder {
n_step: usize,
capacities: Vec<usize>,
buffers: Vec<(Box<dyn BufferObject>, DataType)>,
metadata: Option<HashMap<String, String>>,
n_entry: usize,
drop_option: DropOption,
decimation: usize,
file_format: FileFormat,
batch_size: Option<usize>,
}
impl ArrowBuilder {
pub fn new(n_step: usize) -> Self {
Self {
n_step,
capacities: Vec::new(),
buffers: Vec::new(),
metadata: None,
n_entry: 0,
drop_option: DropOption::Save(None),
decimation: 1,
file_format: Default::default(),
batch_size: None,
}
}
#[deprecated = "replaced by the log method of the InputLogs trait"]
pub fn entry<T: BufferDataType, U>(self, size: usize) -> Self
where
T: 'static + ArrowNativeType + Send + Sync,
U: 'static + Send + Sync + UniqueIdentifier<DataType = Vec<T>>,
{
let mut buffers = self.buffers;
let mut capacity = size * (1 + self.n_step / self.decimation);
if capacity * size_of::<T>() > MAX_CAPACITY_BYTE {
capacity = MAX_CAPACITY_BYTE / size_of::<T>();
log::info!("Capacity limit of 1GB exceeded, reduced to : {}", capacity);
}
let buffer: LogData<ArrowBuffer<U>> = LogData::new(BufferBuilder::<T>::new(capacity));
buffers.push((Box::new(buffer), T::buffer_data_type()));
let mut capacities = self.capacities;
capacities.push(size);
Self {
buffers,
capacities,
n_entry: self.n_entry + 1,
..self
}
}
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.batch_size = Some(batch_size);
self
}
pub fn filename<S: Into<String>>(self, filename: S) -> Self {
Self {
drop_option: DropOption::Save(Some(filename.into())),
..self
}
}
pub fn no_save(self) -> Self {
Self {
drop_option: DropOption::NoSave,
..self
}
}
pub fn file_format(self, file_format: FileFormat) -> Self {
Self {
file_format,
..self
}
}
pub fn metadata(mut self, metadata: HashMap<String, String>) -> Self {
self.metadata = Some(metadata);
self
}
pub fn decimation(self, decimation: usize) -> Self {
Self { decimation, ..self }
}
pub fn build(self) -> Arrow {
Arrow {
n_step: self.n_step,
capacities: self.capacities,
buffers: self.buffers,
metadata: self.metadata,
step: 0,
n_entry: self.n_entry,
record: None,
batch: None,
drop_option: self.drop_option,
decimation: self.decimation,
count: 0,
file_format: self.file_format,
batch_size: self.batch_size,
}
}
}
pub struct Arrow {
n_step: usize,
capacities: Vec<usize>,
buffers: Vec<(Box<dyn BufferObject>, DataType)>,
metadata: Option<HashMap<String, String>>,
pub(crate) step: usize,
pub(crate) n_entry: usize,
record: Option<RecordBatch>,
batch: Option<Vec<RecordBatch>>,
drop_option: DropOption,
pub(crate) decimation: usize,
pub(crate) count: usize,
file_format: FileFormat,
pub(crate) batch_size: Option<usize>,
}
impl Default for Arrow {
fn default() -> Self {
Arrow {
n_step: 0,
capacities: Vec::new(),
buffers: Vec::new(),
metadata: None,
step: 0,
n_entry: 0,
record: None,
batch: None,
drop_option: DropOption::NoSave,
decimation: 1,
count: 0,
file_format: Default::default(),
batch_size: None,
}
}
}
impl Arrow {
pub fn builder(n_step: usize) -> ArrowBuilder {
ArrowBuilder::new(n_step)
}
pub(crate) fn data<T, U>(&mut self) -> Option<&mut LogData<ArrowBuffer<U>>>
where
T: 'static + ArrowNativeType,
U: 'static + UniqueIdentifier<DataType = Vec<T>>,
{
self.buffers
.iter_mut()
.find_map(|(b, _)| b.as_mut_any().downcast_mut::<LogData<ArrowBuffer<U>>>())
}
pub fn pct_complete(&self) -> usize {
self.step / self.n_step / self.n_entry
}
pub fn size(&self) -> usize {
self.step / self.n_entry
}
}
impl<T, U> Entry<U> for Arrow
where
T: 'static + BufferDataType + ArrowNativeType + Send + Sync,
U: 'static + Send + Sync + UniqueIdentifier<DataType = Vec<T>>,
{
fn entry(&mut self, size: usize) {
let mut capacity = size * (1 + self.n_step / self.decimation);
if capacity * size_of::<T>() > MAX_CAPACITY_BYTE {
capacity = MAX_CAPACITY_BYTE / size_of::<T>();
log::info!("Capacity limit of 1GB exceeded, reduced to : {}", capacity);
}
let buffer: LogData<ArrowBuffer<U>> = LogData::new(BufferBuilder::<T>::new(capacity));
self.buffers.push((Box::new(buffer), T::buffer_data_type()));
self.capacities.push(size);
self.n_entry += 1;
}
}
impl Display for Arrow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.n_entry > 0 {
writeln!(f, "Arrow logger:")?;
writeln!(f, " - data:")?;
for ((buffer, _), capacity) in self.buffers.iter().zip(self.capacities.iter()) {
writeln!(f, " - {:>8}:{:>4}", buffer.who(), capacity)?;
}
write!(
f,
" - steps #: {}/{}/{}",
self.n_step,
self.step / self.n_entry,
self.count / self.n_entry
)?;
}
Ok(())
}
}
impl Drop for Arrow {
fn drop(&mut self) {
log::info!("{self}");
match self.drop_option {
DropOption::Save(ref filename) => {
let file_name = filename
.as_ref()
.cloned()
.unwrap_or_else(|| "data".to_string());
match self.file_format {
FileFormat::Parquet => {
if let Err(e) = self.to_parquet(file_name) {
print_info("Arrow error", Some(&e));
}
}
#[cfg(feature = "matio-rs")]
FileFormat::Matlab(_) => {
if let Err(e) = self.to_mat(file_name) {
print_info("Arrow error", Some(&e));
}
}
}
}
DropOption::NoSave => {
log::info!("Dropping Arrow logger without saving.");
}
}
}
}
impl Arrow {
pub fn save(&mut self) -> &mut Self {
match self.drop_option {
DropOption::Save(ref filename) => {
let file_name = filename
.as_ref()
.cloned()
.unwrap_or_else(|| "data".to_string());
match self.file_format {
FileFormat::Parquet => {
if let Err(e) = self.to_parquet(file_name) {
print_info("Arrow error", Some(&e));
}
}
#[cfg(feature = "matio-rs")]
FileFormat::Matlab(_) => {
if let Err(e) = self.to_mat(file_name) {
print_info("Arrow error", Some(&e));
}
}
}
}
DropOption::NoSave => {
log::info!("no saving option set");
}
}
self
}
pub fn batch(&mut self) -> Result<&Vec<RecordBatch>> {
self.record()?;
if let Some(record) = self.record.take() {
self.batch.get_or_insert(vec![]).push(record);
}
self.batch.as_ref().ok_or(ArrowError::NoRecord)
}
pub fn concat_batches(&mut self) -> Result<RecordBatch> {
self.batch().and_then(|batches| {
let schema = batches[0].schema();
let record = concat_batches(&schema, batches)?;
Ok(record)
})
}
pub fn record(&mut self) -> Result<&RecordBatch> {
if self.record.is_none() {
let mut lists: Vec<Arc<dyn Array>> = vec![];
for ((buffer, buffer_data_type), n) in self.buffers.iter_mut().zip(&self.capacities) {
let list = buffer.into_list(
self.batch_size.unwrap_or(self.count / self.n_entry),
*n,
buffer_data_type.clone(),
)?;
lists.push(Arc::new(list));
}
let fields: Vec<_> = self
.buffers
.iter()
.map(|(buffer, data_type)| {
Field::new(
&buffer
.who()
.split("::")
.last()
.unwrap_or("no name")
.replace(">", ""),
DataType::List(Box::new(Field::new("values", data_type.clone(), false))),
false,
)
})
.collect();
let schema = Arc::new(if let Some(metadata) = self.metadata.as_ref() {
Schema::new_with_metadata(fields, metadata.clone())
} else {
Schema::new(fields)
});
self.record = Some(RecordBatch::try_new(Arc::clone(&schema), lists)?);
}
self.record.as_ref().ok_or(ArrowError::NoRecord)
}
pub fn to_parquet<P: AsRef<Path> + std::fmt::Debug>(&mut self, path: P) -> Result<()> {
let batch = self.concat_batches()?;
let root_env = env::var("DATA_REPO").unwrap_or_else(|_| ".".to_string());
let root = Path::new(&root_env).join(&path).with_extension("parquet");
if let Some(path) = root.parent() {
if !path.is_dir() {
DirBuilder::new().recursive(true).create(&path)?;
}
};
let file = File::create(&root)?;
let props = WriterProperties::builder().build();
let mut writer = ArrowWriter::try_new(file, Arc::clone(&batch.schema()), Some(props))?;
writer.write(&batch)?;
writer.close()?;
log::info!("Arrow data saved to {root:?}");
Ok(())
}
pub fn from_parquet<P>(path: P) -> Result<Self>
where
P: AsRef<Path>,
{
let root_env = env::var("DATA_REPO").unwrap_or_else(|_| ".".to_string());
let root = Path::new(&root_env);
let filename = root.join(&path).with_extension("parquet");
if let Some(path) = filename.parent() {
if !path.is_dir() {
DirBuilder::new().recursive(true).create(&path)?;
}
};
let file = File::open(&filename)?;
log::info!("Loading {:?}", filename);
let parquet_reader = ParquetRecordBatchReaderBuilder::try_new(file)?
.with_batch_size(2048)
.build()?;
let schema = parquet_reader.schema();
let records: std::result::Result<Vec<_>, apache_arrow::error::ArrowError> =
parquet_reader.collect();
let record = concat_batches(&schema, records?.as_slice())?;
Ok(Arrow {
n_step: 0,
capacities: Vec::new(),
buffers: Vec::new(),
metadata: None,
step: 0,
n_entry: 0,
record: Some(record),
batch: None,
drop_option: DropOption::NoSave,
decimation: 1,
count: 0,
file_format: FileFormat::Parquet,
batch_size: None,
})
}
#[cfg(feature = "matio-rs")]
pub fn to_mat<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
use crate::Get;
use matio_rs::MatFile;
let batch = self.concat_batches()?;
let root_env = env::var("DATA_REPO").unwrap_or_else(|_| ".".to_string());
let root = Path::new(&root_env).join(&path).with_extension("mat");
let mat_file = MatFile::save(&root)?;
let mut n_sample = 0;
for field in batch.schema().fields() {
let name = field.name();
let data: Vec<Vec<f64>> = self.get(name)?;
n_sample = data.len();
let n_data = data[0].len();
mat_file.array(
name,
data.into_iter()
.flatten()
.collect::<Vec<f64>>()
.as_mut_slice(),
vec![n_data as u64, n_sample as u64],
)?;
}
if let FileFormat::Matlab(crate::MatFormat::TimeBased(sampling_frequency)) =
self.file_format
{
let tau = sampling_frequency.recip();
let time: Vec<f64> = (0..n_sample).map(|i| i as f64 * tau).collect();
mat_file.var("time", time.as_slice())?;
}
log::info!("Arrow data saved to {root:?}");
Ok(())
}
}