use crate::ptp::{
container_type, unpack_u32, CommandContainer, ContainerType, DataContainer, DeviceInfo,
OperationCode, PtpSession, ResponseCode, ResponseContainer,
};
use crate::transport::{NusbTransport, Transport};
use crate::Error;
use std::sync::Arc;
use std::time::Duration;
pub struct PtpDevice {
transport: Arc<dyn Transport>,
}
impl PtpDevice {
pub async fn open_by_location(location_id: u64) -> Result<Self, Error> {
Self::open_by_location_with_timeout(location_id, NusbTransport::DEFAULT_TIMEOUT).await
}
pub async fn open_by_location_with_timeout(
location_id: u64,
timeout: Duration,
) -> Result<Self, Error> {
let devices = NusbTransport::list_mtp_devices()?;
let device_info = devices
.into_iter()
.find(|d| d.location_id == location_id)
.ok_or(Error::NoDevice)?;
Self::open_device(device_info, timeout).await
}
pub async fn open_by_serial(serial: &str) -> Result<Self, Error> {
Self::open_by_serial_with_timeout(serial, NusbTransport::DEFAULT_TIMEOUT).await
}
pub async fn open_by_serial_with_timeout(
serial: &str,
timeout: Duration,
) -> Result<Self, Error> {
let devices = NusbTransport::list_mtp_devices()?;
let device_info = devices
.into_iter()
.find(|d| d.serial_number.as_deref() == Some(serial))
.ok_or(Error::NoDevice)?;
Self::open_device(device_info, timeout).await
}
pub async fn open_first() -> Result<Self, Error> {
Self::open_first_with_timeout(NusbTransport::DEFAULT_TIMEOUT).await
}
pub async fn open_first_with_timeout(timeout: Duration) -> Result<Self, Error> {
let devices = NusbTransport::list_mtp_devices()?;
let device_info = devices.into_iter().next().ok_or(Error::NoDevice)?;
Self::open_device(device_info, timeout).await
}
async fn open_device(
device_info: crate::transport::UsbDeviceInfo,
timeout: Duration,
) -> Result<Self, Error> {
let device = device_info.open().map_err(Error::Usb)?;
let transport = NusbTransport::open_with_timeout(device, timeout).await?;
Ok(Self {
transport: Arc::new(transport) as Arc<dyn Transport>,
})
}
pub async fn get_device_info(&self) -> Result<DeviceInfo, Error> {
let cmd = CommandContainer {
code: OperationCode::GetDeviceInfo,
transaction_id: 0,
params: vec![],
};
self.transport.send_bulk(&cmd.to_bytes()).await?;
let mut data_payload = Vec::new();
loop {
let mut bytes = self.transport.receive_bulk(64 * 1024).await?;
if bytes.is_empty() {
return Err(Error::invalid_data("Empty response"));
}
let ct = container_type(&bytes)?;
match ct {
ContainerType::Data => {
if bytes.len() >= 4 {
let total_length = unpack_u32(&bytes[0..4])? as usize;
while bytes.len() < total_length {
let more = self.transport.receive_bulk(64 * 1024).await?;
if more.is_empty() {
return Err(Error::invalid_data(
"Incomplete data container: device stopped sending",
));
}
bytes.extend_from_slice(&more);
}
}
let container = DataContainer::from_bytes(&bytes)?;
data_payload.extend_from_slice(&container.payload);
}
ContainerType::Response => {
let response = ResponseContainer::from_bytes(&bytes)?;
if response.code != ResponseCode::Ok {
return Err(Error::Protocol {
code: response.code,
operation: OperationCode::GetDeviceInfo,
});
}
break;
}
_ => {
return Err(Error::invalid_data(format!(
"Unexpected container type: {:?}",
ct
)));
}
}
}
DeviceInfo::from_bytes(&data_payload)
}
pub async fn open_session(&self) -> Result<PtpSession, Error> {
self.open_session_with_id(1).await
}
pub async fn open_session_with_id(&self, session_id: u32) -> Result<PtpSession, Error> {
PtpSession::open(Arc::clone(&self.transport), session_id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ptp::pack::{pack_u16, pack_u16_array, pack_u32};
use crate::transport::mock::MockTransport;
#[tokio::test]
#[ignore] async fn test_open_first() {
let device = PtpDevice::open_first().await.unwrap();
let info = device.get_device_info().await.unwrap();
println!("Model: {}", info.model);
}
#[tokio::test]
#[ignore] async fn test_open_session() {
let device = PtpDevice::open_first().await.unwrap();
let session = device.open_session().await.unwrap();
let info = session.get_device_info().await.unwrap();
println!("Model: {}", info.model);
session.close().await.unwrap();
}
fn minimal_device_info_payload() -> Vec<u8> {
let mut buf = Vec::new();
buf.extend_from_slice(&pack_u16(100)); buf.extend_from_slice(&pack_u32(0)); buf.extend_from_slice(&pack_u16(0)); buf.push(0x00); buf.extend_from_slice(&pack_u16(0)); buf.extend_from_slice(&pack_u16_array(&[])); buf.extend_from_slice(&pack_u16_array(&[])); buf.extend_from_slice(&pack_u16_array(&[])); buf.extend_from_slice(&pack_u16_array(&[])); buf.extend_from_slice(&pack_u16_array(&[])); buf.push(0x00); buf.push(0x00); buf.push(0x00); buf.push(0x00); buf
}
fn data_container_header(total_length: u32) -> Vec<u8> {
let mut buf = Vec::with_capacity(12);
buf.extend_from_slice(&pack_u32(total_length));
buf.extend_from_slice(&pack_u16(ContainerType::Data.to_code()));
buf.extend_from_slice(&pack_u16(OperationCode::GetDeviceInfo.into()));
buf.extend_from_slice(&pack_u32(0)); buf
}
fn ok_response_session_less() -> Vec<u8> {
let mut buf = Vec::with_capacity(12);
buf.extend_from_slice(&pack_u32(12));
buf.extend_from_slice(&pack_u16(ContainerType::Response.to_code()));
buf.extend_from_slice(&pack_u16(ResponseCode::Ok.into()));
buf.extend_from_slice(&pack_u32(0));
buf
}
#[tokio::test]
async fn get_device_info_handles_split_header_and_data() {
let mock = Arc::new(MockTransport::new());
let payload = minimal_device_info_payload();
let total_length = 12 + payload.len() as u32;
mock.queue_response(data_container_header(total_length));
mock.queue_response(payload);
mock.queue_response(ok_response_session_less());
let device = PtpDevice {
transport: mock as Arc<dyn Transport>,
};
let info = device.get_device_info().await.unwrap();
assert_eq!(info.standard_version, 100);
}
#[tokio::test]
async fn get_device_info_handles_combined_header_and_data() {
let mock = Arc::new(MockTransport::new());
let payload = minimal_device_info_payload();
let total_length = 12 + payload.len() as u32;
let mut combined = data_container_header(total_length);
combined.extend_from_slice(&payload);
mock.queue_response(combined);
mock.queue_response(ok_response_session_less());
let device = PtpDevice {
transport: mock as Arc<dyn Transport>,
};
let info = device.get_device_info().await.unwrap();
assert_eq!(info.standard_version, 100);
}
}