easyfix_session/io/
input_stream.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
use std::{
    io,
    pin::Pin,
    task::{ready, Context, Poll},
};

use bytes::BytesMut;
use easyfix_messages::{
    deserializer::{self, raw_message, RawMessageError},
    messages::FixtMessage,
};
use futures_util::Stream;
use pin_project::pin_project;
use tokio::io::AsyncRead;
use tokio_util::io::poll_read_buf;
use tracing::{debug, info, warn};

use crate::application::DeserializeError;

#[derive(Debug)]
pub enum InputEvent {
    Message(Box<FixtMessage>),
    DeserializeError(DeserializeError),
    IoError(io::Error),
    Timeout,
}

fn process_garbled_data(buf: &mut BytesMut) {
    let len = buf.len();
    for i in 1..buf.len() {
        if let Ok(_) | Err(RawMessageError::Incomplete) = raw_message(&buf[i..]) {
            buf.split_to(i).freeze();
            info!("dropped {i} bytes of garbled message");
            return;
        }
    }
    buf.clear();
    info!("dropped {len} bytes of garbled message");
}

fn parse_message(
    bytes: &mut BytesMut,
) -> Result<Option<Box<FixtMessage>>, deserializer::DeserializeError> {
    if bytes.is_empty() {
        return Ok(None);
    }
    debug!(
        "Raw data input :: {}",
        String::from_utf8_lossy(bytes).replace('\x01', "|")
    );

    let src_len = bytes.len();

    match raw_message(bytes) {
        Ok((leftover, raw_msg)) => {
            let result = FixtMessage::from_raw_message(raw_msg).map(Some);
            let leftover_len = leftover.len();
            bytes.split_to(src_len - leftover_len).freeze();
            result
        }
        Err(RawMessageError::Incomplete) => Ok(None),
        Err(err) => {
            process_garbled_data(bytes);
            Err(err.into())
        }
    }
}

#[pin_project]
pub struct InputStream<S> {
    buffer: BytesMut,
    #[pin]
    source: S,
}

impl<S> Stream for InputStream<S>
where
    S: AsyncRead + Unpin,
{
    type Item = InputEvent;

    fn poll_next(self: std::pin::Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut this = self.project();

        loop {
            // Attempt to parse a message from the buffered data.
            // If enough data has been buffered, the message is returned.
            match parse_message(this.buffer) {
                Ok(Some(msg)) => {
                    return Poll::Ready(Some(InputEvent::Message(msg)));
                }
                Ok(None) => {}
                // Convert `deserializer::DeserializeError` to `application::DeserializeError`
                // to prevent leaking ParseRejectReason to user code.
                Err(error) => {
                    return Poll::Ready(Some(InputEvent::DeserializeError(error.into())));
                }
            }

            // There is not enough buffered data to read a message.
            // Attempt to read more data from the socket.
            //
            // On success, the number of bytes is returned. `0` indicates "end
            // of stream".
            let future = poll_read_buf(Pin::new(&mut this.source), cx, this.buffer);
            match ready!(future) {
                Ok(0) => {
                    // The remote closed the connection. For this to be a clean
                    // shutdown, there should be no data in the read buffer. If
                    // there is, this means that the peer closed the socket while
                    // sending a frame.
                    if this.buffer.is_empty() {
                        info!("Stream closed");
                        return Poll::Ready(None);
                    } else {
                        warn!("Connection reset by peer");
                        return Poll::Ready(None);
                    }
                }
                Ok(_n) => continue,
                Err(err) => return Poll::Ready(Some(InputEvent::IoError(err))),
            }
        }
    }
}

pub fn input_stream<S>(source: S) -> InputStream<S>
where
    S: AsyncRead + Unpin,
{
    InputStream {
        // TODO: Max MSG size
        buffer: BytesMut::with_capacity(4096),
        source,
    }
}