1use futures::ready;
7use std::{io, marker, mem, pin, task};
8use std::future::Future;
9use tokio::io::{AsyncBufRead, AsyncRead, BufReader};
10
11const ABUSE_ERR: &str = "Bad client, bad! >:(";
12const UTF8_ERR: &str = "This was definitely not UTF-8...";
13const TOO_LONG_ERR: &str = "Kyaa! Your message is too long!";
14
15const MAX_READ_PER_MESSAGE: u8 = 4;
17const MAX_TAG_LENGTH: usize = 4096;
18
19pub struct IrcReader<R> {
21 inner: BufReader<R>,
22 message_max: usize,
23}
24
25impl<R: AsyncRead> IrcReader<R> {
26 pub fn new(r: R, message_max: usize) -> Self {
36 Self { inner: BufReader::new(r), message_max }
37 }
38
39 pub fn read_message<'a>(&'a mut self, buf: &'a mut String) -> ReadMessage<'a, R>
47 where Self: marker::Unpin,
48 {
49 ReadMessage {
50 reader: &mut self.inner,
51 bytes: unsafe { mem::replace(buf.as_mut_vec(), Vec::new()) },
52 buf,
53 n: ReadInfo {
54 read: 0,
55 limit: 0,
56 message_max: self.message_max,
57 count: 0,
58 },
59 }
60 }
61}
62
63#[derive(Debug)]
64struct ReadInfo {
65 read: usize,
66 limit: usize,
67 message_max: usize,
68 count: u8,
69}
70
71#[must_use = "futures do nothing unless polled or .await'ed"]
73#[derive(Debug)]
74pub struct ReadMessage<'a, R> {
75 reader: &'a mut BufReader<R>,
76 bytes: Vec<u8>,
77 buf: &'a mut String,
78 n: ReadInfo,
79}
80
81impl<R: AsyncRead + marker::Unpin> Future for ReadMessage<'_, R> {
82 type Output = io::Result<usize>;
83
84 fn poll(mut self: pin::Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
85 let Self { reader, buf, bytes, n } = &mut *self;
86 read_message(pin::Pin::new(reader), cx, buf, bytes, n)
87 }
88}
89
90fn read_message<R>(reader: pin::Pin<&mut BufReader<R>>, cx: &mut task::Context<'_>,
91 buf: &mut String, bytes: &mut Vec<u8>, n: &mut ReadInfo)
92 -> task::Poll<io::Result<usize>>
93 where R: AsyncRead,
94{
95 let ret = ready!(read_line(reader, cx, bytes, n))?;
96 if std::str::from_utf8(&bytes).is_err() {
97 task::Poll::Ready(
98 Err(io::Error::new(io::ErrorKind::InvalidData, UTF8_ERR))
99 )
100 } else {
101 mem::swap(unsafe { buf.as_mut_vec() }, bytes);
102 task::Poll::Ready(Ok(ret))
103 }
104}
105
106fn read_line<R>(mut reader: pin::Pin<&mut BufReader<R>>, cx: &mut task::Context<'_>,
107 bytes: &mut Vec<u8>, n: &mut ReadInfo)
108 -> task::Poll<io::Result<usize>>
109 where R: AsyncRead,
110{
111 loop {
112 if MAX_READ_PER_MESSAGE <= n.count {
113 return task::Poll::Ready(Err(io::Error::new(io::ErrorKind::TimedOut, ABUSE_ERR)));
114 }
115 if 0 < n.limit && n.limit <= n.read {
116 return task::Poll::Ready(Err(io::Error::new(io::ErrorKind::TimedOut, TOO_LONG_ERR)));
117 }
118 let (done, used) = {
119 let available = ready!(reader.as_mut().poll_fill_buf(cx))?;
121
122 if n.limit == 0 && !available.is_empty() {
123 if available[0] == b'@' {
124 n.limit = MAX_TAG_LENGTH;
125 }
126 n.limit += n.message_max;
127 }
128
129 if let Some(i) = memchr::memchr2(b'\r', b'\n', available) {
130 bytes.extend_from_slice(&available[..=i]);
131 if i + 1 < available.len() && available[i + 1] == b'\n' {
132 (true, i + 2)
133 } else {
134 (true, i + 1)
135 }
136 } else {
137 bytes.extend_from_slice(available);
138 (false, available.len())
139 }
140 };
141 reader.as_mut().consume(used);
142 n.read += used;
143 if done || used == 0 {
144 return task::Poll::Ready(Ok(mem::replace(&mut n.read, 0)));
145 }
146 n.count += 1;
147 }
148}