use core::cell::{Cell, RefCell};
use core::future::{Future, poll_fn};
use core::mem::{self, MaybeUninit};
use core::sync::atomic::{AtomicBool, Ordering};
use core::task::Poll;
use embassy_sync::blocking_mutex::CriticalSectionMutex;
use embassy_sync::waitqueue::WakerRegistration;
use crate::control::{self, InResponse, OutResponse, Recipient, Request, RequestType};
use crate::driver::{Driver, Endpoint, EndpointError, EndpointIn, EndpointOut};
use crate::types::InterfaceNumber;
use crate::{Builder, Handler};
pub const USB_CLASS_CDC: u8 = 0x02;
const USB_CLASS_CDC_DATA: u8 = 0x0a;
const CDC_SUBCLASS_ACM: u8 = 0x02;
const CDC_PROTOCOL_NONE: u8 = 0x00;
const CS_INTERFACE: u8 = 0x24;
const CDC_TYPE_HEADER: u8 = 0x00;
const CDC_TYPE_ACM: u8 = 0x02;
const CDC_TYPE_UNION: u8 = 0x06;
const REQ_SEND_ENCAPSULATED_COMMAND: u8 = 0x00;
#[allow(unused)]
const REQ_GET_ENCAPSULATED_COMMAND: u8 = 0x01;
const REQ_SET_LINE_CODING: u8 = 0x20;
const REQ_GET_LINE_CODING: u8 = 0x21;
const REQ_SET_CONTROL_LINE_STATE: u8 = 0x22;
#[derive(Clone, Debug)]
pub enum CdcAcmError {
NotConnected,
}
impl core::fmt::Display for CdcAcmError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match *self {
Self::NotConnected => f.write_str("NotConnected"),
}
}
}
impl core::error::Error for CdcAcmError {}
impl embedded_io_async::Error for CdcAcmError {
fn kind(&self) -> embedded_io_async::ErrorKind {
match *self {
Self::NotConnected => embedded_io_async::ErrorKind::NotConnected,
}
}
}
pub struct State<'a> {
control: MaybeUninit<Control<'a>>,
shared: ControlShared,
}
impl<'a> Default for State<'a> {
fn default() -> Self {
Self::new()
}
}
impl<'a> State<'a> {
pub const fn new() -> Self {
Self {
control: MaybeUninit::uninit(),
shared: ControlShared::new(),
}
}
}
pub struct CdcAcmClass<'d, D: Driver<'d>> {
_comm_ep: D::EndpointIn,
_data_if: InterfaceNumber,
read_ep: D::EndpointOut,
write_ep: D::EndpointIn,
control: &'d ControlShared,
}
struct Control<'a> {
comm_if: InterfaceNumber,
shared: &'a ControlShared,
}
struct ControlShared {
line_coding: CriticalSectionMutex<Cell<LineCoding>>,
dtr: AtomicBool,
rts: AtomicBool,
waker: RefCell<WakerRegistration>,
changed: AtomicBool,
}
impl Default for ControlShared {
fn default() -> Self {
Self::new()
}
}
impl ControlShared {
const fn new() -> Self {
ControlShared {
dtr: AtomicBool::new(false),
rts: AtomicBool::new(false),
line_coding: CriticalSectionMutex::new(Cell::new(LineCoding {
stop_bits: StopBits::One,
data_bits: 8,
parity_type: ParityType::None,
data_rate: 8_000,
})),
waker: RefCell::new(WakerRegistration::new()),
changed: AtomicBool::new(false),
}
}
fn changed(&self) -> impl Future<Output = ()> + '_ {
poll_fn(|cx| {
if self.changed.load(Ordering::Relaxed) {
self.changed.store(false, Ordering::Relaxed);
Poll::Ready(())
} else {
self.waker.borrow_mut().register(cx.waker());
Poll::Pending
}
})
}
}
impl<'a> Control<'a> {
fn shared(&mut self) -> &'a ControlShared {
self.shared
}
}
impl<'d> Handler for Control<'d> {
fn reset(&mut self) {
let shared = self.shared();
shared.line_coding.lock(|x| x.set(LineCoding::default()));
shared.dtr.store(false, Ordering::Relaxed);
shared.rts.store(false, Ordering::Relaxed);
shared.changed.store(true, Ordering::Relaxed);
shared.waker.borrow_mut().wake();
}
fn control_out(&mut self, req: control::Request, data: &[u8]) -> Option<OutResponse> {
if (req.request_type, req.recipient, req.index)
!= (RequestType::Class, Recipient::Interface, self.comm_if.0 as u16)
{
return None;
}
match req.request {
REQ_SEND_ENCAPSULATED_COMMAND => {
Some(OutResponse::Accepted)
}
REQ_SET_LINE_CODING if data.len() >= 7 => {
let coding = LineCoding {
data_rate: u32::from_le_bytes(data[0..4].try_into().unwrap()),
stop_bits: data[4].into(),
parity_type: data[5].into(),
data_bits: data[6],
};
let shared = self.shared();
shared.line_coding.lock(|x| x.set(coding));
debug!("Set line coding to: {:?}", coding);
shared.changed.store(true, Ordering::Relaxed);
shared.waker.borrow_mut().wake();
Some(OutResponse::Accepted)
}
REQ_SET_CONTROL_LINE_STATE => {
let dtr = (req.value & 0x0001) != 0;
let rts = (req.value & 0x0002) != 0;
let shared = self.shared();
shared.dtr.store(dtr, Ordering::Relaxed);
shared.rts.store(rts, Ordering::Relaxed);
debug!("Set dtr {}, rts {}", dtr, rts);
shared.changed.store(true, Ordering::Relaxed);
shared.waker.borrow_mut().wake();
Some(OutResponse::Accepted)
}
_ => Some(OutResponse::Rejected),
}
}
fn control_in<'a>(&'a mut self, req: Request, buf: &'a mut [u8]) -> Option<InResponse<'a>> {
if (req.request_type, req.recipient, req.index)
!= (RequestType::Class, Recipient::Interface, self.comm_if.0 as u16)
{
return None;
}
match req.request {
REQ_GET_LINE_CODING if req.length == 7 => {
debug!("Sending line coding");
let coding = self.shared().line_coding.lock(Cell::get);
assert!(buf.len() >= 7);
buf[0..4].copy_from_slice(&coding.data_rate.to_le_bytes());
buf[4] = coding.stop_bits as u8;
buf[5] = coding.parity_type as u8;
buf[6] = coding.data_bits;
Some(InResponse::Accepted(&buf[0..7]))
}
_ => Some(InResponse::Rejected),
}
}
}
impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> {
pub fn new(builder: &mut Builder<'d, D>, state: &'d mut State<'d>, max_packet_size: u16) -> Self {
assert!(builder.control_buf_len() >= 7);
let mut func = builder.function(USB_CLASS_CDC, CDC_SUBCLASS_ACM, CDC_PROTOCOL_NONE);
let mut iface = func.interface();
let comm_if = iface.interface_number();
let data_if = u8::from(comm_if) + 1;
let mut alt = iface.alt_setting(USB_CLASS_CDC, CDC_SUBCLASS_ACM, CDC_PROTOCOL_NONE, None);
alt.descriptor(
CS_INTERFACE,
&[
CDC_TYPE_HEADER, 0x10,
0x01, ],
);
alt.descriptor(
CS_INTERFACE,
&[
CDC_TYPE_ACM, 0x02, ],
);
alt.descriptor(
CS_INTERFACE,
&[
CDC_TYPE_UNION, comm_if.into(), data_if, ],
);
let comm_ep = alt.endpoint_interrupt_in(None, 8, 255);
let mut iface = func.interface();
let data_if = iface.interface_number();
let mut alt = iface.alt_setting(USB_CLASS_CDC_DATA, 0x00, CDC_PROTOCOL_NONE, None);
let read_ep = alt.endpoint_bulk_out(None, max_packet_size);
let write_ep = alt.endpoint_bulk_in(None, max_packet_size);
drop(func);
let control = state.control.write(Control {
shared: &state.shared,
comm_if,
});
builder.handler(control);
let control_shared = &state.shared;
CdcAcmClass {
_comm_ep: comm_ep,
_data_if: data_if,
read_ep,
write_ep,
control: control_shared,
}
}
pub fn max_packet_size(&self) -> u16 {
self.read_ep.info().max_packet_size
}
pub fn line_coding(&self) -> LineCoding {
self.control.line_coding.lock(Cell::get)
}
pub fn dtr(&self) -> bool {
self.control.dtr.load(Ordering::Relaxed)
}
pub fn rts(&self) -> bool {
self.control.rts.load(Ordering::Relaxed)
}
pub async fn write_packet(&mut self, data: &[u8]) -> Result<(), EndpointError> {
self.write_ep.write(data).await
}
pub async fn read_packet(&mut self, data: &mut [u8]) -> Result<usize, EndpointError> {
self.read_ep.read(data).await
}
pub async fn wait_connection(&mut self) {
self.read_ep.wait_enabled().await;
}
pub fn split(self) -> (Sender<'d, D>, Receiver<'d, D>) {
(
Sender {
write_ep: self.write_ep,
control: self.control,
},
Receiver {
read_ep: self.read_ep,
control: self.control,
},
)
}
pub fn split_with_control(self) -> (Sender<'d, D>, Receiver<'d, D>, ControlChanged<'d>) {
(
Sender {
write_ep: self.write_ep,
control: self.control,
},
Receiver {
read_ep: self.read_ep,
control: self.control,
},
ControlChanged { control: self.control },
)
}
}
pub struct ControlChanged<'d> {
control: &'d ControlShared,
}
impl<'d> ControlChanged<'d> {
pub async fn control_changed(&self) {
self.control.changed().await;
}
pub fn dtr(&self) -> bool {
self.control.dtr.load(Ordering::Relaxed)
}
pub fn rts(&self) -> bool {
self.control.rts.load(Ordering::Relaxed)
}
}
pub struct Sender<'d, D: Driver<'d>> {
write_ep: D::EndpointIn,
control: &'d ControlShared,
}
impl<'d, D: Driver<'d>> Sender<'d, D> {
pub fn max_packet_size(&self) -> u16 {
self.write_ep.info().max_packet_size
}
pub fn line_coding(&self) -> LineCoding {
self.control.line_coding.lock(Cell::get)
}
pub fn dtr(&self) -> bool {
self.control.dtr.load(Ordering::Relaxed)
}
pub fn rts(&self) -> bool {
self.control.rts.load(Ordering::Relaxed)
}
pub async fn write_packet(&mut self, data: &[u8]) -> Result<(), EndpointError> {
self.write_ep.write(data).await
}
pub async fn wait_connection(&mut self) {
self.write_ep.wait_enabled().await;
}
}
impl<'d, D: Driver<'d>> embedded_io_async::ErrorType for Sender<'d, D> {
type Error = CdcAcmError;
}
impl<'d, D: Driver<'d>> embedded_io_async::Write for Sender<'d, D> {
async fn write(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
let len = core::cmp::min(buf.len(), self.max_packet_size() as usize);
match self.write_packet(&buf[..len]).await {
Ok(()) => Ok(len),
Err(EndpointError::BufferOverflow) => unreachable!(),
Err(EndpointError::Disabled) => Err(CdcAcmError::NotConnected),
}
}
async fn flush(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}
pub struct Receiver<'d, D: Driver<'d>> {
read_ep: D::EndpointOut,
control: &'d ControlShared,
}
impl<'d, D: Driver<'d>> Receiver<'d, D> {
pub fn max_packet_size(&self) -> u16 {
self.read_ep.info().max_packet_size
}
pub fn line_coding(&self) -> LineCoding {
self.control.line_coding.lock(Cell::get)
}
pub fn dtr(&self) -> bool {
self.control.dtr.load(Ordering::Relaxed)
}
pub fn rts(&self) -> bool {
self.control.rts.load(Ordering::Relaxed)
}
pub async fn read_packet(&mut self, data: &mut [u8]) -> Result<usize, EndpointError> {
self.read_ep.read(data).await
}
pub async fn wait_connection(&mut self) {
self.read_ep.wait_enabled().await;
}
pub fn into_buffered(self, buf: &'d mut [u8]) -> BufferedReceiver<'d, D> {
BufferedReceiver {
receiver: self,
buffer: buf,
start: 0,
end: 0,
}
}
}
pub struct BufferedReceiver<'d, D: Driver<'d>> {
receiver: Receiver<'d, D>,
buffer: &'d mut [u8],
start: usize,
end: usize,
}
impl<'d, D: Driver<'d>> BufferedReceiver<'d, D> {
fn read_from_buffer(&mut self, buf: &mut [u8]) -> usize {
let available = &self.buffer[self.start..self.end];
let len = core::cmp::min(available.len(), buf.len());
buf[..len].copy_from_slice(&available[..len]);
self.start += len;
len
}
pub fn line_coding(&self) -> LineCoding {
self.receiver.line_coding()
}
pub fn dtr(&self) -> bool {
self.receiver.dtr()
}
pub fn rts(&self) -> bool {
self.receiver.rts()
}
pub async fn wait_connection(&mut self) {
self.receiver.wait_connection().await;
}
}
impl<'d, D: Driver<'d>> embedded_io_async::ErrorType for BufferedReceiver<'d, D> {
type Error = CdcAcmError;
}
impl<'d, D: Driver<'d>> embedded_io_async::Read for BufferedReceiver<'d, D> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
if self.start != self.end {
return Ok(self.read_from_buffer(buf));
}
if buf.len() > self.receiver.max_packet_size() as usize {
return match self.receiver.read_packet(buf).await {
Ok(n) => Ok(n),
Err(EndpointError::BufferOverflow) => unreachable!(),
Err(EndpointError::Disabled) => Err(CdcAcmError::NotConnected),
};
}
match self.receiver.read_packet(&mut self.buffer).await {
Ok(n) => self.end = n,
Err(EndpointError::BufferOverflow) => unreachable!(),
Err(EndpointError::Disabled) => return Err(CdcAcmError::NotConnected),
}
self.start = 0;
return Ok(self.read_from_buffer(buf));
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum StopBits {
One = 0,
OnePointFive = 1,
Two = 2,
}
impl From<u8> for StopBits {
fn from(value: u8) -> Self {
if value <= 2 {
unsafe { mem::transmute(value) }
} else {
StopBits::One
}
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ParityType {
None = 0,
Odd = 1,
Even = 2,
Mark = 3,
Space = 4,
}
impl From<u8> for ParityType {
fn from(value: u8) -> Self {
if value <= 4 {
unsafe { mem::transmute(value) }
} else {
ParityType::None
}
}
}
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct LineCoding {
stop_bits: StopBits,
data_bits: u8,
parity_type: ParityType,
data_rate: u32,
}
impl LineCoding {
pub fn stop_bits(&self) -> StopBits {
self.stop_bits
}
pub const fn data_bits(&self) -> u8 {
self.data_bits
}
pub const fn parity_type(&self) -> ParityType {
self.parity_type
}
pub const fn data_rate(&self) -> u32 {
self.data_rate
}
}
impl Default for LineCoding {
fn default() -> Self {
LineCoding {
stop_bits: StopBits::One,
data_bits: 8,
parity_type: ParityType::None,
data_rate: 8_000,
}
}
}