Skip to main content

mtp_rs/mtp/
stream.rs

1//! Streaming download/upload support.
2
3use crate::ptp::ReceiveStream;
4use crate::Error;
5use bytes::Bytes;
6use std::ops::ControlFlow;
7
8/// Progress information for transfers.
9#[derive(Debug, Clone)]
10pub struct Progress {
11    /// Bytes transferred so far.
12    pub bytes_transferred: u64,
13    /// Total bytes (if known).
14    pub total_bytes: Option<u64>,
15}
16
17impl Progress {
18    /// Progress as a percentage (0.0 to 100.0).
19    #[must_use]
20    pub fn percent(&self) -> f64 {
21        self.fraction() * 100.0
22    }
23
24    /// Progress as a fraction (0.0 to 1.0).
25    #[must_use]
26    pub fn fraction(&self) -> f64 {
27        self.total_bytes.map_or(1.0, |total| {
28            if total == 0 {
29                1.0
30            } else {
31                self.bytes_transferred as f64 / total as f64
32            }
33        })
34    }
35}
36
37/// A file download in progress with true USB streaming.
38///
39/// This struct wraps the low-level `ReceiveStream` and provides convenient
40/// methods for tracking progress. Data is streamed directly from USB as
41/// chunks arrive, without buffering the entire file in memory.
42///
43/// # Important
44///
45/// The MTP session is locked while this download is active. You must consume
46/// the entire download (or drop it) before calling other storage methods.
47///
48/// # Example
49///
50/// ```rust,ignore
51/// let mut download = storage.download_stream(handle).await?;
52/// println!("Downloading {} bytes...", download.size());
53///
54/// while let Some(chunk) = download.next_chunk().await {
55///     let bytes = chunk?;
56///     file.write_all(&bytes).await?;
57///     println!("Progress: {:.1}%", download.progress() * 100.0);
58/// }
59/// ```
60pub struct FileDownload {
61    size: u64,
62    bytes_received: u64,
63    stream: ReceiveStream,
64}
65
66impl FileDownload {
67    /// Create a new FileDownload wrapping a ReceiveStream.
68    pub(crate) fn new(size: u64, stream: ReceiveStream) -> Self {
69        Self {
70            size,
71            bytes_received: 0,
72            stream,
73        }
74    }
75
76    /// Total file size in bytes.
77    #[must_use]
78    pub fn size(&self) -> u64 {
79        self.size
80    }
81
82    /// Bytes received so far.
83    #[must_use]
84    pub fn bytes_received(&self) -> u64 {
85        self.bytes_received
86    }
87
88    /// Progress as a fraction (0.0 to 1.0).
89    #[must_use]
90    pub fn progress(&self) -> f64 {
91        if self.size == 0 {
92            1.0
93        } else {
94            self.bytes_received as f64 / self.size as f64
95        }
96    }
97
98    /// Get the next chunk of data from USB.
99    ///
100    /// Returns `None` when the download is complete.
101    pub async fn next_chunk(&mut self) -> Option<Result<Bytes, Error>> {
102        match self.stream.next_chunk().await {
103            Some(Ok(bytes)) => {
104                self.bytes_received += bytes.len() as u64;
105                Some(Ok(bytes))
106            }
107            Some(Err(e)) => Some(Err(e)),
108            None => None,
109        }
110    }
111
112    /// Consume the download and iterate with a progress callback.
113    ///
114    /// Calls `on_progress` after each chunk. Return `ControlFlow::Break(())`
115    /// to cancel the download.
116    ///
117    /// # Example
118    ///
119    /// ```rust,ignore
120    /// let data = download.collect_with_progress(|progress| {
121    ///     println!("{:.1}%", progress.percent().unwrap_or(0.0));
122    ///     ControlFlow::Continue(())
123    /// }).await?;
124    /// ```
125    pub async fn collect_with_progress<F>(mut self, mut on_progress: F) -> Result<Vec<u8>, Error>
126    where
127        F: FnMut(Progress) -> ControlFlow<()>,
128    {
129        let mut data = Vec::with_capacity(self.size as usize);
130
131        while let Some(result) = self.next_chunk().await {
132            let chunk = result?;
133            data.extend_from_slice(&chunk);
134
135            let progress = Progress {
136                bytes_transferred: self.bytes_received,
137                total_bytes: Some(self.size),
138            };
139
140            if let ControlFlow::Break(()) = on_progress(progress) {
141                return Err(Error::Cancelled);
142            }
143        }
144
145        Ok(data)
146    }
147
148    /// Collect all remaining data into a `Vec<u8>`.
149    ///
150    /// This consumes the download and buffers all data in memory.
151    pub async fn collect(self) -> Result<Vec<u8>, Error> {
152        self.stream.collect().await
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn progress_calculations() {
162        let cases = [
163            (50, Some(100), 50.0, 0.5),
164            (100, Some(100), 100.0, 1.0),
165            (25, Some(100), 25.0, 0.25),
166            (0, Some(0), 100.0, 1.0), // Empty file
167            (50, None, 100.0, 1.0),   // Unknown total defaults to complete
168        ];
169        for (transferred, total, expected_pct, expected_frac) in cases {
170            let p = Progress {
171                bytes_transferred: transferred,
172                total_bytes: total,
173            };
174            assert_eq!(
175                p.percent(),
176                expected_pct,
177                "percent failed for {transferred}/{total:?}"
178            );
179            assert_eq!(
180                p.fraction(),
181                expected_frac,
182                "fraction failed for {transferred}/{total:?}"
183            );
184        }
185
186        // Large numbers
187        let large = Progress {
188            bytes_transferred: u64::MAX / 2,
189            total_bytes: Some(u64::MAX),
190        };
191        let frac = large.fraction();
192        assert!(frac > 0.49 && frac < 0.51);
193    }
194}