ellidri_reader/
lib.rs

1//! Asynchronous IRC message reading.
2//!
3//! This exposes a more robust alternative to tokio's `BufReader`, with better control on how lines
4//! are read.
5
6use futures::ready;
7use std::{io, marker, mem, pin, task};
8use std::future::Future;
9use tokio::io::{AsyncBufRead, AsyncRead, BufReader};
10
11const ABUSE_ERR: &str = "Bad client, bad! >:(";
12const UTF8_ERR: &str = "This was definitely not UTF-8...";
13const TOO_LONG_ERR: &str = "Kyaa! Your message is too long!";
14
15// TODO make this configurable
16const MAX_READ_PER_MESSAGE: u8 = 4;
17const MAX_TAG_LENGTH: usize = 4096;
18
19/// Asynchronous IRC message reader.
20pub struct IrcReader<R> {
21    inner: BufReader<R>,
22    message_max: usize,
23}
24
25impl<R: AsyncRead> IrcReader<R> {
26    /// Creates a new `IrcReader` with the given maximum length for messages.
27    ///
28    /// Although `message_max` allows restriction on the message length, `IrcReader` will always
29    /// allow lines of `4096 + message_max` bytes if the line starts with `@`.  This is because the
30    /// [message tag spec][1] states that tags can occupy up to 4096 bytes.  Thus, `message_max`
31    /// designates the maximum length of a message without tags (should default to 512, see RFCs
32    /// 1459 and 2812).
33    ///
34    /// [1]: https://ircv3.net/specs/extensions/message-tags.html
35    pub fn new(r: R, message_max: usize) -> Self {
36        Self { inner: BufReader::new(r), message_max }
37    }
38
39    /// Equivalent of tokio's `AsyncBufReadExt::read_line` for IRC messages.
40    ///
41    /// Function signature can also be read like so:
42    ///
43    /// ```rust
44    /// async fn read_message(&mut self, buf: &mut String) -> io::Result<usize>
45    /// ```
46    pub fn read_message<'a>(&'a mut self, buf: &'a mut String) -> ReadMessage<'a, R>
47        where Self: marker::Unpin,
48    {
49        ReadMessage {
50            reader: &mut self.inner,
51            bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) },
52            buf,
53            n: ReadInfo {
54                read: 0,
55                limit: 0,
56                message_max: self.message_max,
57                count: 0,
58            },
59        }
60    }
61}
62
63#[derive(Debug)]
64struct ReadInfo {
65    read: usize,
66    limit: usize,
67    message_max: usize,
68    count: u8,
69}
70
71/// Future returned by `IrcReader::read_message`.
72#[must_use = "futures do nothing unless polled or .await'ed"]
73#[derive(Debug)]
74pub struct ReadMessage<'a, R> {
75    reader: &'a mut BufReader<R>,
76    bytes: Vec<u8>,
77    buf: &'a mut String,
78    n: ReadInfo,
79}
80
81impl<R: AsyncRead + marker::Unpin> Future for ReadMessage<'_, R> {
82    type Output = io::Result<usize>;
83
84    fn poll(mut self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
85        let Self { reader, buf, bytes, n } = &mut *self;
86        read_message(pin::Pin::new(reader), cx, buf, bytes, n)
87    }
88}
89
90fn read_message<R>(reader: pin::Pin<&mut BufReader<R>>, cx: &mut task::Context<'_>,
91                   buf: &mut String, bytes: &mut Vec<u8>, n: &mut ReadInfo)
92                   -> task::Poll<io::Result<usize>>
93    where R: AsyncRead,
94{
95    let ret = ready!(read_line(reader, cx, bytes, n))?;
96    if std::str::from_utf8(&bytes).is_err() {
97        task::Poll::Ready(
98            Err(io::Error::new(io::ErrorKind::InvalidData, UTF8_ERR))
99        )
100    } else {
101        mem::swap(unsafe { buf.as_mut_vec() }, bytes);
102        task::Poll::Ready(Ok(ret))
103    }
104}
105
106fn read_line<R>(mut reader: pin::Pin<&mut BufReader<R>>, cx: &mut task::Context<'_>,
107                bytes: &mut Vec<u8>, n: &mut ReadInfo)
108                -> task::Poll<io::Result<usize>>
109    where R: AsyncRead,
110{
111    loop {
112        if MAX_READ_PER_MESSAGE <= n.count {
113            return task::Poll::Ready(Err(io::Error::new(io::ErrorKind::TimedOut, ABUSE_ERR)));
114        }
115        if 0 < n.limit && n.limit <= n.read {
116            return task::Poll::Ready(Err(io::Error::new(io::ErrorKind::TimedOut, TOO_LONG_ERR)));
117        }
118        let (done, used) = {
119            // TODO prevent spam +inf times "\r" or "\n"
120            let available = ready!(reader.as_mut().poll_fill_buf(cx))?;
121
122            if n.limit == 0 && !available.is_empty() {
123                if available[0] == b'@' {
124                    n.limit = MAX_TAG_LENGTH;
125                }
126                n.limit += n.message_max;
127            }
128
129            if let Some(i) = memchr::memchr2(b'\r', b'\n', available) {
130                bytes.extend_from_slice(&available[..=i]);
131                if i + 1 < available.len() && available[i + 1] == b'\n' {
132                    (true, i + 2)
133                } else {
134                    (true, i + 1)
135                }
136            } else {
137                bytes.extend_from_slice(available);
138                (false, available.len())
139            }
140        };
141        reader.as_mut().consume(used);
142        n.read += used;
143        if done || used == 0 {
144            return task::Poll::Ready(Ok(mem::replace(&mut n.read, 0)));
145        }
146        n.count += 1;
147    }
148}