polars_arrow/io/ipc/write/
writer.rs1use std::io::Write;
2use std::sync::Arc;
3
4use arrow_format::ipc::planus::Builder;
5use polars_error::{PolarsResult, polars_bail};
6
7use super::super::{ARROW_MAGIC_V2, IpcField};
8use super::common::{DictionaryTracker, EncodedData, WriteOptions};
9use super::common_sync::{write_continuation, write_message};
10use super::{default_ipc_fields, schema, schema_to_bytes};
11use crate::array::Array;
12use crate::datatypes::*;
13use crate::io::ipc::write::common::encode_chunk_amortized;
14use crate::record_batch::RecordBatchT;
15
16#[derive(Clone, Copy, PartialEq, Eq)]
17pub(crate) enum State {
18 None,
19 Started,
20 Finished,
21}
22
23pub struct FileWriter<W: Write> {
25 pub(crate) writer: W,
27 pub(crate) options: WriteOptions,
29 pub(crate) schema: ArrowSchemaRef,
31 pub(crate) ipc_fields: Vec<IpcField>,
32 pub(crate) block_offsets: usize,
34 pub(crate) dictionary_blocks: Vec<arrow_format::ipc::Block>,
36 pub(crate) record_blocks: Vec<arrow_format::ipc::Block>,
38 pub(crate) state: State,
40 pub(crate) dictionary_tracker: DictionaryTracker,
42 pub(crate) encoded_message: EncodedData,
44 pub(crate) custom_schema_metadata: Option<Arc<Metadata>>,
46}
47
48impl<W: Write> FileWriter<W> {
49 pub fn try_new(
51 writer: W,
52 schema: ArrowSchemaRef,
53 ipc_fields: Option<Vec<IpcField>>,
54 options: WriteOptions,
55 ) -> PolarsResult<Self> {
56 let mut slf = Self::new(writer, schema, ipc_fields, options);
57 slf.start()?;
58
59 Ok(slf)
60 }
61
62 pub fn new(
64 writer: W,
65 schema: ArrowSchemaRef,
66 ipc_fields: Option<Vec<IpcField>>,
67 options: WriteOptions,
68 ) -> Self {
69 let ipc_fields = if let Some(ipc_fields) = ipc_fields {
70 ipc_fields
71 } else {
72 default_ipc_fields(schema.iter_values())
73 };
74
75 Self {
76 writer,
77 options,
78 schema,
79 ipc_fields,
80 block_offsets: 0,
81 dictionary_blocks: vec![],
82 record_blocks: vec![],
83 state: State::None,
84 dictionary_tracker: DictionaryTracker {
85 dictionaries: Default::default(),
86 cannot_replace: true,
87 },
88 encoded_message: Default::default(),
89 custom_schema_metadata: None,
90 }
91 }
92
93 pub fn into_inner(self) -> W {
95 self.writer
96 }
97
98 pub fn get_scratches(&mut self) -> EncodedData {
101 std::mem::take(&mut self.encoded_message)
102 }
103 pub fn set_scratches(&mut self, scratches: EncodedData) {
106 self.encoded_message = scratches;
107 }
108
109 pub fn start(&mut self) -> PolarsResult<()> {
113 if self.state != State::None {
114 polars_bail!(oos = "The IPC file can only be started once");
115 }
116 self.writer.write_all(&ARROW_MAGIC_V2[..])?;
118 self.writer.write_all(&[0, 0])?;
120 let encoded_message = EncodedData {
123 ipc_message: schema_to_bytes(
124 &self.schema,
125 &self.ipc_fields,
126 None,
128 ),
129 arrow_data: vec![],
130 };
131
132 let (meta, data) = write_message(&mut self.writer, &encoded_message)?;
133 self.block_offsets += meta + data + 8; self.state = State::Started;
135 Ok(())
136 }
137
138 pub fn write(
140 &mut self,
141 chunk: &RecordBatchT<Box<dyn Array>>,
142 ipc_fields: Option<&[IpcField]>,
143 ) -> PolarsResult<()> {
144 if self.state != State::Started {
145 polars_bail!(
146 oos = "The IPC file must be started before it can be written to. Call `start` before `write`"
147 );
148 }
149
150 let ipc_fields = if let Some(ipc_fields) = ipc_fields {
151 ipc_fields
152 } else {
153 self.ipc_fields.as_ref()
154 };
155 let encoded_dictionaries = encode_chunk_amortized(
156 chunk,
157 ipc_fields,
158 &mut self.dictionary_tracker,
159 &self.options,
160 &mut self.encoded_message,
161 )?;
162
163 let encoded_message = std::mem::take(&mut self.encoded_message);
164 self.write_encoded(&encoded_dictionaries[..], &encoded_message)?;
165 self.encoded_message = encoded_message;
166
167 Ok(())
168 }
169
170 pub fn write_encoded(
171 &mut self,
172 encoded_dictionaries: &[EncodedData],
173 encoded_message: &EncodedData,
174 ) -> PolarsResult<()> {
175 if self.state != State::Started {
176 polars_bail!(
177 oos = "The IPC file must be started before it can be written to. Call `start` before `write`"
178 );
179 }
180
181 for encoded_dictionary in encoded_dictionaries {
183 let (meta, data) = write_message(&mut self.writer, encoded_dictionary)?;
184
185 let block = arrow_format::ipc::Block {
186 offset: self.block_offsets as i64,
187 meta_data_length: meta as i32,
188 body_length: data as i64,
189 };
190 self.dictionary_blocks.push(block);
191 self.block_offsets += meta + data;
192 }
193
194 self.write_encoded_record_batch(encoded_message)?;
195
196 Ok(())
197 }
198
199 pub fn write_encoded_record_batch(
200 &mut self,
201 encoded_message: &EncodedData,
202 ) -> PolarsResult<()> {
203 let (meta, data) = write_message(&mut self.writer, encoded_message)?;
204 let block = arrow_format::ipc::Block {
206 offset: self.block_offsets as i64,
207 meta_data_length: meta as i32, body_length: data as i64,
209 };
210 self.record_blocks.push(block);
211 self.block_offsets += meta + data;
212
213 Ok(())
214 }
215
216 pub fn finish(&mut self) -> PolarsResult<()> {
218 if self.state != State::Started {
219 polars_bail!(
220 oos = "The IPC file must be started before it can be finished. Call `start` before `finish`"
221 );
222 }
223
224 write_continuation(&mut self.writer, 0)?;
226
227 let schema = schema::serialize_schema(
228 &self.schema,
229 &self.ipc_fields,
230 self.custom_schema_metadata.as_deref(),
231 );
232
233 let root = arrow_format::ipc::Footer {
234 version: arrow_format::ipc::MetadataVersion::V5,
235 schema: Some(Box::new(schema)),
236 dictionaries: Some(std::mem::take(&mut self.dictionary_blocks)),
237 record_batches: Some(std::mem::take(&mut self.record_blocks)),
238 custom_metadata: None,
239 };
240 let mut builder = Builder::new();
241 let footer_data = builder.finish(&root, None);
242 self.writer.write_all(footer_data)?;
243 self.writer
244 .write_all(&(footer_data.len() as i32).to_le_bytes())?;
245 self.writer.write_all(&ARROW_MAGIC_V2)?;
246 self.writer.flush()?;
247 self.state = State::Finished;
248
249 Ok(())
250 }
251
252 pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
254 self.custom_schema_metadata = Some(custom_metadata);
255 }
256}