pack_it/
write.rs

1use std::io::Write;
2use std::sync::Arc;
3use std::thread::JoinHandle;
4
5use crate::erratum::join;
6use anyhow::{anyhow, bail, Result};
7use arrow2::array::Array;
8use arrow2::chunk::Chunk;
9use arrow2::datatypes::Schema;
10use arrow2::datatypes::{Field as ArrowField, Metadata};
11use arrow2::error::Error as ArrowError;
12use arrow2::io::parquet::write::{
13    CompressionOptions, FileWriter, RowGroupIterator, Version, WriteOptions,
14};
15use crossbeam_channel::{SendError, Sender};
16use log::info;
17
18use crate::table::TableField;
19
20pub struct Writer<W> {
21    schema: Box<[TableField]>,
22    threads: Vec<JoinHandle<Result<W>>>,
23    tx: Option<Sender<Result<Chunk<Arc<dyn Array>>, ArrowError>>>,
24}
25
26fn out_thread<W: Write + Send + 'static>(
27    mut inner: W,
28    schema: &[TableField],
29    rx: impl IntoIterator<Item = Result<Chunk<Arc<dyn Array>>, ArrowError>> + Send + 'static,
30) -> Result<JoinHandle<Result<W>>> {
31    let arrow_schema = Schema::from(
32        schema
33            .iter()
34            .map(|f| ArrowField {
35                name: f.name.to_string(),
36                data_type: f.kind.to_arrow(),
37                is_nullable: f.nullable,
38                metadata: Metadata::default(),
39            })
40            .collect::<Vec<_>>(),
41    );
42
43    let write_options = WriteOptions {
44        write_statistics: true,
45        compression: CompressionOptions::Zstd(None),
46        version: Version::V2,
47        // TODO: do we want to override this?
48        data_pagesize_limit: None,
49    };
50    let encodings = schema.iter().map(|f| vec![f.encoding]).collect();
51
52    Ok(std::thread::spawn(move || -> Result<W> {
53        let rg_iter =
54            RowGroupIterator::try_new(rx.into_iter(), &arrow_schema, write_options, encodings)?;
55        let mut writer = FileWriter::try_new(&mut inner, arrow_schema, write_options)?;
56
57        for rg in rg_iter {
58            writer.write(rg?)?;
59        }
60
61        writer.end(None)?;
62        Ok(inner)
63    }))
64}
65
66impl<W: Write + Send + 'static> Writer<W> {
67    pub fn new(
68        inner: impl IntoIterator<Item = W, IntoIter = impl Iterator<Item = W> + ExactSizeIterator>,
69        schema: &[TableField],
70    ) -> Result<Self> {
71        let inner = inner.into_iter();
72
73        let (tx, rx) = crossbeam_channel::bounded(inner.len());
74
75        let threads = inner
76            .into_iter()
77            .map(|inner| out_thread(inner, schema, rx.clone()))
78            .collect::<Result<Vec<_>>>()?;
79
80        Ok(Self {
81            schema: schema.to_vec().into_boxed_slice(),
82            threads,
83            tx: Some(tx),
84        })
85    }
86
87    pub fn find_field(&self, name: &str) -> Option<(usize, &TableField)> {
88        self.schema.iter().enumerate().find(|(_, f)| f.name == name)
89    }
90
91    pub fn submit_batch(&mut self, batch: impl IntoIterator<Item = Arc<dyn Array>>) -> Result<()> {
92        let result = Chunk::try_new(batch.into_iter().collect())?;
93
94        let tx = self
95            .tx
96            .as_mut()
97            .ok_or_else(|| anyhow!("previously failed"))?;
98
99        if let Err(SendError(_)) = tx.send(Ok(result)) {
100            // all of the writers have failed, so we need to die
101            // (this doesn't catch the case where one writer has died)
102            drop(self.tx.take());
103
104            // this should fail
105            join_all(&mut self.threads)?;
106
107            bail!("all of the threads have gone, but none have bothered to tell us why");
108        }
109
110        Ok(())
111    }
112
113    pub fn finish(mut self) -> Result<Vec<W>> {
114        if self.threads.is_empty() {
115            bail!("had previously failed");
116        }
117
118        info!("finishing...");
119        drop(self.tx);
120
121        join_all(&mut self.threads)
122    }
123}
124
125fn join_all<T>(threads: &mut Vec<JoinHandle<Result<T>>>) -> Result<Vec<T>> {
126    let mut ret = Vec::with_capacity(threads.len());
127    while let Some(thread) = threads.pop() {
128        ret.push(join(thread)?);
129    }
130    Ok(ret)
131}