Skip to main content

mtp_rs/ptp/session/
mod.rs

1//! PTP session management.
2//!
3//! This module provides session-level operations for MTP/PTP communication.
4//! A session maintains the connection state and serializes concurrent operations.
5
6mod operations;
7mod properties;
8mod streaming;
9
10pub use streaming::{receive_stream_to_stream, ReceiveStream};
11
12use crate::ptp::{
13    container_type, unpack_u32, CommandContainer, ContainerType, DataContainer, OperationCode,
14    ResponseCode, ResponseContainer, SessionId, TransactionId,
15};
16use crate::transport::Transport;
17use crate::Error;
18use futures::lock::Mutex;
19use std::sync::atomic::{AtomicU32, Ordering};
20use std::sync::Arc;
21
22/// Container header size in bytes.
23pub(crate) const HEADER_SIZE: usize = 12;
24
25/// A PTP session with a device.
26///
27/// PtpSession manages the lifecycle of a PTP/MTP session, including:
28/// - Opening and closing sessions
29/// - Transaction ID management
30/// - Serializing concurrent operations (MTP only allows one operation at a time)
31/// - Executing operations and receiving responses
32///
33/// # Example
34///
35/// ```rust,ignore
36/// use mtp_rs::ptp::PtpSession;
37///
38/// // Open a session with session ID 1
39/// let session = PtpSession::open(transport, 1).await?;
40///
41/// // Get device info
42/// let device_info = session.get_device_info().await?;
43///
44/// // Get storage IDs
45/// let storage_ids = session.get_storage_ids().await?;
46///
47/// // Close the session when done
48/// session.close().await?;
49/// ```
50pub struct PtpSession {
51    /// The transport layer for USB communication.
52    pub(crate) transport: Arc<dyn Transport>,
53    /// The session ID assigned to this session.
54    session_id: SessionId,
55    /// Atomic counter for generating transaction IDs.
56    transaction_id: AtomicU32,
57    /// Mutex to serialize operations (MTP only allows one operation at a time).
58    /// Wrapped in Arc so it can be shared with ReceiveStream.
59    pub(crate) operation_lock: Arc<Mutex<()>>,
60}
61
62impl PtpSession {
63    /// Create a new session (internal, use open() to start session).
64    fn new(transport: Arc<dyn Transport>, session_id: SessionId) -> Self {
65        Self {
66            transport,
67            session_id,
68            transaction_id: AtomicU32::new(TransactionId::FIRST.0),
69            operation_lock: Arc::new(Mutex::new(())),
70        }
71    }
72
73    /// Open a new session with the device.
74    ///
75    /// This sends an OpenSession command to the device and establishes a session
76    /// with the given session ID.
77    ///
78    /// # Arguments
79    ///
80    /// * `transport` - The transport layer for USB communication
81    /// * `session_id` - The session ID to use (typically 1)
82    ///
83    /// # Errors
84    ///
85    /// Returns an error if the device rejects the session or communication fails.
86    pub async fn open(transport: Arc<dyn Transport>, session_id: u32) -> Result<Self, Error> {
87        let session = Self::new(transport, SessionId(session_id));
88
89        // Send OpenSession command
90        let response = session
91            .execute(OperationCode::OpenSession, &[session_id])
92            .await?;
93
94        if response.code == ResponseCode::Ok {
95            return Ok(session);
96        }
97
98        if response.code == ResponseCode::SessionAlreadyOpen {
99            // Session already exists with potentially mismatched transaction ID.
100            // Close the existing session (ignore errors) and open a fresh one.
101            let _ = session.execute(OperationCode::CloseSession, &[]).await;
102
103            // Create a new session instance with reset transaction ID counter
104            let fresh_session = Self::new(Arc::clone(&session.transport), SessionId(session_id));
105
106            let retry_response = fresh_session
107                .execute(OperationCode::OpenSession, &[session_id])
108                .await?;
109
110            if retry_response.code != ResponseCode::Ok {
111                return Err(Error::Protocol {
112                    code: retry_response.code,
113                    operation: OperationCode::OpenSession,
114                });
115            }
116
117            return Ok(fresh_session);
118        }
119
120        Err(Error::Protocol {
121            code: response.code,
122            operation: OperationCode::OpenSession,
123        })
124    }
125
126    /// Get the session ID.
127    #[must_use]
128    pub fn session_id(&self) -> SessionId {
129        self.session_id
130    }
131
132    /// Close the session.
133    ///
134    /// This sends a CloseSession command to the device. Errors during close
135    /// are ignored since the session is being terminated anyway.
136    pub async fn close(self) -> Result<(), Error> {
137        let _ = self.execute(OperationCode::CloseSession, &[]).await;
138        Ok(())
139    }
140
141    /// Get the next transaction ID.
142    ///
143    /// Transaction IDs start at 1 and wrap correctly, skipping 0 and 0xFFFFFFFF.
144    pub(crate) fn next_transaction_id(&self) -> u32 {
145        loop {
146            let current = self.transaction_id.load(Ordering::SeqCst);
147            let next = TransactionId(current).next().0;
148            if self
149                .transaction_id
150                .compare_exchange(current, next, Ordering::SeqCst, Ordering::SeqCst)
151                .is_ok()
152            {
153                return current;
154            }
155        }
156    }
157
158    // =========================================================================
159    // Core operation execution
160    // =========================================================================
161
162    /// Execute an operation without data phase.
163    pub(crate) async fn execute(
164        &self,
165        operation: OperationCode,
166        params: &[u32],
167    ) -> Result<ResponseContainer, Error> {
168        let _guard = self.operation_lock.lock().await;
169
170        let tx_id = self.next_transaction_id();
171
172        // Build and send command
173        let cmd = CommandContainer {
174            code: operation,
175            transaction_id: tx_id,
176            params: params.to_vec(),
177        };
178        self.transport.send_bulk(&cmd.to_bytes()).await?;
179
180        // Receive response
181        let response_bytes = self.transport.receive_bulk(512).await?;
182        let response = ResponseContainer::from_bytes(&response_bytes)?;
183
184        // Verify transaction ID matches
185        if response.transaction_id != tx_id {
186            return Err(Error::invalid_data(format!(
187                "Transaction ID mismatch: expected {}, got {}",
188                tx_id, response.transaction_id
189            )));
190        }
191
192        Ok(response)
193    }
194
195    /// Execute operation with data receive phase.
196    pub(crate) async fn execute_with_receive(
197        &self,
198        operation: OperationCode,
199        params: &[u32],
200    ) -> Result<(ResponseContainer, Vec<u8>), Error> {
201        let _guard = self.operation_lock.lock().await;
202
203        let tx_id = self.next_transaction_id();
204
205        // Send command
206        let cmd = CommandContainer {
207            code: operation,
208            transaction_id: tx_id,
209            params: params.to_vec(),
210        };
211        self.transport.send_bulk(&cmd.to_bytes()).await?;
212
213        // Receive data container(s)
214        // MTP sends data in one or more containers, then response.
215        // A single data container may span multiple USB transfers if larger than 64KB.
216        let mut data = Vec::new();
217
218        loop {
219            let mut bytes = self.transport.receive_bulk(64 * 1024).await?;
220            if bytes.is_empty() {
221                return Err(Error::invalid_data("Empty response"));
222            }
223
224            let ct = container_type(&bytes)?;
225            match ct {
226                ContainerType::Data => {
227                    // Check if we need to receive more data for this container.
228                    // The length field in the header tells us the total container size.
229                    if bytes.len() >= 4 {
230                        let total_length = unpack_u32(&bytes[0..4])? as usize;
231                        // Keep receiving until we have the complete container
232                        while bytes.len() < total_length {
233                            let more = self.transport.receive_bulk(64 * 1024).await?;
234                            if more.is_empty() {
235                                return Err(Error::invalid_data(
236                                    "Incomplete data container: device stopped sending",
237                                ));
238                            }
239                            bytes.extend_from_slice(&more);
240                        }
241                    }
242                    let container = DataContainer::from_bytes(&bytes)?;
243                    data.extend_from_slice(&container.payload);
244                    // Continue to receive more containers or response
245                }
246                ContainerType::Response => {
247                    let response = ResponseContainer::from_bytes(&bytes)?;
248                    if response.transaction_id != tx_id {
249                        return Err(Error::invalid_data(format!(
250                            "Transaction ID mismatch: expected {}, got {}",
251                            tx_id, response.transaction_id
252                        )));
253                    }
254                    return Ok((response, data));
255                }
256                _ => {
257                    return Err(Error::invalid_data(format!(
258                        "Unexpected container type: {:?}",
259                        ct
260                    )));
261                }
262            }
263        }
264    }
265
266    /// Execute operation with data send phase.
267    pub(crate) async fn execute_with_send(
268        &self,
269        operation: OperationCode,
270        params: &[u32],
271        data: &[u8],
272    ) -> Result<ResponseContainer, Error> {
273        let _guard = self.operation_lock.lock().await;
274
275        let tx_id = self.next_transaction_id();
276
277        // Send command
278        let cmd = CommandContainer {
279            code: operation,
280            transaction_id: tx_id,
281            params: params.to_vec(),
282        };
283        self.transport.send_bulk(&cmd.to_bytes()).await?;
284
285        // Send data
286        let data_container = DataContainer {
287            code: operation,
288            transaction_id: tx_id,
289            payload: data.to_vec(),
290        };
291        self.transport.send_bulk(&data_container.to_bytes()).await?;
292
293        // Receive response
294        let response_bytes = self.transport.receive_bulk(512).await?;
295        let response = ResponseContainer::from_bytes(&response_bytes)?;
296
297        if response.transaction_id != tx_id {
298            return Err(Error::invalid_data(format!(
299                "Transaction ID mismatch: expected {}, got {}",
300                tx_id, response.transaction_id
301            )));
302        }
303
304        Ok(response)
305    }
306
307    // =========================================================================
308    // Helper methods
309    // =========================================================================
310
311    /// Helper to check response is OK.
312    pub(crate) fn check_response(
313        response: &ResponseContainer,
314        operation: OperationCode,
315    ) -> Result<(), Error> {
316        if response.code == ResponseCode::Ok {
317            Ok(())
318        } else {
319            Err(Error::Protocol {
320                code: response.code,
321                operation,
322            })
323        }
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::ptp::{pack_u16, pack_u32, ContainerType, ObjectHandle};
331    use crate::transport::mock::MockTransport;
332
333    /// Create a mock transport as Arc<dyn Transport>.
334    pub(crate) fn mock_transport() -> (Arc<dyn Transport>, Arc<MockTransport>) {
335        let mock = Arc::new(MockTransport::new());
336        let transport: Arc<dyn Transport> = Arc::clone(&mock) as Arc<dyn Transport>;
337        (transport, mock)
338    }
339
340    /// Build an OK response container bytes.
341    pub(crate) fn ok_response(tx_id: u32) -> Vec<u8> {
342        let mut buf = Vec::with_capacity(12);
343        buf.extend_from_slice(&pack_u32(12)); // length
344        buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
345        buf.extend_from_slice(&pack_u16(ResponseCode::Ok.into()));
346        buf.extend_from_slice(&pack_u32(tx_id));
347        buf
348    }
349
350    /// Build a response container with params.
351    pub(crate) fn response_with_params(tx_id: u32, code: ResponseCode, params: &[u32]) -> Vec<u8> {
352        let len = 12 + params.len() * 4;
353        let mut buf = Vec::with_capacity(len);
354        buf.extend_from_slice(&pack_u32(len as u32));
355        buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
356        buf.extend_from_slice(&pack_u16(code.into()));
357        buf.extend_from_slice(&pack_u32(tx_id));
358        for p in params {
359            buf.extend_from_slice(&pack_u32(*p));
360        }
361        buf
362    }
363
364    /// Build a data container.
365    pub(crate) fn data_container(tx_id: u32, code: OperationCode, payload: &[u8]) -> Vec<u8> {
366        let len = 12 + payload.len();
367        let mut buf = Vec::with_capacity(len);
368        buf.extend_from_slice(&pack_u32(len as u32));
369        buf.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
370        buf.extend_from_slice(&pack_u16(code.into()));
371        buf.extend_from_slice(&pack_u32(tx_id));
372        buf.extend_from_slice(payload);
373        buf
374    }
375
376    #[tokio::test]
377    async fn test_open_session() {
378        let (transport, mock) = mock_transport();
379        mock.queue_response(ok_response(1));
380
381        let session = PtpSession::open(transport, 1).await.unwrap();
382        assert_eq!(session.session_id(), SessionId(1));
383    }
384
385    #[tokio::test]
386    async fn test_open_session_already_open_recovers() {
387        let (transport, mock) = mock_transport();
388
389        // First OpenSession returns SessionAlreadyOpen
390        mock.queue_response(response_with_params(
391            1,
392            ResponseCode::SessionAlreadyOpen,
393            &[],
394        ));
395        // CloseSession response (ignored, but we need to provide one)
396        mock.queue_response(ok_response(2));
397        // Second OpenSession (fresh session, tx_id starts at 1 again)
398        mock.queue_response(ok_response(1));
399
400        // Should succeed by closing and reopening
401        let session = PtpSession::open(transport, 1).await.unwrap();
402        assert_eq!(session.session_id(), SessionId(1));
403    }
404
405    #[tokio::test]
406    async fn test_open_session_already_open_transaction_id_reset() {
407        let (transport, mock) = mock_transport();
408
409        // First OpenSession returns SessionAlreadyOpen
410        mock.queue_response(response_with_params(
411            1,
412            ResponseCode::SessionAlreadyOpen,
413            &[],
414        ));
415        // CloseSession response
416        mock.queue_response(ok_response(2));
417        // Second OpenSession (fresh session, tx_id starts at 1 again)
418        mock.queue_response(ok_response(1));
419        // Next operation should use tx_id = 2 (after the fresh OpenSession used 1)
420        mock.queue_response(ok_response(2));
421
422        let session = PtpSession::open(transport, 1).await.unwrap();
423
424        // Perform an operation to verify transaction ID is properly reset
425        // The next operation should use tx_id = 2 (since the fresh OpenSession used 1)
426        session.delete_object(ObjectHandle(1)).await.unwrap();
427    }
428
429    #[tokio::test]
430    async fn test_open_session_already_open_close_error_ignored() {
431        let (transport, mock) = mock_transport();
432
433        // First OpenSession returns SessionAlreadyOpen
434        mock.queue_response(response_with_params(
435            1,
436            ResponseCode::SessionAlreadyOpen,
437            &[],
438        ));
439        // CloseSession returns an error (should be ignored)
440        mock.queue_response(response_with_params(2, ResponseCode::GeneralError, &[]));
441        // Second OpenSession succeeds
442        mock.queue_response(ok_response(1));
443
444        // Should succeed even if CloseSession fails
445        let session = PtpSession::open(transport, 1).await.unwrap();
446        assert_eq!(session.session_id(), SessionId(1));
447    }
448
449    #[tokio::test]
450    async fn test_open_session_error() {
451        let (transport, mock) = mock_transport();
452        mock.queue_response(response_with_params(1, ResponseCode::GeneralError, &[]));
453
454        let result = PtpSession::open(transport, 1).await;
455        assert!(result.is_err());
456    }
457
458    #[tokio::test]
459    async fn test_transaction_id_increment() {
460        let (transport, mock) = mock_transport();
461        mock.queue_response(ok_response(1)); // OpenSession
462        mock.queue_response(ok_response(2)); // First operation
463        mock.queue_response(ok_response(3)); // Second operation
464
465        let session = PtpSession::open(transport, 1).await.unwrap();
466
467        // Execute two operations and verify transaction IDs increment
468        session.delete_object(ObjectHandle(1)).await.unwrap();
469        session.delete_object(ObjectHandle(2)).await.unwrap();
470    }
471
472    #[tokio::test]
473    async fn test_transaction_id_mismatch() {
474        let (transport, mock) = mock_transport();
475        mock.queue_response(ok_response(1)); // OpenSession
476        mock.queue_response(ok_response(999)); // Wrong transaction ID
477
478        let session = PtpSession::open(transport, 1).await.unwrap();
479        let result = session.delete_object(ObjectHandle(1)).await;
480
481        assert!(result.is_err());
482    }
483
484    #[tokio::test]
485    async fn test_close_session() {
486        let (transport, mock) = mock_transport();
487        mock.queue_response(ok_response(1)); // OpenSession
488        mock.queue_response(ok_response(2)); // CloseSession
489
490        let session = PtpSession::open(transport, 1).await.unwrap();
491        session.close().await.unwrap();
492    }
493
494    #[tokio::test]
495    async fn test_close_session_ignores_errors() {
496        let (transport, mock) = mock_transport();
497        mock.queue_response(ok_response(1)); // OpenSession
498        mock.queue_response(response_with_params(2, ResponseCode::GeneralError, &[])); // CloseSession error
499
500        let session = PtpSession::open(transport, 1).await.unwrap();
501        // Should succeed even if close fails
502        session.close().await.unwrap();
503    }
504}