1use std::io::Write;
2use std::sync::Arc;
3use std::thread::JoinHandle;
4
5use crate::erratum::join;
6use anyhow::{anyhow, bail, Result};
7use arrow2::array::Array;
8use arrow2::chunk::Chunk;
9use arrow2::datatypes::Schema;
10use arrow2::datatypes::{Field as ArrowField, Metadata};
11use arrow2::error::Error as ArrowError;
12use arrow2::io::parquet::write::{
13 CompressionOptions, FileWriter, RowGroupIterator, Version, WriteOptions,
14};
15use crossbeam_channel::{SendError, Sender};
16use log::info;
17
18use crate::table::TableField;
19
20pub struct Writer<W> {
21 schema: Box<[TableField]>,
22 threads: Vec<JoinHandle<Result<W>>>,
23 tx: Option<Sender<Result<Chunk<Arc<dyn Array>>, ArrowError>>>,
24}
25
26fn out_thread<W: Write + Send + 'static>(
27 mut inner: W,
28 schema: &[TableField],
29 rx: impl IntoIterator<Item = Result<Chunk<Arc<dyn Array>>, ArrowError>> + Send + 'static,
30) -> Result<JoinHandle<Result<W>>> {
31 let arrow_schema = Schema::from(
32 schema
33 .iter()
34 .map(|f| ArrowField {
35 name: f.name.to_string(),
36 data_type: f.kind.to_arrow(),
37 is_nullable: f.nullable,
38 metadata: Metadata::default(),
39 })
40 .collect::<Vec<_>>(),
41 );
42
43 let write_options = WriteOptions {
44 write_statistics: true,
45 compression: CompressionOptions::Zstd(None),
46 version: Version::V2,
47 data_pagesize_limit: None,
49 };
50 let encodings = schema.iter().map(|f| vec![f.encoding]).collect();
51
52 Ok(std::thread::spawn(move || -> Result<W> {
53 let rg_iter =
54 RowGroupIterator::try_new(rx.into_iter(), &arrow_schema, write_options, encodings)?;
55 let mut writer = FileWriter::try_new(&mut inner, arrow_schema, write_options)?;
56
57 for rg in rg_iter {
58 writer.write(rg?)?;
59 }
60
61 writer.end(None)?;
62 Ok(inner)
63 }))
64}
65
66impl<W: Write + Send + 'static> Writer<W> {
67 pub fn new(
68 inner: impl IntoIterator<Item = W, IntoIter = impl Iterator<Item = W> + ExactSizeIterator>,
69 schema: &[TableField],
70 ) -> Result<Self> {
71 let inner = inner.into_iter();
72
73 let (tx, rx) = crossbeam_channel::bounded(inner.len());
74
75 let threads = inner
76 .into_iter()
77 .map(|inner| out_thread(inner, schema, rx.clone()))
78 .collect::<Result<Vec<_>>>()?;
79
80 Ok(Self {
81 schema: schema.to_vec().into_boxed_slice(),
82 threads,
83 tx: Some(tx),
84 })
85 }
86
87 pub fn find_field(&self, name: &str) -> Option<(usize, &TableField)> {
88 self.schema.iter().enumerate().find(|(_, f)| f.name == name)
89 }
90
91 pub fn submit_batch(&mut self, batch: impl IntoIterator<Item = Arc<dyn Array>>) -> Result<()> {
92 let result = Chunk::try_new(batch.into_iter().collect())?;
93
94 let tx = self
95 .tx
96 .as_mut()
97 .ok_or_else(|| anyhow!("previously failed"))?;
98
99 if let Err(SendError(_)) = tx.send(Ok(result)) {
100 drop(self.tx.take());
103
104 join_all(&mut self.threads)?;
106
107 bail!("all of the threads have gone, but none have bothered to tell us why");
108 }
109
110 Ok(())
111 }
112
113 pub fn finish(mut self) -> Result<Vec<W>> {
114 if self.threads.is_empty() {
115 bail!("had previously failed");
116 }
117
118 info!("finishing...");
119 drop(self.tx);
120
121 join_all(&mut self.threads)
122 }
123}
124
125fn join_all<T>(threads: &mut Vec<JoinHandle<Result<T>>>) -> Result<Vec<T>> {
126 let mut ret = Vec::with_capacity(threads.len());
127 while let Some(thread) = threads.pop() {
128 ret.push(join(thread)?);
129 }
130 Ok(ret)
131}