use usb_device::class_prelude::*;
use usb_device::{
Result as UsbResult,
control::{
RequestType,
Request,
},
};
pub use UsbError::WouldBlock;
use packing::{
Error as PackingError,
Packed,
PackedSize,
};
use usbd_mass_storage::{
MscClass,
InterfaceSubclass,
InterfaceProtocol,
};
use crate::logging::*;
use super::{
CommandBlockWrapper,
CommandStatusWrapper,
Direction,
CommandStatus,
};
const REQ_GET_MAX_LUN: u8 = 0xFE;
const REQ_BULK_ONLY_RESET: u8 = 0xFF;
const BUFFER_BYTES: usize = 512;
#[derive(Debug)]
pub enum Error {
UsbError(UsbError),
PackingError(PackingError),
DataError,
}
impl From<UsbError> for Error {
fn from(e: UsbError) -> Error {
Error::UsbError(e)
}
}
impl From<PackingError> for Error {
fn from(e: PackingError) -> Error {
Error::PackingError(e)
}
}
#[derive(Clone, Copy, Eq, PartialEq, Debug)]
enum State {
WaitingForCommand,
SendingDataToHost,
ReceivingDataFromHost,
NeedZlp,
NeedToSendStatus,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TransferState {
NotTransferring { bytes_remaining: usize, empty: bool },
ReceivingDataFromHost { bytes_available: usize, full: bool, done: bool },
SendingDataToHost { bytes_remaining: usize, empty: bool },
}
pub struct BulkOnlyTransport<'a, B: UsbBus> {
inner: MscClass<'a, B>,
max_lun: u8,
state: State,
command_block_wrapper: CommandBlockWrapper,
command_status_wrapper: CommandStatusWrapper,
buffer: [u8; BUFFER_BYTES],
buffer_i: usize,
data_i: usize,
last_packet_full: bool,
data_done: bool,
}
impl<B: UsbBus> BulkOnlyTransport<'_, B> {
pub const BUFFER_BYTES: usize = BUFFER_BYTES;
pub fn new(
alloc: &UsbBusAllocator<B>,
max_packet_size: u16,
subclass: InterfaceSubclass,
max_lun: u8,
) -> BulkOnlyTransport<'_, B> {
assert!(max_lun < 16);
BulkOnlyTransport {
inner: MscClass::new(
alloc,
max_packet_size,
subclass,
InterfaceProtocol::BulkOnlyTransport,
),
max_lun,
state: State::WaitingForCommand,
command_block_wrapper: Default::default(),
command_status_wrapper: Default::default(),
buffer: [0; BUFFER_BYTES],
buffer_i: 0,
data_i: 0,
last_packet_full: false,
data_done: false,
}
}
fn max_packet_size(&self) -> u16 {
self.inner.max_packet_size()
}
fn max_packet_usize(&self) -> usize {
self.max_packet_size() as usize
}
pub fn read(&mut self) -> Result<(), Error> {
match self.state {
State::WaitingForCommand => self.waiting_for_command(),
State::ReceivingDataFromHost => self.receiving_data_from_host(),
_ => Ok(()),
}
}
pub fn write(&mut self) -> Result<(), Error> {
match self.state {
State::SendingDataToHost => self.sending_data_to_host(),
State::NeedZlp => self.send_zlp(),
State::NeedToSendStatus => self.need_to_send_status(),
_ => Ok(()),
}
}
fn waiting_for_command(&mut self) -> Result<(), Error> {
if self.buffer.len() - self.buffer_i < (self.max_packet_size() as usize) {
trace_bot_buffer!("BUFFER> too full to read command");
Err(WouldBlock)?;
}
let bytes = self.inner.read_packet(&mut self.buffer[self.buffer_i..])?;
trace_bot_bytes!("BYTES> Read {} bytes for command", bytes);
self.buffer_i += bytes;
let new_i = CommandBlockWrapper::truncate_to_signature(&mut self.buffer[..self.buffer_i]);
if self.buffer_i != new_i {
trace_bot_headers!("HEADER> Discarded {} bytes looking for command block wrapper signature", self.buffer_i - new_i);
self.buffer_i = new_i;
}
if self.buffer_i >= CommandBlockWrapper::BYTES {
trace_bot_buffer!("BUFFER> full enough to try deserializing command block wrapper");
let cbw = CommandBlockWrapper::unpack(&self.buffer)
.map_err(|_| Error::DataError);
if cbw.is_err() {
let err = cbw.err().unwrap();
warn!("CBW unpack error: {:?}", err);
self.buffer_i = 0;
return Err(err);
}
self.transition_to_data(cbw?);
self.read()?;
}
Ok(())
}
fn prepare_for_command(&mut self, cbw: &CommandBlockWrapper) {
self.command_status_wrapper.tag = cbw.tag;
self.command_status_wrapper.data_residue = cbw.data_transfer_length;
self.command_status_wrapper.status = CommandStatus::CommandOk;
}
fn change_state(&mut self, new_state: State) {
trace_bot_states!("STATE> {:?} -> {:?}",
self.state,
new_state,
);
self.state = new_state;
}
fn transition_to_data(&mut self, cbw: CommandBlockWrapper) {
trace_bot_headers!("HEADER> CommandBlockWrapper: {:X?}", cbw);
self.buffer_i = 0;
self.data_i = 0;
self.data_done = false;
self.prepare_for_command(&cbw);
match cbw.direction {
Direction::HostToDevice => {
self.change_state(State::ReceivingDataFromHost);
},
Direction::DeviceToHost => {
self.change_state(State::SendingDataToHost);
},
}
self.command_block_wrapper = cbw;
}
pub fn get_current_command(&self) -> Option<&CommandBlockWrapper> {
match self.state {
State::SendingDataToHost |
State::ReceivingDataFromHost => Some(&self.command_block_wrapper),
_ => None,
}
}
pub fn data_residue(&self) -> Option<u32> {
match self.state {
State::SendingDataToHost |
State::ReceivingDataFromHost => Some(self.command_status_wrapper.data_residue),
_ => None,
}
}
pub fn transfer_state(&self) -> TransferState {
trace_bot_buffer!("BUFFER> i: {}, di: {}", self.buffer_i, self.data_i);
match self.state {
State::ReceivingDataFromHost => TransferState::ReceivingDataFromHost {
bytes_available: self.buffer_i - self.data_i,
full: self.buffer_i == self.buffer.len(),
done: self.command_status_wrapper.data_residue == 0,
},
State::SendingDataToHost => TransferState::SendingDataToHost {
bytes_remaining: self.buffer_i - self.data_i,
empty: self.buffer_i == 0,
},
_ => TransferState::NotTransferring {
bytes_remaining: self.buffer_i - self.data_i,
empty: self.buffer_i == 0,
},
}
}
pub fn take_buffer_space(&mut self, len: usize) -> Result<&mut [u8], Error> {
if len > self.buffer.len() {
panic!("BulkOnlyTransport::take_buffer_space called with len > buffer.len() ({} > {}) which can never be successful",
len, self.buffer.len());
}
if len <= self.buffer.len() - self.buffer_i {
trace_bot_buffer!("BUFFER> successfully allocated {} bytes", len);
let s = self.buffer_i;
let e = s + len;
self.buffer_i += len;
Ok(&mut self.buffer[s..e])
} else {
trace_bot_buffer!("BUFFER> insufficient space to allocate {} bytes", len);
Err(WouldBlock)?
}
}
pub fn take_buffered_data(&mut self, len: usize, take_available: bool) -> Result<&[u8], Error> {
if len > self.buffer.len() {
panic!("BulkOnlyTransport::take_buffered_data called with len > buffer.len() ({} > {}) which can never be successful",
len, self.buffer.len());
}
let available = self.buffer_i - self.data_i;
if !take_available && len > available {
trace_bot_buffer!("BUFFER> contains insufficient data for take; requested: {}, available: {}", len, self.buffer_i - self.data_i);
Err(WouldBlock)?
}
let s = self.data_i;
let len = len.min(available);
let e = s + len;
self.data_i += len;
if self.data_i == self.buffer_i {
self.data_i = 0;
self.buffer_i = 0;
}
trace_bot_buffer!("BUFFER> took {}, available after: {}", len, self.buffer_i - self.data_i);
Ok(&self.buffer[s..e])
}
fn flush(&mut self) -> Result<(), Error> {
let packet_size = self.max_packet_size() as usize;
let residue = self.command_status_wrapper.data_residue as usize;
let bytes = if self.data_i < self.buffer_i && residue > 0 {
let start = self.data_i;
let len = (self.buffer_i - self.data_i)
.min(residue)
.min(packet_size);
let end = start + len;
let bytes = self.inner.write_packet(&self.buffer[start..end])?;
self.last_packet_full = bytes == packet_size;
self.data_i += bytes;
let residue = residue - bytes;
self.command_status_wrapper.data_residue = residue as u32;
if self.data_i == self.buffer_i || residue == 0{
self.data_i = 0;
self.buffer_i = 0;
}
bytes
} else {
0
};
trace_bot_bytes!("BYTES> Sent {} bytes. Data residue {} -> {}. Buff bytes: {}",
bytes,
residue,
self.command_status_wrapper.data_residue,
self.buffer_i - self.data_i,
);
Ok(())
}
fn send_zlp(&mut self) -> Result<(), Error> {
match self.inner.write_packet(&[]) {
Ok(_) => trace_bot_zlp!("ZLP> sent"),
Err(e) => {
trace_bot_zlp!("ZLP> sending failed: {:?}", e);
Err(e)?
},
}
self.change_state(State::NeedToSendStatus);
Ok(())
}
fn pack_csw(&mut self) {
self.command_status_wrapper.pack(&mut self.buffer[..CommandStatusWrapper::BYTES]).unwrap();
self.buffer_i = CommandStatusWrapper::BYTES;
self.data_i = 0;
self.command_status_wrapper.data_residue = self.buffer_i as u32;
trace_bot_headers!("HEADER> CommandStatusWrapper buffered to send: {:X?}", self.command_status_wrapper);
}
fn end_data_transfer(&mut self) -> Result<(), Error> {
self.pack_csw();
let needs_zlp = self.last_packet_full &&
self.state == State::SendingDataToHost &&
self.command_status_wrapper.data_residue > 0;
if needs_zlp {
trace_bot_zlp!("ZLP> required");
self.change_state(State::NeedZlp);
self.send_zlp()?;
} else {
trace_bot_zlp!("ZLP> not required");
self.change_state(State::NeedToSendStatus);
self.flush()?;
}
Ok(())
}
pub fn send_command_ok(&mut self) -> Result<(), Error> {
self.command_status_wrapper.status = CommandStatus::CommandOk;
self.data_done = true;
self.check_end_data_transfer()
}
pub fn send_command_error(&mut self) -> Result<(), Error> {
self.command_status_wrapper.status = CommandStatus::CommandError;
self.data_done = true;
self.check_end_data_transfer()
}
fn sending_data_to_host(&mut self) -> Result<(), Error> {
self.flush()?;
self.check_end_data_transfer()
}
fn check_end_data_transfer(&mut self) -> Result<(), Error> {
match self.state {
State::ReceivingDataFromHost => {
if self.command_status_wrapper.data_residue == 0 &&
self.data_i == self.buffer_i
{
trace_bot_states!("STATE> Data residue = 0 and buffer empty, all data received");
self.end_data_transfer()?;
}
},
State::SendingDataToHost => {
if self.command_status_wrapper.data_residue == 0 {
trace_bot_states!("STATE> Data residue = 0, all data sent");
self.end_data_transfer()?;
} else if self.data_done && self.data_i == self.buffer_i {
trace_bot_states!("STATE> Data residue > 0, early termination");
self.end_data_transfer()?;
}
}
_ => {},
}
Ok(())
}
fn receiving_data_from_host(&mut self) -> Result<(), Error> {
if self.command_status_wrapper.data_residue > 0 &&
self.buffer.len() - self.buffer_i >= self.max_packet_usize()
{
let bytes = self.inner.read_packet(&mut self.buffer[self.buffer_i..])?;
self.buffer_i += bytes;
let bytes = bytes as u32;
let residue = self.command_status_wrapper.data_residue;
if self.command_status_wrapper.data_residue >= bytes {
self.command_status_wrapper.data_residue -= bytes;
} else {
warn!("Read more bytes that CBW offered");
self.command_status_wrapper.data_residue = 0;
}
trace_bot_bytes!("BYTES> Read {} bytes. Data residue {} -> {}. Buff bytes: {}",
bytes,
residue,
self.command_status_wrapper.data_residue,
self.buffer_i - self.data_i,
);
}
self.check_end_data_transfer()?;
Ok(())
}
fn need_to_send_status(&mut self) -> Result<(), Error> {
self.flush()?;
if self.command_status_wrapper.data_residue == 0 {
self.change_state(State::WaitingForCommand);
}
Ok(())
}
}
impl<B: UsbBus> UsbClass<B> for BulkOnlyTransport<'_, B> {
fn get_configuration_descriptors(&self, writer: &mut DescriptorWriter) -> UsbResult<()> {
self.inner.get_configuration_descriptors(writer)
}
fn reset(&mut self) {
trace_usb_control!("USB_CONTROL> reset");
self.buffer_i = 0;
self.data_i = 0;
self.data_done = false;
self.change_state(State::WaitingForCommand);
self.inner.reset()
}
fn control_in(&mut self, xfer: ControlIn<B>) {
let req = xfer.request();
if !self.inner.correct_interface_number(req.index) {
self.inner.control_in(xfer);
return;
}
let handled_res = match req {
Request { request_type: RequestType::Class, request: REQ_GET_MAX_LUN, .. } =>
Some(xfer.accept(|data| {
trace_usb_control!("USB_CONTROL> Get max lun. Response: {}", self.max_lun);
data[0] = self.max_lun;
Ok(1)
})),
Request { request_type: RequestType::Class, request: REQ_BULK_ONLY_RESET, .. } => {
self.reset();
Some(xfer.accept(|_| {
trace_usb_control!("USB_CONTROL> Bulk only mass storage reset");
Ok(0)
}))
},
_ => {
self.inner.control_in(xfer);
None
},
};
if let Some(Err(e)) = handled_res {
error!("Error from ControlIn.accept: {:?}", e);
}
}
fn control_out(&mut self, xfer: ControlOut<B>) {
self.inner.control_out(xfer)
}
fn poll(&mut self) {
panic!("BulkOnlyTransport::poll should never be called. Consumers (SCSI for example) should use BulkOnlyTransport::read and BulkOnlyTransport::write");
}
}