1use crate::ptp::{
8 container_type, pack_u16, pack_u32, unpack_u32, CommandContainer, ContainerType, ObjectHandle,
9 OperationCode, ResponseCode, ResponseContainer,
10};
11use crate::transport::Transport;
12use crate::Error;
13use bytes::Bytes;
14use futures::lock::OwnedMutexGuard;
15use futures::Stream;
16use std::sync::Arc;
17
18use super::{PtpSession, HEADER_SIZE};
19
20impl PtpSession {
21 pub async fn execute_with_receive_stream(
44 self: &Arc<Self>,
45 operation: OperationCode,
46 params: &[u32],
47 ) -> Result<ReceiveStream, Error> {
48 let lock = Arc::clone(&self.operation_lock);
50 let guard = lock.lock_owned().await;
51
52 let tx_id = self.next_transaction_id();
53
54 let cmd = CommandContainer {
56 code: operation,
57 transaction_id: tx_id,
58 params: params.to_vec(),
59 };
60 self.transport.send_bulk(&cmd.to_bytes()).await?;
61
62 Ok(ReceiveStream {
63 transport: Arc::clone(&self.transport),
64 _guard: guard,
65 transaction_id: tx_id,
66 operation,
67 buffer: Vec::new(),
68 container_length: 0,
69 payload_yielded: 0,
70 header_parsed: false,
71 done: false,
72 })
73 }
74
75 pub async fn execute_with_send_stream<S>(
92 &self,
93 operation: OperationCode,
94 params: &[u32],
95 total_size: u64,
96 mut data: S,
97 ) -> Result<ResponseContainer, Error>
98 where
99 S: Stream<Item = Result<Bytes, std::io::Error>> + Unpin,
100 {
101 use futures::StreamExt;
102
103 let _guard = self.operation_lock.lock().await;
104 let tx_id = self.next_transaction_id();
105
106 let cmd = CommandContainer {
108 code: operation,
109 transaction_id: tx_id,
110 params: params.to_vec(),
111 };
112 self.transport.send_bulk(&cmd.to_bytes()).await?;
113
114 let container_length = HEADER_SIZE as u64 + total_size;
117 let mut buffer = Vec::with_capacity(container_length as usize);
118
119 if container_length <= u32::MAX as u64 {
121 buffer.extend_from_slice(&pack_u32(container_length as u32));
122 } else {
123 buffer.extend_from_slice(&pack_u32(0xFFFFFFFF));
124 }
125 buffer.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
126 buffer.extend_from_slice(&pack_u16(operation.into()));
127 buffer.extend_from_slice(&pack_u32(tx_id));
128
129 while let Some(chunk_result) = data.next().await {
131 let chunk = chunk_result.map_err(Error::Io)?;
132 buffer.extend_from_slice(&chunk);
133 }
134
135 self.transport.send_bulk(&buffer).await?;
137
138 let response_bytes = self.transport.receive_bulk(512).await?;
140 let response = ResponseContainer::from_bytes(&response_bytes)?;
141
142 if response.transaction_id != tx_id {
143 return Err(Error::invalid_data(format!(
144 "Transaction ID mismatch: expected {}, got {}",
145 tx_id, response.transaction_id
146 )));
147 }
148
149 Ok(response)
150 }
151
152 pub async fn get_object_stream(
162 self: &Arc<Self>,
163 handle: ObjectHandle,
164 ) -> Result<ReceiveStream, Error> {
165 self.execute_with_receive_stream(OperationCode::GetObject, &[handle.0])
166 .await
167 }
168
169 pub async fn send_object_stream<S>(&self, total_size: u64, data: S) -> Result<(), Error>
178 where
179 S: Stream<Item = Result<Bytes, std::io::Error>> + Unpin,
180 {
181 let response = self
182 .execute_with_send_stream(OperationCode::SendObject, &[], total_size, data)
183 .await?;
184 Self::check_response(&response, OperationCode::SendObject)?;
185 Ok(())
186 }
187}
188
189pub struct ReceiveStream {
199 transport: Arc<dyn Transport>,
201 _guard: OwnedMutexGuard<()>,
203 transaction_id: u32,
205 operation: OperationCode,
207 buffer: Vec<u8>,
209 container_length: usize,
211 payload_yielded: usize,
213 header_parsed: bool,
215 done: bool,
217}
218
219impl ReceiveStream {
220 #[must_use]
222 pub fn transaction_id(&self) -> u32 {
223 self.transaction_id
224 }
225
226 pub async fn next_chunk(&mut self) -> Option<Result<Bytes, Error>> {
230 if self.done {
231 return None;
232 }
233
234 loop {
235 if self.header_parsed {
237 let payload_start = HEADER_SIZE + self.payload_yielded;
238 let payload_end = std::cmp::min(self.buffer.len(), self.container_length);
239
240 if payload_start < payload_end {
241 let chunk_data = self.buffer[payload_start..payload_end].to_vec();
243 self.payload_yielded += chunk_data.len();
244
245 if self.buffer.len() >= self.container_length {
247 self.buffer.drain(..self.container_length);
249 self.header_parsed = false;
250 self.container_length = 0;
251 self.payload_yielded = 0;
252 }
253
254 if !chunk_data.is_empty() {
255 return Some(Ok(Bytes::from(chunk_data)));
256 }
257 } else if self.buffer.len() >= self.container_length {
258 self.buffer.drain(..self.container_length);
260 self.header_parsed = false;
261 self.container_length = 0;
262 self.payload_yielded = 0;
263 }
264 }
265
266 match self.transport.receive_bulk(64 * 1024).await {
268 Ok(bytes) => {
269 if bytes.is_empty() {
270 return Some(Err(Error::invalid_data("Empty response from device")));
271 }
272 self.buffer.extend_from_slice(&bytes);
273 }
274 Err(e) => {
275 self.done = true;
276 return Some(Err(e));
277 }
278 }
279
280 if !self.header_parsed && self.buffer.len() >= HEADER_SIZE {
282 let ct = match container_type(&self.buffer) {
283 Ok(ct) => ct,
284 Err(e) => {
285 self.done = true;
286 return Some(Err(e));
287 }
288 };
289
290 match ct {
291 ContainerType::Data => {
292 let length = match unpack_u32(&self.buffer[0..4]) {
293 Ok(l) => l as usize,
294 Err(e) => {
295 self.done = true;
296 return Some(Err(e));
297 }
298 };
299 self.container_length = length;
300 self.header_parsed = true;
301 }
302 ContainerType::Response => {
303 let response = match ResponseContainer::from_bytes(&self.buffer) {
305 Ok(r) => r,
306 Err(e) => {
307 self.done = true;
308 return Some(Err(e));
309 }
310 };
311
312 self.done = true;
313
314 if response.transaction_id != self.transaction_id {
316 return Some(Err(Error::invalid_data(format!(
317 "Transaction ID mismatch: expected {}, got {}",
318 self.transaction_id, response.transaction_id
319 ))));
320 }
321
322 if response.code != ResponseCode::Ok {
324 return Some(Err(Error::Protocol {
325 code: response.code,
326 operation: self.operation,
327 }));
328 }
329
330 return None;
331 }
332 _ => {
333 self.done = true;
334 return Some(Err(Error::invalid_data(format!(
335 "Unexpected container type: {:?}",
336 ct
337 ))));
338 }
339 }
340 }
341 }
342 }
343
344 pub async fn collect(mut self) -> Result<Vec<u8>, Error> {
348 let mut data = Vec::new();
349 while let Some(result) = self.next_chunk().await {
350 let chunk = result?;
351 data.extend_from_slice(&chunk);
352 }
353 Ok(data)
354 }
355}
356
357pub fn receive_stream_to_stream(recv: ReceiveStream) -> impl Stream<Item = Result<Bytes, Error>> {
361 futures::stream::unfold(recv, |mut recv| async move {
362 recv.next_chunk().await.map(|result| (result, recv))
363 })
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use crate::ptp::session::tests::{
370 data_container, mock_transport, ok_response, response_with_params,
371 };
372 use crate::ptp::ResponseCode;
373
374 #[tokio::test]
375 async fn test_receive_stream_small_file() {
376 let (transport, mock) = mock_transport();
377 mock.queue_response(ok_response(1)); let file_data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
381 mock.queue_response(data_container(2, OperationCode::GetObject, &file_data));
382 mock.queue_response(ok_response(2));
383
384 let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
385
386 let mut stream = session.get_object_stream(ObjectHandle(1)).await.unwrap();
388
389 let mut received = Vec::new();
391 while let Some(result) = stream.next_chunk().await {
392 let chunk = result.unwrap();
393 received.extend_from_slice(&chunk);
394 }
395
396 assert_eq!(received, file_data);
397 }
398
399 #[tokio::test]
400 async fn test_receive_stream_collect() {
401 let (transport, mock) = mock_transport();
402 mock.queue_response(ok_response(1)); let file_data = vec![1, 2, 3, 4, 5];
405 mock.queue_response(data_container(2, OperationCode::GetObject, &file_data));
406 mock.queue_response(ok_response(2));
407
408 let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
409
410 let stream = session.get_object_stream(ObjectHandle(1)).await.unwrap();
411 let collected = stream.collect().await.unwrap();
412
413 assert_eq!(collected, file_data);
414 }
415
416 #[tokio::test]
417 async fn test_receive_stream_error_response() {
418 let (transport, mock) = mock_transport();
419 mock.queue_response(ok_response(1)); mock.queue_response(response_with_params(
423 2,
424 ResponseCode::InvalidObjectHandle,
425 &[],
426 ));
427
428 let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
429
430 let mut stream = session.get_object_stream(ObjectHandle(999)).await.unwrap();
431
432 let result = stream.next_chunk().await;
434 assert!(result.is_some());
435 let err = result.unwrap();
436 assert!(err.is_err());
437 }
438
439 #[tokio::test]
440 async fn test_send_stream_small_file() {
441 use futures::stream;
442
443 let (transport, mock) = mock_transport();
444 mock.queue_response(ok_response(1)); mock.queue_response(ok_response(2)); let session = PtpSession::open(transport, 1).await.unwrap();
448
449 let data = vec![1u8, 2, 3, 4, 5];
451 let data_stream = stream::iter(vec![Ok::<_, std::io::Error>(Bytes::from(data.clone()))]);
452
453 session.send_object_stream(5, data_stream).await.unwrap();
455 }
456
457 #[tokio::test]
458 async fn test_send_stream_multiple_chunks() {
459 use futures::stream;
460
461 let (transport, mock) = mock_transport();
462 mock.queue_response(ok_response(1)); mock.queue_response(ok_response(2)); let session = PtpSession::open(transport, 1).await.unwrap();
466
467 let chunks = vec![
469 Ok::<_, std::io::Error>(Bytes::from(vec![1, 2, 3])),
470 Ok(Bytes::from(vec![4, 5, 6])),
471 Ok(Bytes::from(vec![7, 8, 9, 10])),
472 ];
473 let data_stream = stream::iter(chunks);
474
475 session.send_object_stream(10, data_stream).await.unwrap();
477 }
478
479 #[tokio::test]
480 async fn test_receive_stream_to_stream_conversion() {
481 let (transport, mock) = mock_transport();
482 mock.queue_response(ok_response(1)); let file_data = vec![1, 2, 3, 4, 5];
485 mock.queue_response(data_container(2, OperationCode::GetObject, &file_data));
486 mock.queue_response(ok_response(2));
487
488 let session = Arc::new(PtpSession::open(transport, 1).await.unwrap());
489
490 let recv_stream = session.get_object_stream(ObjectHandle(1)).await.unwrap();
491
492 use futures::StreamExt;
495 use std::pin::pin;
496 let mut stream = pin!(receive_stream_to_stream(recv_stream));
497
498 let mut collected = Vec::new();
499 while let Some(result) = stream.next().await {
500 collected.extend_from_slice(&result.unwrap());
501 }
502
503 assert_eq!(collected, file_data);
504 }
505}