use std::{collections::HashMap, fmt::Display, mem::size_of};
use apache_arrow::{
array::BufferBuilder,
datatypes::{ArrowNativeType, DataType},
record_batch::RecordBatch,
};
use interface::{Entry, UniqueIdentifier, print_info};
use crate::{
ArrowBuffer, BufferDataType, BufferObject, DropOption, FileFormat, LogData, MAX_CAPACITY_BYTE,
};
mod arrow;
mod builder;
pub(crate) mod iter;
pub use builder::ArrowBuilder;
pub struct Arrow {
n_step: usize,
capacities: Vec<usize>,
buffers: Vec<(Box<dyn BufferObject>, DataType)>,
metadata: Option<HashMap<String, String>>,
pub(crate) step: usize,
pub(crate) n_entry: usize,
record: Option<RecordBatch>,
batch: Option<Vec<RecordBatch>>,
drop_option: DropOption,
pub(crate) decimation: usize,
pub(crate) count: usize,
file_format: FileFormat,
pub(crate) batch_size: Option<usize>,
}
impl Default for Arrow {
fn default() -> Self {
Arrow {
n_step: 0,
capacities: Vec::new(),
buffers: Vec::new(),
metadata: None,
step: 0,
n_entry: 0,
record: None,
batch: None,
drop_option: DropOption::NoSave,
decimation: 1,
count: 0,
file_format: Default::default(),
batch_size: None,
}
}
}
impl Arrow {
pub fn builder(n_step: usize) -> ArrowBuilder {
ArrowBuilder::new(n_step)
}
pub(crate) fn data<T, U>(&mut self) -> Option<&mut LogData<ArrowBuffer<U>>>
where
T: 'static + ArrowNativeType,
U: 'static + UniqueIdentifier<DataType = Vec<T>>,
{
self.buffers
.iter_mut()
.find_map(|(b, _)| b.as_mut_any().downcast_mut::<LogData<ArrowBuffer<U>>>())
}
pub fn pct_complete(&self) -> usize {
self.step / self.n_step / self.n_entry
}
pub fn size(&self) -> usize {
self.step / self.n_entry
}
}
impl<T, U> Entry<U> for Arrow
where
T: 'static + BufferDataType + ArrowNativeType + Send + Sync,
U: 'static + Send + Sync + UniqueIdentifier<DataType = Vec<T>>,
{
fn entry(&mut self, size: usize) {
let mut capacity = size * (1 + self.n_step / self.decimation);
if capacity * size_of::<T>() > MAX_CAPACITY_BYTE {
capacity = MAX_CAPACITY_BYTE / size_of::<T>();
log::info!("Capacity limit of 1GB exceeded, reduced to : {}", capacity);
}
let buffer: LogData<ArrowBuffer<U>> = LogData::new(BufferBuilder::<T>::new(capacity));
let name = buffer.who();
if let Some(_) = self.buffers.iter().find(|buffer| buffer.0.who() == name) {
log::info!(
r#"found existing entry with same name in Arrow buffers, skipping "{name}""#
);
return;
}
self.buffers.push((Box::new(buffer), T::buffer_data_type()));
self.capacities.push(size);
self.n_entry += 1;
}
}
impl Display for Arrow {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.n_entry > 0 {
writeln!(f, "Arrow logger:")?;
writeln!(f, " - data:")?;
for ((buffer, _), capacity) in self.buffers.iter().zip(self.capacities.iter()) {
writeln!(f, " - {:>8}:{:>4}", buffer.who(), capacity)?;
}
write!(
f,
" - steps #: {}/{}/{}",
self.n_step,
self.step / self.n_entry,
self.count / self.n_entry
)?;
return Ok(());
}
if let Some(record) = &self.record {
write!(
f,
"Arrow logger {:?}:\n{:}",
(record.num_rows(), record.num_columns()),
record
.schema()
.flattened_fields()
.iter()
.step_by(2)
.map(|field| format!(" - {}", field.name()))
.collect::<Vec<_>>()
.join("\n")
)?;
return Ok(());
}
Ok(())
}
}
impl Drop for Arrow {
fn drop(&mut self) {
log::info!("{self}");
match self.drop_option {
DropOption::Save(ref filename) => {
let file_name = filename
.as_ref()
.cloned()
.unwrap_or_else(|| "data".to_string());
match self.file_format {
FileFormat::Parquet => {
if let Err(e) = self.to_parquet(file_name) {
print_info("Arrow error", Some(&e));
}
}
#[cfg(feature = "matio-rs")]
FileFormat::Matlab(_) => {
if let Err(e) = self.to_mat(file_name) {
print_info("Arrow error", Some(&e));
}
}
}
}
DropOption::NoSave => {
log::info!("Dropping Arrow logger without saving.");
}
}
}
}