net_shell/ssh/
mod.rs

1pub mod local;
2
3use anyhow::{Context, Error, Result};
4use ssh2::Session;
5use std::fs;
6use std::io::{BufRead, BufReader};
7use std::net::TcpStream;
8use std::path::Path;
9use std::sync::Arc;
10use std::sync::mpsc;
11use std::time::Duration;
12use tokio::sync::mpsc as tokio_mpsc;
13use tracing::info;
14
15use crate::models::{ExecutionResult, SshConfig, OutputEvent, OutputType, OutputCallback};
16use crate::Step;
17use crate::vars::VariableManager;
18use crate::ExtractRule;
19
20/// SSH执行器
21pub struct SshExecutor;
22
23impl SshExecutor {
24    /// 通过SSH执行脚本(支持实时输出)
25    pub fn execute_script_with_realtime_output(
26        global_scripts:Arc<Vec<String>>,
27        server_name: &str,
28        ssh_config: &SshConfig, 
29        step: &Step,
30        pipeline_name: &str,
31        step_name: &str,
32        output_callback: Option<OutputCallback>,
33        mut variable_manager: VariableManager,
34        extract_rules: Option<Vec<ExtractRule>>
35    ) -> Result<ExecutionResult> {
36        info!("Connecting to {}:{} as {}", ssh_config.host, ssh_config.port, ssh_config.username);
37
38        // 只用step.script作为脚本路径,不做参数处理
39        let script_path = step.script.as_str();
40
41        // 读取本地脚本内容并替换变量
42        let script_content = std::fs::read_to_string(script_path)
43            .context(format!("Failed to read script file: {}", script_path))?;
44
45        let mut gloabl_script_content = global_scripts.iter()
46        .map(|v|std::fs::read_to_string(v).context(format!("read file:[{}]", v)))
47        .fold(Ok("".to_string()), |p:Result<String>,v|{
48            if p.is_err(){
49                return p;
50            }
51
52            if v.is_err(){
53                return Err(Error::msg(format!("{:?}", v.err())));
54            }
55            let content = v.unwrap();
56
57            let mut s = p.unwrap_or_default();
58
59            s.push_str("\n");
60            s.push_str(&content);
61
62            return  Ok(s.clone());
63        })?;
64
65        gloabl_script_content.push_str("\n");
66        gloabl_script_content.push_str(&script_content);
67
68        let script_content = gloabl_script_content.clone();
69
70        match fs::write("script.sh", script_path.as_bytes())
71                    .context("Failed to write temporary script file") {
72            Ok(_) => {},
73            Err(e) => {
74                println!("Warning: Failed to write temporary script file: {}", e);
75            },
76        };
77
78        variable_manager.set_variable("ssh_server_name".to_string(), server_name.to_string());
79        variable_manager.set_variable("ssh_server_ip".to_string(), ssh_config.host.to_string());
80
81        let script_content = variable_manager.replace_variables(&script_content);
82
83        // 设置连接超时
84        let ssh_timeout_seconds = ssh_config.timeout_seconds.unwrap_or(3);
85        let ssh_timeout_duration = Duration::from_secs(ssh_timeout_seconds);
86
87        // 建立TCP连接(带严格超时)
88        let tcp = connect_with_timeout(&format!("{}:{}", ssh_config.host, ssh_config.port), ssh_timeout_duration)
89            .context("Failed to connect to SSH server")?;
90
91        let timeout_duration = Duration::from_secs(step.timeout_seconds.unwrap_or(30));
92        
93        // 设置TCP连接超时
94        tcp.set_read_timeout(Some(timeout_duration))
95            .context("Failed to set read timeout")?;
96        tcp.set_write_timeout(Some(timeout_duration))
97            .context("Failed to set write timeout")?;
98        tcp.set_nodelay(true)
99            .context("Failed to set TCP nodelay")?;
100
101        // 创建SSH会话
102        let mut sess = Session::new()
103            .context("Failed to create SSH session")?;
104        
105        sess.set_tcp_stream(tcp);
106        
107        // 设置SSH会话超时(使用步骤级别的超时,如果没有则使用默认值)
108        let session_timeout_seconds = step.timeout_seconds.unwrap_or(30);
109        let session_timeout_duration = Duration::from_secs(session_timeout_seconds);
110        sess.set_timeout(session_timeout_duration.as_millis() as u32);
111        
112        // SSH握手(带超时)
113        sess.handshake()
114            .context(format!("SSH handshake failed: timeout {} s", ssh_timeout_seconds))?;
115
116        info!("SSH handshake completed, starting authentication");
117
118        // 认证(带超时)
119        let auth_result = if let Some(ref password) = ssh_config.password {
120            sess.userauth_password(&ssh_config.username, password)
121                .context("SSH password authentication failed")
122        } else if let Some(ref key_path) = ssh_config.private_key_path {
123            sess.userauth_pubkey_file(&ssh_config.username, None, Path::new(key_path), None)
124                .context("SSH key authentication failed")
125        } else {
126            Err(anyhow::anyhow!("No authentication method provided"))
127        };
128
129        auth_result?;
130        info!("SSH authentication successful");
131
132        // 打开远程shell
133        let mut channel = sess.channel_session()
134            .context("Failed to create SSH channel")?;
135        channel.exec("sh")
136            .context("Failed to exec remote shell")?;
137
138        // 把脚本内容写入远程shell的stdin
139        use std::io::Write;
140        channel.write_all(script_content.as_bytes())
141            .context("Failed to write script to remote shell")?;
142        channel.send_eof()
143            .context("Failed to send EOF to remote shell")?;
144
145        // 创建通道用于实时输出
146        let (tx, mut rx) = tokio_mpsc::channel::<OutputEvent>(100);
147        let output_callback = output_callback.map(|cb| Arc::new(cb));
148
149        // 在单独的线程中处理实时输出
150        let server_name = server_name.to_string();
151        let _step_name = step_name.to_string();
152        let pipeline_name = pipeline_name.to_string();
153        let output_callback_clone = output_callback.clone();
154        
155        let output_handle = std::thread::spawn(move || {
156            while let Some(event) = rx.blocking_recv() {
157                if let Some(callback) = &output_callback_clone {
158                    callback(event);
159                }
160            }
161        });
162
163        // 读取stdout和stderr
164        let mut stdout = String::new();
165        let mut stderr = String::new();
166        let start_time = std::time::Instant::now();
167
168        // 实时读取stdout
169        let stdout_stream = channel.stream(0);
170        let mut stdout_reader = BufReader::new(stdout_stream);
171        let mut line = String::new();
172        
173        while stdout_reader.read_line(&mut line)? > 0 {
174            let content = line.clone();
175            stdout.push_str(&content);
176            
177            // 发送实时输出事件
178            let event = OutputEvent {
179                pipeline_name: pipeline_name.clone(),
180                server_name: server_name.clone(),
181                step: step.clone(), // 传递完整的Step对象
182                output_type: OutputType::Stdout,
183                content: content.trim().to_string(),
184                timestamp: std::time::Instant::now(),
185                variables: variable_manager.get_variables().clone(),
186            };
187            
188            if tx.blocking_send(event).is_err() {
189                break;
190            }
191            
192            line.clear();
193        }
194
195        // 实时读取stderr
196        let stderr_stream = channel.stderr();
197        let mut stderr_reader = BufReader::new(stderr_stream);
198        line.clear();
199        
200        while stderr_reader.read_line(&mut line)? > 0 {
201            let content = line.clone();
202            stderr.push_str(&content);
203            
204            // 发送实时输出事件
205            let event = OutputEvent {
206                pipeline_name: pipeline_name.clone(),
207                server_name: server_name.clone(),
208                step: step.clone(), // 传递完整的Step对象
209                output_type: OutputType::Stderr,
210                content: content.trim().to_string(),
211                timestamp: std::time::Instant::now(),
212                variables: variable_manager.get_variables().clone(),
213            };
214            
215            if tx.blocking_send(event).is_err() {
216                break;
217            }
218            
219            line.clear();
220        }
221
222        // 等待通道关闭
223        drop(tx);
224        if let Err(e) = output_handle.join() {
225            eprintln!("Output handler thread error: {:?}", e);
226        }
227
228        channel.wait_close()
229            .context("Failed to wait for channel close")?;
230
231        let exit_code = channel.exit_status()
232            .context("Failed to get exit status")?;
233
234        let execution_time = start_time.elapsed().as_millis() as u64;
235        info!("SSH command executed with exit code: {}", exit_code);
236
237        // 创建执行结果
238        let execution_result = ExecutionResult {
239            success: exit_code == 0,
240            stdout,
241            stderr,
242            script: step.script.to_string(),
243            exit_code,
244            execution_time_ms: execution_time,
245            error_message: None,
246        };
247
248        // 提取变量
249        if let Some(rules) = extract_rules {
250            if let Err(e) = variable_manager.extract_variables(&rules, &execution_result) {
251                info!("Failed to extract variables: {}", e);
252            }
253        }
254
255        Ok(execution_result)
256    }
257
258}
259
260/// 工具函数:带超时的TCP连接
261fn connect_with_timeout(addr: &str, timeout: Duration) -> std::io::Result<TcpStream> {
262    let (tx, rx) = mpsc::channel();
263    let addr = addr.to_string();
264    let error_message = format!("connect to {} timeout {} s", addr, timeout.as_secs());
265    std::thread::spawn(move || {
266        let res = TcpStream::connect(addr);
267        let _ = tx.send(res);
268    });
269    rx.recv_timeout(timeout).unwrap_or_else(|_| Err(std::io::Error::new(std::io::ErrorKind::TimedOut, error_message)))
270}