1use std::fmt::Display;
10
11use nom::bytes::complete::take_until;
12use nom::character::complete::{char, digit1, newline, space0};
13use nom::combinator::map_res;
14use nom::multi::many0;
15use nom::{IResult, Parser};
16
17use crate::Error;
18
19#[derive(Debug, PartialEq, Clone, Copy, Eq)]
20pub enum MessageType {
21 Capabilities,
22 Log,
23 Status,
24 URIStart,
25 URIDone,
26 URIFailure,
27 GeneralFailure,
28 URIAcquire,
29 Configuration,
30}
31
32impl MessageType {
33 fn code(&self) -> u16 {
34 match self {
35 MessageType::Capabilities => 100,
36 MessageType::Log => 101,
37 MessageType::Status => 102,
38 MessageType::URIStart => 200,
39 MessageType::URIDone => 201,
40 MessageType::URIFailure => 400,
41 MessageType::GeneralFailure => 401,
42 MessageType::URIAcquire => 600,
43 MessageType::Configuration => 601,
44 }
45 }
46
47 fn description(&self) -> &str {
48 match self {
49 MessageType::Capabilities => "Capabilities",
50 MessageType::Log => "Log",
51 MessageType::Status => "Status",
52 MessageType::URIStart => "URI Start",
53 MessageType::URIDone => "URI Done",
54 MessageType::URIFailure => "URI Failure",
55 MessageType::GeneralFailure => "General Failure",
56 MessageType::URIAcquire => "URI Acquire",
57 MessageType::Configuration => "Configuration",
58 }
59 }
60
61 pub fn from_bytes(input: &[u8]) -> IResult<&[u8], MessageType> {
62 let (input, code) = digit1(input)?;
65 let (input, _) = take_until("\n")(input)?;
66 let (input, _) = newline(input)?;
67
68 match code {
69 b"100" => Ok((input, MessageType::Capabilities)),
70 b"101" => Ok((input, MessageType::Log)),
71 b"102" => Ok((input, MessageType::Status)),
72 b"200" => Ok((input, MessageType::URIStart)),
73 b"201" => Ok((input, MessageType::URIDone)),
74 b"400" => Ok((input, MessageType::URIFailure)),
75 b"401" => Ok((input, MessageType::GeneralFailure)),
76 b"600" => Ok((input, MessageType::URIAcquire)),
77 b"601" => Ok((input, MessageType::Configuration)),
78 _ => unimplemented!("Unknown message type: {:?}", code),
79 }
80 }
81}
82
83#[derive(Debug, PartialEq)]
84pub struct Message {
85 pub message_type: MessageType,
86 pub headers: Vec<(String, String)>,
87}
88
89fn key_value_pair(input: &[u8]) -> IResult<&[u8], (String, String)> {
90 let mut parse_key = map_res(take_until(":"), |buf| std::str::from_utf8(buf));
91 let mut parse_value = map_res(take_until("\n"), |buf| std::str::from_utf8(buf));
92
93 let (input, key) = parse_key.parse(input)?;
94 let (input, _) = char(':')(input)?;
95 let (input, _) = space0(input)?;
96 let (input, value) = parse_value.parse(input)?;
97 let (input, _) = newline(input)?;
98
99 let res = (key.to_string(), value.to_string());
100 Ok((input, res))
101}
102
103impl Message {
104 pub fn new(message_type: MessageType, headers: Vec<(&str, &str)>) -> Message {
108 Message {
109 message_type,
110 headers: headers
111 .iter()
112 .map(|(k, v)| (k.to_string(), v.to_string()))
113 .collect(),
114 }
115 }
116
117 fn parse(input: &[u8]) -> IResult<&[u8], Message> {
118 let (input, message_type) = MessageType::from_bytes(input)?;
120
121 let (input, headers) = many0(key_value_pair).parse(input)?;
123
124 let (input, _) = newline(input)?;
126
127 Ok((
128 input,
129 Message {
130 message_type,
131 headers,
132 },
133 ))
134 }
135
136 pub fn from_bytes(input: &[u8]) -> Result<Message, Error> {
137 match Message::parse(input) {
138 Ok((b"", message)) => Ok(message),
139 Ok((_, _)) => Err(Error::MessageTooMuchData),
140 Err(err) => Err(Error::MessageParse(format!("{err}"))),
141 }
142 }
143
144 pub fn status(message: &str) -> Self {
145 Self::new(MessageType::Status, vec![("Message", message)])
146 }
147
148 pub fn general_failure(message: &str) -> Self {
149 Self::new(MessageType::GeneralFailure, vec![("Message", message)])
150 }
151
152 pub fn uri_start(uri: &str, size: u64, last_modified: &str) -> Self {
153 Self::new(
154 MessageType::URIStart,
155 vec![
156 ("URI", uri),
157 ("Size", &size.to_string()),
158 ("Last-Modified", last_modified),
159 ],
160 )
161 }
162
163 pub fn uri_failure(uri: &str, message: &str) -> Self {
164 Self::new(
165 MessageType::URIFailure,
166 vec![("URI", uri), ("Message", message)],
167 )
168 }
169
170 pub fn uri_success(uri: &str, filename: &str) -> Self {
171 Self::new(
172 MessageType::URIDone,
173 vec![("URI", uri), ("Filename", filename)],
174 )
175 }
176
177 pub fn header(&self, key: &str) -> Result<&str, Error> {
178 self.headers
179 .iter()
180 .find(|(k, _)| k == key)
181 .map(|(_, v)| v.as_str())
182 .ok_or(Error::HeaderNotFound(key.to_string()))
183 }
184}
185
186impl Display for Message {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 writeln!(
189 f,
190 "{} {}",
191 self.message_type.code(),
192 self.message_type.description()
193 )?;
194 for (key, value) in &self.headers {
195 writeln!(f, "{key}: {value}")?;
196 }
197 writeln!(f)?;
198 Ok(())
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use crate::test_util::{cover_debug, cover_error};
206
207 fn check_parse(input: &[u8], expected: MessageType) {
208 let (input, message) =
209 MessageType::from_bytes(input).expect("Failed to parse message type");
210 assert_eq!(message, expected);
211 assert_eq!(input, &b""[..])
212 }
213
214 #[test]
215 fn test_coverage() -> Result<(), Box<dyn std::error::Error>> {
216 let message = Message::new(MessageType::Log, vec![]);
217 cover_debug(&message);
218
219 let error = Error::HeaderNotFound("text".to_string());
220 cover_error(&error);
221 cover_debug(&error);
222
223 Ok(())
224 }
225
226 #[test]
227 fn test_message_codes() {
228 assert_eq!(MessageType::Capabilities.code(), 100);
229 assert_eq!(MessageType::Log.code(), 101);
230 assert_eq!(MessageType::Status.code(), 102);
231 assert_eq!(MessageType::URIStart.code(), 200);
232 assert_eq!(MessageType::URIDone.code(), 201);
233 assert_eq!(MessageType::URIFailure.code(), 400);
234 assert_eq!(MessageType::GeneralFailure.code(), 401);
235 assert_eq!(MessageType::URIAcquire.code(), 600);
236 assert_eq!(MessageType::Configuration.code(), 601);
237 }
238
239 #[test]
240 fn test_message_descriptions() {
241 assert_eq!(MessageType::Capabilities.description(), "Capabilities");
242 assert_eq!(MessageType::Log.description(), "Log");
243 assert_eq!(MessageType::Status.description(), "Status");
244 assert_eq!(MessageType::URIStart.description(), "URI Start");
245 assert_eq!(MessageType::URIDone.description(), "URI Done");
246 assert_eq!(MessageType::URIFailure.description(), "URI Failure");
247 assert_eq!(MessageType::GeneralFailure.description(), "General Failure");
248 assert_eq!(MessageType::URIAcquire.description(), "URI Acquire");
249 assert_eq!(MessageType::Configuration.description(), "Configuration");
250 }
251
252 #[test]
253 fn test_message_type_from_bytes() {
254 check_parse(b"100 Capabilities\n", MessageType::Capabilities);
255 check_parse(b"101 Log\n", MessageType::Log);
256 check_parse(b"102 Status\n", MessageType::Status);
257 check_parse(b"200 URI Start\n", MessageType::URIStart);
258 check_parse(b"201 URI Done\n", MessageType::URIDone);
259 check_parse(b"400 URI Failure\n", MessageType::URIFailure);
260 check_parse(b"401 General Failure\n", MessageType::GeneralFailure);
261 check_parse(b"600 URI Acquire\n", MessageType::URIAcquire);
262 check_parse(b"601 Configuration\n", MessageType::Configuration);
263 }
264
265 #[test]
266 #[should_panic(expected = "Unknown message type")]
267 fn test_unimplemented_message_type() {
268 let _ = MessageType::from_bytes(b"999 Unknown\n").unwrap();
269 }
270
271 #[test]
272 fn test_key_value_pair() {
273 let (input, (key, value)) = key_value_pair(b"Key: Value\n").unwrap();
274 assert_eq!(key, "Key");
275 assert_eq!(value, "Value");
276 assert_eq!(input, &b""[..]);
277 }
278
279 #[test]
280 fn test_message_from_bytes() -> Result<(), Box<dyn std::error::Error>> {
281 let input = b"100 Capabilities\n\
282 Key: Value\n\
283 \n";
284 let message = Message::from_bytes(input)?;
285 assert_eq!(message.message_type, MessageType::Capabilities);
286
287 let (key, value) = message.headers.first().unwrap();
288 assert_eq!(key, "Key");
289 assert_eq!(value, "Value");
290 Ok(())
291 }
292
293 #[test]
294 fn test_too_much_data() -> Result<(), Box<dyn std::error::Error>> {
295 let input = b"100 Capabilities\n\
296 Key: Value\n\
297 \ntoo much data";
298 let message = Message::from_bytes(input);
299 match message {
300 Err(Error::MessageTooMuchData) => (),
301 _ => panic!("Unexpected error"), }
303 Ok(())
304 }
305
306 #[test]
307 fn test_buggy_message() -> Result<(), Box<dyn std::error::Error>> {
308 let input = b"100 Capabilities\n\
309 No header line\n\
310 \n";
311 let message = Message::from_bytes(input);
312 match message {
313 Err(Error::MessageParse(_)) => (),
314 _ => panic!("Unexpected error"), }
316 Ok(())
317 }
318
319 #[test]
320 fn test_message_write() -> Result<(), Box<dyn std::error::Error>> {
321 let message = Message {
322 message_type: MessageType::Capabilities,
323 headers: vec![("Key".to_string(), "Value".to_string())],
324 };
325
326 let output = format!("{message}");
327 assert_eq!(
328 output,
329 "100 Capabilities\n\
330 Key: Value\n\
331 \n"
332 );
333 Ok(())
334 }
335
336 #[test]
337 fn test_round_trip() -> Result<(), Box<dyn std::error::Error>> {
338 let message = Message {
339 message_type: MessageType::Capabilities,
340 headers: vec![("Key".to_string(), "Value".to_string())],
341 };
342
343 let output = format!("{message}");
344 let parsed_message = Message::from_bytes(output.as_bytes())?;
345 assert_eq!(parsed_message, message);
346 Ok(())
347 }
348
349 #[test]
350 fn test_description() {
351 let message = Message {
352 message_type: MessageType::Capabilities,
353 headers: vec![],
354 };
355 assert_eq!(message.description(), "100 Capabilities");
356 }
357}