forgefix 0.3.0

ForgeFIX is an opinionated FIX 4.2 client library for the buy-side written in Rust. ForgeFIX is optimized for the subset of the FIX protocol used by buy-side firms connecting to brokers and exchanges for communicating orders and fills.
Documentation
use std::io::Write;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::AsyncWrite;

pub struct ChecksumWriter<W>(W, usize);
impl<W> Write for ChecksumWriter<W>
where
    W: Write,
{
    fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
        for c in buf {
            self.1 += (*c) as usize;
        }
        self.0.write(buf)
    }
    fn flush(&mut self) -> std::io::Result<()> {
        self.0.flush()
    }
}
impl<W> ChecksumWriter<W> {
    #[allow(dead_code)]
    pub fn new(w: W) -> Self {
        ChecksumWriter(w, 0)
    }
    #[allow(dead_code)]
    pub fn checksum(&self) -> usize {
        self.1 % 256
    }
}

pub struct AsyncChecksumWriter<W>(W, usize);
impl<W> AsyncWrite for AsyncChecksumWriter<W>
where
    W: AsyncWrite + Unpin,
{
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, std::io::Error>> {
        let mut_self = &mut self.get_mut();
        for c in buf {
            mut_self.1 += (*c) as usize;
        }

        Pin::new(&mut mut_self.0).poll_write(cx, buf)
    }
    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
        Pin::new(&mut self.get_mut().0).poll_flush(cx)
    }
    fn poll_shutdown(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        Pin::new(&mut self.get_mut().0).poll_shutdown(cx)
    }
}
impl<W> AsyncChecksumWriter<W> {
    pub fn new(w: W) -> Self {
        AsyncChecksumWriter(w, 0)
    }
    pub fn checksum(&self) -> usize {
        self.1 % 256
    }
}

pub fn calc_checksum(bytes: &[u8]) -> i32 {
    bytes.iter().map(|c| *c as i32).sum::<i32>() % 256
}

pub fn checksum_is_valid(msg_buf: &[u8]) -> bool {
    if let Some(checksum) = parse_checksum(msg_buf) {
        return checksum_matches(&msg_buf[..msg_buf.len() - 7], checksum);
    }
    false
}

fn parse_checksum(msg_buf: &[u8]) -> Option<i32> {
    if msg_buf.len() < 7 {
        return None;
    }
    let tail = &msg_buf[msg_buf.len() - 7..];
    if &tail[0..3] != b"10="
        || !tail[3..6].iter().all(|&byte| byte.is_ascii_digit())
        || tail[6] != b'\x01'
    {
        return None;
    }

    std::str::from_utf8(&tail[3..6]).unwrap_or("").parse().ok()
}

fn checksum_matches(msg: &[u8], checksum: i32) -> bool {
    let calculated = calc_checksum(msg);
    checksum == calculated
}

#[cfg(test)]
mod test {
    use super::*;

    #[test]
    fn test_checksum_matches() {
        let tests: Vec<(&[u8], i32, bool)> = vec![
            (b"8=FIX.4.2\x019=98\x0135=5\x0134=2\x0149=ISLD5\x012=20230803-14:13:08.157\x0156=TW\x0158=MsgSeqNum too low, expecting 3 but received 2\x01", 81, true),
            (b"8=FIX.4.2\x019=98\x0135=5\x0134=2\x0149=ISLD5\x012=20230803-14:13:08.157\x0156=TW\x0158=MsgSeqNum too low, expecting 3 but received 2\x01", 0, false),
            (b"8=FIX.4.2\x019=57\x0135=A\x0134=1\x0149=TW\x0152=20230803-15:42:57\x0156=ISLD\x0198=0\x01108=30\x01", 19, true),
        ];
        for t in tests {
            assert_eq!(checksum_matches(t.0, t.1), t.2);
        }
    }

    #[test]
    fn test_parse_checksum() {
        let tests: Vec<(&[u8], bool)> = vec![
            (b"aaaaaaaaaaaaaaaa10=123\x01", true),
            (b"aaaaaaaa10=43\x01", false),
            (b"aaaaaaaa10=123", false),
            (b"aaaaaaaa11=123\x01", false),
        ];
        for t in tests {
            assert_eq!(
                parse_checksum(t.0).is_some(),
                t.1,
                "{:?} {}",
                parse_checksum(t.0),
                t.1
            );
        }
    }
}