use anyhow::Result;
use futures::future::join_all;
use indicatif::MultiProgress;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::node::Node;
use crate::ssh::known_hosts::StrictHostKeyChecking;
use super::connection_manager::{download_from_node, ExecutionConfig};
use super::execution_strategy::{
create_progress_style, download_file_task, execute_command_task, setup_download_progress_bar,
setup_progress_bar, upload_file_task,
};
use super::result_types::{DownloadResult, ExecutionResult, UploadResult};
pub struct ParallelExecutor {
pub(crate) nodes: Vec<Node>,
pub(crate) max_parallel: usize,
pub(crate) key_path: Option<String>,
pub(crate) strict_mode: StrictHostKeyChecking,
pub(crate) use_agent: bool,
pub(crate) use_password: bool,
#[cfg(target_os = "macos")]
pub(crate) use_keychain: bool,
pub(crate) timeout: Option<u64>,
pub(crate) jump_hosts: Option<String>,
}
impl ParallelExecutor {
pub fn new(nodes: Vec<Node>, max_parallel: usize, key_path: Option<String>) -> Self {
Self::new_with_strict_mode(
nodes,
max_parallel,
key_path,
StrictHostKeyChecking::AcceptNew,
)
}
pub fn new_with_strict_mode(
nodes: Vec<Node>,
max_parallel: usize,
key_path: Option<String>,
strict_mode: StrictHostKeyChecking,
) -> Self {
Self {
nodes,
max_parallel,
key_path,
strict_mode,
use_agent: false,
use_password: false,
#[cfg(target_os = "macos")]
use_keychain: false,
timeout: None,
jump_hosts: None,
}
}
pub fn new_with_strict_mode_and_agent(
nodes: Vec<Node>,
max_parallel: usize,
key_path: Option<String>,
strict_mode: StrictHostKeyChecking,
use_agent: bool,
) -> Self {
Self {
nodes,
max_parallel,
key_path,
strict_mode,
use_agent,
use_password: false,
#[cfg(target_os = "macos")]
use_keychain: false,
timeout: None,
jump_hosts: None,
}
}
pub fn new_with_all_options(
nodes: Vec<Node>,
max_parallel: usize,
key_path: Option<String>,
strict_mode: StrictHostKeyChecking,
use_agent: bool,
use_password: bool,
) -> Self {
Self {
nodes,
max_parallel,
key_path,
strict_mode,
use_agent,
use_password,
#[cfg(target_os = "macos")]
use_keychain: false,
timeout: None,
jump_hosts: None,
}
}
pub fn with_timeout(mut self, timeout: Option<u64>) -> Self {
self.timeout = timeout;
self
}
pub fn with_jump_hosts(mut self, jump_hosts: Option<String>) -> Self {
self.jump_hosts = jump_hosts;
self
}
#[cfg(target_os = "macos")]
pub fn with_keychain(mut self, use_keychain: bool) -> Self {
self.use_keychain = use_keychain;
self
}
pub async fn execute(&self, command: &str) -> Result<Vec<ExecutionResult>> {
let semaphore = Arc::new(Semaphore::new(self.max_parallel));
let multi_progress = MultiProgress::new();
let style = create_progress_style()?;
let tasks: Vec<_> = self
.nodes
.iter()
.map(|node| {
let node = node.clone();
let command = command.to_string();
let key_path = self.key_path.clone();
let strict_mode = self.strict_mode;
let use_agent = self.use_agent;
let use_password = self.use_password;
#[cfg(target_os = "macos")]
let use_keychain = self.use_keychain;
let timeout = self.timeout;
let jump_hosts = self.jump_hosts.clone();
let semaphore = Arc::clone(&semaphore);
let pb = setup_progress_bar(&multi_progress, &node, style.clone(), "Connecting...");
tokio::spawn(async move {
let config = ExecutionConfig {
key_path: key_path.as_deref(),
strict_mode,
use_agent,
use_password,
#[cfg(target_os = "macos")]
use_keychain,
timeout,
jump_hosts: jump_hosts.as_deref(),
};
execute_command_task(node, command, config, semaphore, pb).await
})
})
.collect();
let results = join_all(tasks).await;
self.collect_results(results)
}
pub async fn upload_file(
&self,
local_path: &Path,
remote_path: &str,
) -> Result<Vec<UploadResult>> {
let semaphore = Arc::new(Semaphore::new(self.max_parallel));
let multi_progress = MultiProgress::new();
let style = create_progress_style()?;
let tasks: Vec<_> = self
.nodes
.iter()
.map(|node| {
let node = node.clone();
let local_path = local_path.to_path_buf();
let remote_path = remote_path.to_string();
let key_path = self.key_path.clone();
let strict_mode = self.strict_mode;
let use_agent = self.use_agent;
let use_password = self.use_password;
let jump_hosts = self.jump_hosts.clone();
let semaphore = Arc::clone(&semaphore);
let pb = setup_progress_bar(&multi_progress, &node, style.clone(), "Connecting...");
tokio::spawn(upload_file_task(
node,
local_path,
remote_path,
key_path,
strict_mode,
use_agent,
use_password,
jump_hosts,
semaphore,
pb,
))
})
.collect();
let results = join_all(tasks).await;
self.collect_upload_results(results)
}
pub async fn download_file(
&self,
remote_path: &str,
local_dir: &Path,
) -> Result<Vec<DownloadResult>> {
let semaphore = Arc::new(Semaphore::new(self.max_parallel));
let multi_progress = MultiProgress::new();
let style = create_progress_style()?;
let tasks: Vec<_> = self
.nodes
.iter()
.map(|node| {
let node = node.clone();
let remote_path = remote_path.to_string();
let local_dir = local_dir.to_path_buf();
let key_path = self.key_path.clone();
let strict_mode = self.strict_mode;
let use_agent = self.use_agent;
let use_password = self.use_password;
let jump_hosts = self.jump_hosts.clone();
let semaphore = Arc::clone(&semaphore);
let pb = setup_progress_bar(&multi_progress, &node, style.clone(), "Connecting...");
tokio::spawn(download_file_task(
node,
remote_path,
local_dir,
key_path,
strict_mode,
use_agent,
use_password,
jump_hosts,
semaphore,
pb,
))
})
.collect();
let results = join_all(tasks).await;
self.collect_download_results(results)
}
pub async fn download_files(
&self,
remote_paths: Vec<String>,
local_dir: &Path,
) -> Result<Vec<DownloadResult>> {
let semaphore = Arc::new(Semaphore::new(self.max_parallel));
let multi_progress = MultiProgress::new();
let style = create_progress_style()?;
let mut all_results = Vec::new();
for remote_path in remote_paths {
let tasks: Vec<_> = self
.nodes
.iter()
.map(|node| {
let node = node.clone();
let remote_path = remote_path.clone();
let local_dir = local_dir.to_path_buf();
let semaphore = Arc::clone(&semaphore);
let pb = setup_download_progress_bar(
&multi_progress,
&node,
style.clone(),
&remote_path,
);
let filename = if let Some(file_name) = Path::new(&remote_path).file_name() {
format!(
"{}_{}",
node.host.replace(':', "_"),
file_name.to_string_lossy()
)
} else {
format!("{}_download", node.host.replace(':', "_"))
};
let local_path = local_dir.join(filename);
let key_path = self.key_path.clone();
let strict_mode = self.strict_mode;
let use_agent = self.use_agent;
let use_password = self.use_password;
let jump_hosts = self.jump_hosts.clone();
tokio::spawn(async move {
let _permit = match semaphore.acquire().await {
Ok(permit) => permit,
Err(e) => {
pb.finish_with_message(format!("✗ Semaphore failed: {e}"));
return DownloadResult {
node,
result: Err(anyhow::anyhow!(
"Semaphore acquisition failed: {e}"
)),
};
}
};
let result = download_from_node(
node.clone(),
&remote_path,
&local_path,
key_path.as_deref(),
strict_mode,
use_agent,
use_password,
jump_hosts.as_deref(),
)
.await;
match &result {
Ok(path) => {
pb.finish_with_message(format!("✓ Downloaded {}", path.display()));
}
Err(e) => {
pb.finish_with_message(format!("✗ Failed: {e}"));
}
}
DownloadResult {
node,
result: result.map(|_| local_path),
}
})
})
.collect();
let results = join_all(tasks).await;
for result in results {
match result {
Ok(download_result) => all_results.push(download_result),
Err(e) => {
tracing::error!("Task failed: {}", e);
}
}
}
}
Ok(all_results)
}
fn collect_results(
&self,
results: Vec<Result<ExecutionResult, tokio::task::JoinError>>,
) -> Result<Vec<ExecutionResult>> {
let mut execution_results = Vec::new();
for result in results {
match result {
Ok(exec_result) => execution_results.push(exec_result),
Err(e) => {
tracing::error!("Task failed: {}", e);
}
}
}
Ok(execution_results)
}
fn collect_upload_results(
&self,
results: Vec<Result<UploadResult, tokio::task::JoinError>>,
) -> Result<Vec<UploadResult>> {
let mut upload_results = Vec::new();
for result in results {
match result {
Ok(upload_result) => upload_results.push(upload_result),
Err(e) => {
tracing::error!("Task failed: {}", e);
}
}
}
Ok(upload_results)
}
fn collect_download_results(
&self,
results: Vec<Result<DownloadResult, tokio::task::JoinError>>,
) -> Result<Vec<DownloadResult>> {
let mut download_results = Vec::new();
for result in results {
match result {
Ok(download_result) => download_results.push(download_result),
Err(e) => {
tracing::error!("Task failed: {}", e);
}
}
}
Ok(download_results)
}
}