Skip to main content

apt_transport/
message.rs

1//! Raw message types and parsing for APT transport protocol.
2//!
3//! You probably want to use the higher-level abstractions in
4//! the root of this crate instead of working with these types directly.
5
6// Copyright (c) Microsoft Corporation.
7// Licensed under the MIT License.
8
9use std::fmt::Display;
10
11use nom::bytes::complete::take_until;
12use nom::character::complete::{char, digit1, newline, space0};
13use nom::combinator::map_res;
14use nom::multi::many0;
15use nom::{IResult, Parser};
16
17use crate::Error;
18
19#[derive(Debug, PartialEq, Clone, Copy, Eq)]
20pub enum MessageType {
21    Capabilities,
22    Log,
23    Status,
24    URIStart,
25    URIDone,
26    URIFailure,
27    GeneralFailure,
28    URIAcquire,
29    Configuration,
30}
31
32impl MessageType {
33    fn code(&self) -> u16 {
34        match self {
35            MessageType::Capabilities => 100,
36            MessageType::Log => 101,
37            MessageType::Status => 102,
38            MessageType::URIStart => 200,
39            MessageType::URIDone => 201,
40            MessageType::URIFailure => 400,
41            MessageType::GeneralFailure => 401,
42            MessageType::URIAcquire => 600,
43            MessageType::Configuration => 601,
44        }
45    }
46
47    fn description(&self) -> &str {
48        match self {
49            MessageType::Capabilities => "Capabilities",
50            MessageType::Log => "Log",
51            MessageType::Status => "Status",
52            MessageType::URIStart => "URI Start",
53            MessageType::URIDone => "URI Done",
54            MessageType::URIFailure => "URI Failure",
55            MessageType::GeneralFailure => "General Failure",
56            MessageType::URIAcquire => "URI Acquire",
57            MessageType::Configuration => "Configuration",
58        }
59    }
60
61    pub fn from_bytes(input: &[u8]) -> IResult<&[u8], MessageType> {
62        // The first line of a message is the message type and a description,
63        // followed by a newline
64        let (input, code) = digit1(input)?;
65        let (input, _) = take_until("\n")(input)?;
66        let (input, _) = newline(input)?;
67
68        match code {
69            b"100" => Ok((input, MessageType::Capabilities)),
70            b"101" => Ok((input, MessageType::Log)),
71            b"102" => Ok((input, MessageType::Status)),
72            b"200" => Ok((input, MessageType::URIStart)),
73            b"201" => Ok((input, MessageType::URIDone)),
74            b"400" => Ok((input, MessageType::URIFailure)),
75            b"401" => Ok((input, MessageType::GeneralFailure)),
76            b"600" => Ok((input, MessageType::URIAcquire)),
77            b"601" => Ok((input, MessageType::Configuration)),
78            _ => unimplemented!("Unknown message type: {:?}", code),
79        }
80    }
81}
82
83#[derive(Debug, PartialEq)]
84pub struct Message {
85    pub message_type: MessageType,
86    pub headers: Vec<(String, String)>,
87}
88
89fn key_value_pair(input: &[u8]) -> IResult<&[u8], (String, String)> {
90    let mut parse_key = map_res(take_until(":"), |buf| std::str::from_utf8(buf));
91    let mut parse_value = map_res(take_until("\n"), |buf| std::str::from_utf8(buf));
92
93    let (input, key) = parse_key.parse(input)?;
94    let (input, _) = char(':')(input)?;
95    let (input, _) = space0(input)?;
96    let (input, value) = parse_value.parse(input)?;
97    let (input, _) = newline(input)?;
98
99    let res = (key.to_string(), value.to_string());
100    Ok((input, res))
101}
102
103impl Message {
104    //
105    // Construction and logging functions cannot log, as they are used by the logger
106    //
107    pub fn new(message_type: MessageType, headers: Vec<(&str, &str)>) -> Message {
108        Message {
109            message_type,
110            headers: headers
111                .iter()
112                .map(|(k, v)| (k.to_string(), v.to_string()))
113                .collect(),
114        }
115    }
116
117    fn parse(input: &[u8]) -> IResult<&[u8], Message> {
118        // Parse the MessageType from the message
119        let (input, message_type) = MessageType::from_bytes(input)?;
120
121        // Now take the headers; these are key-value pairs separated by a colon
122        let (input, headers) = many0(key_value_pair).parse(input)?;
123
124        // Now take the final newline.
125        let (input, _) = newline(input)?;
126
127        Ok((
128            input,
129            Message {
130                message_type,
131                headers,
132            },
133        ))
134    }
135
136    pub fn from_bytes(input: &[u8]) -> Result<Message, Error> {
137        match Message::parse(input) {
138            Ok((b"", message)) => Ok(message),
139            Ok((_, _)) => Err(Error::MessageTooMuchData),
140            Err(err) => Err(Error::MessageParse(format!("{err}"))),
141        }
142    }
143
144    pub fn status(message: &str) -> Self {
145        Self::new(MessageType::Status, vec![("Message", message)])
146    }
147
148    pub fn general_failure(message: &str) -> Self {
149        Self::new(MessageType::GeneralFailure, vec![("Message", message)])
150    }
151
152    pub fn uri_start(uri: &str, size: u64, last_modified: &str) -> Self {
153        Self::new(
154            MessageType::URIStart,
155            vec![
156                ("URI", uri),
157                ("Size", &size.to_string()),
158                ("Last-Modified", last_modified),
159            ],
160        )
161    }
162
163    pub fn uri_failure(uri: &str, message: &str) -> Self {
164        Self::new(
165            MessageType::URIFailure,
166            vec![("URI", uri), ("Message", message)],
167        )
168    }
169
170    pub fn uri_success(uri: &str, filename: &str) -> Self {
171        Self::new(
172            MessageType::URIDone,
173            vec![("URI", uri), ("Filename", filename)],
174        )
175    }
176
177    pub fn header(&self, key: &str) -> Result<&str, Error> {
178        self.headers
179            .iter()
180            .find(|(k, _)| k == key)
181            .map(|(_, v)| v.as_str())
182            .ok_or(Error::HeaderNotFound(key.to_string()))
183    }
184}
185
186impl Display for Message {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        writeln!(
189            f,
190            "{} {}",
191            self.message_type.code(),
192            self.message_type.description()
193        )?;
194        for (key, value) in &self.headers {
195            writeln!(f, "{key}: {value}")?;
196        }
197        writeln!(f)?;
198        Ok(())
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use crate::test_util::{cover_debug, cover_error};
206
207    fn check_parse(input: &[u8], expected: MessageType) {
208        let (input, message) =
209            MessageType::from_bytes(input).expect("Failed to parse message type");
210        assert_eq!(message, expected);
211        assert_eq!(input, &b""[..])
212    }
213
214    #[test]
215    fn test_coverage() -> Result<(), Box<dyn std::error::Error>> {
216        let message = Message::new(MessageType::Log, vec![]);
217        cover_debug(&message);
218
219        let error = Error::HeaderNotFound("text".to_string());
220        cover_error(&error);
221        cover_debug(&error);
222
223        Ok(())
224    }
225
226    #[test]
227    fn test_message_codes() {
228        assert_eq!(MessageType::Capabilities.code(), 100);
229        assert_eq!(MessageType::Log.code(), 101);
230        assert_eq!(MessageType::Status.code(), 102);
231        assert_eq!(MessageType::URIStart.code(), 200);
232        assert_eq!(MessageType::URIDone.code(), 201);
233        assert_eq!(MessageType::URIFailure.code(), 400);
234        assert_eq!(MessageType::GeneralFailure.code(), 401);
235        assert_eq!(MessageType::URIAcquire.code(), 600);
236        assert_eq!(MessageType::Configuration.code(), 601);
237    }
238
239    #[test]
240    fn test_message_descriptions() {
241        assert_eq!(MessageType::Capabilities.description(), "Capabilities");
242        assert_eq!(MessageType::Log.description(), "Log");
243        assert_eq!(MessageType::Status.description(), "Status");
244        assert_eq!(MessageType::URIStart.description(), "URI Start");
245        assert_eq!(MessageType::URIDone.description(), "URI Done");
246        assert_eq!(MessageType::URIFailure.description(), "URI Failure");
247        assert_eq!(MessageType::GeneralFailure.description(), "General Failure");
248        assert_eq!(MessageType::URIAcquire.description(), "URI Acquire");
249        assert_eq!(MessageType::Configuration.description(), "Configuration");
250    }
251
252    #[test]
253    fn test_message_type_from_bytes() {
254        check_parse(b"100 Capabilities\n", MessageType::Capabilities);
255        check_parse(b"101 Log\n", MessageType::Log);
256        check_parse(b"102 Status\n", MessageType::Status);
257        check_parse(b"200 URI Start\n", MessageType::URIStart);
258        check_parse(b"201 URI Done\n", MessageType::URIDone);
259        check_parse(b"400 URI Failure\n", MessageType::URIFailure);
260        check_parse(b"401 General Failure\n", MessageType::GeneralFailure);
261        check_parse(b"600 URI Acquire\n", MessageType::URIAcquire);
262        check_parse(b"601 Configuration\n", MessageType::Configuration);
263    }
264
265    #[test]
266    #[should_panic(expected = "Unknown message type")]
267    fn test_unimplemented_message_type() {
268        let _ = MessageType::from_bytes(b"999 Unknown\n").unwrap();
269    }
270
271    #[test]
272    fn test_key_value_pair() {
273        let (input, (key, value)) = key_value_pair(b"Key: Value\n").unwrap();
274        assert_eq!(key, "Key");
275        assert_eq!(value, "Value");
276        assert_eq!(input, &b""[..]);
277    }
278
279    #[test]
280    fn test_message_from_bytes() -> Result<(), Box<dyn std::error::Error>> {
281        let input = b"100 Capabilities\n\
282                      Key: Value\n\
283                      \n";
284        let message = Message::from_bytes(input)?;
285        assert_eq!(message.message_type, MessageType::Capabilities);
286
287        let (key, value) = message.headers.first().unwrap();
288        assert_eq!(key, "Key");
289        assert_eq!(value, "Value");
290        Ok(())
291    }
292
293    #[test]
294    fn test_too_much_data() -> Result<(), Box<dyn std::error::Error>> {
295        let input = b"100 Capabilities\n\
296                      Key: Value\n\
297                      \ntoo much data";
298        let message = Message::from_bytes(input);
299        match message {
300            Err(Error::MessageTooMuchData) => (),
301            _ => panic!("Unexpected error"), // LCOV_EXCL_LINE
302        }
303        Ok(())
304    }
305
306    #[test]
307    fn test_buggy_message() -> Result<(), Box<dyn std::error::Error>> {
308        let input = b"100 Capabilities\n\
309                      No header line\n\
310                      \n";
311        let message = Message::from_bytes(input);
312        match message {
313            Err(Error::MessageParse(_)) => (),
314            _ => panic!("Unexpected error"), // LCOV_EXCL_LINE
315        }
316        Ok(())
317    }
318
319    #[test]
320    fn test_message_write() -> Result<(), Box<dyn std::error::Error>> {
321        let message = Message {
322            message_type: MessageType::Capabilities,
323            headers: vec![("Key".to_string(), "Value".to_string())],
324        };
325
326        let output = format!("{message}");
327        assert_eq!(
328            output,
329            "100 Capabilities\n\
330              Key: Value\n\
331              \n"
332        );
333        Ok(())
334    }
335
336    #[test]
337    fn test_round_trip() -> Result<(), Box<dyn std::error::Error>> {
338        let message = Message {
339            message_type: MessageType::Capabilities,
340            headers: vec![("Key".to_string(), "Value".to_string())],
341        };
342
343        let output = format!("{message}");
344        let parsed_message = Message::from_bytes(output.as_bytes())?;
345        assert_eq!(parsed_message, message);
346        Ok(())
347    }
348
349    #[test]
350    fn test_description() {
351        let message = Message {
352            message_type: MessageType::Capabilities,
353            headers: vec![],
354        };
355        assert_eq!(message.description(), "100 Capabilities");
356    }
357}