jupiter_rs/
response.rs

1use std::error::Error;
2use std::fmt::{Display, Formatter, Write};
3
4use anyhow::anyhow;
5use bytes::BytesMut;
6
7#[derive(Debug)]
8pub enum OutputError {
9    IOError(std::fmt::Error),
10    ProtocolError(anyhow::Error),
11}
12
13impl From<std::fmt::Error> for OutputError {
14    fn from(err: std::fmt::Error) -> OutputError {
15        OutputError::IOError(err)
16    }
17}
18
19impl From<anyhow::Error> for OutputError {
20    fn from(err: anyhow::Error) -> OutputError {
21        OutputError::ProtocolError(err)
22    }
23}
24
25impl Display for OutputError {
26    fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
27        match self {
28            OutputError::IOError(e) => write!(f, "IO error: {:?}", e),
29            OutputError::ProtocolError(e) => write!(f, "Protocol error: {:?}", e),
30        }
31    }
32}
33
34impl Error for OutputError {
35    fn source(&self) -> Option<&(dyn Error + 'static)> {
36        match *self {
37            OutputError::IOError(ref e) => Some(e),
38            OutputError::ProtocolError(_) => None,
39        }
40    }
41}
42
43pub type OutputResult = std::result::Result<(), OutputError>;
44
45pub struct Response {
46    buffer: BytesMut,
47    nesting: Vec<i32>,
48}
49
50impl Response {
51    pub fn new() -> Self {
52        Response {
53            buffer: BytesMut::with_capacity(8192),
54            nesting: vec![1],
55        }
56    }
57
58    fn check_nesting(&mut self) -> OutputResult {
59        let current_nesting = match self.nesting.last_mut() {
60            Some(level) => level,
61            None => {
62                return Err(OutputError::ProtocolError(anyhow!(
63                    "Invalid result nesting!"
64                )))
65            }
66        };
67
68        *current_nesting -= 1;
69        return if *current_nesting > 0 {
70            Ok(())
71        } else if *current_nesting == 0 {
72            self.nesting.pop();
73            Ok(())
74        } else {
75            Err(OutputError::ProtocolError(anyhow!(
76                "Invalid result nesting!"
77            )))
78        };
79    }
80
81    #[inline]
82    fn reserve(&mut self, required_length: usize) {
83        let len = self.buffer.len();
84        let rem = self.buffer.capacity() - len;
85
86        if rem < required_length {
87            self.reserve_inner(required_length);
88        }
89    }
90
91    fn reserve_inner(&mut self, required_length: usize) {
92        let required_blocks = (required_length / 8192) + 1;
93        self.buffer.reserve(required_blocks * 8192);
94    }
95
96    pub fn complete(mut self) -> Result<BytesMut, OutputError> {
97        if !self.nesting.is_empty() {
98            return Err(OutputError::ProtocolError(anyhow!(
99                "Invalid result nesting!"
100            )));
101        }
102
103        self.nesting.push(1);
104        Ok(self.buffer)
105    }
106
107    pub fn array(&mut self, items: i32) -> OutputResult {
108        self.check_nesting()?;
109        self.nesting.push(items);
110        self.reserve(16);
111        self.buffer.write_char('*')?;
112        write!(self.buffer, "{}\r\n", items)?;
113        Ok(())
114    }
115
116    pub fn ok(&mut self) -> OutputResult {
117        self.check_nesting()?;
118        self.reserve(5);
119        self.buffer.write_str("+OK\r\n")?;
120        Ok(())
121    }
122
123    pub fn zero(&mut self) -> OutputResult {
124        self.check_nesting()?;
125        self.reserve(4);
126        self.buffer.write_str(":0\r\n")?;
127        Ok(())
128    }
129
130    pub fn one(&mut self) -> OutputResult {
131        self.check_nesting()?;
132        self.reserve(4);
133        self.buffer.write_str(":1\r\n")?;
134        Ok(())
135    }
136
137    pub fn number(&mut self, number: i32) -> OutputResult {
138        if number == 0 {
139            self.zero()
140        } else if number == 1 {
141            self.one()
142        } else {
143            self.check_nesting()?;
144            self.reserve(16);
145            self.buffer.write_char(':')?;
146            write!(self.buffer, "{}\r\n", number)?;
147            Ok(())
148        }
149    }
150
151    pub fn boolean(&mut self, boolean: bool) -> OutputResult {
152        self.number(if boolean { 1 } else { 0 })
153    }
154
155    pub fn simple(&mut self, string: impl AsRef<str>) -> OutputResult {
156        if string.as_ref().len() == 0 {
157            self.empty_string()
158        } else {
159            self.check_nesting()?;
160            self.reserve(3 + string.as_ref().len());
161            self.buffer.write_char('+')?;
162            self.buffer.write_str(string.as_ref())?;
163            self.buffer.write_str("\r\n")?;
164
165            Ok(())
166        }
167    }
168
169    pub fn empty_string(&mut self) -> OutputResult {
170        self.check_nesting()?;
171        self.reserve(3);
172        self.buffer.write_str("+\r\n")?;
173
174        Ok(())
175    }
176
177    pub fn bulk(&mut self, string: impl AsRef<str>) -> OutputResult {
178        self.check_nesting()?;
179        self.reserve(3 + 16 + string.as_ref().len());
180        self.buffer.write_char('$')?;
181        write!(self.buffer, "{}\r\n", string.as_ref().len())?;
182        self.buffer.write_str(string.as_ref())?;
183        self.buffer.write_str("\r\n")?;
184
185        Ok(())
186    }
187
188    pub fn error(&mut self, string: impl AsRef<str>) -> OutputResult {
189        self.check_nesting()?;
190        self.reserve(3 + string.as_ref().len());
191        self.buffer.write_char('-')?;
192        self.buffer.write_str(
193            string
194                .as_ref()
195                .to_owned()
196                .replace(&"\r", " ")
197                .replace(&"\n", " ")
198                .as_str(),
199        )?;
200        self.buffer.write_str("\r\n")?;
201
202        Ok(())
203    }
204
205    pub fn int_triple(
206        &mut self,
207        name: impl AsRef<str>,
208        value: i32,
209        human_value: impl AsRef<str>,
210    ) -> OutputResult {
211        self.array(3)?;
212        self.bulk(name)?;
213        self.number(value)?;
214        self.bulk(human_value)?;
215
216        Ok(())
217    }
218
219    pub fn string_triple(
220        &mut self,
221        name: impl AsRef<str>,
222        value: impl AsRef<str>,
223        human_value: impl AsRef<str>,
224    ) -> OutputResult {
225        self.array(3)?;
226        self.bulk(name)?;
227        self.bulk(value.as_ref())?;
228        self.bulk(human_value.as_ref())?;
229
230        Ok(())
231    }
232
233    pub fn as_str(&self) -> &str {
234        std::str::from_utf8(&self.buffer[..]).unwrap()
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use crate::request::Request;
241    use crate::response::Response;
242
243    #[test]
244    fn an_array_of_bulk_strings_can_be_read_by_request() {
245        let mut response = Response::new();
246        response.array(2).unwrap();
247        response.bulk("Hello").unwrap();
248        response.bulk("World").unwrap();
249
250        assert_eq!(response.as_str(), "*2\r\n$5\r\nHello\r\n$5\r\nWorld\r\n");
251
252        let mut buffer = response.complete().unwrap();
253        let request = Request::parse(&mut buffer).unwrap().unwrap();
254        assert_eq!(request.command(), "Hello");
255        assert_eq!(request.parameter_count(), 1);
256        assert_eq!(request.str_parameter(0).unwrap(), "World");
257    }
258
259    #[test]
260    fn errors_are_sanitized() {
261        let mut response = Response::new();
262        response.error("Error\nProblem").unwrap();
263
264        assert_eq!(response.as_str(), "-Error Problem\r\n");
265    }
266
267    #[test]
268    fn incorrect_nesting_is_detected() {
269        {
270            let mut response = Response::new();
271            response.array(2).unwrap();
272            response.ok().unwrap();
273            assert_eq!(response.complete().is_err(), true);
274        }
275        {
276            let mut response = Response::new();
277            response.ok().unwrap();
278            assert_eq!(response.ok().is_err(), true);
279        }
280        {
281            let mut response = Response::new();
282            response.array(1).unwrap();
283            response.ok().unwrap();
284            assert_eq!(response.ok().is_err(), true);
285        }
286    }
287
288    #[test]
289    fn dynamic_buffer_allocation_works() {
290        let many_x = "X".repeat(16_000);
291        let many_y = "Y".repeat(16_000);
292
293        let mut response = Response::new();
294        response.array(2).unwrap();
295        response.simple(many_x.as_str()).unwrap();
296        response.bulk(many_y.as_str()).unwrap();
297
298        assert_eq!(
299            response.as_str(),
300            format!("*2\r\n+{}\r\n$16000\r\n{}\r\n", many_x, many_y)
301        );
302    }
303}