1use std::error::Error;
2use std::fmt::{Display, Formatter, Write};
3
4use anyhow::anyhow;
5use bytes::BytesMut;
6
7#[derive(Debug)]
8pub enum OutputError {
9 IOError(std::fmt::Error),
10 ProtocolError(anyhow::Error),
11}
12
13impl From<std::fmt::Error> for OutputError {
14 fn from(err: std::fmt::Error) -> OutputError {
15 OutputError::IOError(err)
16 }
17}
18
19impl From<anyhow::Error> for OutputError {
20 fn from(err: anyhow::Error) -> OutputError {
21 OutputError::ProtocolError(err)
22 }
23}
24
25impl Display for OutputError {
26 fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
27 match self {
28 OutputError::IOError(e) => write!(f, "IO error: {:?}", e),
29 OutputError::ProtocolError(e) => write!(f, "Protocol error: {:?}", e),
30 }
31 }
32}
33
34impl Error for OutputError {
35 fn source(&self) -> Option<&(dyn Error + 'static)> {
36 match *self {
37 OutputError::IOError(ref e) => Some(e),
38 OutputError::ProtocolError(_) => None,
39 }
40 }
41}
42
43pub type OutputResult = std::result::Result<(), OutputError>;
44
45pub struct Response {
46 buffer: BytesMut,
47 nesting: Vec<i32>,
48}
49
50impl Response {
51 pub fn new() -> Self {
52 Response {
53 buffer: BytesMut::with_capacity(8192),
54 nesting: vec![1],
55 }
56 }
57
58 fn check_nesting(&mut self) -> OutputResult {
59 let current_nesting = match self.nesting.last_mut() {
60 Some(level) => level,
61 None => {
62 return Err(OutputError::ProtocolError(anyhow!(
63 "Invalid result nesting!"
64 )))
65 }
66 };
67
68 *current_nesting -= 1;
69 return if *current_nesting > 0 {
70 Ok(())
71 } else if *current_nesting == 0 {
72 self.nesting.pop();
73 Ok(())
74 } else {
75 Err(OutputError::ProtocolError(anyhow!(
76 "Invalid result nesting!"
77 )))
78 };
79 }
80
81 #[inline]
82 fn reserve(&mut self, required_length: usize) {
83 let len = self.buffer.len();
84 let rem = self.buffer.capacity() - len;
85
86 if rem < required_length {
87 self.reserve_inner(required_length);
88 }
89 }
90
91 fn reserve_inner(&mut self, required_length: usize) {
92 let required_blocks = (required_length / 8192) + 1;
93 self.buffer.reserve(required_blocks * 8192);
94 }
95
96 pub fn complete(mut self) -> Result<BytesMut, OutputError> {
97 if !self.nesting.is_empty() {
98 return Err(OutputError::ProtocolError(anyhow!(
99 "Invalid result nesting!"
100 )));
101 }
102
103 self.nesting.push(1);
104 Ok(self.buffer)
105 }
106
107 pub fn array(&mut self, items: i32) -> OutputResult {
108 self.check_nesting()?;
109 self.nesting.push(items);
110 self.reserve(16);
111 self.buffer.write_char('*')?;
112 write!(self.buffer, "{}\r\n", items)?;
113 Ok(())
114 }
115
116 pub fn ok(&mut self) -> OutputResult {
117 self.check_nesting()?;
118 self.reserve(5);
119 self.buffer.write_str("+OK\r\n")?;
120 Ok(())
121 }
122
123 pub fn zero(&mut self) -> OutputResult {
124 self.check_nesting()?;
125 self.reserve(4);
126 self.buffer.write_str(":0\r\n")?;
127 Ok(())
128 }
129
130 pub fn one(&mut self) -> OutputResult {
131 self.check_nesting()?;
132 self.reserve(4);
133 self.buffer.write_str(":1\r\n")?;
134 Ok(())
135 }
136
137 pub fn number(&mut self, number: i32) -> OutputResult {
138 if number == 0 {
139 self.zero()
140 } else if number == 1 {
141 self.one()
142 } else {
143 self.check_nesting()?;
144 self.reserve(16);
145 self.buffer.write_char(':')?;
146 write!(self.buffer, "{}\r\n", number)?;
147 Ok(())
148 }
149 }
150
151 pub fn boolean(&mut self, boolean: bool) -> OutputResult {
152 self.number(if boolean { 1 } else { 0 })
153 }
154
155 pub fn simple(&mut self, string: impl AsRef<str>) -> OutputResult {
156 if string.as_ref().len() == 0 {
157 self.empty_string()
158 } else {
159 self.check_nesting()?;
160 self.reserve(3 + string.as_ref().len());
161 self.buffer.write_char('+')?;
162 self.buffer.write_str(string.as_ref())?;
163 self.buffer.write_str("\r\n")?;
164
165 Ok(())
166 }
167 }
168
169 pub fn empty_string(&mut self) -> OutputResult {
170 self.check_nesting()?;
171 self.reserve(3);
172 self.buffer.write_str("+\r\n")?;
173
174 Ok(())
175 }
176
177 pub fn bulk(&mut self, string: impl AsRef<str>) -> OutputResult {
178 self.check_nesting()?;
179 self.reserve(3 + 16 + string.as_ref().len());
180 self.buffer.write_char('$')?;
181 write!(self.buffer, "{}\r\n", string.as_ref().len())?;
182 self.buffer.write_str(string.as_ref())?;
183 self.buffer.write_str("\r\n")?;
184
185 Ok(())
186 }
187
188 pub fn error(&mut self, string: impl AsRef<str>) -> OutputResult {
189 self.check_nesting()?;
190 self.reserve(3 + string.as_ref().len());
191 self.buffer.write_char('-')?;
192 self.buffer.write_str(
193 string
194 .as_ref()
195 .to_owned()
196 .replace(&"\r", " ")
197 .replace(&"\n", " ")
198 .as_str(),
199 )?;
200 self.buffer.write_str("\r\n")?;
201
202 Ok(())
203 }
204
205 pub fn int_triple(
206 &mut self,
207 name: impl AsRef<str>,
208 value: i32,
209 human_value: impl AsRef<str>,
210 ) -> OutputResult {
211 self.array(3)?;
212 self.bulk(name)?;
213 self.number(value)?;
214 self.bulk(human_value)?;
215
216 Ok(())
217 }
218
219 pub fn string_triple(
220 &mut self,
221 name: impl AsRef<str>,
222 value: impl AsRef<str>,
223 human_value: impl AsRef<str>,
224 ) -> OutputResult {
225 self.array(3)?;
226 self.bulk(name)?;
227 self.bulk(value.as_ref())?;
228 self.bulk(human_value.as_ref())?;
229
230 Ok(())
231 }
232
233 pub fn as_str(&self) -> &str {
234 std::str::from_utf8(&self.buffer[..]).unwrap()
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use crate::request::Request;
241 use crate::response::Response;
242
243 #[test]
244 fn an_array_of_bulk_strings_can_be_read_by_request() {
245 let mut response = Response::new();
246 response.array(2).unwrap();
247 response.bulk("Hello").unwrap();
248 response.bulk("World").unwrap();
249
250 assert_eq!(response.as_str(), "*2\r\n$5\r\nHello\r\n$5\r\nWorld\r\n");
251
252 let mut buffer = response.complete().unwrap();
253 let request = Request::parse(&mut buffer).unwrap().unwrap();
254 assert_eq!(request.command(), "Hello");
255 assert_eq!(request.parameter_count(), 1);
256 assert_eq!(request.str_parameter(0).unwrap(), "World");
257 }
258
259 #[test]
260 fn errors_are_sanitized() {
261 let mut response = Response::new();
262 response.error("Error\nProblem").unwrap();
263
264 assert_eq!(response.as_str(), "-Error Problem\r\n");
265 }
266
267 #[test]
268 fn incorrect_nesting_is_detected() {
269 {
270 let mut response = Response::new();
271 response.array(2).unwrap();
272 response.ok().unwrap();
273 assert_eq!(response.complete().is_err(), true);
274 }
275 {
276 let mut response = Response::new();
277 response.ok().unwrap();
278 assert_eq!(response.ok().is_err(), true);
279 }
280 {
281 let mut response = Response::new();
282 response.array(1).unwrap();
283 response.ok().unwrap();
284 assert_eq!(response.ok().is_err(), true);
285 }
286 }
287
288 #[test]
289 fn dynamic_buffer_allocation_works() {
290 let many_x = "X".repeat(16_000);
291 let many_y = "Y".repeat(16_000);
292
293 let mut response = Response::new();
294 response.array(2).unwrap();
295 response.simple(many_x.as_str()).unwrap();
296 response.bulk(many_y.as_str()).unwrap();
297
298 assert_eq!(
299 response.as_str(),
300 format!("*2\r\n+{}\r\n$16000\r\n{}\r\n", many_x, many_y)
301 );
302 }
303}