Skip to main content

mtp_rs/ptp/session/
streaming.rs

1//! Streaming transfer operations.
2//!
3//! This module contains the `ReceiveStream` struct and methods for streaming
4//! data transfers, allowing memory-efficient downloads and uploads without
5//! buffering entire files in memory.
6
7use crate::ptp::{
8    container_type, pack_u16, pack_u32, unpack_u32, CommandContainer, ContainerType, ObjectHandle,
9    OperationCode, ResponseCode, ResponseContainer,
10};
11use crate::transport::Transport;
12use crate::Error;
13use bytes::Bytes;
14use futures::lock::OwnedMutexGuard;
15use futures::Stream;
16use std::sync::Arc;
17
18use super::{PtpSession, HEADER_SIZE};
19
20impl PtpSession {
21    // =========================================================================
22    // Streaming operations
23    // =========================================================================
24
25    /// Execute operation with streaming data receive.
26    ///
27    /// Returns a Stream that yields data chunks as they arrive from USB.
28    /// The stream yields `Bytes` chunks (typically up to 64KB each).
29    ///
30    /// # Important
31    ///
32    /// The caller must consume the entire stream before calling any other
33    /// session methods. The MTP session is locked while the stream is active.
34    ///
35    /// # Arguments
36    ///
37    /// * `operation` - The operation code to execute
38    /// * `params` - Operation parameters
39    ///
40    /// # Returns
41    ///
42    /// A `ReceiveStream` that yields `Result<Bytes, Error>` chunks.
43    pub async fn execute_with_receive_stream(
44        self: &Arc<Self>,
45        operation: OperationCode,
46        params: &[u32],
47    ) -> Result<ReceiveStream, Error> {
48        // Clone the Arc for the lock
49        let lock = Arc::clone(&self.operation_lock);
50        let guard = lock.lock_owned().await;
51
52        let tx_id = self.next_transaction_id();
53
54        // Send command
55        let cmd = CommandContainer {
56            code: operation,
57            transaction_id: tx_id,
58            params: params.to_vec(),
59        };
60        self.transport.send_bulk(&cmd.to_bytes()).await?;
61
62        Ok(ReceiveStream {
63            transport: Arc::clone(&self.transport),
64            _guard: guard,
65            transaction_id: tx_id,
66            operation,
67            buffer: Vec::new(),
68            container_length: 0,
69            payload_yielded: 0,
70            header_parsed: false,
71            done: false,
72        })
73    }
74
75    /// Execute operation with streaming data send.
76    ///
77    /// Accepts a Stream of data chunks to send. The total_size must be
78    /// known upfront (MTP protocol requirement).
79    ///
80    /// # Arguments
81    ///
82    /// * `operation` - The operation code
83    /// * `params` - Operation parameters
84    /// * `total_size` - Total bytes that will be sent (REQUIRED by MTP protocol)
85    /// * `data` - Stream of data chunks to send
86    ///
87    /// # Important
88    ///
89    /// The `total_size` must match the actual total bytes in the stream.
90    /// MTP requires knowing the size before transfer begins.
91    pub async fn execute_with_send_stream<S>(
92        &self,
93        operation: OperationCode,
94        params: &[u32],
95        total_size: u64,
96        mut data: S,
97    ) -> Result<ResponseContainer, Error>
98    where
99        S: Stream<Item = Result<Bytes, std::io::Error>> + Unpin,
100    {
101        use futures::StreamExt;
102
103        let _guard = self.operation_lock.lock().await;
104        let tx_id = self.next_transaction_id();
105
106        // Send command
107        let cmd = CommandContainer {
108            code: operation,
109            transaction_id: tx_id,
110            params: params.to_vec(),
111        };
112        self.transport.send_bulk(&cmd.to_bytes()).await?;
113
114        // Build complete data container (header + all payload)
115        // MTP devices expect the entire data container in a single USB transfer
116        let container_length = HEADER_SIZE as u64 + total_size;
117        let mut buffer = Vec::with_capacity(container_length as usize);
118
119        // Add header
120        if container_length <= u32::MAX as u64 {
121            buffer.extend_from_slice(&pack_u32(container_length as u32));
122        } else {
123            buffer.extend_from_slice(&pack_u32(0xFFFFFFFF));
124        }
125        buffer.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
126        buffer.extend_from_slice(&pack_u16(operation.into()));
127        buffer.extend_from_slice(&pack_u32(tx_id));
128
129        // Collect all chunks into buffer
130        while let Some(chunk_result) = data.next().await {
131            let chunk = chunk_result.map_err(Error::Io)?;
132            buffer.extend_from_slice(&chunk);
133        }
134
135        // Send entire data container as one USB transfer
136        self.transport.send_bulk(&buffer).await?;
137
138        // Receive response
139        let response_bytes = self.transport.receive_bulk(512).await?;
140        let response = ResponseContainer::from_bytes(&response_bytes)?;
141
142        if response.transaction_id != tx_id {
143            return Err(Error::invalid_data(format!(
144                "Transaction ID mismatch: expected {}, got {}",
145                tx_id, response.transaction_id
146            )));
147        }
148
149        Ok(response)
150    }
151
152    /// Download an object as a stream of chunks.
153    ///
154    /// This is a convenience method that calls `execute_with_receive_stream`
155    /// with GetObject operation.
156    ///
157    /// # Important
158    ///
159    /// The caller must consume the entire stream before calling any other
160    /// session methods. The MTP session is locked while the stream is active.
161    pub async fn get_object_stream(
162        self: &Arc<Self>,
163        handle: ObjectHandle,
164    ) -> Result<ReceiveStream, Error> {
165        self.execute_with_receive_stream(OperationCode::GetObject, &[handle.0])
166            .await
167    }
168
169    /// Upload an object from a stream.
170    ///
171    /// This is a convenience method that streams object data directly to USB.
172    ///
173    /// # Arguments
174    ///
175    /// * `total_size` - Total bytes that will be sent
176    /// * `data` - Stream of data chunks to send
177    pub async fn send_object_stream<S>(&self, total_size: u64, data: S) -> Result<(), Error>
178    where
179        S: Stream<Item = Result<Bytes, std::io::Error>> + Unpin,
180    {
181        let response = self
182            .execute_with_send_stream(OperationCode::SendObject, &[], total_size, data)
183            .await?;
184        Self::check_response(&response, OperationCode::SendObject)?;
185        Ok(())
186    }
187}
188
189/// A stream of data chunks received from USB during a download operation.
190///
191/// This stream yields `Bytes` chunks as they arrive from the device,
192/// allowing memory-efficient streaming without buffering the entire file.
193///
194/// # Important
195///
196/// The MTP session is locked while this stream exists. You must consume
197/// the entire stream (or drop it) before calling other session methods.
198pub struct ReceiveStream {
199    /// The transport layer for USB communication.
200    transport: Arc<dyn Transport>,
201    /// Guard that holds the operation lock for the duration of streaming.
202    _guard: OwnedMutexGuard<()>,
203    /// Transaction ID for this operation.
204    transaction_id: u32,
205    /// Operation code for this operation.
206    operation: OperationCode,
207    /// Buffer for partial container data.
208    buffer: Vec<u8>,
209    /// Total length of current container (from header).
210    container_length: usize,
211    /// How much payload we've already yielded from current container.
212    payload_yielded: usize,
213    /// Whether we've parsed the container header.
214    header_parsed: bool,
215    /// Whether the stream is complete.
216    done: bool,
217}
218
219impl ReceiveStream {
220    /// Get the transaction ID for this operation.
221    #[must_use]
222    pub fn transaction_id(&self) -> u32 {
223        self.transaction_id
224    }
225
226    /// Poll for the next chunk of data.
227    ///
228    /// This is the async version of the Stream trait's poll_next.
229    pub async fn next_chunk(&mut self) -> Option<Result<Bytes, Error>> {
230        if self.done {
231            return None;
232        }
233
234        loop {
235            // If we have buffered data beyond what we've already yielded, yield it
236            if self.header_parsed {
237                let payload_start = HEADER_SIZE + self.payload_yielded;
238                let payload_end = std::cmp::min(self.buffer.len(), self.container_length);
239
240                if payload_start < payload_end {
241                    // We have new data to yield
242                    let chunk_data = self.buffer[payload_start..payload_end].to_vec();
243                    self.payload_yielded += chunk_data.len();
244
245                    // Check if this container is complete
246                    if self.buffer.len() >= self.container_length {
247                        // Remove this container from buffer
248                        self.buffer.drain(..self.container_length);
249                        self.header_parsed = false;
250                        self.container_length = 0;
251                        self.payload_yielded = 0;
252                    }
253
254                    if !chunk_data.is_empty() {
255                        return Some(Ok(Bytes::from(chunk_data)));
256                    }
257                } else if self.buffer.len() >= self.container_length {
258                    // Container complete but no new data (shouldn't happen, but handle it)
259                    self.buffer.drain(..self.container_length);
260                    self.header_parsed = false;
261                    self.container_length = 0;
262                    self.payload_yielded = 0;
263                }
264            }
265
266            // Need more data from USB
267            match self.transport.receive_bulk(64 * 1024).await {
268                Ok(bytes) => {
269                    if bytes.is_empty() {
270                        return Some(Err(Error::invalid_data("Empty response from device")));
271                    }
272                    self.buffer.extend_from_slice(&bytes);
273                }
274                Err(e) => {
275                    self.done = true;
276                    return Some(Err(e));
277                }
278            }
279
280            // Try to parse container header if we haven't yet
281            if !self.header_parsed && self.buffer.len() >= HEADER_SIZE {
282                let ct = match container_type(&self.buffer) {
283                    Ok(ct) => ct,
284                    Err(e) => {
285                        self.done = true;
286                        return Some(Err(e));
287                    }
288                };
289
290                match ct {
291                    ContainerType::Data => {
292                        let length = match unpack_u32(&self.buffer[0..4]) {
293                            Ok(l) => l as usize,
294                            Err(e) => {
295                                self.done = true;
296                                return Some(Err(e));
297                            }
298                        };
299                        self.container_length = length;
300                        self.header_parsed = true;
301                    }
302                    ContainerType::Response => {
303                        // End of data transfer
304                        let response = match ResponseContainer::from_bytes(&self.buffer) {
305                            Ok(r) => r,
306                            Err(e) => {
307                                self.done = true;
308                                return Some(Err(e));
309                            }
310                        };
311
312                        self.done = true;
313
314                        // Check transaction ID
315                        if response.transaction_id != self.transaction_id {
316                            return Some(Err(Error::invalid_data(format!(
317                                "Transaction ID mismatch: expected {}, got {}",
318                                self.transaction_id, response.transaction_id
319                            ))));
320                        }
321
322                        // Check response code
323                        if response.code != ResponseCode::Ok {
324                            return Some(Err(Error::Protocol {
325                                code: response.code,
326                                operation: self.operation,
327                            }));
328                        }
329
330                        return None;
331                    }
332                    _ => {
333                        self.done = true;
334                        return Some(Err(Error::invalid_data(format!(
335                            "Unexpected container type: {:?}",
336                            ct
337                        ))));
338                    }
339                }
340            }
341        }
342    }
343
344    /// Collect all remaining data into a `Vec<u8>`.
345    ///
346    /// This consumes the stream and buffers all data in memory.
347    pub async fn collect(mut self) -> Result<Vec<u8>, Error> {
348        let mut data = Vec::new();
349        while let Some(result) = self.next_chunk().await {
350            let chunk = result?;
351            data.extend_from_slice(&chunk);
352        }
353        Ok(data)
354    }
355}
356
357/// Convert a ReceiveStream into a futures::Stream using async iteration.
358///
359/// This creates a proper Stream that can be used with StreamExt methods.
360pub fn receive_stream_to_stream(recv: ReceiveStream) -> impl Stream<Item = Result<Bytes, Error>> {
361    futures::stream::unfold(recv, |mut recv| async move {
362        recv.next_chunk().await.map(|result| (result, recv))
363    })
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use crate::ptp::session::tests::{
370        data_container, mock_transport, ok_response, response_with_params,
371    };
372    use crate::ptp::ResponseCode;
373
374    #[tokio::test]
375    async fn test_receive_stream_small_file() {
376        let (transport, mock) = mock_transport();
377        mock.queue_response(ok_response(1)); // OpenSession
378
379        // GetObject data response (small file fits in one container)
380        let file_data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
381        mock.queue_response(data_container(2, OperationCode::GetObject, &file_data));
382        mock.queue_response(ok_response(2));
383
384        let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
385
386        // Use streaming API
387        let mut stream = session.get_object_stream(ObjectHandle(1)).await.unwrap();
388
389        // Collect all chunks
390        let mut received = Vec::new();
391        while let Some(result) = stream.next_chunk().await {
392            let chunk = result.unwrap();
393            received.extend_from_slice(&chunk);
394        }
395
396        assert_eq!(received, file_data);
397    }
398
399    #[tokio::test]
400    async fn test_receive_stream_collect() {
401        let (transport, mock) = mock_transport();
402        mock.queue_response(ok_response(1)); // OpenSession
403
404        let file_data = vec![1, 2, 3, 4, 5];
405        mock.queue_response(data_container(2, OperationCode::GetObject, &file_data));
406        mock.queue_response(ok_response(2));
407
408        let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
409
410        let stream = session.get_object_stream(ObjectHandle(1)).await.unwrap();
411        let collected = stream.collect().await.unwrap();
412
413        assert_eq!(collected, file_data);
414    }
415
416    #[tokio::test]
417    async fn test_receive_stream_error_response() {
418        let (transport, mock) = mock_transport();
419        mock.queue_response(ok_response(1)); // OpenSession
420
421        // Return error response instead of data
422        mock.queue_response(response_with_params(
423            2,
424            ResponseCode::InvalidObjectHandle,
425            &[],
426        ));
427
428        let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
429
430        let mut stream = session.get_object_stream(ObjectHandle(999)).await.unwrap();
431
432        // Should get error when reading
433        let result = stream.next_chunk().await;
434        assert!(result.is_some());
435        let err = result.unwrap();
436        assert!(err.is_err());
437    }
438
439    #[tokio::test]
440    async fn test_send_stream_small_file() {
441        use futures::stream;
442
443        let (transport, mock) = mock_transport();
444        mock.queue_response(ok_response(1)); // OpenSession
445        mock.queue_response(ok_response(2)); // SendObject response
446
447        let session = PtpSession::open(transport, 1).await.unwrap();
448
449        // Create a small data stream (use iter instead of once for Unpin)
450        let data = vec![1u8, 2, 3, 4, 5];
451        let data_stream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(data.clone()))]);
452
453        // Send using streaming API
454        session.send_object_stream(5, data_stream).await.unwrap();
455    }
456
457    #[tokio::test]
458    async fn test_send_stream_multiple_chunks() {
459        use futures::stream;
460
461        let (transport, mock) = mock_transport();
462        mock.queue_response(ok_response(1)); // OpenSession
463        mock.queue_response(ok_response(2)); // SendObject response
464
465        let session = PtpSession::open(transport, 1).await.unwrap();
466
467        // Create a multi-chunk data stream
468        let chunks = vec![
469            Ok::<_, std::io::Error>(Bytes::from(vec![1, 2, 3])),
470            Ok(Bytes::from(vec![4, 5, 6])),
471            Ok(Bytes::from(vec![7, 8, 9, 10])),
472        ];
473        let data_stream = stream::iter(chunks);
474
475        // Send using streaming API (total size = 10)
476        session.send_object_stream(10, data_stream).await.unwrap();
477    }
478
479    #[tokio::test]
480    async fn test_receive_stream_to_stream_conversion() {
481        let (transport, mock) = mock_transport();
482        mock.queue_response(ok_response(1)); // OpenSession
483
484        let file_data = vec![1, 2, 3, 4, 5];
485        mock.queue_response(data_container(2, OperationCode::GetObject, &file_data));
486        mock.queue_response(ok_response(2));
487
488        let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
489
490        let recv_stream = session.get_object_stream(ObjectHandle(1)).await.unwrap();
491
492        // Convert to futures::Stream and use StreamExt
493        // Use pin_mut! to make it Unpin
494        use futures::StreamExt;
495        use std::pin::pin;
496        let mut stream = pin!(receive_stream_to_stream(recv_stream));
497
498        let mut collected = Vec::new();
499        while let Some(result) = stream.next().await {
500            collected.extend_from_slice(&result.unwrap());
501        }
502
503        assert_eq!(collected, file_data);
504    }
505}