1use std::io::Write;
3use std::num::NonZero;
4
5use dwarfs::section::{CompressAlgo, Header, MagicVersion, SectionIndexEntry, SectionType};
6use dwarfs::zerocopy::IntoBytes;
7use zerocopy::FromBytes;
8
9use crate::ordered_parallel::OrderedParallel;
10use crate::{ErrorInner, Result};
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14#[non_exhaustive]
15pub enum CompressParam {
16 None,
18 #[cfg(feature = "zstd")]
20 Zstd(zstd_safe::CompressionLevel),
21 #[cfg(feature = "lzma")]
23 Lzma(u32),
24}
25
26#[derive(Debug)]
28pub struct Writer<W: ?Sized> {
29 workers: OrderedParallel<Result<Vec<u8>>>,
30 initiated_section_count: u32,
32 index: IndexBuilder,
33
34 w: W,
35}
36
37#[derive(Debug, Default)]
38struct IndexBuilder {
39 index: Vec<SectionIndexEntry>,
40 next_offset: u64,
41}
42
43impl IndexBuilder {
44 fn push(&mut self, typ: SectionType, sec_raw_len: usize) -> Result<()> {
45 let ent = SectionIndexEntry::new(typ, self.next_offset).expect("checked by last write");
46 self.next_offset = u64::try_from(sec_raw_len)
47 .ok()
48 .and_then(|l| l.checked_add(self.next_offset))
49 .filter(|&n| n < 1u64 << 48)
50 .ok_or(ErrorInner::Limit("archive size exceeds 2^48 bytes"))?;
51 self.index.push(ent);
52 Ok(())
53 }
54}
55
56impl<W> Writer<W> {
57 pub fn new(w: W) -> std::io::Result<Self> {
59 let thread_cnt = std::thread::available_parallelism()?;
60 Self::new_with_threads(w, thread_cnt)
61 }
62
63 pub fn new_with_threads(w: W, thread_cnt: NonZero<usize>) -> std::io::Result<Self> {
65 let workers = OrderedParallel::new("compressor", thread_cnt)?;
66 Ok(Self {
67 workers,
68 initiated_section_count: 0,
69 index: IndexBuilder::default(),
70 w,
71 })
72 }
73}
74
75impl<W: ?Sized> Writer<W> {
76 pub fn get_ref(&self) -> &W {
78 &self.w
79 }
80
81 pub fn get_mut(&mut self) -> &mut W {
83 &mut self.w
84 }
85
86 pub fn into_inner(self) -> W
88 where
89 W: Sized,
90 {
91 self.w
92 }
93}
94
95impl<W: Write> Writer<W> {
96 #[must_use]
98 pub fn section_count(&self) -> u32 {
99 self.initiated_section_count
101 }
102
103 pub fn finish(mut self) -> Result<W> {
105 self.workers.stop();
107 while let Some(iter) = self.workers.wait_and_get() {
108 Self::commit_completed(iter, &mut self.w, &mut self.index)?;
109 }
110
111 let index_byte_len = self.index.index.as_bytes().len() + size_of::<SectionIndexEntry>();
113 self.index
114 .push(SectionType::SECTION_INDEX, index_byte_len)?;
115 let sec = Self::seal_section(
116 self.section_count(),
117 SectionType::SECTION_INDEX,
118 CompressParam::None,
119 self.index.index.as_bytes(),
120 )?;
121 self.w.write_all(&sec)?;
122
123 Ok(self.w)
124 }
125
126 fn commit_completed(
127 completed: impl Iterator<Item = Result<Vec<u8>>>,
128 w: &mut W,
129 index: &mut IndexBuilder,
130 ) -> Result<()> {
131 for ret in completed {
132 let sec = ret?;
133 let off = std::mem::offset_of!(Header, section_type);
134 let typ = SectionType::read_from_prefix(&sec[off..]).unwrap().0;
135 w.write_all(&sec)?;
136 index.push(typ, sec.len())?;
137 }
138 Ok(())
139 }
140
141 pub fn write_section(
143 &mut self,
144 section_type: SectionType,
145 compression: CompressParam,
146 payload: &[u8],
147 ) -> Result<()> {
148 assert!(u64::try_from(size_of::<Header>() + payload.len()).is_ok());
150
151 let section_number = self.section_count();
152 self.initiated_section_count = self
153 .initiated_section_count
154 .checked_add(1)
155 .ok_or(ErrorInner::Limit("section count exceeds 2^32"))?;
156
157 let payload = payload.to_vec();
158 Self::commit_completed(
159 self.workers.submit_and_get(move || {
160 Self::seal_section(section_number, section_type, compression, &payload)
161 }),
162 &mut self.w,
163 &mut self.index,
164 )
165 }
166
167 fn seal_section(
169 section_number: u32,
170 section_type: SectionType,
171 compression: CompressParam,
172 payload: &[u8],
173 ) -> Result<Vec<u8>> {
174 let mut buf = vec![0u8; size_of::<Header>() + payload.len()];
175 #[cfg_attr(not(feature = "default"), allow(unused_labels))]
176 let (compress_algo, compressed_len) = 'compressed: {
177 let compressed_buf = &mut buf[size_of::<Header>()..];
178 match compression {
179 CompressParam::None => {}
180
181 #[cfg(feature = "zstd")]
182 #[expect(non_upper_case_globals, reason = "name from C")]
183 CompressParam::Zstd(lvl) => {
184 const ZSTD_error_dstSize_tooSmall: zstd_safe::ErrorCode = -70isize as usize;
186
187 match zstd_safe::compress(compressed_buf, payload, lvl) {
188 Ok(compressed_len) => {
189 assert!(compressed_len <= payload.len());
190 break 'compressed (CompressAlgo::ZSTD, compressed_len);
191 }
192 Err(ZSTD_error_dstSize_tooSmall) => {}
193 Err(code) => {
194 let err = std::io::Error::new(
195 std::io::ErrorKind::InvalidInput,
196 format!(
197 "ZSTD compression failed (code={}): {}",
198 code,
199 zstd_safe::get_error_name(code),
200 ),
201 );
202 return Err(ErrorInner::Compress(err).into());
203 }
204 }
205 }
206
207 #[cfg(feature = "lzma")]
208 CompressParam::Lzma(lvl) => {
209 if let Some(compressed_len) = (|| {
210 use liblzma::stream::{Action, Check, Status, Stream};
211
212 let mut encoder = Stream::new_easy_encoder(lvl, Check::Crc64)?;
215
216 match encoder.process(payload, compressed_buf, Action::Run)? {
217 Status::Ok if encoder.total_in() == payload.len() as u64 => {}
219 Status::Ok | Status::MemNeeded => return Ok(None),
220 Status::StreamEnd | Status::GetCheck => unreachable!(),
221 }
222 match encoder.process(
223 &[],
224 &mut compressed_buf[encoder.total_out() as usize..],
225 Action::Finish,
226 )? {
227 Status::StreamEnd => {}
228 Status::MemNeeded => return Ok(None),
229 Status::Ok | Status::GetCheck => unreachable!(),
230 }
231
232 Ok::<_, std::io::Error>(Some(encoder.total_out() as usize))
233 })()
234 .map_err(ErrorInner::Compress)?
235 {
236 break 'compressed (CompressAlgo::LZMA, compressed_len);
237 }
238 }
239 }
240 compressed_buf.copy_from_slice(payload);
241 (CompressAlgo::NONE, payload.len())
242 };
243 buf.truncate(size_of::<Header>() + compressed_len);
244 let (header_buf, compressed_buf) = buf.split_at_mut(size_of::<Header>());
245
246 let mut header = Header {
247 magic_version: MagicVersion::LATEST,
248 slow_hash: [0u8; 32],
249 fast_hash: [0u8; 8],
250 section_number: section_number.into(),
251 section_type,
252 compress_algo,
253 payload_size: 0.into(),
254 };
255 header.update_size_and_checksum(compressed_buf);
256 header_buf.copy_from_slice(header.as_bytes());
257
258 Ok(buf)
259 }
260
261 pub fn write_metadata_sections(
263 &mut self,
264 metadata: &dwarfs::metadata::Metadata,
265 compression: CompressParam,
266 ) -> Result<()> {
267 let (schema, metadata_bytes) = metadata.to_schema_and_bytes()?;
268 let schema_bytes = schema.to_bytes()?;
269 self.write_section(SectionType::METADATA_V2_SCHEMA, compression, &schema_bytes)?;
270 self.write_section(SectionType::METADATA_V2, compression, &metadata_bytes)
271 }
272}