use super::com::{self, WpdDevice};
use super::props::map_hresult;
use crate::cancel::CancelToken;
use crate::mtp::backend::ByteRange;
use crate::mtp::object::NewObjectInfo;
use crate::mtp::{
Capabilities, DeviceEvent, DeviceInfo, Error, ObjectHandle, ObjectInfo, StorageId, StorageInfo,
};
use bytes::Bytes;
use futures::channel::mpsc::UnboundedSender;
use futures::channel::{mpsc, oneshot};
use futures::executor::block_on;
use futures::{SinkExt, StreamExt};
use std::sync::mpsc as std_mpsc;
use std::thread::JoinHandle;
use windows::Win32::System::Com::{CoInitializeEx, CoUninitialize, COINIT_MULTITHREADED};
const CHUNK: usize = 256 * 1024;
pub(crate) const DATA_BOUND: usize = 4;
pub(crate) enum OpenSpec {
First,
Serial(String),
UsbDevice {
serial: Option<String>,
vid: u16,
pid: u16,
},
}
pub(crate) struct DownloadStart {
pub(crate) size: u64,
pub(crate) data: mpsc::Receiver<Result<Bytes, Error>>,
}
pub(crate) enum UploadReply {
Committed(ObjectHandle),
ShortClosed { partial: Option<ObjectHandle> },
Error(Error),
}
pub(crate) enum Request {
Storages(oneshot::Sender<Result<Vec<StorageInfo>, Error>>),
StorageInfo(StorageId, oneshot::Sender<Result<StorageInfo, Error>>),
List {
storage: StorageId,
parent: Option<ObjectHandle>,
cancel: Option<CancelToken>,
reply: oneshot::Sender<Result<Vec<ObjectInfo>, Error>>,
},
ObjectInfo(ObjectHandle, oneshot::Sender<Result<ObjectInfo, Error>>),
Download {
obj: ObjectHandle,
range: ByteRange,
reply: oneshot::Sender<Result<DownloadStart, Error>>,
},
ReadRange {
obj: ObjectHandle,
offset: u64,
len: Option<u32>,
reply: oneshot::Sender<Result<Vec<u8>, Error>>,
},
Thumbnail {
obj: ObjectHandle,
reply: oneshot::Sender<Result<Vec<u8>, Error>>,
},
CreateFolder {
storage: StorageId,
parent: Option<ObjectHandle>,
name: String,
reply: oneshot::Sender<Result<ObjectHandle, Error>>,
},
Upload {
storage: StorageId,
parent: Option<ObjectHandle>,
info: NewObjectInfo,
data: mpsc::Receiver<Bytes>,
reply: oneshot::Sender<UploadReply>,
},
Delete {
obj: ObjectHandle,
reply: oneshot::Sender<Result<(), Error>>,
},
Rename {
obj: ObjectHandle,
name: String,
reply: oneshot::Sender<Result<(), Error>>,
},
MoveObject {
obj: ObjectHandle,
new_parent: ObjectHandle,
new_storage: StorageId,
reply: oneshot::Sender<Result<(), Error>>,
},
CopyObject {
obj: ObjectHandle,
new_parent: ObjectHandle,
new_storage: StorageId,
reply: oneshot::Sender<Result<ObjectHandle, Error>>,
},
Shutdown,
}
pub(crate) struct WpdHandle {
req_tx: std_mpsc::Sender<Request>,
join: Option<JoinHandle<()>>,
}
impl WpdHandle {
pub(crate) async fn spawn(
spec: OpenSpec,
) -> Result<
(
Self,
DeviceInfo,
Capabilities,
mpsc::UnboundedReceiver<DeviceEvent>,
),
Error,
> {
let (req_tx, req_rx) = std_mpsc::channel::<Request>();
let (startup_tx, startup_rx) =
oneshot::channel::<Result<(DeviceInfo, Capabilities), Error>>();
let (event_tx, event_rx) = mpsc::unbounded::<DeviceEvent>();
let join = std::thread::Builder::new()
.name("wpd-com-worker".into())
.spawn(move || worker_main(spec, startup_tx, req_rx, event_tx))
.map_err(|e| Error::Io {
message: format!("failed to spawn WPD worker thread: {e}"),
})?;
let (device_info, capabilities) = startup_rx
.await
.map_err(|_| Error::Disconnected)? ?;
Ok((
Self {
req_tx,
join: Some(join),
},
device_info,
capabilities,
event_rx,
))
}
pub(crate) async fn call<T, F>(&self, make: F) -> Result<T, Error>
where
F: FnOnce(oneshot::Sender<Result<T, Error>>) -> Request,
{
let (tx, rx) = oneshot::channel();
self.req_tx
.send(make(tx))
.map_err(|_| Error::Disconnected)?;
rx.await.map_err(|_| Error::Disconnected)?
}
pub(crate) fn send(&self, req: Request) -> Result<(), Error> {
self.req_tx.send(req).map_err(|_| Error::Disconnected)
}
pub(crate) fn shutdown(&self) {
let _ = self.req_tx.send(Request::Shutdown);
}
}
impl Drop for WpdHandle {
fn drop(&mut self) {
let _ = self.req_tx.send(Request::Shutdown);
if let Some(join) = self.join.take() {
let _ = join.join();
}
}
}
fn worker_main(
spec: OpenSpec,
startup: oneshot::Sender<Result<(DeviceInfo, Capabilities), Error>>,
req_rx: std_mpsc::Receiver<Request>,
event_tx: UnboundedSender<DeviceEvent>,
) {
unsafe {
if let Err(e) = CoInitializeEx(None, COINIT_MULTITHREADED).ok() {
let _ = startup.send(Err(map_hresult(e)));
return;
}
}
let mut dev = match unsafe { open_device(spec) } {
Ok(dev) => {
let _ = startup.send(Ok((dev.device_info().clone(), *dev.capabilities())));
dev
}
Err(e) => {
let _ = startup.send(Err(e));
unsafe { CoUninitialize() };
return;
}
};
unsafe { dev.register_events(event_tx) };
while let Ok(req) = req_rx.recv() {
match req {
Request::Shutdown => break,
Request::Storages(reply) => {
let _ = reply.send(unsafe { dev.storages() });
}
Request::StorageInfo(storage, reply) => {
let _ = reply.send(unsafe { dev.storage_info(storage) });
}
Request::List {
storage,
parent,
cancel,
reply,
} => {
let _ = reply.send(unsafe { dev.list(storage, parent, cancel.as_ref()) });
}
Request::ObjectInfo(obj, reply) => {
let _ = reply.send(unsafe { dev.object_info(obj) });
}
Request::Download { obj, range, reply } => handle_download(&mut dev, obj, range, reply),
Request::ReadRange {
obj,
offset,
len,
reply,
} => {
let _ = reply.send(handle_read_range(&mut dev, obj, offset, len));
}
Request::Thumbnail { obj, reply } => {
let _ = reply.send(handle_thumbnail(&mut dev, obj));
}
Request::CreateFolder {
storage,
parent,
name,
reply,
} => {
let _ = reply.send(unsafe { dev.create_folder(storage, parent, &name) });
}
Request::Upload {
storage,
parent,
info,
data,
reply,
} => handle_upload(&mut dev, storage, parent, info, data, reply),
Request::Delete { obj, reply } => {
let _ = reply.send(unsafe { dev.delete(obj) });
}
Request::Rename { obj, name, reply } => {
let _ = reply.send(unsafe { dev.rename(obj, &name) });
}
Request::MoveObject {
obj,
new_parent,
new_storage,
reply,
} => {
let _ = reply.send(unsafe { dev.move_object(obj, new_parent, new_storage) });
}
Request::CopyObject {
obj,
new_parent,
new_storage,
reply,
} => {
let _ = reply.send(unsafe { dev.copy_object(obj, new_parent, new_storage) });
}
}
}
drop(dev);
unsafe { CoUninitialize() };
}
unsafe fn open_device(spec: OpenSpec) -> Result<WpdDevice, Error> {
let entries = com::enumerate()?;
match spec {
OpenSpec::First => {
let first = entries.first().ok_or(Error::NoDevice)?;
WpdDevice::open(&first.pnp_id)
}
OpenSpec::Serial(serial) => {
for entry in &entries {
if let Ok(dev) = WpdDevice::open(&entry.pnp_id) {
if dev.device_info().serial_number == serial {
return Ok(dev);
}
}
}
Err(Error::NoDevice)
}
OpenSpec::UsbDevice { serial, vid, pid } => {
let candidates: Vec<&com::DeviceEntry> = entries
.iter()
.filter(|e| pnp_vid_pid(&e.pnp_id) == Some((vid, pid)))
.collect();
match candidates.as_slice() {
[] => Err(Error::NoDevice),
[only] => WpdDevice::open(&only.pnp_id),
many => {
if let Some(target) = serial.as_deref() {
for entry in many {
if com::wpd_device_usb_serial(&entry.pnp_id)
.is_some_and(|s| s.eq_ignore_ascii_case(target))
{
return WpdDevice::open(&entry.pnp_id);
}
}
}
Err(Error::Other {
detail: format!(
"{} WPD devices share VID/PID {vid:04x}:{pid:04x} and none matched the \
USB serial {serial:?}; open by (WPD) serial instead",
many.len()
),
})
}
}
}
}
}
fn pnp_vid_pid(pnp_id: &str) -> Option<(u16, u16)> {
let lower = pnp_id.to_ascii_lowercase();
let hex4 = |marker: &str| -> Option<u16> {
let digits = lower.split(marker).nth(1)?.get(0..4)?;
u16::from_str_radix(digits, 16).ok()
};
Some((hex4("vid_")?, hex4("pid_")?))
}
fn handle_download(
dev: &mut WpdDevice,
obj: ObjectHandle,
range: ByteRange,
reply: oneshot::Sender<Result<DownloadStart, Error>>,
) {
let size = match unsafe { dev.object_size(obj) } {
Ok(s) => s,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
let offset = range.offset();
if offset > size {
let _ = reply.send(Err(Error::invalid_data(format!(
"download offset {offset} is past the object size {size}"
))));
return;
}
let stream = match unsafe { dev.open_stream(obj) } {
Ok(s) => s,
Err(e) => {
let _ = reply.send(Err(e));
return;
}
};
if let Err(e) = unsafe { com::stream_seek(&stream, offset) } {
let _ = reply.send(Err(e));
return;
}
let (mut tx, rx) = mpsc::channel::<Result<Bytes, Error>>(DATA_BOUND);
if reply.send(Ok(DownloadStart { size, data: rx })).is_err() {
return; }
let mut remaining: Option<u64> = match range {
ByteRange::Range { len, .. } => Some(len),
_ => None,
};
let mut buf = vec![0u8; CHUNK];
loop {
let want = match remaining {
Some(r) => (r as usize).min(buf.len()),
None => buf.len(),
};
if want == 0 {
break;
}
let n = match unsafe { com::stream_read(&stream, &mut buf[..want]) } {
Ok(0) => break, Ok(n) => n,
Err(e) => {
let _ = block_on(tx.send(Err(e)));
break;
}
};
if block_on(tx.send(Ok(Bytes::copy_from_slice(&buf[..n])))).is_err() {
break; }
if let Some(r) = remaining.as_mut() {
*r -= n as u64;
}
}
}
fn handle_read_range(
dev: &mut WpdDevice,
obj: ObjectHandle,
offset: u64,
len: Option<u32>,
) -> Result<Vec<u8>, Error> {
let stream = unsafe { dev.open_stream(obj) }?;
unsafe { com::stream_seek(&stream, offset) }?;
let cap = len.map(|l| l as usize);
let mut out = Vec::with_capacity(cap.unwrap_or(0));
let mut buf = vec![0u8; CHUNK];
loop {
let want = match cap {
Some(c) => (c - out.len()).min(buf.len()),
None => buf.len(),
};
if want == 0 {
break;
}
let n = unsafe { com::stream_read(&stream, &mut buf[..want]) }?;
if n == 0 {
break;
}
out.extend_from_slice(&buf[..n]);
}
Ok(out)
}
fn handle_thumbnail(dev: &mut WpdDevice, obj: ObjectHandle) -> Result<Vec<u8>, Error> {
let stream = unsafe { dev.open_thumbnail_stream(obj) }?;
let mut out = Vec::new();
let mut buf = vec![0u8; CHUNK];
loop {
let n = unsafe { com::stream_read(&stream, &mut buf) }?;
if n == 0 {
break;
}
out.extend_from_slice(&buf[..n]);
}
Ok(out)
}
fn handle_upload(
dev: &mut WpdDevice,
storage: StorageId,
parent: Option<ObjectHandle>,
info: NewObjectInfo,
mut rx: mpsc::Receiver<Bytes>,
reply: oneshot::Sender<UploadReply>,
) {
let stream = match unsafe { dev.create_upload_stream(storage, parent, &info) } {
Ok(s) => s,
Err(e) => {
let _ = reply.send(UploadReply::Error(e));
return;
}
};
let mut written: u64 = 0;
while let Some(chunk) = block_on(rx.next()) {
if let Err(e) = unsafe { com::stream_write(&stream, &chunk) } {
drop(stream);
let _ = reply.send(UploadReply::Error(e));
return;
}
written += chunk.len() as u64;
}
if written == info.size {
let result = unsafe { dev.commit_upload_stream(&stream) };
let _ = reply.send(match result {
Ok(handle) => UploadReply::Committed(handle),
Err(e) => UploadReply::Error(e),
});
} else {
drop(stream);
let partial = unsafe { dev.find_child_by_name(storage, parent, &info.filename) };
let _ = reply.send(UploadReply::ShortClosed { partial });
}
}