#![allow(missing_docs)]
use core::num::NonZeroU8;
use bitflags::bitflags;
use embassy_time::Timer;
use embassy_usb::control::Request;
use embassy_usb_driver::host::{HostError, SplitInfo, SplitSpeed, UsbHostAllocator, UsbPipe, pipe};
use embassy_usb_driver::{Direction, EndpointInfo, EndpointType, Speed};
use zerocopy::{FromBytes, Immutable, KnownLayout};
use crate::control::{ControlPipeExt, ControlType, Recipient, RequestType, SetupPacket};
use crate::descriptor::{DEFAULT_MAX_DESCRIPTOR_SIZE, InterfaceDescriptor, USBDescriptor};
use crate::handler::{BusRoute, EnumerationInfo, HandlerEvent, RegisterError};
use crate::{BusHandle, EnumerationError};
pub struct HubHandler<'d, A: UsbHostAllocator<'d>, const MAX_PORTS: usize> {
bus: BusHandle<'d, A>,
interrupt_channel: A::Pipe<pipe::Interrupt, pipe::In>,
control_channel: A::Pipe<pipe::Control, pipe::InOut>,
desc: HubDescriptor,
device_address: u8,
device_lut: [Option<NonZeroU8>; MAX_PORTS],
route: BusRoute,
}
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum HubEvent {
DeviceDetected { port: u8, speed: Speed },
DeviceRemoved { address: Option<NonZeroU8>, port: u8 },
}
impl<'d, A: UsbHostAllocator<'d>, const MAX_PORTS: usize> HubHandler<'d, A, MAX_PORTS> {
pub async fn try_register(bus: &BusHandle<'d, A>, enum_info: &EnumerationInfo) -> Result<Self, RegisterError> {
let ls_over_fs = matches!(enum_info.split(), Some(s) if s.device_speed() == SplitSpeed::Low);
let mut control_channel = bus.alloc_pipe::<pipe::Control, pipe::InOut>(
enum_info.device_address,
&EndpointInfo {
addr: 0.into(),
ep_type: EndpointType::Control,
max_packet_size: enum_info
.device_desc
.max_packet_size0
.min(if ls_over_fs { 8 } else { 64 }) as u16,
interval_ms: 0,
},
enum_info.split(),
)?;
let mut cfg_desc_buf = [0u8; DEFAULT_MAX_DESCRIPTOR_SIZE];
let configuration = enum_info
.active_config_or_set_default(&mut control_channel, &mut cfg_desc_buf)
.await?;
let iface = configuration
.iter_interface()
.find(|v| {
matches!(
v,
InterfaceDescriptor {
interface_class: 0x09,
interface_subclass: 0x0,
interface_protocol: 0x0,
..
}
)
})
.ok_or(RegisterError::NoSupportedInterface)?;
let interrupt_ep = iface
.iter_endpoints()
.find(|v| v.ep_type() == EndpointType::Interrupt && v.ep_dir() == Direction::In)
.ok_or(RegisterError::NoSupportedInterface)?;
let interrupt_channel = bus.alloc_pipe::<pipe::Interrupt, pipe::In>(
enum_info.device_address,
&interrupt_ep.into(),
enum_info.split(),
)?;
let desc = control_channel.request_descriptor::<HubDescriptor, 64>(0, true).await?;
let mut hub = HubHandler {
bus: bus.clone(),
interrupt_channel,
control_channel,
desc,
device_address: enum_info.device_address,
device_lut: [None; MAX_PORTS],
route: enum_info.route,
};
for port in 0..hub.desc.port_num {
hub.port_feature(true, PortFeature::Power, port, 0).await?;
}
Timer::after_millis(hub.desc.power_on_delay as u64 * 2).await;
Ok(hub)
}
pub async fn wait_for_event(&mut self) -> Result<HandlerEvent<HubEvent>, HostError> {
loop {
let mut buf = [0u8; (1 + 255) / u8::BITS as usize];
let slice = &mut buf[..(self.desc.port_num as usize / 8) + 1];
self.interrupt_channel.request_in(slice).await?;
let mut hub_changes = HubInterrupt(slice);
if hub_changes.take_hub_change() {
trace!("HUB {}: hub changed, requesting status", self.device_address);
let (status, change) = self.get_hub_status().await?;
debug!(
"HUB {}: hub status: {:?} change: {:?}",
self.device_address, status, change
);
if !change.is_empty() {
return Err(HostError::Other("Unhandled hub status change"));
}
}
while let Some(port) = hub_changes.take_port_change() {
trace!("HUB {}: port {} changed, requesting status", self.device_address, port);
let (status, mut change) = self.get_port_status(port).await?;
debug!(
"HUB {}: port {} status: {:?} change: {:?}",
self.device_address, port, status, change
);
if change.contains(PortStatusChange::RESET) {
change.toggle(PortStatusChange::RESET);
self.port_feature(false, PortFeature::ChangeReset, port, 0).await?;
}
if change.contains(PortStatusChange::CONNECT) {
change.toggle(PortStatusChange::CONNECT);
self.port_feature(false, PortFeature::ChangeConnection, port, 0).await?;
match status.contains(PortStatus::CONNECTED) {
true => {
let speed: Speed = status.into();
debug!(
"HUB {}: Device connected to port {} at {:?}",
self.device_address, port, speed
);
return Ok(HandlerEvent::HandlerEvent(HubEvent::DeviceDetected { port, speed }));
}
false => {
debug!("HUB {}: Device disconnected from port {}", self.device_address, port);
let device_ref = self.device_lut.get_mut(port as usize);
return Ok(HandlerEvent::HandlerEvent(HubEvent::DeviceRemoved {
address: device_ref.and_then(|v| v.take()),
port,
}));
}
}
}
if !change.is_empty() {
return Err(HostError::Other("Unhandled port status change"));
}
}
}
}
#[allow(dead_code)]
async fn hub_feature(&mut self, set: bool, feature: HubFeature) -> Result<(), HostError> {
let setup = SetupPacket {
request_type: RequestType {
direction: Direction::Out,
control_type: ControlType::Class,
recipient: Recipient::Device,
},
request: if set {
Request::SET_FEATURE
} else {
Request::CLEAR_FEATURE
},
value: feature as u16,
index: 0,
length: 0,
};
self.control_channel.control_out(&setup.to_bytes(), &[]).await?;
Ok(())
}
async fn get_hub_status(&mut self) -> Result<(HubStatus, HubStatusChange), HostError> {
let setup = SetupPacket {
request_type: RequestType {
direction: Direction::In,
control_type: ControlType::Class,
recipient: Recipient::Device,
},
request: Request::GET_STATUS,
value: 0,
index: 0,
length: 4,
};
let mut buf = [0u8; 4];
self.control_channel.control_in(&setup.to_bytes(), &mut buf).await?;
Ok((
HubStatus::from_bits_truncate(u16::from_le_bytes(buf[..2].try_into().unwrap())),
HubStatusChange::from_bits_truncate(u16::from_le_bytes(buf[2..].try_into().unwrap())),
))
}
pub async fn enumerate_port(
&mut self,
config_buffer: &mut [u8],
port: u8,
speed: Speed,
) -> Result<(EnumerationInfo, usize), EnumerationError> {
self.port_feature(true, PortFeature::Reset, port, 0).await?;
Timer::after_millis(50).await;
self.port_feature(false, PortFeature::ChangeReset, port, 0).await?;
let route = match self.route.split() {
Some(parent_split) => match speed {
Speed::Low => BusRoute::Translated(SplitInfo::new(
parent_split.hub_addr(),
parent_split.port(),
SplitSpeed::Low,
)),
Speed::Full => BusRoute::Translated(SplitInfo::new(
parent_split.hub_addr(),
parent_split.port(),
SplitSpeed::Full,
)),
Speed::High => BusRoute::Direct(speed),
},
None => {
let split_speed = match (speed, self.route.device_speed()) {
(Speed::Low, Speed::Full | Speed::High) => Some(SplitSpeed::Low),
(Speed::Full, Speed::High) => Some(SplitSpeed::Full),
_ => None,
};
match split_speed {
Some(ss) => BusRoute::Translated(SplitInfo::new(self.device_address, port + 1, ss)),
None => BusRoute::Direct(speed),
}
}
};
let (info, config_len) = self.bus.enumerate(route, config_buffer).await?;
self.device_lut[port as usize] = NonZeroU8::new(info.device_address);
Ok((info, config_len))
}
async fn port_feature(&mut self, set: bool, feature: PortFeature, port: u8, selector: u8) -> Result<(), HostError> {
let setup = SetupPacket {
request_type: RequestType {
direction: Direction::Out,
control_type: ControlType::Class,
recipient: Recipient::Other,
},
request: if set {
Request::SET_FEATURE
} else {
Request::CLEAR_FEATURE
},
value: feature as u16,
index: ((selector as u16) << 8) | (port + 1) as u16,
length: 0,
};
self.control_channel.control_out(&setup.to_bytes(), &[]).await?;
Ok(())
}
async fn get_port_status(&mut self, port: u8) -> Result<(PortStatus, PortStatusChange), HostError> {
let setup = SetupPacket {
request_type: RequestType {
direction: Direction::In,
control_type: ControlType::Class,
recipient: Recipient::Other,
},
request: Request::GET_STATUS,
value: 0,
index: (port + 1) as u16,
length: 4,
};
let mut buf = [0u8; 4];
self.control_channel.control_in(&setup.to_bytes(), &mut buf).await?;
Ok((
PortStatus::from_bits_truncate(u16::from_le_bytes(buf[..2].try_into().unwrap())),
PortStatusChange::from_bits_truncate(u16::from_le_bytes(buf[2..].try_into().unwrap())),
))
}
}
struct HubInterrupt<'a>(&'a mut [u8]);
impl HubInterrupt<'_> {
fn take_hub_change(&mut self) -> bool {
let mut hub_change = false;
if let Some(b) = self.0.get_mut(0) {
if *b & 1 != 0 {
*b &= !1;
hub_change = true;
}
}
hub_change
}
fn take_port_change(&mut self) -> Option<u8> {
self.0
.iter_mut()
.enumerate()
.find(|(_, v)| v.trailing_zeros() < u8::BITS)
.map(|(idx, v)| {
let bit = v.trailing_zeros() as usize;
if idx == 0 && bit == 0 {
panic!("the hub change must be taken before a port change is taken");
}
*v &= !(1 << bit);
(bit + idx * 8 - 1) as u8
})
}
}
#[derive(KnownLayout, FromBytes, Immutable, Clone, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
#[repr(C)]
struct HubDescriptor {
len: u8,
desc_type: u8,
port_num: u8,
characteristics0: u8,
characteristics1: u8,
power_on_delay: u8,
max_current: u8,
port_buf: [u8; 32],
}
impl USBDescriptor for HubDescriptor {
const SIZE: usize = core::mem::size_of::<Self>();
const DESC_TYPE: u8 = 0x29;
type Error = ();
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
let (byref, _) = Self::ref_from_prefix(bytes).map_err(|_| ())?;
if byref.desc_type != Self::DESC_TYPE {
return Err(());
}
Ok(byref.clone())
}
}
#[allow(dead_code)]
#[derive(Clone, Copy)]
#[repr(u8)]
enum HubFeature {
ChangeHubLocalPower = 0,
ChangeHubOverCurrent,
}
bitflags! {
#[derive(Debug)]
struct HubStatus: u16 {
const LOCAL_POWER = 1 << 0;
const OVERCURRENT = 1 << 1;
}
}
bitflags! {
#[derive(Debug)]
struct HubStatusChange: u16 {
const LOCAL_POWER = 1 << 0;
const OVERCURRENT = 1 << 1;
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for HubStatus {
fn format(&self, fmt: defmt::Formatter) {
defmt::write!(fmt, "HubStatus({=u16:b})", self.bits());
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for HubStatusChange {
fn format(&self, fmt: defmt::Formatter) {
defmt::write!(fmt, "HubStatusChange({=u16:b})", self.bits());
}
}
#[allow(dead_code)]
#[derive(Clone, Copy)]
#[repr(u8)]
enum PortFeature {
Connection = 0,
Enable,
Suspend,
OverCurrent,
Reset,
Power = 8,
LowSpeed,
ChangeConnection = 16,
ChangeEnable,
ChangeSuspend,
ChangeOverCurrent,
ChangeReset,
Test,
Indicator,
}
bitflags! {
#[derive(Debug)]
struct PortStatus: u16 {
const CONNECTED = 1 << 0;
const ENABLED = 1 << 1;
const SUSPENDED = 1 << 2;
const OVERCURRENT = 1 << 3;
const RESET = 1 << 4;
const POWERED = 1 << 8;
const LOW_SPEED = 1 << 9;
const HIGH_SPEED = 1 << 10;
const TEST_MODE = 1 << 11;
const INDICATOR_CUSTOM_COLOR = 1 << 12;
}
}
bitflags! {
#[derive(Debug)]
struct PortStatusChange: u16 {
const CONNECT = 1 << 0;
const ENABLE = 1 << 1;
const SUSPEND = 1 << 2;
const OVERCURRENT = 1 << 3;
const RESET = 1 << 4;
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for PortStatus {
fn format(&self, fmt: defmt::Formatter) {
defmt::write!(fmt, "PortStatus({=u16:b})", self.bits());
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for PortStatusChange {
fn format(&self, fmt: defmt::Formatter) {
defmt::write!(fmt, "PortStatusChange({=u16:b})", self.bits());
}
}
impl From<PortStatus> for Speed {
fn from(value: PortStatus) -> Self {
match (
value.contains(PortStatus::LOW_SPEED),
value.contains(PortStatus::HIGH_SPEED),
) {
(true, _) => Speed::Low,
(false, false) => Speed::Full,
(false, true) => Speed::High,
}
}
}
#[cfg(test)]
pub mod tests {
use super::HubInterrupt;
#[test]
fn test_hub_interrupt_0() {
let mut buf: [u8; _] = [];
let mut changes = HubInterrupt(&mut buf);
assert_eq!(changes.take_hub_change(), false);
assert_eq!(changes.take_port_change(), None);
}
#[test]
fn test_hub_interrupt_1_empty() {
let mut buf: [u8; _] = [0b0000_0000];
let mut changes = HubInterrupt(&mut buf);
assert_eq!(changes.take_hub_change(), false);
assert_eq!(changes.take_port_change(), None);
}
#[test]
fn test_hub_interrupt_1_hub() {
let mut buf: [u8; _] = [0b0000_0001];
let mut changes = HubInterrupt(&mut buf);
assert_eq!(changes.take_hub_change(), true);
assert_eq!(changes.take_port_change(), None);
}
#[test]
fn test_hub_interrupt_1_port() {
let mut buf: [u8; _] = [0b0000_0010];
let mut changes = HubInterrupt(&mut buf);
assert_eq!(changes.take_hub_change(), false);
assert_eq!(changes.take_port_change(), Some(0));
assert_eq!(changes.take_port_change(), None);
}
#[test]
fn test_hub_interrupt_1_full() {
let mut buf: [u8; _] = [0b1111_1111];
let mut changes = HubInterrupt(&mut buf);
assert_eq!(changes.take_hub_change(), true);
assert_eq!(changes.take_port_change(), Some(0));
assert_eq!(changes.take_port_change(), Some(1));
assert_eq!(changes.take_port_change(), Some(2));
assert_eq!(changes.take_port_change(), Some(3));
assert_eq!(changes.take_port_change(), Some(4));
assert_eq!(changes.take_port_change(), Some(5));
assert_eq!(changes.take_port_change(), Some(6));
assert_eq!(changes.take_port_change(), None);
}
#[test]
fn test_hub_interrupt_3_hub_empty_port() {
let mut buf: [u8; _] = [0b0000_0001, 0b0000_0000, 0b1000_0000];
let mut changes = HubInterrupt(&mut buf);
assert_eq!(changes.take_hub_change(), true);
assert_eq!(changes.take_port_change(), Some(22));
assert_eq!(changes.take_port_change(), None);
}
}