net_shell/ssh/
mod.rs

1pub mod local;
2
3use anyhow::{Context, Result};
4use ssh2::Session;
5use std::io::{BufRead, BufReader};
6use std::net::TcpStream;
7use std::path::Path;
8use std::sync::Arc;
9use std::sync::mpsc;
10use std::time::Duration;
11use tokio::sync::mpsc as tokio_mpsc;
12use tracing::info;
13
14use crate::models::{ExecutionResult, SshConfig, OutputEvent, OutputType, OutputCallback};
15use crate::Step;
16use crate::vars::VariableManager;
17use crate::ExtractRule;
18
19/// SSH执行器
20pub struct SshExecutor;
21
22impl SshExecutor {
23    /// 通过SSH执行脚本(支持实时输出)
24    pub fn execute_script_with_realtime_output(
25        server_name: &str,
26        ssh_config: &SshConfig, 
27        step: &Step,
28        pipeline_name: &str,
29        step_name: &str,
30        output_callback: Option<OutputCallback>,
31        mut variable_manager: VariableManager,
32        extract_rules: Option<Vec<ExtractRule>>
33    ) -> Result<ExecutionResult> {
34        info!("Connecting to {}:{} as {}", ssh_config.host, ssh_config.port, ssh_config.username);
35
36        // 只用step.script作为脚本路径,不做参数处理
37        let script_path = step.script.as_str();
38        // 读取本地脚本内容并替换变量
39        let script_content = std::fs::read_to_string(script_path)
40            .context(format!("Failed to read script file: {}", script_path))?;
41
42        variable_manager.set_variable("ssh_server_name".to_string(), server_name.to_string());
43        variable_manager.set_variable("ssh_server_ip".to_string(), ssh_config.host.to_string());
44
45        let script_content = variable_manager.replace_variables(&script_content);
46
47        // 设置连接超时
48        let timeout_seconds = step.timeout_seconds
49            .or(ssh_config.timeout_seconds)
50            .unwrap_or(3);
51        let timeout_duration = Duration::from_secs(timeout_seconds);
52        
53        // 建立TCP连接(带严格超时)
54        let tcp = connect_with_timeout(&format!("{}:{}", ssh_config.host, ssh_config.port), timeout_duration)
55            .context("Failed to connect to SSH server")?;
56        
57        // 设置TCP连接超时
58        tcp.set_read_timeout(Some(timeout_duration))
59            .context("Failed to set read timeout")?;
60        tcp.set_write_timeout(Some(timeout_duration))
61            .context("Failed to set write timeout")?;
62        tcp.set_nodelay(true)
63            .context("Failed to set TCP nodelay")?;
64
65        // 创建SSH会话
66        let mut sess = Session::new()
67            .context("Failed to create SSH session")?;
68        
69        sess.set_tcp_stream(tcp);
70        
71        // 设置SSH会话超时(使用步骤级别的超时,如果没有则使用默认值)
72        let session_timeout_seconds = step.timeout_seconds.unwrap_or(30);
73        let session_timeout_duration = Duration::from_secs(session_timeout_seconds);
74        sess.set_timeout(session_timeout_duration.as_millis() as u32);
75        
76        // SSH握手(带超时)
77        sess.handshake()
78            .context("SSH handshake failed")?;
79
80        info!("SSH handshake completed, starting authentication");
81
82        // 认证(带超时)
83        let auth_result = if let Some(ref password) = ssh_config.password {
84            sess.userauth_password(&ssh_config.username, password)
85                .context("SSH password authentication failed")
86        } else if let Some(ref key_path) = ssh_config.private_key_path {
87            sess.userauth_pubkey_file(&ssh_config.username, None, Path::new(key_path), None)
88                .context("SSH key authentication failed")
89        } else {
90            Err(anyhow::anyhow!("No authentication method provided"))
91        };
92
93        auth_result?;
94        info!("SSH authentication successful");
95
96        // 打开远程shell
97        let mut channel = sess.channel_session()
98            .context("Failed to create SSH channel")?;
99        channel.exec("sh")
100            .context("Failed to exec remote shell")?;
101
102        // 把脚本内容写入远程shell的stdin
103        use std::io::Write;
104        channel.write_all(script_content.as_bytes())
105            .context("Failed to write script to remote shell")?;
106        channel.send_eof()
107            .context("Failed to send EOF to remote shell")?;
108
109        // 创建通道用于实时输出
110        let (tx, mut rx) = tokio_mpsc::channel::<OutputEvent>(100);
111        let output_callback = output_callback.map(|cb| Arc::new(cb));
112
113        // 在单独的线程中处理实时输出
114        let server_name = server_name.to_string();
115        let _step_name = step_name.to_string();
116        let pipeline_name = pipeline_name.to_string();
117        let output_callback_clone = output_callback.clone();
118        
119        let output_handle = std::thread::spawn(move || {
120            while let Some(event) = rx.blocking_recv() {
121                if let Some(callback) = &output_callback_clone {
122                    callback(event);
123                }
124            }
125        });
126
127        // 读取stdout和stderr
128        let mut stdout = String::new();
129        let mut stderr = String::new();
130        let start_time = std::time::Instant::now();
131
132        // 实时读取stdout
133        let stdout_stream = channel.stream(0);
134        let mut stdout_reader = BufReader::new(stdout_stream);
135        let mut line = String::new();
136        
137        while stdout_reader.read_line(&mut line)? > 0 {
138            let content = line.clone();
139            stdout.push_str(&content);
140            
141            // 发送实时输出事件
142            let event = OutputEvent {
143                pipeline_name: pipeline_name.clone(),
144                server_name: server_name.clone(),
145                step: step.clone(), // 传递完整的Step对象
146                output_type: OutputType::Stdout,
147                content: content.trim().to_string(),
148                timestamp: std::time::Instant::now(),
149                variables: variable_manager.get_variables().clone(),
150            };
151            
152            if tx.blocking_send(event).is_err() {
153                break;
154            }
155            
156            line.clear();
157        }
158
159        // 实时读取stderr
160        let stderr_stream = channel.stderr();
161        let mut stderr_reader = BufReader::new(stderr_stream);
162        line.clear();
163        
164        while stderr_reader.read_line(&mut line)? > 0 {
165            let content = line.clone();
166            stderr.push_str(&content);
167            
168            // 发送实时输出事件
169            let event = OutputEvent {
170                pipeline_name: pipeline_name.clone(),
171                server_name: server_name.clone(),
172                step: step.clone(), // 传递完整的Step对象
173                output_type: OutputType::Stderr,
174                content: content.trim().to_string(),
175                timestamp: std::time::Instant::now(),
176                variables: variable_manager.get_variables().clone(),
177            };
178            
179            if tx.blocking_send(event).is_err() {
180                break;
181            }
182            
183            line.clear();
184        }
185
186        // 等待通道关闭
187        drop(tx);
188        if let Err(e) = output_handle.join() {
189            eprintln!("Output handler thread error: {:?}", e);
190        }
191
192        channel.wait_close()
193            .context("Failed to wait for channel close")?;
194
195        let exit_code = channel.exit_status()
196            .context("Failed to get exit status")?;
197
198        let execution_time = start_time.elapsed().as_millis() as u64;
199        info!("SSH command executed with exit code: {}", exit_code);
200
201        // 创建执行结果
202        let execution_result = ExecutionResult {
203            success: exit_code == 0,
204            stdout,
205            stderr,
206            script: step.script.to_string(),
207            exit_code,
208            execution_time_ms: execution_time,
209            error_message: None,
210        };
211
212        // 提取变量
213        if let Some(rules) = extract_rules {
214            if let Err(e) = variable_manager.extract_variables(&rules, &execution_result) {
215                info!("Failed to extract variables: {}", e);
216            }
217        }
218
219        Ok(execution_result)
220    }
221
222}
223
224/// 工具函数:带超时的TCP连接
225fn connect_with_timeout(addr: &str, timeout: Duration) -> std::io::Result<TcpStream> {
226    let (tx, rx) = mpsc::channel();
227    let addr = addr.to_string();
228    let error_message = format!("connect to {} timeout {} s", addr, timeout.as_secs());
229    std::thread::spawn(move || {
230        let res = TcpStream::connect(addr);
231        let _ = tx.send(res);
232    });
233    rx.recv_timeout(timeout).unwrap_or_else(|_| Err(std::io::Error::new(std::io::ErrorKind::TimedOut, error_message)))
234}