1pub mod local;
2
3use anyhow::{Context, Result};
4use ssh2::Session;
5use std::io::{BufRead, BufReader};
6use std::net::TcpStream;
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::mpsc;
10use std::time::Duration;
11use tokio::sync::mpsc as tokio_mpsc;
12use tracing::info;
13
14use crate::models::{ExecutionResult, SshConfig, OutputEvent, OutputType, OutputCallback};
15use crate::Step;
16use crate::vars::VariableManager;
17use crate::ExtractRule;
18
19pub struct SshExecutor;
21
22impl SshExecutor {
23 pub fn execute_script_with_realtime_output(
25 server_name: &str,
26 ssh_config: &SshConfig,
27 step: &Step,
28 pipeline_name: &str,
29 step_name: &str,
30 output_callback: Option<OutputCallback>,
31 mut variable_manager: VariableManager,
32 extract_rules: Option<Vec<ExtractRule>>
33 ) -> Result<ExecutionResult> {
34 info!("Connecting to {}:{} as {}", ssh_config.host, ssh_config.port, ssh_config.username);
35
36 let script_path = step.script.as_str();
38 let script_content = std::fs::read_to_string(script_path)
40 .context(format!("Failed to read script file: {}", script_path))?;
41
42 variable_manager.set_variable("ssh_server_name".to_string(), server_name.to_string());
43 variable_manager.set_variable("ssh_server_ip".to_string(), ssh_config.host.to_string());
44
45 let script_content = variable_manager.replace_variables(&script_content);
46
47 let timeout_seconds = step.timeout_seconds
49 .or(ssh_config.timeout_seconds)
50 .unwrap_or(3);
51 let timeout_duration = Duration::from_secs(timeout_seconds);
52
53 let tcp = connect_with_timeout(&format!("{}:{}", ssh_config.host, ssh_config.port), timeout_duration)
55 .context("Failed to connect to SSH server")?;
56
57 tcp.set_read_timeout(Some(timeout_duration))
59 .context("Failed to set read timeout")?;
60 tcp.set_write_timeout(Some(timeout_duration))
61 .context("Failed to set write timeout")?;
62 tcp.set_nodelay(true)
63 .context("Failed to set TCP nodelay")?;
64
65 let mut sess = Session::new()
67 .context("Failed to create SSH session")?;
68
69 sess.set_tcp_stream(tcp);
70
71 let session_timeout_seconds = step.timeout_seconds.unwrap_or(30);
73 let session_timeout_duration = Duration::from_secs(session_timeout_seconds);
74 sess.set_timeout(session_timeout_duration.as_millis() as u32);
75
76 sess.handshake()
78 .context("SSH handshake failed")?;
79
80 info!("SSH handshake completed, starting authentication");
81
82 let auth_result = if let Some(ref password) = ssh_config.password {
84 sess.userauth_password(&ssh_config.username, password)
85 .context("SSH password authentication failed")
86 } else if let Some(ref key_path) = ssh_config.private_key_path {
87 sess.userauth_pubkey_file(&ssh_config.username, None, Path::new(key_path), None)
88 .context("SSH key authentication failed")
89 } else {
90 Err(anyhow::anyhow!("No authentication method provided"))
91 };
92
93 auth_result?;
94 info!("SSH authentication successful");
95
96 let mut channel = sess.channel_session()
98 .context("Failed to create SSH channel")?;
99 channel.exec("sh")
100 .context("Failed to exec remote shell")?;
101
102 use std::io::Write;
104 channel.write_all(script_content.as_bytes())
105 .context("Failed to write script to remote shell")?;
106 channel.send_eof()
107 .context("Failed to send EOF to remote shell")?;
108
109 let (tx, mut rx) = tokio_mpsc::channel::<OutputEvent>(100);
111 let output_callback = output_callback.map(|cb| Arc::new(cb));
112
113 let server_name = server_name.to_string();
115 let _step_name = step_name.to_string();
116 let pipeline_name = pipeline_name.to_string();
117 let output_callback_clone = output_callback.clone();
118
119 let output_handle = std::thread::spawn(move || {
120 while let Some(event) = rx.blocking_recv() {
121 if let Some(callback) = &output_callback_clone {
122 callback(event);
123 }
124 }
125 });
126
127 let mut stdout = String::new();
129 let mut stderr = String::new();
130 let start_time = std::time::Instant::now();
131
132 let stdout_stream = channel.stream(0);
134 let mut stdout_reader = BufReader::new(stdout_stream);
135 let mut line = String::new();
136
137 while stdout_reader.read_line(&mut line)? > 0 {
138 let content = line.clone();
139 stdout.push_str(&content);
140
141 let event = OutputEvent {
143 pipeline_name: pipeline_name.clone(),
144 server_name: server_name.clone(),
145 step: step.clone(), output_type: OutputType::Stdout,
147 content: content.trim().to_string(),
148 timestamp: std::time::Instant::now(),
149 variables: variable_manager.get_variables().clone(),
150 };
151
152 if tx.blocking_send(event).is_err() {
153 break;
154 }
155
156 line.clear();
157 }
158
159 let stderr_stream = channel.stderr();
161 let mut stderr_reader = BufReader::new(stderr_stream);
162 line.clear();
163
164 while stderr_reader.read_line(&mut line)? > 0 {
165 let content = line.clone();
166 stderr.push_str(&content);
167
168 let event = OutputEvent {
170 pipeline_name: pipeline_name.clone(),
171 server_name: server_name.clone(),
172 step: step.clone(), output_type: OutputType::Stderr,
174 content: content.trim().to_string(),
175 timestamp: std::time::Instant::now(),
176 variables: variable_manager.get_variables().clone(),
177 };
178
179 if tx.blocking_send(event).is_err() {
180 break;
181 }
182
183 line.clear();
184 }
185
186 drop(tx);
188 if let Err(e) = output_handle.join() {
189 eprintln!("Output handler thread error: {:?}", e);
190 }
191
192 channel.wait_close()
193 .context("Failed to wait for channel close")?;
194
195 let exit_code = channel.exit_status()
196 .context("Failed to get exit status")?;
197
198 let execution_time = start_time.elapsed().as_millis() as u64;
199 info!("SSH command executed with exit code: {}", exit_code);
200
201 let execution_result = ExecutionResult {
203 success: exit_code == 0,
204 stdout,
205 stderr,
206 script: step.script.to_string(),
207 exit_code,
208 execution_time_ms: execution_time,
209 error_message: None,
210 };
211
212 if let Some(rules) = extract_rules {
214 if let Err(e) = variable_manager.extract_variables(&rules, &execution_result) {
215 info!("Failed to extract variables: {}", e);
216 }
217 }
218
219 Ok(execution_result)
220 }
221
222}
223
224fn connect_with_timeout(addr: &str, timeout: Duration) -> std::io::Result<TcpStream> {
226 let (tx, rx) = mpsc::channel();
227 let addr = addr.to_string();
228 let error_message = format!("connect to {} timeout {} s", addr, timeout.as_secs());
229 std::thread::spawn(move || {
230 let res = TcpStream::connect(addr);
231 let _ = tx.send(res);
232 });
233 rx.recv_timeout(timeout).unwrap_or_else(|_| Err(std::io::Error::new(std::io::ErrorKind::TimedOut, error_message)))
234}