use core::fmt;
use crate::error::{Error, ErrorCode};
use crate::utils::storage::WriteBuf;
bitflags::bitflags! {
#[repr(transparent)]
#[derive(Default, Debug, Copy, Clone, Eq, PartialEq, Hash)]
pub struct BtpFlags: u8 {
const HANDSHAKE = 0x40;
const MANAGEMENT = 0x20;
const ACK = 0x08;
const ENDING_SEGMENT = 0x04;
const CONTINUE = 0x02;
const BEGINNING_SEGMENT = 0x01;
}
}
impl fmt::Display for BtpFlags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut sep = false;
for flag in [
Self::HANDSHAKE,
Self::MANAGEMENT,
Self::ACK,
Self::BEGINNING_SEGMENT,
Self::CONTINUE,
Self::ENDING_SEGMENT,
] {
if self.contains(flag) {
if sep {
write!(f, "|")?;
}
let str = match flag {
Self::HANDSHAKE => "H",
Self::MANAGEMENT => "M",
Self::ACK => "A",
Self::BEGINNING_SEGMENT => "B",
Self::CONTINUE => "C",
Self::ENDING_SEGMENT => "E",
_ => "?",
};
write!(f, "{}", str)?;
sep = true;
}
}
Ok(())
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for BtpFlags {
fn format(&self, f: defmt::Formatter<'_>) {
let mut sep = false;
for flag in [
Self::HANDSHAKE,
Self::MANAGEMENT,
Self::ACK,
Self::BEGINNING_SEGMENT,
Self::CONTINUE,
Self::ENDING_SEGMENT,
] {
if self.contains(flag) {
if sep {
defmt::write!(f, "|");
}
let str = match flag {
Self::HANDSHAKE => "H",
Self::MANAGEMENT => "M",
Self::ACK => "A",
Self::BEGINNING_SEGMENT => "B",
Self::CONTINUE => "C",
Self::ENDING_SEGMENT => "E",
_ => "?",
};
defmt::write!(f, "{}", str);
sep = true;
}
}
}
}
#[derive(Debug, Default, Clone)]
pub struct BtpHdr {
flags: BtpFlags,
opcode: u8,
ack_num: u8,
seq_num: u8,
msg_len: u16,
}
impl BtpHdr {
#[inline(always)]
pub const fn new() -> Self {
Self {
flags: BtpFlags::empty(),
opcode: 0,
ack_num: 0,
seq_num: 0,
msg_len: 0,
}
}
pub fn from<I>(msg: I) -> Result<Self, Error>
where
I: Iterator<Item = u8>,
{
let mut hdr = Self::new();
hdr.decode(msg)?;
Ok(hdr)
}
pub fn is_handshake(&self) -> bool {
self.flags.contains(BtpFlags::HANDSHAKE)
}
pub fn set_handshake(&mut self) {
self.flags |= BtpFlags::HANDSHAKE | BtpFlags::BEGINNING_SEGMENT | BtpFlags::ENDING_SEGMENT;
}
pub fn get_opcode(&self) -> Option<u8> {
self.flags
.contains(BtpFlags::MANAGEMENT)
.then_some(self.opcode)
}
pub fn set_opcode(&mut self, opcode: Option<u8>) {
if let Some(opcode) = opcode {
self.flags |= BtpFlags::MANAGEMENT;
self.opcode = opcode
} else {
self.flags.remove(BtpFlags::MANAGEMENT);
self.opcode = 0;
}
}
pub fn get_ack(&self) -> Option<u8> {
self.flags.contains(BtpFlags::ACK).then_some(self.ack_num)
}
pub fn set_ack(&mut self, ack_num: Option<u8>) {
if let Some(ack_num) = ack_num {
self.flags |= BtpFlags::ACK;
self.ack_num = ack_num;
} else {
self.flags.remove(BtpFlags::ACK);
self.ack_num = 0;
}
}
pub fn get_seq(&self) -> Option<u8> {
(!self.flags.contains(BtpFlags::HANDSHAKE)).then_some(self.seq_num)
}
pub fn set_seq(&mut self, seq_num: Option<u8>) {
if let Some(seq_num) = seq_num {
self.flags.remove(BtpFlags::HANDSHAKE);
self.seq_num = seq_num;
} else {
self.flags |= BtpFlags::HANDSHAKE;
self.seq_num = 0;
}
}
pub fn is_standalone_ack(&self) -> bool {
!self.is_handshake()
&& self.get_msg_len().is_none()
&& !self.is_continue()
&& !self.is_final()
&& self.get_ack().is_some()
}
pub fn get_msg_len(&self) -> Option<u16> {
(self.flags.contains(BtpFlags::BEGINNING_SEGMENT)
&& !self.flags.contains(BtpFlags::HANDSHAKE))
.then_some(self.msg_len)
}
pub fn set_msg_len(&mut self, msg_len: Option<u16>) {
if let Some(msg_len) = msg_len {
self.flags |= BtpFlags::BEGINNING_SEGMENT;
self.msg_len = msg_len;
} else {
self.flags.remove(BtpFlags::BEGINNING_SEGMENT);
self.msg_len = 0;
}
}
pub fn is_continue(&self) -> bool {
self.flags.contains(BtpFlags::CONTINUE)
}
pub fn set_continue(&mut self) {
self.flags |= BtpFlags::CONTINUE;
}
pub fn is_final(&self) -> bool {
self.flags.contains(BtpFlags::ENDING_SEGMENT)
}
pub fn set_final(&mut self) {
self.flags |= BtpFlags::ENDING_SEGMENT;
}
fn decode<I>(&mut self, mut msg: I) -> Result<(), Error>
where
I: Iterator<Item = u8>,
{
self.flags = BtpFlags::from_bits_truncate(msg.next().ok_or(ErrorCode::Invalid)?);
if self.flags.contains(BtpFlags::MANAGEMENT) {
self.opcode = msg.next().ok_or(ErrorCode::Invalid)?;
}
if self.flags.contains(BtpFlags::ACK) {
self.ack_num = msg.next().ok_or(ErrorCode::Invalid)?;
}
if !self.flags.contains(BtpFlags::HANDSHAKE) {
self.seq_num = msg.next().ok_or(ErrorCode::Invalid)?;
}
if self.flags.contains(BtpFlags::BEGINNING_SEGMENT)
&& !self.flags.contains(BtpFlags::HANDSHAKE)
{
let msg_len = [
msg.next().ok_or(ErrorCode::Invalid)?,
msg.next().ok_or(ErrorCode::Invalid)?,
];
self.msg_len = u16::from_le_bytes(msg_len);
}
trace!("[decode] {}", self);
Ok(())
}
pub fn encode(&self, resp_buf: &mut WriteBuf) -> Result<(), Error> {
trace!("[encode] {}", self);
resp_buf.le_u8(self.flags.bits())?;
if self.flags.contains(BtpFlags::MANAGEMENT) {
resp_buf.le_u8(self.opcode)?;
}
if self.flags.contains(BtpFlags::ACK) {
resp_buf.le_u8(self.ack_num)?;
}
if !self.flags.contains(BtpFlags::HANDSHAKE) {
resp_buf.le_u8(self.seq_num)?;
}
if self.flags.contains(BtpFlags::BEGINNING_SEGMENT)
&& !self.flags.contains(BtpFlags::HANDSHAKE)
{
resp_buf.le_u16(self.msg_len)?;
}
Ok(())
}
pub fn len(&self) -> usize {
let mut len = 1;
if self.flags.contains(BtpFlags::MANAGEMENT) {
len += 1;
}
if self.flags.contains(BtpFlags::ACK) {
len += 1;
}
if !self.flags.contains(BtpFlags::HANDSHAKE) {
len += 1;
}
if self.flags.contains(BtpFlags::BEGINNING_SEGMENT)
&& !self.flags.contains(BtpFlags::HANDSHAKE)
{
len += 2;
}
len
}
}
impl fmt::Display for BtpHdr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.flags.is_empty() {
write!(f, "{}", self.flags)?;
}
if let Some(opcode) = self.get_opcode() {
write!(f, ",OP:{:x}", opcode)?;
}
if let Some(ack_num) = self.get_ack() {
write!(f, ",ACTR:{:x}", ack_num)?;
}
if let Some(seq_num) = self.get_seq() {
write!(f, ",CTR:{:x}", seq_num)?;
}
if let Some(msg_len) = self.get_msg_len() {
write!(f, ",LEN:{:x}", msg_len)?;
}
Ok(())
}
}
#[cfg(feature = "defmt")]
impl defmt::Format for BtpHdr {
fn format(&self, f: defmt::Formatter<'_>) {
if !self.flags.is_empty() {
defmt::write!(f, "{}", self.flags);
}
if let Some(opcode) = self.get_opcode() {
defmt::write!(f, ",OP:{:x}", opcode);
}
if let Some(ack_num) = self.get_ack() {
defmt::write!(f, ",ACTR:{:x}", ack_num);
}
if let Some(seq_num) = self.get_seq() {
defmt::write!(f, ",CTR:{:x}", seq_num);
}
if let Some(msg_len) = self.get_msg_len() {
defmt::write!(f, ",LEN:{:x}", msg_len);
}
}
}
#[derive(Debug, Default)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct HandshakeReq {
pub versions: u32,
pub mtu: u16,
pub window_size: u8,
}
impl HandshakeReq {
pub fn from<I>(msg: I) -> Result<Self, Error>
where
I: Iterator<Item = u8>,
{
let mut req = Self::default();
req.decode(msg)?;
Ok(req)
}
pub fn versions(&self) -> impl Iterator<Item = u8> + '_ {
(0..7u8)
.map(|index| ((self.versions >> (index * 4)) & 0xff) as u8)
.filter(|version| *version > 0)
}
#[allow(unused)]
pub fn set_versions<I>(&mut self, versions: I)
where
I: Iterator<Item = u8>,
{
for (index, version) in (0_u8..).zip(versions) {
self.versions |= (version as u32) << (index * 4);
}
}
fn decode<I>(&mut self, mut msg: I) -> Result<(), Error>
where
I: Iterator<Item = u8>,
{
self.versions = u32::from_le_bytes([
msg.next().ok_or(ErrorCode::Invalid)?,
msg.next().ok_or(ErrorCode::Invalid)?,
msg.next().ok_or(ErrorCode::Invalid)?,
msg.next().ok_or(ErrorCode::Invalid)?,
]);
self.mtu = u16::from_le_bytes([
msg.next().ok_or(ErrorCode::Invalid)?,
msg.next().ok_or(ErrorCode::Invalid)?,
]);
self.window_size = msg.next().ok_or(ErrorCode::Invalid)?;
Ok(())
}
pub fn encode(&self, resp_buf: &mut WriteBuf) -> Result<(), Error> {
resp_buf.le_u32(self.versions)?;
resp_buf.le_u16(self.mtu)?;
resp_buf.le_u8(self.window_size)?;
Ok(())
}
}
#[derive(Debug, Default)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct HandshakeResp {
pub version: u8,
pub mtu: u16,
pub window_size: u8,
}
impl HandshakeResp {
pub fn from<I>(msg: I) -> Result<Self, Error>
where
I: Iterator<Item = u8>,
{
let mut resp = Self::default();
resp.decode(msg)?;
Ok(resp)
}
fn decode<I>(&mut self, mut msg: I) -> Result<(), Error>
where
I: Iterator<Item = u8>,
{
self.version = msg.next().ok_or(ErrorCode::Invalid)?;
self.mtu = u16::from_le_bytes([
msg.next().ok_or(ErrorCode::Invalid)?,
msg.next().ok_or(ErrorCode::Invalid)?,
]);
self.window_size = msg.next().ok_or(ErrorCode::Invalid)?;
Ok(())
}
pub fn encode(&self, resp_buf: &mut WriteBuf) -> Result<(), Error> {
resp_buf.le_u8(self.version)?;
resp_buf.le_u16(self.mtu)?;
resp_buf.le_u8(self.window_size)?;
Ok(())
}
}