1#![deny(missing_docs)]
27
28pub mod ema;
29pub mod error;
30pub mod throttle;
31pub mod units;
32
33#[cfg(feature = "cli")]
34pub mod cursor;
35#[cfg(all(feature = "cli", unix))]
36pub mod signals;
37
38pub use error::PvError;
39pub use units::UnitSystem;
40
41use std::io::{Read, Write};
42use std::time::{Duration, Instant};
43
44#[non_exhaustive]
50#[derive(Debug, Clone)]
51pub struct Progress {
52 pub bytes_done: u64,
54 pub bytes_total: Option<u64>,
56 pub rate: f64,
58 pub eta: Option<Duration>,
60 pub elapsed: Duration,
62}
63
64pub trait Reporter: Send {
70 fn report(&mut self, progress: &Progress);
72}
73
74#[derive(Debug, Default)]
76pub struct NoopReporter;
77
78impl Reporter for NoopReporter {
79 fn report(&mut self, _progress: &Progress) {}
80}
81
82#[non_exhaustive]
84pub struct Pv {
85 total_bytes: Option<u64>,
86 rate_limit: Option<u64>,
87 buffer_size: usize,
88 interval: Duration,
89 name: Option<String>,
90 reporter: Box<dyn Reporter>,
91}
92
93impl Pv {
94 pub fn copy<R: Read + ?Sized, W: Write + ?Sized>(
101 mut self,
102 reader: &mut R,
103 writer: &mut W,
104 ) -> Result<u64, PvError> {
105 let start = Instant::now();
106 let mut buf = vec![0u8; self.buffer_size];
107 let mut bytes_done: u64 = 0;
108 let throttle = self.rate_limit.map(throttle::TokenBucket::new);
109 let mut ema = ema::Ema::new();
110 let mut last_tick = start;
111 let mut last_bytes: u64 = 0;
112
113 loop {
114 if let Some(tb) = &throttle {
116 tb.maybe_sleep(bytes_done);
117 }
118
119 let n = reader.read(&mut buf).map_err(PvError::from)?;
120 if n == 0 {
121 break;
122 }
123 writer.write_all(&buf[..n]).map_err(PvError::from)?;
124 bytes_done += n as u64;
125
126 let now = Instant::now();
128 if now.duration_since(last_tick) >= self.interval {
129 let dt = now.duration_since(last_tick).as_secs_f64().max(1e-9);
130 let dn = (bytes_done - last_bytes) as f64;
131 let sample = dn / dt;
132 let smoothed = ema.update(sample);
133 let elapsed = now.duration_since(start);
134 let eta = match (self.total_bytes, smoothed > 0.0) {
135 (Some(total), true) if total > bytes_done => Some(Duration::from_secs_f64(
136 (total - bytes_done) as f64 / smoothed,
137 )),
138 _ => None,
139 };
140 let progress = Progress {
141 bytes_done,
142 bytes_total: self.total_bytes,
143 rate: smoothed,
144 eta,
145 elapsed,
146 };
147 self.reporter.report(&progress);
148 last_tick = now;
149 last_bytes = bytes_done;
150 }
151 }
152
153 let elapsed = start.elapsed();
155 let avg_rate = if elapsed.as_secs_f64() > 0.0 {
156 bytes_done as f64 / elapsed.as_secs_f64()
157 } else {
158 0.0
159 };
160 let progress = Progress {
161 bytes_done,
162 bytes_total: self.total_bytes,
163 rate: avg_rate,
164 eta: None,
165 elapsed,
166 };
167 self.reporter.report(&progress);
168
169 writer.flush().map_err(PvError::from)?;
170 Ok(bytes_done)
171 }
172
173 #[must_use]
175 pub fn total_bytes(&self) -> Option<u64> {
176 self.total_bytes
177 }
178
179 #[must_use]
181 pub fn rate_limit(&self) -> Option<u64> {
182 self.rate_limit
183 }
184
185 #[must_use]
187 pub fn buffer_size(&self) -> usize {
188 self.buffer_size
189 }
190
191 #[must_use]
193 pub fn interval(&self) -> Duration {
194 self.interval
195 }
196
197 #[must_use]
199 pub fn name(&self) -> Option<&str> {
200 self.name.as_deref()
201 }
202}
203
204impl std::fmt::Debug for Pv {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("Pv")
207 .field("total_bytes", &self.total_bytes)
208 .field("rate_limit", &self.rate_limit)
209 .field("buffer_size", &self.buffer_size)
210 .field("interval", &self.interval)
211 .field("name", &self.name)
212 .field("reporter", &"<dyn Reporter>")
213 .finish()
214 }
215}
216
217pub struct PvBuilder {
221 total_bytes: Option<u64>,
222 rate_limit: Option<u64>,
223 buffer_size: usize,
224 interval: Duration,
225 name: Option<String>,
226 reporter: Option<Box<dyn Reporter>>,
227}
228
229impl PvBuilder {
230 #[must_use]
238 pub fn new() -> Self {
239 PvBuilder {
240 total_bytes: None,
241 rate_limit: None,
242 buffer_size: 1 << 20,
243 interval: Duration::from_secs(1),
244 name: None,
245 reporter: None,
246 }
247 }
248
249 #[must_use]
251 pub fn total_bytes(mut self, n: u64) -> Self {
252 self.total_bytes = Some(n);
253 self
254 }
255
256 #[must_use]
258 pub fn rate_limit(mut self, bytes_per_sec: u64) -> Self {
259 self.rate_limit = Some(bytes_per_sec);
260 self
261 }
262
263 #[must_use]
265 pub fn buffer_size(mut self, n: usize) -> Self {
266 self.buffer_size = n;
267 self
268 }
269
270 #[must_use]
272 pub fn interval(mut self, d: Duration) -> Self {
273 self.interval = d;
274 self
275 }
276
277 #[must_use]
279 pub fn name(mut self, name: impl Into<String>) -> Self {
280 self.name = Some(name.into());
281 self
282 }
283
284 #[must_use]
286 pub fn reporter(mut self, r: Box<dyn Reporter>) -> Self {
287 self.reporter = Some(r);
288 self
289 }
290
291 #[must_use]
293 pub fn build(self) -> Pv {
294 Pv {
295 total_bytes: self.total_bytes,
296 rate_limit: self.rate_limit,
297 buffer_size: self.buffer_size,
298 interval: self.interval,
299 name: self.name,
300 reporter: self.reporter.unwrap_or_else(|| Box::new(NoopReporter)),
301 }
302 }
303}
304
305impl Default for PvBuilder {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311impl std::fmt::Debug for PvBuilder {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 f.debug_struct("PvBuilder")
314 .field("total_bytes", &self.total_bytes)
315 .field("rate_limit", &self.rate_limit)
316 .field("buffer_size", &self.buffer_size)
317 .field("interval", &self.interval)
318 .field("name", &self.name)
319 .field(
320 "reporter",
321 &self.reporter.as_ref().map(|_| "<dyn Reporter>"),
322 )
323 .finish()
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use static_assertions::assert_impl_all;
331
332 assert_impl_all!(Pv: Send);
333 assert_impl_all!(PvBuilder: Send);
334 assert_impl_all!(Progress: Send, Sync);
335 assert_impl_all!(PvError: Send, Sync);
336
337 #[test]
338 fn copy_through_in_memory_buffer() {
339 let pv = PvBuilder::new().build();
340 let src = vec![0xABu8; 4096];
341 let mut reader = std::io::Cursor::new(src.clone());
342 let mut writer = Vec::new();
343 let n = pv.copy(&mut reader, &mut writer).unwrap();
344 assert_eq!(n, src.len() as u64);
345 assert_eq!(writer, src);
346 }
347
348 #[test]
349 fn builder_setters_are_independent() {
350 let p1 = PvBuilder::new()
351 .total_bytes(1024)
352 .rate_limit(500)
353 .interval(Duration::from_millis(100))
354 .name("a")
355 .build();
356 let p2 = PvBuilder::new()
357 .name("a")
358 .interval(Duration::from_millis(100))
359 .rate_limit(500)
360 .total_bytes(1024)
361 .build();
362 assert_eq!(p1.total_bytes(), p2.total_bytes());
363 assert_eq!(p1.rate_limit(), p2.rate_limit());
364 assert_eq!(p1.interval(), p2.interval());
365 assert_eq!(p1.name(), p2.name());
366 }
367
368 #[test]
369 fn custom_reporter_receives_updates() {
370 struct Counter(std::sync::Arc<std::sync::atomic::AtomicUsize>);
371 impl Reporter for Counter {
372 fn report(&mut self, _progress: &Progress) {
373 self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
374 }
375 }
376 let count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
377 let pv = PvBuilder::new()
378 .interval(Duration::from_millis(1))
379 .reporter(Box::new(Counter(count.clone())))
380 .build();
381 let mut reader = std::io::Cursor::new(vec![0u8; 65536]);
382 let mut writer = Vec::new();
383 let _ = pv.copy(&mut reader, &mut writer).unwrap();
384 assert!(count.load(std::sync::atomic::Ordering::Relaxed) >= 1);
386 }
387}