1use std::io::Write;
2
3use arrow::datatypes::Metadata;
4use arrow::io::ipc::IpcField;
5use arrow::io::ipc::write::{self, EncodedData, WriteOptions};
6use polars_core::prelude::*;
7#[cfg(feature = "serde")]
8use serde::{Deserialize, Serialize};
9
10use crate::prelude::*;
11use crate::shared::schema_to_arrow_checked;
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
14#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
15#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
16pub struct IpcWriterOptions {
17 pub compression: Option<IpcCompression>,
19 pub compat_level: CompatLevel,
21 pub record_batch_size: Option<usize>,
23 #[cfg_attr(feature = "serde", serde(default))]
25 pub record_batch_statistics: bool,
26}
27
28impl Default for IpcWriterOptions {
29 fn default() -> Self {
30 Self {
31 compression: None,
32 compat_level: CompatLevel::newest(),
33 record_batch_size: None,
34 record_batch_statistics: false,
35 }
36 }
37}
38
39impl IpcWriterOptions {
40 pub fn to_writer<W: Write>(&self, writer: W) -> IpcWriter<W> {
41 IpcWriter::new(writer)
42 .with_compression(self.compression)
43 .with_record_batch_size(self.record_batch_size)
44 .with_record_batch_statistics(self.record_batch_statistics)
45 }
46}
47
48#[must_use]
75pub struct IpcWriter<W> {
76 pub(super) writer: W,
77 pub(super) compression: Option<IpcCompression>,
78 pub(super) compat_level: CompatLevel,
80 pub(super) record_batch_size: Option<usize>,
81 pub(super) record_batch_statistics: bool,
82 pub(super) parallel: bool,
83 pub(super) custom_schema_metadata: Option<Arc<Metadata>>,
84}
85
86impl<W: Write> IpcWriter<W> {
87 pub fn with_compression(mut self, compression: Option<IpcCompression>) -> Self {
89 self.compression = compression;
90 self
91 }
92
93 pub fn with_compat_level(mut self, compat_level: CompatLevel) -> Self {
94 self.compat_level = compat_level;
95 self
96 }
97
98 pub fn with_record_batch_size(mut self, record_batch_size: Option<usize>) -> Self {
99 self.record_batch_size = record_batch_size;
100 self
101 }
102
103 pub fn with_record_batch_statistics(mut self, record_batch_statistics: bool) -> Self {
104 self.record_batch_statistics = record_batch_statistics;
105 self
106 }
107
108 pub fn with_parallel(mut self, parallel: bool) -> Self {
109 self.parallel = parallel;
110 self
111 }
112
113 pub fn batched(
114 self,
115 schema: &Schema,
116 ipc_fields: Vec<IpcField>,
117 ) -> PolarsResult<BatchedWriter<W>> {
118 let schema = schema_to_arrow_checked(schema, self.compat_level, "ipc")?;
119 let mut writer = write::FileWriter::new(
120 self.writer,
121 Arc::new(schema),
122 Some(ipc_fields),
123 WriteOptions {
124 compression: self.compression.map(|c| c.into()),
125 },
126 );
127 writer.start()?;
128
129 Ok(BatchedWriter {
130 writer,
131 compat_level: self.compat_level,
132 })
133 }
134
135 pub fn set_custom_schema_metadata(&mut self, custom_metadata: Arc<Metadata>) {
137 self.custom_schema_metadata = Some(custom_metadata);
138 }
139}
140
141impl<W> SerWriter<W> for IpcWriter<W>
142where
143 W: Write,
144{
145 fn new(writer: W) -> Self {
146 IpcWriter {
147 writer,
148 compression: None,
149 compat_level: CompatLevel::newest(),
150 record_batch_size: None,
151 record_batch_statistics: false,
152 parallel: true,
153 custom_schema_metadata: None,
154 }
155 }
156
157 fn finish(&mut self, df: &mut DataFrame) -> PolarsResult<()> {
158 let schema = schema_to_arrow_checked(df.schema(), self.compat_level, "ipc")?;
159 let mut ipc_writer = write::FileWriter::try_new(
160 &mut self.writer,
161 Arc::new(schema),
162 None,
163 WriteOptions {
164 compression: self.compression.map(|c| c.into()),
165 },
166 )?;
167 if let Some(custom_metadata) = &self.custom_schema_metadata {
168 ipc_writer.set_custom_schema_metadata(Arc::clone(custom_metadata));
169 }
170
171 if self.parallel {
172 df.align_chunks_par();
173 } else {
174 df.align_chunks();
175 }
176 let iter = df.iter_chunks(self.compat_level, true);
177
178 for batch in iter {
179 ipc_writer.write(&batch, None)?
180 }
181 ipc_writer.finish()?;
182 Ok(())
183 }
184}
185
186pub struct BatchedWriter<W: Write> {
187 writer: write::FileWriter<W>,
188 compat_level: CompatLevel,
189}
190
191impl<W: Write> BatchedWriter<W> {
192 pub fn write_batch(&mut self, df: &DataFrame) -> PolarsResult<()> {
197 let iter = df.iter_chunks(self.compat_level, true);
198 for batch in iter {
199 self.writer.write(&batch, None)?
200 }
201 Ok(())
202 }
203
204 pub fn write_encoded(
209 &mut self,
210 dictionaries: &[EncodedData],
211 message: &EncodedData,
212 ) -> PolarsResult<()> {
213 self.writer.write_encoded(dictionaries, message)
214 }
215
216 pub fn write_encoded_dictionaries(
217 &mut self,
218 encoded_dictionaries: &[EncodedData],
219 ) -> PolarsResult<()> {
220 self.writer.write_encoded_dictionaries(encoded_dictionaries)
221 }
222
223 pub fn finish(&mut self) -> PolarsResult<()> {
225 self.writer.finish()?;
226 Ok(())
227 }
228}
229
230#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
232#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
233#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
234pub enum IpcCompression {
235 LZ4,
237 ZSTD(polars_utils::compression::ZstdLevel),
239}
240
241impl Default for IpcCompression {
242 fn default() -> Self {
243 Self::ZSTD(Default::default())
244 }
245}
246
247impl From<IpcCompression> for write::Compression {
248 fn from(value: IpcCompression) -> Self {
249 match value {
250 IpcCompression::LZ4 => write::Compression::LZ4,
251 IpcCompression::ZSTD(level) => write::Compression::ZSTD(level),
252 }
253 }
254}