net_shell/ssh/
mod.rs

1pub mod local;
2
3use anyhow::{Context, Error, Result};
4use ssh2::Session;
5use std::io::{BufRead, BufReader};
6use std::net::TcpStream;
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::mpsc;
10use std::time::{Duration};
11use tokio::sync::mpsc as tokio_mpsc;
12use tracing::info;
13
14use crate::models::{ExecutionResult, SshConfig, OutputEvent, OutputType, OutputCallback};
15use crate::Step;
16use crate::vars::VariableManager;
17use crate::ExtractRule;
18
19/// SSH执行器
20pub struct SshExecutor;
21
22impl SshExecutor {
23    /// 通过SSH执行脚本(支持实时输出)
24    pub fn execute_script_with_realtime_output(
25        global_scripts:Arc<Vec<String>>,
26        server_name: &str,
27        ssh_config: &SshConfig, 
28        step: &Step,
29        pipeline_name: &str,
30        step_name: &str,
31        output_callback: Option<OutputCallback>,
32        mut variable_manager: VariableManager,
33        extract_rules: Option<Vec<ExtractRule>>
34    ) -> Result<ExecutionResult> {
35        info!("Connecting to {}:{} as {}", ssh_config.host, ssh_config.port, ssh_config.username);
36
37        // 只用step.script作为脚本路径,不做参数处理
38        let script_path = step.script.as_str(); 
39
40        // 读取本地脚本内容并替换变量
41        let script_content = std::fs::read_to_string(script_path)
42            .context(format!("Failed to read script file: {}", script_path))?;
43
44        let mut gloabl_script_content = global_scripts.iter()
45        .map(|v|std::fs::read_to_string(v).context(format!("read file:[{}]", v)))
46        .fold(Ok("".to_string()), |p:Result<String>,v|{
47            if p.is_err(){
48                return p;
49            }
50
51            if v.is_err(){
52                return Err(Error::msg(format!("{:?}", v.err())));
53            }
54            let content = v.unwrap();
55
56            let mut s = p.unwrap_or_default();
57
58            s.push_str("\n");
59            s.push_str(&content);
60
61            return  Ok(s.clone());
62        })?;
63
64        gloabl_script_content.push_str("\n");
65        gloabl_script_content.push_str(&script_content);
66
67        let script_content = gloabl_script_content.clone();
68
69        variable_manager.set_variable("ssh_server_name".to_string(), server_name.to_string());
70        variable_manager.set_variable("ssh_server_ip".to_string(), ssh_config.host.to_string());
71
72        let script_content = variable_manager.replace_variables(&script_content);
73
74        // 设置连接超时
75        let ssh_timeout_seconds = ssh_config.timeout_seconds.unwrap_or(3);
76        let ssh_timeout_duration = Duration::from_secs(ssh_timeout_seconds);
77
78        // 建立TCP连接(带严格超时)
79        let tcp = connect_with_timeout(&format!("{}:{}", ssh_config.host, ssh_config.port), ssh_timeout_duration)
80            .context("Failed to connect to SSH server")?;
81
82        let timeout_duration = Duration::from_secs(step.timeout_seconds.unwrap_or(30));
83        
84        // 设置TCP连接超时
85        tcp.set_read_timeout(Some(timeout_duration))
86            .context("Failed to set read timeout")?;
87        tcp.set_write_timeout(Some(timeout_duration))
88            .context("Failed to set write timeout")?;
89        tcp.set_nodelay(true)
90            .context("Failed to set TCP nodelay")?;
91
92        // 创建SSH会话
93        let mut sess = Session::new()
94            .context("Failed to create SSH session")?;
95        
96        sess.set_tcp_stream(tcp);
97        
98        // 设置SSH会话超时(使用步骤级别的超时,如果没有则使用默认值)
99        let session_timeout_seconds = step.timeout_seconds.unwrap_or(30);
100        let session_timeout_duration = Duration::from_secs(session_timeout_seconds);
101        sess.set_timeout(session_timeout_duration.as_millis() as u32);
102        
103        // SSH握手(带超时)
104        sess.handshake()
105            .context(format!("SSH handshake failed: timeout {} s", ssh_timeout_seconds))?;
106
107        info!("SSH handshake completed, starting authentication");
108
109        // 认证(带超时)
110        let auth_result = if let Some(ref password) = ssh_config.password {
111            sess.userauth_password(&ssh_config.username, password)
112                .context("SSH password authentication failed")
113        } else if let Some(ref key_path) = ssh_config.private_key_path {
114            sess.userauth_pubkey_file(&ssh_config.username, None, Path::new(key_path), None)
115                .context("SSH key authentication failed")
116        } else {
117            Err(anyhow::anyhow!("No authentication method provided"))
118        };
119
120        auth_result?;
121        info!("SSH authentication successful");
122
123        // 打开远程shell
124        let mut channel = sess.channel_session()
125            .context("Failed to create SSH channel")?;
126        channel.exec("bash")
127            .context("Failed to exec remote shell")?;
128
129        // 把脚本内容写入远程shell的stdin
130        use std::io::Write;
131        channel.write_all(script_content.as_bytes())
132            .context("Failed to write script to remote shell")?;
133        channel.send_eof()
134            .context("Failed to send EOF to remote shell")?;
135
136        // 创建通道用于实时输出
137        let (tx, mut rx) = tokio_mpsc::channel::<OutputEvent>(100);
138        let output_callback = output_callback.map(|cb| Arc::new(cb));
139
140        // 在单独的线程中处理实时输出
141        let server_name = server_name.to_string();
142        let _step_name = step_name.to_string();
143        let pipeline_name = pipeline_name.to_string();
144        let output_callback_clone = output_callback.clone();
145        
146        let output_handle = std::thread::spawn(move || {
147            while let Some(event) = rx.blocking_recv() {
148                if let Some(callback) = &output_callback_clone {
149                    callback(event);
150                }
151            }
152        });
153
154        // 读取stdout和stderr
155        let mut stdout = String::new();
156        let mut stderr = String::new();
157        let start_time = std::time::Instant::now();
158
159        // 实时读取stdout
160        let stdout_stream = channel.stream(0);
161        let mut stdout_reader = BufReader::new(stdout_stream);
162        let mut line = String::new();
163        
164        while stdout_reader.read_line(&mut line)? > 0 {
165            let content = line.clone();
166            stdout.push_str(&content);
167            
168            // 发送实时输出事件
169            let event = OutputEvent {
170                pipeline_name: pipeline_name.clone(),
171                server_name: server_name.clone(),
172                step: step.clone(), // 传递完整的Step对象
173                script_path:step.script.to_string(),
174                output_type: OutputType::Stdout,
175                content: content.trim().to_string(),
176                timestamp: std::time::Instant::now(),
177                variables: variable_manager.get_variables().clone(),
178            };
179            
180            if tx.blocking_send(event).is_err() {
181                break;
182            }
183            
184            line.clear();
185        }
186
187        // 实时读取stderr
188        let stderr_stream = channel.stderr();
189        let mut stderr_reader = BufReader::new(stderr_stream);
190        line.clear();
191        
192        while stderr_reader.read_line(&mut line)? > 0 {
193            let content = line.clone();
194            stderr.push_str(&content);
195            
196            // 发送实时输出事件
197            let event = OutputEvent {
198                pipeline_name: pipeline_name.clone(),
199                server_name: server_name.clone(),
200                step: step.clone(), // 传递完整的Step对象
201                script_path:step.script.to_string(),
202                output_type: OutputType::Stderr,
203                content: content.trim().to_string(),
204                timestamp: std::time::Instant::now(),
205                variables: variable_manager.get_variables().clone(),
206            };
207            
208            if tx.blocking_send(event).is_err() {
209                break;
210            }
211            
212            line.clear();
213        }
214
215        // 等待通道关闭
216        drop(tx);
217        if let Err(e) = output_handle.join() {
218            eprintln!("Output handler thread error: {:?}", e);
219        }
220
221        channel.wait_close()
222            .context("Failed to wait for channel close")?;
223
224        let exit_code = channel.exit_status()
225            .context("Failed to get exit status")?;
226
227        let execution_time = start_time.elapsed().as_millis() as u64;
228        info!("SSH command executed with exit code: {}", exit_code);
229
230        // 创建执行结果
231        let execution_result = ExecutionResult {
232            success: exit_code == 0,
233            stdout,
234            stderr,
235            script: step.script.to_string(),
236            exit_code,
237            execution_time_ms: execution_time,
238            error_message: None,
239        };
240
241        // 提取变量
242        if let Some(rules) = extract_rules {
243            if let Err(e) = variable_manager.extract_variables(&rules, &execution_result) {
244                info!("Failed to extract variables: {}", e);
245            }
246        }
247
248        Ok(execution_result)
249    }
250
251}
252
253/// 工具函数:带超时的TCP连接
254fn connect_with_timeout(addr: &str, timeout: Duration) -> std::io::Result<TcpStream> {
255    let (tx, rx) = mpsc::channel();
256    let addr = addr.to_string();
257    let error_message = format!("connect to {} timeout {} s", addr, timeout.as_secs());
258    std::thread::spawn(move || {
259        let res = TcpStream::connect(addr);
260        let _ = tx.send(res);
261    });
262    rx.recv_timeout(timeout).unwrap_or_else(|_| Err(std::io::Error::new(std::io::ErrorKind::TimedOut, error_message)))
263}