use std::io::{Read, Write};
use crate::{Error, Result};
pub const PROTOCOL_VERSION: u32 = 0;
pub const DEFAULT_MAX_PACKET_LEN: usize = 64 * 1024 * 1024;
const PACKET_VERSION: u8 = 1;
const PACKET_REQUEST_DATA: u8 = 2;
const PACKET_DATA: u8 = 3;
#[derive(
Clone, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, serde::Serialize, serde::Deserialize,
)]
pub struct DataStreamId(String);
impl DataStreamId {
#[must_use]
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.0
}
#[must_use]
fn as_bytes(&self) -> &[u8] {
self.0.as_bytes()
}
#[must_use]
pub fn into_string(self) -> String {
self.0
}
}
impl From<String> for DataStreamId {
fn from(id: String) -> Self {
Self(id)
}
}
impl From<&str> for DataStreamId {
fn from(id: &str) -> Self {
Self(id.to_string())
}
}
impl AsRef<str> for DataStreamId {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl std::borrow::Borrow<str> for DataStreamId {
fn borrow(&self) -> &str {
self.as_str()
}
}
impl std::fmt::Display for DataStreamId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl std::str::FromStr for DataStreamId {
type Err = std::convert::Infallible;
fn from_str(id: &str) -> std::result::Result<Self, Self::Err> {
Ok(Self::from(id))
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct RequestData {
pub stream_id: DataStreamId,
pub window: usize,
}
impl RequestData {
#[must_use]
pub fn new(stream_id: impl Into<DataStreamId>, window: usize) -> Self {
Self {
stream_id: stream_id.into(),
window,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Data {
pub stream_id: DataStreamId,
pub data: Vec<u8>,
}
impl Data {
#[must_use]
pub fn new(stream_id: impl Into<DataStreamId>, data: impl Into<Vec<u8>>) -> Self {
Self {
stream_id: stream_id.into(),
data: data.into(),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum Packet {
Version(u32),
RequestData(RequestData),
Data(Data),
}
impl Packet {
#[must_use]
pub fn as_ref(&self) -> PacketRef<'_> {
match self {
Packet::Version(version) => PacketRef::Version(*version),
Packet::RequestData(req) => PacketRef::RequestData {
stream_id: &req.stream_id,
window: req.window,
},
Packet::Data(data) => PacketRef::Data {
stream_id: &data.stream_id,
data: &data.data,
},
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum PacketRef<'a> {
Version(u32),
RequestData {
stream_id: &'a DataStreamId,
window: usize,
},
Data {
stream_id: &'a DataStreamId,
data: &'a [u8],
},
}
fn as_u32(value: usize, what: &str) -> Result<u32> {
u32::try_from(value).map_err(|_| Error::msg(format!("{what} does not fit in u32: {value}")))
}
fn packet_len(parts: &[usize]) -> Result<u32> {
let len = parts.iter().try_fold(1usize, |acc, part| {
acc.checked_add(*part)
.ok_or_else(|| Error::msg("packet length overflow"))
})?;
as_u32(len, "packet length")
}
fn read_raw_packet<R: Read>(reader: &mut R, max_packet_len: usize) -> Result<Option<Vec<u8>>> {
let mut len = [0u8; 4];
match reader.read(&mut len[..1]) {
Ok(0) => return Ok(None), Ok(1) => reader
.read_exact(&mut len[1..])
.map_err(|e| Error::other(e, "reading DataStream header"))?,
Ok(_) => unreachable!("one byte buffer cannot read more than one byte"),
Err(e) => return Err(Error::other(e, "reading DataStream first byte")),
}
let len = u32::from_le_bytes(len) as usize;
if len == 0 {
return Err(Error::msg("packet length must include a packet type byte"));
}
if len > max_packet_len {
return Err(Error::msg(format!(
"packet length {len} exceeds limit {max_packet_len}"
)));
}
let mut packet = vec![0u8; len];
reader
.read_exact(&mut packet)
.map_err(|e| Error::other(e, "reading DataStream packet"))?;
Ok(Some(packet))
}
fn parse_packet(packet: &[u8]) -> Result<Packet> {
if packet.is_empty() {
return Err(Error::msg("packet payload is empty"));
}
let ret = match packet[0] {
PACKET_VERSION => parse_version(packet).map(Packet::Version),
PACKET_REQUEST_DATA => parse_request_data(packet).map(Packet::RequestData),
PACKET_DATA => parse_data(packet).map(Packet::Data),
other => Err(Error::msg(format!("unsupported packet type {other}"))),
}?;
log::trace!("Got packet: {ret:?}");
Ok(ret)
}
fn parse_version(packet: &[u8]) -> Result<u32> {
if packet.len() != 5 {
return Err(Error::msg(format!(
"version packet has length {}, want 5",
packet.len()
)));
}
Ok(u32::from_le_bytes(packet[1..5].try_into()?))
}
fn parse_request_data(packet: &[u8]) -> Result<RequestData> {
if packet.len() < 6 {
return Err(Error::msg(format!(
"RequestData packet has length {}, want at least 6",
packet.len()
)));
}
Ok(RequestData {
window: u32::from_le_bytes(packet[1..5].try_into()?) as usize,
stream_id: String::from_utf8(packet[5..].to_vec())?.into(),
})
}
fn parse_data(packet: &[u8]) -> Result<Data> {
if packet.len() < 6 {
return Err(Error::msg(format!(
"Data packet has length {}, want at least 6",
packet.len()
)));
}
let stream_id_len = u32::from_le_bytes(packet[1..5].try_into()?) as usize;
let data_start = 5usize
.checked_add(stream_id_len)
.ok_or_else(|| Error::msg("Data packet stream ID length overflow"))?;
if packet.len() < data_start {
return Err(Error::msg(format!(
"Data packet stream ID length {stream_id_len} exceeds packet length {}",
packet.len()
)));
}
Ok(Data {
stream_id: String::from_utf8(packet[5..data_start].to_vec())?.into(),
data: packet[data_start..].to_vec(),
})
}
fn validate_version(packet: Packet) -> Result<()> {
match packet {
Packet::Version(PROTOCOL_VERSION) => Ok(()),
Packet::Version(version) => Err(Error::msg(format!(
"unsupported protocol version {version}"
))),
other => Err(Error::msg(format!(
"expected Version packet, got {other:?}"
))),
}
}
fn buffered_packet_len(buf: &[u8], max_packet_len: usize) -> Result<Option<usize>> {
if buf.len() < 4 {
return Ok(None);
}
let len = u32::from_le_bytes(buf[..4].try_into()?) as usize;
if len == 0 {
return Err(Error::msg("packet length must include a packet type byte"));
}
if len > max_packet_len {
return Err(Error::msg(format!(
"packet length {len} exceeds limit {max_packet_len}"
)));
}
let full_len = 4usize
.checked_add(len)
.ok_or_else(|| Error::msg("packet length overflow"))?;
if buf.len() < full_len {
return Ok(None);
}
Ok(Some(full_len))
}
pub struct BytesReader {
buf: Vec<u8>,
max_packet_len: usize,
}
impl BytesReader {
#[must_use]
pub fn new() -> Self {
Self {
buf: Vec::new(),
max_packet_len: DEFAULT_MAX_PACKET_LEN,
}
}
#[must_use]
pub fn with_max_packet_len(mut self, max_packet_len: usize) -> Self {
self.max_packet_len = max_packet_len;
self
}
pub fn push_bytes(&mut self, bytes: &[u8]) {
self.buf.extend_from_slice(bytes);
}
#[must_use]
pub fn buffered_len(&self) -> usize {
self.buf.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
pub fn clear(&mut self) {
self.buf.clear();
}
pub fn read_packet(&mut self) -> Result<Option<Packet>> {
let Some(full_len) = buffered_packet_len(&self.buf, self.max_packet_len)? else {
return Ok(None);
};
let ret = parse_packet(&self.buf[4..full_len]);
self.buf.drain(..full_len);
ret.map(Some)
}
pub fn read_version(&mut self) -> Result<bool> {
let Some(packet) = self.read_packet()? else {
return Ok(false);
};
validate_version(packet)?;
Ok(true)
}
}
impl Default for BytesReader {
fn default() -> Self {
Self::new()
}
}
pub struct SyncReader<R> {
reader: R,
max_packet_len: usize,
}
impl<R: Read> SyncReader<R> {
#[must_use]
pub fn new(reader: R) -> Self {
Self {
reader,
max_packet_len: DEFAULT_MAX_PACKET_LEN,
}
}
#[must_use]
pub fn with_max_packet_len(mut self, max_packet_len: usize) -> Self {
self.max_packet_len = max_packet_len;
self
}
#[must_use]
pub fn into_inner(self) -> R {
self.reader
}
pub fn read_packet(&mut self) -> Result<Option<Packet>> {
read_raw_packet(&mut self.reader, self.max_packet_len)?
.as_deref()
.map(parse_packet)
.transpose()
}
pub fn read_version(&mut self) -> Result<bool> {
let Some(packet) = self.read_packet()? else {
return Ok(false);
};
validate_version(packet)?;
Ok(true)
}
}
pub struct SyncWriter<W> {
writer: W,
}
impl<W: Write> SyncWriter<W> {
#[must_use]
pub fn new(writer: W) -> Self {
Self { writer }
}
#[must_use]
pub fn into_inner(self) -> W {
self.writer
}
pub fn write_packet(&mut self, packet: PacketRef<'_>) -> Result<()> {
match packet {
PacketRef::Version(version) => {
let len = packet_len(&[4])?;
self.writer.write_all(&len.to_le_bytes())?;
self.writer.write_all(&[PACKET_VERSION])?;
self.writer.write_all(&version.to_le_bytes())?;
}
PacketRef::RequestData { stream_id, window } => {
let stream_id = stream_id.as_bytes();
let len = packet_len(&[4, stream_id.len()])?;
let window = as_u32(window, "RequestData window")?;
self.writer.write_all(&len.to_le_bytes())?;
self.writer.write_all(&[PACKET_REQUEST_DATA])?;
self.writer.write_all(&window.to_le_bytes())?;
self.writer.write_all(stream_id)?;
}
PacketRef::Data { stream_id, data } => {
let stream_id = stream_id.as_bytes();
let len = packet_len(&[4, stream_id.len(), data.len()])?;
let stream_id_len = as_u32(stream_id.len(), "stream ID length")?;
self.writer.write_all(&len.to_le_bytes())?;
self.writer.write_all(&[PACKET_DATA])?;
self.writer.write_all(&stream_id_len.to_le_bytes())?;
self.writer.write_all(stream_id)?;
self.writer.write_all(data)?;
}
}
self.writer.flush()?;
Ok(())
}
pub fn write_version(&mut self) -> Result<()> {
self.write_packet(PacketRef::Version(PROTOCOL_VERSION))
}
pub fn write_request_data(&mut self, stream_id: &DataStreamId, window: usize) -> Result<()> {
self.write_packet(PacketRef::RequestData { stream_id, window })
}
pub fn write_data(&mut self, stream_id: &DataStreamId, data: &[u8]) -> Result<()> {
self.write_packet(PacketRef::Data { stream_id, data })
}
}
#[cfg(feature = "async")]
pub mod asynchronous {
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use super::{
DEFAULT_MAX_PACKET_LEN, DataStreamId, Error, PACKET_DATA, PACKET_REQUEST_DATA,
PACKET_VERSION, PROTOCOL_VERSION, Packet, PacketRef, Result, as_u32, packet_len,
parse_packet, validate_version,
};
async fn read_raw_packet<R: AsyncRead + Unpin>(
reader: &mut R,
max_packet_len: usize,
) -> Result<Option<Vec<u8>>> {
let mut len = [0u8; 4];
match reader.read(&mut len[..1]).await {
Ok(0) => return Ok(None), Ok(1) => {
reader
.read_exact(&mut len[1..])
.await
.map_err(|e| Error::other(e, "reading DataStream header"))?;
}
Ok(_) => unreachable!("one byte buffer cannot read more than one byte"),
Err(e) => return Err(Error::other(e, "reading DataStream first byte")),
}
let len = u32::from_le_bytes(len) as usize;
if len == 0 {
return Err(Error::msg("packet length must include a packet type byte"));
}
if len > max_packet_len {
return Err(Error::msg(format!(
"packet length {len} exceeds limit {max_packet_len}"
)));
}
let mut packet = vec![0u8; len];
reader.read_exact(&mut packet).await?;
Ok(Some(packet))
}
pub struct AsyncReader<R> {
reader: R,
max_packet_len: usize,
}
impl<R: AsyncRead + Unpin> AsyncReader<R> {
#[must_use]
pub fn new(reader: R) -> Self {
Self {
reader,
max_packet_len: DEFAULT_MAX_PACKET_LEN,
}
}
#[must_use]
pub fn with_max_packet_len(mut self, max_packet_len: usize) -> Self {
self.max_packet_len = max_packet_len;
self
}
#[must_use]
pub fn into_inner(self) -> R {
self.reader
}
pub async fn read_packet(&mut self) -> Result<Option<Packet>> {
read_raw_packet(&mut self.reader, self.max_packet_len)
.await?
.as_deref()
.map(parse_packet)
.transpose()
}
pub async fn read_version(&mut self) -> Result<bool> {
let Some(packet) = self.read_packet().await? else {
return Ok(false);
};
validate_version(packet)?;
Ok(true)
}
}
pub struct AsyncWriter<W> {
writer: W,
}
impl<W: AsyncWrite + Unpin> AsyncWriter<W> {
#[must_use]
pub fn new(writer: W) -> Self {
Self { writer }
}
#[must_use]
pub fn into_inner(self) -> W {
self.writer
}
pub async fn write_packet(&mut self, packet: PacketRef<'_>) -> Result<()> {
match packet {
PacketRef::Version(version) => {
let len = packet_len(&[4])?;
self.writer.write_all(&len.to_le_bytes()).await?;
self.writer.write_all(&[PACKET_VERSION]).await?;
self.writer.write_all(&version.to_le_bytes()).await?;
}
PacketRef::RequestData { stream_id, window } => {
let stream_id = stream_id.as_bytes();
let len = packet_len(&[4, stream_id.len()])?;
let window = as_u32(window, "RequestData window")?;
self.writer.write_all(&len.to_le_bytes()).await?;
self.writer.write_all(&[PACKET_REQUEST_DATA]).await?;
self.writer.write_all(&window.to_le_bytes()).await?;
self.writer.write_all(stream_id).await?;
}
PacketRef::Data { stream_id, data } => {
let stream_id = stream_id.as_bytes();
let len = packet_len(&[4, stream_id.len(), data.len()])?;
let stream_id_len = as_u32(stream_id.len(), "stream ID length")?;
self.writer.write_all(&len.to_le_bytes()).await?;
self.writer.write_all(&[PACKET_DATA]).await?;
self.writer.write_all(&stream_id_len.to_le_bytes()).await?;
self.writer.write_all(stream_id).await?;
self.writer.write_all(data).await?;
}
}
self.writer.flush().await?;
Ok(())
}
pub async fn write_version(&mut self) -> Result<()> {
self.write_packet(PacketRef::Version(PROTOCOL_VERSION))
.await
}
pub async fn write_request_data(
&mut self,
stream_id: &DataStreamId,
window: usize,
) -> Result<()> {
self.write_packet(PacketRef::RequestData { stream_id, window })
.await
}
pub async fn write_data(&mut self, stream_id: &DataStreamId, data: &[u8]) -> Result<()> {
self.write_packet(PacketRef::Data { stream_id, data }).await
}
}
}
#[cfg(feature = "async")]
pub use asynchronous::{AsyncReader, AsyncWriter};
#[cfg(test)]
mod tests {
use std::io::Cursor;
use super::*;
#[test]
fn sync_roundtrip() -> Result<()> {
let stream_id = DataStreamId::new("rtl-sdr");
let mut bytes = Vec::new();
{
let mut writer = SyncWriter::new(&mut bytes);
writer.write_version()?;
writer.write_request_data(&stream_id, 1234)?;
writer.write_data(&stream_id, &[1, 2, 3, 4])?;
}
let mut reader = SyncReader::new(Cursor::new(bytes));
assert!(reader.read_version()?);
assert_eq!(
reader.read_packet()?,
Some(Packet::RequestData(RequestData::new("rtl-sdr", 1234)))
);
assert_eq!(
reader.read_packet()?,
Some(Packet::Data(Data::new("rtl-sdr", vec![1, 2, 3, 4])))
);
assert_eq!(reader.read_packet()?, None);
Ok(())
}
#[test]
fn sync_rejects_non_version_handshake() -> Result<()> {
let mut bytes = Vec::new();
SyncWriter::new(&mut bytes).write_request_data(&DataStreamId::new("rtl-sdr"), 1234)?;
let mut reader = SyncReader::new(Cursor::new(bytes));
let err = reader.read_version().unwrap_err().to_string();
assert!(err.contains("expected Version packet"), "{err}");
Ok(())
}
#[test]
fn bytes_reader_handles_partial_frames() -> Result<()> {
let stream_id = DataStreamId::new("rtl-sdr");
let mut bytes = Vec::new();
{
let mut writer = SyncWriter::new(&mut bytes);
writer.write_version()?;
writer.write_request_data(&stream_id, 1234)?;
writer.write_data(&stream_id, &[1, 2, 3, 4])?;
}
let mut reader = BytesReader::new();
reader.push_bytes(&bytes[..3]);
assert!(!reader.read_version()?);
assert_eq!(reader.buffered_len(), 3);
reader.push_bytes(&bytes[3..9]);
assert!(reader.read_version()?);
assert!(reader.is_empty());
reader.push_bytes(&bytes[9..12]);
assert_eq!(reader.read_packet()?, None);
reader.push_bytes(&bytes[12..]);
assert_eq!(
reader.read_packet()?,
Some(Packet::RequestData(RequestData::new("rtl-sdr", 1234)))
);
assert_eq!(
reader.read_packet()?,
Some(Packet::Data(Data::new("rtl-sdr", vec![1, 2, 3, 4])))
);
assert_eq!(reader.read_packet()?, None);
assert!(reader.is_empty());
Ok(())
}
#[cfg(feature = "async")]
#[tokio::test]
async fn async_roundtrip() -> Result<()> {
let stream_id = DataStreamId::new("rtl-sdr");
let (left, right) = tokio::io::duplex(1024);
let mut writer = AsyncWriter::new(left);
let mut reader = AsyncReader::new(right);
writer.write_version().await?;
assert!(reader.read_version().await?);
writer.write_request_data(&stream_id, 1234).await?;
assert_eq!(
reader.read_packet().await?,
Some(Packet::RequestData(RequestData::new("rtl-sdr", 1234)))
);
writer.write_data(&stream_id, &[1, 2, 3, 4]).await?;
assert_eq!(
reader.read_packet().await?,
Some(Packet::Data(Data::new("rtl-sdr", vec![1, 2, 3, 4])))
);
Ok(())
}
}