use crate::buffer::Buffer;
use crate::fmt::{info, trace};
use crate::transport::{CommandStatus, Transport, TransportError};
use core::borrow::BorrowMut;
use core::cmp::min;
use usb_device::bus::{UsbBus, UsbBusAllocator};
use usb_device::class::ControlIn;
use usb_device::class_prelude::DescriptorWriter;
use usb_device::control::{Recipient, RequestType};
use usb_device::endpoint::{Endpoint, In, Out};
use usb_device::UsbError;
pub(crate) const TRANSPORT_BBB: u8 = 0x50;
const CLASS_SPECIFIC_BULK_ONLY_MASS_STORAGE_RESET: u8 = 0xFF;
const CLASS_SPECIFIC_GET_MAX_LUN: u8 = 0xFE;
const CBW_SIGNATURE_LE: [u8; 4] = 0x43425355u32.to_le_bytes();
const CSW_SIGNATURE_LE: [u8; 4] = 0x53425355u32.to_le_bytes();
const CBW_LEN: usize = 31;
const CSW_LEN: usize = 13;
struct InvalidCbwError;
#[derive(Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum BulkOnlyError {
IoBufferOverflow,
InvalidMaxLun,
InvalidState,
FullPacketExpected,
BufferTooSmall,
}
pub struct CommandBlock<'a> {
pub bytes: &'a [u8],
pub lun: u8,
}
#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
enum State {
Idle, CommandTransfer, DataTransferToHost, DataTransferFromHost, DataTransferNoData, StatusTransfer, }
#[repr(u8)]
#[derive(Default, Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
enum DataDirection {
Out,
In,
#[default]
NotExpected,
}
type BulkOnlyTransportResult<T> = Result<T, TransportError<BulkOnlyError>>;
pub struct BulkOnly<'alloc, Bus: UsbBus, Buf: BorrowMut<[u8]>> {
in_ep: Endpoint<'alloc, Bus, In>,
out_ep: Endpoint<'alloc, Bus, Out>,
buf: Buffer<Buf>,
state: State,
cbw: CommandBlockWrapper,
cs: Option<CommandStatus>,
max_lun: u8,
}
impl<'alloc, Bus, Buf> BulkOnly<'alloc, Bus, Buf>
where
Bus: UsbBus,
Buf: BorrowMut<[u8]>,
{
pub fn new(
alloc: &'alloc UsbBusAllocator<Bus>,
packet_size: u16,
max_lun: u8,
buf: Buf,
) -> Result<BulkOnly<'alloc, Bus, Buf>, BulkOnlyError> {
if max_lun > 0x0F {
return Err(BulkOnlyError::InvalidMaxLun);
}
let buf_len = buf.borrow().len();
if buf_len < CBW_LEN || buf_len < packet_size as usize {
return Err(BulkOnlyError::BufferTooSmall);
}
Ok(BulkOnly {
in_ep: alloc.bulk(packet_size),
out_ep: alloc.bulk(packet_size),
buf: Buffer::new(buf),
state: State::Idle,
cbw: Default::default(),
cs: Default::default(),
max_lun,
})
}
pub fn read(&mut self) -> BulkOnlyTransportResult<()> {
match self.state {
State::Idle | State::CommandTransfer => self.handle_read_cbw(),
State::DataTransferFromHost => self.handle_read_from_host(),
_ => Ok(()),
}
}
pub fn write(&mut self) -> BulkOnlyTransportResult<()> {
match self.state {
State::StatusTransfer => self.handle_write_csw(),
State::DataTransferToHost => self.handle_write_to_host(),
State::DataTransferNoData => self.handle_no_data_transfer(),
_ => Ok(()),
}
}
pub fn set_status(&mut self, status: CommandStatus) {
assert!(matches!(
self.state,
State::DataTransferToHost | State::DataTransferFromHost | State::DataTransferNoData
));
info!("usb: bbb: Set status: {}", status);
self.cs = Some(status);
}
pub fn get_command(&self) -> Option<CommandBlock<'_>> {
match self.state {
State::Idle | State::CommandTransfer => None,
_ => Some(CommandBlock {
bytes: &self.cbw.block[..self.cbw.block_len],
lun: self.cbw.lun,
}),
}
}
pub fn read_data(&mut self, dst: &mut [u8]) -> BulkOnlyTransportResult<usize> {
if !matches!(self.state, State::DataTransferFromHost) {
return Err(TransportError::Error(BulkOnlyError::InvalidState));
}
Ok(self
.buf
.read(|buf| {
let size = min(dst.len(), buf.len());
dst[..size].copy_from_slice(&buf[..size]);
Ok::<usize, ()>(size)
})
.unwrap())
}
pub fn write_data(&mut self, src: &[u8]) -> BulkOnlyTransportResult<usize> {
if !matches!(self.state, State::DataTransferToHost) {
return Err(TransportError::Error(BulkOnlyError::InvalidState));
}
if !self.status_present() {
Ok(self
.buf
.write(&src[..min(src.len(), self.cbw.data_transfer_len as usize)]))
} else {
Err(TransportError::Error(BulkOnlyError::InvalidState))
}
}
pub fn try_write_data_all(&mut self, src: &[u8]) -> BulkOnlyTransportResult<()> {
if !matches!(self.state, State::DataTransferToHost) {
return Err(TransportError::Error(BulkOnlyError::InvalidState));
}
if !self.status_present() {
self.buf
.write_all(
src.len(),
TransportError::Error(BulkOnlyError::IoBufferOverflow),
|dst| {
dst[..src.len()].copy_from_slice(src);
Ok(src.len())
},
)
.map(|_| ())
} else {
Err(TransportError::Error(BulkOnlyError::InvalidState))
}
}
pub fn has_status(&self) -> bool {
self.status_present()
}
fn handle_read_cbw(&mut self) -> BulkOnlyTransportResult<()> {
self.read_packet()?;
if self.buf.available_read() >= CBW_LEN {
match self.try_parse_cbw() {
Ok(cbw) => {
info!("usb: bbb: Recv CBW: {}", cbw);
self.start_data_transfer(cbw);
}
Err(_) => {
self.stall_eps();
self.reset();
}
}
} else {
self.enter_state(State::CommandTransfer)
}
Ok(())
}
fn handle_read_from_host(&mut self) -> BulkOnlyTransportResult<()> {
if !self.status_present() {
let count = self.read_packet()?; self.cbw.data_transfer_len = self.cbw.data_transfer_len.saturating_sub(count as u32);
trace!("usb: bbb: Data residue: {}", self.cbw.data_transfer_len);
}
self.check_end_data_transfer()
}
fn handle_write_to_host(&mut self) -> BulkOnlyTransportResult<()> {
let max_packet_size = self.packet_size() as u32;
let full_packet_expected =
self.cbw.data_transfer_len >= max_packet_size && !self.status_present();
let full_packet = self.buf.available_read() >= max_packet_size as usize;
let full_packet_or_zero = full_packet || !full_packet_expected;
if full_packet_or_zero {
if self.buf.available_read() > 0 {
let count = self.write_packet()?; self.cbw.data_transfer_len =
self.cbw.data_transfer_len.saturating_sub(count as u32);
trace!("usb: bbb: Data residue: {}", self.cbw.data_transfer_len);
}
self.check_end_data_transfer()
} else {
Err(TransportError::Error(BulkOnlyError::FullPacketExpected))
}
}
fn handle_no_data_transfer(&mut self) -> BulkOnlyTransportResult<()> {
self.check_end_data_transfer()
}
fn handle_write_csw(&mut self) -> BulkOnlyTransportResult<()> {
self.write_packet()?; if self.buf.available_read() == 0 {
self.enter_state(State::Idle) }
Ok(())
}
fn check_end_data_transfer(&mut self) -> BulkOnlyTransportResult<()> {
match self.state {
State::DataTransferNoData | State::DataTransferFromHost => {
if self.cs.is_some() {
self.end_data_transfer()?;
}
}
State::DataTransferToHost => {
if self.cs.is_some() && self.buf.available_read() == 0 {
self.end_data_transfer()?;
}
}
_ => {}
}
Ok(())
}
fn end_data_transfer(&mut self) -> BulkOnlyTransportResult<()> {
if self.cbw.data_transfer_len > 0 {
match self.state {
State::DataTransferToHost => {
self.stall_in_ep();
}
State::DataTransferFromHost => {
self.stall_out_ep();
}
_ => {}
}
}
let csw = self.build_csw().unwrap();
self.buf.clean();
self.buf.write(csw.as_slice());
self.enter_state(State::StatusTransfer);
self.write() }
#[inline]
fn status_present(&self) -> bool {
self.cs.is_some()
}
fn build_csw(&mut self) -> Option<[u8; CSW_LEN]> {
self.cs.map(|status| {
let mut csw = [0u8; CSW_LEN];
csw[..4].copy_from_slice(CSW_SIGNATURE_LE.as_slice());
csw[4..8].copy_from_slice(self.cbw.tag.to_le_bytes().as_slice());
csw[8..12].copy_from_slice(self.cbw.data_transfer_len.to_le_bytes().as_slice());
csw[12..].copy_from_slice(&[status as u8]);
csw
})
}
fn try_parse_cbw(&mut self) -> Result<CommandBlockWrapper, InvalidCbwError> {
debug_assert!(matches!(self.state, State::Idle | State::CommandTransfer));
debug_assert!(self.buf.available_read() >= CBW_LEN);
let mut raw_cbw = [0u8; CBW_LEN];
self.buf
.read::<()>(|buf| {
raw_cbw.copy_from_slice(&buf[..CBW_LEN]); Ok(CBW_LEN)
})
.unwrap();
if !raw_cbw.starts_with(&CBW_SIGNATURE_LE) {
return Err(InvalidCbwError);
}
CommandBlockWrapper::from_le_bytes(&raw_cbw[4..]) }
fn start_data_transfer(&mut self, mut cbw: CommandBlockWrapper) {
debug_assert!(matches!(self.state, State::Idle | State::CommandTransfer));
match cbw.direction {
DataDirection::Out => {
self.enter_state(State::DataTransferFromHost);
}
DataDirection::In => {
self.enter_state(State::DataTransferToHost);
}
DataDirection::NotExpected => {
self.enter_state(State::DataTransferNoData);
cbw.data_transfer_len = 0; }
};
self.cbw = cbw;
}
#[inline]
fn packet_size(&self) -> usize {
self.in_ep.max_packet_size() as usize }
fn read_packet(&mut self) -> BulkOnlyTransportResult<usize> {
let count = self.buf.write_all(
self.packet_size(),
TransportError::Error(BulkOnlyError::IoBufferOverflow),
|buf| match self.out_ep.read(buf) {
Ok(count) => Ok(count),
Err(UsbError::WouldBlock) => Ok(0),
Err(err) => Err(TransportError::Usb(err)),
},
)?;
trace!(
"usb: bbb: Read bytes: {}, buf available: {}",
count,
self.buf.available_read()
);
if count == 0 {
Err(TransportError::Usb(UsbError::WouldBlock))
} else {
Ok(count)
}
}
fn write_packet(&mut self) -> BulkOnlyTransportResult<usize> {
let packet_size = self.packet_size();
let count = self.buf.read(|buf| {
if !buf.is_empty() {
match self.in_ep.write(&buf[..min(packet_size, buf.len())]) {
Ok(count) => Ok(count),
Err(UsbError::WouldBlock) => Ok(0),
Err(err) => Err(TransportError::Usb(err)),
}
} else {
Ok(0) }
})?;
trace!(
"usb: bbb: Wrote bytes: {}, buf available: {}",
count,
self.buf.available_read()
);
if count == 0 {
Err(TransportError::Usb(UsbError::WouldBlock))
} else {
Ok(count)
}
}
#[inline]
fn stall_eps(&self) {
self.stall_in_ep();
self.stall_out_ep();
}
#[inline]
fn stall_in_ep(&self) {
info!("usb: bbb: Stall IN ep");
self.in_ep.stall();
}
#[inline]
fn stall_out_ep(&self) {
info!("usb: bbb: Stall OUT ep");
self.out_ep.stall();
}
#[inline]
fn enter_state(&mut self, state: State) {
info!("usb: bbb: Enter state: {}", state);
if matches!(state, State::Idle) {
self.buf.clean();
self.cbw = Default::default();
self.cs = None;
}
self.state = state;
}
}
impl<Bus, Buf> Transport for BulkOnly<'_, Bus, Buf>
where
Bus: UsbBus,
Buf: BorrowMut<[u8]>,
{
const PROTO: u8 = TRANSPORT_BBB;
type Bus = Bus;
fn get_endpoint_descriptors(&self, writer: &mut DescriptorWriter) -> Result<(), UsbError> {
writer.endpoint(&self.in_ep)?;
writer.endpoint(&self.out_ep)?;
Ok(())
}
fn reset(&mut self) {
info!("usb: bbb: Recv reset");
self.in_ep.unstall();
self.out_ep.unstall();
self.enter_state(State::Idle);
}
fn control_in(&mut self, xfer: ControlIn<Self::Bus>) {
let req = xfer.request();
if !(req.request_type == RequestType::Class && req.recipient == Recipient::Interface) {
return;
}
info!("usb: bbb: Recv ctrl_in: {}", req);
match req.request {
CLASS_SPECIFIC_BULK_ONLY_MASS_STORAGE_RESET => {}
CLASS_SPECIFIC_GET_MAX_LUN => {
xfer.accept_with(&[self.max_lun])
.expect("Failed to accept Get Max Lun!");
}
_ => {}
}
}
}
#[derive(Default, Debug, Copy, Clone)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
struct CommandBlockWrapper {
tag: u32,
data_transfer_len: u32,
direction: DataDirection,
lun: u8,
block_len: usize,
block: [u8; 16],
}
impl CommandBlockWrapper {
fn from_le_bytes(value: &[u8]) -> Result<Self, InvalidCbwError> {
const MIN_CB_LEN: u8 = 1;
const MAX_CB_LEN: u8 = 16;
let block_len = value[10];
if !(MIN_CB_LEN..=MAX_CB_LEN).contains(&block_len) {
return Err(InvalidCbwError);
}
Ok(CommandBlockWrapper {
tag: u32::from_le_bytes(value[..4].try_into().unwrap()),
data_transfer_len: u32::from_le_bytes(value[4..8].try_into().unwrap()),
direction: if u32::from_le_bytes(value[4..8].try_into().unwrap()) != 0 {
if (value[8] & (1 << 7)) > 0 {
DataDirection::In
} else {
DataDirection::Out
}
} else {
DataDirection::NotExpected
},
lun: value[9] & 0b00001111,
block_len: block_len as usize,
block: value[11..].try_into().unwrap(), })
}
}
#[cfg(test)]
mod tests {
use crate::transport::bbb::BulkOnly;
use crate::transport::bbb::State::DataTransferFromHost;
use usb_device::bus::{PollResult, UsbBus, UsbBusAllocator};
use usb_device::class_prelude::{EndpointAddress, EndpointType};
use usb_device::{UsbDirection, UsbError};
struct DummyBus;
impl UsbBus for DummyBus {
fn alloc_ep(
&mut self,
_ep_dir: UsbDirection,
_ep_addr: Option<EndpointAddress>,
_ep_type: EndpointType,
_max_packet_size: u16,
_interval: u8,
) -> usb_device::Result<EndpointAddress> {
Ok(EndpointAddress::from(0))
}
fn enable(&mut self) {}
fn reset(&self) {}
fn set_device_address(&self, _addr: u8) {}
fn write(&self, _ep_addr: EndpointAddress, _buf: &[u8]) -> usb_device::Result<usize> {
Err(UsbError::InvalidEndpoint)
}
fn read(&self, _ep_addr: EndpointAddress, _buf: &mut [u8]) -> usb_device::Result<usize> {
Err(UsbError::InvalidEndpoint)
}
fn set_stalled(&self, _ep_addr: EndpointAddress, _stalled: bool) {}
fn is_stalled(&self, _ep_addr: EndpointAddress) -> bool {
false
}
fn suspend(&self) {}
fn resume(&self) {}
fn poll(&self) -> PollResult {
PollResult::None
}
}
#[test]
fn should_read_data_into_small_buffer() {
const BUF_SIZE: usize = 512;
const N: usize = 123;
let alloc = UsbBusAllocator::new(DummyBus);
let mut bbb = BulkOnly::new(&alloc, 8, 0, vec![0u8; BUF_SIZE]).unwrap();
bbb.state = DataTransferFromHost;
bbb.buf.write([0xFFu8; BUF_SIZE].as_slice());
assert_eq!(N, bbb.read_data([0u8; N].as_mut_slice()).unwrap());
}
}