use usb_device::{
bus::{
InterfaceNumber,
StringIndex,
UsbBus,
UsbBusAllocator,
},
class::{
ControlIn,
ControlOut,
UsbClass,
},
control::{
self,
Recipient,
RequestType,
},
descriptor::DescriptorWriter,
endpoint::{
EndpointAddress,
EndpointIn,
},
UsbError,
};
const SPECIFICATION_RELEASE: u16 = 0x111;
const INTERFACE_CLASS_HID: u8 = 0x03;
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum Subclass {
None = 0x00,
BootInterface = 0x01,
}
#[derive(Clone, Copy, Debug, PartialEq)]
#[repr(u8)]
pub enum Protocol {
None = 0x00,
Keyboard = 0x01,
Mouse = 0x02,
}
#[derive(Debug, Clone, Copy)]
#[repr(u8)]
pub enum DescriptorType {
Hid = 0x21,
Report = 0x22,
_Physical = 0x23,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[repr(u8)]
pub enum Request {
GetReport = 0x01,
GetIdle = 0x02,
GetProtocol = 0x03,
SetReport = 0x09,
SetIdle = 0x0a,
SetProtocol = 0x0b,
}
impl Request {
fn new(u: u8) -> Option<Request> {
use Request::*;
match u {
0x01 => Some(GetReport),
0x02 => Some(GetIdle),
0x03 => Some(GetProtocol),
0x09 => Some(SetReport),
0x0a => Some(SetIdle),
0x0b => Some(SetProtocol),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ReportType {
Input,
Output,
Feature,
Reserved(u8),
}
impl From<u8> for ReportType {
fn from(val: u8) -> Self {
match val {
1 => ReportType::Input,
2 => ReportType::Output,
3 => ReportType::Feature,
_ => ReportType::Reserved(val),
}
}
}
pub trait HidDevice {
fn subclass(&self) -> Subclass;
fn protocol(&self) -> Protocol;
fn report_descriptor(&self) -> &[u8];
fn set_report(
&mut self,
report_type: ReportType,
report_id: u8,
data: &[u8],
) -> Result<(), &'static str>;
fn set_boot(&mut self, boot: bool);
fn get_report(&mut self, report_type: ReportType, report_id: u8)
-> Result<&[u8], &'static str>;
}
pub struct HidClass<'a, B: UsbBus, D: HidDevice> {
device: D,
interface: InterfaceNumber,
endpoint_interrupt_in: EndpointIn<'a, B>,
expect_interrupt_in_complete: bool,
boot_proto: bool,
}
impl<B: UsbBus, D: HidDevice> HidClass<'_, B, D> {
pub fn new(mut device: D, alloc: &UsbBusAllocator<B>) -> HidClass<'_, B, D> {
HidClass {
interface: alloc.interface(),
endpoint_interrupt_in: alloc.interrupt(
device.get_report(ReportType::Input, 1).unwrap().len() as _,
10,
),
expect_interrupt_in_complete: false,
boot_proto: false,
device,
}
}
pub fn write_report(&mut self) -> Result<usize, &'static str> {
if self.expect_interrupt_in_complete {
return Ok(0);
}
let data = self.device.get_report(ReportType::Input, 1)?;
if data.len() >= 8 {
self.expect_interrupt_in_complete = true;
}
match self.endpoint_interrupt_in.write(data) {
Ok(count) => Ok(count),
Err(UsbError::WouldBlock) => Ok(0),
Err(_) => Err("error writing data to endpoint"),
}
}
fn get_report(&mut self, xfer: ControlIn<B>) {
let req = xfer.request();
let [report_type, report_id] = req.value.to_be_bytes();
let report_type = ReportType::from(report_type);
match self.device.get_report(report_type, report_id) {
Ok(data) => xfer.accept_with(data).ok(),
Err(_) => xfer.reject().ok(),
};
}
fn set_report(&mut self, xfer: ControlOut<B>) {
let req = xfer.request();
let [report_type, report_id] = req.value.to_be_bytes();
let report_type = ReportType::from(report_type);
match self.device.set_report(report_type, report_id, xfer.data()) {
Ok(()) => xfer.accept().ok(),
Err(_) => xfer.reject().ok(),
};
}
pub fn get_device(&mut self) -> &mut D {
&mut self.device
}
}
impl<B: UsbBus, D: HidDevice> UsbClass<B> for HidClass<'_, B, D> {
fn poll(&mut self) {}
fn reset(&mut self) {
self.expect_interrupt_in_complete = false;
}
fn get_configuration_descriptors(
&self,
writer: &mut DescriptorWriter,
) -> usb_device::Result<()> {
writer.interface(
self.interface,
INTERFACE_CLASS_HID,
self.device.subclass() as u8,
self.device.protocol() as u8,
)?;
let report_descriptor = self.device.report_descriptor();
let descriptor_len = report_descriptor.len();
if descriptor_len > u16::max_value() as usize {
return Err(UsbError::InvalidState);
}
let descriptor_len = (descriptor_len as u16).to_le_bytes();
let specification_release = SPECIFICATION_RELEASE.to_le_bytes();
writer.write(
DescriptorType::Hid as u8,
&[
specification_release[0], specification_release[1], 0, 1, DescriptorType::Report as u8, descriptor_len[0], descriptor_len[1], ],
)?;
writer.endpoint(&self.endpoint_interrupt_in)?;
Ok(())
}
fn get_string(&self, _index: StringIndex, _lang_id: u16) -> Option<&str> {
None
}
fn endpoint_in_complete(&mut self, addr: EndpointAddress) {
if addr == self.endpoint_interrupt_in.address() {
self.expect_interrupt_in_complete = false;
}
}
fn endpoint_out(&mut self, _addr: EndpointAddress) {}
fn control_in(&mut self, xfer: ControlIn<B>) {
let req = xfer.request();
match (req.request_type, req.recipient) {
(RequestType::Standard, Recipient::Interface) => {
if req.request == control::Request::GET_DESCRIPTOR {
let (dtype, index) = req.descriptor_type_index();
if dtype == DescriptorType::Report as u8 && index == 0 {
let descriptor = self.device.report_descriptor();
xfer.accept_with(descriptor).ok();
}
}
}
(RequestType::Class, Recipient::Interface) => {
let value = req.value;
if let Some(request) = Request::new(req.request) {
match request {
Request::GetReport => {
self.get_report(xfer);
}
Request::SetProtocol => {
self.device.set_boot(value != 0);
}
Request::GetProtocol => {
let _ = xfer.accept_with(&[self.boot_proto as u8]);
}
_ => {}
}
}
}
_ => {}
}
}
fn control_out(&mut self, xfer: ControlOut<B>) {
let req = xfer.request();
if req.request_type == RequestType::Class && req.recipient == Recipient::Interface {
if let Some(request) = Request::new(req.request) {
if let Request::SetReport = request {
self.set_report(xfer)
}
}
}
}
}