Skip to main content

rustgate/c2/
client.rs

1use crate::error::{ProxyError, Result};
2use crate::protocol::{
3    frame_tunnel_data, parse_tunnel_data, Command, CommandResponse, ControlMessage, WsTextMessage,
4};
5use crate::socks5::Socks5Listener;
6use crate::ws::{self, ChannelMap};
7use bytes::Bytes;
8use futures_util::{SinkExt, StreamExt};
9use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::TcpStream;
14use tokio::sync::{mpsc, RwLock};
15use tokio_tungstenite::tungstenite::Message;
16use tracing::{info, warn};
17
18/// Run the C2 client with automatic reconnect.
19pub async fn run(
20    server_url: &str,
21    cert_pem_path: &str,
22    key_pem_path: &str,
23    ca_cert_pem_path: &str,
24) -> Result<()> {
25    // Load client cert and key
26    let cert_pem = tokio::fs::read_to_string(cert_pem_path).await?;
27    let key_pem = tokio::fs::read_to_string(key_pem_path).await?;
28    let ca_pem = tokio::fs::read_to_string(ca_cert_pem_path).await?;
29
30    let client_cert_der = pem_to_cert_der(&cert_pem)?;
31    let client_key_der = pem_to_key_der(&key_pem)?;
32    let ca_cert_der = pem_to_cert_der(&ca_pem)?;
33
34    let tls_config = crate::tls::make_mtls_client_config(
35        client_cert_der,
36        client_key_der,
37        ca_cert_der,
38    )?;
39
40    // Parse host:port from server URL (wss://host:port)
41    let (host, port) = parse_wss_url(server_url)?;
42
43    let mut backoff = 1u64;
44    loop {
45        info!("Connecting to {server_url}...");
46        match connect_and_run(&host, port, server_url, tls_config.clone()).await {
47            Ok(()) => {
48                info!("Disconnected from server");
49                backoff = 1;
50            }
51            Err(e) => {
52                warn!("Connection error: {e}");
53            }
54        }
55
56        info!("Reconnecting in {backoff}s...");
57        tokio::time::sleep(std::time::Duration::from_secs(backoff)).await;
58        backoff = (backoff * 2).min(60);
59    }
60}
61
62async fn connect_and_run(
63    host: &str,
64    port: u16,
65    server_url: &str,
66    tls_config: Arc<rustls::ClientConfig>,
67) -> Result<()> {
68    let addr = format!("{host}:{port}");
69    let tcp = TcpStream::connect(&addr).await?;
70
71    let connector = tokio_rustls::TlsConnector::from(tls_config);
72    let server_name = rustls::pki_types::ServerName::try_from(host.to_string())
73        .map_err(|e| ProxyError::Other(e.to_string()))?;
74    let tls_stream = connector.connect(server_name, tcp).await?;
75
76    info!("TLS handshake complete, upgrading to WebSocket...");
77    let ws_stream = ws::connect_ws(tls_stream, server_url).await?;
78    let (mut ws_sink, mut ws_source) = ws_stream.split();
79
80    let channels = Arc::new(ChannelMap::new(1)); // Client uses odd IDs
81    let tunnel_targets: Arc<RwLock<HashMap<u32, String>>> = Arc::new(RwLock::new(HashMap::new()));
82    // Track spawned tunnel tasks for lifecycle management (StopTunnel)
83    let tunnel_handles: Arc<RwLock<HashMap<u32, tokio::task::AbortHandle>>> =
84        Arc::new(RwLock::new(HashMap::new()));
85    let (ws_tx, mut ws_rx) = mpsc::channel::<Message>(256);
86
87    info!("Connected to C2 server");
88
89    // Writer task
90    let writer_handle = tokio::spawn(async move {
91        while let Some(msg) = ws_rx.recv().await {
92            if ws_sink.send(msg).await.is_err() {
93                break;
94            }
95        }
96    });
97
98    // Reader loop
99    while let Some(msg_result) = ws_source.next().await {
100        let msg = match msg_result {
101            Ok(m) => m,
102            Err(e) => {
103                warn!("WebSocket read error: {e}");
104                break;
105            }
106        };
107
108        match msg {
109            Message::Text(text) => {
110                match serde_json::from_str::<WsTextMessage>(&text) {
111                    Ok(WsTextMessage::Command(cmd)) => {
112                        handle_command(
113                            cmd,
114                            &channels,
115                            &tunnel_targets,
116                            &tunnel_handles,
117                            ws_tx.clone(),
118                        )
119                        .await;
120                    }
121                    Ok(WsTextMessage::Control(ctrl)) => {
122                        handle_client_control(
123                            ctrl,
124                            &channels,
125                            &tunnel_targets,
126                            ws_tx.clone(),
127                        )
128                        .await;
129                    }
130                    Ok(WsTextMessage::Response(_)) => {
131                        warn!("Unexpected response from server");
132                    }
133                    Err(e) => {
134                        warn!("Failed to parse message: {e}");
135                    }
136                }
137            }
138            Message::Binary(data) => {
139                if let Some((channel_id, payload)) = parse_tunnel_data(&data) {
140                    if !channels.send(channel_id, Bytes::copy_from_slice(payload)).await {
141                        warn!("Data for unknown channel {channel_id}");
142                    }
143                }
144            }
145            Message::Close(_) => break,
146            _ => {}
147        }
148    }
149
150    writer_handle.abort();
151
152    // Close all channels — drops senders so relay tasks exit immediately
153    channels.close_all().await;
154
155    // Clean up all tunnel tasks so ports are freed for reconnect
156    {
157        let mut handles = tunnel_handles.write().await;
158        for (tid, handle) in handles.drain() {
159            handle.abort();
160            info!("Aborted tunnel {tid} on disconnect");
161        }
162    }
163    tunnel_targets.write().await.clear();
164
165    Ok(())
166}
167
168/// Handle a command from the server.
169async fn handle_command(
170    cmd: Command,
171    channels: &Arc<ChannelMap>,
172    tunnel_targets: &Arc<RwLock<HashMap<u32, String>>>,
173    tunnel_handles: &Arc<RwLock<HashMap<u32, tokio::task::AbortHandle>>>,
174    ws_tx: mpsc::Sender<Message>,
175) {
176    match cmd {
177        Command::Socks { tunnel_id, port } => {
178            let addr = format!("127.0.0.1:{port}");
179            info!("Starting SOCKS5 listener on {addr} (tunnel {tunnel_id})");
180
181            match Socks5Listener::bind(&addr, tunnel_id).await {
182                Ok(socks_listener) => {
183                    send_response(
184                        &ws_tx,
185                        CommandResponse::SocksReady { tunnel_id },
186                    )
187                    .await;
188
189                    let channels = channels.clone();
190                    let ws_tx = ws_tx.clone();
191                    let handle = tokio::spawn(async move {
192                        socks_accept_loop(socks_listener, channels, ws_tx).await;
193                    });
194                    tunnel_handles
195                        .write()
196                        .await
197                        .insert(tunnel_id, handle.abort_handle());
198                }
199                Err(e) => {
200                    warn!("Failed to bind SOCKS5: {e}");
201                    send_response(
202                        &ws_tx,
203                        CommandResponse::Error {
204                            tunnel_id: Some(tunnel_id),
205                            message: format!("Failed to bind: {e}"),
206                        },
207                    )
208                    .await;
209                }
210            }
211        }
212        Command::ReverseTunnel {
213            tunnel_id,
214            remote_port,
215            local_target,
216        } => {
217            info!(
218                "Reverse tunnel {tunnel_id}: validating {local_target} \
219                 (remote_port={remote_port})"
220            );
221
222            // Validate target is reachable before acknowledging
223            match tokio::time::timeout(
224                std::time::Duration::from_secs(5),
225                TcpStream::connect(&local_target),
226            )
227            .await
228            {
229                Ok(Ok(_tcp)) => {
230                    // Target reachable — register and confirm
231                    tunnel_targets
232                        .write()
233                        .await
234                        .insert(tunnel_id, local_target);
235                    send_response(
236                        &ws_tx,
237                        CommandResponse::ReverseTunnelReady { tunnel_id },
238                    )
239                    .await;
240                }
241                Ok(Err(e)) => {
242                    warn!("Reverse tunnel {tunnel_id}: target {local_target} unreachable: {e}");
243                    send_response(
244                        &ws_tx,
245                        CommandResponse::Error {
246                            tunnel_id: Some(tunnel_id),
247                            message: format!("Target unreachable: {e}"),
248                        },
249                    )
250                    .await;
251                }
252                Err(_) => {
253                    warn!("Reverse tunnel {tunnel_id}: target {local_target} connect timed out");
254                    send_response(
255                        &ws_tx,
256                        CommandResponse::Error {
257                            tunnel_id: Some(tunnel_id),
258                            message: "Target connect timed out".into(),
259                        },
260                    )
261                    .await;
262                }
263            }
264        }
265        Command::Ping { seq } => {
266            send_response(&ws_tx, CommandResponse::Pong { seq }).await;
267        }
268        Command::StopTunnel { tunnel_id } => {
269            tunnel_targets.write().await.remove(&tunnel_id);
270            // Abort the spawned listener/task for this tunnel
271            if let Some(handle) = tunnel_handles.write().await.remove(&tunnel_id) {
272                handle.abort();
273            }
274            // Close all active channels belonging to this tunnel
275            let closed = channels.close_tunnel(tunnel_id).await;
276            if !closed.is_empty() {
277                info!("Closed {} active channels for tunnel {tunnel_id}", closed.len());
278            }
279            info!("Tunnel {tunnel_id} stopped");
280            send_response(
281                &ws_tx,
282                CommandResponse::Ok {
283                    tunnel_id: Some(tunnel_id),
284                    message: Some("Tunnel stopped".into()),
285                },
286            )
287            .await;
288        }
289    }
290}
291
292/// Handle control messages from server on the client side.
293async fn handle_client_control(
294    ctrl: ControlMessage,
295    channels: &Arc<ChannelMap>,
296    tunnel_targets: &Arc<RwLock<HashMap<u32, String>>>,
297    ws_tx: mpsc::Sender<Message>,
298) {
299    match ctrl {
300        ControlMessage::ChannelOpen {
301            channel_id,
302            tunnel_id,
303            target: _,
304        } => {
305            // Validate: server-originated channel_id must be even and not already in use
306            if channel_id % 2 != 0 {
307                warn!("Rejected ChannelOpen with odd channel_id {channel_id} from server");
308                return;
309            }
310            if channels.has(channel_id).await {
311                warn!("Rejected ChannelOpen with duplicate channel_id {channel_id}");
312                let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
313                if let Ok(json) = serde_json::to_string(&close) {
314                    let _ = ws_tx.send(Message::Text(json)).await;
315                }
316                return;
317            }
318
319            // Server opened a channel for a reverse tunnel — connect to local target
320            let targets = tunnel_targets.read().await;
321            let local_target = match targets.get(&tunnel_id) {
322                Some(t) => t.clone(),
323                None => {
324                    warn!("ChannelOpen for unknown tunnel {tunnel_id}");
325                    return;
326                }
327            };
328            drop(targets);
329
330            info!("Reverse channel {channel_id} -> connecting to {local_target}");
331
332            // Reserve channel BEFORE async connect so ChannelClose can cancel it
333            let (data_tx, data_rx) = mpsc::channel::<Bytes>(256);
334            channels.insert_with_tunnel(channel_id, tunnel_id, data_tx).await;
335
336            let channels = channels.clone();
337            tokio::spawn(async move {
338                // Timeout connect at 8s (< server's 10s readiness timeout)
339                let connect_result = tokio::time::timeout(
340                    std::time::Duration::from_secs(8),
341                    TcpStream::connect(&local_target),
342                )
343                .await;
344                match connect_result {
345                    Ok(Ok(tcp)) => {
346                        // Re-check channel still exists (not revoked during connect)
347                        if !channels.has(channel_id).await {
348                            warn!("Channel {channel_id} revoked during reverse connect, dropping");
349                            drop(tcp);
350                            return;
351                        }
352
353                        let ready = WsTextMessage::Control(ControlMessage::ChannelReady {
354                            channel_id,
355                        });
356                        if let Ok(json) = serde_json::to_string(&ready) {
357                            let _ = ws_tx.send(Message::Text(json)).await;
358                        }
359
360                        relay_tcp_ws(tcp, channel_id, data_rx, channels, ws_tx).await;
361                    }
362                    Ok(Err(e)) => {
363                        warn!("Failed to connect to {local_target}: {e}");
364                        channels.remove(channel_id).await;
365                        let close = WsTextMessage::Control(ControlMessage::ChannelClose {
366                            channel_id,
367                        });
368                        if let Ok(json) = serde_json::to_string(&close) {
369                            let _ = ws_tx.send(Message::Text(json)).await;
370                        }
371                    }
372                    Err(_) => {
373                        warn!("Connect to {local_target} timed out for channel {channel_id}");
374                        channels.remove(channel_id).await;
375                        let close = WsTextMessage::Control(ControlMessage::ChannelClose {
376                            channel_id,
377                        });
378                        if let Ok(json) = serde_json::to_string(&close) {
379                            let _ = ws_tx.send(Message::Text(json)).await;
380                        }
381                    }
382                }
383            });
384        }
385        ControlMessage::ChannelReady { channel_id } => {
386            channels.signal_ready(channel_id).await;
387            info!("Channel {channel_id} ready");
388        }
389        ControlMessage::ChannelClose { channel_id } => {
390            channels.remove(channel_id).await;
391            info!("Channel {channel_id} closed by server");
392        }
393    }
394}
395
396/// SOCKS5 accept loop: accepts raw TCP connections concurrently, then performs
397/// handshake per-connection with a timeout so one stalled client cannot block others.
398async fn socks_accept_loop(
399    listener: Socks5Listener,
400    channels: Arc<ChannelMap>,
401    ws_tx: mpsc::Sender<Message>,
402) {
403    let tunnel_id = listener.tunnel_id;
404    loop {
405        match listener.accept_raw().await {
406            Ok(raw_stream) => {
407                let channels = channels.clone();
408                let ws_tx = ws_tx.clone();
409                tokio::spawn(async move {
410                    handle_socks_connection(raw_stream, tunnel_id, channels, ws_tx).await;
411                });
412            }
413            Err(e) => {
414                warn!("SOCKS5 accept error: {e}");
415            }
416        }
417    }
418}
419
420/// Handle a single SOCKS5 connection: handshake (with timeout) -> ChannelOpen -> relay.
421async fn handle_socks_connection(
422    raw_stream: TcpStream,
423    tunnel_id: u32,
424    channels: Arc<ChannelMap>,
425    ws_tx: mpsc::Sender<Message>,
426) {
427    // SOCKS handshake with 5s timeout to prevent one stalled client from blocking
428    let handshake = tokio::time::timeout(
429        std::time::Duration::from_secs(5),
430        crate::socks5::socks5_handshake(raw_stream),
431    )
432    .await;
433    let (mut tcp_stream, req) = match handshake {
434        Ok(Ok(result)) => result,
435        Ok(Err(e)) => {
436            warn!("SOCKS5 handshake failed: {e}");
437            return;
438        }
439        Err(_) => {
440            warn!("SOCKS5 handshake timed out");
441            return;
442        }
443    };
444
445    let channel_id = channels.alloc_id();
446    info!(
447        "SOCKS5 connection -> {}, channel {channel_id}",
448        req.target_addr
449    );
450
451    let (data_tx, data_rx) = mpsc::channel::<Bytes>(256);
452    channels
453        .insert_with_tunnel(channel_id, tunnel_id, data_tx)
454        .await;
455
456    let ready_rx = channels.wait_ready(channel_id).await;
457
458    let open = WsTextMessage::Control(ControlMessage::ChannelOpen {
459        channel_id,
460        tunnel_id,
461        target: Some(req.target_addr),
462    });
463    if let Ok(json) = serde_json::to_string(&open) {
464        if ws_tx.send(Message::Text(json)).await.is_err() {
465            channels.remove(channel_id).await;
466            return;
467        }
468    }
469
470    // Wait for server to confirm the remote connection is ready (bounded timeout)
471    let ready_result = tokio::time::timeout(
472        std::time::Duration::from_secs(10),
473        ready_rx,
474    )
475    .await;
476    if ready_result.is_err() || ready_result.unwrap().is_err() {
477        warn!("Channel {channel_id} ready timeout or signal dropped");
478        channels.remove(channel_id).await;
479        let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
480        if let Ok(json) = serde_json::to_string(&close) {
481            let _ = ws_tx.send(Message::Text(json)).await;
482        }
483        return;
484    }
485
486    if crate::socks5::send_socks5_success(&mut tcp_stream)
487        .await
488        .is_err()
489    {
490        warn!("Failed to send SOCKS5 success for channel {channel_id}");
491        channels.remove(channel_id).await;
492        let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
493        if let Ok(json) = serde_json::to_string(&close) {
494            let _ = ws_tx.send(Message::Text(json)).await;
495        }
496        return;
497    }
498
499    relay_tcp_ws(tcp_stream, channel_id, data_rx, channels, ws_tx).await;
500}
501
502/// Bidirectional relay between a TCP stream and a WS channel.
503/// `data_rx` must already be registered in `channels` before calling this.
504async fn relay_tcp_ws(
505    tcp: TcpStream,
506    channel_id: u32,
507    mut data_rx: mpsc::Receiver<Bytes>,
508    channels: Arc<ChannelMap>,
509    ws_tx: mpsc::Sender<Message>,
510) {
511    let (mut tcp_read, mut tcp_write) = tcp.into_split();
512
513    // WS -> TCP
514    let ws2tcp = tokio::spawn(async move {
515        while let Some(data) = data_rx.recv().await {
516            if tcp_write.write_all(&data).await.is_err() {
517                break;
518            }
519        }
520        let _ = tcp_write.shutdown().await;
521    });
522
523    // TCP -> WS
524    let ws_tx_data = ws_tx.clone();
525    let tcp2ws = tokio::spawn(async move {
526        let mut buf = vec![0u8; 8192];
527        loop {
528            match tcp_read.read(&mut buf).await {
529                Ok(0) | Err(_) => break,
530                Ok(n) => {
531                    let frame = frame_tunnel_data(channel_id, &buf[..n]);
532                    if ws_tx_data.send(Message::Binary(frame)).await.is_err() {
533                        break;
534                    }
535                }
536            }
537        }
538    });
539
540    // When first direction finishes: notify peer, give grace period to drain,
541    // then remove channel routing and force-abort.
542    let ws2tcp_abort = ws2tcp.abort_handle();
543    let tcp2ws_abort = tcp2ws.abort_handle();
544
545    tokio::select! {
546        _ = ws2tcp => {}
547        _ = tcp2ws => {}
548    }
549
550    // Notify peer (channel stays registered for drain)
551    let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
552    if let Ok(json) = serde_json::to_string(&close) {
553        let _ = ws_tx.send(Message::Text(json)).await;
554    }
555
556    // Grace period: channel stays registered so in-flight frames are delivered
557    tokio::time::sleep(std::time::Duration::from_secs(5)).await;
558
559    // Now remove and force-abort
560    channels.remove(channel_id).await;
561    ws2tcp_abort.abort();
562    tcp2ws_abort.abort();
563}
564
565async fn send_response(ws_tx: &mpsc::Sender<Message>, resp: CommandResponse) {
566    let msg = WsTextMessage::Response(resp);
567    if let Ok(json) = serde_json::to_string(&msg) {
568        let _ = ws_tx.send(Message::Text(json)).await;
569    }
570}
571
572fn parse_wss_url(url: &str) -> Result<(String, u16)> {
573    let stripped = url
574        .strip_prefix("wss://")
575        .ok_or_else(|| ProxyError::Other("Server URL must start with wss://".into()))?;
576    let (host, port) = if let Some((h, p)) = stripped.rsplit_once(':') {
577        let port: u16 = p
578            .parse()
579            .map_err(|_| ProxyError::Other(format!("Invalid port in URL: {p}")))?;
580        (h.to_string(), port)
581    } else {
582        (stripped.to_string(), 443)
583    };
584    Ok((host, port))
585}
586
587fn pem_to_cert_der(pem: &str) -> Result<CertificateDer<'static>> {
588    let mut reader = std::io::BufReader::new(pem.as_bytes());
589    let certs = rustls_pemfile::certs(&mut reader)
590        .collect::<std::result::Result<Vec<_>, _>>()?;
591    certs
592        .into_iter()
593        .next()
594        .ok_or_else(|| ProxyError::Other("No certificate found in PEM".into()))
595}
596
597fn pem_to_key_der(pem: &str) -> Result<PrivatePkcs8KeyDer<'static>> {
598    let mut reader = std::io::BufReader::new(pem.as_bytes());
599    let keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
600        .collect::<std::result::Result<Vec<_>, _>>()?;
601    keys.into_iter()
602        .next()
603        .ok_or_else(|| ProxyError::Other("No PKCS8 private key found in PEM".into()))
604}