#![allow(clippy::test_attr_in_doctest)]
use std::fmt::Debug;
use std::marker::PhantomData;
use std::{cell::RefCell, rc::Rc};
use usb_device::bus::{UsbBus, UsbBusAllocator};
use usb_device::class::UsbClass;
use usb_device::device::{StringDescriptors, UsbDevice, UsbDeviceBuilder, UsbVidPid};
use usb_device::endpoint::EndpointAddress;
use usb_device::prelude::BuilderError;
use usb_device::UsbDirection;
mod bus;
use bus::*;
mod usbdata;
use usbdata::*;
pub mod prelude {
pub use crate::bus::EmulatedUsbBus;
pub use crate::usbdata::{CtrRequestType, SetupPacket};
pub use crate::{with_usb, AnyResult, AnyUsbError, Device, UsbDeviceCtx};
}
const DEFAULT_EP0_SIZE: u8 = 8;
const DEFAULT_ADDRESS: u8 = 5;
#[derive(Debug, PartialEq)]
pub enum AnyUsbError {
EP0Stalled,
EP0ReadFailed,
EP0BadGetStatusSize,
EP0BadGetConfigSize,
DataConversion,
SetAddressFailed,
InvalidDescriptorLength,
InvalidDescriptorType,
InvalidStringLength,
UsbDeviceBuilder(BuilderError),
UserDefined1,
UserDefined2,
UserDefinedU64(u64),
UserDefinedString(String),
}
pub type AnyResult<T> = core::result::Result<T, AnyUsbError>;
pub trait UsbDeviceCtx<B: UsbBus, C: UsbClass<B>> {
const EP0_SIZE: u8 = DEFAULT_EP0_SIZE;
const ADDRESS: u8 = DEFAULT_ADDRESS;
fn create_class(&mut self, alloc: &UsbBusAllocator<B>) -> AnyResult<C>;
fn post_poll(&mut self, _cls: &mut C) {}
fn skip_setup(&mut self) -> bool {
false
}
fn build_usb_device<'a>(
&mut self,
alloc: &'a UsbBusAllocator<B>,
) -> AnyResult<UsbDevice<'a, B>> {
let usb_dev = UsbDeviceBuilder::new(alloc, UsbVidPid(0x1234, 0x5678))
.strings(&[StringDescriptors::default()
.manufacturer("TestManufacturer")
.product("TestProduct")
.serial_number("TestSerial")])
.map_err(AnyUsbError::UsbDeviceBuilder)?
.device_release(0x0200)
.self_powered(true)
.max_power(250)
.map_err(AnyUsbError::UsbDeviceBuilder)?
.max_packet_size_0(Self::EP0_SIZE)
.map_err(AnyUsbError::UsbDeviceBuilder)?
.build();
Ok(usb_dev)
}
}
pub struct Device<'a, C, X>
where
C: UsbClass<EmulatedUsbBus>,
X: UsbDeviceCtx<EmulatedUsbBus, C>,
{
ctx: X,
usb: &'a RefCell<UsbBusImpl>,
dev: UsbDevice<'a, EmulatedUsbBus>,
_cls: PhantomData<C>,
}
impl<'a, C, X> Device<'a, C, X>
where
C: UsbClass<EmulatedUsbBus>,
X: UsbDeviceCtx<EmulatedUsbBus, C>,
{
fn new(usb: &'a RefCell<UsbBusImpl>, ctx: X, dev: UsbDevice<'a, EmulatedUsbBus>) -> Self {
Device {
usb,
ctx,
dev,
_cls: PhantomData,
}
}
pub fn usb_dev(&mut self) -> &mut UsbDevice<'a, EmulatedUsbBus> {
&mut self.dev
}
pub fn ep0(
&mut self,
d: &mut C,
setup: SetupPacket,
data: Option<&[u8]>,
out: &mut [u8],
) -> core::result::Result<usize, AnyUsbError> {
let setup_bytes: [u8; 8] = setup.into();
self.ep0_raw(d, &setup_bytes, data, out)
}
pub fn ep0_raw(
&mut self,
d: &mut C,
setup_bytes: &[u8],
data: Option<&[u8]>,
out: &mut [u8],
) -> core::result::Result<usize, AnyUsbError> {
let out0 = EndpointAddress::from_parts(0, UsbDirection::Out);
let in0 = EndpointAddress::from_parts(0, UsbDirection::In);
self.usb.borrow().set_read(out0, setup_bytes, true);
self.dev.poll(&mut [d]);
self.ctx.post_poll(d);
if self.usb.borrow().stalled0() {
return Err(AnyUsbError::EP0Stalled);
}
if let Some(val) = data {
self.usb.borrow().set_read(out0, val, false);
for i in 1..100 {
let res = self.dev.poll(&mut [d]);
self.ctx.post_poll(d);
if !res {
break;
}
if i >= 99 {
return Err(AnyUsbError::EP0ReadFailed);
}
}
if self.usb.borrow().stalled0() {
return Err(AnyUsbError::EP0Stalled);
}
};
let mut len = 0;
loop {
let one = self.usb.borrow().get_write(in0, &mut out[len..]);
self.dev.poll(&mut [d]);
self.ctx.post_poll(d);
if self.usb.borrow().stalled0() {
return Err(AnyUsbError::EP0Stalled);
}
len += one;
if one < DEFAULT_EP0_SIZE as usize {
break;
}
}
Ok(len)
}
#[allow(clippy::too_many_arguments)]
pub fn ep_io_control(
&mut self,
cls: &mut C,
reqt: CtrRequestType,
req: u8,
value: u16,
index: u16,
length: u16,
data: Option<&[u8]>,
) -> core::result::Result<Vec<u8>, AnyUsbError> {
let mut buf: Vec<u8> = vec![0; length as usize];
let setup = SetupPacket::new(reqt, req, value, index, length);
let len = self.ep0(cls, setup, data, buf.as_mut_slice())?;
buf.truncate(len);
Ok(buf)
}
pub fn control_read(
&mut self,
cls: &mut C,
reqt: CtrRequestType,
req: u8,
value: u16,
index: u16,
length: u16,
) -> core::result::Result<Vec<u8>, AnyUsbError> {
self.ep_io_control(cls, reqt, req, value, index, length, None)
}
#[allow(clippy::too_many_arguments)]
pub fn control_write(
&mut self,
cls: &mut C,
reqt: CtrRequestType,
req: u8,
value: u16,
index: u16,
length: u16,
data: &[u8],
) -> core::result::Result<Vec<u8>, AnyUsbError> {
self.ep_io_control(cls, reqt, req, value, index, length, Some(data))
}
pub fn device_get_status(&mut self, cls: &mut C) -> core::result::Result<u16, AnyUsbError> {
let data = self.control_read(cls, CtrRequestType::to_host(), 0, 0, 0, 2)?;
if data.len() != 2 {
return Err(AnyUsbError::EP0BadGetStatusSize);
}
let res = data.try_into().map_err(|_| AnyUsbError::DataConversion)?;
Ok(u16::from_le_bytes(res))
}
pub fn device_clear_feature(
&mut self,
cls: &mut C,
feature: u16,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(cls, CtrRequestType::to_device(), 1, feature, 0, 0, &[])
.and(Ok(()))
}
pub fn device_set_feature(
&mut self,
cls: &mut C,
feature: u16,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(cls, CtrRequestType::to_device(), 3, feature, 0, 0, &[])
.and(Ok(()))
}
pub fn device_set_address(
&mut self,
cls: &mut C,
address: u8,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(
cls,
CtrRequestType::to_device(),
5,
address as u16,
0,
0,
&[],
)
.and(Ok(()))
}
pub fn device_get_descriptor(
&mut self,
cls: &mut C,
dtype: u8,
dindex: u8,
lang_id: u16,
length: u16,
) -> core::result::Result<Vec<u8>, AnyUsbError> {
let typeindex: u16 = ((dtype as u16) << 8) | dindex as u16;
self.control_read(
cls,
CtrRequestType::to_host(),
6,
typeindex,
lang_id,
length,
)
}
pub fn device_get_string(
&mut self,
cls: &mut C,
index: u8,
lang_id: u16,
) -> core::result::Result<String, AnyUsbError> {
let typeindex: u16 = (3u16 << 8) | index as u16;
let descr =
self.control_read(cls, CtrRequestType::to_host(), 6, typeindex, lang_id, 255)?;
if descr.len() < 2 {
return Err(AnyUsbError::InvalidDescriptorLength);
}
if descr[0] as usize != descr.len() {
return Err(AnyUsbError::InvalidDescriptorLength);
}
if descr[1] != 3 {
return Err(AnyUsbError::InvalidDescriptorType);
}
if descr[0] % 2 != 0 {
return Err(AnyUsbError::InvalidStringLength);
}
let vu16: Vec<u16> = descr[2..]
.chunks(2)
.map(|c| u16::from_le_bytes([c[0], c[1]]))
.collect();
let res = String::from_utf16(&vu16).map_err(|_| AnyUsbError::DataConversion)?;
Ok(res)
}
pub fn device_set_descriptor(
&mut self,
cls: &mut C,
dtype: u8,
dindex: u8,
lang_id: u16,
length: u16,
data: &[u8],
) -> core::result::Result<(), AnyUsbError> {
let typeindex: u16 = ((dtype as u16) << 8) | dindex as u16;
self.control_write(
cls,
CtrRequestType::to_device(),
7,
typeindex,
lang_id,
length,
data,
)
.and(Ok(()))
}
pub fn device_get_configuration(
&mut self,
cls: &mut C,
) -> core::result::Result<u8, AnyUsbError> {
let res = self.control_read(cls, CtrRequestType::to_host(), 8, 0, 0, 1)?;
if res.len() != 1 {
return Err(AnyUsbError::EP0BadGetConfigSize);
}
Ok(res[0])
}
pub fn device_set_configuration(
&mut self,
cls: &mut C,
configuration: u8,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(
cls,
CtrRequestType::to_device(),
9,
configuration as u16,
0,
0,
&[],
)
.and(Ok(()))
}
pub fn interface_get_status(
&mut self,
cls: &mut C,
interface: u8,
) -> core::result::Result<u16, AnyUsbError> {
let data = self.control_read(
cls,
CtrRequestType::to_host().interface(),
0,
0,
interface as u16,
2,
)?;
if data.len() != 2 {
return Err(AnyUsbError::EP0BadGetStatusSize);
}
let res = data.try_into().map_err(|_| AnyUsbError::DataConversion)?;
Ok(u16::from_le_bytes(res))
}
pub fn interface_clear_feature(
&mut self,
cls: &mut C,
interface: u8,
feature: u16,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(
cls,
CtrRequestType::to_device().interface(),
1,
feature,
interface as u16,
0,
&[],
)
.and(Ok(()))
}
pub fn interface_set_feature(
&mut self,
cls: &mut C,
interface: u8,
feature: u16,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(
cls,
CtrRequestType::to_device().interface(),
3,
feature,
interface as u16,
0,
&[],
)
.and(Ok(()))
}
pub fn interface_get_interface(
&mut self,
cls: &mut C,
) -> core::result::Result<u8, AnyUsbError> {
let res = self.control_read(cls, CtrRequestType::to_host().interface(), 10, 0, 0, 1)?;
if res.len() != 1 {
return Err(AnyUsbError::EP0BadGetConfigSize);
}
Ok(res[0])
}
pub fn interface_set_interface(
&mut self,
cls: &mut C,
interface: u8,
alt_setting: u8,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(
cls,
CtrRequestType::to_device().interface(),
11,
alt_setting as u16,
interface as u16,
0,
&[],
)
.and(Ok(()))
}
pub fn endpoint_get_status(
&mut self,
cls: &mut C,
endpoint: u8,
) -> core::result::Result<u16, AnyUsbError> {
let data = self.control_read(
cls,
CtrRequestType::to_host().endpoint(),
0,
0,
endpoint as u16,
2,
)?;
if data.len() != 2 {
return Err(AnyUsbError::EP0BadGetStatusSize);
}
let res = data.try_into().map_err(|_| AnyUsbError::DataConversion)?;
Ok(u16::from_le_bytes(res))
}
pub fn endpoint_clear_feature(
&mut self,
cls: &mut C,
endpoint: u8,
feature: u16,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(
cls,
CtrRequestType::to_device().endpoint(),
1,
feature,
endpoint as u16,
0,
&[],
)
.and(Ok(()))
}
pub fn endpoint_set_feature(
&mut self,
cls: &mut C,
endpoint: u8,
feature: u16,
) -> core::result::Result<(), AnyUsbError> {
self.control_write(
cls,
CtrRequestType::to_device().endpoint(),
3,
feature,
endpoint as u16,
0,
&[],
)
.and(Ok(()))
}
pub fn endpoint_synch_frame(
&mut self,
cls: &mut C,
endpoint: u8,
) -> core::result::Result<u16, AnyUsbError> {
let data = self.control_read(
cls,
CtrRequestType::to_host().endpoint(),
12,
0,
endpoint as u16,
2,
)?;
if data.len() != 2 {
return Err(AnyUsbError::EP0BadGetStatusSize);
}
let res = data.try_into().map_err(|_| AnyUsbError::DataConversion)?;
Ok(u16::from_le_bytes(res))
}
pub fn setup(&mut self, cls: &mut C) -> core::result::Result<(), AnyUsbError> {
let mut vec;
self.device_get_descriptor(cls, 1, 0, 0, 64)?;
self.device_set_address(cls, X::ADDRESS)?;
if self.dev.bus().get_address() != X::ADDRESS {
return Err(AnyUsbError::SetAddressFailed);
}
let devd = self.device_get_descriptor(cls, 1, 0, 0, 18)?;
vec = self.device_get_descriptor(cls, 2, 0, 0, 9)?;
let conf_desc_len = u16::from_le_bytes([vec[2], vec[3]]);
self.device_get_descriptor(cls, 2, 0, 0, conf_desc_len)?;
vec = self.device_get_descriptor(cls, 3, 0, 0, 255)?;
let lang_id = u16::from_le_bytes([vec[2], vec[3]]);
for sid in devd[14..17].iter() {
if *sid != 0 {
self.device_get_descriptor(cls, 3, *sid, lang_id, 255)?;
}
}
self.device_set_configuration(cls, 1)?;
Ok(())
}
}
pub fn with_usb<C, X>(mut ctx: X, case: for<'a> fn(cls: C, dev: Device<'a, C, X>)) -> AnyResult<()>
where
C: UsbClass<EmulatedUsbBus>,
X: UsbDeviceCtx<EmulatedUsbBus, C>,
{
let stio: UsbBusImpl = UsbBusImpl::new();
let io = Rc::new(RefCell::new(stio));
let bus = EmulatedUsbBus::new(&io);
let alloc: usb_device::bus::UsbBusAllocator<EmulatedUsbBus> = UsbBusAllocator::new(bus);
let mut cls = ctx.create_class(&alloc)?;
let mut usb_dev = ctx.build_usb_device(&alloc)?;
let skip_setup = ctx.skip_setup();
usb_dev.poll(&mut [&mut cls]);
ctx.post_poll(&mut cls);
let mut dev = Device::new(io.as_ref(), ctx, usb_dev);
if !skip_setup {
dev.setup(&mut cls)?;
}
case(cls, dev);
Ok(())
}