use usb_device::class_prelude::*;
use usb_device::Result;
use crate::descriptor::AsInputReport;
use crate::descriptor::BufferOverflow;
const USB_CLASS_HID: u8 = 0x03;
const HID_DESC_DESCTYPE_HID: u8 = 0x21;
const HID_DESC_DESCTYPE_HID_REPORT: u8 = 0x22;
const HID_DESC_SPEC_1_10: [u8; 2] = [0x10, 0x01];
const HID_REQ_GET_IDLE: u8 = 0x02;
const HID_REQ_SET_IDLE: u8 = 0x0a;
const HID_REQ_GET_PROTOCOL: u8 = 0x03;
const HID_REQ_SET_PROTOCOL: u8 = 0x0b;
const HID_REQ_GET_REPORT: u8 = 0x01;
const HID_REQ_SET_REPORT: u8 = 0x09;
const CONTROL_BUF_LEN: usize = 128;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ReportType {
Input = 1,
Output = 2,
Feature = 3,
Reserved,
}
impl From<u8> for ReportType {
fn from(rt: u8) -> ReportType {
match rt {
1 => ReportType::Input,
2 => ReportType::Output,
3 => ReportType::Feature,
_ => ReportType::Reserved,
}
}
}
#[derive(Copy, Clone, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct ReportInfo {
pub report_type: ReportType,
pub report_id: u8,
pub len: usize,
}
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
struct Report {
info: ReportInfo,
buf: [u8; CONTROL_BUF_LEN],
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[repr(u8)]
pub enum HidCountryCode {
NotSupported = 0,
Arabic = 1,
Belgian = 2,
CanadianBilingual = 3,
CanadianFrench = 4,
CzechRepublic = 5,
Danish = 6,
Finnish = 7,
French = 8,
German = 9,
Greek = 10,
Hebrew = 11,
Hungary = 12,
InternationalISO = 13,
Italian = 14,
JapanKatakana = 15,
Korean = 16,
LatinAmerica = 17,
NetherlandsDutch = 18,
Norwegian = 19,
PersianFarsi = 20,
Poland = 21,
Portuguese = 22,
Russia = 23,
Slovakia = 24,
Spanish = 25,
Swedish = 26,
SwissFrench = 27,
SwissGerman = 28,
Switzerland = 29,
Taiwan = 30,
TurkishQ = 31,
UK = 32,
US = 33,
Yugoslavia = 34,
TurkishF = 35,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[repr(u8)]
pub enum HidSubClass {
NoSubClass = 0,
Boot = 1,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[repr(u8)]
pub enum HidProtocol {
Generic = 0,
Keyboard = 1,
Mouse = 2,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[repr(u8)]
pub enum HidProtocolMode {
Boot = 0,
Report = 1,
}
impl From<u8> for HidProtocolMode {
fn from(mode: u8) -> HidProtocolMode {
if mode == HidProtocolMode::Boot as u8 {
HidProtocolMode::Boot
} else {
HidProtocolMode::Report
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ProtocolModeConfig {
DefaultBehavior,
ForceBoot,
ForceReport,
}
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct HidClassSettings {
pub subclass: HidSubClass,
pub protocol: HidProtocol,
pub config: ProtocolModeConfig,
pub locale: HidCountryCode,
}
impl Default for HidClassSettings {
fn default() -> Self {
Self {
subclass: HidSubClass::NoSubClass,
protocol: HidProtocol::Generic,
config: ProtocolModeConfig::DefaultBehavior,
locale: HidCountryCode::NotSupported,
}
}
}
pub struct HIDClass<'a, B: UsbBus> {
if_num: InterfaceNumber,
out_ep: Option<EndpointOut<'a, B>>,
in_ep: Option<EndpointIn<'a, B>>,
report_descriptor: &'static [u8],
set_report_buf: Option<Report>,
protocol: Option<HidProtocolMode>,
settings: HidClassSettings,
}
fn determine_protocol_setting(settings: &HidClassSettings) -> Option<HidProtocolMode> {
if settings.protocol == HidProtocol::Keyboard || settings.protocol == HidProtocol::Mouse {
match settings.config {
ProtocolModeConfig::DefaultBehavior | ProtocolModeConfig::ForceReport => {
Some(HidProtocolMode::Report)
}
ProtocolModeConfig::ForceBoot => Some(HidProtocolMode::Boot),
}
} else {
None
}
}
impl<B: UsbBus> HIDClass<'_, B> {
pub fn new<'a>(
alloc: &'a UsbBusAllocator<B>,
report_descriptor: &'static [u8],
poll_ms: u8,
) -> HIDClass<'a, B> {
let settings = HidClassSettings::default();
HIDClass {
if_num: alloc.interface(),
out_ep: Some(alloc.interrupt(64, poll_ms)),
in_ep: Some(alloc.interrupt(64, poll_ms)),
report_descriptor,
set_report_buf: None,
protocol: determine_protocol_setting(&settings),
settings,
}
}
pub fn new_with_settings<'a>(
alloc: &'a UsbBusAllocator<B>,
report_descriptor: &'static [u8],
poll_ms: u8,
settings: HidClassSettings,
) -> HIDClass<'a, B> {
HIDClass {
if_num: alloc.interface(),
out_ep: Some(alloc.interrupt(64, poll_ms)),
in_ep: Some(alloc.interrupt(64, poll_ms)),
report_descriptor,
set_report_buf: None,
protocol: determine_protocol_setting(&settings),
settings,
}
}
pub fn new_ep_in<'a>(
alloc: &'a UsbBusAllocator<B>,
report_descriptor: &'static [u8],
poll_ms: u8,
) -> HIDClass<'a, B> {
let settings = HidClassSettings::default();
HIDClass {
if_num: alloc.interface(),
out_ep: None,
in_ep: Some(alloc.interrupt(64, poll_ms)),
report_descriptor,
set_report_buf: None,
protocol: determine_protocol_setting(&settings),
settings,
}
}
pub fn new_ep_in_with_settings<'a>(
alloc: &'a UsbBusAllocator<B>,
report_descriptor: &'static [u8],
poll_ms: u8,
settings: HidClassSettings,
) -> HIDClass<'a, B> {
HIDClass {
if_num: alloc.interface(),
out_ep: None,
in_ep: Some(alloc.interrupt(64, poll_ms)),
report_descriptor,
set_report_buf: None,
protocol: determine_protocol_setting(&settings),
settings,
}
}
pub fn new_ep_out<'a>(
alloc: &'a UsbBusAllocator<B>,
report_descriptor: &'static [u8],
poll_ms: u8,
) -> HIDClass<'a, B> {
let settings = HidClassSettings::default();
HIDClass {
if_num: alloc.interface(),
out_ep: Some(alloc.interrupt(64, poll_ms)),
in_ep: None,
report_descriptor,
set_report_buf: None,
protocol: determine_protocol_setting(&settings),
settings,
}
}
pub fn new_ep_out_with_settings<'a>(
alloc: &'a UsbBusAllocator<B>,
report_descriptor: &'static [u8],
poll_ms: u8,
settings: HidClassSettings,
) -> HIDClass<'a, B> {
HIDClass {
if_num: alloc.interface(),
out_ep: Some(alloc.interrupt(64, poll_ms)),
in_ep: None,
report_descriptor,
set_report_buf: None,
protocol: determine_protocol_setting(&settings),
settings,
}
}
pub fn push_input<IR: AsInputReport>(&self, r: &IR) -> Result<usize> {
match self.settings.protocol {
HidProtocol::Keyboard | HidProtocol::Mouse => {
if let Some(protocol) = self.protocol {
if (protocol == HidProtocolMode::Report
&& self.settings.subclass != HidSubClass::NoSubClass)
|| (protocol == HidProtocolMode::Boot
&& self.settings.subclass != HidSubClass::Boot)
{
return Err(UsbError::InvalidState);
}
}
}
_ => {}
}
if let Some(ep) = &self.in_ep {
let mut buff: [u8; 64] = [0; 64];
let size = match r.serialize(&mut buff) {
Ok(l) => l,
Err(BufferOverflow) => return Err(UsbError::BufferOverflow),
};
ep.write(&buff[0..size])
} else {
Err(UsbError::InvalidEndpoint)
}
}
pub fn push_raw_input(&self, data: &[u8]) -> Result<usize> {
match self.settings.protocol {
HidProtocol::Keyboard | HidProtocol::Mouse => {
if let Some(protocol) = self.protocol {
if (protocol == HidProtocolMode::Report
&& self.settings.subclass != HidSubClass::NoSubClass)
|| (protocol == HidProtocolMode::Boot
&& self.settings.subclass != HidSubClass::Boot)
{
return Err(UsbError::InvalidState);
}
}
}
_ => {}
}
if let Some(ep) = &self.in_ep {
ep.write(data)
} else {
Err(UsbError::InvalidEndpoint)
}
}
pub fn pull_raw_output(&self, data: &mut [u8]) -> Result<usize> {
if let Some(ep) = &self.out_ep {
ep.read(data)
} else {
Err(UsbError::InvalidEndpoint)
}
}
pub fn pull_raw_report(&mut self, data: &mut [u8]) -> Result<ReportInfo> {
let info = match &self.set_report_buf {
Some(set_report_buf) => {
let info = set_report_buf.info;
if data.len() < info.len {
return Err(UsbError::BufferOverflow);
}
data[..info.len].copy_from_slice(&set_report_buf.buf[..info.len]);
info
}
None => {
return Err(UsbError::WouldBlock);
}
};
self.set_report_buf = None;
Ok(info)
}
pub fn get_protocol_mode(&self) -> Result<HidProtocolMode> {
match self.settings.protocol {
HidProtocol::Keyboard | HidProtocol::Mouse => {}
_ => {
return Err(UsbError::Unsupported);
}
}
if let Some(protocol) = self.protocol {
Ok(protocol)
} else {
Err(UsbError::InvalidState)
}
}
pub fn set_protocol_mode(
&mut self,
mode: HidProtocolMode,
config: ProtocolModeConfig,
) -> Result<()> {
match self.settings.protocol {
HidProtocol::Keyboard | HidProtocol::Mouse => {}
_ => {
return Err(UsbError::Unsupported);
}
}
match config {
ProtocolModeConfig::DefaultBehavior => self.protocol = Some(mode),
ProtocolModeConfig::ForceBoot => {
self.protocol = Some(HidProtocolMode::Boot);
}
ProtocolModeConfig::ForceReport => {
self.protocol = Some(HidProtocolMode::Report);
}
}
self.settings.config = config;
Ok(())
}
}
impl<B: UsbBus> UsbClass<B> for HIDClass<'_, B> {
fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> Result<()> {
writer.interface(
self.if_num,
USB_CLASS_HID,
self.settings.subclass as u8,
self.settings.protocol as u8,
)?;
writer.write(
HID_DESC_DESCTYPE_HID,
&[
HID_DESC_SPEC_1_10[0],
HID_DESC_SPEC_1_10[1],
self.settings.locale as u8,
1,
HID_DESC_DESCTYPE_HID_REPORT,
(self.report_descriptor.len() & 0xFF) as u8,
((self.report_descriptor.len() >> 8) & 0xFF) as u8,
],
)?;
if let Some(ep) = &self.out_ep {
writer.endpoint(ep)?;
}
if let Some(ep) = &self.in_ep {
writer.endpoint(ep)?;
}
Ok(())
}
fn control_in(&mut self, xfer: ControlIn<B>) {
let req = xfer.request();
if req.index != u8::from(self.if_num) as u16 {
return;
}
match (req.request_type, req.request) {
(control::RequestType::Standard, control::Request::GET_DESCRIPTOR) => {
match (req.value >> 8) as u8 {
HID_DESC_DESCTYPE_HID_REPORT => {
xfer.accept_with_static(self.report_descriptor).ok();
}
HID_DESC_DESCTYPE_HID => {
let buf = &[
9,
HID_DESC_DESCTYPE_HID,
HID_DESC_SPEC_1_10[0],
HID_DESC_SPEC_1_10[1],
self.settings.locale as u8,
1,
HID_DESC_DESCTYPE_HID_REPORT,
(self.report_descriptor.len() & 0xFF) as u8,
((self.report_descriptor.len() >> 8) & 0xFF) as u8,
];
xfer.accept_with(buf).ok();
}
_ => {}
}
}
(control::RequestType::Class, HID_REQ_GET_REPORT) => {
xfer.reject().ok(); }
(control::RequestType::Class, HID_REQ_GET_IDLE) => {
xfer.reject().ok(); }
(control::RequestType::Class, HID_REQ_GET_PROTOCOL) => {
if let Some(protocol) = self.protocol {
xfer.accept_with(&[protocol as u8]).ok();
} else {
xfer.reject().ok();
}
}
_ => {}
}
}
fn control_out(&mut self, xfer: ControlOut<B>) {
let req = xfer.request();
if !(req.recipient == control::Recipient::Interface
&& req.index == u8::from(self.if_num) as u16)
{
return;
}
match req.request {
HID_REQ_SET_IDLE => {
xfer.accept().ok();
}
HID_REQ_SET_PROTOCOL => {
if let Some(_protocol) = self.protocol {
if self.settings.config == ProtocolModeConfig::DefaultBehavior {
self.protocol = Some(((req.value & 0xFF) as u8).into());
}
xfer.accept().ok();
} else {
xfer.reject().ok();
}
}
HID_REQ_SET_REPORT => {
let report_type = ((req.value >> 8) as u8).into();
let report_id = (req.value & 0xFF) as u8;
let len = req.length as usize;
if len > CONTROL_BUF_LEN {
self.set_report_buf = None;
xfer.reject().ok();
} else {
let mut buf: [u8; CONTROL_BUF_LEN] = [0; CONTROL_BUF_LEN];
buf[..len].copy_from_slice(&xfer.data()[..len]);
self.set_report_buf = Some(Report {
info: ReportInfo {
report_type,
report_id,
len,
},
buf,
});
xfer.accept().ok();
}
}
_ => {
xfer.reject().ok();
}
}
}
}