1use anyhow::{Context, Result};
2use russh::client::{self, Handle, Msg};
3use russh::keys::*;
4use russh::*;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use tokio::io::AsyncReadExt;
8use tokio::net::TcpStream;
9use tokio::sync::mpsc;
10use tracing::{debug, error, info, warn};
11
12#[derive(Debug, Clone)]
14pub struct ReverseSshConfig {
15 pub server_addr: String,
17 pub server_port: u16,
19 pub username: String,
21 pub key_path: Option<String>,
23 pub password: Option<String>,
25 pub remote_port: u32,
27 pub local_addr: String,
29 pub local_port: u16,
31}
32
33struct Client {
35 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
36 message_tx: mpsc::UnboundedSender<String>,
37}
38
39#[async_trait::async_trait]
40impl client::Handler for Client {
41 type Error = russh::Error;
42
43 async fn check_server_key(
44 &mut self,
45 _server_public_key: &key::PublicKey,
46 ) -> Result<bool, Self::Error> {
47 Ok(true)
50 }
51
52 async fn server_channel_open_forwarded_tcpip(
53 &mut self,
54 channel: Channel<Msg>,
55 connected_address: &str,
56 connected_port: u32,
57 originator_address: &str,
58 originator_port: u32,
59 _session: &mut client::Session,
60 ) -> Result<(), Self::Error> {
61 info!(
62 "Server opened forwarded channel: {}:{} -> {}:{}",
63 originator_address, originator_port, connected_address, connected_port
64 );
65
66 let _ = self.tx.send((channel, connected_address.to_string(), connected_port));
68
69 Ok(())
70 }
71
72 async fn data(
73 &mut self,
74 _channel: ChannelId,
75 data: &[u8],
76 _session: &mut client::Session,
77 ) -> Result<(), Self::Error> {
78 if let Ok(message) = String::from_utf8(data.to_vec()) {
81 debug!("Received data ({} bytes): {}", data.len(), message);
82 let _ = self.message_tx.send(message);
83 } else {
84 debug!("Received {} bytes of non-UTF8 data on channel {:?}", data.len(), _channel);
86 }
87 Ok(())
88 }
89
90 async fn extended_data(
91 &mut self,
92 _channel: ChannelId,
93 ext: u32,
94 data: &[u8],
95 _session: &mut client::Session,
96 ) -> Result<(), Self::Error> {
97 if let Ok(message) = String::from_utf8(data.to_vec()) {
100 info!("Received extended data (type {}): {}", ext, message);
101 let _ = self.message_tx.send(message);
102 }
103 debug!("Received {} bytes of extended data (type {}) on channel {:?}", data.len(), ext, _channel);
104 Ok(())
105 }
106}
107
108impl Client {
109 fn new(
110 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
111 message_tx: mpsc::UnboundedSender<String>,
112 ) -> Self {
113 Self { tx, message_tx }
114 }
115}
116
117pub struct ReverseSshClient {
119 config: ReverseSshConfig,
120 handle: Option<Handle<Client>>,
121}
122
123impl ReverseSshClient {
124 pub fn new(config: ReverseSshConfig) -> Self {
126 Self {
127 config,
128 handle: None,
129 }
130 }
131
132 pub async fn connect(
134 &mut self,
135 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
136 message_tx: mpsc::UnboundedSender<String>,
137 ) -> Result<()> {
138 info!("Connecting to SSH server {}:{}", self.config.server_addr, self.config.server_port);
139
140 let client_config = client::Config {
141 inactivity_timeout: Some(std::time::Duration::from_secs(3600)),
142 ..<_>::default()
143 };
144
145 let client_handler = Client::new(tx, message_tx);
146
147 let mut session = client::connect(
148 Arc::new(client_config),
149 (self.config.server_addr.as_str(), self.config.server_port),
150 client_handler,
151 )
152 .await
153 .context("Failed to connect to SSH server")?;
154
155 let auth_result = if let Some(key_path) = &self.config.key_path {
157 info!("Authenticating with private key: {}", key_path);
158 let key_pair = russh_keys::load_secret_key(key_path, None)
159 .context("Failed to load private key")?;
160 session
161 .authenticate_publickey(&self.config.username, Arc::new(key_pair))
162 .await
163 } else if let Some(password) = &self.config.password {
164 info!("Authenticating with password");
165 session
166 .authenticate_password(&self.config.username, password)
167 .await
168 } else {
169 anyhow::bail!("No authentication method provided (need key_path or password)");
170 };
171
172 if !auth_result.context("Authentication failed")? {
173 anyhow::bail!("Authentication rejected by server");
174 }
175
176 info!("Successfully authenticated to SSH server");
177 self.handle = Some(session);
178 Ok(())
179 }
180
181 pub async fn setup_reverse_tunnel(&mut self) -> Result<()> {
184 let handle = self
185 .handle
186 .as_mut()
187 .context("Not connected - call connect() first")?;
188
189 info!(
190 "Setting up reverse tunnel: server port {} -> local {}:{}",
191 self.config.remote_port, self.config.local_addr, self.config.local_port
192 );
193
194 handle
198 .tcpip_forward("", self.config.remote_port)
199 .await
200 .context("Failed to set up remote port forwarding")?;
201
202 info!("Reverse tunnel established successfully");
203
204 match handle.channel_open_session().await {
207 Ok(channel) => {
208 info!("Opened shell session to receive server messages");
209 if let Err(e) = channel.request_shell(false).await {
211 warn!("Failed to request shell: {}", e);
212 } else {
213 debug!("Shell requested successfully");
214 }
215 }
218 Err(e) => {
219 warn!("Could not open shell session: {} (this may be normal for some servers)", e);
220 }
221 }
222
223 Ok(())
224 }
225
226 pub async fn read_server_messages(&mut self) -> Result<Vec<String>> {
229 let handle = self
230 .handle
231 .as_mut()
232 .context("Not connected - call connect() first")?;
233
234 let mut messages = Vec::new();
235
236 match handle.channel_open_session().await {
238 Ok(channel) => {
239 let _ = channel.request_shell(false).await;
241
242 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
244
245 let _ = channel.eof().await;
251 let _ = channel.close().await;
252
253 messages.push("Check SSH session output for connection URL".to_string());
254 }
255 Err(e) => {
256 warn!("Could not open session channel: {}", e);
257 }
258 }
259
260 Ok(messages)
261 }
262
263 pub async fn handle_forwarded_connections(&mut self, mut rx: mpsc::UnboundedReceiver<(Channel<Msg>, String, u32)>) -> Result<()> {
265 info!("Waiting for forwarded connections...");
266
267 while let Some((channel, _remote_addr, _remote_port)) = rx.recv().await {
268 info!("New forwarded connection received");
269
270 let local_addr = self.config.local_addr.clone();
272 let local_port = self.config.local_port;
273
274 tokio::spawn(async move {
275 if let Err(e) = handle_connection(channel, &local_addr, local_port).await {
276 error!("Error handling connection: {}", e);
277 }
278 });
279 }
280
281 warn!("Connection closed by server");
282 Ok(())
283 }
284
285 pub async fn run(&mut self) -> Result<()> {
287 let (tx, rx) = mpsc::unbounded_channel();
288 let (message_tx, mut message_rx) = mpsc::unbounded_channel();
289
290 self.connect(tx, message_tx).await?;
291 self.setup_reverse_tunnel().await?;
292
293 tokio::spawn(async move {
295 while let Some(message) = message_rx.recv().await {
296 if !message.trim().is_empty() {
298 println!("[Server] {}", message.trim());
299 }
300 }
301 });
302
303 self.handle_forwarded_connections(rx).await?;
304
305 Ok(())
306 }
307
308 pub async fn run_with_message_handler<F>(&mut self, mut message_handler: F) -> Result<()>
310 where
311 F: FnMut(String) + Send + 'static,
312 {
313 let (tx, rx) = mpsc::unbounded_channel();
314 let (message_tx, mut message_rx) = mpsc::unbounded_channel();
315
316 self.connect(tx, message_tx).await?;
317 self.setup_reverse_tunnel().await?;
318
319 tokio::spawn(async move {
321 while let Some(message) = message_rx.recv().await {
322 message_handler(message);
323 }
324 });
325
326 self.handle_forwarded_connections(rx).await?;
327
328 Ok(())
329 }
330}
331
332async fn handle_connection(
334 channel: Channel<Msg>,
335 local_addr: &str,
336 local_port: u16,
337) -> Result<()> {
338 info!("Connecting to local service {}:{}", local_addr, local_port);
339
340 let local_socket_addr: SocketAddr = format!("{}:{}", local_addr, local_port)
342 .parse()
343 .context("Invalid local address")?;
344
345 let mut local_stream = TcpStream::connect(local_socket_addr)
346 .await
347 .context("Failed to connect to local service")?;
348
349 info!("Connected to local service, starting proxy");
350
351 let mut buffer = vec![0u8; 8192];
356
357 loop {
359 match local_stream.read(&mut buffer).await {
360 Ok(0) => {
361 debug!("Local connection closed");
362 break;
363 }
364 Ok(n) => {
365 debug!("Read {} bytes from local service", n);
366 if let Err(e) = channel.data(&buffer[..n]).await {
368 error!("Failed to send data to SSH channel: {}", e);
369 break;
370 }
371 }
372 Err(e) => {
373 error!("Error reading from local service: {}", e);
374 break;
375 }
376 }
377 }
378
379 let _ = channel.eof().await;
381 let _ = channel.close().await;
382
383 info!("Connection closed");
384 Ok(())
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn test_config_creation() {
393 let config = ReverseSshConfig {
394 server_addr: "example.com".to_string(),
395 server_port: 22,
396 username: "user".to_string(),
397 key_path: Some("/path/to/key".to_string()),
398 password: None,
399 remote_port: 8080,
400 local_addr: "127.0.0.1".to_string(),
401 local_port: 3000,
402 };
403
404 assert_eq!(config.server_addr, "example.com");
405 assert_eq!(config.remote_port, 8080);
406 }
407}