mod operations;
mod properties;
mod streaming;
pub use streaming::{receive_stream_to_stream, ReceiveStream};
use crate::ptp::{
container_type, pack_u16, pack_u32, unpack_u32, CommandContainer, ContainerType, DataContainer,
OperationCode, ResponseCode, ResponseContainer, SessionId, TransactionId,
};
use crate::transport::Transport;
use crate::Error;
use futures::lock::Mutex;
use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
use std::sync::Arc;
pub(crate) const HEADER_SIZE: usize = 12;
const RECOVERY_DRAIN_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(300);
#[derive(Debug, Default)]
pub(crate) struct RecoveryState {
needed: AtomicBool,
poisoned_tx_id: AtomicU32,
}
impl RecoveryState {
fn flag(&self, tx_id: u32) {
self.poisoned_tx_id.store(tx_id, Ordering::SeqCst);
self.needed.store(true, Ordering::SeqCst);
}
fn take(&self) -> Option<u32> {
if self.needed.swap(false, Ordering::SeqCst) {
Some(self.poisoned_tx_id.load(Ordering::SeqCst))
} else {
None
}
}
fn is_needed(&self) -> bool {
self.needed.load(Ordering::SeqCst)
}
}
struct TransactionScope<'a> {
recovery: &'a RecoveryState,
tx_id: u32,
armed: bool,
}
impl<'a> TransactionScope<'a> {
fn arm(recovery: &'a RecoveryState, tx_id: u32) -> Self {
Self {
recovery,
tx_id,
armed: true,
}
}
fn disarm(&mut self) {
self.armed = false;
}
}
impl Drop for TransactionScope<'_> {
fn drop(&mut self) {
if self.armed {
self.recovery.flag(self.tx_id);
}
}
}
pub struct PtpSession {
pub(crate) transport: Arc<dyn Transport>,
session_id: SessionId,
transaction_id: AtomicU32,
pub(crate) operation_lock: Arc<Mutex<()>>,
split_header_data: AtomicBool,
recovery: Arc<RecoveryState>,
}
impl PtpSession {
fn new(transport: Arc<dyn Transport>, session_id: SessionId) -> Self {
Self {
transport,
session_id,
transaction_id: AtomicU32::new(TransactionId::FIRST.0),
operation_lock: Arc::new(Mutex::new(())),
split_header_data: AtomicBool::new(false),
recovery: Arc::new(RecoveryState::default()),
}
}
#[must_use]
pub fn needs_recovery(&self) -> bool {
self.recovery.is_needed()
}
pub fn set_split_header_data(&self, split: bool) {
self.split_header_data.store(split, Ordering::Relaxed);
}
#[must_use]
pub fn is_split_header_data(&self) -> bool {
self.split_header_data.load(Ordering::Relaxed)
}
pub async fn open(transport: Arc<dyn Transport>, session_id: u32) -> Result<Self, Error> {
let session = Self::new(transport, SessionId(session_id));
let response = Self::send_open_session(&session.transport, session_id).await?;
if response.code == ResponseCode::Ok {
return Ok(session);
}
if response.code == ResponseCode::SessionAlreadyOpen {
let _ = session.execute(OperationCode::CloseSession, &[]).await;
let fresh_session = Self::new(Arc::clone(&session.transport), SessionId(session_id));
let retry_response =
Self::send_open_session(&fresh_session.transport, session_id).await?;
if retry_response.code != ResponseCode::Ok {
return Err(Error::Protocol {
code: retry_response.code,
operation: OperationCode::OpenSession,
});
}
return Ok(fresh_session);
}
Err(Error::Protocol {
code: response.code,
operation: OperationCode::OpenSession,
})
}
async fn send_open_session(
transport: &Arc<dyn Transport>,
session_id: u32,
) -> Result<ResponseContainer, Error> {
let cmd = CommandContainer {
code: OperationCode::OpenSession,
transaction_id: TransactionId::SESSION_LESS.0,
params: vec![session_id],
};
transport.send_bulk(&cmd.to_bytes()).await?;
let response_bytes = transport.receive_bulk(512).await?;
ResponseContainer::from_bytes(&response_bytes)
}
#[must_use]
pub fn session_id(&self) -> SessionId {
self.session_id
}
pub async fn close(self) -> Result<(), Error> {
let _ = self.execute(OperationCode::CloseSession, &[]).await;
Ok(())
}
pub(crate) fn next_transaction_id(&self) -> u32 {
loop {
let current = self.transaction_id.load(Ordering::SeqCst);
let next = TransactionId(current).next().0;
if self
.transaction_id
.compare_exchange(current, next, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
return current;
}
}
}
async fn recover_if_needed(&self) -> Result<(), Error> {
let Some(tx_id) = self.recovery.take() else {
return Ok(());
};
if let Err(e) = self
.transport
.cancel_transfer(tx_id, RECOVERY_DRAIN_TIMEOUT)
.await
{
self.recovery.flag(tx_id);
return Err(e);
}
Ok(())
}
pub async fn execute(
&self,
operation: OperationCode,
params: &[u32],
) -> Result<ResponseContainer, Error> {
let _guard = self.operation_lock.lock().await;
self.recover_if_needed().await?;
let tx_id = self.next_transaction_id();
let mut scope = TransactionScope::arm(&self.recovery, tx_id);
let cmd = CommandContainer {
code: operation,
transaction_id: tx_id,
params: params.to_vec(),
};
self.transport.send_bulk(&cmd.to_bytes()).await?;
let response_bytes = self.transport.receive_bulk(512).await?;
let response = ResponseContainer::from_bytes(&response_bytes)?;
if response.transaction_id != tx_id {
return Err(Error::invalid_data(format!(
"Transaction ID mismatch: expected {}, got {}",
tx_id, response.transaction_id
)));
}
scope.disarm();
Ok(response)
}
pub async fn execute_with_receive(
&self,
operation: OperationCode,
params: &[u32],
) -> Result<(ResponseContainer, Vec<u8>), Error> {
let _guard = self.operation_lock.lock().await;
self.recover_if_needed().await?;
let tx_id = self.next_transaction_id();
let mut scope = TransactionScope::arm(&self.recovery, tx_id);
let cmd = CommandContainer {
code: operation,
transaction_id: tx_id,
params: params.to_vec(),
};
self.transport.send_bulk(&cmd.to_bytes()).await?;
let mut data = Vec::new();
loop {
let mut bytes = self.transport.receive_bulk(64 * 1024).await?;
if bytes.is_empty() {
return Err(Error::invalid_data("Empty response"));
}
let ct = container_type(&bytes)?;
match ct {
ContainerType::Data => {
if bytes.len() >= 4 {
let total_length = unpack_u32(&bytes[0..4])? as usize;
while bytes.len() < total_length {
let more = self.transport.receive_bulk(64 * 1024).await?;
if more.is_empty() {
return Err(Error::invalid_data(
"Incomplete data container: device stopped sending",
));
}
bytes.extend_from_slice(&more);
}
}
let container = DataContainer::from_bytes(&bytes)?;
data.extend_from_slice(&container.payload);
}
ContainerType::Response => {
let response = ResponseContainer::from_bytes(&bytes)?;
if response.transaction_id != tx_id {
return Err(Error::invalid_data(format!(
"Transaction ID mismatch: expected {}, got {}",
tx_id, response.transaction_id
)));
}
scope.disarm();
return Ok((response, data));
}
_ => {
return Err(Error::invalid_data(format!(
"Unexpected container type: {:?}",
ct
)));
}
}
}
}
pub async fn execute_with_send(
&self,
operation: OperationCode,
params: &[u32],
data: &[u8],
) -> Result<ResponseContainer, Error> {
let _guard = self.operation_lock.lock().await;
self.recover_if_needed().await?;
let tx_id = self.next_transaction_id();
let mut scope = TransactionScope::arm(&self.recovery, tx_id);
let cmd = CommandContainer {
code: operation,
transaction_id: tx_id,
params: params.to_vec(),
};
self.transport.send_bulk(&cmd.to_bytes()).await?;
if self.split_header_data.load(Ordering::Relaxed) {
let total_len = (HEADER_SIZE + data.len()) as u32;
let mut header = Vec::with_capacity(HEADER_SIZE);
header.extend_from_slice(&pack_u32(total_len));
header.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
header.extend_from_slice(&pack_u16(operation.into()));
header.extend_from_slice(&pack_u32(tx_id));
self.transport.send_bulk(&header).await?;
if !data.is_empty() {
self.transport.send_bulk(data).await?;
}
} else {
let data_container = DataContainer {
code: operation,
transaction_id: tx_id,
payload: data.to_vec(),
};
self.transport.send_bulk(&data_container.to_bytes()).await?;
}
let response_bytes = self.transport.receive_bulk(512).await?;
let response = ResponseContainer::from_bytes(&response_bytes)?;
if response.transaction_id != tx_id {
return Err(Error::invalid_data(format!(
"Transaction ID mismatch: expected {}, got {}",
tx_id, response.transaction_id
)));
}
scope.disarm();
Ok(response)
}
pub(crate) fn check_response(
response: &ResponseContainer,
operation: OperationCode,
) -> Result<(), Error> {
if response.code == ResponseCode::Ok {
Ok(())
} else {
Err(Error::Protocol {
code: response.code,
operation,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ptp::{pack_u16, pack_u32, ContainerType, ObjectHandle};
use crate::transport::mock::MockTransport;
pub(crate) fn mock_transport() -> (Arc<dyn Transport>, Arc<MockTransport>) {
let mock = Arc::new(MockTransport::new());
let transport: Arc<dyn Transport> = Arc::clone(&mock) as Arc<dyn Transport>;
(transport, mock)
}
pub(crate) fn ok_response(tx_id: u32) -> Vec<u8> {
let mut buf = Vec::with_capacity(12);
buf.extend_from_slice(&pack_u32(12)); buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
buf.extend_from_slice(&pack_u16(ResponseCode::Ok.into()));
buf.extend_from_slice(&pack_u32(tx_id));
buf
}
pub(crate) fn response_with_params(tx_id: u32, code: ResponseCode, params: &[u32]) -> Vec<u8> {
let len = 12 + params.len() * 4;
let mut buf = Vec::with_capacity(len);
buf.extend_from_slice(&pack_u32(len as u32));
buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
buf.extend_from_slice(&pack_u16(code.into()));
buf.extend_from_slice(&pack_u32(tx_id));
for p in params {
buf.extend_from_slice(&pack_u32(*p));
}
buf
}
pub(crate) fn data_container(tx_id: u32, code: OperationCode, payload: &[u8]) -> Vec<u8> {
let len = 12 + payload.len();
let mut buf = Vec::with_capacity(len);
buf.extend_from_slice(&pack_u32(len as u32));
buf.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
buf.extend_from_slice(&pack_u16(code.into()));
buf.extend_from_slice(&pack_u32(tx_id));
buf.extend_from_slice(payload);
buf
}
#[tokio::test]
async fn test_open_session() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0));
let session = PtpSession::open(transport, 1).await.unwrap();
assert_eq!(session.session_id(), SessionId(1));
}
#[tokio::test]
async fn abandoned_receive_flags_session_for_recovery() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); let session = PtpSession::open(transport, 1).await.unwrap();
mock.queue_response(data_container(1, OperationCode::GetObjectInfo, &[0u8; 4]));
mock.queue_response(ok_response(1));
mock.block_receive();
let mut fut = Box::pin(session.execute_with_receive(OperationCode::GetObjectInfo, &[42]));
assert!(
matches!(futures::poll!(fut.as_mut()), std::task::Poll::Pending),
"op should suspend at receive_bulk after sending its command",
);
drop(fut);
assert!(
session.needs_recovery(),
"abandoning an op mid-receive must flag the session for recovery",
);
}
#[tokio::test]
async fn abandoned_receive_does_not_poison_next_operation() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); let session = PtpSession::open(transport, 1).await.unwrap();
mock.queue_response(data_container(1, OperationCode::GetObjectInfo, &[0u8; 4]));
mock.queue_response(ok_response(1));
mock.block_receive();
let mut op_a = Box::pin(session.execute_with_receive(OperationCode::GetObjectInfo, &[1]));
assert!(matches!(
futures::poll!(op_a.as_mut()),
std::task::Poll::Pending
));
drop(op_a);
mock.unblock_receive();
let result = session
.execute_with_receive(OperationCode::GetObjectInfo, &[2])
.await;
assert!(
matches!(result, Err(Error::NoDevice)),
"op B must not inherit op A's response; got {result:?}",
);
assert_eq!(
mock.get_cancel_calls(),
vec![1],
"recovery must drain the pipe for the poisoned transaction (tx=1)",
);
}
#[tokio::test]
async fn test_open_session_already_open_recovers() {
let (transport, mock) = mock_transport();
mock.queue_response(response_with_params(
0,
ResponseCode::SessionAlreadyOpen,
&[],
));
mock.queue_response(ok_response(1));
mock.queue_response(ok_response(0));
let session = PtpSession::open(transport, 1).await.unwrap();
assert_eq!(session.session_id(), SessionId(1));
}
#[tokio::test]
async fn test_open_session_already_open_transaction_id_reset() {
let (transport, mock) = mock_transport();
mock.queue_response(response_with_params(
0,
ResponseCode::SessionAlreadyOpen,
&[],
));
mock.queue_response(ok_response(1));
mock.queue_response(ok_response(0));
mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
session.delete_object(ObjectHandle(1)).await.unwrap();
}
#[tokio::test]
async fn test_open_session_already_open_close_error_ignored() {
let (transport, mock) = mock_transport();
mock.queue_response(response_with_params(
0,
ResponseCode::SessionAlreadyOpen,
&[],
));
mock.queue_response(response_with_params(1, ResponseCode::GeneralError, &[]));
mock.queue_response(ok_response(0));
let session = PtpSession::open(transport, 1).await.unwrap();
assert_eq!(session.session_id(), SessionId(1));
}
#[tokio::test]
async fn test_open_session_error() {
let (transport, mock) = mock_transport();
mock.queue_response(response_with_params(0, ResponseCode::GeneralError, &[]));
let result = PtpSession::open(transport, 1).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_transaction_id_increment() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(ok_response(1)); mock.queue_response(ok_response(2));
let session = PtpSession::open(transport, 1).await.unwrap();
session.delete_object(ObjectHandle(1)).await.unwrap();
session.delete_object(ObjectHandle(2)).await.unwrap();
}
#[tokio::test]
async fn test_transaction_id_mismatch() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(ok_response(999));
let session = PtpSession::open(transport, 1).await.unwrap();
let result = session.delete_object(ObjectHandle(1)).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_execute_with_send_combined_default() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
assert!(!session.is_split_header_data());
let payload = vec![0xAA, 0xBB, 0xCC, 0xDD];
session
.execute_with_send(OperationCode::SendObject, &[], &payload)
.await
.unwrap();
let sends = mock.get_sends();
assert_eq!(sends.len(), 3);
let data = &sends[2];
assert_eq!(data.len(), HEADER_SIZE + payload.len());
assert_eq!(unpack_u32(&data[0..4]).unwrap() as usize, data.len());
assert_eq!(&data[HEADER_SIZE..], payload.as_slice());
}
#[tokio::test]
async fn test_execute_with_send_split_header_data() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
session.set_split_header_data(true);
assert!(session.is_split_header_data());
let payload = vec![0xAA, 0xBB, 0xCC, 0xDD];
session
.execute_with_send(OperationCode::SendObject, &[], &payload)
.await
.unwrap();
let sends = mock.get_sends();
assert_eq!(sends.len(), 4);
let header = &sends[2];
assert_eq!(header.len(), HEADER_SIZE);
assert_eq!(
unpack_u32(&header[0..4]).unwrap() as usize,
HEADER_SIZE + payload.len()
);
assert_eq!(sends[3], payload);
}
#[tokio::test]
async fn test_execute_with_send_split_empty_payload() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0));
mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
session.set_split_header_data(true);
session
.execute_with_send(OperationCode::SendObject, &[], &[])
.await
.unwrap();
let sends = mock.get_sends();
assert_eq!(sends.len(), 3);
assert_eq!(sends[2].len(), HEADER_SIZE);
}
#[tokio::test]
async fn test_execute_with_send_stream_combined_default() {
use bytes::Bytes;
use futures::stream;
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
assert!(!session.is_split_header_data());
let chunks: Vec<Result<Bytes, std::io::Error>> = vec![
Ok(Bytes::from_static(&[0xAA, 0xBB])),
Ok(Bytes::from_static(&[0xCC, 0xDD])),
];
let total_size = 4u64;
session
.execute_with_send_stream(
OperationCode::SendObject,
&[],
total_size,
stream::iter(chunks),
)
.await
.unwrap();
let sends = mock.get_sends();
assert_eq!(sends.len(), 3);
let data = &sends[2];
assert_eq!(data.len(), HEADER_SIZE + total_size as usize);
assert_eq!(unpack_u32(&data[0..4]).unwrap() as usize, data.len());
assert_eq!(&data[HEADER_SIZE..], &[0xAA, 0xBB, 0xCC, 0xDD]);
}
#[tokio::test]
async fn test_execute_with_send_stream_combined_large_multichunk() {
use bytes::Bytes;
use futures::stream;
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
let chunk_size = 64 * 1024;
let num_chunks = 48; let total_size = (chunk_size * num_chunks) as u64;
let chunks: Vec<Result<Bytes, std::io::Error>> = (0..num_chunks)
.map(|i| Ok(Bytes::from(vec![(i % 256) as u8; chunk_size])))
.collect();
session
.execute_with_send_stream(
OperationCode::SendObject,
&[],
total_size,
stream::iter(chunks),
)
.await
.unwrap();
let sends = mock.get_sends();
let data_sends: Vec<u8> = sends[2..].iter().flat_map(|s| s.clone()).collect();
assert_eq!(data_sends.len(), HEADER_SIZE + total_size as usize);
let payload = &data_sends[HEADER_SIZE..];
for i in 0..num_chunks {
let chunk_start = i * chunk_size;
let chunk_end = chunk_start + chunk_size;
assert!(
payload[chunk_start..chunk_end]
.iter()
.all(|&b| b == (i % 256) as u8),
"chunk {i} data mismatch"
);
}
}
#[tokio::test]
async fn test_execute_with_send_stream_split_header_data() {
use bytes::Bytes;
use futures::stream;
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
session.set_split_header_data(true);
let chunks: Vec<Result<Bytes, std::io::Error>> = vec![
Ok(Bytes::from_static(&[0xAA, 0xBB])),
Ok(Bytes::from_static(&[0xCC, 0xDD])),
];
let total_size = 4u64;
session
.execute_with_send_stream(
OperationCode::SendObject,
&[],
total_size,
stream::iter(chunks),
)
.await
.unwrap();
let sends = mock.get_sends();
assert_eq!(sends.len(), 5);
let header = &sends[2];
assert_eq!(header.len(), HEADER_SIZE);
assert_eq!(
unpack_u32(&header[0..4]).unwrap() as usize,
HEADER_SIZE + total_size as usize
);
assert_eq!(sends[3], &[0xAA, 0xBB]);
assert_eq!(sends[4], &[0xCC, 0xDD]);
}
#[tokio::test]
async fn test_execute_with_send_stream_split_skips_empty_chunks() {
use bytes::Bytes;
use futures::stream;
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0));
mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
session.set_split_header_data(true);
let chunks: Vec<Result<Bytes, std::io::Error>> = vec![
Ok(Bytes::from_static(&[0xAA])),
Ok(Bytes::new()), Ok(Bytes::from_static(&[0xBB])),
];
let total_size = 2u64;
session
.execute_with_send_stream(
OperationCode::SendObject,
&[],
total_size,
stream::iter(chunks),
)
.await
.unwrap();
let sends = mock.get_sends();
assert_eq!(sends.len(), 5);
assert_eq!(sends[2].len(), HEADER_SIZE);
assert_eq!(sends[3], &[0xAA]);
assert_eq!(sends[4], &[0xBB]);
}
#[tokio::test]
async fn test_close_session() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(ok_response(1));
let session = PtpSession::open(transport, 1).await.unwrap();
session.close().await.unwrap();
}
#[tokio::test]
async fn test_close_session_ignores_errors() {
let (transport, mock) = mock_transport();
mock.queue_response(ok_response(0)); mock.queue_response(response_with_params(1, ResponseCode::GeneralError, &[]));
let session = PtpSession::open(transport, 1).await.unwrap();
session.close().await.unwrap();
}
}