1use crate::config::{Config, ServerConfig, ServerServiceConfig, ServiceType, TransportType};
2use crate::config_watcher::{ConfigChange, ServerServiceChange};
3use crate::constants::{listen_backoff, UDP_BUFFER_SIZE};
4use crate::helper::retry_notify_with_deadline;
5use crate::multi_map::MultiMap;
6use crate::protocol::Hello::{ControlChannelHello, DataChannelHello};
7use crate::protocol::{
8 self, read_auth, read_hello, Ack, ControlChannelCmd, DataChannelCmd, Hello, UdpTraffic,
9 HASH_WIDTH_IN_BYTES,
10};
11use crate::transport::{SocketOpts, TcpTransport, Transport};
12use anyhow::{anyhow, bail, Context, Result};
13use backoff::backoff::Backoff;
14use backoff::ExponentialBackoff;
15
16use rand::RngCore;
17use std::collections::HashMap;
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::io::{self, copy_bidirectional, AsyncReadExt, AsyncWriteExt};
21use tokio::net::{TcpListener, TcpStream, UdpSocket};
22use tokio::sync::{broadcast, mpsc, RwLock};
23use tokio::time;
24use tracing::{debug, error, info, info_span, instrument, warn, Instrument, Span};
25
26#[cfg(feature = "noise")]
27use crate::transport::NoiseTransport;
28#[cfg(feature = "tls")]
29use crate::transport::TlsTransport;
30#[cfg(feature = "websocket")]
31use crate::transport::WebsocketTransport;
32
33type ServiceDigest = protocol::Digest; type Nonce = protocol::Digest; const TCP_POOL_SIZE: usize = 8; const UDP_POOL_SIZE: usize = 2; const CHAN_SIZE: usize = 2048; const HANDSHAKE_TIMEOUT: u64 = 5; pub async fn run_server(
43 config: Config,
44 shutdown_rx: broadcast::Receiver<bool>,
45 update_rx: mpsc::Receiver<ConfigChange>,
46) -> Result<()> {
47 let config = match config.server {
48 Some(config) => config,
49 None => {
50 return Err(anyhow!("Try to run as a server, but the configuration is missing. Please add the `[server]` block"))
51 }
52 };
53
54 match config.transport.transport_type {
55 TransportType::Tcp => {
56 let mut server = Server::<TcpTransport>::from(config).await?;
57 server.run(shutdown_rx, update_rx).await?;
58 }
59 TransportType::Tls => {
60 #[cfg(feature = "tls")]
61 {
62 let mut server = Server::<TlsTransport>::from(config).await?;
63 server.run(shutdown_rx, update_rx).await?;
64 }
65 #[cfg(not(feature = "tls"))]
66 crate::helper::feature_not_compile("tls")
67 }
68 TransportType::Noise => {
69 #[cfg(feature = "noise")]
70 {
71 let mut server = Server::<NoiseTransport>::from(config).await?;
72 server.run(shutdown_rx, update_rx).await?;
73 }
74 #[cfg(not(feature = "noise"))]
75 crate::helper::feature_not_compile("noise")
76 }
77 TransportType::Websocket => {
78 #[cfg(feature = "websocket")]
79 {
80 let mut server = Server::<WebsocketTransport>::from(config).await?;
81 server.run(shutdown_rx, update_rx).await?;
82 }
83 #[cfg(not(feature = "websocket"))]
84 crate::helper::feature_not_compile("websocket")
85 }
86 }
87
88 Ok(())
89}
90
91type ControlChannelMap<T> = MultiMap<ServiceDigest, Nonce, ControlChannelHandle<T>>;
94
95struct Server<T: Transport> {
97 config: Arc<ServerConfig>,
99
100 services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
102 control_channels: Arc<RwLock<ControlChannelMap<T>>>,
104 transport: Arc<T>,
106}
107
108fn generate_service_hashmap(
110 server_config: &ServerConfig,
111) -> HashMap<ServiceDigest, ServerServiceConfig> {
112 let mut ret = HashMap::new();
113 for u in &server_config.services {
114 ret.insert(protocol::digest(u.0.as_bytes()), (*u.1).clone());
115 }
116 ret
117}
118
119impl<T: 'static + Transport> Server<T> {
120 pub async fn from(config: ServerConfig) -> Result<Server<T>> {
122 let config = Arc::new(config);
123 let services = Arc::new(RwLock::new(generate_service_hashmap(&config)));
124 let control_channels = Arc::new(RwLock::new(ControlChannelMap::new()));
125 let transport = Arc::new(T::new(&config.transport)?);
126 Ok(Server {
127 config,
128 services,
129 control_channels,
130 transport,
131 })
132 }
133
134 pub async fn run(
136 &mut self,
137 mut shutdown_rx: broadcast::Receiver<bool>,
138 mut update_rx: mpsc::Receiver<ConfigChange>,
139 ) -> Result<()> {
140 let l = self
142 .transport
143 .bind(&self.config.bind_addr)
144 .await
145 .with_context(|| "Failed to listen at `server.bind_addr`")?;
146 info!("Listening at {}", self.config.bind_addr);
147
148 let mut backoff = ExponentialBackoff {
150 max_interval: Duration::from_millis(100),
151 max_elapsed_time: None,
152 ..Default::default()
153 };
154
155 loop {
157 tokio::select! {
158 ret = self.transport.accept(&l) => {
160 match ret {
161 Err(err) => {
162 if let Some(err) = err.downcast_ref::<io::Error>() {
164 if let Some(d) = backoff.next_backoff() {
168 error!("Failed to accept: {:#}. Retry in {:?}...", err, d);
169 time::sleep(d).await;
170 } else {
171 error!("Too many retries. Aborting...");
173 break;
174 }
175 }
176 }
179 Ok((conn, addr)) => {
180 backoff.reset();
181
182 match time::timeout(Duration::from_secs(HANDSHAKE_TIMEOUT), self.transport.handshake(conn)).await {
184 Ok(conn) => {
185 match conn.with_context(|| "Failed to do transport handshake") {
186 Ok(conn) => {
187 let services = self.services.clone();
188 let control_channels = self.control_channels.clone();
189 let server_config = self.config.clone();
190 tokio::spawn(async move {
191 if let Err(err) = handle_connection(conn, services, control_channels, server_config).await {
192 error!("{:#}", err);
193 }
194 }.instrument(info_span!("connection", %addr)));
195 }, Err(e) => {
196 error!("{:#}", e);
197 }
198 }
199 },
200 Err(e) => {
201 error!("Transport handshake timeout: {}", e);
202 }
203 }
204 }
205 }
206 },
207 _ = shutdown_rx.recv() => {
209 info!("Shuting down gracefully...");
210 break;
211 },
212 e = update_rx.recv() => {
213 if let Some(e) = e {
214 self.handle_hot_reload(e).await;
215 }
216 }
217 }
218 }
219
220 info!("Shutdown");
221
222 Ok(())
223 }
224
225 async fn handle_hot_reload(&mut self, e: ConfigChange) {
226 match e {
227 ConfigChange::ServerChange(server_change) => match server_change {
228 ServerServiceChange::Add(cfg) => {
229 let hash = protocol::digest(cfg.name.as_bytes());
230 let mut wg = self.services.write().await;
231 let _ = wg.insert(hash, cfg);
232
233 let mut wg = self.control_channels.write().await;
234 let _ = wg.remove1(&hash);
235 }
236 ServerServiceChange::Delete(s) => {
237 let hash = protocol::digest(s.as_bytes());
238 let _ = self.services.write().await.remove(&hash);
239
240 let mut wg = self.control_channels.write().await;
241 let _ = wg.remove1(&hash);
242 }
243 },
244 ignored => warn!("Ignored {:?} since running as a server", ignored),
245 }
246 }
247}
248
249async fn handle_connection<T: 'static + Transport>(
251 mut conn: T::Stream,
252 services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
253 control_channels: Arc<RwLock<ControlChannelMap<T>>>,
254 server_config: Arc<ServerConfig>,
255) -> Result<()> {
256 let hello = read_hello(&mut conn).await?;
258 match hello {
259 ControlChannelHello(_, service_digest) => {
260 do_control_channel_handshake(
261 conn,
262 services,
263 control_channels,
264 service_digest,
265 server_config,
266 )
267 .await?;
268 }
269 DataChannelHello(_, nonce) => {
270 do_data_channel_handshake(conn, control_channels, nonce).await?;
271 }
272 }
273 Ok(())
274}
275
276async fn do_control_channel_handshake<T: 'static + Transport>(
277 mut conn: T::Stream,
278 services: Arc<RwLock<HashMap<ServiceDigest, ServerServiceConfig>>>,
279 control_channels: Arc<RwLock<ControlChannelMap<T>>>,
280 service_digest: ServiceDigest,
281 server_config: Arc<ServerConfig>,
282) -> Result<()> {
283 info!("Try to handshake a control channel");
284
285 T::hint(&conn, SocketOpts::for_control_channel());
286
287 let mut nonce = vec![0u8; HASH_WIDTH_IN_BYTES];
289 rand::thread_rng().fill_bytes(&mut nonce);
290
291 let hello_send = Hello::ControlChannelHello(
293 protocol::CURRENT_PROTO_VERSION,
294 nonce.clone().try_into().unwrap(),
295 );
296 conn.write_all(&bincode::serialize(&hello_send).unwrap())
297 .await?;
298 conn.flush().await?;
299
300 let service_config = match services.read().await.get(&service_digest) {
302 Some(v) => v,
303 None => {
304 conn.write_all(&bincode::serialize(&Ack::ServiceNotExist).unwrap())
305 .await?;
306 bail!("No such a service {}", hex::encode(service_digest));
307 }
308 }
309 .to_owned();
310
311 let service_name = &service_config.name;
312
313 let mut concat = Vec::from(service_config.token.as_ref().unwrap().as_bytes());
315 concat.append(&mut nonce);
316
317 let protocol::Auth(d) = read_auth(&mut conn).await?;
319
320 let session_key = protocol::digest(&concat);
322 if session_key != d {
323 conn.write_all(&bincode::serialize(&Ack::AuthFailed).unwrap())
324 .await?;
325 debug!(
326 "Expect {}, but got {}",
327 hex::encode(session_key),
328 hex::encode(d)
329 );
330 bail!("Service {} failed the authentication", service_name);
331 } else {
332 let mut h = control_channels.write().await;
333
334 if h.remove1(&service_digest).is_some() {
339 warn!(
340 "Dropping previous control channel for service {}",
341 service_name
342 );
343 }
344
345 conn.write_all(&bincode::serialize(&Ack::Ok).unwrap())
347 .await?;
348 conn.flush().await?;
349
350 info!(service = %service_config.name, "Control channel established");
351 let handle =
352 ControlChannelHandle::new(conn, service_config, server_config.heartbeat_interval);
353
354 let _ = h.insert(service_digest, session_key, handle);
356 }
357
358 Ok(())
359}
360
361async fn do_data_channel_handshake<T: 'static + Transport>(
362 conn: T::Stream,
363 control_channels: Arc<RwLock<ControlChannelMap<T>>>,
364 nonce: Nonce,
365) -> Result<()> {
366 debug!("Try to handshake a data channel");
367
368 let control_channels_guard = control_channels.read().await;
370 match control_channels_guard.get2(&nonce) {
371 Some(handle) => {
372 T::hint(&conn, SocketOpts::from_server_cfg(&handle.service));
373
374 handle
376 .data_ch_tx
377 .send(conn)
378 .await
379 .with_context(|| "Data channel for a stale control channel")?;
380 }
381 None => {
382 warn!("Data channel has incorrect nonce");
383 }
384 }
385 Ok(())
386}
387
388pub struct ControlChannelHandle<T: Transport> {
389 _shutdown_tx: broadcast::Sender<bool>,
391 data_ch_tx: mpsc::Sender<T::Stream>,
392 service: ServerServiceConfig,
393}
394
395impl<T> ControlChannelHandle<T>
396where
397 T: 'static + Transport,
398{
399 #[instrument(name = "handle", skip_all, fields(service = %service.name))]
402 fn new(
403 conn: T::Stream,
404 service: ServerServiceConfig,
405 heartbeat_interval: u64,
406 ) -> ControlChannelHandle<T> {
407 let (shutdown_tx, shutdown_rx) = broadcast::channel::<bool>(1);
409
410 let (data_ch_tx, data_ch_rx) = mpsc::channel(CHAN_SIZE * 2);
412
413 let (data_ch_req_tx, data_ch_req_rx) = mpsc::unbounded_channel();
415
416 let pool_size = match service.service_type {
418 ServiceType::Tcp => TCP_POOL_SIZE,
419 ServiceType::Udp => UDP_POOL_SIZE,
420 };
421
422 for _i in 0..pool_size {
423 if let Err(e) = data_ch_req_tx.send(true) {
424 error!("Failed to request data channel {}", e);
425 };
426 }
427
428 let shutdown_rx_clone = shutdown_tx.subscribe();
429 let bind_addr = service.bind_addr.clone();
430 match service.service_type {
431 ServiceType::Tcp => tokio::spawn(
432 async move {
433 if let Err(e) = run_tcp_connection_pool::<T>(
434 bind_addr,
435 data_ch_rx,
436 data_ch_req_tx,
437 shutdown_rx_clone,
438 )
439 .await
440 .with_context(|| "Failed to run TCP connection pool")
441 {
442 error!("{:#}", e);
443 }
444 }
445 .instrument(Span::current()),
446 ),
447 ServiceType::Udp => tokio::spawn(
448 async move {
449 if let Err(e) = run_udp_connection_pool::<T>(
450 bind_addr,
451 data_ch_rx,
452 data_ch_req_tx,
453 shutdown_rx_clone,
454 )
455 .await
456 .with_context(|| "Failed to run TCP connection pool")
457 {
458 error!("{:#}", e);
459 }
460 }
461 .instrument(Span::current()),
462 ),
463 };
464
465 let ch = ControlChannel::<T> {
467 conn,
468 shutdown_rx,
469 data_ch_req_rx,
470 heartbeat_interval,
471 };
472
473 tokio::spawn(
475 async move {
476 if let Err(err) = ch.run().await {
477 error!("{:#}", err);
478 }
479 }
480 .instrument(Span::current()),
481 );
482
483 ControlChannelHandle {
484 _shutdown_tx: shutdown_tx,
485 data_ch_tx,
486 service,
487 }
488 }
489}
490
491struct ControlChannel<T: Transport> {
493 conn: T::Stream, shutdown_rx: broadcast::Receiver<bool>, data_ch_req_rx: mpsc::UnboundedReceiver<bool>, heartbeat_interval: u64, }
498
499impl<T: Transport> ControlChannel<T> {
500 async fn write_and_flush(&mut self, data: &[u8]) -> Result<()> {
501 self.conn
502 .write_all(data)
503 .await
504 .with_context(|| "Failed to write control cmds")?;
505 self.conn
506 .flush()
507 .await
508 .with_context(|| "Failed to flush control cmds")?;
509 Ok(())
510 }
511 #[instrument(skip_all)]
513 async fn run(mut self) -> Result<()> {
514 let create_ch_cmd = bincode::serialize(&ControlChannelCmd::CreateDataChannel).unwrap();
515 let heartbeat = bincode::serialize(&ControlChannelCmd::HeartBeat).unwrap();
516
517 loop {
519 tokio::select! {
520 val = self.data_ch_req_rx.recv() => {
521 match val {
522 Some(_) => {
523 if let Err(e) = self.write_and_flush(&create_ch_cmd).await {
524 error!("{:#}", e);
525 break;
526 }
527 }
528 None => {
529 break;
530 }
531 }
532 },
533 _ = time::sleep(Duration::from_secs(self.heartbeat_interval)), if self.heartbeat_interval != 0 => {
534 if let Err(e) = self.write_and_flush(&heartbeat).await {
535 error!("{:#}", e);
536 break;
537 }
538 }
539 _ = self.shutdown_rx.recv() => {
541 break;
542 }
543 }
544 }
545
546 info!("Control channel shutdown");
547
548 Ok(())
549 }
550}
551
552fn tcp_listen_and_send(
553 addr: String,
554 data_ch_req_tx: mpsc::UnboundedSender<bool>,
555 mut shutdown_rx: broadcast::Receiver<bool>,
556) -> mpsc::Receiver<TcpStream> {
557 let (tx, rx) = mpsc::channel(CHAN_SIZE);
558
559 tokio::spawn(async move {
560 let l = retry_notify_with_deadline(listen_backoff(), || async {
561 Ok(TcpListener::bind(&addr).await?)
562 }, |e, duration| {
563 error!("{:#}. Retry in {:?}", e, duration);
564 }, &mut shutdown_rx).await
565 .with_context(|| "Failed to listen for the service");
566
567 let l: TcpListener = match l {
568 Ok(v) => v,
569 Err(e) => {
570 error!("{:#}", e);
571 return;
572 }
573 };
574
575 info!("Listening at {}", &addr);
576
577 let mut backoff = ExponentialBackoff {
579 max_interval: Duration::from_secs(1),
580 max_elapsed_time: None,
581 ..Default::default()
582 };
583
584 loop {
586 tokio::select! {
587 val = l.accept() => {
588 match val {
589 Err(e) => {
590 error!("{}. Sleep for a while", e);
593 if let Some(d) = backoff.next_backoff() {
594 time::sleep(d).await;
595 } else {
596 error!("Too many retries. Aborting...");
598 break;
599 }
600 }
601 Ok((incoming, addr)) => {
602 if data_ch_req_tx.send(true).with_context(|| "Failed to send data chan create request").is_err() {
604 break;
607 }
608
609 backoff.reset();
610
611 debug!("New visitor from {}", addr);
612
613 let _ = tx.send(incoming).await;
615 }
616 }
617 },
618 _ = shutdown_rx.recv() => {
619 break;
620 }
621 }
622 }
623
624 info!("TCPListener shutdown");
625 }.instrument(Span::current()));
626
627 rx
628}
629
630#[instrument(skip_all)]
631async fn run_tcp_connection_pool<T: Transport>(
632 bind_addr: String,
633 mut data_ch_rx: mpsc::Receiver<T::Stream>,
634 data_ch_req_tx: mpsc::UnboundedSender<bool>,
635 shutdown_rx: broadcast::Receiver<bool>,
636) -> Result<()> {
637 let mut visitor_rx = tcp_listen_and_send(bind_addr, data_ch_req_tx.clone(), shutdown_rx);
638 let cmd = bincode::serialize(&DataChannelCmd::StartForwardTcp).unwrap();
639
640 'pool: while let Some(mut visitor) = visitor_rx.recv().await {
641 loop {
642 if let Some(mut ch) = data_ch_rx.recv().await {
643 if ch.write_all(&cmd).await.is_ok() {
644 tokio::spawn(async move {
645 let _ = copy_bidirectional(&mut ch, &mut visitor).await;
646 });
647 break;
648 } else {
649 if data_ch_req_tx.send(true).is_err() {
651 break 'pool;
652 }
653 }
654 } else {
655 break 'pool;
656 }
657 }
658 }
659
660 info!("Shutdown");
661 Ok(())
662}
663
664#[instrument(skip_all)]
665async fn run_udp_connection_pool<T: Transport>(
666 bind_addr: String,
667 mut data_ch_rx: mpsc::Receiver<T::Stream>,
668 _data_ch_req_tx: mpsc::UnboundedSender<bool>,
669 mut shutdown_rx: broadcast::Receiver<bool>,
670) -> Result<()> {
671 let l = retry_notify_with_deadline(
674 listen_backoff(),
675 || async { Ok(UdpSocket::bind(&bind_addr).await?) },
676 |e, duration| {
677 warn!("{:#}. Retry in {:?}", e, duration);
678 },
679 &mut shutdown_rx,
680 )
681 .await
682 .with_context(|| "Failed to listen for the service")?;
683
684 info!("Listening at {}", &bind_addr);
685
686 let cmd = bincode::serialize(&DataChannelCmd::StartForwardUdp).unwrap();
687
688 let mut conn = data_ch_rx
690 .recv()
691 .await
692 .ok_or_else(|| anyhow!("No available data channels"))?;
693 conn.write_all(&cmd).await?;
694
695 let mut buf = [0u8; UDP_BUFFER_SIZE];
696 loop {
697 tokio::select! {
698 val = l.recv_from(&mut buf) => {
700 let (n, from) = val?;
701 UdpTraffic::write_slice(&mut conn, from, &buf[..n]).await?;
702 },
703
704 hdr_len = conn.read_u8() => {
706 let t = UdpTraffic::read(&mut conn, hdr_len?).await?;
707 l.send_to(&t.data, t.from).await?;
708 }
709
710 _ = shutdown_rx.recv() => {
711 break;
712 }
713 }
714 }
715
716 debug!("UDP pool dropped");
717
718 Ok(())
719}