1#![deny(missing_docs)]
2
3use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
18use std::ffi::{OsStr, OsString};
19use std::fmt::{self, Display};
20use std::path::PathBuf;
21
22pub const DEFAULT_SSH_PORT: u16 = 22;
24
25#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
28pub struct Address {
29 pub host: String,
31 pub port: Option<u16>,
33}
34
35#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
37pub enum AddressError {
38 #[error("invalid address format")]
40 InvalidFormat,
41
42 #[error("invalid address port")]
44 InvalidPort,
45}
46
47impl Address {
48 pub fn new(host: &str, port: u16) -> Address {
50 Address {
51 host: host.to_string(),
52 port: Some(port),
53 }
54 }
55
56 pub fn from_host(host: &str) -> Address {
58 Address {
59 host: host.to_string(),
60 port: None,
61 }
62 }
63}
64
65impl std::str::FromStr for Address {
66 type Err = AddressError;
67
68 fn from_str(address: &str) -> Result<Self, Self::Err> {
70 let mut iter = address.split(':');
71 if let Some(host) = iter.next() {
72 if host.is_empty() {
74 return Err(AddressError::InvalidFormat);
75 }
76
77 if let Some(port) = iter.next() {
78 if iter.next().is_some() {
80 return Err(AddressError::InvalidFormat);
81 }
82
83 if let Ok(port) = port.parse() {
85 Ok(Address::new(host, port))
86 } else {
87 Err(AddressError::InvalidPort)
88 }
89 } else {
90 Ok(Address::from_host(address))
91 }
92 } else {
93 Err(AddressError::InvalidFormat)
94 }
95 }
96}
97
98struct AddressVisitor;
99
100impl<'de> de::Visitor<'de> for AddressVisitor {
101 type Value = Address;
102
103 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
104 formatter.write_str("host[:port]")
105 }
106
107 fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
108 where
109 E: de::Error,
110 {
111 match value.parse::<Address>() {
112 Ok(addr) => Ok(addr),
113 Err(AddressError::InvalidFormat) => {
114 Err(E::custom("invalid address format"))
115 }
116 Err(AddressError::InvalidPort) => {
117 Err(E::custom("invalid port number"))
118 }
119 }
120 }
121}
122
123impl<'de> Deserialize<'de> for Address {
124 fn deserialize<D>(deserializer: D) -> Result<Address, D::Error>
125 where
126 D: Deserializer<'de>,
127 {
128 deserializer.deserialize_str(AddressVisitor)
129 }
130}
131
132impl Serialize for Address {
133 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134 where
135 S: Serializer,
136 {
137 serializer.serialize_str(&self.to_string())
138 }
139}
140
141impl Display for Address {
142 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
143 if let Some(port) = self.port {
144 write!(f, "{}:{}", self.host, port)
145 } else {
146 write!(f, "{}", self.host)
147 }
148 }
149}
150
151#[derive(Clone, Debug, Eq, PartialEq)]
153pub struct SshParams {
154 pub address: Address,
156
157 pub identity: Option<PathBuf>,
159
160 pub user: Option<String>,
162
163 pub strict_host_key_checking: bool,
171}
172
173impl Default for SshParams {
174 fn default() -> SshParams {
175 SshParams {
176 address: Address::default(),
177 identity: None,
178 user: None,
179 strict_host_key_checking: true,
180 }
181 }
182}
183
184impl SshParams {
185 pub fn command<S: AsRef<OsStr>>(&self, args: &[S]) -> Vec<OsString> {
187 let mut output: Vec<OsString> = Vec::new();
188 output.push("ssh".into());
189
190 if !self.strict_host_key_checking {
191 output.extend_from_slice(&[
192 "-oStrictHostKeyChecking=no".into(),
193 "-oUserKnownHostsFile=/dev/null".into(),
194 ]);
195 }
196 output.push("-oBatchMode=yes".into());
197
198 if let Some(identity) = &self.identity {
199 output.extend_from_slice(&["-i".into(), identity.into()]);
200 }
201
202 if let Some(port) = self.address.port {
203 output.extend_from_slice(&["-p".into(), port.to_string().into()]);
204 }
205
206 let target = if let Some(user) = &self.user {
207 format!("{}@{}", user, self.address.host)
208 } else {
209 self.address.host.clone()
210 };
211
212 output.push(target.into());
213 output.extend(args.iter().map(|arg| arg.into()));
214
215 output
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use serde_test::{assert_tokens, Token};
223 use std::path::Path;
224
225 #[test]
226 fn test_address_parse() {
227 assert_eq!("a".parse(), Ok(Address::from_host("a")));
228 assert_eq!("a:1234".parse(), Ok(Address::new("a", 1234)));
229 assert_eq!("".parse::<Address>(), Err(AddressError::InvalidFormat));
230 assert_eq!("a:b".parse::<Address>(), Err(AddressError::InvalidPort));
231 assert_eq!(
232 "a:1234:5678".parse::<Address>(),
233 Err(AddressError::InvalidFormat)
234 );
235 }
236
237 #[test]
238 fn test_address_display() {
239 let addr = Address::from_host("abc");
240 assert_eq!(format!("{}", addr), "abc");
241 let addr = Address::new("abc", 123);
242 assert_eq!(format!("{}", addr), "abc:123");
243 }
244
245 #[test]
246 fn test_address_tokens() {
247 assert_tokens(&Address::from_host("abc"), &[Token::Str("abc")]);
248 assert_tokens(&Address::new("abc", 123), &[Token::Str("abc:123")]);
249 }
250
251 #[test]
252 fn test_command() {
253 let target = SshParams {
254 address: "localhost:9222".parse().unwrap(),
255 identity: Some(Path::new("/myIdentity").to_path_buf()),
256 user: Some("me".to_string()),
257 strict_host_key_checking: false,
258 };
259 let cmd = target.command(&["arg1", "arg2"]);
260 assert_eq!(
261 cmd,
262 vec![
263 "ssh",
264 "-oStrictHostKeyChecking=no",
265 "-oUserKnownHostsFile=/dev/null",
266 "-oBatchMode=yes",
267 "-i",
268 "/myIdentity",
269 "-p",
270 "9222",
271 "me@localhost",
272 "arg1",
273 "arg2"
274 ]
275 );
276 }
277}