gmt_dos_clients_arrow/
arrow.rs

1use std::{collections::HashMap, fmt::Display, mem::size_of};
2
3use apache_arrow::{
4    array::BufferBuilder,
5    datatypes::{ArrowNativeType, DataType},
6    record_batch::RecordBatch,
7};
8use interface::{Entry, UniqueIdentifier, print_info};
9
10use crate::{
11    ArrowBuffer, BufferDataType, BufferObject, DropOption, FileFormat, LogData, MAX_CAPACITY_BYTE,
12};
13
14mod arrow;
15mod builder;
16// mod get;
17mod iter;
18pub use builder::ArrowBuilder;
19
20/// Apache [Arrow](https://docs.rs/arrow) client
21pub struct Arrow {
22    n_step: usize,
23    capacities: Vec<usize>,
24    buffers: Vec<(Box<dyn BufferObject>, DataType)>,
25    metadata: Option<HashMap<String, String>>,
26    pub(crate) step: usize,
27    pub(crate) n_entry: usize,
28    record: Option<RecordBatch>,
29    batch: Option<Vec<RecordBatch>>,
30    drop_option: DropOption,
31    pub(crate) decimation: usize,
32    pub(crate) count: usize,
33    file_format: FileFormat,
34    pub(crate) batch_size: Option<usize>,
35}
36impl Default for Arrow {
37    fn default() -> Self {
38        Arrow {
39            n_step: 0,
40            capacities: Vec::new(),
41            buffers: Vec::new(),
42            metadata: None,
43            step: 0,
44            n_entry: 0,
45            record: None,
46            batch: None,
47            drop_option: DropOption::NoSave,
48            decimation: 1,
49            count: 0,
50            file_format: Default::default(),
51            batch_size: None,
52        }
53    }
54}
55impl Arrow {
56    /// Creates a new Apache [Arrow](https://docs.rs/arrow) data logger
57    ///
58    ///  - `n_step`: the number of time step
59    pub fn builder(n_step: usize) -> ArrowBuilder {
60        ArrowBuilder::new(n_step)
61    }
62    pub(crate) fn data<T, U>(&mut self) -> Option<&mut LogData<ArrowBuffer<U>>>
63    where
64        T: 'static + ArrowNativeType,
65        U: 'static + UniqueIdentifier<DataType = Vec<T>>,
66    {
67        self.buffers
68            .iter_mut()
69            .find_map(|(b, _)| b.as_mut_any().downcast_mut::<LogData<ArrowBuffer<U>>>())
70    }
71    pub fn pct_complete(&self) -> usize {
72        self.step / self.n_step / self.n_entry
73    }
74    pub fn size(&self) -> usize {
75        self.step / self.n_entry
76    }
77}
78
79impl<T, U> Entry<U> for Arrow
80where
81    T: 'static + BufferDataType + ArrowNativeType + Send + Sync,
82    U: 'static + Send + Sync + UniqueIdentifier<DataType = Vec<T>>,
83{
84    fn entry(&mut self, size: usize) {
85        let mut capacity = size * (1 + self.n_step / self.decimation);
86        //log::info!("Buffer capacity: {}", capacity);
87        if capacity * size_of::<T>() > MAX_CAPACITY_BYTE {
88            capacity = MAX_CAPACITY_BYTE / size_of::<T>();
89            log::info!("Capacity limit of 1GB exceeded, reduced to : {}", capacity);
90        }
91        let buffer: LogData<ArrowBuffer<U>> = LogData::new(BufferBuilder::<T>::new(capacity));
92        
93        // checking if a buffer with the same name already exists
94        let name = buffer.who();
95        if let Some(_) = self.buffers.iter().find(|buffer| buffer.0.who() == name) {
96            log::info!(
97                r#"found existing entry with same name in Arrow buffers, skipping "{name}""#
98            );
99            return;
100        }
101
102        self.buffers.push((Box::new(buffer), T::buffer_data_type()));
103        self.capacities.push(size);
104        self.n_entry += 1;
105    }
106}
107/*
108impl<T, U> Entry<Vec<T>, U> for Arrow
109where
110    T: 'static + BufferDataType + ArrowNativeType + Send + Sync,
111    U: 'static + Send + Sync,
112{
113    fn entry(&mut self, size: usize) {
114        <Arrow as Entry<T, U>>::entry(self, size);
115    }
116}
117 */
118impl Display for Arrow {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        if self.n_entry > 0 {
121            writeln!(f, "Arrow logger:")?;
122            writeln!(f, " - data:")?;
123            for ((buffer, _), capacity) in self.buffers.iter().zip(self.capacities.iter()) {
124                writeln!(f, "   - {:>8}:{:>4}", buffer.who(), capacity)?;
125            }
126            write!(
127                f,
128                " - steps #: {}/{}/{}",
129                self.n_step,
130                self.step / self.n_entry,
131                self.count / self.n_entry
132            )?;
133            return Ok(());
134        }
135        if let Some(record) = &self.record {
136            write!(
137                f,
138                "Arrow logger {:?}:\n{:}",
139                (record.num_rows(), record.num_columns()),
140                record
141                    .schema()
142                    .flattened_fields()
143                    .iter()
144                    .step_by(2)
145                    .map(|field| format!(" - {}", field.name()))
146                    .collect::<Vec<_>>()
147                    .join("\n")
148            )?;
149            return Ok(());
150        }
151        Ok(())
152    }
153}
154
155impl Drop for Arrow {
156    fn drop(&mut self) {
157        log::info!("{self}");
158        match self.drop_option {
159            DropOption::Save(ref filename) => {
160                let file_name = filename
161                    .as_ref()
162                    .cloned()
163                    .unwrap_or_else(|| "data".to_string());
164                match self.file_format {
165                    FileFormat::Parquet => {
166                        if let Err(e) = self.to_parquet(file_name) {
167                            print_info("Arrow error", Some(&e));
168                        }
169                    }
170                    #[cfg(feature = "matio-rs")]
171                    FileFormat::Matlab(_) => {
172                        if let Err(e) = self.to_mat(file_name) {
173                            print_info("Arrow error", Some(&e));
174                        }
175                    }
176                }
177            }
178            DropOption::NoSave => {
179                log::info!("Dropping Arrow logger without saving.");
180            }
181        }
182    }
183}