pub mod local;
use anyhow::{Context, Error, Result};
use ssh2::Session;
use std::fs;
use std::io::{BufRead, BufReader};
use std::net::TcpStream;
use std::path::Path;
use std::sync::Arc;
use std::sync::mpsc;
use std::time::Duration;
use tokio::sync::mpsc as tokio_mpsc;
use tracing::info;
use crate::models::{ExecutionResult, SshConfig, OutputEvent, OutputType, OutputCallback};
use crate::Step;
use crate::vars::VariableManager;
use crate::ExtractRule;
pub struct SshExecutor;
impl SshExecutor {
pub fn execute_script_with_realtime_output(
global_scripts:Arc<Vec<String>>,
server_name: &str,
ssh_config: &SshConfig,
step: &Step,
pipeline_name: &str,
step_name: &str,
output_callback: Option<OutputCallback>,
mut variable_manager: VariableManager,
extract_rules: Option<Vec<ExtractRule>>
) -> Result<ExecutionResult> {
info!("Connecting to {}:{} as {}", ssh_config.host, ssh_config.port, ssh_config.username);
let script_path = step.script.as_str();
let script_content = std::fs::read_to_string(script_path)
.context(format!("Failed to read script file: {}", script_path))?;
let mut gloabl_script_content = global_scripts.iter()
.map(|v|std::fs::read_to_string(v).context(format!("read file:[{}]", v)))
.fold(Ok("".to_string()), |p:Result<String>,v|{
if p.is_err(){
return p;
}
if v.is_err(){
return Err(Error::msg(format!("{:?}", v.err())));
}
let content = v.unwrap();
let mut s = p.unwrap_or_default();
s.push_str("\n");
s.push_str(&content);
return Ok(s.clone());
})?;
gloabl_script_content.push_str("\n");
gloabl_script_content.push_str(&script_content);
let script_content = gloabl_script_content.clone();
match fs::write("script.sh", script_path.as_bytes())
.context("Failed to write temporary script file") {
Ok(_) => {},
Err(e) => {
println!("Warning: Failed to write temporary script file: {}", e);
},
};
variable_manager.set_variable("ssh_server_name".to_string(), server_name.to_string());
variable_manager.set_variable("ssh_server_ip".to_string(), ssh_config.host.to_string());
let script_content = variable_manager.replace_variables(&script_content);
let ssh_timeout_seconds = ssh_config.timeout_seconds.unwrap_or(3);
let ssh_timeout_duration = Duration::from_secs(ssh_timeout_seconds);
let tcp = connect_with_timeout(&format!("{}:{}", ssh_config.host, ssh_config.port), ssh_timeout_duration)
.context("Failed to connect to SSH server")?;
let timeout_duration = Duration::from_secs(step.timeout_seconds.unwrap_or(30));
tcp.set_read_timeout(Some(timeout_duration))
.context("Failed to set read timeout")?;
tcp.set_write_timeout(Some(timeout_duration))
.context("Failed to set write timeout")?;
tcp.set_nodelay(true)
.context("Failed to set TCP nodelay")?;
let mut sess = Session::new()
.context("Failed to create SSH session")?;
sess.set_tcp_stream(tcp);
let session_timeout_seconds = step.timeout_seconds.unwrap_or(30);
let session_timeout_duration = Duration::from_secs(session_timeout_seconds);
sess.set_timeout(session_timeout_duration.as_millis() as u32);
sess.handshake()
.context(format!("SSH handshake failed: timeout {} s", ssh_timeout_seconds))?;
info!("SSH handshake completed, starting authentication");
let auth_result = if let Some(ref password) = ssh_config.password {
sess.userauth_password(&ssh_config.username, password)
.context("SSH password authentication failed")
} else if let Some(ref key_path) = ssh_config.private_key_path {
sess.userauth_pubkey_file(&ssh_config.username, None, Path::new(key_path), None)
.context("SSH key authentication failed")
} else {
Err(anyhow::anyhow!("No authentication method provided"))
};
auth_result?;
info!("SSH authentication successful");
let mut channel = sess.channel_session()
.context("Failed to create SSH channel")?;
channel.exec("sh")
.context("Failed to exec remote shell")?;
use std::io::Write;
channel.write_all(script_content.as_bytes())
.context("Failed to write script to remote shell")?;
channel.send_eof()
.context("Failed to send EOF to remote shell")?;
let (tx, mut rx) = tokio_mpsc::channel::<OutputEvent>(100);
let output_callback = output_callback.map(|cb| Arc::new(cb));
let server_name = server_name.to_string();
let _step_name = step_name.to_string();
let pipeline_name = pipeline_name.to_string();
let output_callback_clone = output_callback.clone();
let output_handle = std::thread::spawn(move || {
while let Some(event) = rx.blocking_recv() {
if let Some(callback) = &output_callback_clone {
callback(event);
}
}
});
let mut stdout = String::new();
let mut stderr = String::new();
let start_time = std::time::Instant::now();
let stdout_stream = channel.stream(0);
let mut stdout_reader = BufReader::new(stdout_stream);
let mut line = String::new();
while stdout_reader.read_line(&mut line)? > 0 {
let content = line.clone();
stdout.push_str(&content);
let event = OutputEvent {
pipeline_name: pipeline_name.clone(),
server_name: server_name.clone(),
step: step.clone(), output_type: OutputType::Stdout,
content: content.trim().to_string(),
timestamp: std::time::Instant::now(),
variables: variable_manager.get_variables().clone(),
};
if tx.blocking_send(event).is_err() {
break;
}
line.clear();
}
let stderr_stream = channel.stderr();
let mut stderr_reader = BufReader::new(stderr_stream);
line.clear();
while stderr_reader.read_line(&mut line)? > 0 {
let content = line.clone();
stderr.push_str(&content);
let event = OutputEvent {
pipeline_name: pipeline_name.clone(),
server_name: server_name.clone(),
step: step.clone(), output_type: OutputType::Stderr,
content: content.trim().to_string(),
timestamp: std::time::Instant::now(),
variables: variable_manager.get_variables().clone(),
};
if tx.blocking_send(event).is_err() {
break;
}
line.clear();
}
drop(tx);
if let Err(e) = output_handle.join() {
eprintln!("Output handler thread error: {:?}", e);
}
channel.wait_close()
.context("Failed to wait for channel close")?;
let exit_code = channel.exit_status()
.context("Failed to get exit status")?;
let execution_time = start_time.elapsed().as_millis() as u64;
info!("SSH command executed with exit code: {}", exit_code);
let execution_result = ExecutionResult {
success: exit_code == 0,
stdout,
stderr,
script: step.script.to_string(),
exit_code,
execution_time_ms: execution_time,
error_message: None,
};
if let Some(rules) = extract_rules {
if let Err(e) = variable_manager.extract_variables(&rules, &execution_result) {
info!("Failed to extract variables: {}", e);
}
}
Ok(execution_result)
}
}
fn connect_with_timeout(addr: &str, timeout: Duration) -> std::io::Result<TcpStream> {
let (tx, rx) = mpsc::channel();
let addr = addr.to_string();
let error_message = format!("connect to {} timeout {} s", addr, timeout.as_secs());
std::thread::spawn(move || {
let res = TcpStream::connect(addr);
let _ = tx.send(res);
});
rx.recv_timeout(timeout).unwrap_or_else(|_| Err(std::io::Error::new(std::io::ErrorKind::TimedOut, error_message)))
}