dwarfs_enc/
section.rs

1//! DwarFS section writer.
2use std::io::Write;
3use std::num::NonZero;
4
5use dwarfs::section::{CompressAlgo, Header, MagicVersion, SectionIndexEntry, SectionType};
6use dwarfs::zerocopy::IntoBytes;
7use zerocopy::FromBytes;
8
9use crate::ordered_parallel::OrderedParallel;
10use crate::{ErrorInner, Result};
11
12/// The section compression parameter.
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
14#[non_exhaustive]
15pub enum CompressParam {
16    /// No compression.
17    None,
18    /// Compress with a given ZSTD level. Requires feature `zstd`.
19    #[cfg(feature = "zstd")]
20    Zstd(zstd_safe::CompressionLevel),
21    /// Compress with a given LZMA (aka. xz) level. Requires feature `lzma`.
22    #[cfg(feature = "lzma")]
23    Lzma(u32),
24}
25
26/// DwarFS section writer.
27#[derive(Debug)]
28pub struct Writer<W: ?Sized> {
29    workers: OrderedParallel<Result<Vec<u8>>>,
30    /// The total number of sections initiated, including ones that are not written yet.
31    initiated_section_count: u32,
32    index: IndexBuilder,
33
34    w: W,
35}
36
37#[derive(Debug, Default)]
38struct IndexBuilder {
39    index: Vec<SectionIndexEntry>,
40    next_offset: u64,
41}
42
43impl IndexBuilder {
44    fn push(&mut self, typ: SectionType, sec_raw_len: usize) -> Result<()> {
45        let ent = SectionIndexEntry::new(typ, self.next_offset).expect("checked by last write");
46        self.next_offset = u64::try_from(sec_raw_len)
47            .ok()
48            .and_then(|l| l.checked_add(self.next_offset))
49            .filter(|&n| n < 1u64 << 48)
50            .ok_or(ErrorInner::Limit("archive size exceeds 2^48 bytes"))?;
51        self.index.push(ent);
52        Ok(())
53    }
54}
55
56impl<W> Writer<W> {
57    /// Create a default multi-threaded section writer.
58    pub fn new(w: W) -> std::io::Result<Self> {
59        let thread_cnt = std::thread::available_parallelism()?;
60        Self::new_with_threads(w, thread_cnt)
61    }
62
63    /// Create a section writer with specific parallelism.
64    pub fn new_with_threads(w: W, thread_cnt: NonZero<usize>) -> std::io::Result<Self> {
65        let workers = OrderedParallel::new("compressor", thread_cnt)?;
66        Ok(Self {
67            workers,
68            initiated_section_count: 0,
69            index: IndexBuilder::default(),
70            w,
71        })
72    }
73}
74
75impl<W: ?Sized> Writer<W> {
76    /// Get a reference to the underlying writer.
77    pub fn get_ref(&self) -> &W {
78        &self.w
79    }
80
81    /// Get a mutable reference tothe underlying writer.
82    pub fn get_mut(&mut self) -> &mut W {
83        &mut self.w
84    }
85
86    /// Retrieve the ownership of the underlying reader.
87    pub fn into_inner(self) -> W
88    where
89        W: Sized,
90    {
91        self.w
92    }
93}
94
95impl<W: Write> Writer<W> {
96    /// Number of sections of initiated via `write_section`.
97    #[must_use]
98    pub fn section_count(&self) -> u32 {
99        // Checked by `write_section` not to overflow u32.
100        self.initiated_section_count
101    }
102
103    /// Finalize and seal the DwarFS archive.
104    pub fn finish(mut self) -> Result<W> {
105        // Wait for all proceeding sections to complete, so their offsets are recorded.
106        self.workers.stop();
107        while let Some(iter) = self.workers.wait_and_get() {
108            Self::commit_completed(iter, &mut self.w, &mut self.index)?;
109        }
110
111        // The last length is unused.
112        let index_byte_len = self.index.index.as_bytes().len() + size_of::<SectionIndexEntry>();
113        self.index
114            .push(SectionType::SECTION_INDEX, index_byte_len)?;
115        let sec = Self::seal_section(
116            self.section_count(),
117            SectionType::SECTION_INDEX,
118            CompressParam::None,
119            self.index.index.as_bytes(),
120        )?;
121        self.w.write_all(&sec)?;
122
123        Ok(self.w)
124    }
125
126    fn commit_completed(
127        completed: impl Iterator<Item = Result<Vec<u8>>>,
128        w: &mut W,
129        index: &mut IndexBuilder,
130    ) -> Result<()> {
131        for ret in completed {
132            let sec = ret?;
133            let off = std::mem::offset_of!(Header, section_type);
134            let typ = SectionType::read_from_prefix(&sec[off..]).unwrap().0;
135            w.write_all(&sec)?;
136            index.push(typ, sec.len())?;
137        }
138        Ok(())
139    }
140
141    /// Write a section with given (uncompressed) payload.
142    pub fn write_section(
143        &mut self,
144        section_type: SectionType,
145        compression: CompressParam,
146        payload: &[u8],
147    ) -> Result<()> {
148        // Should not happen for current machines.
149        assert!(u64::try_from(size_of::<Header>() + payload.len()).is_ok());
150
151        let section_number = self.section_count();
152        self.initiated_section_count = self
153            .initiated_section_count
154            .checked_add(1)
155            .ok_or(ErrorInner::Limit("section count exceeds 2^32"))?;
156
157        let payload = payload.to_vec();
158        Self::commit_completed(
159            self.workers.submit_and_get(move || {
160                Self::seal_section(section_number, section_type, compression, &payload)
161            }),
162            &mut self.w,
163            &mut self.index,
164        )
165    }
166
167    /// Compress payload if possible, calculate hashes and fill the section header.
168    fn seal_section(
169        section_number: u32,
170        section_type: SectionType,
171        compression: CompressParam,
172        payload: &[u8],
173    ) -> Result<Vec<u8>> {
174        let mut buf = vec![0u8; size_of::<Header>() + payload.len()];
175        #[cfg_attr(not(feature = "default"), allow(unused_labels))]
176        let (compress_algo, compressed_len) = 'compressed: {
177            let compressed_buf = &mut buf[size_of::<Header>()..];
178            match compression {
179                CompressParam::None => {}
180
181                #[cfg(feature = "zstd")]
182                #[expect(non_upper_case_globals, reason = "name from C")]
183                CompressParam::Zstd(lvl) => {
184                    // See: <https://github.com/gyscos/zstd-rs/issues/276>
185                    const ZSTD_error_dstSize_tooSmall: zstd_safe::ErrorCode = -70isize as usize;
186
187                    match zstd_safe::compress(compressed_buf, payload, lvl) {
188                        Ok(compressed_len) => {
189                            assert!(compressed_len <= payload.len());
190                            break 'compressed (CompressAlgo::ZSTD, compressed_len);
191                        }
192                        Err(ZSTD_error_dstSize_tooSmall) => {}
193                        Err(code) => {
194                            let err = std::io::Error::new(
195                                std::io::ErrorKind::InvalidInput,
196                                format!(
197                                    "ZSTD compression failed (code={}): {}",
198                                    code,
199                                    zstd_safe::get_error_name(code),
200                                ),
201                            );
202                            return Err(ErrorInner::Compress(err).into());
203                        }
204                    }
205                }
206
207                #[cfg(feature = "lzma")]
208                CompressParam::Lzma(lvl) => {
209                    if let Some(compressed_len) = (|| {
210                        use liblzma::stream::{Action, Check, Status, Stream};
211
212                        // The default parameters used by `liblzma::bufread::XzEncoder::new`.
213                        // See: <https://docs.rs/liblzma/0.4.1/src/liblzma/bufread.rs.html#35>
214                        let mut encoder = Stream::new_easy_encoder(lvl, Check::Crc64)?;
215
216                        match encoder.process(payload, compressed_buf, Action::Run)? {
217                            // Treat partial consumption as buffer-too-small.
218                            Status::Ok if encoder.total_in() == payload.len() as u64 => {}
219                            Status::Ok | Status::MemNeeded => return Ok(None),
220                            Status::StreamEnd | Status::GetCheck => unreachable!(),
221                        }
222                        match encoder.process(
223                            &[],
224                            &mut compressed_buf[encoder.total_out() as usize..],
225                            Action::Finish,
226                        )? {
227                            Status::StreamEnd => {}
228                            Status::MemNeeded => return Ok(None),
229                            Status::Ok | Status::GetCheck => unreachable!(),
230                        }
231
232                        Ok::<_, std::io::Error>(Some(encoder.total_out() as usize))
233                    })()
234                    .map_err(ErrorInner::Compress)?
235                    {
236                        break 'compressed (CompressAlgo::LZMA, compressed_len);
237                    }
238                }
239            }
240            compressed_buf.copy_from_slice(payload);
241            (CompressAlgo::NONE, payload.len())
242        };
243        buf.truncate(size_of::<Header>() + compressed_len);
244        let (header_buf, compressed_buf) = buf.split_at_mut(size_of::<Header>());
245
246        let mut header = Header {
247            magic_version: MagicVersion::LATEST,
248            slow_hash: [0u8; 32],
249            fast_hash: [0u8; 8],
250            section_number: section_number.into(),
251            section_type,
252            compress_algo,
253            payload_size: 0.into(),
254        };
255        header.update_size_and_checksum(compressed_buf);
256        header_buf.copy_from_slice(header.as_bytes());
257
258        Ok(buf)
259    }
260
261    /// Write metadata sections `METADATA_V2_{,_SCHEMA}`.
262    pub fn write_metadata_sections(
263        &mut self,
264        metadata: &dwarfs::metadata::Metadata,
265        compression: CompressParam,
266    ) -> Result<()> {
267        let (schema, metadata_bytes) = metadata.to_schema_and_bytes()?;
268        let schema_bytes = schema.to_bytes()?;
269        self.write_section(SectionType::METADATA_V2_SCHEMA, compression, &schema_bytes)?;
270        self.write_section(SectionType::METADATA_V2, compression, &metadata_bytes)
271    }
272}