1use anyhow::{Context, Result};
2use russh::client::{self, Handle, Msg};
3use russh::keys::*;
4use russh::*;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use tokio::net::TcpStream;
8use tokio::sync::mpsc;
9use tracing::{debug, error, info, warn};
10
11#[derive(Debug, Clone)]
13pub struct ReverseSshConfig {
14 pub server_addr: String,
16 pub server_port: u16,
18 pub username: String,
20 pub key_path: Option<String>,
22 pub password: Option<String>,
24 pub bind_address: String,
29 pub remote_port: u32,
31 pub local_addr: String,
33 pub local_port: u16,
35}
36
37struct Client {
39 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
40 message_tx: mpsc::UnboundedSender<String>,
41}
42
43#[async_trait::async_trait]
44impl client::Handler for Client {
45 type Error = russh::Error;
46
47 async fn check_server_key(
48 &mut self,
49 _server_public_key: &key::PublicKey,
50 ) -> Result<bool, Self::Error> {
51 Ok(true)
54 }
55
56 async fn server_channel_open_forwarded_tcpip(
57 &mut self,
58 channel: Channel<Msg>,
59 connected_address: &str,
60 connected_port: u32,
61 originator_address: &str,
62 originator_port: u32,
63 _session: &mut client::Session,
64 ) -> Result<(), Self::Error> {
65 debug!(
66 "Forwarded channel: {}:{} -> {}:{}",
67 originator_address, originator_port, connected_address, connected_port
68 );
69
70 let _ = self
72 .tx
73 .send((channel, connected_address.to_string(), connected_port));
74
75 Ok(())
76 }
77
78 async fn data(
79 &mut self,
80 _channel: ChannelId,
81 data: &[u8],
82 _session: &mut client::Session,
83 ) -> Result<(), Self::Error> {
84 if let Ok(message) = String::from_utf8(data.to_vec()) {
87 debug!("Received data ({} bytes): {}", data.len(), message);
88 let _ = self.message_tx.send(message);
89 } else {
90 debug!(
92 "Received {} bytes of non-UTF8 data on channel {:?}",
93 data.len(),
94 _channel
95 );
96 }
97 Ok(())
98 }
99
100 async fn extended_data(
101 &mut self,
102 _channel: ChannelId,
103 ext: u32,
104 data: &[u8],
105 _session: &mut client::Session,
106 ) -> Result<(), Self::Error> {
107 if let Ok(message) = String::from_utf8(data.to_vec()) {
110 info!("Received extended data (type {}): {}", ext, message);
111 let _ = self.message_tx.send(message);
112 }
113 debug!(
114 "Received {} bytes of extended data (type {}) on channel {:?}",
115 data.len(),
116 ext,
117 _channel
118 );
119 Ok(())
120 }
121}
122
123impl Client {
124 fn new(
125 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
126 message_tx: mpsc::UnboundedSender<String>,
127 ) -> Self {
128 Self { tx, message_tx }
129 }
130}
131
132pub struct ReverseSshClient {
134 config: ReverseSshConfig,
135 handle: Option<Handle<Client>>,
136}
137
138impl ReverseSshClient {
139 pub fn new(config: ReverseSshConfig) -> Self {
141 Self {
142 config,
143 handle: None,
144 }
145 }
146
147 pub async fn connect(
149 &mut self,
150 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
151 message_tx: mpsc::UnboundedSender<String>,
152 ) -> Result<()> {
153 info!(
154 "Connecting to SSH server {}:{}",
155 self.config.server_addr, self.config.server_port
156 );
157
158 let client_config = client::Config {
159 inactivity_timeout: Some(std::time::Duration::from_secs(3600)),
160 ..<_>::default()
161 };
162
163 let client_handler = Client::new(tx, message_tx);
164
165 let mut session = client::connect(
166 Arc::new(client_config),
167 (self.config.server_addr.as_str(), self.config.server_port),
168 client_handler,
169 )
170 .await
171 .context("Failed to connect to SSH server")?;
172
173 let auth_result = if let Some(key_path) = &self.config.key_path {
175 info!("Authenticating with private key: {}", key_path);
176 let key_pair = russh_keys::load_secret_key(key_path, None)
177 .context("Failed to load private key")?;
178 session
179 .authenticate_publickey(&self.config.username, Arc::new(key_pair))
180 .await
181 } else if let Some(password) = &self.config.password {
182 info!("Authenticating with password");
183 session
184 .authenticate_password(&self.config.username, password)
185 .await
186 } else {
187 anyhow::bail!("No authentication method provided (need key_path or password)");
188 };
189
190 if !auth_result.context("Authentication failed")? {
191 anyhow::bail!("Authentication rejected by server");
192 }
193
194 info!("Successfully authenticated to SSH server");
195 self.handle = Some(session);
196 Ok(())
197 }
198
199 pub async fn setup_reverse_tunnel(&mut self) -> Result<()> {
202 let handle = self
203 .handle
204 .as_mut()
205 .context("Not connected - call connect() first")?;
206
207 if self.config.bind_address.is_empty() {
208 info!(
209 "Setting up reverse tunnel: server port {} -> local {}:{}",
210 self.config.remote_port, self.config.local_addr, self.config.local_port
211 );
212 } else {
213 info!(
214 "Setting up reverse tunnel: {}:{} -> local {}:{}",
215 self.config.bind_address,
216 self.config.remote_port,
217 self.config.local_addr,
218 self.config.local_port
219 );
220 }
221
222 handle
226 .tcpip_forward(&self.config.bind_address, self.config.remote_port)
227 .await
228 .context("Failed to set up remote port forwarding")?;
229
230 info!("Reverse tunnel established successfully");
231
232 match handle.channel_open_session().await {
235 Ok(channel) => {
236 info!("Opened shell session to receive server messages");
237 if let Err(e) = channel.request_shell(false).await {
239 warn!("Failed to request shell: {}", e);
240 } else {
241 debug!("Shell requested successfully");
242 }
243 }
246 Err(e) => {
247 warn!(
248 "Could not open shell session: {} (this may be normal for some servers)",
249 e
250 );
251 }
252 }
253
254 Ok(())
255 }
256
257 #[allow(dead_code)]
260 pub async fn read_server_messages(&mut self) -> Result<Vec<String>> {
261 let handle = self
262 .handle
263 .as_mut()
264 .context("Not connected - call connect() first")?;
265
266 let mut messages = Vec::new();
267
268 match handle.channel_open_session().await {
270 Ok(channel) => {
271 let _ = channel.request_shell(false).await;
273
274 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
276
277 let _ = channel.eof().await;
283 let _ = channel.close().await;
284
285 messages.push("Check SSH session output for connection URL".to_string());
286 }
287 Err(e) => {
288 warn!("Could not open session channel: {}", e);
289 }
290 }
291
292 Ok(messages)
293 }
294
295 pub async fn handle_forwarded_connections(
297 &mut self,
298 mut rx: mpsc::UnboundedReceiver<(Channel<Msg>, String, u32)>,
299 ) -> Result<()> {
300 info!("Waiting for forwarded connections...");
301
302 while let Some((channel, _remote_addr, _remote_port)) = rx.recv().await {
303 info!("New forwarded connection received");
304
305 let local_addr = self.config.local_addr.clone();
307 let local_port = self.config.local_port;
308
309 tokio::spawn(async move {
310 if let Err(e) = handle_connection(channel, &local_addr, local_port).await {
311 error!("Error handling connection: {}", e);
312 }
313 });
314 }
315
316 warn!("Connection closed by server");
317 Ok(())
318 }
319
320 #[allow(dead_code)]
322 pub async fn run(&mut self) -> Result<()> {
323 let (tx, rx) = mpsc::unbounded_channel();
324 let (message_tx, mut message_rx) = mpsc::unbounded_channel();
325
326 self.connect(tx, message_tx).await?;
327 self.setup_reverse_tunnel().await?;
328
329 tokio::spawn(async move {
331 while let Some(message) = message_rx.recv().await {
332 if !message.trim().is_empty() {
334 println!("[Server] {}", message.trim());
335 }
336 }
337 });
338
339 self.handle_forwarded_connections(rx).await?;
340
341 Ok(())
342 }
343
344 pub async fn run_with_message_handler<F>(&mut self, mut message_handler: F) -> Result<()>
346 where
347 F: FnMut(String) + Send + 'static,
348 {
349 let (tx, rx) = mpsc::unbounded_channel();
350 let (message_tx, mut message_rx) = mpsc::unbounded_channel();
351
352 self.connect(tx, message_tx).await?;
353 self.setup_reverse_tunnel().await?;
354
355 tokio::spawn(async move {
357 while let Some(message) = message_rx.recv().await {
358 message_handler(message);
359 }
360 });
361
362 self.handle_forwarded_connections(rx).await?;
363
364 Ok(())
365 }
366}
367
368async fn handle_connection(
370 mut channel: Channel<Msg>,
371 local_addr: &str,
372 local_port: u16,
373) -> Result<()> {
374 use tokio::io::{AsyncReadExt, AsyncWriteExt};
375
376 info!("Connecting to local service {}:{}", local_addr, local_port);
377
378 let local_socket_addr: SocketAddr = format!("{}:{}", local_addr, local_port)
380 .parse()
381 .context("Invalid local address")?;
382
383 let mut local_stream = TcpStream::connect(local_socket_addr)
384 .await
385 .context("Failed to connect to local service")?;
386
387 info!("Connected to local service, starting bidirectional proxy");
388
389 let mut local_buf = vec![0u8; 8192];
391
392 loop {
394 tokio::select! {
395 msg = channel.wait() => {
397 match msg {
398 Some(russh::ChannelMsg::Data { data }) => {
399 debug!("Received {} bytes from SSH channel", data.len());
400 if let Err(e) = local_stream.write_all(&data).await {
401 error!("Failed to write to local service: {}", e);
402 break;
403 }
404 }
405 Some(russh::ChannelMsg::Eof) => {
406 debug!("Received EOF from SSH channel");
407 let _ = local_stream.shutdown().await;
408 break;
409 }
410 Some(russh::ChannelMsg::Close) => {
411 debug!("SSH channel closed");
412 break;
413 }
414 Some(other) => {
415 debug!("Received other channel message: {:?}", other);
416 }
417 None => {
418 debug!("SSH channel receiver closed");
419 break;
420 }
421 }
422 }
423
424 result = local_stream.read(&mut local_buf) => {
426 match result {
427 Ok(0) => {
428 debug!("Local connection closed");
429 break;
430 }
431 Ok(n) => {
432 debug!("Read {} bytes from local service", n);
433 if let Err(e) = channel.data(&local_buf[..n]).await {
434 error!("Failed to send data to SSH channel: {}", e);
435 break;
436 }
437 }
438 Err(e) => {
439 error!("Error reading from local service: {}", e);
440 break;
441 }
442 }
443 }
444 }
445 }
446
447 let _ = channel.eof().await;
449 let _ = channel.close().await;
450
451 info!("Connection proxy closed");
452
453 Ok(())
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
461 fn test_config_creation() {
462 let config = ReverseSshConfig {
463 server_addr: "example.com".to_string(),
464 server_port: 22,
465 username: "user".to_string(),
466 key_path: Some("/path/to/key".to_string()),
467 password: None,
468 bind_address: String::new(),
469 remote_port: 8080,
470 local_addr: "127.0.0.1".to_string(),
471 local_port: 3000,
472 };
473
474 assert_eq!(config.server_addr, "example.com");
475 assert_eq!(config.remote_port, 8080);
476 assert!(config.bind_address.is_empty());
477 }
478
479 #[test]
480 fn test_config_with_bind_address() {
481 let config = ReverseSshConfig {
482 server_addr: "tuns.sh".to_string(),
483 server_port: 22,
484 username: "myuser".to_string(),
485 key_path: Some("/path/to/key".to_string()),
486 password: None,
487 bind_address: "dev".to_string(),
488 remote_port: 80,
489 local_addr: "127.0.0.1".to_string(),
490 local_port: 8000,
491 };
492
493 assert_eq!(config.server_addr, "tuns.sh");
494 assert_eq!(config.bind_address, "dev");
495 assert_eq!(config.remote_port, 80);
496 }
497}