easyfix_session/io/
input_stream.rs

1use std::{
2    io,
3    pin::Pin,
4    task::{ready, Context, Poll},
5};
6
7use bytes::BytesMut;
8use easyfix_messages::{
9    deserializer::{self, raw_message, RawMessageError},
10    messages::FixtMessage,
11};
12use futures_util::Stream;
13use pin_project::pin_project;
14use tokio::io::AsyncRead;
15use tokio_util::io::poll_read_buf;
16use tracing::{debug, info, warn};
17
18use crate::application::DeserializeError;
19
20#[derive(Debug)]
21pub enum InputEvent {
22    Message(Box<FixtMessage>),
23    DeserializeError(DeserializeError),
24    IoError(io::Error),
25    Timeout,
26}
27
28fn process_garbled_data(buf: &mut BytesMut) {
29    let len = buf.len();
30    for i in 1..buf.len() {
31        if let Ok(_) | Err(RawMessageError::Incomplete) = raw_message(&buf[i..]) {
32            buf.split_to(i).freeze();
33            info!("dropped {i} bytes of garbled message");
34            return;
35        }
36    }
37    buf.clear();
38    info!("dropped {len} bytes of garbled message");
39}
40
41fn parse_message(
42    bytes: &mut BytesMut,
43) -> Result<Option<Box<FixtMessage>>, deserializer::DeserializeError> {
44    if bytes.is_empty() {
45        return Ok(None);
46    }
47    debug!(
48        "Raw data input :: {}",
49        String::from_utf8_lossy(bytes).replace('\x01', "|")
50    );
51
52    let src_len = bytes.len();
53
54    match raw_message(bytes) {
55        Ok((leftover, raw_msg)) => {
56            let result = FixtMessage::from_raw_message(raw_msg).map(Some);
57            let leftover_len = leftover.len();
58            bytes.split_to(src_len - leftover_len).freeze();
59            result
60        }
61        Err(RawMessageError::Incomplete) => Ok(None),
62        Err(err) => {
63            process_garbled_data(bytes);
64            Err(err.into())
65        }
66    }
67}
68
69#[pin_project]
70pub struct InputStream<S> {
71    buffer: BytesMut,
72    #[pin]
73    source: S,
74}
75
76impl<S> Stream for InputStream<S>
77where
78    S: AsyncRead + Unpin,
79{
80    type Item = InputEvent;
81
82    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
83        let mut this = self.project();
84
85        loop {
86            // Attempt to parse a message from the buffered data.
87            // If enough data has been buffered, the message is returned.
88            match parse_message(this.buffer) {
89                Ok(Some(msg)) => {
90                    return Poll::Ready(Some(InputEvent::Message(msg)));
91                }
92                Ok(None) => {}
93                // Convert `deserializer::DeserializeError` to `application::DeserializeError`
94                // to prevent leaking ParseRejectReason to user code.
95                Err(error) => {
96                    return Poll::Ready(Some(InputEvent::DeserializeError(error.into())));
97                }
98            }
99
100            // There is not enough buffered data to read a message.
101            // Attempt to read more data from the socket.
102            //
103            // On success, the number of bytes is returned. `0` indicates "end
104            // of stream".
105            let future = poll_read_buf(Pin::new(&mut this.source), cx, this.buffer);
106            match ready!(future) {
107                Ok(0) => {
108                    // The remote closed the connection. For this to be a clean
109                    // shutdown, there should be no data in the read buffer. If
110                    // there is, this means that the peer closed the socket while
111                    // sending a frame.
112                    if this.buffer.is_empty() {
113                        info!("Stream closed");
114                        return Poll::Ready(None);
115                    } else {
116                        warn!("Connection reset by peer");
117                        return Poll::Ready(None);
118                    }
119                }
120                Ok(_n) => continue,
121                Err(err) => return Poll::Ready(Some(InputEvent::IoError(err))),
122            }
123        }
124    }
125}
126
127pub fn input_stream<S>(source: S) -> InputStream<S>
128where
129    S: AsyncRead + Unpin,
130{
131    InputStream {
132        // TODO: Max MSG size
133        buffer: BytesMut::with_capacity(4096),
134        source,
135    }
136}