use super::error::Error;
use super::{MtpEligibility, UsbDeviceDescriptor, UsbDeviceFlagSet};
use crate::error::MtpError;
use std::fmt::Debug;
use std::sync::Arc;
use std::time::Duration;
use bitflags::Flags;
use futures::stream::FuturesUnordered;
use futures::{Stream, StreamExt};
use mtp_spec::device::session::MtpSession;
pub use nusb;
use nusb::descriptors::TransferType;
use nusb::descriptors::language_id::US_ENGLISH;
use nusb::transfer::Direction;
#[derive(Clone)]
pub struct Device {
info: nusb::DeviceInfo,
well_known_info: Option<UsbDeviceDescriptor>,
flags: UsbDeviceFlagSet,
handle: Option<nusb::Device>,
}
impl Debug for Device {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Device")
.field("info", &self.info)
.finish_non_exhaustive()
}
}
impl From<nusb::DeviceInfo> for Device {
fn from(info: nusb::DeviceInfo) -> Self {
Self {
info,
well_known_info: None,
flags: UsbDeviceFlagSet::empty(),
handle: None,
}
}
}
impl Device {
#[allow(clippy::missing_panics_doc)] pub async fn open(
self,
) -> Result<MtpSession<super::handle::DeviceHandle>, super::error::Error> {
let handle = self.open_raw().await?;
MtpSession::open(handle)
.await
.map_err(|e| Error::Generic(Arc::new(e)))
}
#[allow(clippy::missing_panics_doc)] pub async fn open_raw(self) -> Result<super::handle::DeviceHandle, super::error::Error> {
const MTP_ENDPOINT_COUNT: u8 = 3;
let device = self.info.open().await.map_err(|e| Arc::new(e.into()))?;
let mut interface_num = None;
let mut endpoints = None;
'outer: for config in device.configurations() {
for interface in config.interfaces() {
for alt_settings in interface.alt_settings() {
if alt_settings.num_endpoints() != MTP_ENDPOINT_COUNT {
continue;
}
let mut bulk_in = None;
let mut bulk_in_buffer_size = 0;
let mut bulk_out = None;
let mut bulk_out_buffer_size = 0;
let mut interrupt = None;
let mut interrupt_buffer_size = 0;
for endpoint in alt_settings.endpoints() {
match endpoint.transfer_type() {
TransferType::Bulk => match endpoint.direction() {
Direction::In => {
bulk_in = Some(endpoint.address());
bulk_in_buffer_size = endpoint.max_packet_size();
},
Direction::Out => {
bulk_out = Some(endpoint.address());
bulk_out_buffer_size = endpoint.max_packet_size();
},
},
TransferType::Interrupt => {
interrupt = Some(endpoint.address());
interrupt_buffer_size = endpoint.max_packet_size();
},
_ => {},
}
}
let (Some(bulk_in), Some(bulk_out), Some(interrupt)) =
(bulk_in, bulk_out, interrupt)
else {
continue;
};
endpoints = Some(super::handle::Endpoints {
bulk_in,
bulk_in_buffer_size,
bulk_out,
bulk_out_buffer_size,
interrupt,
interrupt_buffer_size,
});
interface_num = Some(interface.interface_number());
break 'outer;
}
}
}
let Some(interface_num) = interface_num else {
return Err(Error::Core(MtpError::Transport(Arc::new(
super::error::UsbError::NoApplicableInterface,
))));
};
let interface = device
.claim_interface(interface_num)
.await
.map_err(|e| Arc::new(e.into()))?;
super::handle::DeviceHandle::new(device, self.flags, interface, endpoints.unwrap())
}
pub fn info(&self) -> &nusb::DeviceInfo {
&self.info
}
pub fn well_known_info(&self) -> Option<UsbDeviceDescriptor> {
self.well_known_info
}
}
impl Device {
async fn check_mtp_eligibility(&mut self) -> Result<MtpEligibility, super::error::UsbError> {
if let Some(well_known_entry) = super::WELL_KNOWN_DEVICE_DESCRIPTORS.iter().find(|d| {
d.vendor_id == self.info.vendor_id() && d.product_id == self.info.product_id()
}) {
self.flags = well_known_entry.flags;
self.well_known_info = Some(*well_known_entry);
return Ok(MtpEligibility::Eligible);
}
if self.check_for_mtp_descriptor().await? {
return Ok(MtpEligibility::Eligible);
}
Ok(MtpEligibility::Ineligible)
}
async fn check_for_mtp_descriptor(&mut self) -> Result<bool, super::error::UsbError> {
const CLASS_PER_INTERFACE: u8 = 0;
const CLASS_COMM: u8 = 2;
const CLASS_PTP: u8 = 6;
const CLASS_VENDOR_SPECIFIC: u8 = 255;
const LIKELY_DEVICE_CLASSES: &[u8] = &[
CLASS_PER_INTERFACE,
CLASS_COMM,
CLASS_PTP,
0xEF,
CLASS_VENDOR_SPECIFIC,
];
if !LIKELY_DEVICE_CLASSES.contains(&self.info.class()) {
return Ok(false);
}
let Ok(handle) = self.info.open().await else {
return Ok(false);
};
self.handle = Some(handle);
let handle = self.handle.as_ref().unwrap();
for config in handle.configurations() {
for interface in config.interfaces() {
for alt_settings in interface.alt_settings() {
if alt_settings.num_endpoints() != 3 {
continue;
}
if alt_settings.class() == CLASS_VENDOR_SPECIFIC {
todo!()
}
let Some(string_index) = alt_settings.string_index() else {
continue;
};
let timeout = Duration::from_secs(1);
let Ok(interface_name) = handle
.get_string_descriptor(string_index, US_ENGLISH, timeout)
.await
else {
continue;
};
if interface_name.contains("MTP") {
return Ok(true);
}
}
}
}
Ok(false)
}
}
pub async fn device_list() -> Result<
impl Stream<Item = Result<Device, Arc<super::error::UsbError>>>,
Arc<super::error::UsbError>,
> {
let all_devices = nusb::list_devices().await.map_err(|e| Arc::new(e.into()))?;
let futures = all_devices
.map(|info| {
let mut device = Device::from(info);
async move {
match device.check_mtp_eligibility().await {
Ok(MtpEligibility::Eligible) => Some(Ok(device)),
Ok(MtpEligibility::Ineligible) => None,
Err(e) => Some(Err(Arc::new(e))),
}
}
})
.collect::<FuturesUnordered<_>>();
Ok(futures.filter_map(std::future::ready))
}