use anyhow::{Context, Result};
use bssh::{
cli::Cli,
config::Config,
jump::parse_jump_hosts,
node::Node,
ssh::{known_hosts::StrictHostKeyChecking, SshConfig},
utils::init_logging,
};
use std::path::PathBuf;
pub struct AppContext {
pub config: Config,
pub ssh_config: SshConfig,
pub nodes: Vec<Node>,
pub cluster_name: Option<String>,
pub strict_mode: StrictHostKeyChecking,
#[allow(dead_code)] pub jump_hosts: Option<Vec<bssh::jump::JumpHost>>,
pub max_parallel: usize,
}
pub async fn initialize_app(cli: &Cli, args: &[String]) -> Result<AppContext> {
init_logging(cli.verbose);
let has_explicit_config = args.iter().any(|arg| arg == "--config");
let has_explicit_parallel = args
.iter()
.any(|arg| arg == "--parallel" || arg.starts_with("--parallel="));
if has_explicit_config {
let expanded_path = if cli.config.starts_with("~") {
let path_str = cli.config.to_string_lossy();
if let Ok(home) = std::env::var("HOME") {
PathBuf::from(path_str.replacen("~", &home, 1))
} else {
cli.config.clone()
}
} else {
cli.config.clone()
};
if !expanded_path.exists() {
anyhow::bail!("Config file not found: {expanded_path:?}");
}
}
let config = Config::load_with_priority(&cli.config).await?;
let ssh_config = if let Some(ref ssh_config_path) = cli.ssh_config {
SshConfig::load_from_file_cached(ssh_config_path)
.await
.with_context(|| format!("Failed to load SSH config from {ssh_config_path:?}"))?
} else {
SshConfig::load_default_cached().await.unwrap_or_else(|_| {
tracing::debug!("No SSH config found or failed to load, using empty config");
SshConfig::new()
})
};
let (nodes, actual_cluster_name) =
super::nodes::resolve_nodes(cli, &config, &ssh_config).await?;
if nodes.is_empty() {
anyhow::bail!(
"No hosts specified. Please use one of the following options:\n \
-H <hosts> Specify comma-separated hosts (e.g., -H user@host1,user@host2)\n \
-c <cluster> Use a cluster from your configuration file"
);
}
let jump_hosts = if let Some(ref jump_spec) = cli.jump_hosts {
Some(
parse_jump_hosts(jump_spec)
.with_context(|| format!("Invalid jump host specification: '{jump_spec}'"))?,
)
} else {
None
};
if let Some(ref jumps) = jump_hosts {
if jumps.len() == 1 {
tracing::info!("Using jump host: {}", jumps[0]);
} else {
tracing::info!(
"Using jump host chain: {}",
jumps
.iter()
.map(|j| j.to_string())
.collect::<Vec<_>>()
.join(" -> ")
);
}
}
let hostname = if cli.is_ssh_mode() {
cli.parse_destination().map(|(_, host, _)| host)
} else {
None
};
let strict_mode = determine_strict_host_key_checking(cli, &ssh_config, hostname.as_deref());
let max_parallel = if cli.is_ssh_mode() {
1
} else if has_explicit_parallel {
cli.parallel
} else {
config
.get_parallel(actual_cluster_name.as_deref().or(cli.cluster.as_deref()))
.unwrap_or(cli.parallel) };
Ok(AppContext {
config,
ssh_config,
nodes,
cluster_name: actual_cluster_name,
strict_mode,
jump_hosts,
max_parallel,
})
}
pub fn determine_strict_host_key_checking(
cli: &Cli,
ssh_config: &SshConfig,
hostname: Option<&str>,
) -> StrictHostKeyChecking {
if cli.strict_host_key_checking != "accept-new" {
return cli.strict_host_key_checking.parse().unwrap_or_default();
}
if let Some(host) = hostname {
if let Some(ssh_config_value) = ssh_config.get_strict_host_key_checking(host) {
return match ssh_config_value.to_lowercase().as_str() {
"yes" => StrictHostKeyChecking::Yes,
"no" => StrictHostKeyChecking::No,
"ask" | "accept-new" => StrictHostKeyChecking::AcceptNew,
_ => StrictHostKeyChecking::AcceptNew,
};
}
}
cli.strict_host_key_checking.parse().unwrap_or_default()
}
pub fn determine_ssh_key_path(
cli: &Cli,
config: &Config,
ssh_config: &SshConfig,
hostname: Option<&str>,
cluster_name: Option<&str>,
) -> Option<PathBuf> {
if let Some(identity) = &cli.identity {
return Some(identity.clone());
}
if let Some(host) = hostname {
let identity_files = ssh_config.get_identity_files(host);
if !identity_files.is_empty() {
return Some(identity_files[0].clone());
}
}
config
.get_ssh_key(cluster_name)
.map(|ssh_key| bssh::config::expand_tilde(std::path::Path::new(&ssh_key)))
}
#[cfg(target_os = "macos")]
pub fn determine_use_keychain(ssh_config: &SshConfig, hostname: Option<&str>) -> bool {
if let Some(host) = hostname {
let host_config = ssh_config.find_host_config(host);
host_config.use_keychain.unwrap_or(false)
} else {
false
}
}
#[cfg(not(target_os = "macos"))]
#[allow(dead_code)]
pub fn determine_use_keychain(_ssh_config: &SshConfig, _hostname: Option<&str>) -> bool {
false
}