1use bytes::{Buf, BufMut, Bytes, BytesMut};
4use std::cmp;
5use std::collections::HashSet;
6use std::error;
7#[cfg(unix)]
8use std::ffi::OsStr;
9use std::io::{self, Cursor};
10use std::mem;
11#[cfg(unix)]
12use std::os::unix::ffi::OsStrExt;
13
14#[derive(Clone, Debug, Eq, PartialEq)]
16pub struct ServerSpec {
17 pub capabilities: HashSet<String>,
18 pub encoding: Option<String>,
19 pub process_id: Option<u32>,
20 pub process_group_id: Option<u32>,
21}
22
23pub fn parse_hello(src: Bytes) -> io::Result<ServerSpec> {
25 let mut capabilities = HashSet::new();
26 let mut encoding = None;
27 let mut process_id = None;
28 let mut process_group_id = None;
29
30 for line in decode_latin1(src).lines() {
31 let mut split = line.splitn(2, ':').fuse();
32 let key = split.next().unwrap();
33 let value = split
34 .next()
35 .map(|s| s.trim_start())
36 .ok_or(new_parse_error(format!(
37 "malformed line in hello message: {}",
38 line
39 )))?;
40
41 match key {
42 "capabilities" => {
43 capabilities.extend(value.split_whitespace().map(|s| s.to_owned()));
44 }
45 "encoding" => {
46 encoding = Some(value.to_owned());
47 }
48 "pid" => {
49 let n = value
50 .parse()
51 .map_err(|_| new_parse_error(format!("malformed pid: {}", value)))?;
52 process_id = Some(n);
53 }
54 "pgid" => {
55 let n = value
56 .parse()
57 .map_err(|_| new_parse_error(format!("malformed pgid: {}", value)))?;
58 process_group_id = Some(n);
59 }
60 _ => {} }
62 }
63
64 Ok(ServerSpec {
65 capabilities,
66 encoding,
67 process_id,
68 process_group_id,
69 })
70}
71
72pub fn parse_result_code(data: Bytes) -> io::Result<i32> {
74 if data.len() < mem::size_of::<i32>() {
75 let msg = format!("result code too short: {}", data.len());
76 return Err(io::Error::new(io::ErrorKind::InvalidData, msg));
77 }
78 Ok(Cursor::new(data).get_i32())
79}
80
81pub fn pack_result_code(code: i32) -> Bytes {
83 let mut data = BytesMut::with_capacity(mem::size_of::<i32>());
84 data.put_i32(code);
85 data.freeze()
86}
87
88pub fn pack_args(args: impl IntoIterator<Item = impl AsRef<[u8]>>) -> Bytes {
94 let mut args_iter = args.into_iter();
95 if let Some(a) = args_iter.next() {
96 assert!(!a.as_ref().contains(&0), "argument shouldn't contain NUL");
97 let mut dst = BytesMut::with_capacity(cmp::max(a.as_ref().len(), 200));
98 dst.put_slice(a.as_ref());
99 for a in args_iter {
100 assert!(!a.as_ref().contains(&0), "argument shouldn't contain NUL");
101 dst.reserve(1 + a.as_ref().len());
102 dst.put_u8(b'\0');
103 dst.put_slice(a.as_ref());
104 }
105 dst.freeze()
106 } else {
107 Bytes::new()
108 }
109}
110
111#[cfg(unix)]
113pub fn pack_args_os(args: impl IntoIterator<Item = impl AsRef<OsStr>>) -> Bytes {
114 pack_args(args.into_iter().map(|a| a.as_ref().as_bytes().to_owned()))
115}
116
117fn decode_latin1(s: impl AsRef<[u8]>) -> String {
118 s.as_ref().iter().map(|&c| c as char).collect()
119}
120
121fn new_parse_error(error: impl Into<Box<dyn error::Error + Send + Sync>>) -> io::Error {
123 io::Error::new(io::ErrorKind::InvalidData, error)
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::iter::FromIterator;
130 use std::panic;
131
132 #[test]
133 fn parse_hello_empty() {
134 let spec = ServerSpec {
135 capabilities: HashSet::new(),
136 encoding: None,
137 process_id: None,
138 process_group_id: None,
139 };
140 assert_eq!(parse_hello(Bytes::from_static(b"")).unwrap(), spec);
141 }
142
143 #[test]
144 fn parse_hello_some() {
145 let src = [
146 b"capabilities: getencoding runcommand".as_ref(),
147 b"encoding: utf-8".as_ref(),
148 b"pid: 12345".as_ref(),
149 b"pgid: 6789",
150 ]
151 .join(&b'\n');
152 let spec = ServerSpec {
153 capabilities: HashSet::from_iter(
154 ["getencoding", "runcommand"].iter().map(|&s| s.to_owned()),
155 ),
156 encoding: Some("utf-8".to_owned()),
157 process_id: Some(12345),
158 process_group_id: Some(6789),
159 };
160 assert_eq!(parse_hello(Bytes::from(src)).unwrap(), spec);
161 }
162
163 #[test]
164 fn parse_hello_unsupported() {
165 let src = [
166 b"capabilities: runcommand unsupported".as_ref(),
167 b"unsupported: value".as_ref(),
168 b"pid: 12345".as_ref(),
169 ]
170 .join(&b'\n');
171 let spec = ServerSpec {
172 capabilities: HashSet::from_iter(
173 ["runcommand", "unsupported"].iter().map(|&s| s.to_owned()),
174 ),
175 encoding: None,
176 process_id: Some(12345),
177 process_group_id: None,
178 };
179 assert_eq!(parse_hello(Bytes::from(src)).unwrap(), spec);
180 }
181
182 #[test]
183 fn parse_hello_malformed() {
184 assert_eq!(
185 parse_hello(Bytes::from_static(b"\n")).unwrap_err().kind(),
186 io::ErrorKind::InvalidData
187 );
188 assert_eq!(
189 parse_hello(Bytes::from_static(b"caba")).unwrap_err().kind(),
190 io::ErrorKind::InvalidData
191 );
192 assert_eq!(
193 parse_hello(Bytes::from_static(b"pid: foo"))
194 .unwrap_err()
195 .kind(),
196 io::ErrorKind::InvalidData
197 );
198 assert_eq!(
199 parse_hello(Bytes::from_static(b"pgid: foo"))
200 .unwrap_err()
201 .kind(),
202 io::ErrorKind::InvalidData
203 );
204 }
205
206 #[test]
207 fn pack_args_some() {
208 assert_eq!(pack_args(&[] as &[&[u8]]), Bytes::new());
209 assert_eq!(pack_args(&[b"foo".as_ref()]), Bytes::from_static(b"foo"));
210 assert_eq!(
211 pack_args(&[b"foo".as_ref(), b"".as_ref(), b"bar".as_ref()]),
212 Bytes::from_static(b"foo\0\0bar")
213 );
214 }
215
216 #[test]
217 fn pack_args_nul() {
218 assert!(panic::catch_unwind(|| pack_args(&[b"\0"])).is_err());
219 assert!(panic::catch_unwind(|| pack_args(&[b"fo\0"])).is_err());
220 assert!(panic::catch_unwind(|| pack_args(&[b"foo", b"\0ar"])).is_err());
221 assert!(panic::catch_unwind(|| pack_args(&[b"foo", b"bar", b"b\0z"])).is_err());
222 }
223}