lzma_rust2/lzip/
writer_mt.rs

1use std::{
2    io::{self, Cursor, Write},
3    sync::{
4        atomic::{AtomicBool, AtomicU32, Ordering},
5        mpsc::SyncSender,
6        Arc, Mutex,
7    },
8};
9
10use super::{LzipOptions, LzipWriter};
11use crate::{
12    error_invalid_input, set_error,
13    work_pool::{WorkPool, WorkPoolConfig},
14    work_queue::WorkerHandle,
15    AutoFinish, AutoFinisher,
16};
17
18/// A work unit for a worker thread.
19#[derive(Debug, Clone)]
20struct WorkUnit {
21    data: Vec<u8>,
22    options: LzipOptions,
23}
24
25/// A multi-threaded LZIP compressor.
26pub struct LzipWriterMt<W: Write> {
27    inner: W,
28    options: LzipOptions,
29    current_work_unit: Vec<u8>,
30    member_size: usize,
31    work_pool: WorkPool<WorkUnit, Vec<u8>>,
32    current_chunk: Cursor<Vec<u8>>,
33    pending_write_data: Vec<u8>,
34}
35
36impl<W: Write> LzipWriterMt<W> {
37    /// Creates a new multi-threaded LZIP writer.
38    ///
39    /// - `inner`: The writer to write compressed data to.
40    /// - `options`: The LZIP options used for compressing. Member size must be set when using the
41    ///   multi-threaded encoder. If you need just one member, then use the single-threaded encoder.
42    /// - `num_workers`: The maximum number of worker threads for compression.
43    ///   Currently capped at 256 threads.
44    pub fn new(inner: W, options: LzipOptions, num_workers: u32) -> io::Result<Self> {
45        let member_size = match options.member_size {
46            None => return Err(error_invalid_input("member size must be set")),
47            Some(member_size) => member_size.get().max(options.lzma_options.dict_size as u64),
48        };
49
50        let member_size = usize::try_from(member_size)
51            .map_err(|_| error_invalid_input("member size bigger than usize"))?;
52
53        // We don't know how many work units we'll have ahead of time.
54        let num_work = u64::MAX;
55
56        Ok(Self {
57            inner,
58            options,
59            current_work_unit: Vec::with_capacity(member_size.min(1024 * 1024)),
60            member_size,
61            work_pool: WorkPool::new(
62                WorkPoolConfig::new(num_workers, num_work),
63                worker_thread_logic,
64            ),
65            current_chunk: Cursor::new(Vec::new()),
66            pending_write_data: Vec::new(),
67        })
68    }
69
70    /// Sends the current work unit to the workers, writing any available results.
71    fn send_work_unit(&mut self) -> io::Result<()> {
72        if self.current_work_unit.is_empty() {
73            return Ok(());
74        }
75
76        self.drain_available_results()?;
77
78        let work_data = core::mem::take(&mut self.current_work_unit);
79        let mut single_member_options = self.options.clone();
80        single_member_options.member_size = None;
81
82        let mut work_data_opt = Some(work_data);
83
84        self.work_pool.dispatch_next_work(&mut |_seq| {
85            let data = work_data_opt.take().ok_or_else(|| {
86                io::Error::new(io::ErrorKind::InvalidInput, "work already provided")
87            })?;
88            Ok(WorkUnit {
89                data,
90                options: single_member_options.clone(),
91            })
92        })?;
93
94        self.drain_available_results()?;
95
96        Ok(())
97    }
98
99    /// Drains all currently available results from the work pool and writes them.
100    fn drain_available_results(&mut self) -> io::Result<()> {
101        while let Some(compressed_data) = self.work_pool.try_get_result()? {
102            self.inner.write_all(&compressed_data)?;
103        }
104        Ok(())
105    }
106
107    /// Returns a wrapper around `self` that will finish the stream on drop.
108    pub fn auto_finish(self) -> AutoFinisher<Self> {
109        AutoFinisher(Some(self))
110    }
111
112    /// Consume the LzipWriterMt and return the inner writer.
113    pub fn into_inner(self) -> W {
114        self.inner
115    }
116
117    /// Finishes the compression and returns the underlying writer.
118    pub fn finish(mut self) -> io::Result<W> {
119        if !self.current_work_unit.is_empty() {
120            self.send_work_unit()?;
121        }
122
123        // If no data was provided to compress, write an empty LZIP file (single empty member).
124        if self.work_pool.next_index_to_dispatch() == 0 {
125            let mut options = self.options.clone();
126            options.member_size = None;
127            let lzip_writer = LzipWriter::new(Vec::new(), options);
128            let empty_member = lzip_writer.finish()?;
129
130            self.inner.write_all(&empty_member)?;
131            self.inner.flush()?;
132
133            return Ok(self.inner);
134        }
135
136        // Mark the WorkPool as finished so it knows no more work is coming.
137        self.work_pool.finish();
138
139        // Wait for all remaining work to complete.
140        while let Some(compressed_data) = self.work_pool.get_result(|_| {
141            Err(io::Error::new(
142                io::ErrorKind::InvalidInput,
143                "no more work to dispatch",
144            ))
145        })? {
146            self.inner.write_all(&compressed_data)?;
147        }
148
149        self.inner.flush()?;
150
151        Ok(self.inner)
152    }
153}
154
155/// The logic for a single worker thread.
156fn worker_thread_logic(
157    worker_handle: WorkerHandle<(u64, WorkUnit)>,
158    result_tx: SyncSender<(u64, Vec<u8>)>,
159    shutdown_flag: Arc<AtomicBool>,
160    error_store: Arc<Mutex<Option<io::Error>>>,
161    active_workers: Arc<AtomicU32>,
162) {
163    while !shutdown_flag.load(Ordering::Acquire) {
164        let (index, work_unit) = match worker_handle.steal() {
165            Some(work) => {
166                active_workers.fetch_add(1, Ordering::Release);
167                work
168            }
169            None => {
170                // No more work available and queue is closed.
171                break;
172            }
173        };
174
175        let mut compressed_buffer = Vec::new();
176
177        let mut writer = LzipWriter::new(&mut compressed_buffer, work_unit.options);
178        let result = match writer.write_all(&work_unit.data) {
179            Ok(_) => match writer.finish() {
180                Ok(_) => compressed_buffer,
181                Err(error) => {
182                    active_workers.fetch_sub(1, Ordering::Release);
183                    set_error(error, &error_store, &shutdown_flag);
184                    return;
185                }
186            },
187            Err(error) => {
188                active_workers.fetch_sub(1, Ordering::Release);
189                set_error(error, &error_store, &shutdown_flag);
190                return;
191            }
192        };
193
194        if result_tx.send((index, result)).is_err() {
195            active_workers.fetch_sub(1, Ordering::Release);
196            return;
197        }
198
199        active_workers.fetch_sub(1, Ordering::Release);
200    }
201}
202
203impl<W: Write> Write for LzipWriterMt<W> {
204    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
205        if buf.is_empty() {
206            return Ok(0);
207        }
208
209        let mut total_written = 0;
210        let mut remaining_buf = buf;
211
212        while !remaining_buf.is_empty() {
213            let member_remaining = self
214                .member_size
215                .saturating_sub(self.current_work_unit.len());
216            let to_write = remaining_buf.len().min(member_remaining);
217
218            if to_write > 0 {
219                self.current_work_unit
220                    .extend_from_slice(&remaining_buf[..to_write]);
221                total_written += to_write;
222                remaining_buf = &remaining_buf[to_write..];
223            }
224
225            if self.current_work_unit.len() >= self.member_size {
226                self.send_work_unit()?;
227            }
228
229            self.drain_available_results()?;
230        }
231
232        Ok(total_written)
233    }
234
235    fn flush(&mut self) -> io::Result<()> {
236        if !self.current_work_unit.is_empty() {
237            self.send_work_unit()?;
238        }
239
240        // Wait for all pending work to complete and write the results.
241        while let Some(compressed_data) = self.work_pool.try_get_result()? {
242            self.inner.write_all(&compressed_data)?;
243        }
244
245        self.inner.flush()
246    }
247}
248
249impl<W: Write> AutoFinish for LzipWriterMt<W> {
250    fn finish_ignore_error(self) {
251        let _ = self.finish();
252    }
253}