mtp_rs/mtp/stream.rs
1//! Streaming download/upload support.
2
3use crate::ptp::ReceiveStream;
4use crate::Error;
5use bytes::Bytes;
6use std::ops::ControlFlow;
7
8/// Progress information for transfers.
9#[derive(Debug, Clone)]
10pub struct Progress {
11 /// Bytes transferred so far.
12 pub bytes_transferred: u64,
13 /// Total bytes (if known).
14 pub total_bytes: Option<u64>,
15}
16
17impl Progress {
18 /// Progress as a percentage (0.0 to 100.0).
19 #[must_use]
20 pub fn percent(&self) -> f64 {
21 self.fraction() * 100.0
22 }
23
24 /// Progress as a fraction (0.0 to 1.0).
25 #[must_use]
26 pub fn fraction(&self) -> f64 {
27 self.total_bytes.map_or(1.0, |total| {
28 if total == 0 {
29 1.0
30 } else {
31 self.bytes_transferred as f64 / total as f64
32 }
33 })
34 }
35}
36
37/// Default idle timeout for cancel drain operations.
38///
39/// After sending the cancel control request, this is how long we wait
40/// for additional data on each pipe before assuming it's clear. Matches
41/// the 300ms timeout used by libmtp, which mirrors Windows behavior.
42pub const DEFAULT_CANCEL_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(300);
43
44/// A file download in progress with true USB streaming.
45///
46/// This struct wraps the low-level `ReceiveStream` and provides convenient
47/// methods for tracking progress. Data is streamed directly from USB as
48/// chunks arrive, without buffering the entire file in memory.
49///
50/// # Important
51///
52/// The MTP session is locked while this download is active. You must either
53/// consume the entire download or call [`cancel()`](Self::cancel) before
54/// dropping it. Dropping mid-download without cancelling corrupts the USB
55/// session.
56///
57/// # Example
58///
59/// ```rust,no_run
60/// use mtp_rs::mtp::MtpDevice;
61/// use mtp_rs::ObjectHandle;
62/// use tokio::io::AsyncWriteExt;
63///
64/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
65/// # let device = MtpDevice::open_first().await?;
66/// # let storages = device.storages().await?;
67/// # let storage = &storages[0];
68/// # let handle = ObjectHandle(1);
69/// let mut download = storage.download_stream(handle).await?;
70/// println!("Downloading {} bytes...", download.size());
71///
72/// # let mut file = tokio::fs::File::create("output.bin").await?;
73/// while let Some(chunk) = download.next_chunk().await {
74/// let bytes = chunk?;
75/// file.write_all(&bytes).await?;
76/// println!("Progress: {:.1}%", download.progress() * 100.0);
77/// }
78/// # Ok(())
79/// # }
80/// ```
81#[must_use = "dropping a FileDownload mid-transfer corrupts the USB session; \
82 consume it fully or call cancel()"]
83pub struct FileDownload {
84 size: u64,
85 bytes_received: u64,
86 stream: ReceiveStream,
87}
88
89impl FileDownload {
90 /// Create a new FileDownload wrapping a ReceiveStream.
91 pub(crate) fn new(size: u64, stream: ReceiveStream) -> Self {
92 Self {
93 size,
94 bytes_received: 0,
95 stream,
96 }
97 }
98
99 /// Total file size in bytes.
100 #[must_use]
101 pub fn size(&self) -> u64 {
102 self.size
103 }
104
105 /// Bytes received so far.
106 #[must_use]
107 pub fn bytes_received(&self) -> u64 {
108 self.bytes_received
109 }
110
111 /// Progress as a fraction (0.0 to 1.0).
112 #[must_use]
113 pub fn progress(&self) -> f64 {
114 if self.size == 0 {
115 1.0
116 } else {
117 self.bytes_received as f64 / self.size as f64
118 }
119 }
120
121 /// Cancel the in-progress download.
122 ///
123 /// Uses the USB Still Image Class cancel mechanism to stop the transfer
124 /// and drain remaining data, leaving the session clean for the next
125 /// operation.
126 ///
127 /// The `idle_timeout` controls how long to wait during pipe drain before
128 /// assuming the pipe is clear. 1–2 seconds is typically sufficient.
129 ///
130 /// If the download is already complete, this is a no-op.
131 pub async fn cancel(&mut self, idle_timeout: std::time::Duration) -> Result<(), Error> {
132 self.stream.cancel(idle_timeout).await
133 }
134
135 /// Get the next chunk of data from USB.
136 ///
137 /// Returns `None` when the download is complete.
138 pub async fn next_chunk(&mut self) -> Option<Result<Bytes, Error>> {
139 match self.stream.next_chunk().await {
140 Some(Ok(bytes)) => {
141 self.bytes_received += bytes.len() as u64;
142 Some(Ok(bytes))
143 }
144 Some(Err(e)) => Some(Err(e)),
145 None => None,
146 }
147 }
148
149 /// Consume the download and iterate with a progress callback.
150 ///
151 /// Calls `on_progress` after each chunk. Return `ControlFlow::Break(())`
152 /// to cancel the download.
153 ///
154 /// # Example
155 ///
156 /// ```rust,no_run
157 /// use mtp_rs::mtp::MtpDevice;
158 /// use mtp_rs::ObjectHandle;
159 /// use std::ops::ControlFlow;
160 ///
161 /// # async fn example() -> Result<(), mtp_rs::Error> {
162 /// # let device = MtpDevice::open_first().await?;
163 /// # let storages = device.storages().await?;
164 /// # let storage = &storages[0];
165 /// # let handle = ObjectHandle(1);
166 /// let download = storage.download_stream(handle).await?;
167 /// let data = download.collect_with_progress(|progress| {
168 /// println!("{:.1}%", progress.percent());
169 /// ControlFlow::Continue(())
170 /// }).await?;
171 /// # Ok(())
172 /// # }
173 /// ```
174 pub async fn collect_with_progress<F>(mut self, mut on_progress: F) -> Result<Vec<u8>, Error>
175 where
176 F: FnMut(Progress) -> ControlFlow<()>,
177 {
178 let mut data = Vec::with_capacity(self.size as usize);
179
180 while let Some(result) = self.next_chunk().await {
181 let chunk = result?;
182 data.extend_from_slice(&chunk);
183
184 let progress = Progress {
185 bytes_transferred: self.bytes_received,
186 total_bytes: Some(self.size),
187 };
188
189 if let ControlFlow::Break(()) = on_progress(progress) {
190 self.stream.cancel(DEFAULT_CANCEL_TIMEOUT).await?;
191 return Err(Error::Cancelled);
192 }
193 }
194
195 Ok(data)
196 }
197
198 /// Collect all remaining data into a `Vec<u8>`.
199 ///
200 /// This consumes the download and buffers all data in memory.
201 pub async fn collect(self) -> Result<Vec<u8>, Error> {
202 self.stream.collect().await
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use std::ops::ControlFlow;
210
211 #[test]
212 fn progress_calculations() {
213 let cases = [
214 (50, Some(100), 50.0, 0.5),
215 (100, Some(100), 100.0, 1.0),
216 (25, Some(100), 25.0, 0.25),
217 (0, Some(0), 100.0, 1.0), // Empty file
218 (50, None, 100.0, 1.0), // Unknown total defaults to complete
219 ];
220 for (transferred, total, expected_pct, expected_frac) in cases {
221 let p = Progress {
222 bytes_transferred: transferred,
223 total_bytes: total,
224 };
225 assert_eq!(
226 p.percent(),
227 expected_pct,
228 "percent failed for {transferred}/{total:?}"
229 );
230 assert_eq!(
231 p.fraction(),
232 expected_frac,
233 "fraction failed for {transferred}/{total:?}"
234 );
235 }
236
237 // Large numbers
238 let large = Progress {
239 bytes_transferred: u64::MAX / 2,
240 total_bytes: Some(u64::MAX),
241 };
242 let frac = large.fraction();
243 assert!(frac > 0.49 && frac < 0.51);
244 }
245
246 #[tokio::test]
247 async fn test_collect_with_progress_cancel_cleans_up() {
248 use crate::ptp::{
249 pack_u16, pack_u32, ContainerType, ObjectHandle, OperationCode, PtpSession,
250 ResponseCode,
251 };
252 use crate::transport::mock::MockTransport;
253 use std::sync::Arc;
254
255 // Helper to build a response container
256 fn response(tx_id: u32, code: ResponseCode) -> Vec<u8> {
257 let mut buf = Vec::with_capacity(12);
258 buf.extend_from_slice(&pack_u32(12));
259 buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
260 buf.extend_from_slice(&pack_u16(code.into()));
261 buf.extend_from_slice(&pack_u32(tx_id));
262 buf
263 }
264
265 // Helper to build a data container
266 fn data(tx_id: u32, code: OperationCode, payload: &[u8]) -> Vec<u8> {
267 let len = 12 + payload.len();
268 let mut buf = Vec::with_capacity(len);
269 buf.extend_from_slice(&pack_u32(len as u32));
270 buf.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
271 buf.extend_from_slice(&pack_u16(code.into()));
272 buf.extend_from_slice(&pack_u32(tx_id));
273 buf.extend_from_slice(payload);
274 buf
275 }
276
277 let mock = Arc::new(MockTransport::new());
278 let transport: Arc<dyn crate::transport::Transport> = Arc::clone(&mock) as _;
279 mock.queue_response(response(0, ResponseCode::Ok)); // OpenSession
280
281 let file_data = vec![1u8; 1000];
282 let file_size = file_data.len() as u64;
283 mock.queue_response(data(1, OperationCode::GetObject, &file_data));
284
285 let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
286 let stream = session.get_object_stream(ObjectHandle(1)).await.unwrap();
287 let download = FileDownload::new(file_size, stream);
288
289 // Break after first chunk
290 let result = download
291 .collect_with_progress(|_progress| ControlFlow::Break(()))
292 .await;
293
294 assert!(matches!(result, Err(Error::Cancelled)));
295
296 // Verify cancel_transfer was called with the correct transaction ID
297 let cancel_calls = mock.get_cancel_calls();
298 assert_eq!(cancel_calls, vec![1]);
299 }
300}