nbssh/
lib.rs

1#![deny(missing_docs)]
2
3//! SSH command generator. Example usage:
4//!
5//! ```rust
6//! use nbssh::{Address, SshParams};
7//! use std::process::Command;
8//!
9//! let params = SshParams {
10//!   address: Address::from_host("myHost"),
11//!   ..Default::default()
12//! };
13//! let args = params.command(&["echo", "hello"]);
14//! Command::new(&args[0]).args(&args[1..]).status().unwrap();
15//! ```
16
17use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
18use std::ffi::{OsStr, OsString};
19use std::fmt::{self, Display};
20use std::path::PathBuf;
21
22/// Default SSH port number 22.
23pub const DEFAULT_SSH_PORT: u16 = 22;
24
25/// Host and port number. Can be serialized and deserialized with
26/// serde using the "host[:port]" format.
27#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
28pub struct Address {
29    /// Host name or IP address.
30    pub host: String,
31    /// Port number.
32    pub port: Option<u16>,
33}
34
35/// Address parse errors.
36#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
37pub enum AddressError {
38    /// The address either contains more than one colon or is empty.
39    #[error("invalid address format")]
40    InvalidFormat,
41
42    /// The port number could not be parsed as a u16.
43    #[error("invalid address port")]
44    InvalidPort,
45}
46
47impl Address {
48    /// Create a new address.
49    pub fn new(host: &str, port: u16) -> Address {
50        Address {
51            host: host.to_string(),
52            port: Some(port),
53        }
54    }
55
56    /// Create a new address with no port number set.
57    pub fn from_host(host: &str) -> Address {
58        Address {
59            host: host.to_string(),
60            port: None,
61        }
62    }
63}
64
65impl std::str::FromStr for Address {
66    type Err = AddressError;
67
68    /// Parse an address in "host[:port]" format.
69    fn from_str(address: &str) -> Result<Self, Self::Err> {
70        let mut iter = address.split(':');
71        if let Some(host) = iter.next() {
72            // Reject empty hosts
73            if host.is_empty() {
74                return Err(AddressError::InvalidFormat);
75            }
76
77            if let Some(port) = iter.next() {
78                // Reject more than two colons
79                if iter.next().is_some() {
80                    return Err(AddressError::InvalidFormat);
81                }
82
83                // Parse the port
84                if let Ok(port) = port.parse() {
85                    Ok(Address::new(host, port))
86                } else {
87                    Err(AddressError::InvalidPort)
88                }
89            } else {
90                Ok(Address::from_host(address))
91            }
92        } else {
93            Err(AddressError::InvalidFormat)
94        }
95    }
96}
97
98struct AddressVisitor;
99
100impl<'de> de::Visitor<'de> for AddressVisitor {
101    type Value = Address;
102
103    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
104        formatter.write_str("host[:port]")
105    }
106
107    fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
108    where
109        E: de::Error,
110    {
111        match value.parse::<Address>() {
112            Ok(addr) => Ok(addr),
113            Err(AddressError::InvalidFormat) => {
114                Err(E::custom("invalid address format"))
115            }
116            Err(AddressError::InvalidPort) => {
117                Err(E::custom("invalid port number"))
118            }
119        }
120    }
121}
122
123impl<'de> Deserialize<'de> for Address {
124    fn deserialize<D>(deserializer: D) -> Result<Address, D::Error>
125    where
126        D: Deserializer<'de>,
127    {
128        deserializer.deserialize_str(AddressVisitor)
129    }
130}
131
132impl Serialize for Address {
133    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
134    where
135        S: Serializer,
136    {
137        serializer.serialize_str(&self.to_string())
138    }
139}
140
141impl Display for Address {
142    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
143        if let Some(port) = self.port {
144            write!(f, "{}:{}", self.host, port)
145        } else {
146            write!(f, "{}", self.host)
147        }
148    }
149}
150
151/// Inputs for an SSH command, excluding the remote command itself.
152#[derive(Clone, Debug, Eq, PartialEq)]
153pub struct SshParams {
154    /// Target address.
155    pub address: Address,
156
157    /// Optional identity path ("-i" option).
158    pub identity: Option<PathBuf>,
159
160    /// Target user name.
161    pub user: Option<String>,
162
163    /// If false, skip the known-host check and do not add the target
164    /// to the known-hosts file. This is useful, for example, with
165    /// ephemeral VMs.
166    ///
167    /// Setting this to false adds these flags:
168    /// 1. -oStrictHostKeyChecking=no
169    /// 2. -oUserKnownHostsFile=/dev/null
170    pub strict_host_key_checking: bool,
171}
172
173impl Default for SshParams {
174    fn default() -> SshParams {
175        SshParams {
176            address: Address::default(),
177            identity: None,
178            user: None,
179            strict_host_key_checking: true,
180        }
181    }
182}
183
184impl SshParams {
185    /// Create a full SSH command.
186    pub fn command<S: AsRef<OsStr>>(&self, args: &[S]) -> Vec<OsString> {
187        let mut output: Vec<OsString> = Vec::new();
188        output.push("ssh".into());
189
190        if !self.strict_host_key_checking {
191            output.extend_from_slice(&[
192                "-oStrictHostKeyChecking=no".into(),
193                "-oUserKnownHostsFile=/dev/null".into(),
194            ]);
195        }
196        output.push("-oBatchMode=yes".into());
197
198        if let Some(identity) = &self.identity {
199            output.extend_from_slice(&["-i".into(), identity.into()]);
200        }
201
202        if let Some(port) = self.address.port {
203            output.extend_from_slice(&["-p".into(), port.to_string().into()]);
204        }
205
206        let target = if let Some(user) = &self.user {
207            format!("{}@{}", user, self.address.host)
208        } else {
209            self.address.host.clone()
210        };
211
212        output.push(target.into());
213        output.extend(args.iter().map(|arg| arg.into()));
214
215        output
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use serde_test::{assert_tokens, Token};
223    use std::path::Path;
224
225    #[test]
226    fn test_address_parse() {
227        assert_eq!("a".parse(), Ok(Address::from_host("a")));
228        assert_eq!("a:1234".parse(), Ok(Address::new("a", 1234)));
229        assert_eq!("".parse::<Address>(), Err(AddressError::InvalidFormat));
230        assert_eq!("a:b".parse::<Address>(), Err(AddressError::InvalidPort));
231        assert_eq!(
232            "a:1234:5678".parse::<Address>(),
233            Err(AddressError::InvalidFormat)
234        );
235    }
236
237    #[test]
238    fn test_address_display() {
239        let addr = Address::from_host("abc");
240        assert_eq!(format!("{}", addr), "abc");
241        let addr = Address::new("abc", 123);
242        assert_eq!(format!("{}", addr), "abc:123");
243    }
244
245    #[test]
246    fn test_address_tokens() {
247        assert_tokens(&Address::from_host("abc"), &[Token::Str("abc")]);
248        assert_tokens(&Address::new("abc", 123), &[Token::Str("abc:123")]);
249    }
250
251    #[test]
252    fn test_command() {
253        let target = SshParams {
254            address: "localhost:9222".parse().unwrap(),
255            identity: Some(Path::new("/myIdentity").to_path_buf()),
256            user: Some("me".to_string()),
257            strict_host_key_checking: false,
258        };
259        let cmd = target.command(&["arg1", "arg2"]);
260        assert_eq!(
261            cmd,
262            vec![
263                "ssh",
264                "-oStrictHostKeyChecking=no",
265                "-oUserKnownHostsFile=/dev/null",
266                "-oBatchMode=yes",
267                "-i",
268                "/myIdentity",
269                "-p",
270                "9222",
271                "me@localhost",
272                "arg1",
273                "arg2"
274            ]
275        );
276    }
277}