1mod channel_bind;
6mod manager;
7mod permission;
8
9use std::{
10 borrow::Cow,
11 collections::HashMap,
12 marker::{Send, Sync},
13 mem,
14 net::{IpAddr, SocketAddr},
15 sync::{
16 Arc,
17 atomic::{AtomicUsize, Ordering},
18 },
19};
20
21use derive_more::with_trait::Display;
22use rand::random;
23use stun_codec::{
24 Message, MessageClass, TransactionId, rfc5766::methods::DATA,
25};
26use tokio::{
27 net::UdpSocket,
28 sync::{Mutex, mpsc},
29 time::{Duration, Instant, sleep},
30};
31
32pub(crate) use self::manager::{Config as ManagerConfig, Manager};
33use self::{channel_bind::ChannelBind, permission::Permission};
34use crate::{
35 Error, Transport,
36 allocation::permission::PERMISSION_LIFETIME,
37 attr::{Attribute, Data, Username, XorPeerAddress},
38 chandata::ChannelData,
39 server::INBOUND_MTU,
40 transport,
41};
42
43type DynTransport = Arc<dyn Transport + Send + Sync>;
45
46#[derive(Clone, Copy, Debug, Display, Eq, Hash, PartialEq)]
55#[display("{protocol}_{src_addr}_{dst_addr}")]
56pub struct FiveTuple {
57 pub protocol: u8,
61
62 pub src_addr: SocketAddr,
64
65 pub dst_addr: SocketAddr,
67}
68
69#[derive(Clone, Debug)]
73pub struct Info {
74 pub five_tuple: FiveTuple,
78
79 pub username: Username,
83
84 pub relayed_bytes: usize,
88}
89
90impl Info {
91 #[must_use]
93 pub const fn new(
94 five_tuple: FiveTuple,
95 username: Username,
96 relayed_bytes: usize,
97 ) -> Self {
98 Self { five_tuple, username, relayed_bytes }
99 }
100}
101
102#[derive(Debug)]
106pub(crate) struct Allocation {
107 relay_addr: SocketAddr,
109
110 relay_socket: Arc<UdpSocket>,
112
113 five_tuple: FiveTuple,
115
116 username: Username,
120
121 permissions: Arc<Mutex<HashMap<IpAddr, Permission>>>,
123
124 channel_bindings: Arc<Mutex<HashMap<u16, ChannelBind>>>,
126
127 refresh_tx: mpsc::Sender<Duration>,
130
131 relayed_bytes: AtomicUsize,
133
134 alloc_close_notify: Option<mpsc::Sender<Info>>,
136}
137
138impl Allocation {
139 pub(crate) fn new(
141 turn_socket: Arc<dyn Transport + Send + Sync>,
142 relay_socket: Arc<UdpSocket>,
143 relay_addr: SocketAddr,
144 five_tuple: FiveTuple,
145 lifetime: Duration,
146 username: Username,
147 alloc_close_notify: Option<mpsc::Sender<Info>>,
148 ) -> Self {
149 let (refresh_tx, refresh_rx) = mpsc::channel(1);
150
151 let this = Self {
152 relay_addr,
153 relay_socket,
154 five_tuple,
155 username,
156 permissions: Arc::new(Mutex::new(HashMap::new())),
157 channel_bindings: Arc::new(Mutex::new(HashMap::new())),
158 refresh_tx,
159 relayed_bytes: AtomicUsize::default(),
160 alloc_close_notify,
161 };
162
163 this.spawn_relay_handler(refresh_rx, lifetime, turn_socket);
164
165 this
166 }
167
168 pub(crate) fn is_alive(&self) -> bool {
171 !self.refresh_tx.is_closed()
172 }
173
174 pub(crate) async fn relay(
181 &self,
182 data: &[u8],
183 to: SocketAddr,
184 ) -> Result<(), Error> {
185 if !self.is_alive() {
186 return Err(Error::NoAllocationFound);
187 }
188
189 let n = self
190 .relay_socket
191 .send_to(data, to)
192 .await
193 .map_err(transport::Error::from)?;
194 _ = self.relayed_bytes.fetch_add(n, Ordering::AcqRel);
195 Ok(())
196 }
197
198 pub(crate) const fn relay_addr(&self) -> SocketAddr {
200 self.relay_addr
201 }
202
203 pub(crate) async fn has_permission(&self, addr: &SocketAddr) -> bool {
205 if !self.is_alive() {
206 return false;
207 }
208
209 self.permissions.lock().await.get(&addr.ip()).is_some()
210 }
211
212 pub(crate) async fn add_permission(&self, ip: IpAddr) {
214 if !self.is_alive() {
215 return;
216 }
217
218 let mut permissions = self.permissions.lock().await;
219 if let Some(existed_permission) = permissions.get(&ip) {
220 existed_permission.refresh(PERMISSION_LIFETIME).await;
221 } else {
222 let p = Permission::new(
223 ip,
224 Arc::clone(&self.permissions),
225 PERMISSION_LIFETIME,
226 );
227 drop(permissions.insert(p.ip(), p));
228 }
229 }
230
231 pub(crate) async fn add_channel_bind(
234 &self,
235 number: u16,
236 peer_addr: SocketAddr,
237 lifetime: Duration,
238 ) -> Result<(), Error> {
239 if !self.is_alive() {
240 return Err(Error::NoAllocationFound);
241 }
242
243 if let Some(addr) = self.get_channel_addr(&number).await {
246 if addr != peer_addr {
247 return Err(Error::SameChannelDifferentPeer);
248 }
249 }
250
251 if let Some(n) = self.get_channel_number(&peer_addr).await {
254 if number != n {
255 return Err(Error::SamePeerDifferentChannel);
256 }
257 }
258
259 let mut channel_bindings = self.channel_bindings.lock().await;
260 if let Some(cb) = channel_bindings.get(&number).cloned() {
261 drop(channel_bindings);
262
263 cb.refresh(lifetime).await;
264
265 self.add_permission(cb.peer().ip()).await;
267 } else {
268 let bind = ChannelBind::new(
269 number,
270 peer_addr,
271 Arc::clone(&self.channel_bindings),
272 lifetime,
273 );
274
275 drop(channel_bindings.insert(number, bind));
276 drop(channel_bindings);
277
278 self.add_permission(peer_addr.ip()).await;
280 }
281 Ok(())
282 }
283
284 pub(crate) async fn get_channel_addr(
286 &self,
287 number: &u16,
288 ) -> Option<SocketAddr> {
289 if !self.is_alive() {
290 return None;
291 }
292
293 self.channel_bindings.lock().await.get(number).map(ChannelBind::peer)
294 }
295
296 pub(crate) async fn get_channel_number(
299 &self,
300 addr: &SocketAddr,
301 ) -> Option<u16> {
302 if !self.is_alive() {
303 return None;
304 }
305 self.channel_bindings
306 .lock()
307 .await
308 .values()
309 .find_map(|b| (b.peer() == *addr).then_some(b.num()))
310 }
311
312 pub(crate) async fn refresh(&self, lifetime: Duration) {
314 _ = self.refresh_tx.send(lifetime).await;
315 }
316
317 #[expect(clippy::too_many_lines, reason = "needs refactoring")]
346 fn spawn_relay_handler(
347 &self,
348 mut refresh_rx: mpsc::Receiver<Duration>,
349 lifetime: Duration,
350 turn_socket: Arc<dyn Transport + Send + Sync>,
351 ) {
352 let five_tuple = self.five_tuple;
353 let relay_addr = self.relay_addr;
354 let relay_socket = Arc::clone(&self.relay_socket);
355 let channel_bindings = Arc::clone(&self.channel_bindings);
356 let permissions = Arc::clone(&self.permissions);
357
358 drop(tokio::spawn(async move {
359 log::trace!("Listening on relay addr: {relay_addr}");
360
361 let expired = sleep(lifetime);
362 tokio::pin!(expired);
363 let mut buf = vec![0u8; INBOUND_MTU];
364
365 loop {
366 let (recv_len, src_addr) = tokio::select! {
367 result = relay_socket.recv_from(
368 &mut buf[ChannelData::HEADER_SIZE..],
369 ) => {
370 if let Ok((n, src_addr)) = result {
371 (n, src_addr)
372 } else {
373 break;
374 }
375 }
376 () = &mut expired => {
377 break;
378 },
379 refresh = refresh_rx.recv() => {
380 match refresh {
381 Some(lf) => {
382 if lf == Duration::ZERO {
383 break;
384 }
385 expired.as_mut().reset(Instant::now() + lf);
386 continue;
387 }
388 None => {
389 break;
390 }
391 }
392 },
393 };
394
395 let cb_number = channel_bindings
396 .lock()
397 .await
398 .iter()
399 .find(|(_, cb)| cb.peer() == src_addr)
400 .map(|(cn, _)| *cn);
401
402 if let Some(number) = cb_number {
403 match ChannelData::encode(&mut buf, recv_len, number) {
404 Ok(n) => {
405 if let Err(e) = turn_socket
406 .send_to(
407 Cow::Borrowed(&buf[..n]),
408 five_tuple.src_addr,
409 )
410 .await
411 {
412 match e {
413 transport::Error::TransportIsDead => {
414 break;
415 }
416 transport::Error::ChannelData(..)
417 | transport::Error::Decode(..)
418 | transport::Error::Encode(..)
419 | transport::Error::Io(..) => {
420 log::warn!(
421 "Failed to send `ChannelData` from \
422 `Allocation(scr: {src_addr}`: {e}",
423 );
424 }
425 }
426 }
427 }
428 Err(e) => {
429 log::warn!(
430 "Failed to send `ChannelData` from \
431 `Allocation(src: {src_addr})`: {e}",
432 );
433 }
434 }
435 } else {
436 let has_permission =
437 permissions.lock().await.contains_key(&src_addr.ip());
438
439 if has_permission {
440 log::trace!(
441 "Relaying message from {src_addr} to client at {}",
442 five_tuple.src_addr,
443 );
444
445 let mut msg: Message<Attribute> = Message::new(
446 MessageClass::Indication,
447 DATA,
448 TransactionId::new(random()),
449 );
450 msg.add_attribute(XorPeerAddress::new(src_addr));
451
452 let data = buf[ChannelData::HEADER_SIZE
453 ..recv_len + ChannelData::HEADER_SIZE]
454 .to_vec();
455 let Ok(data) = Data::new(data) else {
456 log::error!("`DataIndication` is too long");
457 continue;
458 };
459 msg.add_attribute(data);
460
461 if let Err(e) = turn_socket
462 .send_msg_to(msg, five_tuple.src_addr)
463 .await
464 {
465 log::error!(
466 "Failed to send `DataIndication` from \
467 `Allocation(src: {src_addr})`: {e}",
468 );
469 }
470 } else {
471 log::info!(
472 "No `Permission` or `ChannelBind` exists for \
473 `{src_addr}` on `Allocation(relay: {relay_addr})`",
474 );
475 }
476 }
477 }
478
479 drop(mem::take(&mut *channel_bindings.lock().await));
480 drop(mem::take(&mut *permissions.lock().await));
481
482 log::trace!(
483 "`Allocation(five_tuple: {five_tuple})` stopped, stop \
484 `relay_handler`",
485 );
486 }));
487 }
488}
489
490impl Drop for Allocation {
491 fn drop(&mut self) {
492 if let Some(notify_tx) = self.alloc_close_notify.take() {
493 let info = Info {
494 five_tuple: self.five_tuple,
495 username: self.username.clone(),
496 relayed_bytes: self.relayed_bytes.load(Ordering::Acquire),
497 };
498
499 drop(tokio::spawn(async move {
500 drop(notify_tx.send(info).await);
501 }));
502 }
503 }
504}
505
506#[cfg(test)]
507mod spec {
508 use std::{
509 net::{Ipv4Addr, SocketAddr},
510 str::FromStr,
511 sync::Arc,
512 };
513
514 use tokio::net::UdpSocket;
515
516 use super::{Allocation, FiveTuple};
517 use crate::{
518 attr::{ChannelNumber, PROTO_UDP, Username},
519 server::DEFAULT_LIFETIME,
520 };
521
522 impl Default for FiveTuple {
523 fn default() -> Self {
524 FiveTuple {
525 protocol: PROTO_UDP,
526 src_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0),
527 dst_addr: SocketAddr::new(Ipv4Addr::new(0, 0, 0, 0).into(), 0),
528 }
529 }
530 }
531
532 #[tokio::test]
533 async fn has_permission() {
534 let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap());
535 let relay_socket = Arc::clone(&turn_socket);
536 let relay_addr = relay_socket.local_addr().unwrap();
537 let a = Allocation::new(
538 turn_socket,
539 relay_socket,
540 relay_addr,
541 FiveTuple::default(),
542 DEFAULT_LIFETIME,
543 Username::new(String::from("user")).unwrap(),
544 None,
545 );
546
547 let addr1 = SocketAddr::from_str("127.0.0.1:3478").unwrap();
548 let addr2 = SocketAddr::from_str("127.0.0.1:3479").unwrap();
549 let addr3 = SocketAddr::from_str("127.0.0.2:3478").unwrap();
550
551 a.add_permission(addr1.ip()).await;
552 a.add_permission(addr2.ip()).await;
553 a.add_permission(addr3.ip()).await;
554
555 let found_p1 = a.has_permission(&addr1).await;
556 assert!(found_p1, "should keep the first one");
557
558 let found_p2 = a.has_permission(&addr2).await;
559 assert!(found_p2, "second one should be ignored");
560
561 let found_p3 = a.has_permission(&addr3).await;
562 assert!(found_p3, "`Permission` with another IP should be found");
563 }
564
565 #[tokio::test]
566 async fn add_permission() {
567 let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap());
568 let relay_socket = Arc::clone(&turn_socket);
569 let relay_addr = relay_socket.local_addr().unwrap();
570 let a = Allocation::new(
571 turn_socket,
572 relay_socket,
573 relay_addr,
574 FiveTuple::default(),
575 DEFAULT_LIFETIME,
576 Username::new(String::from("user")).unwrap(),
577 None,
578 );
579
580 let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap();
581 a.add_permission(addr.ip()).await;
582
583 let found_p = a.has_permission(&addr).await;
584 assert!(found_p, "should keep the first one");
585 }
586
587 #[tokio::test]
588 async fn get_channel_by_number() {
589 let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap());
590 let relay_socket = Arc::clone(&turn_socket);
591 let relay_addr = relay_socket.local_addr().unwrap();
592 let a = Allocation::new(
593 turn_socket,
594 relay_socket,
595 relay_addr,
596 FiveTuple::default(),
597 DEFAULT_LIFETIME,
598 Username::new(String::from("user")).unwrap(),
599 None,
600 );
601
602 let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap();
603
604 a.add_channel_bind(ChannelNumber::MIN, addr, DEFAULT_LIFETIME)
605 .await
606 .unwrap();
607
608 let exist_channel_addr =
609 a.get_channel_addr(&ChannelNumber::MIN).await.unwrap();
610 assert_eq!(addr, exist_channel_addr);
611
612 let not_exist_channel =
613 a.get_channel_addr(&(ChannelNumber::MIN + 1)).await;
614 assert!(not_exist_channel.is_none(), "found, but shouldn't");
615 }
616
617 #[tokio::test]
618 async fn get_channel_by_addr() {
619 let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap());
620 let relay_socket = Arc::clone(&turn_socket);
621 let relay_addr = relay_socket.local_addr().unwrap();
622 let a = Allocation::new(
623 turn_socket,
624 relay_socket,
625 relay_addr,
626 FiveTuple::default(),
627 DEFAULT_LIFETIME,
628 Username::new(String::from("user")).unwrap(),
629 None,
630 );
631
632 let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap();
633 let addr2 = SocketAddr::from_str("127.0.0.1:3479").unwrap();
634
635 a.add_channel_bind(ChannelNumber::MIN, addr, DEFAULT_LIFETIME)
636 .await
637 .unwrap();
638
639 let exist_channel_number = a.get_channel_number(&addr).await.unwrap();
640 assert_eq!(ChannelNumber::MIN, exist_channel_number);
641
642 let not_exist_channel = a.get_channel_number(&addr2).await;
643 assert!(not_exist_channel.is_none(), "found, but shouldn't");
644 }
645
646 #[tokio::test]
647 async fn closing() {
648 let turn_socket = Arc::new(UdpSocket::bind("0.0.0.0:0").await.unwrap());
649 let relay_socket = Arc::clone(&turn_socket);
650 let relay_addr = relay_socket.local_addr().unwrap();
651 let a = Allocation::new(
652 turn_socket,
653 relay_socket,
654 relay_addr,
655 FiveTuple::default(),
656 DEFAULT_LIFETIME,
657 Username::new(String::from("user")).unwrap(),
658 None,
659 );
660
661 let addr = SocketAddr::from_str("127.0.0.1:3478").unwrap();
662 a.add_channel_bind(ChannelNumber::MIN, addr, DEFAULT_LIFETIME)
663 .await
664 .unwrap();
665 a.add_permission(addr.ip()).await;
666 }
667}
668
669#[cfg(test)]
670mod five_tuple_spec {
671 use std::net::SocketAddr;
672
673 use crate::{
674 FiveTuple,
675 attr::{PROTO_TCP, PROTO_UDP},
676 };
677
678 #[test]
679 fn equality() {
680 let src_addr1: SocketAddr =
681 "0.0.0.0:3478".parse::<SocketAddr>().unwrap();
682 let src_addr2: SocketAddr =
683 "0.0.0.0:3479".parse::<SocketAddr>().unwrap();
684
685 let dst_addr1: SocketAddr =
686 "0.0.0.0:3480".parse::<SocketAddr>().unwrap();
687 let dst_addr2: SocketAddr =
688 "0.0.0.0:3481".parse::<SocketAddr>().unwrap();
689
690 let tests = [
691 (
692 "Equal",
693 true,
694 FiveTuple {
695 protocol: PROTO_UDP,
696 src_addr: src_addr1,
697 dst_addr: dst_addr1,
698 },
699 FiveTuple {
700 protocol: PROTO_UDP,
701 src_addr: src_addr1,
702 dst_addr: dst_addr1,
703 },
704 ),
705 (
706 "DifferentProtocol",
707 false,
708 FiveTuple {
709 protocol: PROTO_TCP,
710 src_addr: src_addr1,
711 dst_addr: dst_addr1,
712 },
713 FiveTuple {
714 protocol: PROTO_UDP,
715 src_addr: src_addr1,
716 dst_addr: dst_addr1,
717 },
718 ),
719 (
720 "DifferentSrcAddr",
721 false,
722 FiveTuple {
723 protocol: PROTO_UDP,
724 src_addr: src_addr1,
725 dst_addr: dst_addr1,
726 },
727 FiveTuple {
728 protocol: PROTO_UDP,
729 src_addr: src_addr2,
730 dst_addr: dst_addr1,
731 },
732 ),
733 (
734 "DifferentDstAddr",
735 false,
736 FiveTuple {
737 protocol: PROTO_UDP,
738 src_addr: src_addr1,
739 dst_addr: dst_addr1,
740 },
741 FiveTuple {
742 protocol: PROTO_UDP,
743 src_addr: src_addr1,
744 dst_addr: dst_addr2,
745 },
746 ),
747 ];
748 for (name, expect, a, b) in tests {
749 let fact = a == b;
750 assert_eq!(
751 expect, fact,
752 "{name}: {a}, {b} equal check should be {expect}, but {fact}",
753 );
754 }
755 }
756}