reverse_ssh/
lib.rs

1use anyhow::{Context, Result};
2use russh::client::{self, Handle, Msg};
3use russh::keys::*;
4use russh::*;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use tokio::net::TcpStream;
8use tokio::sync::mpsc;
9use tracing::{debug, error, info, warn};
10
11/// Configuration for the reverse SSH connection
12#[derive(Debug, Clone)]
13pub struct ReverseSshConfig {
14    /// The SSH server address to connect to
15    pub server_addr: String,
16    /// The SSH server port
17    pub server_port: u16,
18    /// Username for SSH authentication
19    pub username: String,
20    /// Private key path for authentication
21    pub key_path: Option<String>,
22    /// Password for authentication (if not using key)
23    pub password: Option<String>,
24    /// Remote port to listen on (on the SSH server)
25    pub remote_port: u32,
26    /// Local address to forward connections to
27    pub local_addr: String,
28    /// Local port to forward connections to
29    pub local_port: u16,
30}
31
32/// SSH client handler
33struct Client {
34    tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
35    message_tx: mpsc::UnboundedSender<String>,
36}
37
38#[async_trait::async_trait]
39impl client::Handler for Client {
40    type Error = russh::Error;
41
42    async fn check_server_key(
43        &mut self,
44        _server_public_key: &key::PublicKey,
45    ) -> Result<bool, Self::Error> {
46        // In production, you should verify the server's public key
47        // For now, we accept any key
48        Ok(true)
49    }
50
51    async fn server_channel_open_forwarded_tcpip(
52        &mut self,
53        channel: Channel<Msg>,
54        connected_address: &str,
55        connected_port: u32,
56        originator_address: &str,
57        originator_port: u32,
58        _session: &mut client::Session,
59    ) -> Result<(), Self::Error> {
60        debug!(
61            "Forwarded channel: {}:{} -> {}:{}",
62            originator_address, originator_port, connected_address, connected_port
63        );
64
65        // Send the channel to be handled
66        let _ = self
67            .tx
68            .send((channel, connected_address.to_string(), connected_port));
69
70        Ok(())
71    }
72
73    async fn data(
74        &mut self,
75        _channel: ChannelId,
76        data: &[u8],
77        _session: &mut client::Session,
78    ) -> Result<(), Self::Error> {
79        // Convert data to string and send it for processing
80        // Don't filter out partial messages - send everything
81        if let Ok(message) = String::from_utf8(data.to_vec()) {
82            debug!("Received data ({} bytes): {}", data.len(), message);
83            let _ = self.message_tx.send(message);
84        } else {
85            // Log if we received non-UTF8 data
86            debug!(
87                "Received {} bytes of non-UTF8 data on channel {:?}",
88                data.len(),
89                _channel
90            );
91        }
92        Ok(())
93    }
94
95    async fn extended_data(
96        &mut self,
97        _channel: ChannelId,
98        ext: u32,
99        data: &[u8],
100        _session: &mut client::Session,
101    ) -> Result<(), Self::Error> {
102        // Extended data includes stderr (ext == 1)
103        // localhost.run sends URL info through stderr
104        if let Ok(message) = String::from_utf8(data.to_vec()) {
105            info!("Received extended data (type {}): {}", ext, message);
106            let _ = self.message_tx.send(message);
107        }
108        debug!(
109            "Received {} bytes of extended data (type {}) on channel {:?}",
110            data.len(),
111            ext,
112            _channel
113        );
114        Ok(())
115    }
116}
117
118impl Client {
119    fn new(
120        tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
121        message_tx: mpsc::UnboundedSender<String>,
122    ) -> Self {
123        Self { tx, message_tx }
124    }
125}
126
127/// Reverse SSH client that establishes a reverse tunnel
128pub struct ReverseSshClient {
129    config: ReverseSshConfig,
130    handle: Option<Handle<Client>>,
131}
132
133impl ReverseSshClient {
134    /// Create a new reverse SSH client with the given configuration
135    pub fn new(config: ReverseSshConfig) -> Self {
136        Self {
137            config,
138            handle: None,
139        }
140    }
141
142    /// Connect to the SSH server and authenticate
143    pub async fn connect(
144        &mut self,
145        tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
146        message_tx: mpsc::UnboundedSender<String>,
147    ) -> Result<()> {
148        info!(
149            "Connecting to SSH server {}:{}",
150            self.config.server_addr, self.config.server_port
151        );
152
153        let client_config = client::Config {
154            inactivity_timeout: Some(std::time::Duration::from_secs(3600)),
155            ..<_>::default()
156        };
157
158        let client_handler = Client::new(tx, message_tx);
159
160        let mut session = client::connect(
161            Arc::new(client_config),
162            (self.config.server_addr.as_str(), self.config.server_port),
163            client_handler,
164        )
165        .await
166        .context("Failed to connect to SSH server")?;
167
168        // Authenticate
169        let auth_result = if let Some(key_path) = &self.config.key_path {
170            info!("Authenticating with private key: {}", key_path);
171            let key_pair = russh_keys::load_secret_key(key_path, None)
172                .context("Failed to load private key")?;
173            session
174                .authenticate_publickey(&self.config.username, Arc::new(key_pair))
175                .await
176        } else if let Some(password) = &self.config.password {
177            info!("Authenticating with password");
178            session
179                .authenticate_password(&self.config.username, password)
180                .await
181        } else {
182            anyhow::bail!("No authentication method provided (need key_path or password)");
183        };
184
185        if !auth_result.context("Authentication failed")? {
186            anyhow::bail!("Authentication rejected by server");
187        }
188
189        info!("Successfully authenticated to SSH server");
190        self.handle = Some(session);
191        Ok(())
192    }
193
194    /// Set up a reverse port forward (remote port forwarding)
195    /// This makes the SSH server listen on a port and forward connections back to us
196    pub async fn setup_reverse_tunnel(&mut self) -> Result<()> {
197        let handle = self
198            .handle
199            .as_mut()
200            .context("Not connected - call connect() first")?;
201
202        info!(
203            "Setting up reverse tunnel: server port {} -> local {}:{}",
204            self.config.remote_port, self.config.local_addr, self.config.local_port
205        );
206
207        // Request remote port forwarding
208        // Use empty string "" instead of "0.0.0.0" - this lets the SSH server choose
209        // the bind address. localhost.run requires this format.
210        handle
211            .tcpip_forward("", self.config.remote_port)
212            .await
213            .context("Failed to set up remote port forwarding")?;
214
215        info!("Reverse tunnel established successfully");
216
217        // Open a shell session to receive server messages (like the URL from localhost.run)
218        // This is important for services that send connection info via shell
219        match handle.channel_open_session().await {
220            Ok(channel) => {
221                info!("Opened shell session to receive server messages");
222                // Request a shell - this triggers the server to send welcome messages
223                if let Err(e) = channel.request_shell(false).await {
224                    warn!("Failed to request shell: {}", e);
225                } else {
226                    debug!("Shell requested successfully");
227                }
228                // Don't close the channel - keep it open to receive messages
229                // The channel will be kept alive by the handler
230            }
231            Err(e) => {
232                warn!(
233                    "Could not open shell session: {} (this may be normal for some servers)",
234                    e
235                );
236            }
237        }
238
239        Ok(())
240    }
241
242    /// Read server messages (useful for services like localhost.run that send URL info)
243    /// This opens a session channel and attempts to read any messages from the server
244    #[allow(dead_code)]
245    pub async fn read_server_messages(&mut self) -> Result<Vec<String>> {
246        let handle = self
247            .handle
248            .as_mut()
249            .context("Not connected - call connect() first")?;
250
251        let mut messages = Vec::new();
252
253        // Try to open a session channel to read any server messages
254        match handle.channel_open_session().await {
255            Ok(channel) => {
256                // Request a shell to trigger server messages
257                let _ = channel.request_shell(false).await;
258
259                // Wait a bit for messages to arrive
260                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
261
262                // Try to read data from the channel
263                // Note: This is a simplified approach - in practice, we'd need to
264                // handle the channel data in the Handler's data() method
265
266                // Close the channel
267                let _ = channel.eof().await;
268                let _ = channel.close().await;
269
270                messages.push("Check SSH session output for connection URL".to_string());
271            }
272            Err(e) => {
273                warn!("Could not open session channel: {}", e);
274            }
275        }
276
277        Ok(messages)
278    }
279
280    /// Handle forwarded connections from the SSH server
281    pub async fn handle_forwarded_connections(
282        &mut self,
283        mut rx: mpsc::UnboundedReceiver<(Channel<Msg>, String, u32)>,
284    ) -> Result<()> {
285        info!("Waiting for forwarded connections...");
286
287        while let Some((channel, _remote_addr, _remote_port)) = rx.recv().await {
288            info!("New forwarded connection received");
289
290            // Spawn a task to handle this connection
291            let local_addr = self.config.local_addr.clone();
292            let local_port = self.config.local_port;
293
294            tokio::spawn(async move {
295                if let Err(e) = handle_connection(channel, &local_addr, local_port).await {
296                    error!("Error handling connection: {}", e);
297                }
298            });
299        }
300
301        warn!("Connection closed by server");
302        Ok(())
303    }
304
305    /// Run the reverse SSH client (connect, setup tunnel, and handle connections)
306    #[allow(dead_code)]
307    pub async fn run(&mut self) -> Result<()> {
308        let (tx, rx) = mpsc::unbounded_channel();
309        let (message_tx, mut message_rx) = mpsc::unbounded_channel();
310
311        self.connect(tx, message_tx).await?;
312        self.setup_reverse_tunnel().await?;
313
314        // Spawn a task to print server messages
315        tokio::spawn(async move {
316            while let Some(message) = message_rx.recv().await {
317                // Print server messages, which may include URLs
318                if !message.trim().is_empty() {
319                    println!("[Server] {}", message.trim());
320                }
321            }
322        });
323
324        self.handle_forwarded_connections(rx).await?;
325
326        Ok(())
327    }
328
329    /// Run the client with custom message handling
330    pub async fn run_with_message_handler<F>(&mut self, mut message_handler: F) -> Result<()>
331    where
332        F: FnMut(String) + Send + 'static,
333    {
334        let (tx, rx) = mpsc::unbounded_channel();
335        let (message_tx, mut message_rx) = mpsc::unbounded_channel();
336
337        self.connect(tx, message_tx).await?;
338        self.setup_reverse_tunnel().await?;
339
340        // Spawn a task to handle server messages with custom handler
341        tokio::spawn(async move {
342            while let Some(message) = message_rx.recv().await {
343                message_handler(message);
344            }
345        });
346
347        self.handle_forwarded_connections(rx).await?;
348
349        Ok(())
350    }
351}
352
353/// Handle a single forwarded connection by proxying data between SSH channel and local service
354async fn handle_connection(
355    mut channel: Channel<Msg>,
356    local_addr: &str,
357    local_port: u16,
358) -> Result<()> {
359    use tokio::io::{AsyncReadExt, AsyncWriteExt};
360
361    info!("Connecting to local service {}:{}", local_addr, local_port);
362
363    // Connect to the local service
364    let local_socket_addr: SocketAddr = format!("{}:{}", local_addr, local_port)
365        .parse()
366        .context("Invalid local address")?;
367
368    let mut local_stream = TcpStream::connect(local_socket_addr)
369        .await
370        .context("Failed to connect to local service")?;
371
372    info!("Connected to local service, starting bidirectional proxy");
373
374    // Bidirectional proxy using tokio::select!
375    let mut local_buf = vec![0u8; 8192];
376
377    // Read from local and forward to SSH
378    loop {
379        tokio::select! {
380            // Read from SSH channel and write to local service
381            msg = channel.wait() => {
382                match msg {
383                    Some(russh::ChannelMsg::Data { data }) => {
384                        debug!("Received {} bytes from SSH channel", data.len());
385                        if let Err(e) = local_stream.write_all(&data).await {
386                            error!("Failed to write to local service: {}", e);
387                            break;
388                        }
389                    }
390                    Some(russh::ChannelMsg::Eof) => {
391                        debug!("Received EOF from SSH channel");
392                        let _ = local_stream.shutdown().await;
393                        break;
394                    }
395                    Some(russh::ChannelMsg::Close) => {
396                        debug!("SSH channel closed");
397                        break;
398                    }
399                    Some(other) => {
400                        debug!("Received other channel message: {:?}", other);
401                    }
402                    None => {
403                        debug!("SSH channel receiver closed");
404                        break;
405                    }
406                }
407            }
408
409            // Read from local service and write to SSH channel
410            result = local_stream.read(&mut local_buf) => {
411                match result {
412                    Ok(0) => {
413                        debug!("Local connection closed");
414                        break;
415                    }
416                    Ok(n) => {
417                        debug!("Read {} bytes from local service", n);
418                        if let Err(e) = channel.data(&local_buf[..n]).await {
419                            error!("Failed to send data to SSH channel: {}", e);
420                            break;
421                        }
422                    }
423                    Err(e) => {
424                        error!("Error reading from local service: {}", e);
425                        break;
426                    }
427                }
428            }
429        }
430    }
431
432    // Close the channel gracefully
433    let _ = channel.eof().await;
434    let _ = channel.close().await;
435
436    info!("Connection proxy closed");
437
438    Ok(())
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_config_creation() {
447        let config = ReverseSshConfig {
448            server_addr: "example.com".to_string(),
449            server_port: 22,
450            username: "user".to_string(),
451            key_path: Some("/path/to/key".to_string()),
452            password: None,
453            remote_port: 8080,
454            local_addr: "127.0.0.1".to_string(),
455            local_port: 3000,
456        };
457
458        assert_eq!(config.server_addr, "example.com");
459        assert_eq!(config.remote_port, 8080);
460    }
461}