use std::{cell::RefCell, cmp::min, rc::Rc};
use usb_device::bus::PollResult;
use usb_device::endpoint::{EndpointAddress, EndpointType};
use usb_device::{Result as UsbDeviceResult, UsbDirection, UsbError};
struct EndpointImpl {
alloc: bool,
stall: bool,
read_len: usize,
read: [u8; 1024],
read_ready: bool,
write_len: usize,
write: [u8; 1024],
write_done: bool,
setup: bool,
max_size: usize,
}
impl EndpointImpl {
fn new() -> Self {
EndpointImpl {
alloc: false,
stall: false,
read_len: 0,
read: [0; 1024],
read_ready: false,
write_len: 0,
write: [0; 1024],
write_done: false,
setup: false,
max_size: 0,
}
}
fn set_read(&mut self, data: &[u8], setup: bool) {
self.read_len = data.len();
if self.read_len > 0 {
self.read[..self.read_len].clone_from_slice(data);
self.setup = setup;
self.read_ready = true;
}
}
fn get_write(&mut self, data: &mut [u8]) -> usize {
let res = self.write_len;
dbg!("g", self.write_len);
self.write_len = 0;
data[..res].clone_from_slice(&self.write[..res]);
self.write_done = true;
res
}
}
pub(crate) struct UsbBusImpl {
ep_i: [RefCell<EndpointImpl>; 4],
ep_o: [RefCell<EndpointImpl>; 4],
}
impl UsbBusImpl {
pub(crate) fn new() -> Self {
Self {
ep_i: [
RefCell::new(EndpointImpl::new()),
RefCell::new(EndpointImpl::new()),
RefCell::new(EndpointImpl::new()),
RefCell::new(EndpointImpl::new()),
],
ep_o: [
RefCell::new(EndpointImpl::new()),
RefCell::new(EndpointImpl::new()),
RefCell::new(EndpointImpl::new()),
RefCell::new(EndpointImpl::new()),
],
}
}
fn epidx(&self, ep_addr: EndpointAddress) -> &RefCell<EndpointImpl> {
match ep_addr.direction() {
UsbDirection::In => self.ep_i.get(ep_addr.index()).unwrap(),
UsbDirection::Out => self.ep_o.get(ep_addr.index()).unwrap(),
}
}
pub(crate) fn get_write(&self, ep_addr: EndpointAddress, data: &mut [u8]) -> usize {
let mut ep = self.epidx(ep_addr).borrow_mut();
ep.get_write(data)
}
pub(crate) fn set_read(&self, ep_addr: EndpointAddress, data: &[u8], setup: bool) {
let mut ep = self.epidx(ep_addr).borrow_mut();
if setup && ep_addr.index() == 0 && ep_addr.direction() == UsbDirection::Out {
ep.stall = false;
let mut ep0in = self.ep_i.get(0).unwrap().borrow_mut();
ep0in.stall = false;
}
ep.set_read(data, setup)
}
pub(crate) fn stalled0(&self) -> bool {
let in0 = EndpointAddress::from_parts(0, UsbDirection::In);
let out0 = EndpointAddress::from_parts(0, UsbDirection::Out);
{
let ep = self.epidx(in0).borrow();
if ep.stall {
return true;
}
}
{
let ep = self.epidx(out0).borrow();
if ep.stall {
return true;
}
}
false
}
}
pub struct EmulatedUsbBus {
usb_address: RefCell<u8>,
bus: Rc<RefCell<UsbBusImpl>>,
}
unsafe impl Sync for EmulatedUsbBus {}
impl EmulatedUsbBus {
pub(crate) fn new(bus: &Rc<RefCell<UsbBusImpl>>) -> Self {
Self {
usb_address: RefCell::new(0),
bus: bus.clone(),
}
}
fn bus_ref(&self) -> &RefCell<UsbBusImpl> {
self.bus.as_ref()
}
pub fn get_address(&self) -> u8 {
*self.usb_address.borrow()
}
}
impl usb_device::bus::UsbBus for EmulatedUsbBus {
fn alloc_ep(
&mut self,
_ep_dir: UsbDirection,
ep_addr: Option<EndpointAddress>,
_ep_type: EndpointType,
max_packet_size: u16,
_interval: u8,
) -> UsbDeviceResult<EndpointAddress> {
if let Some(ea) = ep_addr {
let io = self.bus_ref().borrow();
let mut sep = io.epidx(ea).borrow_mut();
if sep.alloc {
return Err(UsbError::InvalidEndpoint);
}
sep.alloc = true;
sep.stall = false;
sep.max_size = max_packet_size as usize;
Ok(ea)
} else {
Err(UsbError::EndpointMemoryOverflow)
}
}
fn enable(&mut self) {}
fn force_reset(&self) -> UsbDeviceResult<()> {
Err(UsbError::Unsupported)
}
fn poll(&self) -> PollResult {
let in0 = EndpointAddress::from_parts(0, UsbDirection::In);
let out0 = EndpointAddress::from_parts(0, UsbDirection::Out);
let io = self.bus_ref().borrow();
let ep0out = io.epidx(out0).borrow();
let mut ep0in = io.epidx(in0).borrow_mut();
let ep0_write_done = ep0in.write_done;
let ep0_can_read = ep0out.read_ready | ep0in.read_ready;
let ep0_setup = ep0out.setup;
ep0in.write_done = false;
dbg!(ep0_write_done, ep0_can_read, ep0_setup);
if ep0_write_done || ep0_can_read || ep0_setup {
PollResult::Data {
ep_in_complete: if ep0_write_done { 1 } else { 0 },
ep_out: if ep0_can_read { 1 } else { 0 },
ep_setup: if ep0_setup { 1 } else { 0 },
}
} else {
PollResult::None
}
}
fn read(&self, ep_addr: EndpointAddress, buf: &mut [u8]) -> UsbDeviceResult<usize> {
let io = self.bus_ref().borrow();
let mut ep = io.epidx(ep_addr).borrow_mut();
let len = min(buf.len(), min(ep.read_len, ep.max_size));
dbg!("read len from", buf.len(), len, ep_addr);
if len == 0 {
return Err(UsbError::WouldBlock);
}
buf[..len].clone_from_slice(&ep.read[..len]);
ep.read_len -= len;
ep.read.copy_within(len.., 0);
if ep.read_len == 0 {
ep.setup = false;
}
ep.read_ready = ep.read_len > 0;
Ok(len)
}
fn reset(&self) {
todo!()
}
fn resume(&self) {
todo!()
}
fn suspend(&self) {
todo!()
}
fn set_device_address(&self, addr: u8) {
self.usb_address.replace(addr);
}
fn is_stalled(&self, ep_addr: EndpointAddress) -> bool {
let io = self.bus_ref().borrow();
let ep = io.epidx(ep_addr).borrow();
ep.stall
}
fn set_stalled(&self, ep_addr: EndpointAddress, stalled: bool) {
let io = self.bus_ref().borrow();
let mut ep = io.epidx(ep_addr).borrow_mut();
ep.stall = stalled;
}
fn write(&self, ep_addr: EndpointAddress, buf: &[u8]) -> UsbDeviceResult<usize> {
let io = self.bus_ref().borrow();
let mut ep = io.epidx(ep_addr).borrow_mut();
let offset = ep.write_len;
let mut len = 0;
dbg!("write", buf.len());
if buf.len() > ep.max_size {
return Err(UsbError::BufferOverflow);
}
for (i, e) in ep.write[offset..].iter_mut().enumerate() {
if i >= buf.len() {
break;
}
*e = buf[i];
len += 1;
}
dbg!("wrote", len);
ep.write_len += len;
ep.write_done = false;
Ok(len)
}
}