use crate::Result;
use bytes::{Buf, BytesMut};
use futures::{SinkExt, StreamExt};
use tokio_util::codec::{Decoder, Encoder};
const MAX: usize = 8 * 1024 * 1024;
struct Obfs4Codec {}
impl Obfs4Codec {
fn new() -> Self {
Self {}
}
}
impl Decoder for Obfs4Codec {
type Item = String;
type Error = std::io::Error;
fn decode(
&mut self,
src: &mut BytesMut,
) -> std::result::Result<Option<Self::Item>, Self::Error> {
if src.len() < 4 {
return Ok(None);
}
let mut length_bytes = [0u8; 4];
length_bytes.copy_from_slice(&src[..4]);
let length = u32::from_le_bytes(length_bytes) as usize;
if length > MAX {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Frame of length {} is too large.", length),
));
}
if src.len() < 4 + length {
src.reserve(4 + length - src.len());
return Ok(None);
}
let data = src[4..4 + length].to_vec();
src.advance(4 + length);
match String::from_utf8(data) {
Ok(string) => Ok(Some(string)),
Err(utf8_error) => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
utf8_error.utf8_error(),
)),
}
}
}
impl Encoder<String> for Obfs4Codec {
type Error = std::io::Error;
fn encode(&mut self, item: String, dst: &mut BytesMut) -> std::result::Result<(), Self::Error> {
if item.len() > MAX {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Frame of length {} is too large.", item.len()),
));
}
let len_slice = u32::to_le_bytes(item.len() as u32);
dst.reserve(4 + item.len());
dst.extend_from_slice(&len_slice);
dst.extend_from_slice(item.as_bytes());
Ok(())
}
}
#[tokio::test]
async fn framing_flow() -> Result<()> {
let (c, s) = tokio::io::duplex(16 * 1024);
tokio::spawn(async move {
let codec = Obfs4Codec::new();
let (mut sink, mut input) = codec.framed(s).split();
while let Some(Ok(event)) = input.next().await {
sink.send(event).await.expect("server response failed");
}
});
let message = "Hello there";
let client_codec = Obfs4Codec::new();
let (mut c_sink, mut c_stream) = client_codec.framed(c).split();
c_sink
.send(message.into())
.await
.expect("client send failed");
let m: String = c_stream
.next()
.await
.expect("you were supposed to call me back!")
.expect("an error occured when you called back");
assert_eq!(m, message);
Ok(())
}