use super::{Config, DecodeError, Decoder, Encoder, RawDecoder, RawFrame};
use crate::tagvalue::Message;
use crate::{Dictionary, GetConfig};
use bytes::{BufMut, Bytes, BytesMut};
use tokio_util::codec;
#[derive(Debug)]
pub struct TokioRawDecoder {
raw_decoder: RawDecoder,
}
impl TokioRawDecoder {
pub fn new() -> Self {
Self { raw_decoder: RawDecoder::new() }
}
}
impl Default for TokioRawDecoder {
fn default() -> Self {
Self::new()
}
}
impl codec::Decoder for TokioRawDecoder {
type Item = RawFrame<Bytes>;
type Error = DecodeError;
fn decode(
&mut self,
src: &mut BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
if let Some(message_len) = self.find_complete_message(src)? {
let message_data = src.split_to(message_len);
let frozen_data = message_data.freeze();
let frozen_data_clone = frozen_data.clone();
let result = self.raw_decoder.decode(&frozen_data[..]);
match result {
Ok(raw_frame) => {
Ok(Some(RawFrame {
data: frozen_data_clone,
begin_string: raw_frame.begin_string,
payload: raw_frame.payload,
}))
}
Err(e) => Err(e),
}
} else {
Ok(None)
}
}
}
impl TokioRawDecoder {
fn find_complete_message(
&self,
src: &BytesMut,
) -> Result<Option<usize>, DecodeError> {
if src.len() < 10 {
return Ok(None); }
let separator = self.raw_decoder.config().separator;
let mut pos = 0;
let mut in_body_length = false;
let mut body_length = 0u32;
let mut body_length_end = 0;
while pos < src.len() {
if src[pos] == b'9' && pos + 1 < src.len() && src[pos + 1] == b'='
{
in_body_length = true;
pos += 2;
continue;
}
if in_body_length {
if src[pos] == separator {
body_length_end = pos + 1;
break;
} else if src[pos].is_ascii_digit() {
body_length = body_length
.saturating_mul(10)
.saturating_add((src[pos] - b'0') as u32);
}
}
pos += 1;
}
if body_length_end == 0 {
return Ok(None); }
let expected_total_length = body_length_end + body_length as usize + 7;
if src.len() >= expected_total_length {
Ok(Some(expected_total_length))
} else {
Ok(None)
}
}
}
impl GetConfig for TokioRawDecoder {
type Config = Config;
fn config(&self) -> &Self::Config {
self.raw_decoder.config()
}
fn config_mut(&mut self) -> &mut Self::Config {
self.raw_decoder.config_mut()
}
}
#[derive(Debug)]
#[cfg_attr(doc_cfg, doc(cfg(feature = "utils-tokio")))]
pub struct TokioDecoder {
decoder: Decoder,
}
impl TokioDecoder {
pub fn new(dict: Dictionary) -> Self {
Self { decoder: Decoder::new(dict) }
}
}
impl codec::Decoder for TokioDecoder {
type Item = Message<'static, Bytes>;
type Error = DecodeError;
fn decode(
&mut self,
src: &mut BytesMut,
) -> Result<Option<Self::Item>, Self::Error> {
if let Some(message_len) = self.find_complete_message(src)? {
let message_data = src.split_to(message_len);
let frozen_data = message_data.freeze();
let frozen_data_clone = frozen_data.clone();
let result = self.decoder.decode(&frozen_data[..]);
match result {
Ok(message) => {
let owned_message = Self::message_to_owned_static(
message,
frozen_data_clone,
);
Ok(Some(owned_message))
}
Err(e) => Err(e),
}
} else {
Ok(None)
}
}
}
impl TokioDecoder {
fn find_complete_message(
&self,
src: &BytesMut,
) -> Result<Option<usize>, DecodeError> {
if src.len() < 10 {
return Ok(None);
}
let separator = self.decoder.config().separator;
let mut pos = 0;
let mut in_body_length = false;
let mut body_length = 0u32;
let mut body_length_end = 0;
while pos < src.len() {
if src[pos] == b'9' && pos + 1 < src.len() && src[pos + 1] == b'='
{
in_body_length = true;
pos += 2;
continue;
}
if in_body_length {
if src[pos] == separator {
body_length_end = pos + 1;
break;
} else if src[pos].is_ascii_digit() {
body_length = body_length * 10 + (src[pos] - b'0') as u32;
}
}
pos += 1;
}
if body_length_end == 0 {
return Ok(None);
}
let expected_total_length = body_length_end + body_length as usize + 7;
if src.len() >= expected_total_length {
Ok(Some(expected_total_length))
} else {
Ok(None)
}
}
fn message_to_owned_static(
_message: Message<&[u8]>,
_data: Bytes,
) -> Message<'static, Bytes> {
todo!(
"Implement proper message ownership conversion without unsafe transmute"
)
}
}
#[derive(Debug)]
#[cfg_attr(doc_cfg, doc(cfg(feature = "utils-tokio")))]
pub struct TokioEncoder {
encoder: Encoder,
}
impl TokioEncoder {
pub fn new() -> Self {
Self { encoder: Encoder::new() }
}
}
impl Default for TokioEncoder {
fn default() -> Self {
Self::new()
}
}
impl codec::Encoder<&[u8]> for TokioEncoder {
type Error = std::io::Error;
fn encode(
&mut self,
item: &[u8],
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
dst.reserve(item.len());
dst.put_slice(item);
Ok(())
}
}
impl codec::Encoder<Vec<u8>> for TokioEncoder {
type Error = std::io::Error;
fn encode(
&mut self,
item: Vec<u8>,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
dst.reserve(item.len());
dst.put_slice(&item);
Ok(())
}
}
impl codec::Encoder<Bytes> for TokioEncoder {
type Error = std::io::Error;
fn encode(
&mut self,
item: Bytes,
dst: &mut BytesMut,
) -> Result<(), Self::Error> {
dst.reserve(item.len());
dst.put_slice(&item);
Ok(())
}
}
impl GetConfig for TokioEncoder {
type Config = Config;
fn config(&self) -> &Self::Config {
self.encoder.config()
}
fn config_mut(&mut self) -> &mut Self::Config {
self.encoder.config_mut()
}
}
#[cfg(test)]
mod test {
use super::*;
use tokio_util::codec::{Decoder, Encoder};
#[tokio::test]
async fn tokio_raw_decoder_basic() {
let mut decoder = TokioRawDecoder::new();
let mut buf = BytesMut::from(&b"8=FIX.4.4|9=42|35=0|49=A|56=B|34=12|52=20100304-07:59:30|10=185|"[..]);
let result = decoder.decode(&mut buf);
assert!(result.is_ok());
}
#[tokio::test]
async fn tokio_encoder_basic() {
let mut encoder = TokioEncoder::new();
let mut buf = BytesMut::new();
let message: &[u8] = b"8=FIX.4.4|9=42|35=0|10=185|";
let result = encoder.encode(message, &mut buf);
assert!(result.is_ok());
assert_eq!(buf.as_ref(), message);
}
}