#![doc(html_root_url = "https://docs.rs/spurs/0.9.1")]
use std::{
io::Read,
net::{SocketAddr, TcpStream, ToSocketAddrs},
path::{Path, PathBuf},
sync::{Arc, Mutex},
thread::JoinHandle,
time::Duration,
};
use log::{debug, info, trace};
use ssh2::Session;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
#[derive(Debug, PartialEq, Eq)]
pub struct SshCommand {
cmd: String,
cwd: Option<PathBuf>,
use_bash: bool,
allow_error: bool,
dry_run: bool,
no_pty: bool,
}
#[derive(Debug)]
pub struct SshOutput {
pub stdout: String,
pub stderr: String,
}
#[derive(Debug)]
pub enum SshError {
KeyNotFound { file: String },
AuthFailed { key: std::path::PathBuf },
NonZeroExit { cmd: String, exit: i32 },
SshError { error: ssh2::Error },
IoError { error: std::io::Error },
}
pub struct SshShell {
tcp: TcpStream,
username: String,
key: PathBuf,
remote_name: String,
remote: SocketAddr,
sess: Arc<Mutex<Session>>,
dry_run_mode: bool,
}
pub struct SshSpawnHandle {
thread_handle: JoinHandle<(SshShell, Result<SshOutput, SshError>)>,
}
pub trait Execute: Sized {
fn run(&self, cmd: SshCommand) -> Result<SshOutput, SshError>;
fn duplicate(&self) -> Result<Self, SshError>;
fn reconnect(&mut self) -> Result<(), SshError>;
}
impl std::fmt::Display for SshError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> Result<(), std::fmt::Error> {
match self {
SshError::KeyNotFound { file } => write!(f, "no such key: {}", file),
SshError::AuthFailed { key } => {
write!(f, "authentication failed with private key: {:?}", key)
}
SshError::NonZeroExit { cmd, exit } => {
write!(f, "non-zero exit ({}) for command: {}", exit, cmd)
}
SshError::SshError { error } => write!(f, "{}", error),
SshError::IoError { error } => write!(f, "{}", error),
}
}
}
impl std::error::Error for SshError {}
impl std::convert::From<ssh2::Error> for SshError {
fn from(error: ssh2::Error) -> Self {
SshError::SshError { error }
}
}
impl std::convert::From<std::io::Error> for SshError {
fn from(error: std::io::Error) -> Self {
SshError::IoError { error }
}
}
impl SshCommand {
pub fn new(cmd: &str) -> Self {
SshCommand {
cmd: cmd.to_owned(),
cwd: None,
use_bash: false,
allow_error: false,
dry_run: false,
no_pty: false,
}
}
pub fn cwd<P: AsRef<Path>>(self, cwd: P) -> Self {
SshCommand {
cwd: Some(cwd.as_ref().to_owned()),
..self
}
}
pub fn use_bash(self) -> Self {
SshCommand {
use_bash: true,
..self
}
}
pub fn allow_error(self) -> Self {
SshCommand {
allow_error: true,
..self
}
}
pub fn dry_run(self, is_dry: bool) -> Self {
SshCommand {
dry_run: is_dry,
..self
}
}
pub fn no_pty(self) -> Self {
SshCommand {
no_pty: true,
..self
}
}
#[cfg(any(test, feature = "test"))]
pub fn make_cmd(
cmd: &str,
cwd: Option<PathBuf>,
use_bash: bool,
allow_error: bool,
dry_run: bool,
no_pty: bool,
) -> Self {
SshCommand {
cmd: cmd.into(),
cwd,
use_bash,
allow_error,
dry_run,
no_pty,
}
}
#[cfg(any(test, feature = "test"))]
pub fn cmd(&self) -> &str {
&self.cmd
}
}
impl SshShell {
pub fn with_default_key<A: ToSocketAddrs + std::fmt::Debug>(
username: &str,
remote: A,
) -> Result<Self, SshError> {
const DEFAULT_KEY_SUFFIX: &str = ".ssh/id_rsa";
let home = if let Some(home) = dirs::home_dir() {
home
} else {
return Err(SshError::KeyNotFound {
file: DEFAULT_KEY_SUFFIX.into(),
}
.into());
};
SshShell::with_key(username, remote, home.join(DEFAULT_KEY_SUFFIX))
}
pub fn with_any_key<A: Copy + ToSocketAddrs + std::fmt::Debug>(
username: &str,
remote: A,
) -> Result<Self, SshError> {
const DEFAULT_KEY_DIR: &str = ".ssh/";
let home = if let Some(home) = dirs::home_dir() {
home
} else {
return Err(SshError::KeyNotFound {
file: DEFAULT_KEY_DIR.into(),
});
};
let key_dir = home.join(DEFAULT_KEY_DIR);
for entry in std::fs::read_dir(&key_dir)? {
let entry = entry?;
let name = entry.file_name().into_string().unwrap();
if !name.ends_with(".pub") {
continue;
}
let (priv_key, _) = name.split_at(name.len() - 4);
let shell = SshShell::with_key(username, remote, key_dir.join(priv_key));
if shell.is_ok() {
return shell;
}
}
Err(SshError::KeyNotFound {
file: DEFAULT_KEY_DIR.into(),
})
}
pub fn with_key<A: ToSocketAddrs + std::fmt::Debug, P: AsRef<Path>>(
username: &str,
remote: A,
key: P,
) -> Result<Self, SshError> {
info!("New SSH shell: {}@{:?}", username, remote);
debug!("Using key: {:?}", key.as_ref());
debug!("Create new TCP stream...");
let tcp = TcpStream::connect(&remote)?;
tcp.set_read_timeout(Some(DEFAULT_TIMEOUT))?;
tcp.set_write_timeout(Some(DEFAULT_TIMEOUT))?;
let remote_name = format!("{:?}", remote);
let remote = remote.to_socket_addrs().unwrap().next().unwrap();
debug!("Create new SSH session...");
let mut sess = Session::new().unwrap();
sess.handshake(&tcp)?;
trace!("SSH session handshook.");
sess.userauth_pubkey_file(username, None, key.as_ref(), None)?;
if !sess.authenticated() {
return Err(SshError::AuthFailed {
key: key.as_ref().to_path_buf(),
}
.into());
}
trace!("SSH session authenticated.");
println!(
"{}",
console::style(format!("{}@{} ({})", username, remote_name, remote))
.green()
.bold()
);
Ok(SshShell {
tcp,
username: username.to_owned(),
key: key.as_ref().to_owned(),
remote_name,
remote,
sess: Arc::new(Mutex::new(sess)),
dry_run_mode: false,
})
}
pub fn from_existing(shell: &SshShell) -> Result<Self, SshError> {
info!("New SSH shell: {}@{:?}", shell.username, shell.remote);
debug!("Using key: {:?}", shell.key);
debug!("Create new TCP stream...");
let tcp = TcpStream::connect(&shell.remote)?;
tcp.set_read_timeout(Some(DEFAULT_TIMEOUT))?;
tcp.set_write_timeout(Some(DEFAULT_TIMEOUT))?;
let remote = shell.remote.clone();
debug!("Create new SSH session...");
let mut sess = Session::new().unwrap();
sess.handshake(&tcp)?;
trace!("SSH session handshook.");
sess.userauth_pubkey_file(&shell.username, None, shell.key.as_ref(), None)?;
if !sess.authenticated() {
return Err(SshError::AuthFailed {
key: shell.key.clone(),
}
.into());
}
trace!("SSH session authenticated.");
println!(
"{}",
console::style(format!(
"{}@{} ({})",
shell.username, shell.remote_name, remote
))
.green()
.bold()
);
Ok(SshShell {
tcp,
username: shell.username.clone(),
key: shell.key.clone(),
remote_name: shell.remote_name.clone(),
remote,
sess: Arc::new(Mutex::new(sess)),
dry_run_mode: false,
})
}
pub fn set_dry_run(&mut self, on: bool) {
self.dry_run_mode = on;
info!(
"Toggled dry run mode: {}",
if self.dry_run_mode { "on" } else { "off" }
);
}
pub fn spawn(&self, cmd: SshCommand) -> Result<SshSpawnHandle, SshError> {
debug!("spawn({:?})", cmd);
let shell = Self::from_existing(self)?;
let cmd = if self.dry_run_mode {
cmd.dry_run(true)
} else {
cmd
};
let thread_handle = std::thread::spawn(move || {
let result = shell.run(cmd);
(shell, result)
});
debug!("spawned thread for command.");
Ok(SshSpawnHandle { thread_handle })
}
fn run_with_chan_and_opts(
host_and_username: String,
mut chan: ssh2::Channel,
cmd_opts: SshCommand,
) -> Result<SshOutput, SshError> {
debug!("run_with_chan_and_opts({:?})", cmd_opts);
let SshCommand {
cwd,
cmd,
use_bash,
allow_error,
dry_run,
no_pty,
} = cmd_opts;
let msg = cmd.clone();
let cmd = if use_bash {
format!("bash -c {}", escape_for_bash(&cmd))
} else {
cmd
};
debug!("After shell escaping: {:?}", cmd);
let cmd = if let Some(cwd) = &cwd {
format!("cd {} ; {}", cwd.display(), cmd)
} else {
cmd
};
debug!("After cwd: {:?}", cmd);
if let Some(cwd) = cwd {
println!(
"{:-<80}\n{}\n{}\n{}",
"",
console::style(host_and_username).blue(),
console::style(cwd.display()).blue(),
console::style(msg).yellow().bold()
);
} else {
println!(
"{:-<80}\n{}\n{}",
"",
console::style(host_and_username).blue(),
console::style(msg).yellow().bold()
);
}
let mut stdout = String::new();
let mut stderr = String::new();
if dry_run {
chan.close()?;
chan.wait_close()?;
debug!("Closed channel after dry run.");
return Ok(SshOutput { stdout, stderr });
}
if !no_pty {
chan.request_pty("vt100", None, None)?;
debug!("Requested pty.");
}
debug!("Execute command remotely (asynchronous)...");
chan.exec(&cmd)?;
trace!("Read stdout...");
let mut buf = [0; 256];
while chan.read(&mut buf)? > 0 {
let out = String::from_utf8_lossy(&buf);
let out = out.trim_end_matches('\u{0}');
print!("{}", out);
stdout.push_str(out);
buf.iter_mut().for_each(|x| *x = 0);
}
trace!("No more stdout.");
chan.close()?;
chan.wait_close()?;
debug!("Command completed remotely.");
buf.iter_mut().for_each(|x| *x = 0);
trace!("Read stderr...");
while chan.stderr().read(&mut buf)? > 0 {
let err = String::from_utf8_lossy(&buf);
let err = err.trim_end_matches('\u{0}');
print!("{}", err);
stderr.push_str(err);
buf.iter_mut().for_each(|x| *x = 0);
}
trace!("No more stderr.");
debug!("Checking exit status.");
let exit = chan.exit_status()?;
debug!("Exit status: {}", exit);
if exit != 0 && !allow_error {
return Err(SshError::NonZeroExit { cmd, exit }.into());
}
trace!("Done with command.");
Ok(SshOutput { stdout, stderr })
}
}
impl Execute for SshShell {
fn run(&self, cmd: SshCommand) -> Result<SshOutput, SshError> {
debug!("run(cmd)");
let sess = self.sess.lock().unwrap();
debug!("Attempt to crate channel...");
let chan = sess.channel_session()?;
debug!("Channel created.");
let host_and_username = format!("{}@{}", self.username, self.remote_name);
let cmd = if self.dry_run_mode {
cmd.dry_run(true)
} else {
cmd
};
Self::run_with_chan_and_opts(host_and_username, chan, cmd)
}
fn duplicate(&self) -> Result<Self, SshError> {
Self::from_existing(self)
}
fn reconnect(&mut self) -> Result<(), SshError> {
info!("Reconnect attempt.");
trace!("Attempt to create new TCP stream...");
loop {
print!("{}", console::style("Attempt Reconnect ... ").red());
match TcpStream::connect_timeout(&self.remote, DEFAULT_TIMEOUT / 2) {
Ok(tcp) => {
self.tcp = tcp;
break;
}
Err(e) => {
trace!("{:?}", e);
println!("{}", console::style("failed, retrying").red());
std::thread::sleep(DEFAULT_TIMEOUT / 2);
}
}
}
println!(
"{}",
console::style("TCP connected, doing SSH handshake").red()
);
debug!("Attempt to create new SSH session...");
let mut sess = Session::new().unwrap();
sess.handshake(&self.tcp)?;
trace!("Handshook!");
sess.userauth_pubkey_file(&self.username, None, self.key.as_ref(), None)?;
if !sess.authenticated() {
return Err(SshError::AuthFailed {
key: self.key.clone(),
}
.into());
}
trace!("authenticated!");
let self_sess = Arc::get_mut(&mut self.sess).unwrap().get_mut().unwrap();
let _old_sess = std::mem::replace(self_sess, sess);
println!(
"{}",
console::style(format!("{}@{}", self.username, self.remote))
.green()
.bold()
);
Ok(())
}
}
impl std::fmt::Debug for SshShell {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
f,
"SshShell {{ {}@{:?} dry_run={} key={:?} }}",
self.username, self.remote, self.dry_run_mode, self.key
)
}
}
impl SshSpawnHandle {
pub fn join(self) -> (SshShell, Result<SshOutput, SshError>) {
debug!("Blocking on spawned commmand.");
let ret = self.thread_handle.join().unwrap();
debug!("Spawned commmand complete.");
ret
}
}
impl std::fmt::Debug for SshSpawnHandle {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "SshSpawnHandle {{ running }}")
}
}
#[macro_export]
macro_rules! cmd {
($fmt:expr) => {
$crate::SshCommand::new(&format!($fmt))
};
($fmt:expr, $($arg:tt)*) => {
$crate::SshCommand::new(&format!($fmt, $($arg)*))
};
}
fn escape_for_bash(s: &str) -> String {
let mut new = String::with_capacity(s.len());
new.push('\'');
for c in s.chars() {
if c == '\'' {
new.push('\'');
new.push('"');
new.push('\'');
new.push('"');
new.push('\'');
} else {
new.push(c);
}
}
new.push('\'');
new
}
#[cfg(test)]
mod test {
use crate::{cmd, SshCommand};
#[test]
fn test_cmd_macro() {
assert_eq!(cmd!("{} {}", "ls", 3), SshCommand::new("ls 3"));
}
mod test_escape_for_bash {
use super::super::escape_for_bash;
#[test]
fn simple() {
const TEST_STRING: &str = "ls";
assert_eq!(escape_for_bash(TEST_STRING), "'ls'");
}
#[test]
fn more_complex() {
const TEST_STRING: &str = r#"echo '$HELLOWORLD="hello world"' | grep "hello""#;
assert_eq!(
escape_for_bash(TEST_STRING),
r#"'echo '"'"'$HELLOWORLD="hello world"'"'"' | grep "hello"'"#
);
}
}
}