use std::io::Cursor;
use super::proto::*;
use log::{debug, error, warn};
use num_traits::FromPrimitive;
use thiserror::Error;
use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
const MINIMUM_EXTENSION_ITEM_SIZE: u32 = 5;
#[derive(Debug, Eq, PartialEq, Clone)]
pub(crate) enum TcpClPacket {
ContactHeader(ContactHeaderFlags),
SessInit(SessInitData),
SessTerm(SessTermData),
XferSeg(XferSegData),
XferAck(XferAckData),
XferRefuse(XferRefuseData),
KeepAlive,
MsgReject(MsgRejectData),
}
impl TcpClPacket {
pub async fn write(&self, writer: &mut (impl AsyncWrite + Unpin)) -> anyhow::Result<()> {
match self {
TcpClPacket::SessInit(sess_init_data) => {
writer.write_u8(MessageType::SessInit as u8).await?;
writer.write_u16(sess_init_data.keepalive).await?;
writer.write_u64(sess_init_data.segment_mru).await?;
writer.write_u64(sess_init_data.transfer_mru).await?;
writer
.write_u16(sess_init_data.node_id.len() as u16)
.await?;
let node_id_bytes = sess_init_data.node_id.as_bytes();
writer.write_all(node_id_bytes).await?;
writer.write_u32(0).await?;
}
TcpClPacket::SessTerm(sess_term_data) => {
writer.write_u8(MessageType::SessTerm as u8).await?;
writer.write_u8(sess_term_data.flags.bits()).await?;
writer.write_u8(sess_term_data.reason as u8).await?;
}
TcpClPacket::XferSeg(xfer_seg_data) => {
writer.write_u8(MessageType::XferSegment as u8).await?;
writer.write_u8(xfer_seg_data.flags.bits()).await?;
writer.write_u64(xfer_seg_data.tid).await?;
if xfer_seg_data.flags.contains(XferSegmentFlags::START) {
let mut ext_data = Cursor::new(Vec::new());
let mut len = 0u32;
for ext in &xfer_seg_data.extensions {
ext_data.write_u8(ext.flags.bits()).await?;
ext_data.write_u16(ext.item_type as u16).await?;
ext_data.write_u16(ext.data.len() as u16).await?;
ext_data.write_all(ext.data.as_ref()).await?;
len += MINIMUM_EXTENSION_ITEM_SIZE + ext.data.len() as u32;
}
writer.write_u32(len).await?;
if len > 0 {
writer.write_all(ext_data.get_ref()).await?;
}
}
writer.write_u64(xfer_seg_data.len).await?;
if xfer_seg_data.len > 0 {
writer.write_all(xfer_seg_data.buf.as_ref()).await?;
}
}
TcpClPacket::XferAck(xfer_ack_data) => {
writer.write_u8(MessageType::XferAck as u8).await?;
writer.write_u8(xfer_ack_data.flags.bits()).await?;
writer.write_u64(xfer_ack_data.tid).await?;
writer.write_u64(xfer_ack_data.len).await?;
}
TcpClPacket::XferRefuse(xfer_refuse_data) => {
writer.write_u8(MessageType::XferRefuse as u8).await?;
writer.write_u8(xfer_refuse_data.reason as u8).await?;
writer.write_u64(xfer_refuse_data.tid).await?;
}
TcpClPacket::KeepAlive => {
writer.write_u8(MessageType::Keepalive as u8).await?;
}
TcpClPacket::MsgReject(msg_reject_data) => {
writer.write_u8(MessageType::MsgReject as u8).await?;
writer.write_u8(msg_reject_data.reason as u8).await?;
writer.write_u8(msg_reject_data.header).await?;
}
TcpClPacket::ContactHeader(flags) => {
writer.write_all(b"dtn!").await?;
writer.write_u8(4).await?;
writer.write_u8(flags.bits()).await?;
}
}
writer.flush().await?;
Ok(())
}
pub async fn read(reader: &mut (impl AsyncRead + Unpin)) -> Result<Self, TcpClError> {
let mtype = reader.read_u8().await?;
if let Some(mtype) = MessageType::from_u8(mtype) {
match mtype {
MessageType::XferSegment => {
let flags = XferSegmentFlags::from_bits_truncate(reader.read_u8().await?);
let tid: u64 = reader.read_u64().await?;
let mut extensions = Vec::new();
if flags.contains(XferSegmentFlags::START) {
let mut ext_len: u32 = reader.read_u32().await?;
if ext_len != 0 {
debug!("parsing transfer extensions");
while ext_len >= MINIMUM_EXTENSION_ITEM_SIZE {
let flag = reader.read_u8().await?;
let item_type = reader.read_u16().await?;
let item_length = reader.read_u16().await?;
ext_len =
ext_len - MINIMUM_EXTENSION_ITEM_SIZE - item_length as u32;
let mut data = vec![0; item_length as usize];
reader.read_exact(&mut data).await?;
if let Some(item_type) =
TransferExtensionItemType::from_u16(item_type)
{
let transfer_extension = TransferExtensionItem {
flags: TransferExtensionItemFlags::from_bits_truncate(flag),
item_type,
data: data.into(),
};
extensions.push(transfer_extension);
}
}
if ext_len != 0 {
warn!("malformed transfer extensions, ignoring rest");
for _ in 0..ext_len {
reader.read_u8().await?;
}
}
}
}
let len = reader.read_u64().await?;
debug!("Reading xfer segment with len {}", len);
let mut data = vec![0; len as usize];
if len > 0 {
reader.read_exact(&mut data).await?;
}
let seg = XferSegData {
flags,
tid,
len,
buf: data.into(),
extensions,
};
Ok(TcpClPacket::XferSeg(seg))
}
MessageType::XferAck => {
let flags = XferSegmentFlags::from_bits_truncate(reader.read_u8().await?);
let tid: u64 = reader.read_u64().await?;
let len: u64 = reader.read_u64().await?;
let data = XferAckData { flags, tid, len };
Ok(TcpClPacket::XferAck(data))
}
MessageType::XferRefuse => {
let rcode = reader.read_u8().await?;
if let Some(reason) = XferRefuseReasonCode::from_u8(rcode) {
let tid: u64 = reader.read_u64().await?;
let data = XferRefuseData { reason, tid };
Ok(TcpClPacket::XferRefuse(data))
} else {
Err(TcpClError::UnknownResaonCode(rcode))
}
}
MessageType::Keepalive => Ok(TcpClPacket::KeepAlive),
MessageType::SessTerm => {
let flags = SessTermFlags::from_bits_truncate(reader.read_u8().await?);
let rcode = reader.read_u8().await?;
if let Some(reason) = SessTermReasonCode::from_u8(rcode) {
let data = SessTermData { flags, reason };
Ok(TcpClPacket::SessTerm(data))
} else {
Err(TcpClError::UnknownResaonCode(rcode))
}
}
MessageType::SessInit => {
let keepalive = reader.read_u16().await?;
let segment_mru = reader.read_u64().await?;
let transfer_mru = reader.read_u64().await?;
let node_id_len = reader.read_u16().await? as usize;
let mut node_buffer = vec![0u8; node_id_len];
reader.read_exact(&mut node_buffer).await?;
let node_id: String = String::from_utf8_lossy(&node_buffer).into();
let ext_items = reader.read_u32().await?;
if ext_items != 0 {
return Err(TcpClError::SessionExtensionItemsUnsupported);
}
let data = SessInitData {
keepalive,
segment_mru,
transfer_mru,
node_id,
};
Ok(TcpClPacket::SessInit(data))
}
MessageType::MsgReject => {
let reason_code = reader.read_u8().await?;
let header = reader.read_u8().await?;
if let Some(reason) = MsgRejectReasonCode::from_u8(reason_code) {
let data = MsgRejectData { reason, header };
Ok(TcpClPacket::MsgReject(data))
} else {
Err(TcpClError::UnknownResaonCode(reason_code))
}
}
}
} else if mtype == b'd' {
let mut buf: [u8; 6] = [0; 6];
reader.read_exact(&mut buf).await?;
if &buf[0..4] != b"dtn!" {
return Err(TcpClError::InvalidMagic);
}
if buf[4] != 4 {
return Err(TcpClError::UnsupportedVersion);
}
Ok(TcpClPacket::ContactHeader(
ContactHeaderFlags::from_bits_truncate(buf[5]),
))
} else {
Err(TcpClError::UnknownPacketType(mtype))
}
}
}
#[derive(Error, Debug)]
pub(crate) enum TcpClError {
#[error("error reading bytes")]
ReadError(#[from] io::Error),
#[error("unknown packet type ({0}) encountered")]
UnknownPacketType(u8),
#[error("unknown reason code ({0}) encountered")]
UnknownResaonCode(u8),
#[error("session extension items found but unsupported")]
SessionExtensionItemsUnsupported,
#[error("unexpected packet received")]
UnexpectedPacket,
#[error("invalid magic in contact header")]
InvalidMagic,
#[error("unsupported version in contact header")]
UnsupportedVersion,
}