use anyhow::Result;
use std::path::Path;
use std::sync::Arc;
use crate::executor::{ExitCodeStrategy, OutputMode, ParallelExecutor, RankDetector};
use crate::forwarding::ForwardingType;
use crate::node::Node;
use crate::security::SudoPassword;
use crate::ssh::known_hosts::StrictHostKeyChecking;
use crate::ssh::tokio_client::SshConnectionConfig;
use crate::ssh::SshConfig;
use crate::ui::OutputFormatter;
use crate::utils::output::save_outputs_to_files;
pub struct ExecuteCommandParams<'a> {
pub nodes: Vec<Node>,
pub command: &'a str,
pub max_parallel: usize,
pub key_path: Option<&'a Path>,
pub verbose: bool,
pub strict_mode: StrictHostKeyChecking,
pub use_agent: bool,
pub use_password: bool,
#[cfg(target_os = "macos")]
pub use_keychain: bool,
pub output_dir: Option<&'a Path>,
pub stream: bool,
pub no_prefix: bool,
pub timeout: Option<u64>,
pub connect_timeout: Option<u64>,
pub jump_hosts: Option<&'a str>,
pub port_forwards: Option<Vec<ForwardingType>>,
pub require_all_success: bool,
pub check_all_nodes: bool,
pub sudo_password: Option<Arc<SudoPassword>>,
pub batch: bool,
pub fail_fast: bool,
pub ssh_config: Option<&'a SshConfig>,
pub ssh_connection_config: SshConnectionConfig,
}
pub async fn execute_command(params: ExecuteCommandParams<'_>) -> Result<()> {
println!(
"{}",
OutputFormatter::format_command_header(params.command, params.nodes.len())
);
if let Some(ref forwards) = params.port_forwards {
if !forwards.is_empty() {
return execute_command_with_forwarding(params).await;
}
}
execute_command_without_forwarding(params).await
}
async fn execute_command_with_forwarding(params: ExecuteCommandParams<'_>) -> Result<()> {
use crate::forwarding::{ForwardingConfig, ForwardingManager};
use std::sync::Arc;
println!("Setting up port forwarding...");
let forwards = params.port_forwards.as_ref().unwrap();
let node = ¶ms.nodes[0];
for forward in forwards {
println!(" {forward}");
}
let forwarding_config = ForwardingConfig::default();
let mut manager = ForwardingManager::new(forwarding_config);
use crate::ssh::known_hosts::StrictHostKeyChecking;
use crate::ssh::tokio_client::{AuthMethod, Client, ServerCheckMethod};
let auth_method = if params.use_agent {
#[cfg(not(target_os = "windows"))]
{
AuthMethod::with_agent()
}
#[cfg(target_os = "windows")]
{
return Err(anyhow::anyhow!("SSH agent not supported on Windows"));
}
} else if params.use_password {
return Err(anyhow::anyhow!(
"Password authentication not yet supported with port forwarding"
));
} else {
let key_path = params
.key_path
.map(|p| p.to_path_buf())
.or_else(|| {
let home = std::env::var("HOME").ok()?;
let ed25519_path = std::path::PathBuf::from(&home)
.join(".ssh")
.join("id_ed25519");
let rsa_path = std::path::PathBuf::from(&home).join(".ssh").join("id_rsa");
if ed25519_path.exists() {
Some(ed25519_path)
} else if rsa_path.exists() {
Some(rsa_path)
} else {
None
}
})
.ok_or_else(|| anyhow::anyhow!("No SSH key found for port forwarding"))?;
AuthMethod::with_key_file(key_path, None)
};
let server_check = match params.strict_mode {
StrictHostKeyChecking::Yes => ServerCheckMethod::DefaultKnownHostsFile,
StrictHostKeyChecking::No => ServerCheckMethod::NoCheck,
StrictHostKeyChecking::AcceptNew => ServerCheckMethod::DefaultKnownHostsFile, };
let ssh_client = Arc::new(
Client::connect(
(node.host.as_str(), node.port),
&node.username,
auth_method,
server_check,
)
.await?,
);
println!(
"SSH connection established to {}@{}",
node.username, node.host
);
let mut forwarding_ids = Vec::new();
for forward in forwards {
let id = manager.add_forwarding(forward.clone()).await?;
forwarding_ids.push(id);
manager
.start_forwarding(id, Arc::clone(&ssh_client))
.await?;
}
println!("Port forwarding active. Executing command...");
let result = execute_command_without_forwarding(ExecuteCommandParams {
port_forwards: None, batch: params.batch,
..params
})
.await;
println!("Stopping port forwarding...");
for id in forwarding_ids {
if let Err(e) = manager.stop_forwarding(id).await {
eprintln!("Warning: Failed to stop forwarding {id}: {e}");
}
}
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
result
}
async fn execute_command_without_forwarding(params: ExecuteCommandParams<'_>) -> Result<()> {
let nodes_for_rank_detection = params.nodes.clone();
let key_path = params.key_path.map(|p| p.to_string_lossy().to_string());
let executor = ParallelExecutor::new_with_all_options(
params.nodes,
params.max_parallel,
key_path,
params.strict_mode,
params.use_agent,
params.use_password,
)
.with_timeout(params.timeout)
.with_connect_timeout(params.connect_timeout)
.with_jump_hosts(params.jump_hosts.map(|s| s.to_string()))
.with_sudo_password(params.sudo_password)
.with_batch_mode(params.batch)
.with_fail_fast(params.fail_fast)
.with_ssh_config(params.ssh_config.cloned())
.with_ssh_connection_config(params.ssh_connection_config);
#[cfg(target_os = "macos")]
let executor = executor.with_keychain(params.use_keychain);
let output_mode = OutputMode::from_args_with_no_prefix(
params.stream,
params.output_dir.map(|p| p.to_path_buf()),
params.no_prefix,
);
let results = if output_mode.is_normal() {
executor.execute(params.command).await?
} else {
executor
.execute_with_streaming(params.command, output_mode.clone())
.await?
};
if let Some(dir) = params.output_dir {
if !params.stream {
save_outputs_to_files(&results, dir, params.command).await?;
}
}
if !params.stream {
for result in &results {
result.print_output(params.verbose);
}
}
let success_count = results.iter().filter(|r| r.is_success()).count();
let failed_count = results.len() - success_count;
println!(
"{}",
OutputFormatter::format_summary(results.len(), success_count, failed_count)
);
let strategy = if params.require_all_success {
ExitCodeStrategy::RequireAllSuccess
} else if params.check_all_nodes {
ExitCodeStrategy::MainRankWithFailureCheck
} else {
ExitCodeStrategy::MainRank };
let main_idx = RankDetector::identify_main_rank(&nodes_for_rank_detection);
let exit_code = strategy.calculate(&results, main_idx);
if exit_code != 0 {
std::process::exit(exit_code);
}
Ok(())
}