tokio_hglib/
message.rs

1//! Utilities for parsing and building command-server messages.
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use std::cmp;
5use std::collections::HashSet;
6use std::error;
7#[cfg(unix)]
8use std::ffi::OsStr;
9use std::io::{self, Cursor};
10use std::mem;
11#[cfg(unix)]
12use std::os::unix::ffi::OsStrExt;
13
14/// Information reported by the server at the initial handshake.
15#[derive(Clone, Debug, Eq, PartialEq)]
16pub struct ServerSpec {
17    pub capabilities: HashSet<String>,
18    pub encoding: Option<String>,
19    pub process_id: Option<u32>,
20    pub process_group_id: Option<u32>,
21}
22
23/// Parses "hello" response into `ServerSpec`.
24pub fn parse_hello(src: Bytes) -> io::Result<ServerSpec> {
25    let mut capabilities = HashSet::new();
26    let mut encoding = None;
27    let mut process_id = None;
28    let mut process_group_id = None;
29
30    for line in decode_latin1(src).lines() {
31        let mut split = line.splitn(2, ':').fuse();
32        let key = split.next().unwrap();
33        let value = split
34            .next()
35            .map(|s| s.trim_start())
36            .ok_or(new_parse_error(format!(
37                "malformed line in hello message: {}",
38                line
39            )))?;
40
41        match key {
42            "capabilities" => {
43                capabilities.extend(value.split_whitespace().map(|s| s.to_owned()));
44            }
45            "encoding" => {
46                encoding = Some(value.to_owned());
47            }
48            "pid" => {
49                let n = value
50                    .parse()
51                    .map_err(|_| new_parse_error(format!("malformed pid: {}", value)))?;
52                process_id = Some(n);
53            }
54            "pgid" => {
55                let n = value
56                    .parse()
57                    .map_err(|_| new_parse_error(format!("malformed pgid: {}", value)))?;
58                process_group_id = Some(n);
59            }
60            _ => {} // ignores unknown key
61        }
62    }
63
64    Ok(ServerSpec {
65        capabilities,
66        encoding,
67        process_id,
68        process_group_id,
69    })
70}
71
72/// Parses command exit code into integer.
73pub fn parse_result_code(data: Bytes) -> io::Result<i32> {
74    if data.len() < mem::size_of::<i32>() {
75        let msg = format!("result code too short: {}", data.len());
76        return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
77    }
78    Ok(Cursor::new(data).get_i32())
79}
80
81/// Packs command exit code into bytes.
82pub fn pack_result_code(code: i32) -> Bytes {
83    let mut data = BytesMut::with_capacity(mem::size_of::<i32>());
84    data.put_i32(code);
85    data.freeze()
86}
87
88/// Packs command arguments of arbitrary encoding into bytes.
89///
90/// # Panics
91///
92/// Panics if argument contains `\0` character.
93pub fn pack_args(args: impl IntoIterator<Item = impl AsRef<[u8]>>) -> Bytes {
94    let mut args_iter = args.into_iter();
95    if let Some(a) = args_iter.next() {
96        assert!(!a.as_ref().contains(&0), "argument shouldn't contain NUL");
97        let mut dst = BytesMut::with_capacity(cmp::max(a.as_ref().len(), 200));
98        dst.put_slice(a.as_ref());
99        for a in args_iter {
100            assert!(!a.as_ref().contains(&0), "argument shouldn't contain NUL");
101            dst.reserve(1 + a.as_ref().len());
102            dst.put_u8(b'\0');
103            dst.put_slice(a.as_ref());
104        }
105        dst.freeze()
106    } else {
107        Bytes::new()
108    }
109}
110
111/// Packs command arguments of platform encoding into bytes.
112#[cfg(unix)]
113pub fn pack_args_os(args: impl IntoIterator<Item = impl AsRef<OsStr>>) -> Bytes {
114    pack_args(args.into_iter().map(|a| a.as_ref().as_bytes().to_owned()))
115}
116
117fn decode_latin1(s: impl AsRef<[u8]>) -> String {
118    s.as_ref().iter().map(|&c| c as char).collect()
119}
120
121// TODO: error type
122fn new_parse_error(error: impl Into<Box<dyn error::Error + Send + Sync>>) -> io::Error {
123    io::Error::new(io::ErrorKind::InvalidData, error)
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use std::iter::FromIterator;
130    use std::panic;
131
132    #[test]
133    fn parse_hello_empty() {
134        let spec = ServerSpec {
135            capabilities: HashSet::new(),
136            encoding: None,
137            process_id: None,
138            process_group_id: None,
139        };
140        assert_eq!(parse_hello(Bytes::from_static(b"")).unwrap(), spec);
141    }
142
143    #[test]
144    fn parse_hello_some() {
145        let src = [
146            b"capabilities: getencoding runcommand".as_ref(),
147            b"encoding: utf-8".as_ref(),
148            b"pid: 12345".as_ref(),
149            b"pgid: 6789",
150        ]
151        .join(&b'\n');
152        let spec = ServerSpec {
153            capabilities: HashSet::from_iter(
154                ["getencoding", "runcommand"].iter().map(|&s| s.to_owned()),
155            ),
156            encoding: Some("utf-8".to_owned()),
157            process_id: Some(12345),
158            process_group_id: Some(6789),
159        };
160        assert_eq!(parse_hello(Bytes::from(src)).unwrap(), spec);
161    }
162
163    #[test]
164    fn parse_hello_unsupported() {
165        let src = [
166            b"capabilities: runcommand unsupported".as_ref(),
167            b"unsupported: value".as_ref(),
168            b"pid: 12345".as_ref(),
169        ]
170        .join(&b'\n');
171        let spec = ServerSpec {
172            capabilities: HashSet::from_iter(
173                ["runcommand", "unsupported"].iter().map(|&s| s.to_owned()),
174            ),
175            encoding: None,
176            process_id: Some(12345),
177            process_group_id: None,
178        };
179        assert_eq!(parse_hello(Bytes::from(src)).unwrap(), spec);
180    }
181
182    #[test]
183    fn parse_hello_malformed() {
184        assert_eq!(
185            parse_hello(Bytes::from_static(b"\n")).unwrap_err().kind(),
186            io::ErrorKind::InvalidData
187        );
188        assert_eq!(
189            parse_hello(Bytes::from_static(b"caba")).unwrap_err().kind(),
190            io::ErrorKind::InvalidData
191        );
192        assert_eq!(
193            parse_hello(Bytes::from_static(b"pid: foo"))
194                .unwrap_err()
195                .kind(),
196            io::ErrorKind::InvalidData
197        );
198        assert_eq!(
199            parse_hello(Bytes::from_static(b"pgid: foo"))
200                .unwrap_err()
201                .kind(),
202            io::ErrorKind::InvalidData
203        );
204    }
205
206    #[test]
207    fn pack_args_some() {
208        assert_eq!(pack_args(&[] as &[&[u8]]), Bytes::new());
209        assert_eq!(pack_args(&[b"foo".as_ref()]), Bytes::from_static(b"foo"));
210        assert_eq!(
211            pack_args(&[b"foo".as_ref(), b"".as_ref(), b"bar".as_ref()]),
212            Bytes::from_static(b"foo\0\0bar")
213        );
214    }
215
216    #[test]
217    fn pack_args_nul() {
218        assert!(panic::catch_unwind(|| pack_args(&[b"\0"])).is_err());
219        assert!(panic::catch_unwind(|| pack_args(&[b"fo\0"])).is_err());
220        assert!(panic::catch_unwind(|| pack_args(&[b"foo", b"\0ar"])).is_err());
221        assert!(panic::catch_unwind(|| pack_args(&[b"foo", b"bar", b"b\0z"])).is_err());
222    }
223}