reverse_ssh/
lib.rs

1use anyhow::{Context, Result};
2use russh::client::{self, Handle, Msg};
3use russh::keys::*;
4use russh::*;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use tokio::io::AsyncReadExt;
8use tokio::net::TcpStream;
9use tokio::sync::mpsc;
10use tracing::{debug, error, info, warn};
11
12/// Configuration for the reverse SSH connection
13#[derive(Debug, Clone)]
14pub struct ReverseSshConfig {
15    /// The SSH server address to connect to
16    pub server_addr: String,
17    /// The SSH server port
18    pub server_port: u16,
19    /// Username for SSH authentication
20    pub username: String,
21    /// Private key path for authentication
22    pub key_path: Option<String>,
23    /// Password for authentication (if not using key)
24    pub password: Option<String>,
25    /// Remote port to listen on (on the SSH server)
26    pub remote_port: u32,
27    /// Local address to forward connections to
28    pub local_addr: String,
29    /// Local port to forward connections to
30    pub local_port: u16,
31}
32
33/// SSH client handler
34struct Client {
35    tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
36    message_tx: mpsc::UnboundedSender<String>,
37}
38
39#[async_trait::async_trait]
40impl client::Handler for Client {
41    type Error = russh::Error;
42
43    async fn check_server_key(
44        &mut self,
45        _server_public_key: &key::PublicKey,
46    ) -> Result<bool, Self::Error> {
47        // In production, you should verify the server's public key
48        // For now, we accept any key
49        Ok(true)
50    }
51
52    async fn server_channel_open_forwarded_tcpip(
53        &mut self,
54        channel: Channel<Msg>,
55        connected_address: &str,
56        connected_port: u32,
57        originator_address: &str,
58        originator_port: u32,
59        _session: &mut client::Session,
60    ) -> Result<(), Self::Error> {
61        info!(
62            "Server opened forwarded channel: {}:{} -> {}:{}",
63            originator_address, originator_port, connected_address, connected_port
64        );
65
66        // Send the channel to be handled
67        let _ = self.tx.send((channel, connected_address.to_string(), connected_port));
68
69        Ok(())
70    }
71
72    async fn data(
73        &mut self,
74        _channel: ChannelId,
75        data: &[u8],
76        _session: &mut client::Session,
77    ) -> Result<(), Self::Error> {
78        // Convert data to string and send it for processing
79        // Don't filter out partial messages - send everything
80        if let Ok(message) = String::from_utf8(data.to_vec()) {
81            debug!("Received data ({} bytes): {}", data.len(), message);
82            let _ = self.message_tx.send(message);
83        } else {
84            // Log if we received non-UTF8 data
85            debug!("Received {} bytes of non-UTF8 data on channel {:?}", data.len(), _channel);
86        }
87        Ok(())
88    }
89
90    async fn extended_data(
91        &mut self,
92        _channel: ChannelId,
93        ext: u32,
94        data: &[u8],
95        _session: &mut client::Session,
96    ) -> Result<(), Self::Error> {
97        // Extended data includes stderr (ext == 1)
98        // localhost.run sends URL info through stderr
99        if let Ok(message) = String::from_utf8(data.to_vec()) {
100            info!("Received extended data (type {}): {}", ext, message);
101            let _ = self.message_tx.send(message);
102        }
103        debug!("Received {} bytes of extended data (type {}) on channel {:?}", data.len(), ext, _channel);
104        Ok(())
105    }
106}
107
108impl Client {
109    fn new(
110        tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
111        message_tx: mpsc::UnboundedSender<String>,
112    ) -> Self {
113        Self { tx, message_tx }
114    }
115}
116
117/// Reverse SSH client that establishes a reverse tunnel
118pub struct ReverseSshClient {
119    config: ReverseSshConfig,
120    handle: Option<Handle<Client>>,
121}
122
123impl ReverseSshClient {
124    /// Create a new reverse SSH client with the given configuration
125    pub fn new(config: ReverseSshConfig) -> Self {
126        Self {
127            config,
128            handle: None,
129        }
130    }
131
132    /// Connect to the SSH server and authenticate
133    pub async fn connect(
134        &mut self,
135        tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
136        message_tx: mpsc::UnboundedSender<String>,
137    ) -> Result<()> {
138        info!("Connecting to SSH server {}:{}", self.config.server_addr, self.config.server_port);
139
140        let client_config = client::Config {
141            inactivity_timeout: Some(std::time::Duration::from_secs(3600)),
142            ..<_>::default()
143        };
144
145        let client_handler = Client::new(tx, message_tx);
146
147        let mut session = client::connect(
148            Arc::new(client_config),
149            (self.config.server_addr.as_str(), self.config.server_port),
150            client_handler,
151        )
152        .await
153        .context("Failed to connect to SSH server")?;
154
155        // Authenticate
156        let auth_result = if let Some(key_path) = &self.config.key_path {
157            info!("Authenticating with private key: {}", key_path);
158            let key_pair = russh_keys::load_secret_key(key_path, None)
159                .context("Failed to load private key")?;
160            session
161                .authenticate_publickey(&self.config.username, Arc::new(key_pair))
162                .await
163        } else if let Some(password) = &self.config.password {
164            info!("Authenticating with password");
165            session
166                .authenticate_password(&self.config.username, password)
167                .await
168        } else {
169            anyhow::bail!("No authentication method provided (need key_path or password)");
170        };
171
172        if !auth_result.context("Authentication failed")? {
173            anyhow::bail!("Authentication rejected by server");
174        }
175
176        info!("Successfully authenticated to SSH server");
177        self.handle = Some(session);
178        Ok(())
179    }
180
181    /// Set up a reverse port forward (remote port forwarding)
182    /// This makes the SSH server listen on a port and forward connections back to us
183    pub async fn setup_reverse_tunnel(&mut self) -> Result<()> {
184        let handle = self
185            .handle
186            .as_mut()
187            .context("Not connected - call connect() first")?;
188
189        info!(
190            "Setting up reverse tunnel: server port {} -> local {}:{}",
191            self.config.remote_port, self.config.local_addr, self.config.local_port
192        );
193
194        // Request remote port forwarding
195        // Use empty string "" instead of "0.0.0.0" - this lets the SSH server choose
196        // the bind address. localhost.run requires this format.
197        handle
198            .tcpip_forward("", self.config.remote_port)
199            .await
200            .context("Failed to set up remote port forwarding")?;
201
202        info!("Reverse tunnel established successfully");
203
204        // Open a shell session to receive server messages (like the URL from localhost.run)
205        // This is important for services that send connection info via shell
206        match handle.channel_open_session().await {
207            Ok(channel) => {
208                info!("Opened shell session to receive server messages");
209                // Request a shell - this triggers the server to send welcome messages
210                if let Err(e) = channel.request_shell(false).await {
211                    warn!("Failed to request shell: {}", e);
212                } else {
213                    debug!("Shell requested successfully");
214                }
215                // Don't close the channel - keep it open to receive messages
216                // The channel will be kept alive by the handler
217            }
218            Err(e) => {
219                warn!("Could not open shell session: {} (this may be normal for some servers)", e);
220            }
221        }
222
223        Ok(())
224    }
225
226    /// Read server messages (useful for services like localhost.run that send URL info)
227    /// This opens a session channel and attempts to read any messages from the server
228    pub async fn read_server_messages(&mut self) -> Result<Vec<String>> {
229        let handle = self
230            .handle
231            .as_mut()
232            .context("Not connected - call connect() first")?;
233
234        let mut messages = Vec::new();
235
236        // Try to open a session channel to read any server messages
237        match handle.channel_open_session().await {
238            Ok(channel) => {
239                // Request a shell to trigger server messages
240                let _ = channel.request_shell(false).await;
241
242                // Wait a bit for messages to arrive
243                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
244
245                // Try to read data from the channel
246                // Note: This is a simplified approach - in practice, we'd need to
247                // handle the channel data in the Handler's data() method
248
249                // Close the channel
250                let _ = channel.eof().await;
251                let _ = channel.close().await;
252
253                messages.push("Check SSH session output for connection URL".to_string());
254            }
255            Err(e) => {
256                warn!("Could not open session channel: {}", e);
257            }
258        }
259
260        Ok(messages)
261    }
262
263    /// Handle forwarded connections from the SSH server
264    pub async fn handle_forwarded_connections(&mut self, mut rx: mpsc::UnboundedReceiver<(Channel<Msg>, String, u32)>) -> Result<()> {
265        info!("Waiting for forwarded connections...");
266
267        while let Some((channel, _remote_addr, _remote_port)) = rx.recv().await {
268            info!("New forwarded connection received");
269
270            // Spawn a task to handle this connection
271            let local_addr = self.config.local_addr.clone();
272            let local_port = self.config.local_port;
273
274            tokio::spawn(async move {
275                if let Err(e) = handle_connection(channel, &local_addr, local_port).await {
276                    error!("Error handling connection: {}", e);
277                }
278            });
279        }
280
281        warn!("Connection closed by server");
282        Ok(())
283    }
284
285    /// Run the reverse SSH client (connect, setup tunnel, and handle connections)
286    pub async fn run(&mut self) -> Result<()> {
287        let (tx, rx) = mpsc::unbounded_channel();
288        let (message_tx, mut message_rx) = mpsc::unbounded_channel();
289
290        self.connect(tx, message_tx).await?;
291        self.setup_reverse_tunnel().await?;
292
293        // Spawn a task to print server messages
294        tokio::spawn(async move {
295            while let Some(message) = message_rx.recv().await {
296                // Print server messages, which may include URLs
297                if !message.trim().is_empty() {
298                    println!("[Server] {}", message.trim());
299                }
300            }
301        });
302
303        self.handle_forwarded_connections(rx).await?;
304
305        Ok(())
306    }
307
308    /// Run the client with custom message handling
309    pub async fn run_with_message_handler<F>(&mut self, mut message_handler: F) -> Result<()>
310    where
311        F: FnMut(String) + Send + 'static,
312    {
313        let (tx, rx) = mpsc::unbounded_channel();
314        let (message_tx, mut message_rx) = mpsc::unbounded_channel();
315
316        self.connect(tx, message_tx).await?;
317        self.setup_reverse_tunnel().await?;
318
319        // Spawn a task to handle server messages with custom handler
320        tokio::spawn(async move {
321            while let Some(message) = message_rx.recv().await {
322                message_handler(message);
323            }
324        });
325
326        self.handle_forwarded_connections(rx).await?;
327
328        Ok(())
329    }
330}
331
332/// Handle a single forwarded connection by proxying data between SSH channel and local service
333async fn handle_connection(
334    channel: Channel<Msg>,
335    local_addr: &str,
336    local_port: u16,
337) -> Result<()> {
338    info!("Connecting to local service {}:{}", local_addr, local_port);
339
340    // Connect to the local service
341    let local_socket_addr: SocketAddr = format!("{}:{}", local_addr, local_port)
342        .parse()
343        .context("Invalid local address")?;
344
345    let mut local_stream = TcpStream::connect(local_socket_addr)
346        .await
347        .context("Failed to connect to local service")?;
348
349    info!("Connected to local service, starting proxy");
350
351    // For now, we'll implement a simple bidirectional proxy
352    // Note: The russh Channel API may need adjustment based on actual usage
353
354    // Create a simple buffer for reading from local service
355    let mut buffer = vec![0u8; 8192];
356
357    // Read from local and forward to SSH
358    loop {
359        match local_stream.read(&mut buffer).await {
360            Ok(0) => {
361                debug!("Local connection closed");
362                break;
363            }
364            Ok(n) => {
365                debug!("Read {} bytes from local service", n);
366                // Send data through SSH channel
367                if let Err(e) = channel.data(&buffer[..n]).await {
368                    error!("Failed to send data to SSH channel: {}", e);
369                    break;
370                }
371            }
372            Err(e) => {
373                error!("Error reading from local service: {}", e);
374                break;
375            }
376        }
377    }
378
379    // Close the channel
380    let _ = channel.eof().await;
381    let _ = channel.close().await;
382
383    info!("Connection closed");
384    Ok(())
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    #[test]
392    fn test_config_creation() {
393        let config = ReverseSshConfig {
394            server_addr: "example.com".to_string(),
395            server_port: 22,
396            username: "user".to_string(),
397            key_path: Some("/path/to/key".to_string()),
398            password: None,
399            remote_port: 8080,
400            local_addr: "127.0.0.1".to_string(),
401            local_port: 3000,
402        };
403
404        assert_eq!(config.server_addr, "example.com");
405        assert_eq!(config.remote_port, 8080);
406    }
407}