use anyhow::Result;
use clap::Parser;
use std::path::{Path, PathBuf};
use std::time::Duration;
use bssh::{
cli::{Cli, Commands},
commands::{
download::download_file,
exec::{execute_command, ExecuteCommandParams},
interactive::InteractiveCommand,
list::list_clusters,
ping::ping_nodes,
upload::{upload_file, FileTransferParams},
},
config::{Config, InteractiveMode},
node::Node,
ssh::known_hosts::StrictHostKeyChecking,
utils::init_logging,
};
fn format_duration(duration: Duration) -> String {
let total_seconds = duration.as_secs_f64();
if total_seconds < 1.0 {
format!("{:.1} ms", duration.as_secs_f64() * 1000.0)
} else if total_seconds < 60.0 {
format!("{total_seconds:.2} s")
} else {
let minutes = duration.as_secs() / 60;
let seconds = duration.as_secs() % 60;
let millis = duration.subsec_millis();
if seconds == 0 {
format!("{minutes}m")
} else if millis > 0 {
format!("{minutes}m {seconds}.{millis:03}s")
} else {
format!("{minutes}m {seconds}s")
}
}
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
init_logging(cli.verbose);
let args: Vec<String> = std::env::args().collect();
let has_explicit_config = args.iter().any(|arg| arg == "--config");
let has_explicit_parallel = args.iter().any(|arg| {
arg == "-p"
|| arg == "--parallel"
|| arg.starts_with("-p=")
|| arg.starts_with("--parallel=")
});
if has_explicit_config {
let expanded_path = if cli.config.starts_with("~") {
let path_str = cli.config.to_string_lossy();
if let Ok(home) = std::env::var("HOME") {
PathBuf::from(path_str.replacen("~", &home, 1))
} else {
cli.config.clone()
}
} else {
cli.config.clone()
};
if !expanded_path.exists() {
anyhow::bail!("Config file not found: {:?}", expanded_path);
}
}
let config = Config::load_with_priority(&cli.config).await?;
if matches!(cli.command, Some(Commands::List)) {
list_clusters(&config);
return Ok(());
}
let (nodes, actual_cluster_name) = resolve_nodes(&cli, &config).await?;
let max_parallel = if has_explicit_parallel {
cli.parallel
} else {
config
.get_parallel(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.unwrap_or(cli.parallel) };
if nodes.is_empty() {
anyhow::bail!(
"No hosts specified. Please use one of the following options:\n -H <hosts> Specify comma-separated hosts (e.g., -H user@host1,user@host2)\n -c <cluster> Use a cluster from your configuration file"
);
}
let strict_mode: StrictHostKeyChecking =
cli.strict_host_key_checking.parse().unwrap_or_default();
let command = cli.get_command();
let needs_command = matches!(cli.command, None | Some(Commands::Exec { .. }));
if command.is_empty() && needs_command {
anyhow::bail!(
"No command specified. Please provide a command to execute.\nExample: bssh -H host1,host2 'ls -la'"
);
}
match cli.command {
Some(Commands::Ping) => {
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};
ping_nodes(
nodes,
max_parallel,
key_path.as_deref(),
strict_mode,
cli.use_agent,
cli.password,
)
.await
}
Some(Commands::Upload {
source,
destination,
recursive,
}) => {
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};
let params = FileTransferParams {
nodes,
max_parallel,
key_path: key_path.as_deref(),
strict_mode,
use_agent: cli.use_agent,
use_password: cli.password,
recursive,
};
upload_file(params, &source, &destination).await
}
Some(Commands::Download {
source,
destination,
recursive,
}) => {
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};
let params = FileTransferParams {
nodes,
max_parallel,
key_path: key_path.as_deref(),
strict_mode,
use_agent: cli.use_agent,
use_password: cli.password,
recursive,
};
download_file(params, &source, &destination).await
}
Some(Commands::Interactive {
single_node,
multiplex,
prompt_format,
history_file,
work_dir,
}) => {
let cluster_name = cli.cluster.as_deref();
let interactive_config = config.get_interactive_config(cluster_name);
let merged_mode = if single_node {
(true, false)
} else if multiplex {
(false, true)
} else {
match interactive_config.default_mode {
InteractiveMode::SingleNode => (true, false),
InteractiveMode::Multiplex => (false, true),
}
};
let merged_prompt = if prompt_format != "[{node}:{user}@{host}:{pwd}]$ " {
prompt_format
} else {
interactive_config.prompt_format.clone()
};
let merged_history = if history_file.to_string_lossy() != "~/.bssh_history" {
history_file
} else if let Some(config_history) = interactive_config.history_file.clone() {
PathBuf::from(config_history)
} else {
history_file
};
let merged_work_dir = work_dir.or(interactive_config.work_dir.clone());
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};
let interactive_cmd = InteractiveCommand {
single_node: merged_mode.0,
multiplex: merged_mode.1,
prompt_format: merged_prompt,
history_file: merged_history,
work_dir: merged_work_dir,
nodes,
config: config.clone(),
interactive_config,
cluster_name: cluster_name.map(String::from),
key_path,
use_agent: cli.use_agent,
use_password: cli.password,
strict_mode,
};
let result = interactive_cmd.execute().await?;
println!("\nInteractive session ended.");
println!("Duration: {}", format_duration(result.duration));
println!("Commands executed: {}", result.commands_executed);
println!("Nodes connected: {}", result.nodes_connected);
Ok(())
}
_ => {
let timeout = if cli.timeout > 0 {
Some(cli.timeout)
} else {
config.get_timeout(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
};
let key_path = if let Some(identity) = &cli.identity {
Some(identity.clone())
} else {
config
.get_ssh_key(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.map(|ssh_key| bssh::config::expand_tilde(Path::new(&ssh_key)))
};
let params = ExecuteCommandParams {
nodes,
command: &command,
max_parallel,
key_path: key_path.as_deref(),
verbose: cli.verbose > 0,
strict_mode,
use_agent: cli.use_agent,
use_password: cli.password,
output_dir: cli.output_dir.as_deref(),
timeout,
};
execute_command(params).await
}
}
}
async fn resolve_nodes(cli: &Cli, config: &Config) -> Result<(Vec<Node>, Option<String>)> {
let mut nodes = Vec::new();
let mut cluster_name = None;
if let Some(hosts) = &cli.hosts {
for host_str in hosts {
for single_host in host_str.split(',') {
let node = Node::parse(single_host.trim(), None)?;
nodes.push(node);
}
}
} else if let Some(cli_cluster_name) = &cli.cluster {
nodes = config.resolve_nodes(cli_cluster_name)?;
cluster_name = Some(cli_cluster_name.clone());
} else {
if config.clusters.contains_key("bai_auto") {
nodes = config.resolve_nodes("bai_auto")?;
cluster_name = Some("bai_auto".to_string());
}
}
Ok((nodes, cluster_name))
}