pub mod message;
use core::future::Future;
use core::marker::PhantomData;
use byteorder::{ByteOrder, LittleEndian};
use embassy_futures::select::{Either, select};
use heapless::Vec;
use message::Message;
use message::data::{Data, request};
use message::extended::extended_control::ExtendedControlMessageType;
use message::header::{ControlMessageType, DataMessageType, ExtendedMessageType, Header, MessageType};
use usbpd_traits::{Driver, DriverRxError, DriverTxError};
use crate::PowerRole;
use crate::counters::{Counter, CounterType, Error as CounterError};
use crate::protocol_layer::message::data::epr_mode::EprModeDataObject;
use crate::protocol_layer::message::extended::Extended;
use crate::protocol_layer::message::{ParseError, Payload};
use crate::timers::{Timer, TimerType};
const MAX_MESSAGE_SIZE: usize = 272;
const MSG_HEADER_SIZE: usize = 2;
const EXT_HEADER_SIZE: usize = 2;
#[derive(thiserror::Error, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum ProtocolError {
#[error("RX error")]
RxError(#[from] RxError),
#[error("TX error")]
TxError(#[from] TxError),
#[error("transmit retries (`{0}`) exceeded")]
TransmitRetriesExceeded(u8),
#[error("unexpected message")]
UnexpectedMessage,
}
#[derive(thiserror::Error, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum RxError {
#[error("soft reset")]
SoftReset,
#[error("hard reset")]
HardReset,
#[error("receive timeout")]
ReceiveTimeout,
#[error("unsupported message")]
UnsupportedMessage,
#[error("parse error")]
ParseError(#[from] ParseError),
#[error("wrong tx id `{0}` acknowledged")]
AcknowledgeMismatch(u8),
}
#[derive(thiserror::Error, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum TxError {
#[error("hard reset")]
HardReset,
#[error("unchunked extended messages not supported")]
UnchunkedExtendedMessagesNotSupported,
#[error("AVS voltage alignment invalid")]
AvsVoltageAlignmentInvalid,
}
#[derive(Debug)]
struct Counters {
_busy: Counter,
_caps: Counter, _discover_identity: Counter,
rx_message: Option<Counter>,
tx_message: Counter,
retry: Counter,
}
impl Default for Counters {
fn default() -> Self {
Counters {
_busy: Counter::new(CounterType::Busy),
_caps: Counter::new(CounterType::Caps),
_discover_identity: Counter::new(CounterType::DiscoverIdentity),
rx_message: None,
tx_message: Counter::new(CounterType::MessageId),
retry: Counter::new(CounterType::Retry),
}
}
}
#[derive(Debug)]
pub(crate) struct ProtocolLayer<DRIVER: Driver, TIMER: Timer> {
driver: DRIVER,
counters: Counters,
default_header: Header,
extended_rx_buffer: Vec<u8, MAX_MESSAGE_SIZE>,
extended_rx_expected: Option<(ExtendedMessageType, u16, u8)>,
_timer: PhantomData<TIMER>,
}
impl<DRIVER: Driver, TIMER: Timer> ProtocolLayer<DRIVER, TIMER> {
pub fn new(driver: DRIVER, default_header: Header) -> Self {
Self {
driver,
counters: Default::default(),
default_header,
extended_rx_buffer: Vec::new(),
extended_rx_expected: None,
_timer: PhantomData,
}
}
pub fn reset(&mut self) {
self.counters = Default::default();
}
#[cfg(test)]
pub fn driver(&mut self) -> &mut DRIVER {
&mut self.driver
}
#[cfg(test)]
pub fn header(&self) -> &Header {
&self.default_header
}
fn get_message_buffer() -> [u8; MAX_MESSAGE_SIZE] {
[0u8; MAX_MESSAGE_SIZE]
}
pub fn get_timer(timer_type: TimerType) -> impl Future<Output = ()> {
TimerType::get_timer::<TIMER>(timer_type)
}
async fn receive_simple(&mut self) -> Result<Message, RxError> {
loop {
let mut buffer = Self::get_message_buffer();
let length = match self.driver.receive(&mut buffer).await {
Ok(length) => length,
Err(DriverRxError::Discarded) => continue,
Err(DriverRxError::HardReset) => return Err(RxError::HardReset),
};
let message = Message::from_bytes(&buffer[..length])?;
return Ok(message);
}
}
async fn wait_for_good_crc(&mut self) -> Result<(), RxError> {
trace!("Wait for GoodCrc");
let timeout_fut = Self::get_timer(TimerType::CRCReceive);
let receive_fut = async {
let message = self.receive_simple().await?;
if matches!(
message.header.message_type(),
MessageType::Control(ControlMessageType::GoodCRC)
) {
trace!(
"Received GoodCrc, TX message count: {}, expected: {}",
message.header.message_id(),
self.counters.tx_message.value()
);
if message.header.message_id() == self.counters.tx_message.value() {
self.counters.retry.reset();
_ = self.counters.tx_message.increment();
Ok(())
} else {
Err(RxError::AcknowledgeMismatch(message.header.message_id()))
}
} else if matches!(message.header.message_type(), MessageType::Control(_)) {
Err(ParseError::InvalidControlMessageType(message.header.message_type_raw()).into())
} else {
Err(ParseError::InvalidMessageType(message.header.message_type_raw()).into())
}
};
match select(timeout_fut, receive_fut).await {
Either::First(_) => Err(RxError::ReceiveTimeout),
Either::Second(receive_result) => receive_result,
}
}
fn validate_outgoing_message(message: &Message) -> Result<(), TxError> {
if let Some(Payload::Data(message::data::Data::Request(power_source))) = &message.payload {
use message::data::request::PowerSource;
match power_source {
PowerSource::FixedVariableSupply(rdo) => {
if rdo.unchunked_extended_messages_supported() {
return Err(TxError::UnchunkedExtendedMessagesNotSupported);
}
}
PowerSource::Pps(rdo) => {
if rdo.unchunked_extended_messages_supported() {
return Err(TxError::UnchunkedExtendedMessagesNotSupported);
}
}
PowerSource::EprRequest(epr) => {
let rdo_bits = epr.rdo;
let unchunked = (rdo_bits >> 23) & 1 == 1;
if unchunked {
return Err(TxError::UnchunkedExtendedMessagesNotSupported);
}
let is_avs = ((rdo_bits >> 30) & 0x3 == 0) && ((rdo_bits >> 28) & 0x3 == 3);
if is_avs {
let voltage = (rdo_bits >> 9) & 0xFFF;
if (voltage as u16) & 0x3 != 0 {
return Err(TxError::AvsVoltageAlignmentInvalid);
}
}
}
_ => {}
}
}
Ok(())
}
async fn transmit_inner(&mut self, buffer: &[u8]) -> Result<(), TxError> {
loop {
match self.driver.transmit(buffer).await {
Ok(_) => return Ok(()),
Err(DriverTxError::HardReset) => return Err(TxError::HardReset),
Err(DriverTxError::Discarded) => {
}
}
}
}
pub async fn transmit(&mut self, message: Message) -> Result<(), ProtocolError> {
assert_ne!(
message.header.message_type(),
MessageType::Control(ControlMessageType::GoodCRC)
);
Self::validate_outgoing_message(&message)?;
trace!("Transmit message: {:?}", message);
let mut buffer = Self::get_message_buffer();
let size = message.to_bytes(&mut buffer);
if DRIVER::HAS_AUTO_RETRY {
match self.driver.transmit(&buffer[..size]).await {
Ok(()) => {
self.counters.retry.reset();
_ = self.counters.tx_message.increment();
trace!("Transmit success (hardware retry)");
Ok(())
}
Err(DriverTxError::HardReset) => Err(TxError::HardReset.into()),
Err(DriverTxError::Discarded) => {
Err(ProtocolError::TransmitRetriesExceeded(self.counters.retry.max_value()))
}
}
} else {
self.counters.retry.reset();
loop {
match self.transmit_inner(&buffer[..size]).await {
Ok(_) => match self.wait_for_good_crc().await {
Ok(()) => {
trace!("Transmit success");
return Ok(());
}
Err(RxError::ReceiveTimeout) => match self.counters.retry.increment() {
Ok(_) => {
}
Err(CounterError::Exceeded) => {
return Err(ProtocolError::TransmitRetriesExceeded(self.counters.retry.max_value()));
}
},
Err(other) => return Err(other.into()),
},
Err(other) => return Err(other.into()),
}
}
}
}
async fn transmit_good_crc(&mut self) -> Result<(), ProtocolError> {
trace!(
"Transmit message GoodCrc for RX message count: {}",
self.counters.rx_message.unwrap().value()
);
let mut buffer = Self::get_message_buffer();
let size = Message::new(Header::new_control(
self.default_header,
self.counters.rx_message.unwrap(), ControlMessageType::GoodCRC,
))
.to_bytes(&mut buffer);
Ok(self.transmit_inner(&buffer[..size]).await?)
}
async fn handle_rx_ack(&mut self, message: &Message) -> Result<bool, RxError> {
let is_good_crc = matches!(
message.header.message_type(),
MessageType::Control(ControlMessageType::GoodCRC)
);
let is_retransmission = if is_good_crc {
false
} else {
self.update_rx_message_counter(message)
};
if !DRIVER::HAS_AUTO_GOOD_CRC && !is_good_crc {
match self.transmit_good_crc().await {
Ok(()) => {}
Err(ProtocolError::TxError(TxError::HardReset)) => return Err(RxError::HardReset),
Err(_) => return Err(RxError::UnsupportedMessage),
}
}
Ok(is_retransmission)
}
fn reset_chunked_rx(&mut self) {
self.extended_rx_buffer.clear();
self.extended_rx_expected = None;
}
async fn receive_message_inner(&mut self) -> Result<Message, RxError> {
loop {
let mut buffer = Self::get_message_buffer();
let length = match self.driver.receive(&mut buffer).await {
Ok(length) => length,
Err(DriverRxError::Discarded) => continue,
Err(DriverRxError::HardReset) => return Err(RxError::HardReset),
};
let header = Header::from_bytes(&buffer[..MSG_HEADER_SIZE])?;
let message_type = header.message_type();
if matches!(message_type, MessageType::Extended(_)) {
let ext_header_end = MSG_HEADER_SIZE + EXT_HEADER_SIZE;
let ext_header =
message::extended::ExtendedHeader::from_bytes(&buffer[MSG_HEADER_SIZE..ext_header_end]);
let payload = &buffer[ext_header_end..length];
let total_size = ext_header.data_size();
let chunked = ext_header.chunked();
let chunk_number = ext_header.chunk_number();
let msg_type = match message_type {
MessageType::Extended(mt) => mt,
_ => unreachable!(),
};
self.default_header = self.default_header.with_spec_revision(header.spec_revision()?);
if chunked {
trace!(
"Received chunked extended message {:?}, chunk {}, size {}",
message_type,
chunk_number,
payload.len()
);
let tmp_message = Message { header, payload: None };
if self.handle_rx_ack(&tmp_message).await? {
continue; }
let (expected_total, expected_next) = match self.extended_rx_expected {
Some((ty, total, next)) if ty == msg_type => (total, next),
_ => (total_size, 0),
};
if expected_next != 0 && chunk_number != expected_next {
self.reset_chunked_rx();
return Err(RxError::UnsupportedMessage);
}
if chunk_number == 0 || expected_next == 0 {
self.extended_rx_buffer.clear();
self.extended_rx_expected = Some((msg_type, total_size, 1));
} else {
self.extended_rx_expected = Some((msg_type, expected_total, expected_next + 1));
}
if self.extended_rx_buffer.len() + payload.len() > self.extended_rx_buffer.capacity() {
self.reset_chunked_rx();
return Err(RxError::UnsupportedMessage);
}
if self.extended_rx_buffer.extend_from_slice(payload).is_err() {
self.reset_chunked_rx();
return Err(RxError::UnsupportedMessage);
}
if self.extended_rx_buffer.len() < total_size as usize {
let next_chunk = self
.extended_rx_expected
.as_ref()
.map(|(_, _, next)| *next)
.unwrap_or(1);
self.transmit_chunk_request(msg_type, next_chunk).await?;
continue;
}
let ext_payload = &self.extended_rx_buffer[..total_size as usize];
let parsed_payload = match msg_type {
ExtendedMessageType::ExtendedControl => {
Payload::Extended(message::extended::Extended::ExtendedControl(
message::extended::extended_control::ExtendedControl::from_bytes(ext_payload),
))
}
ExtendedMessageType::EprSourceCapabilities => {
Payload::Extended(message::extended::Extended::EprSourceCapabilities(
ext_payload
.chunks_exact(4)
.map(|buf| {
message::data::source_capabilities::parse_raw_pdo(LittleEndian::read_u32(buf))
})
.collect(),
))
}
_ => Payload::Extended(message::extended::Extended::Unknown),
};
self.extended_rx_expected = None;
let mut message = Message::new(header);
message.payload = Some(parsed_payload);
trace!("Received assembled extended message {:?}", message);
return Ok(message);
}
}
let message = Message::from_bytes(&buffer[..length])?;
self.default_header = self.default_header.with_spec_revision(message.header.spec_revision()?);
match message.header.message_type() {
MessageType::Control(ControlMessageType::Reserved) | MessageType::Data(DataMessageType::Reserved) => {
trace!("Unsupported message type in header: {:?}", message.header);
return Err(RxError::UnsupportedMessage);
}
MessageType::Control(ControlMessageType::SoftReset) => return Err(RxError::SoftReset),
_ => (),
}
if self.handle_rx_ack(&message).await? {
continue; }
trace!("Received message {:?}", message);
return Ok(message);
}
}
pub async fn receive_message(&mut self) -> Result<Message, ProtocolError> {
self.receive_message_inner().await.map_err(|err| err.into())
}
fn update_rx_message_counter(&mut self, rx_message: &Message) -> bool {
match self.counters.rx_message.as_mut() {
None => {
trace!(
"Received first message after protocol layer reset with RX counter value: {}",
rx_message.header.message_id()
);
self.counters.rx_message = Some(Counter::new_from_value(
CounterType::MessageId,
rx_message.header.message_id(),
));
false
}
Some(counter) => {
if rx_message.header.message_id() == counter.value() {
trace!("Received retransmission of RX counter value: {}", counter.value());
true
} else {
counter.set(rx_message.header.message_id());
false
}
}
}
}
pub async fn receive_message_type(
&mut self,
message_types: &[MessageType],
timer_type: TimerType,
) -> Result<Message, ProtocolError> {
for message_type in message_types {
assert_ne!(*message_type, MessageType::Control(ControlMessageType::GoodCRC));
}
let timeout_fut = Self::get_timer(timer_type);
let receive_fut = async {
loop {
match self.receive_message_inner().await {
Ok(message) => {
if matches!(
message.header.message_type(),
MessageType::Control(ControlMessageType::GoodCRC)
) {
continue;
}
return if message_types.contains(&message.header.message_type()) {
Ok(message)
} else {
Err(ProtocolError::UnexpectedMessage)
};
}
Err(RxError::ParseError(_)) => unreachable!(),
Err(other) => return Err(other.into()),
}
}
};
match select(timeout_fut, receive_fut).await {
Either::First(_) => Err(RxError::ReceiveTimeout.into()),
Either::Second(receive_result) => receive_result,
}
}
pub async fn hard_reset(&mut self) -> Result<(), ProtocolError> {
self.counters.tx_message.reset();
self.counters.retry.reset();
loop {
match self.driver.transmit_hard_reset().await {
Ok(_) | Err(DriverTxError::HardReset) => break,
Err(DriverTxError::Discarded) => (),
}
}
trace!("Performed hard reset");
Ok(())
}
pub async fn wait_for_vbus(&mut self) {
self.driver.wait_for_vbus().await
}
pub async fn wait_for_source_capabilities(&mut self) -> Result<Message, ProtocolError> {
self.receive_message_type(
&[
MessageType::Data(message::header::DataMessageType::SourceCapabilities),
MessageType::Extended(ExtendedMessageType::EprSourceCapabilities),
],
TimerType::SinkWaitCap,
)
.await
}
pub async fn transmit_control_message(&mut self, message_type: ControlMessageType) -> Result<(), ProtocolError> {
let message = Message::new(Header::new_control(
self.default_header,
self.counters.tx_message,
message_type,
));
self.transmit(message).await
}
pub async fn transmit_extended_control_message(
&mut self,
message_type: ExtendedControlMessageType,
) -> Result<(), ProtocolError> {
let mut message = Message::new(Header::new_extended(
self.default_header,
self.counters.tx_message,
ExtendedMessageType::ExtendedControl,
1,
));
message.payload = Some(Payload::Extended(Extended::ExtendedControl(
message::extended::extended_control::ExtendedControl::default().with_message_type(message_type),
)));
self.transmit(message).await
}
pub async fn transmit_epr_mode(
&mut self,
action: message::data::epr_mode::Action,
data: u8,
) -> Result<(), ProtocolError> {
let header = Header::new_data(
self.default_header,
self.counters.tx_message,
DataMessageType::EprMode,
1,
);
let mdo = EprModeDataObject::default().with_action(action).with_data(data);
self.transmit(Message::new_with_data(header, Data::EprMode(mdo))).await
}
pub async fn request_power(&mut self, power_source_request: request::PowerSource) -> Result<(), ProtocolError> {
assert!(matches!(self.default_header.port_power_role(), PowerRole::Sink));
let message_type = power_source_request.message_type();
let num_objects = power_source_request.num_objects();
let header = Header::new_data(self.default_header, self.counters.tx_message, message_type, num_objects);
self.transmit(Message::new_with_data(header, Data::Request(power_source_request)))
.await
}
async fn transmit_chunk_request(
&mut self,
message_type: ExtendedMessageType,
chunk_number: u8,
) -> Result<(), RxError> {
trace!("Transmit chunk request for {:?} chunk {}", message_type, chunk_number);
let ext_header = message::extended::ExtendedHeader::default()
.with_chunked(true)
.with_request_chunk(true)
.with_chunk_number(chunk_number);
let header = Header::new_extended(self.default_header, self.counters.tx_message, message_type, 1);
let mut buffer = Self::get_message_buffer();
let mut offset = header.to_bytes(&mut buffer);
offset += ext_header.to_bytes(&mut buffer[offset..]);
offset += 2;
if DRIVER::HAS_AUTO_RETRY {
match self.driver.transmit(&buffer[..offset]).await {
Ok(()) => {
self.counters.retry.reset();
_ = self.counters.tx_message.increment();
Ok(())
}
Err(DriverTxError::HardReset) => Err(RxError::HardReset),
Err(DriverTxError::Discarded) => Err(RxError::ReceiveTimeout),
}
} else {
match self.transmit_inner(&buffer[..offset]).await {
Ok(_) => self.wait_for_good_crc().await,
Err(TxError::HardReset) => Err(RxError::HardReset),
Err(TxError::UnchunkedExtendedMessagesNotSupported | TxError::AvsVoltageAlignmentInvalid) => {
unreachable!("validation should happen before transmit_inner")
}
}
}
}
pub async fn transmit_sink_capabilities(
&mut self,
capabilities: message::data::sink_capabilities::SinkCapabilities,
) -> Result<(), ProtocolError> {
let num_objects = capabilities.num_objects();
let header = Header::new_data(
self.default_header,
self.counters.tx_message,
DataMessageType::SinkCapabilities,
num_objects,
);
self.transmit(Message::new_with_data(header, Data::SinkCapabilities(capabilities)))
.await
}
pub async fn transmit_epr_sink_capabilities(
&mut self,
capabilities: message::data::sink_capabilities::SinkCapabilities,
) -> Result<(), ProtocolError> {
let pdos: heapless::Vec<_, 7> = capabilities.0.iter().cloned().collect();
let extended_payload = message::extended::Extended::EprSinkCapabilities(pdos);
let header = Header::new_extended(
self.default_header,
self.counters.tx_message,
ExtendedMessageType::EprSinkCapabilities,
0, );
let mut message = Message::new(header);
message.payload = Some(Payload::Extended(extended_payload));
self.transmit(message).await
}
}
#[cfg(test)]
mod tests {
use core::iter::zip;
use super::ProtocolLayer;
use super::message::data::Data;
use super::message::data::source_capabilities::SourceCapabilities;
use super::message::header::Header;
use crate::dummy::{
DUMMY_CAPABILITIES, DummyDriver, DummyTimer, MAX_DATA_MESSAGE_SIZE, get_dummy_source_capabilities,
};
use crate::protocol_layer::message::Payload;
fn get_protocol_layer() -> ProtocolLayer<DummyDriver<MAX_DATA_MESSAGE_SIZE>, DummyTimer> {
ProtocolLayer::new(
DummyDriver::new(),
Header::new_template(
crate::DataRole::Ufp,
crate::PowerRole::Sink,
super::message::header::SpecificationRevision::R3_X,
),
)
}
#[tokio::test]
async fn test_it() {
let mut protocol_layer = get_protocol_layer();
protocol_layer.driver.inject_received_data(&DUMMY_CAPABILITIES);
let message = protocol_layer.receive_message().await.unwrap();
if let Some(Payload::Data(Data::SourceCapabilities(SourceCapabilities(caps)))) = message.payload {
for (cap, dummy_cap) in zip(caps, get_dummy_source_capabilities()) {
assert_eq!(cap, dummy_cap);
}
} else {
panic!()
}
}
}