1use anyhow::{Context, Result};
2use russh::client::{self, Handle, Msg};
3use russh::keys::*;
4use russh::*;
5use std::net::SocketAddr;
6use std::sync::Arc;
7use tokio::net::TcpStream;
8use tokio::sync::mpsc;
9use tracing::{debug, error, info, warn};
10
11#[derive(Debug, Clone)]
13pub struct ReverseSshConfig {
14 pub server_addr: String,
16 pub server_port: u16,
18 pub username: String,
20 pub key_path: Option<String>,
22 pub password: Option<String>,
24 pub remote_port: u32,
26 pub local_addr: String,
28 pub local_port: u16,
30}
31
32struct Client {
34 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
35 message_tx: mpsc::UnboundedSender<String>,
36}
37
38#[async_trait::async_trait]
39impl client::Handler for Client {
40 type Error = russh::Error;
41
42 async fn check_server_key(
43 &mut self,
44 _server_public_key: &key::PublicKey,
45 ) -> Result<bool, Self::Error> {
46 Ok(true)
49 }
50
51 async fn server_channel_open_forwarded_tcpip(
52 &mut self,
53 channel: Channel<Msg>,
54 connected_address: &str,
55 connected_port: u32,
56 originator_address: &str,
57 originator_port: u32,
58 _session: &mut client::Session,
59 ) -> Result<(), Self::Error> {
60 debug!(
61 "Forwarded channel: {}:{} -> {}:{}",
62 originator_address, originator_port, connected_address, connected_port
63 );
64
65 let _ = self
67 .tx
68 .send((channel, connected_address.to_string(), connected_port));
69
70 Ok(())
71 }
72
73 async fn data(
74 &mut self,
75 _channel: ChannelId,
76 data: &[u8],
77 _session: &mut client::Session,
78 ) -> Result<(), Self::Error> {
79 if let Ok(message) = String::from_utf8(data.to_vec()) {
82 debug!("Received data ({} bytes): {}", data.len(), message);
83 let _ = self.message_tx.send(message);
84 } else {
85 debug!(
87 "Received {} bytes of non-UTF8 data on channel {:?}",
88 data.len(),
89 _channel
90 );
91 }
92 Ok(())
93 }
94
95 async fn extended_data(
96 &mut self,
97 _channel: ChannelId,
98 ext: u32,
99 data: &[u8],
100 _session: &mut client::Session,
101 ) -> Result<(), Self::Error> {
102 if let Ok(message) = String::from_utf8(data.to_vec()) {
105 info!("Received extended data (type {}): {}", ext, message);
106 let _ = self.message_tx.send(message);
107 }
108 debug!(
109 "Received {} bytes of extended data (type {}) on channel {:?}",
110 data.len(),
111 ext,
112 _channel
113 );
114 Ok(())
115 }
116}
117
118impl Client {
119 fn new(
120 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
121 message_tx: mpsc::UnboundedSender<String>,
122 ) -> Self {
123 Self { tx, message_tx }
124 }
125}
126
127pub struct ReverseSshClient {
129 config: ReverseSshConfig,
130 handle: Option<Handle<Client>>,
131}
132
133impl ReverseSshClient {
134 pub fn new(config: ReverseSshConfig) -> Self {
136 Self {
137 config,
138 handle: None,
139 }
140 }
141
142 pub async fn connect(
144 &mut self,
145 tx: mpsc::UnboundedSender<(Channel<Msg>, String, u32)>,
146 message_tx: mpsc::UnboundedSender<String>,
147 ) -> Result<()> {
148 info!(
149 "Connecting to SSH server {}:{}",
150 self.config.server_addr, self.config.server_port
151 );
152
153 let client_config = client::Config {
154 inactivity_timeout: Some(std::time::Duration::from_secs(3600)),
155 ..<_>::default()
156 };
157
158 let client_handler = Client::new(tx, message_tx);
159
160 let mut session = client::connect(
161 Arc::new(client_config),
162 (self.config.server_addr.as_str(), self.config.server_port),
163 client_handler,
164 )
165 .await
166 .context("Failed to connect to SSH server")?;
167
168 let auth_result = if let Some(key_path) = &self.config.key_path {
170 info!("Authenticating with private key: {}", key_path);
171 let key_pair = russh_keys::load_secret_key(key_path, None)
172 .context("Failed to load private key")?;
173 session
174 .authenticate_publickey(&self.config.username, Arc::new(key_pair))
175 .await
176 } else if let Some(password) = &self.config.password {
177 info!("Authenticating with password");
178 session
179 .authenticate_password(&self.config.username, password)
180 .await
181 } else {
182 anyhow::bail!("No authentication method provided (need key_path or password)");
183 };
184
185 if !auth_result.context("Authentication failed")? {
186 anyhow::bail!("Authentication rejected by server");
187 }
188
189 info!("Successfully authenticated to SSH server");
190 self.handle = Some(session);
191 Ok(())
192 }
193
194 pub async fn setup_reverse_tunnel(&mut self) -> Result<()> {
197 let handle = self
198 .handle
199 .as_mut()
200 .context("Not connected - call connect() first")?;
201
202 info!(
203 "Setting up reverse tunnel: server port {} -> local {}:{}",
204 self.config.remote_port, self.config.local_addr, self.config.local_port
205 );
206
207 handle
211 .tcpip_forward("", self.config.remote_port)
212 .await
213 .context("Failed to set up remote port forwarding")?;
214
215 info!("Reverse tunnel established successfully");
216
217 match handle.channel_open_session().await {
220 Ok(channel) => {
221 info!("Opened shell session to receive server messages");
222 if let Err(e) = channel.request_shell(false).await {
224 warn!("Failed to request shell: {}", e);
225 } else {
226 debug!("Shell requested successfully");
227 }
228 }
231 Err(e) => {
232 warn!(
233 "Could not open shell session: {} (this may be normal for some servers)",
234 e
235 );
236 }
237 }
238
239 Ok(())
240 }
241
242 #[allow(dead_code)]
245 pub async fn read_server_messages(&mut self) -> Result<Vec<String>> {
246 let handle = self
247 .handle
248 .as_mut()
249 .context("Not connected - call connect() first")?;
250
251 let mut messages = Vec::new();
252
253 match handle.channel_open_session().await {
255 Ok(channel) => {
256 let _ = channel.request_shell(false).await;
258
259 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
261
262 let _ = channel.eof().await;
268 let _ = channel.close().await;
269
270 messages.push("Check SSH session output for connection URL".to_string());
271 }
272 Err(e) => {
273 warn!("Could not open session channel: {}", e);
274 }
275 }
276
277 Ok(messages)
278 }
279
280 pub async fn handle_forwarded_connections(
282 &mut self,
283 mut rx: mpsc::UnboundedReceiver<(Channel<Msg>, String, u32)>,
284 ) -> Result<()> {
285 info!("Waiting for forwarded connections...");
286
287 while let Some((channel, _remote_addr, _remote_port)) = rx.recv().await {
288 info!("New forwarded connection received");
289
290 let local_addr = self.config.local_addr.clone();
292 let local_port = self.config.local_port;
293
294 tokio::spawn(async move {
295 if let Err(e) = handle_connection(channel, &local_addr, local_port).await {
296 error!("Error handling connection: {}", e);
297 }
298 });
299 }
300
301 warn!("Connection closed by server");
302 Ok(())
303 }
304
305 #[allow(dead_code)]
307 pub async fn run(&mut self) -> Result<()> {
308 let (tx, rx) = mpsc::unbounded_channel();
309 let (message_tx, mut message_rx) = mpsc::unbounded_channel();
310
311 self.connect(tx, message_tx).await?;
312 self.setup_reverse_tunnel().await?;
313
314 tokio::spawn(async move {
316 while let Some(message) = message_rx.recv().await {
317 if !message.trim().is_empty() {
319 println!("[Server] {}", message.trim());
320 }
321 }
322 });
323
324 self.handle_forwarded_connections(rx).await?;
325
326 Ok(())
327 }
328
329 pub async fn run_with_message_handler<F>(&mut self, mut message_handler: F) -> Result<()>
331 where
332 F: FnMut(String) + Send + 'static,
333 {
334 let (tx, rx) = mpsc::unbounded_channel();
335 let (message_tx, mut message_rx) = mpsc::unbounded_channel();
336
337 self.connect(tx, message_tx).await?;
338 self.setup_reverse_tunnel().await?;
339
340 tokio::spawn(async move {
342 while let Some(message) = message_rx.recv().await {
343 message_handler(message);
344 }
345 });
346
347 self.handle_forwarded_connections(rx).await?;
348
349 Ok(())
350 }
351}
352
353async fn handle_connection(
355 mut channel: Channel<Msg>,
356 local_addr: &str,
357 local_port: u16,
358) -> Result<()> {
359 use tokio::io::{AsyncReadExt, AsyncWriteExt};
360
361 info!("Connecting to local service {}:{}", local_addr, local_port);
362
363 let local_socket_addr: SocketAddr = format!("{}:{}", local_addr, local_port)
365 .parse()
366 .context("Invalid local address")?;
367
368 let mut local_stream = TcpStream::connect(local_socket_addr)
369 .await
370 .context("Failed to connect to local service")?;
371
372 info!("Connected to local service, starting bidirectional proxy");
373
374 let mut local_buf = vec![0u8; 8192];
376
377 loop {
379 tokio::select! {
380 msg = channel.wait() => {
382 match msg {
383 Some(russh::ChannelMsg::Data { data }) => {
384 debug!("Received {} bytes from SSH channel", data.len());
385 if let Err(e) = local_stream.write_all(&data).await {
386 error!("Failed to write to local service: {}", e);
387 break;
388 }
389 }
390 Some(russh::ChannelMsg::Eof) => {
391 debug!("Received EOF from SSH channel");
392 let _ = local_stream.shutdown().await;
393 break;
394 }
395 Some(russh::ChannelMsg::Close) => {
396 debug!("SSH channel closed");
397 break;
398 }
399 Some(other) => {
400 debug!("Received other channel message: {:?}", other);
401 }
402 None => {
403 debug!("SSH channel receiver closed");
404 break;
405 }
406 }
407 }
408
409 result = local_stream.read(&mut local_buf) => {
411 match result {
412 Ok(0) => {
413 debug!("Local connection closed");
414 break;
415 }
416 Ok(n) => {
417 debug!("Read {} bytes from local service", n);
418 if let Err(e) = channel.data(&local_buf[..n]).await {
419 error!("Failed to send data to SSH channel: {}", e);
420 break;
421 }
422 }
423 Err(e) => {
424 error!("Error reading from local service: {}", e);
425 break;
426 }
427 }
428 }
429 }
430 }
431
432 let _ = channel.eof().await;
434 let _ = channel.close().await;
435
436 info!("Connection proxy closed");
437
438 Ok(())
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_config_creation() {
447 let config = ReverseSshConfig {
448 server_addr: "example.com".to_string(),
449 server_port: 22,
450 username: "user".to_string(),
451 key_path: Some("/path/to/key".to_string()),
452 password: None,
453 remote_port: 8080,
454 local_addr: "127.0.0.1".to_string(),
455 local_port: 3000,
456 };
457
458 assert_eq!(config.server_addr, "example.com");
459 assert_eq!(config.remote_port, 8080);
460 }
461}