reverse_ssh/
lib.rs

1use anyhow::{Context, Result};
2use russh::client::{self, Handle, Msg};
3use russh::keys::*;
4use russh::*;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use tokio::net::TcpStream;
8use tokio::sync::mpsc;
9use tracing::{debug, error, info, warn};
10
11/// Configuration for the reverse SSH connection
12#[derive(Debug, Clone)]
13pub struct ReverseSshConfig {
14    /// The SSH server address to connect to
15    pub server_addr: String,
16    /// The SSH server port
17    pub server_port: u16,
18    /// Username for SSH authentication
19    pub username: String,
20    /// Private key path for authentication
21    pub key_path: Option<String>,
22    /// Password for authentication (if not using key)
23    pub password: Option<String>,
24    /// Bind address for remote port forwarding.
25    /// For services like pico.sh tuns, this is the tunnel name (e.g., "dev" -> "user-dev.tuns.sh").
26    /// For localhost.run, use an empty string to let the server assign a random subdomain.
27    /// Defaults to empty string if not specified.
28    pub bind_address: String,
29    /// Remote port to listen on (on the SSH server)
30    pub remote_port: u32,
31    /// Local address to forward connections to
32    pub local_addr: String,
33    /// Local port to forward connections to
34    pub local_port: u16,
35}
36
37/// SSH client handler
38struct Client {
39    tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
40    message_tx: mpsc::UnboundedSender<String>,
41}
42
43#[async_trait::async_trait]
44impl client::Handler for Client {
45    type Error = russh::Error;
46
47    async fn check_server_key(
48        &mut self,
49        _server_public_key: &key::PublicKey,
50    ) -> Result<bool, Self::Error> {
51        // In production, you should verify the server's public key
52        // For now, we accept any key
53        Ok(true)
54    }
55
56    async fn server_channel_open_forwarded_tcpip(
57        &mut self,
58        channel: Channel<Msg>,
59        connected_address: &str,
60        connected_port: u32,
61        originator_address: &str,
62        originator_port: u32,
63        _session: &mut client::Session,
64    ) -> Result<(), Self::Error> {
65        debug!(
66            "Forwarded channel: {}:{} -> {}:{}",
67            originator_address, originator_port, connected_address, connected_port
68        );
69
70        // Send the channel to be handled
71        let _ = self
72            .tx
73            .send((channel, connected_address.to_string(), connected_port));
74
75        Ok(())
76    }
77
78    async fn data(
79        &mut self,
80        _channel: ChannelId,
81        data: &[u8],
82        _session: &mut client::Session,
83    ) -> Result<(), Self::Error> {
84        // Convert data to string and send it for processing
85        // Don't filter out partial messages - send everything
86        if let Ok(message) = String::from_utf8(data.to_vec()) {
87            debug!("Received data ({} bytes): {}", data.len(), message);
88            let _ = self.message_tx.send(message);
89        } else {
90            // Log if we received non-UTF8 data
91            debug!(
92                "Received {} bytes of non-UTF8 data on channel {:?}",
93                data.len(),
94                _channel
95            );
96        }
97        Ok(())
98    }
99
100    async fn extended_data(
101        &mut self,
102        _channel: ChannelId,
103        ext: u32,
104        data: &[u8],
105        _session: &mut client::Session,
106    ) -> Result<(), Self::Error> {
107        // Extended data includes stderr (ext == 1)
108        // localhost.run sends URL info through stderr
109        if let Ok(message) = String::from_utf8(data.to_vec()) {
110            info!("Received extended data (type {}): {}", ext, message);
111            let _ = self.message_tx.send(message);
112        }
113        debug!(
114            "Received {} bytes of extended data (type {}) on channel {:?}",
115            data.len(),
116            ext,
117            _channel
118        );
119        Ok(())
120    }
121}
122
123impl Client {
124    fn new(
125        tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
126        message_tx: mpsc::UnboundedSender<String>,
127    ) -> Self {
128        Self { tx, message_tx }
129    }
130}
131
132/// Reverse SSH client that establishes a reverse tunnel
133pub struct ReverseSshClient {
134    config: ReverseSshConfig,
135    handle: Option<Handle<Client>>,
136}
137
138impl ReverseSshClient {
139    /// Create a new reverse SSH client with the given configuration
140    pub fn new(config: ReverseSshConfig) -> Self {
141        Self {
142            config,
143            handle: None,
144        }
145    }
146
147    /// Connect to the SSH server and authenticate
148    pub async fn connect(
149        &mut self,
150        tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
151        message_tx: mpsc::UnboundedSender<String>,
152    ) -> Result<()> {
153        info!(
154            "Connecting to SSH server {}:{}",
155            self.config.server_addr, self.config.server_port
156        );
157
158        let client_config = client::Config {
159            inactivity_timeout: Some(std::time::Duration::from_secs(3600)),
160            ..<_>::default()
161        };
162
163        let client_handler = Client::new(tx, message_tx);
164
165        let mut session = client::connect(
166            Arc::new(client_config),
167            (self.config.server_addr.as_str(), self.config.server_port),
168            client_handler,
169        )
170        .await
171        .context("Failed to connect to SSH server")?;
172
173        // Authenticate
174        let auth_result = if let Some(key_path) = &self.config.key_path {
175            info!("Authenticating with private key: {}", key_path);
176            let key_pair = russh_keys::load_secret_key(key_path, None)
177                .context("Failed to load private key")?;
178            session
179                .authenticate_publickey(&self.config.username, Arc::new(key_pair))
180                .await
181        } else if let Some(password) = &self.config.password {
182            info!("Authenticating with password");
183            session
184                .authenticate_password(&self.config.username, password)
185                .await
186        } else {
187            anyhow::bail!("No authentication method provided (need key_path or password)");
188        };
189
190        if !auth_result.context("Authentication failed")? {
191            anyhow::bail!("Authentication rejected by server");
192        }
193
194        info!("Successfully authenticated to SSH server");
195        self.handle = Some(session);
196        Ok(())
197    }
198
199    /// Set up a reverse port forward (remote port forwarding)
200    /// This makes the SSH server listen on a port and forward connections back to us
201    pub async fn setup_reverse_tunnel(&mut self) -> Result<()> {
202        let handle = self
203            .handle
204            .as_mut()
205            .context("Not connected - call connect() first")?;
206
207        if self.config.bind_address.is_empty() {
208            info!(
209                "Setting up reverse tunnel: server port {} -> local {}:{}",
210                self.config.remote_port, self.config.local_addr, self.config.local_port
211            );
212        } else {
213            info!(
214                "Setting up reverse tunnel: {}:{} -> local {}:{}",
215                self.config.bind_address,
216                self.config.remote_port,
217                self.config.local_addr,
218                self.config.local_port
219            );
220        }
221
222        // Request remote port forwarding
223        // The bind_address is used for services like pico.sh tuns where it becomes
224        // the tunnel name (subdomain). For localhost.run, use empty string.
225        handle
226            .tcpip_forward(&self.config.bind_address, self.config.remote_port)
227            .await
228            .context("Failed to set up remote port forwarding")?;
229
230        info!("Reverse tunnel established successfully");
231
232        // Open a shell session to receive server messages (like the URL from localhost.run)
233        // This is important for services that send connection info via shell
234        match handle.channel_open_session().await {
235            Ok(channel) => {
236                info!("Opened shell session to receive server messages");
237                // Request a shell - this triggers the server to send welcome messages
238                if let Err(e) = channel.request_shell(false).await {
239                    warn!("Failed to request shell: {}", e);
240                } else {
241                    debug!("Shell requested successfully");
242                }
243                // Don't close the channel - keep it open to receive messages
244                // The channel will be kept alive by the handler
245            }
246            Err(e) => {
247                warn!(
248                    "Could not open shell session: {} (this may be normal for some servers)",
249                    e
250                );
251            }
252        }
253
254        Ok(())
255    }
256
257    /// Read server messages (useful for services like localhost.run that send URL info)
258    /// This opens a session channel and attempts to read any messages from the server
259    #[allow(dead_code)]
260    pub async fn read_server_messages(&mut self) -> Result<Vec<String>> {
261        let handle = self
262            .handle
263            .as_mut()
264            .context("Not connected - call connect() first")?;
265
266        let mut messages = Vec::new();
267
268        // Try to open a session channel to read any server messages
269        match handle.channel_open_session().await {
270            Ok(channel) => {
271                // Request a shell to trigger server messages
272                let _ = channel.request_shell(false).await;
273
274                // Wait a bit for messages to arrive
275                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
276
277                // Try to read data from the channel
278                // Note: This is a simplified approach - in practice, we'd need to
279                // handle the channel data in the Handler's data() method
280
281                // Close the channel
282                let _ = channel.eof().await;
283                let _ = channel.close().await;
284
285                messages.push("Check SSH session output for connection URL".to_string());
286            }
287            Err(e) => {
288                warn!("Could not open session channel: {}", e);
289            }
290        }
291
292        Ok(messages)
293    }
294
295    /// Handle forwarded connections from the SSH server
296    pub async fn handle_forwarded_connections(
297        &mut self,
298        mut rx: mpsc::UnboundedReceiver<(Channel<Msg>, String, u32)>,
299    ) -> Result<()> {
300        info!("Waiting for forwarded connections...");
301
302        while let Some((channel, _remote_addr, _remote_port)) = rx.recv().await {
303            info!("New forwarded connection received");
304
305            // Spawn a task to handle this connection
306            let local_addr = self.config.local_addr.clone();
307            let local_port = self.config.local_port;
308
309            tokio::spawn(async move {
310                if let Err(e) = handle_connection(channel, &local_addr, local_port).await {
311                    error!("Error handling connection: {}", e);
312                }
313            });
314        }
315
316        warn!("Connection closed by server");
317        Ok(())
318    }
319
320    /// Run the reverse SSH client (connect, setup tunnel, and handle connections)
321    #[allow(dead_code)]
322    pub async fn run(&mut self) -> Result<()> {
323        let (tx, rx) = mpsc::unbounded_channel();
324        let (message_tx, mut message_rx) = mpsc::unbounded_channel();
325
326        self.connect(tx, message_tx).await?;
327        self.setup_reverse_tunnel().await?;
328
329        // Spawn a task to print server messages
330        tokio::spawn(async move {
331            while let Some(message) = message_rx.recv().await {
332                // Print server messages, which may include URLs
333                if !message.trim().is_empty() {
334                    println!("[Server] {}", message.trim());
335                }
336            }
337        });
338
339        self.handle_forwarded_connections(rx).await?;
340
341        Ok(())
342    }
343
344    /// Run the client with custom message handling
345    pub async fn run_with_message_handler<F>(&mut self, mut message_handler: F) -> Result<()>
346    where
347        F: FnMut(String) + Send + 'static,
348    {
349        let (tx, rx) = mpsc::unbounded_channel();
350        let (message_tx, mut message_rx) = mpsc::unbounded_channel();
351
352        self.connect(tx, message_tx).await?;
353        self.setup_reverse_tunnel().await?;
354
355        // Spawn a task to handle server messages with custom handler
356        tokio::spawn(async move {
357            while let Some(message) = message_rx.recv().await {
358                message_handler(message);
359            }
360        });
361
362        self.handle_forwarded_connections(rx).await?;
363
364        Ok(())
365    }
366}
367
368/// Handle a single forwarded connection by proxying data between SSH channel and local service
369async fn handle_connection(
370    mut channel: Channel<Msg>,
371    local_addr: &str,
372    local_port: u16,
373) -> Result<()> {
374    use tokio::io::{AsyncReadExt, AsyncWriteExt};
375
376    info!("Connecting to local service {}:{}", local_addr, local_port);
377
378    // Connect to the local service
379    let local_socket_addr: SocketAddr = format!("{}:{}", local_addr, local_port)
380        .parse()
381        .context("Invalid local address")?;
382
383    let mut local_stream = TcpStream::connect(local_socket_addr)
384        .await
385        .context("Failed to connect to local service")?;
386
387    info!("Connected to local service, starting bidirectional proxy");
388
389    // Bidirectional proxy using tokio::select!
390    let mut local_buf = vec![0u8; 8192];
391
392    // Read from local and forward to SSH
393    loop {
394        tokio::select! {
395            // Read from SSH channel and write to local service
396            msg = channel.wait() => {
397                match msg {
398                    Some(russh::ChannelMsg::Data { data }) => {
399                        debug!("Received {} bytes from SSH channel", data.len());
400                        if let Err(e) = local_stream.write_all(&data).await {
401                            error!("Failed to write to local service: {}", e);
402                            break;
403                        }
404                    }
405                    Some(russh::ChannelMsg::Eof) => {
406                        debug!("Received EOF from SSH channel");
407                        let _ = local_stream.shutdown().await;
408                        break;
409                    }
410                    Some(russh::ChannelMsg::Close) => {
411                        debug!("SSH channel closed");
412                        break;
413                    }
414                    Some(other) => {
415                        debug!("Received other channel message: {:?}", other);
416                    }
417                    None => {
418                        debug!("SSH channel receiver closed");
419                        break;
420                    }
421                }
422            }
423
424            // Read from local service and write to SSH channel
425            result = local_stream.read(&mut local_buf) => {
426                match result {
427                    Ok(0) => {
428                        debug!("Local connection closed");
429                        break;
430                    }
431                    Ok(n) => {
432                        debug!("Read {} bytes from local service", n);
433                        if let Err(e) = channel.data(&local_buf[..n]).await {
434                            error!("Failed to send data to SSH channel: {}", e);
435                            break;
436                        }
437                    }
438                    Err(e) => {
439                        error!("Error reading from local service: {}", e);
440                        break;
441                    }
442                }
443            }
444        }
445    }
446
447    // Close the channel gracefully
448    let _ = channel.eof().await;
449    let _ = channel.close().await;
450
451    info!("Connection proxy closed");
452
453    Ok(())
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    #[test]
461    fn test_config_creation() {
462        let config = ReverseSshConfig {
463            server_addr: "example.com".to_string(),
464            server_port: 22,
465            username: "user".to_string(),
466            key_path: Some("/path/to/key".to_string()),
467            password: None,
468            bind_address: String::new(),
469            remote_port: 8080,
470            local_addr: "127.0.0.1".to_string(),
471            local_port: 3000,
472        };
473
474        assert_eq!(config.server_addr, "example.com");
475        assert_eq!(config.remote_port, 8080);
476        assert!(config.bind_address.is_empty());
477    }
478
479    #[test]
480    fn test_config_with_bind_address() {
481        let config = ReverseSshConfig {
482            server_addr: "tuns.sh".to_string(),
483            server_port: 22,
484            username: "myuser".to_string(),
485            key_path: Some("/path/to/key".to_string()),
486            password: None,
487            bind_address: "dev".to_string(),
488            remote_port: 80,
489            local_addr: "127.0.0.1".to_string(),
490            local_port: 8000,
491        };
492
493        assert_eq!(config.server_addr, "tuns.sh");
494        assert_eq!(config.bind_address, "dev");
495        assert_eq!(config.remote_port, 80);
496    }
497}