use anyhow::Result;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use owo_colors::OwoColorize;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::Semaphore;
use crate::node::Node;
use super::connection_manager::{
download_from_node, execute_on_node_with_jump_hosts, upload_to_node, ExecutionConfig,
};
use super::result_types::{DownloadResult, ExecutionResult, UploadResult};
const PROGRESS_BAR_TICK_RATE_MS: u64 = 80;
const DOWNLOAD_PROGRESS_TICK_RATE_MS: u64 = 100;
pub(crate) fn create_progress_style() -> Result<ProgressStyle> {
ProgressStyle::default_bar()
.template("{prefix:.bold} {spinner:.cyan} {msg}")
.map_err(|e| anyhow::anyhow!("Failed to create progress bar template: {e}"))
.map(|style| style.tick_chars("⣾⣽⣻⢿⡿⣟⣯⣷ "))
}
pub(crate) fn format_node_display(node: &Node) -> String {
if node.to_string().len() > 20 {
format!("{}...", &node.to_string()[..17])
} else {
node.to_string()
}
}
pub(crate) async fn execute_command_task(
node: Node,
command: String,
config: ExecutionConfig<'_>,
semaphore: Arc<Semaphore>,
pb: ProgressBar,
) -> ExecutionResult {
let _permit = match semaphore.acquire().await {
Ok(permit) => permit,
Err(e) => {
pb.finish_with_message(format!("{} {}", "●".red(), "Semaphore closed".red()));
return ExecutionResult {
node,
result: Err(anyhow::anyhow!("Semaphore acquisition failed: {e}")),
is_main_rank: false, };
}
};
pb.set_message(format!("{}", "Executing...".blue()));
let result = execute_on_node_with_jump_hosts(node.clone(), &command, &config).await;
match &result {
Ok(cmd_result) => {
if cmd_result.is_success() {
pb.finish_with_message(format!("{} {}", "●".green(), "Success".green()));
} else {
pb.finish_with_message(format!(
"{} Exit code: {}",
"●".red(),
cmd_result.exit_status.to_string().red()
));
}
}
Err(e) => {
let error_msg = format!("{e:#}");
let first_line = error_msg.lines().next().unwrap_or("Unknown error");
let short_error = if first_line.len() > 50 {
format!("{}...", &first_line[..47])
} else {
first_line.to_string()
};
pb.finish_with_message(format!("{} {}", "●".red(), short_error.red()));
}
}
ExecutionResult {
node,
result,
is_main_rank: false, }
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn upload_file_task(
node: Node,
local_path: std::path::PathBuf,
remote_path: String,
key_path: Option<String>,
strict_mode: crate::ssh::known_hosts::StrictHostKeyChecking,
use_agent: bool,
use_password: bool,
jump_hosts: Option<String>,
connect_timeout: Option<u64>,
semaphore: Arc<Semaphore>,
pb: ProgressBar,
) -> UploadResult {
let _permit = match semaphore.acquire().await {
Ok(permit) => permit,
Err(e) => {
pb.finish_with_message(format!("{} {}", "●".red(), "Semaphore closed".red()));
return UploadResult {
node,
result: Err(anyhow::anyhow!("Semaphore acquisition failed: {e}")),
};
}
};
pb.set_message(format!("{}", "Uploading (SFTP)...".blue()));
let result = upload_to_node(
node.clone(),
&local_path,
&remote_path,
key_path.as_deref(),
strict_mode,
use_agent,
use_password,
jump_hosts.as_deref(),
connect_timeout,
)
.await;
match &result {
Ok(()) => {
pb.finish_with_message(format!("{} {}", "●".green(), "Uploaded".green()));
}
Err(e) => {
let error_msg = format!("{e:#}");
let first_line = error_msg.lines().next().unwrap_or("Unknown error");
let short_error = if first_line.len() > 50 {
format!("{}...", &first_line[..47])
} else {
first_line.to_string()
};
pb.finish_with_message(format!("{} {}", "●".red(), short_error.red()));
}
}
UploadResult { node, result }
}
#[allow(clippy::too_many_arguments)]
pub(crate) async fn download_file_task(
node: Node,
remote_path: String,
local_dir: std::path::PathBuf,
key_path: Option<String>,
strict_mode: crate::ssh::known_hosts::StrictHostKeyChecking,
use_agent: bool,
use_password: bool,
jump_hosts: Option<String>,
connect_timeout: Option<u64>,
semaphore: Arc<Semaphore>,
pb: ProgressBar,
) -> DownloadResult {
let _permit = match semaphore.acquire().await {
Ok(permit) => permit,
Err(e) => {
pb.finish_with_message(format!("{} {}", "●".red(), "Semaphore closed".red()));
return DownloadResult {
node,
result: Err(anyhow::anyhow!("Semaphore acquisition failed: {e}")),
};
}
};
pb.set_message(format!("{}", "Downloading (SFTP)...".blue()));
let filename = if let Some(file_name) = Path::new(&remote_path).file_name() {
format!(
"{}_{}",
node.host.replace(':', "_"),
file_name.to_string_lossy()
)
} else {
format!("{}_download", node.host.replace(':', "_"))
};
let local_path = local_dir.join(filename);
let result = download_from_node(
node.clone(),
&remote_path,
&local_path,
key_path.as_deref(),
strict_mode,
use_agent,
use_password,
jump_hosts.as_deref(),
connect_timeout,
)
.await;
match &result {
Ok(path) => {
pb.finish_with_message(format!("✓ Downloaded to {}", path.display()));
}
Err(e) => {
pb.finish_with_message(format!("✗ Error: {e}"));
}
}
DownloadResult {
node,
result: result.map(|_| local_path),
}
}
pub(crate) fn setup_progress_bar(
multi_progress: &MultiProgress,
node: &Node,
style: ProgressStyle,
initial_message: &str,
) -> ProgressBar {
let pb = multi_progress.add(ProgressBar::new_spinner());
pb.set_style(style);
let node_display = format_node_display(node);
pb.set_prefix(format!("[{node_display}]"));
pb.set_message(format!("{}", initial_message.cyan()));
pb.enable_steady_tick(std::time::Duration::from_millis(PROGRESS_BAR_TICK_RATE_MS));
pb
}
pub(crate) fn setup_download_progress_bar(
multi_progress: &MultiProgress,
node: &Node,
style: ProgressStyle,
remote_path: &str,
) -> ProgressBar {
let pb = multi_progress.add(ProgressBar::new_spinner());
pb.set_style(style);
pb.set_prefix(format!("[{node}]"));
pb.set_message(format!("Downloading {remote_path}"));
pb.enable_steady_tick(std::time::Duration::from_millis(
DOWNLOAD_PROGRESS_TICK_RATE_MS,
));
pb
}