use crate::{MasshConfig, SshAuth, SshClient, SshOutput};
use anyhow::Result;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::mpsc::Receiver;
use std::sync::Arc;
use threadpool::ThreadPool;
pub type MasshHost = String;
pub type MasshReceiver<T> = Receiver<(MasshHost, Result<T>)>;
pub struct MasshClient {
clients: HashMap<MasshHost, Arc<Mutex<SshClient>>>,
pool: Option<ThreadPool>,
}
impl MasshClient {
pub fn from(config: &MasshConfig) -> Self {
let mut clients = HashMap::new();
config.hosts.iter().for_each(|host| {
let addr = host.addr;
let auth = match &host.auth {
Some(auth) => auth,
None => &config.default_auth,
};
let port = match host.port {
Some(port) => port,
None => config.default_port,
};
let user = match &host.user {
Some(user) => user,
None => &config.default_user,
};
let mut ssh = SshClient::from(user, (addr, port));
match auth {
SshAuth::Agent => ssh.set_auth_agent(),
SshAuth::Password(password) => ssh.set_auth_password(password),
SshAuth::Pubkey(path) => ssh.set_auth_pubkey(path),
};
ssh.set_timeout(config.timeout);
let host = format!("{}@{}", ssh.get_user(), ssh.get_addr());
clients.insert(host, Arc::new(Mutex::new(ssh)));
});
let pool = if config.threads == 0 {
None
} else {
Some(ThreadPool::new(config.threads as usize))
};
MasshClient { clients, pool }
}
pub fn execute(&self, command: impl Into<String>) -> MasshReceiver<SshOutput> {
let command = command.into();
let (tx, rx) = std::sync::mpsc::channel();
self.clients.iter().for_each(|(host, client)| {
let (client, host, tx) = (client.clone(), host.clone(), tx.clone());
let command = command.clone();
let task_closure = move || {
let mut client = client.lock();
let result = client.execute(&command);
let _ = tx.send((host, result));
};
if let Some(pool) = &self.pool {
pool.execute(task_closure)
} else {
std::thread::spawn(task_closure);
}
});
rx
}
pub fn scp_download<P>(&self, remote_path: P, local_path: P) -> MasshReceiver<()>
where
P: Into<PathBuf>,
{
let (remote_path, local_path) = (remote_path.into(), local_path.into());
let (tx, rx) = std::sync::mpsc::channel();
self.clients.iter().for_each(|(host, client)| {
let (client, host, tx) = (client.clone(), host.clone(), tx.clone());
let (remote_path, mut local_path) = (remote_path.clone(), local_path.clone());
local_path.push(&host);
let task_closure = move || {
let mut client = client.lock();
let result = client.scp_download(remote_path, local_path);
let _ = tx.send((host, result));
};
if let Some(pool) = &self.pool {
pool.execute(task_closure)
} else {
std::thread::spawn(task_closure);
}
});
rx
}
pub fn scp_upload<P>(&self, local_path: P, remote_path: P) -> MasshReceiver<()>
where
P: Into<PathBuf>,
{
let (local_path, remote_path) = (local_path.into(), remote_path.into());
let (tx, rx) = std::sync::mpsc::channel();
self.clients.iter().for_each(|(host, client)| {
let (client, host, tx) = (client.clone(), host.clone(), tx.clone());
let (local_path, remote_path) = (local_path.clone(), remote_path.clone());
let task_closure = move || {
let mut client = client.lock();
let result = client.scp_upload(local_path, remote_path);
let _ = tx.send((host, result));
};
if let Some(pool) = &self.pool {
pool.execute(task_closure)
} else {
std::thread::spawn(task_closure);
}
});
rx
}
}