Skip to main content

mtp_rs/mtp/
stream.rs

1//! Streaming download/upload support.
2
3use crate::ptp::ReceiveStream;
4use crate::Error;
5use bytes::Bytes;
6use std::ops::ControlFlow;
7
8/// Progress information for transfers.
9#[derive(Debug, Clone)]
10pub struct Progress {
11    /// Bytes transferred so far.
12    pub bytes_transferred: u64,
13    /// Total bytes (if known).
14    pub total_bytes: Option<u64>,
15}
16
17impl Progress {
18    /// Progress as a percentage (0.0 to 100.0).
19    #[must_use]
20    pub fn percent(&self) -> f64 {
21        self.fraction() * 100.0
22    }
23
24    /// Progress as a fraction (0.0 to 1.0).
25    #[must_use]
26    pub fn fraction(&self) -> f64 {
27        self.total_bytes.map_or(1.0, |total| {
28            if total == 0 {
29                1.0
30            } else {
31                self.bytes_transferred as f64 / total as f64
32            }
33        })
34    }
35}
36
37/// Default idle timeout for cancel drain operations.
38///
39/// After sending the cancel control request, this is how long we wait
40/// for additional data on each pipe before assuming it's clear. Matches
41/// the 300ms timeout used by libmtp, which mirrors Windows behavior.
42pub const DEFAULT_CANCEL_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(300);
43
44/// A file download in progress with true USB streaming.
45///
46/// This struct wraps the low-level `ReceiveStream` and provides convenient
47/// methods for tracking progress. Data is streamed directly from USB as
48/// chunks arrive, without buffering the entire file in memory.
49///
50/// # Important
51///
52/// The MTP session is locked while this download is active. You must either
53/// consume the entire download or call [`cancel()`](Self::cancel) before
54/// dropping it. Dropping mid-download without cancelling corrupts the USB
55/// session.
56///
57/// # Example
58///
59/// ```rust,no_run
60/// use mtp_rs::mtp::MtpDevice;
61/// use mtp_rs::ObjectHandle;
62/// use tokio::io::AsyncWriteExt;
63///
64/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
65/// # let device = MtpDevice::open_first().await?;
66/// # let storages = device.storages().await?;
67/// # let storage = &storages[0];
68/// # let handle = ObjectHandle(1);
69/// let mut download = storage.download_stream(handle).await?;
70/// println!("Downloading {} bytes...", download.size());
71///
72/// # let mut file = tokio::fs::File::create("output.bin").await?;
73/// while let Some(chunk) = download.next_chunk().await {
74///     let bytes = chunk?;
75///     file.write_all(&bytes).await?;
76///     println!("Progress: {:.1}%", download.progress() * 100.0);
77/// }
78/// # Ok(())
79/// # }
80/// ```
81#[must_use = "dropping a FileDownload mid-transfer corrupts the USB session; \
82               consume it fully or call cancel()"]
83pub struct FileDownload {
84    size: u64,
85    bytes_received: u64,
86    stream: ReceiveStream,
87}
88
89impl FileDownload {
90    /// Create a new FileDownload wrapping a ReceiveStream.
91    pub(crate) fn new(size: u64, stream: ReceiveStream) -> Self {
92        Self {
93            size,
94            bytes_received: 0,
95            stream,
96        }
97    }
98
99    /// Total file size in bytes.
100    #[must_use]
101    pub fn size(&self) -> u64 {
102        self.size
103    }
104
105    /// Bytes received so far.
106    #[must_use]
107    pub fn bytes_received(&self) -> u64 {
108        self.bytes_received
109    }
110
111    /// Progress as a fraction (0.0 to 1.0).
112    #[must_use]
113    pub fn progress(&self) -> f64 {
114        if self.size == 0 {
115            1.0
116        } else {
117            self.bytes_received as f64 / self.size as f64
118        }
119    }
120
121    /// Cancel the in-progress download.
122    ///
123    /// Uses the USB Still Image Class cancel mechanism to stop the transfer
124    /// and drain remaining data, leaving the session clean for the next
125    /// operation.
126    ///
127    /// The `idle_timeout` controls how long to wait during pipe drain before
128    /// assuming the pipe is clear. 1–2 seconds is typically sufficient.
129    ///
130    /// If the download is already complete, this is a no-op.
131    pub async fn cancel(&mut self, idle_timeout: std::time::Duration) -> Result<(), Error> {
132        self.stream.cancel(idle_timeout).await
133    }
134
135    /// Get the next chunk of data from USB.
136    ///
137    /// Returns `None` when the download is complete.
138    pub async fn next_chunk(&mut self) -> Option<Result<Bytes, Error>> {
139        match self.stream.next_chunk().await {
140            Some(Ok(bytes)) => {
141                self.bytes_received += bytes.len() as u64;
142                Some(Ok(bytes))
143            }
144            Some(Err(e)) => Some(Err(e)),
145            None => None,
146        }
147    }
148
149    /// Consume the download and iterate with a progress callback.
150    ///
151    /// Calls `on_progress` after each chunk. Return `ControlFlow::Break(())`
152    /// to cancel the download.
153    ///
154    /// # Example
155    ///
156    /// ```rust,no_run
157    /// use mtp_rs::mtp::MtpDevice;
158    /// use mtp_rs::ObjectHandle;
159    /// use std::ops::ControlFlow;
160    ///
161    /// # async fn example() -> Result<(), mtp_rs::Error> {
162    /// # let device = MtpDevice::open_first().await?;
163    /// # let storages = device.storages().await?;
164    /// # let storage = &storages[0];
165    /// # let handle = ObjectHandle(1);
166    /// let download = storage.download_stream(handle).await?;
167    /// let data = download.collect_with_progress(|progress| {
168    ///     println!("{:.1}%", progress.percent());
169    ///     ControlFlow::Continue(())
170    /// }).await?;
171    /// # Ok(())
172    /// # }
173    /// ```
174    pub async fn collect_with_progress<F>(mut self, mut on_progress: F) -> Result<Vec<u8>, Error>
175    where
176        F: FnMut(Progress) -> ControlFlow<()>,
177    {
178        let mut data = Vec::with_capacity(self.size as usize);
179
180        while let Some(result) = self.next_chunk().await {
181            let chunk = result?;
182            data.extend_from_slice(&chunk);
183
184            let progress = Progress {
185                bytes_transferred: self.bytes_received,
186                total_bytes: Some(self.size),
187            };
188
189            if let ControlFlow::Break(()) = on_progress(progress) {
190                self.stream.cancel(DEFAULT_CANCEL_TIMEOUT).await?;
191                return Err(Error::Cancelled);
192            }
193        }
194
195        Ok(data)
196    }
197
198    /// Collect all remaining data into a `Vec<u8>`.
199    ///
200    /// This consumes the download and buffers all data in memory.
201    pub async fn collect(self) -> Result<Vec<u8>, Error> {
202        self.stream.collect().await
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use std::ops::ControlFlow;
210
211    #[test]
212    fn progress_calculations() {
213        let cases = [
214            (50, Some(100), 50.0, 0.5),
215            (100, Some(100), 100.0, 1.0),
216            (25, Some(100), 25.0, 0.25),
217            (0, Some(0), 100.0, 1.0), // Empty file
218            (50, None, 100.0, 1.0),   // Unknown total defaults to complete
219        ];
220        for (transferred, total, expected_pct, expected_frac) in cases {
221            let p = Progress {
222                bytes_transferred: transferred,
223                total_bytes: total,
224            };
225            assert_eq!(
226                p.percent(),
227                expected_pct,
228                "percent failed for {transferred}/{total:?}"
229            );
230            assert_eq!(
231                p.fraction(),
232                expected_frac,
233                "fraction failed for {transferred}/{total:?}"
234            );
235        }
236
237        // Large numbers
238        let large = Progress {
239            bytes_transferred: u64::MAX / 2,
240            total_bytes: Some(u64::MAX),
241        };
242        let frac = large.fraction();
243        assert!(frac > 0.49 && frac < 0.51);
244    }
245
246    #[tokio::test]
247    async fn test_collect_with_progress_cancel_cleans_up() {
248        use crate::ptp::{
249            pack_u16, pack_u32, ContainerType, ObjectHandle, OperationCode, PtpSession,
250            ResponseCode,
251        };
252        use crate::transport::mock::MockTransport;
253        use std::sync::Arc;
254
255        // Helper to build a response container
256        fn response(tx_id: u32, code: ResponseCode) -> Vec<u8> {
257            let mut buf = Vec::with_capacity(12);
258            buf.extend_from_slice(&pack_u32(12));
259            buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
260            buf.extend_from_slice(&pack_u16(code.into()));
261            buf.extend_from_slice(&pack_u32(tx_id));
262            buf
263        }
264
265        // Helper to build a data container
266        fn data(tx_id: u32, code: OperationCode, payload: &[u8]) -> Vec<u8> {
267            let len = 12 + payload.len();
268            let mut buf = Vec::with_capacity(len);
269            buf.extend_from_slice(&pack_u32(len as u32));
270            buf.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
271            buf.extend_from_slice(&pack_u16(code.into()));
272            buf.extend_from_slice(&pack_u32(tx_id));
273            buf.extend_from_slice(payload);
274            buf
275        }
276
277        let mock = Arc::new(MockTransport::new());
278        let transport: Arc<dyn crate::transport::Transport> = Arc::clone(&mock) as _;
279        mock.queue_response(response(0, ResponseCode::Ok)); // OpenSession
280
281        let file_data = vec![1u8; 1000];
282        let file_size = file_data.len() as u64;
283        mock.queue_response(data(1, OperationCode::GetObject, &file_data));
284
285        let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
286        let stream = session.get_object_stream(ObjectHandle(1)).await.unwrap();
287        let download = FileDownload::new(file_size, stream);
288
289        // Break after first chunk
290        let result = download
291            .collect_with_progress(|_progress| ControlFlow::Break(()))
292            .await;
293
294        assert!(matches!(result, Err(Error::Cancelled)));
295
296        // Verify cancel_transfer was called with the correct transaction ID
297        let cancel_calls = mock.get_cancel_calls();
298        assert_eq!(cancel_calls, vec![1]);
299    }
300}