polars_arrow/io/ipc/write/
writer.rs

1use std::io::Write;
2use std::sync::Arc;
3
4use arrow_format::ipc::planus::Builder;
5use polars_error::{PolarsResult, polars_bail};
6
7use super::super::{ARROW_MAGIC_V2, IpcField};
8use super::common::{DictionaryTracker, EncodedData, WriteOptions};
9use super::common_sync::{write_continuation, write_message};
10use super::{default_ipc_fields, schema, schema_to_bytes};
11use crate::array::Array;
12use crate::datatypes::*;
13use crate::io::ipc::write::common::encode_chunk_amortized;
14use crate::record_batch::RecordBatchT;
15
16#[derive(Clone, Copy, PartialEq, Eq)]
17pub(crate) enum State {
18    None,
19    Started,
20    Finished,
21}
22
23/// Arrow file writer
24pub struct FileWriter<W: Write> {
25    /// The object to write to
26    pub(crate) writer: W,
27    /// IPC write options
28    pub(crate) options: WriteOptions,
29    /// A reference to the schema, used in validating record batches
30    pub(crate) schema: ArrowSchemaRef,
31    pub(crate) ipc_fields: Vec<IpcField>,
32    /// The number of bytes between each block of bytes, as an offset for random access
33    pub(crate) block_offsets: usize,
34    /// Dictionary blocks that will be written as part of the IPC footer
35    pub(crate) dictionary_blocks: Vec<arrow_format::ipc::Block>,
36    /// Record blocks that will be written as part of the IPC footer
37    pub(crate) record_blocks: Vec<arrow_format::ipc::Block>,
38    /// Whether the writer footer has been written, and the writer is finished
39    pub(crate) state: State,
40    /// Keeps track of dictionaries that have been written
41    pub(crate) dictionary_tracker: DictionaryTracker,
42    /// Buffer/scratch that is reused between writes
43    pub(crate) encoded_message: EncodedData,
44    /// Custom schema-level metadata
45    pub(crate) custom_schema_metadata: Option<Arc<Metadata>>,
46}
47
48impl<W: Write> FileWriter<W> {
49    /// Creates a new [`FileWriter`] and writes the header to `writer`
50    pub fn try_new(
51        writer: W,
52        schema: ArrowSchemaRef,
53        ipc_fields: Option<Vec<IpcField>>,
54        options: WriteOptions,
55    ) -> PolarsResult<Self> {
56        let mut slf = Self::new(writer, schema, ipc_fields, options);
57        slf.start()?;
58
59        Ok(slf)
60    }
61
62    /// Creates a new [`FileWriter`].
63    pub fn new(
64        writer: W,
65        schema: ArrowSchemaRef,
66        ipc_fields: Option<Vec<IpcField>>,
67        options: WriteOptions,
68    ) -> Self {
69        let ipc_fields = if let Some(ipc_fields) = ipc_fields {
70            ipc_fields
71        } else {
72            default_ipc_fields(schema.iter_values())
73        };
74
75        Self {
76            writer,
77            options,
78            schema,
79            ipc_fields,
80            block_offsets: 0,
81            dictionary_blocks: vec![],
82            record_blocks: vec![],
83            state: State::None,
84            dictionary_tracker: DictionaryTracker {
85                dictionaries: Default::default(),
86                cannot_replace: true,
87            },
88            encoded_message: Default::default(),
89            custom_schema_metadata: None,
90        }
91    }
92
93    /// Consumes itself into the inner writer
94    pub fn into_inner(self) -> W {
95        self.writer
96    }
97
98    /// Get the inner memory scratches so they can be reused in a new writer.
99    /// This can be utilized to save memory allocations for performance reasons.
100    pub fn get_scratches(&mut self) -> EncodedData {
101        std::mem::take(&mut self.encoded_message)
102    }
103    /// Set the inner memory scratches so they can be reused in a new writer.
104    /// This can be utilized to save memory allocations for performance reasons.
105    pub fn set_scratches(&mut self, scratches: EncodedData) {
106        self.encoded_message = scratches;
107    }
108
109    /// Writes the header and first (schema) message to the file.
110    /// # Errors
111    /// Errors if the file has been started or has finished.
112    pub fn start(&mut self) -> PolarsResult<()> {
113        if self.state != State::None {
114            polars_bail!(oos = "The IPC file can only be started once");
115        }
116        // write magic to header
117        self.writer.write_all(&ARROW_MAGIC_V2[..])?;
118        // create an 8-byte boundary after the header
119        self.writer.write_all(&[0, 0])?;
120        // write the schema, set the written bytes to the schema
121
122        let encoded_message = EncodedData {
123            ipc_message: schema_to_bytes(
124                &self.schema,
125                &self.ipc_fields,
126                // No need to pass metadata here, as it is already written to the footer in `finish`
127                None,
128            ),
129            arrow_data: vec![],
130        };
131
132        let (meta, data) = write_message(&mut self.writer, &encoded_message)?;
133        self.block_offsets += meta + data + 8; // 8 <=> arrow magic + 2 bytes for alignment
134        self.state = State::Started;
135        Ok(())
136    }
137
138    /// Writes [`RecordBatchT`] to the file
139    pub fn write(
140        &mut self,
141        chunk: &RecordBatchT<Box<dyn Array>>,
142        ipc_fields: Option<&[IpcField]>,
143    ) -> PolarsResult<()> {
144        if self.state != State::Started {
145            polars_bail!(
146                oos = "The IPC file must be started before it can be written to. Call `start` before `write`"
147            );
148        }
149
150        let ipc_fields = if let Some(ipc_fields) = ipc_fields {
151            ipc_fields
152        } else {
153            self.ipc_fields.as_ref()
154        };
155        let encoded_dictionaries = encode_chunk_amortized(
156            chunk,
157            ipc_fields,
158            &mut self.dictionary_tracker,
159            &self.options,
160            &mut self.encoded_message,
161        )?;
162
163        let encoded_message = std::mem::take(&mut self.encoded_message);
164        self.write_encoded(&encoded_dictionaries[..], &encoded_message)?;
165        self.encoded_message = encoded_message;
166
167        Ok(())
168    }
169
170    pub fn write_encoded(
171        &mut self,
172        encoded_dictionaries: &[EncodedData],
173        encoded_message: &EncodedData,
174    ) -> PolarsResult<()> {
175        if self.state != State::Started {
176            polars_bail!(
177                oos = "The IPC file must be started before it can be written to. Call `start` before `write`"
178            );
179        }
180
181        // add all dictionaries
182        for encoded_dictionary in encoded_dictionaries {
183            let (meta, data) = write_message(&mut self.writer, encoded_dictionary)?;
184
185            let block = arrow_format::ipc::Block {
186                offset: self.block_offsets as i64,
187                meta_data_length: meta as i32,
188                body_length: data as i64,
189            };
190            self.dictionary_blocks.push(block);
191            self.block_offsets += meta + data;
192        }
193
194        self.write_encoded_record_batch(encoded_message)?;
195
196        Ok(())
197    }
198
199    pub fn write_encoded_record_batch(
200        &mut self,
201        encoded_message: &EncodedData,
202    ) -> PolarsResult<()> {
203        let (meta, data) = write_message(&mut self.writer, encoded_message)?;
204        // add a record block for the footer
205        let block = arrow_format::ipc::Block {
206            offset: self.block_offsets as i64,
207            meta_data_length: meta as i32, // TODO: is this still applicable?
208            body_length: data as i64,
209        };
210        self.record_blocks.push(block);
211        self.block_offsets += meta + data;
212
213        Ok(())
214    }
215
216    /// Write footer and closing tag, then mark the writer as done
217    pub fn finish(&mut self) -> PolarsResult<()> {
218        if self.state != State::Started {
219            polars_bail!(
220                oos = "The IPC file must be started before it can be finished. Call `start` before `finish`"
221            );
222        }
223
224        // write EOS
225        write_continuation(&mut self.writer, 0)?;
226
227        let schema = schema::serialize_schema(
228            &self.schema,
229            &self.ipc_fields,
230            self.custom_schema_metadata.as_deref(),
231        );
232
233        let root = arrow_format::ipc::Footer {
234            version: arrow_format::ipc::MetadataVersion::V5,
235            schema: Some(Box::new(schema)),
236            dictionaries: Some(std::mem::take(&mut self.dictionary_blocks)),
237            record_batches: Some(std::mem::take(&mut self.record_blocks)),
238            custom_metadata: None,
239        };
240        let mut builder = Builder::new();
241        let footer_data = builder.finish(&root, None);
242        self.writer.write_all(footer_data)?;
243        self.writer
244            .write_all(&(footer_data.len() as i32).to_le_bytes())?;
245        self.writer.write_all(&ARROW_MAGIC_V2)?;
246        self.writer.flush()?;
247        self.state = State::Finished;
248
249        Ok(())
250    }
251
252    /// Sets custom schema metadata. Must be called before `start` is called
253    pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
254        self.custom_schema_metadata = Some(custom_metadata);
255    }
256}