1mod operations;
7mod properties;
8mod streaming;
9
10pub use streaming::{receive_stream_to_stream, ReceiveStream};
11
12use crate::ptp::{
13 container_type, unpack_u32, CommandContainer, ContainerType, DataContainer, OperationCode,
14 ResponseCode, ResponseContainer, SessionId, TransactionId,
15};
16use crate::transport::Transport;
17use crate::Error;
18use futures::lock::Mutex;
19use std::sync::atomic::{AtomicU32, Ordering};
20use std::sync::Arc;
21
22pub(crate) const HEADER_SIZE: usize = 12;
24
25pub struct PtpSession {
51 pub(crate) transport: Arc<dyn Transport>,
53 session_id: SessionId,
55 transaction_id: AtomicU32,
57 pub(crate) operation_lock: Arc<Mutex<()>>,
60}
61
62impl PtpSession {
63 fn new(transport: Arc<dyn Transport>, session_id: SessionId) -> Self {
65 Self {
66 transport,
67 session_id,
68 transaction_id: AtomicU32::new(TransactionId::FIRST.0),
69 operation_lock: Arc::new(Mutex::new(())),
70 }
71 }
72
73 pub async fn open(transport: Arc<dyn Transport>, session_id: u32) -> Result<Self, Error> {
87 let session = Self::new(transport, SessionId(session_id));
88
89 let response = session
91 .execute(OperationCode::OpenSession, &[session_id])
92 .await?;
93
94 if response.code == ResponseCode::Ok {
95 return Ok(session);
96 }
97
98 if response.code == ResponseCode::SessionAlreadyOpen {
99 let _ = session.execute(OperationCode::CloseSession, &[]).await;
102
103 let fresh_session = Self::new(Arc::clone(&session.transport), SessionId(session_id));
105
106 let retry_response = fresh_session
107 .execute(OperationCode::OpenSession, &[session_id])
108 .await?;
109
110 if retry_response.code != ResponseCode::Ok {
111 return Err(Error::Protocol {
112 code: retry_response.code,
113 operation: OperationCode::OpenSession,
114 });
115 }
116
117 return Ok(fresh_session);
118 }
119
120 Err(Error::Protocol {
121 code: response.code,
122 operation: OperationCode::OpenSession,
123 })
124 }
125
126 #[must_use]
128 pub fn session_id(&self) -> SessionId {
129 self.session_id
130 }
131
132 pub async fn close(self) -> Result<(), Error> {
137 let _ = self.execute(OperationCode::CloseSession, &[]).await;
138 Ok(())
139 }
140
141 pub(crate) fn next_transaction_id(&self) -> u32 {
145 loop {
146 let current = self.transaction_id.load(Ordering::SeqCst);
147 let next = TransactionId(current).next().0;
148 if self
149 .transaction_id
150 .compare_exchange(current, next, Ordering::SeqCst, Ordering::SeqCst)
151 .is_ok()
152 {
153 return current;
154 }
155 }
156 }
157
158 pub(crate) async fn execute(
164 &self,
165 operation: OperationCode,
166 params: &[u32],
167 ) -> Result<ResponseContainer, Error> {
168 let _guard = self.operation_lock.lock().await;
169
170 let tx_id = self.next_transaction_id();
171
172 let cmd = CommandContainer {
174 code: operation,
175 transaction_id: tx_id,
176 params: params.to_vec(),
177 };
178 self.transport.send_bulk(&cmd.to_bytes()).await?;
179
180 let response_bytes = self.transport.receive_bulk(512).await?;
182 let response = ResponseContainer::from_bytes(&response_bytes)?;
183
184 if response.transaction_id != tx_id {
186 return Err(Error::invalid_data(format!(
187 "Transaction ID mismatch: expected {}, got {}",
188 tx_id, response.transaction_id
189 )));
190 }
191
192 Ok(response)
193 }
194
195 pub(crate) async fn execute_with_receive(
197 &self,
198 operation: OperationCode,
199 params: &[u32],
200 ) -> Result<(ResponseContainer, Vec<u8>), Error> {
201 let _guard = self.operation_lock.lock().await;
202
203 let tx_id = self.next_transaction_id();
204
205 let cmd = CommandContainer {
207 code: operation,
208 transaction_id: tx_id,
209 params: params.to_vec(),
210 };
211 self.transport.send_bulk(&cmd.to_bytes()).await?;
212
213 let mut data = Vec::new();
217
218 loop {
219 let mut bytes = self.transport.receive_bulk(64 * 1024).await?;
220 if bytes.is_empty() {
221 return Err(Error::invalid_data("Empty response"));
222 }
223
224 let ct = container_type(&bytes)?;
225 match ct {
226 ContainerType::Data => {
227 if bytes.len() >= 4 {
230 let total_length = unpack_u32(&bytes[0..4])? as usize;
231 while bytes.len() < total_length {
233 let more = self.transport.receive_bulk(64 * 1024).await?;
234 if more.is_empty() {
235 return Err(Error::invalid_data(
236 "Incomplete data container: device stopped sending",
237 ));
238 }
239 bytes.extend_from_slice(&more);
240 }
241 }
242 let container = DataContainer::from_bytes(&bytes)?;
243 data.extend_from_slice(&container.payload);
244 }
246 ContainerType::Response => {
247 let response = ResponseContainer::from_bytes(&bytes)?;
248 if response.transaction_id != tx_id {
249 return Err(Error::invalid_data(format!(
250 "Transaction ID mismatch: expected {}, got {}",
251 tx_id, response.transaction_id
252 )));
253 }
254 return Ok((response, data));
255 }
256 _ => {
257 return Err(Error::invalid_data(format!(
258 "Unexpected container type: {:?}",
259 ct
260 )));
261 }
262 }
263 }
264 }
265
266 pub(crate) async fn execute_with_send(
268 &self,
269 operation: OperationCode,
270 params: &[u32],
271 data: &[u8],
272 ) -> Result<ResponseContainer, Error> {
273 let _guard = self.operation_lock.lock().await;
274
275 let tx_id = self.next_transaction_id();
276
277 let cmd = CommandContainer {
279 code: operation,
280 transaction_id: tx_id,
281 params: params.to_vec(),
282 };
283 self.transport.send_bulk(&cmd.to_bytes()).await?;
284
285 let data_container = DataContainer {
287 code: operation,
288 transaction_id: tx_id,
289 payload: data.to_vec(),
290 };
291 self.transport.send_bulk(&data_container.to_bytes()).await?;
292
293 let response_bytes = self.transport.receive_bulk(512).await?;
295 let response = ResponseContainer::from_bytes(&response_bytes)?;
296
297 if response.transaction_id != tx_id {
298 return Err(Error::invalid_data(format!(
299 "Transaction ID mismatch: expected {}, got {}",
300 tx_id, response.transaction_id
301 )));
302 }
303
304 Ok(response)
305 }
306
307 pub(crate) fn check_response(
313 response: &ResponseContainer,
314 operation: OperationCode,
315 ) -> Result<(), Error> {
316 if response.code == ResponseCode::Ok {
317 Ok(())
318 } else {
319 Err(Error::Protocol {
320 code: response.code,
321 operation,
322 })
323 }
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::ptp::{pack_u16, pack_u32, ContainerType, ObjectHandle};
331 use crate::transport::mock::MockTransport;
332
333 pub(crate) fn mock_transport() -> (Arc<dyn Transport>, Arc<MockTransport>) {
335 let mock = Arc::new(MockTransport::new());
336 let transport: Arc<dyn Transport> = Arc::clone(&mock) as Arc<dyn Transport>;
337 (transport, mock)
338 }
339
340 pub(crate) fn ok_response(tx_id: u32) -> Vec<u8> {
342 let mut buf = Vec::with_capacity(12);
343 buf.extend_from_slice(&pack_u32(12)); buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
345 buf.extend_from_slice(&pack_u16(ResponseCode::Ok.into()));
346 buf.extend_from_slice(&pack_u32(tx_id));
347 buf
348 }
349
350 pub(crate) fn response_with_params(tx_id: u32, code: ResponseCode, params: &[u32]) -> Vec<u8> {
352 let len = 12 + params.len() * 4;
353 let mut buf = Vec::with_capacity(len);
354 buf.extend_from_slice(&pack_u32(len as u32));
355 buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
356 buf.extend_from_slice(&pack_u16(code.into()));
357 buf.extend_from_slice(&pack_u32(tx_id));
358 for p in params {
359 buf.extend_from_slice(&pack_u32(*p));
360 }
361 buf
362 }
363
364 pub(crate) fn data_container(tx_id: u32, code: OperationCode, payload: &[u8]) -> Vec<u8> {
366 let len = 12 + payload.len();
367 let mut buf = Vec::with_capacity(len);
368 buf.extend_from_slice(&pack_u32(len as u32));
369 buf.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
370 buf.extend_from_slice(&pack_u16(code.into()));
371 buf.extend_from_slice(&pack_u32(tx_id));
372 buf.extend_from_slice(payload);
373 buf
374 }
375
376 #[tokio::test]
377 async fn test_open_session() {
378 let (transport, mock) = mock_transport();
379 mock.queue_response(ok_response(1));
380
381 let session = PtpSession::open(transport, 1).await.unwrap();
382 assert_eq!(session.session_id(), SessionId(1));
383 }
384
385 #[tokio::test]
386 async fn test_open_session_already_open_recovers() {
387 let (transport, mock) = mock_transport();
388
389 mock.queue_response(response_with_params(
391 1,
392 ResponseCode::SessionAlreadyOpen,
393 &[],
394 ));
395 mock.queue_response(ok_response(2));
397 mock.queue_response(ok_response(1));
399
400 let session = PtpSession::open(transport, 1).await.unwrap();
402 assert_eq!(session.session_id(), SessionId(1));
403 }
404
405 #[tokio::test]
406 async fn test_open_session_already_open_transaction_id_reset() {
407 let (transport, mock) = mock_transport();
408
409 mock.queue_response(response_with_params(
411 1,
412 ResponseCode::SessionAlreadyOpen,
413 &[],
414 ));
415 mock.queue_response(ok_response(2));
417 mock.queue_response(ok_response(1));
419 mock.queue_response(ok_response(2));
421
422 let session = PtpSession::open(transport, 1).await.unwrap();
423
424 session.delete_object(ObjectHandle(1)).await.unwrap();
427 }
428
429 #[tokio::test]
430 async fn test_open_session_already_open_close_error_ignored() {
431 let (transport, mock) = mock_transport();
432
433 mock.queue_response(response_with_params(
435 1,
436 ResponseCode::SessionAlreadyOpen,
437 &[],
438 ));
439 mock.queue_response(response_with_params(2, ResponseCode::GeneralError, &[]));
441 mock.queue_response(ok_response(1));
443
444 let session = PtpSession::open(transport, 1).await.unwrap();
446 assert_eq!(session.session_id(), SessionId(1));
447 }
448
449 #[tokio::test]
450 async fn test_open_session_error() {
451 let (transport, mock) = mock_transport();
452 mock.queue_response(response_with_params(1, ResponseCode::GeneralError, &[]));
453
454 let result = PtpSession::open(transport, 1).await;
455 assert!(result.is_err());
456 }
457
458 #[tokio::test]
459 async fn test_transaction_id_increment() {
460 let (transport, mock) = mock_transport();
461 mock.queue_response(ok_response(1)); mock.queue_response(ok_response(2)); mock.queue_response(ok_response(3)); let session = PtpSession::open(transport, 1).await.unwrap();
466
467 session.delete_object(ObjectHandle(1)).await.unwrap();
469 session.delete_object(ObjectHandle(2)).await.unwrap();
470 }
471
472 #[tokio::test]
473 async fn test_transaction_id_mismatch() {
474 let (transport, mock) = mock_transport();
475 mock.queue_response(ok_response(1)); mock.queue_response(ok_response(999)); let session = PtpSession::open(transport, 1).await.unwrap();
479 let result = session.delete_object(ObjectHandle(1)).await;
480
481 assert!(result.is_err());
482 }
483
484 #[tokio::test]
485 async fn test_close_session() {
486 let (transport, mock) = mock_transport();
487 mock.queue_response(ok_response(1)); mock.queue_response(ok_response(2)); let session = PtpSession::open(transport, 1).await.unwrap();
491 session.close().await.unwrap();
492 }
493
494 #[tokio::test]
495 async fn test_close_session_ignores_errors() {
496 let (transport, mock) = mock_transport();
497 mock.queue_response(ok_response(1)); mock.queue_response(response_with_params(2, ResponseCode::GeneralError, &[])); let session = PtpSession::open(transport, 1).await.unwrap();
501 session.close().await.unwrap();
503 }
504}