#![deny(missing_docs)]
use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
use std::ffi::{OsStr, OsString};
use std::fmt::{self, Display};
use std::path::PathBuf;
pub const DEFAULT_SSH_PORT: u16 = 22;
#[derive(Clone, Debug, Default, Eq, Hash, PartialEq)]
pub struct Address {
pub host: String,
pub port: Option<u16>,
}
#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
pub enum AddressError {
#[error("invalid address format")]
InvalidFormat,
#[error("invalid address port")]
InvalidPort,
}
impl Address {
pub fn new(host: &str, port: u16) -> Address {
Address {
host: host.to_string(),
port: Some(port),
}
}
pub fn from_host(host: &str) -> Address {
Address {
host: host.to_string(),
port: None,
}
}
}
impl std::str::FromStr for Address {
type Err = AddressError;
fn from_str(address: &str) -> Result<Self, Self::Err> {
let mut iter = address.split(':');
if let Some(host) = iter.next() {
if host.is_empty() {
return Err(AddressError::InvalidFormat);
}
if let Some(port) = iter.next() {
if iter.next().is_some() {
return Err(AddressError::InvalidFormat);
}
if let Ok(port) = port.parse() {
Ok(Address::new(host, port))
} else {
Err(AddressError::InvalidPort)
}
} else {
Ok(Address::from_host(address))
}
} else {
Err(AddressError::InvalidFormat)
}
}
}
struct AddressVisitor;
impl<'de> de::Visitor<'de> for AddressVisitor {
type Value = Address;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("host[:port]")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
match value.parse::<Address>() {
Ok(addr) => Ok(addr),
Err(AddressError::InvalidFormat) => {
Err(E::custom("invalid address format"))
}
Err(AddressError::InvalidPort) => {
Err(E::custom("invalid port number"))
}
}
}
}
impl<'de> Deserialize<'de> for Address {
fn deserialize<D>(deserializer: D) -> Result<Address, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(AddressVisitor)
}
}
impl Serialize for Address {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl Display for Address {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
if let Some(port) = self.port {
write!(f, "{}:{}", self.host, port)
} else {
write!(f, "{}", self.host)
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SshParams {
pub address: Address,
pub identity: Option<PathBuf>,
pub user: Option<String>,
pub strict_host_key_checking: bool,
}
impl Default for SshParams {
fn default() -> SshParams {
SshParams {
address: Address::default(),
identity: None,
user: None,
strict_host_key_checking: true,
}
}
}
impl SshParams {
pub fn command<S: AsRef<OsStr>>(&self, args: &[S]) -> Vec<OsString> {
let mut output: Vec<OsString> = Vec::new();
output.push("ssh".into());
if !self.strict_host_key_checking {
output.extend_from_slice(&[
"-oStrictHostKeyChecking=no".into(),
"-oUserKnownHostsFile=/dev/null".into(),
]);
}
output.push("-oBatchMode=yes".into());
if let Some(identity) = &self.identity {
output.extend_from_slice(&["-i".into(), identity.into()]);
}
if let Some(port) = self.address.port {
output.extend_from_slice(&["-p".into(), port.to_string().into()]);
}
let target = if let Some(user) = &self.user {
format!("{}@{}", user, self.address.host)
} else {
self.address.host.clone()
};
output.push(target.into());
output.extend(args.iter().map(|arg| arg.into()));
output
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_test::{assert_tokens, Token};
use std::path::Path;
#[test]
fn test_address_parse() {
assert_eq!("a".parse(), Ok(Address::from_host("a")));
assert_eq!("a:1234".parse(), Ok(Address::new("a", 1234)));
assert_eq!("".parse::<Address>(), Err(AddressError::InvalidFormat));
assert_eq!("a:b".parse::<Address>(), Err(AddressError::InvalidPort));
assert_eq!(
"a:1234:5678".parse::<Address>(),
Err(AddressError::InvalidFormat)
);
}
#[test]
fn test_address_display() {
let addr = Address::from_host("abc");
assert_eq!(format!("{}", addr), "abc");
let addr = Address::new("abc", 123);
assert_eq!(format!("{}", addr), "abc:123");
}
#[test]
fn test_address_tokens() {
assert_tokens(&Address::from_host("abc"), &[Token::Str("abc")]);
assert_tokens(&Address::new("abc", 123), &[Token::Str("abc:123")]);
}
#[test]
fn test_command() {
let target = SshParams {
address: "localhost:9222".parse().unwrap(),
identity: Some(Path::new("/myIdentity").to_path_buf()),
user: Some("me".to_string()),
strict_host_key_checking: false,
};
let cmd = target.command(&["arg1", "arg2"]);
assert_eq!(
cmd,
vec![
"ssh",
"-oStrictHostKeyChecking=no",
"-oUserKnownHostsFile=/dev/null",
"-oBatchMode=yes",
"-i",
"/myIdentity",
"-p",
"9222",
"me@localhost",
"arg1",
"arg2"
]
);
}
}