1use crate::cert::CertificateAuthority;
2use crate::error::{ProxyError, Result};
3use crate::protocol::{
4 frame_tunnel_data, parse_tunnel_data, Command, CommandResponse, ControlMessage, WsTextMessage,
5};
6use crate::ws::{self, ChannelMap};
7use bytes::Bytes;
8use futures_util::{SinkExt, StreamExt};
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicU32, Ordering};
11use std::sync::Arc;
12use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
13use tokio::net::{TcpListener, TcpStream};
14use tokio::sync::mpsc;
15use tokio_rustls::TlsAcceptor;
16use tokio_tungstenite::tungstenite::Message;
17use tracing::{error, info, warn};
18
19struct ClientHandle {
20 cn: String,
21 session_id: u64,
22 ws_tx: mpsc::Sender<Message>,
23 shutdown_tx: tokio::sync::watch::Sender<bool>,
24 channels: Arc<ChannelMap>,
25 pending_reverse: Arc<tokio::sync::RwLock<HashMap<u32, u16>>>,
27 pending_socks: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>>,
29 authorized_tunnels: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>>,
31 reverse_listeners: Arc<tokio::sync::RwLock<HashMap<u32, tokio::task::AbortHandle>>>,
33}
34
35struct ServerState {
36 clients: Arc<tokio::sync::RwLock<HashMap<String, ClientHandle>>>,
37 next_tunnel_id: AtomicU32,
38 next_session_id: std::sync::atomic::AtomicU64,
39}
40
41impl ServerState {
42 fn alloc_tunnel_id(&self) -> u32 {
43 self.next_tunnel_id.fetch_add(1, Ordering::Relaxed)
44 }
45}
46
47pub async fn run(
49 host: &str,
50 port: u16,
51 server_name: &str,
52 ca: Arc<CertificateAuthority>,
53) -> Result<()> {
54 let listen_addr = format!("{host}:{port}");
55
56 let server_ck = ca.generate_server_cert(server_name)?;
58 let ca_cert_der = ca.ca_cert_der();
59 let tls_config =
60 crate::tls::make_mtls_server_config(server_ck.cert_der, server_ck.key_der, ca_cert_der)?;
61 let acceptor = TlsAcceptor::from(tls_config);
62
63 let state = Arc::new(ServerState {
64 clients: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
65 next_session_id: std::sync::atomic::AtomicU64::new(1),
66 next_tunnel_id: AtomicU32::new(1),
67 });
68
69 let listener = TcpListener::bind(&listen_addr).await?;
70 info!(
71 "C2 server listening on {listen_addr} (cert name: {server_name}, mTLS required)"
72 );
73
74 let state_stdin = state.clone();
75 tokio::spawn(async move {
76 if let Err(e) = stdin_command_loop(state_stdin).await {
77 error!("Stdin command loop error: {e}");
78 }
79 });
80
81 let handshake_semaphore = Arc::new(tokio::sync::Semaphore::new(64));
83
84 loop {
85 let (stream, peer) = listener.accept().await?;
86 let acceptor = acceptor.clone();
87 let state = state.clone();
88 let sem = handshake_semaphore.clone();
89
90 tokio::spawn(async move {
91 let permit = match sem.try_acquire() {
93 Ok(p) => p,
94 Err(_) => {
95 warn!("Rejecting {peer}: too many concurrent handshakes");
96 return;
97 }
98 };
99
100 let handshake_result = perform_handshake(stream, peer, &acceptor).await;
102 drop(permit); match handshake_result {
105 Ok((ws_stream, fingerprint, cn)) => {
106 let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
107 match run_session(ws_stream, peer, fingerprint, cn, state, shutdown_tx, shutdown_rx).await {
108 Ok(()) => info!("Client {peer} disconnected"),
109 Err(e) => warn!("Client {peer} error: {e}"),
110 }
111 }
112 Err(e) => warn!("Client {peer} handshake error: {e}"),
113 }
114 });
115 }
116}
117
118async fn perform_handshake(
120 stream: TcpStream,
121 peer: std::net::SocketAddr,
122 acceptor: &TlsAcceptor,
123) -> Result<(ws::ServerWsStream, String, String)> {
124 let tls_stream = tokio::time::timeout(
125 std::time::Duration::from_secs(15),
126 acceptor.accept(stream),
127 )
128 .await
129 .map_err(|_| ProxyError::Other(format!("TLS handshake timed out for {peer}")))?
130 .map_err(|e| ProxyError::Other(format!("TLS handshake failed for {peer}: {e}")))?;
131
132 let (fingerprint, cn) = extract_client_identity(&tls_stream);
133 info!("Client authenticated: {cn} [{fingerprint}] ({peer})");
134
135 let ws_stream = tokio::time::timeout(
136 std::time::Duration::from_secs(10),
137 ws::accept_ws(tls_stream),
138 )
139 .await
140 .map_err(|_| ProxyError::Other(format!("WebSocket upgrade timed out for {peer}")))?
141 ?;
142
143 Ok((ws_stream, fingerprint, cn))
144}
145
146async fn run_session(
148 ws_stream: ws::ServerWsStream,
149 _peer: std::net::SocketAddr,
150 fingerprint: String,
151 cn: String,
152 state: Arc<ServerState>,
153 shutdown_tx: tokio::sync::watch::Sender<bool>,
154 mut shutdown_rx: tokio::sync::watch::Receiver<bool>,
155) -> Result<()> {
156 let client_label = format!("{cn} [{fingerprint}]");
157 let (mut ws_sink, mut ws_source) = ws_stream.split();
158
159 let channels = Arc::new(ChannelMap::new(2)); let (ws_tx, mut ws_rx) = mpsc::channel::<Message>(256);
161
162 let reverse_listeners: Arc<tokio::sync::RwLock<HashMap<u32, tokio::task::AbortHandle>>> =
163 Arc::new(tokio::sync::RwLock::new(HashMap::new()));
164 let pending_reverse: Arc<tokio::sync::RwLock<HashMap<u32, u16>>> =
165 Arc::new(tokio::sync::RwLock::new(HashMap::new()));
166 let authorized_tunnels: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>> =
167 Arc::new(tokio::sync::RwLock::new(std::collections::HashSet::new()));
168 let pending_socks: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>> =
169 Arc::new(tokio::sync::RwLock::new(std::collections::HashSet::new()));
170
171 let session_id = state
172 .next_session_id
173 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
174
175 {
177 let mut clients = state.clients.write().await;
178 if let Some(old) = clients.remove(&fingerprint) {
179 warn!("[{client_label}] Evicting stale session for reconnect");
180 old.authorized_tunnels.write().await.clear();
182 old.pending_socks.write().await.clear();
183 old.pending_reverse.write().await.clear();
184 old.channels.close_all().await;
186 for handle in old.reverse_listeners.write().await.drain() {
188 handle.1.abort();
189 }
190 let _ = old.shutdown_tx.send(true);
192 drop(old);
193 }
194 clients.insert(
195 fingerprint.clone(),
196 ClientHandle {
197 cn: cn.clone(),
198 session_id,
199 ws_tx: ws_tx.clone(),
200 shutdown_tx,
201 channels: channels.clone(),
202 pending_reverse: pending_reverse.clone(),
203 pending_socks: pending_socks.clone(),
204 authorized_tunnels: authorized_tunnels.clone(),
205 reverse_listeners: reverse_listeners.clone(),
206 },
207 );
208 }
209
210 let label_writer = client_label.clone();
212 let writer_handle = tokio::spawn(async move {
213 while let Some(msg) = ws_rx.recv().await {
214 if ws_sink.send(msg).await.is_err() {
215 info!("[{label_writer}] WS write closed");
216 break;
217 }
218 }
219 });
220
221 let channels_reader = channels.clone();
223 let ws_tx_reader = ws_tx.clone();
224 let label_reader = client_label.clone();
225 let tunnel_state = ClientTunnelState {
226 pending_reverse: pending_reverse.clone(),
227 pending_socks: pending_socks.clone(),
228 authorized_tunnels: authorized_tunnels.clone(),
229 reverse_listeners: reverse_listeners.clone(),
230 };
231 loop {
232 let msg_result = tokio::select! {
233 msg = ws_source.next() => msg,
234 _ = shutdown_rx.changed() => {
235 info!("[{label_reader}] Session shutdown signal received");
236 break;
237 }
238 };
239 let msg = match msg_result {
240 Some(Ok(m)) => m,
241 Some(Err(e)) => {
242 warn!("[{label_reader}] WebSocket read error: {e}");
243 break;
244 }
245 None => break,
246 };
247
248 match msg {
249 Message::Text(text) => match serde_json::from_str::<WsTextMessage>(&text) {
250 Ok(WsTextMessage::Response(resp)) => {
251 handle_response(
252 &label_reader,
253 &resp,
254 &tunnel_state,
255 &channels_reader,
256 ws_tx_reader.clone(),
257 )
258 .await;
259 }
260 Ok(WsTextMessage::Control(ctrl)) => {
261 handle_server_control(
262 &label_reader,
263 ctrl,
264 channels_reader.clone(),
265 &tunnel_state.authorized_tunnels,
266 ws_tx_reader.clone(),
267 )
268 .await;
269 }
270 Ok(WsTextMessage::Command(_)) => {
271 warn!("[{label_reader}] Unexpected command from client");
272 }
273 Err(e) => {
274 warn!("[{label_reader}] Failed to parse message: {e}");
275 }
276 },
277 Message::Binary(data) => {
278 if let Some((channel_id, payload)) = parse_tunnel_data(&data) {
279 if !channels_reader
280 .send(channel_id, Bytes::copy_from_slice(payload))
281 .await
282 {
283 warn!("[{label_reader}] Data for unknown channel {channel_id}");
284 }
285 }
286 }
287 Message::Close(_) => break,
288 _ => {}
289 }
290 }
291
292 writer_handle.abort();
294 channels.close_all().await;
295 {
296 let listeners = reverse_listeners.read().await;
297 for handle in listeners.values() {
298 handle.abort();
299 }
300 }
301 {
303 let mut clients = state.clients.write().await;
304 if let Some(existing) = clients.get(&fingerprint) {
305 if existing.session_id == session_id {
306 clients.remove(&fingerprint);
307 }
308 }
309 }
310 info!("[{client_label}] Client removed");
311
312 Ok(())
313}
314
315struct ClientTunnelState {
317 pending_reverse: Arc<tokio::sync::RwLock<HashMap<u32, u16>>>,
318 pending_socks: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>>,
319 authorized_tunnels: Arc<tokio::sync::RwLock<std::collections::HashSet<u32>>>,
320 reverse_listeners: Arc<tokio::sync::RwLock<HashMap<u32, tokio::task::AbortHandle>>>,
321}
322
323async fn handle_response(
325 label: &str,
326 resp: &CommandResponse,
327 ts: &ClientTunnelState,
328 channels: &Arc<ChannelMap>,
329 ws_tx: mpsc::Sender<Message>,
330) {
331 match resp {
332 CommandResponse::SocksReady { tunnel_id: tid } => {
333 if ts.pending_socks.write().await.remove(tid) {
334 ts.authorized_tunnels.write().await.insert(*tid);
335 info!("[{label}] SOCKS tunnel {tid} authorized via SocksReady");
336 } else {
337 warn!("[{label}] Unexpected SocksReady for tunnel {tid}");
338 }
339 }
340 CommandResponse::ReverseTunnelReady { tunnel_id: tid } => {
341 let remote_port = ts.pending_reverse.write().await.remove(tid);
343 if let Some(port) = remote_port {
344 info!("[{label}] Starting reverse listener on 127.0.0.1:{port} (tunnel {tid})");
345 let channels = channels.clone();
346 let tid = *tid;
347 let label = label.to_string();
348 let handle = tokio::spawn(async move {
349 if let Err(e) =
350 reverse_listen_loop(port, tid, channels, ws_tx, &label).await
351 {
352 warn!("[{label}] Reverse listener error: {e}");
353 }
354 });
355 ts.reverse_listeners
356 .write()
357 .await
358 .insert(tid, handle.abort_handle());
359 } else {
360 info!("[{label}] Ok response: tunnel_id={tid}");
361 }
362 }
363 CommandResponse::Ok { .. } => {
364 info!("[{label}] Ok response");
365 }
366 CommandResponse::Error { tunnel_id, message } => {
367 if let Some(tid) = tunnel_id {
369 if ts.pending_socks.write().await.remove(tid) {
370 ts.authorized_tunnels.write().await.remove(tid);
371 info!("[{label}] Revoked failed SOCKS tunnel {tid}");
372 }
373 ts.pending_reverse.write().await.remove(tid);
374 }
375 warn!("[{label}] Error response: {message}");
376 }
377 CommandResponse::Pong { seq } => {
378 info!("[{label}] Pong seq={seq}");
379 }
380 }
381}
382
383async fn reverse_listen_loop(
385 port: u16,
386 tunnel_id: u32,
387 channels: Arc<ChannelMap>,
388 ws_tx: mpsc::Sender<Message>,
389 label: &str,
390) -> Result<()> {
391 let listener = TcpListener::bind(format!("127.0.0.1:{port}")).await?;
392 info!("[{label}] Reverse tunnel {tunnel_id} listening on 127.0.0.1:{port}");
393
394 loop {
395 let (tcp, peer) = listener.accept().await?;
396 let channel_id = channels.alloc_id();
397 info!("[{label}] Reverse connection from {peer}, channel {channel_id}");
398
399 let (data_tx, data_rx) = mpsc::channel::<Bytes>(256);
402 channels.insert_with_tunnel(channel_id, tunnel_id, data_tx).await;
403 let ready_rx = channels.wait_ready(channel_id).await;
404
405 let open = WsTextMessage::Control(ControlMessage::ChannelOpen {
406 channel_id,
407 tunnel_id,
408 target: None,
409 });
410 if let Ok(json) = serde_json::to_string(&open) {
411 if ws_tx.send(Message::Text(json)).await.is_err() {
412 break Ok(());
413 }
414 }
415
416 let channels = channels.clone();
417 let ws_tx = ws_tx.clone();
418 let label = label.to_string();
419 tokio::spawn(async move {
420 let ready_result = tokio::time::timeout(
422 std::time::Duration::from_secs(10),
423 ready_rx,
424 )
425 .await;
426 if ready_result.is_err() || ready_result.unwrap().is_err() {
427 warn!("[{label}] Channel {channel_id} ready timeout or signal dropped");
428 channels.remove(channel_id).await;
429 let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
430 if let Ok(json) = serde_json::to_string(&close) {
431 let _ = ws_tx.send(Message::Text(json)).await;
432 }
433 return;
434 }
435 relay_tcp_ws(tcp, channel_id, data_rx, channels, ws_tx, &label).await;
436 });
437 }
438}
439
440async fn handle_server_control(
442 label: &str,
443 ctrl: ControlMessage,
444 channels: Arc<ChannelMap>,
445 authorized_tunnels: &tokio::sync::RwLock<std::collections::HashSet<u32>>,
446 ws_tx: mpsc::Sender<Message>,
447) {
448 match ctrl {
449 ControlMessage::ChannelOpen {
450 channel_id,
451 tunnel_id,
452 target,
453 } => {
454 if channel_id % 2 == 0 {
456 warn!("[{label}] Rejected ChannelOpen with even channel_id {channel_id}");
457 return;
458 }
459 if channels.has(channel_id).await {
460 warn!("[{label}] Rejected ChannelOpen with duplicate channel_id {channel_id}");
461 let close =
462 WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
463 if let Ok(json) = serde_json::to_string(&close) {
464 let _ = ws_tx.send(Message::Text(json)).await;
465 }
466 return;
467 }
468
469 if !authorized_tunnels.read().await.contains(&tunnel_id) {
471 warn!(
472 "[{label}] Rejected unsolicited ChannelOpen for tunnel {tunnel_id}, channel {channel_id}"
473 );
474 let close =
475 WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
476 if let Ok(json) = serde_json::to_string(&close) {
477 let _ = ws_tx.send(Message::Text(json)).await;
478 }
479 return;
480 }
481
482 let target = match target {
483 Some(t) => t,
484 None => {
485 warn!("[{label}] ChannelOpen without target");
486 return;
487 }
488 };
489
490 let (data_tx, data_rx) = mpsc::channel::<Bytes>(256);
493 channels
494 .insert_with_tunnel(channel_id, tunnel_id, data_tx)
495 .await;
496
497 info!("[{label}] Channel {channel_id} -> connecting to {target}");
498
499 let channels = channels.clone();
500 let label = label.to_string();
501 tokio::spawn(async move {
502 let connect_result = tokio::time::timeout(
504 std::time::Duration::from_secs(10),
505 TcpStream::connect(&target),
506 )
507 .await;
508 match connect_result {
509 Ok(Ok(tcp)) => {
510 if !channels.has(channel_id).await {
512 warn!("[{label}] Channel {channel_id} revoked during connect, dropping");
513 drop(tcp);
514 return;
515 }
516
517 info!("[{label}] Channel {channel_id} connected to {target}");
518
519 let ready = WsTextMessage::Control(ControlMessage::ChannelReady {
520 channel_id,
521 });
522 if let Ok(json) = serde_json::to_string(&ready) {
523 let _ = ws_tx.send(Message::Text(json)).await;
524 }
525
526 relay_tcp_ws(tcp, channel_id, data_rx, channels, ws_tx.clone(), &label)
527 .await;
528 }
529 Ok(Err(e)) => {
530 warn!("[{label}] Failed to connect to {target}: {e}");
531 channels.remove(channel_id).await;
532 let close =
533 WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
534 if let Ok(json) = serde_json::to_string(&close) {
535 let _ = ws_tx.send(Message::Text(json)).await;
536 }
537 }
538 Err(_) => {
539 warn!("[{label}] Connect to {target} timed out for channel {channel_id}");
540 channels.remove(channel_id).await;
541 let close =
542 WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
543 if let Ok(json) = serde_json::to_string(&close) {
544 let _ = ws_tx.send(Message::Text(json)).await;
545 }
546 }
547 }
548 });
549 }
550 ControlMessage::ChannelReady { channel_id } => {
551 channels.signal_ready(channel_id).await;
552 info!("[{label}] Channel {channel_id} ready");
553 }
554 ControlMessage::ChannelClose { channel_id } => {
555 channels.remove(channel_id).await;
556 info!("[{label}] Channel {channel_id} closed");
557 }
558 }
559}
560
561async fn relay_tcp_ws(
564 tcp: TcpStream,
565 channel_id: u32,
566 mut data_rx: mpsc::Receiver<Bytes>,
567 channels: Arc<ChannelMap>,
568 ws_tx: mpsc::Sender<Message>,
569 label: &str,
570) {
571 let (mut tcp_read, mut tcp_write) = tcp.into_split();
572
573 let ws2tcp = tokio::spawn(async move {
574 while let Some(data) = data_rx.recv().await {
575 if tcp_write.write_all(&data).await.is_err() {
576 break;
577 }
578 }
579 let _ = tcp_write.shutdown().await;
580 });
581
582 let ws_tx_data = ws_tx.clone();
583 let tcp2ws = tokio::spawn(async move {
584 let mut buf = vec![0u8; 8192];
585 loop {
586 match tcp_read.read(&mut buf).await {
587 Ok(0) | Err(_) => break,
588 Ok(n) => {
589 let frame = frame_tunnel_data(channel_id, &buf[..n]);
590 if ws_tx_data.send(Message::Binary(frame)).await.is_err() {
591 break;
592 }
593 }
594 }
595 }
596 });
597
598 let ws2tcp_abort = ws2tcp.abort_handle();
601 let tcp2ws_abort = tcp2ws.abort_handle();
602
603 tokio::select! {
604 _ = ws2tcp => {}
605 _ = tcp2ws => {}
606 }
607
608 let close = WsTextMessage::Control(ControlMessage::ChannelClose { channel_id });
610 if let Ok(json) = serde_json::to_string(&close) {
611 let _ = ws_tx.send(Message::Text(json)).await;
612 }
613
614 tokio::time::sleep(std::time::Duration::from_secs(5)).await;
616
617 channels.remove(channel_id).await;
619 ws2tcp_abort.abort();
620 tcp2ws_abort.abort();
621 info!("[{label}] Channel {channel_id} closed");
622}
623
624fn extract_client_identity(
627 tls_stream: &tokio_rustls::server::TlsStream<TcpStream>,
628) -> (String, String) {
629 let (_, server_conn) = tls_stream.get_ref();
630 let certs = server_conn.peer_certificates().unwrap_or_default();
631 let cert_der = match certs.first() {
632 Some(c) => c.as_ref(),
633 None => return ("unknown".into(), "unknown".into()),
634 };
635
636 let digest = ring::digest::digest(&ring::digest::SHA256, cert_der);
638 let fingerprint: String = digest.as_ref().iter().map(|b| format!("{b:02x}")).collect();
639
640 let cn = extract_cn_from_der(cert_der).unwrap_or_else(|| "unknown".into());
641 (fingerprint, cn)
642}
643
644fn extract_cn_from_der(der: &[u8]) -> Option<String> {
648 let cn_oid = [0x55, 0x04, 0x03];
649 let mut last_cn: Option<String> = None;
650 for i in 0..der.len().saturating_sub(3) {
651 if der[i..i + 3] == cn_oid {
652 let val_start = i + 3;
653 if val_start + 2 <= der.len() {
654 let _tag = der[val_start];
655 let len = der[val_start + 1] as usize;
656 let str_start = val_start + 2;
657 if str_start + len <= der.len() {
658 if let Ok(s) = String::from_utf8(der[str_start..str_start + len].to_vec()) {
659 last_cn = Some(s);
660 }
661 }
662 }
663 }
664 }
665 last_cn
666}
667
668async fn stdin_command_loop(state: Arc<ServerState>) -> Result<()> {
670 let stdin = tokio::io::stdin();
671 let reader = BufReader::new(stdin);
672 let mut lines = reader.lines();
673
674 while let Ok(Some(line)) = lines.next_line().await {
675 let line = line.trim().to_string();
676 if line.is_empty() {
677 continue;
678 }
679
680 let parts: Vec<&str> = line.split_whitespace().collect();
681 match parts.first().copied() {
682 Some("list") => {
683 let clients = state.clients.read().await;
684 if clients.is_empty() {
685 info!("No connected clients");
686 } else {
687 for (fp, handle) in clients.iter() {
688 info!(" - {} [{}]", handle.cn, fp);
689 }
690 }
691 }
692 Some("socks") if parts.len() == 3 => {
693 let cn = parts[1];
694 let port: u16 = match parts[2].parse() {
695 Ok(p) => p,
696 Err(_) => {
697 warn!("Invalid port: {}", parts[2]);
698 continue;
699 }
700 };
701 let tunnel_id = state.alloc_tunnel_id();
702 {
705 let clients = state.clients.read().await;
706 if let Some(client) = find_client_in_map(&clients, cn) {
707 client
709 .pending_socks
710 .write()
711 .await
712 .insert(tunnel_id);
713 }
714 }
715 send_command_to_client(
716 &state,
717 cn,
718 WsTextMessage::Command(Command::Socks { tunnel_id, port }),
719 )
720 .await;
721 }
722 Some("reverse") if parts.len() == 4 => {
723 let cn = parts[1];
724 let remote_port: u16 = match parts[2].parse() {
725 Ok(p) => p,
726 Err(_) => {
727 warn!("Invalid port: {}", parts[2]);
728 continue;
729 }
730 };
731 let local_target = parts[3].to_string();
732 let tunnel_id = state.alloc_tunnel_id();
733
734 send_command_to_client_with_reverse(
740 &state,
741 cn,
742 tunnel_id,
743 remote_port,
744 local_target,
745 )
746 .await;
747 }
748 Some("stop") if parts.len() == 3 => {
749 let cn = parts[1];
750 let tunnel_id: u32 = match parts[2].parse() {
751 Ok(id) => id,
752 Err(_) => {
753 warn!("Invalid tunnel ID: {}", parts[2]);
754 continue;
755 }
756 };
757 {
759 let clients = state.clients.read().await;
760 if let Some(client) = find_client_in_map(&clients, cn) {
761 client.pending_socks.write().await.remove(&tunnel_id);
763 client.pending_reverse.write().await.remove(&tunnel_id);
764 client.authorized_tunnels.write().await.remove(&tunnel_id);
765 if let Some(handle) =
766 client.reverse_listeners.write().await.remove(&tunnel_id)
767 {
768 handle.abort();
769 info!("Aborted reverse listener for tunnel {tunnel_id}");
770 }
771 let closed = client.channels.close_tunnel(tunnel_id).await;
772 if !closed.is_empty() {
773 info!("Closed {} server-side channels for tunnel {tunnel_id}", closed.len());
774 }
775 }
776 }
777 send_command_to_client(
778 &state,
779 cn,
780 WsTextMessage::Command(Command::StopTunnel { tunnel_id }),
781 )
782 .await;
783 }
784 Some("help") | Some("?") => {
785 info!("Commands:");
786 info!(" list - List connected clients");
787 info!(" socks <client_cn> <port> - Start SOCKS5 on client");
788 info!(" reverse <client_cn> <remote_port> <local_target> - Reverse tunnel");
789 info!(" stop <client_cn> <tunnel_id> - Stop a tunnel");
790 }
791 _ => {
792 warn!("Unknown command. Type 'help' for usage.");
793 }
794 }
795 }
796
797 Ok(())
798}
799
800async fn send_command_to_client(state: &ServerState, id: &str, msg: WsTextMessage) {
802 let ws_tx = {
803 let clients = state.clients.read().await;
804 match find_client_in_map(&clients, id) {
805 Some(client) => client.ws_tx.clone(),
806 None => return,
807 }
808 };
809 if let Ok(json) = serde_json::to_string(&msg) {
810 if ws_tx.send(Message::Text(json)).await.is_err() {
811 warn!("Failed to send to {id}");
812 } else {
813 info!("Sent command to {id}");
814 }
815 }
816}
817
818async fn send_command_to_client_with_reverse(
820 state: &ServerState,
821 id: &str,
822 tunnel_id: u32,
823 remote_port: u16,
824 local_target: String,
825) {
826 let msg = WsTextMessage::Command(Command::ReverseTunnel {
827 tunnel_id,
828 remote_port,
829 local_target,
830 });
831 let (ws_tx, pending_reverse) = {
832 let clients = state.clients.read().await;
833 match find_client_in_map(&clients, id) {
834 Some(client) => (client.ws_tx.clone(), client.pending_reverse.clone()),
835 None => return,
836 }
837 };
838 pending_reverse.write().await.insert(tunnel_id, remote_port);
839 if let Ok(json) = serde_json::to_string(&msg) {
840 if ws_tx.send(Message::Text(json)).await.is_err() {
841 warn!("Failed to send to {id}");
842 pending_reverse.write().await.remove(&tunnel_id);
843 } else {
844 info!("Sent reverse tunnel command to {id} (tunnel {tunnel_id}, port {remote_port})");
845 }
846 }
847}
848
849fn find_client_in_map<'a>(
851 clients: &'a HashMap<String, ClientHandle>,
852 id: &str,
853) -> Option<&'a ClientHandle> {
854 if let Some(handle) = clients.get(id) {
856 return Some(handle);
857 }
858 let fp_matches: Vec<_> = clients
860 .iter()
861 .filter(|(fp, _)| fp.starts_with(id))
862 .collect();
863 if fp_matches.len() == 1 {
864 return Some(fp_matches[0].1);
865 }
866 let cn_matches: Vec<_> = clients.values().filter(|h| h.cn == id).collect();
868 match cn_matches.len() {
869 1 => Some(cn_matches[0]),
870 0 => {
871 warn!("Client not found: {id}");
872 None
873 }
874 n => {
875 warn!("Ambiguous CN '{id}' matches {n} clients. Use fingerprint instead.");
876 None
877 }
878 }
879}