#![cfg(not(loom))]
use std::{
io,
sync::{
Arc,
atomic::{AtomicU8, Ordering},
},
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio_util::codec::{Decoder, Encoder, Framed};
use wireframe::{
app::{Envelope, Packet, WireframeApp},
codec::FrameCodec,
correlation::CorrelatableFrame,
serializer::{BincodeSerializer, Serializer},
};
#[derive(Clone, Debug)]
struct TaggedFrame {
tag: u8,
payload: Vec<u8>,
}
#[derive(Debug)]
struct TaggedFrameCodec {
max_frame_length: usize,
counter: AtomicU8,
}
impl TaggedFrameCodec {
fn new(max_frame_length: usize) -> Self {
Self {
max_frame_length,
counter: AtomicU8::new(0),
}
}
fn next_tag(&self) -> u8 { self.counter.fetch_add(1, Ordering::SeqCst).wrapping_add(1) }
}
impl Clone for TaggedFrameCodec {
fn clone(&self) -> Self {
Self {
max_frame_length: self.max_frame_length,
counter: AtomicU8::new(0),
}
}
}
#[derive(Clone, Debug)]
struct TaggedAdapter {
max_frame_length: usize,
}
impl TaggedAdapter {
fn new(max_frame_length: usize) -> Self { Self { max_frame_length } }
}
impl Decoder for TaggedAdapter {
type Item = TaggedFrame;
type Error = io::Error;
fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
const HEADER_LEN: usize = 2;
if src.len() < HEADER_LEN {
return Ok(None);
}
let mut header = src.as_ref();
let tag = header.get_u8();
let payload_len = header.get_u8() as usize;
if payload_len > self.max_frame_length {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"payload too large",
));
}
if src.len() < HEADER_LEN + payload_len {
return Ok(None);
}
let mut frame_bytes = src.split_to(HEADER_LEN + payload_len);
frame_bytes.advance(HEADER_LEN);
let payload = frame_bytes.to_vec();
Ok(Some(TaggedFrame { tag, payload }))
}
}
impl Encoder<TaggedFrame> for TaggedAdapter {
type Error = io::Error;
fn encode(&mut self, item: TaggedFrame, dst: &mut BytesMut) -> Result<(), Self::Error> {
if item.payload.len() > self.max_frame_length {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"payload too large",
));
}
let payload_len = u8::try_from(item.payload.len())
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "payload too long"))?;
dst.reserve(2 + item.payload.len());
dst.put_u8(item.tag);
dst.put_u8(payload_len);
dst.extend_from_slice(&item.payload);
Ok(())
}
}
impl FrameCodec for TaggedFrameCodec {
type Frame = TaggedFrame;
type Decoder = TaggedAdapter;
type Encoder = TaggedAdapter;
fn decoder(&self) -> Self::Decoder { TaggedAdapter::new(self.max_frame_length) }
fn encoder(&self) -> Self::Encoder { TaggedAdapter::new(self.max_frame_length) }
fn frame_payload(frame: &Self::Frame) -> &[u8] { frame.payload.as_slice() }
fn wrap_payload(&self, payload: Bytes) -> Self::Frame {
TaggedFrame {
tag: self.next_tag(),
payload: payload.to_vec(),
}
}
fn correlation_id(frame: &Self::Frame) -> Option<u64> { Some(u64::from(frame.tag)) }
fn max_frame_length(&self) -> usize { self.max_frame_length }
}
#[tokio::test]
async fn custom_codec_round_trips_frames() {
let app = WireframeApp::<BincodeSerializer, (), Envelope>::new()
.expect("build app")
.with_codec(TaggedFrameCodec::new(64))
.route(1, Arc::new(|_: &Envelope| Box::pin(async {})))
.expect("route configured");
let (mut client, server) = tokio::io::duplex(256);
let server_task = tokio::spawn(async move {
app.handle_connection_result(server)
.await
.expect("server should exit cleanly");
});
let request = Envelope::new(1, None, b"ping".to_vec());
let payload = BincodeSerializer
.serialize(&request)
.expect("serialize request");
let mut encoder = TaggedAdapter::new(64);
let mut buf = BytesMut::new();
encoder
.encode(TaggedFrame { tag: 7, payload }, &mut buf)
.expect("encode request");
client.write_all(&buf).await.expect("write request");
client.shutdown().await.expect("shutdown client");
let mut output = Vec::new();
client
.read_to_end(&mut output)
.await
.expect("read response");
server_task.await.expect("join server task");
let mut decoder = TaggedAdapter::new(64);
let mut response_buf = BytesMut::from(&output[..]);
let response_frame = decoder
.decode(&mut response_buf)
.expect("decode response")
.expect("response frame");
assert!(response_buf.is_empty(), "unexpected trailing bytes");
let (response_env, _) = BincodeSerializer
.deserialize::<Envelope>(&response_frame.payload)
.expect("deserialize response");
assert_eq!(response_env.correlation_id(), Some(7));
let response_payload = response_env.into_parts().into_payload();
assert_eq!(response_payload, b"ping".to_vec());
}
#[tokio::test]
async fn stateful_codec_advances_tags_per_connection() {
let app = WireframeApp::<BincodeSerializer, (), Envelope>::new()
.expect("build app")
.with_codec(TaggedFrameCodec::new(64))
.route(1, Arc::new(|_: &Envelope| Box::pin(async {})))
.expect("route configured");
let (client, server) = tokio::io::duplex(256);
let server_task = tokio::spawn(async move {
app.handle_connection_result(server)
.await
.expect("server should exit cleanly");
});
let mut framed = Framed::new(client, TaggedAdapter::new(64));
for expected_tag in [1_u8, 2] {
let request = Envelope::new(1, None, b"ping".to_vec());
let payload = BincodeSerializer
.serialize(&request)
.expect("serialize request");
framed
.send(TaggedFrame { tag: 7, payload })
.await
.expect("send request");
let response = framed
.next()
.await
.expect("missing response")
.expect("response frame");
assert_eq!(response.tag, expected_tag);
}
let mut stream = framed.into_inner();
stream.shutdown().await.expect("shutdown client");
server_task.await.expect("join server task");
}