turn_server/router.rs
1use std::{net::SocketAddr, sync::Arc};
2
3use ahash::AHashMap;
4use parking_lot::RwLock;
5use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender, unbounded_channel};
6
7use crate::turn::ResponseMethod;
8
9type Receiver = UnboundedSender<(Vec<u8>, ResponseMethod, SocketAddr)>;
10
11/// Handles packet forwarding between transport protocols.
12#[derive(Clone)]
13pub struct Router(Arc<RwLock<AHashMap<SocketAddr, Receiver>>>);
14
15impl Default for Router {
16 fn default() -> Self {
17 Self(Arc::new(RwLock::new(AHashMap::with_capacity(1024))))
18 }
19}
20
21impl Router {
22 /// Get the socket reader for the route.
23 ///
24 /// Each transport protocol is layered according to its own socket, and
25 /// the data forwarded to this socket can be obtained by routing.
26 ///
27 /// # Example
28 ///
29 /// ```
30 /// use std::net::SocketAddr;
31 /// use turn_server::router::*;
32 /// use turn_server::turn::ResponseMethod;
33 ///
34 /// #[tokio::main]
35 /// async fn main() {
36 /// let addr = "127.0.0.1:8080".parse::<SocketAddr>().unwrap();
37 /// let router = Router::default();
38 /// let mut receiver = router.get_receiver(addr);
39 ///
40 /// router.send(&addr, ResponseMethod::ChannelData, &addr, &[1, 2, 3]);
41 /// let ret = receiver.recv().await.unwrap();
42 /// assert_eq!(ret.0, vec![1, 2, 3]);
43 /// assert_eq!(ret.1, ResponseMethod::ChannelData);
44 /// assert_eq!(ret.2, addr);
45 /// }
46 /// ```
47 pub fn get_receiver(&self, interface: SocketAddr) -> UnboundedReceiver<(Vec<u8>, ResponseMethod, SocketAddr)> {
48 let (sender, receiver) = unbounded_channel();
49 self.0.write().insert(interface, sender);
50 receiver
51 }
52
53 /// Send data to router.
54 ///
55 /// By specifying the socket identifier and destination address, the route
56 /// is forwarded to the corresponding socket. However, it should be noted
57 /// that calling this function will not notify whether the socket exists.
58 /// If it does not exist, the data will be discarded by default.
59 ///
60 /// # Example
61 ///
62 /// ```
63 /// use std::net::SocketAddr;
64 /// use turn_server::router::*;
65 /// use turn_server::turn::ResponseMethod;
66 ///
67 /// #[tokio::main]
68 /// async fn main() {
69 /// let addr = "127.0.0.1:8080".parse::<SocketAddr>().unwrap();
70 /// let router = Router::default();
71 /// let mut receiver = router.get_receiver(addr);
72 ///
73 /// router.send(&addr, ResponseMethod::ChannelData, &addr, &[1, 2, 3]);
74 /// let ret = receiver.recv().await.unwrap();
75 /// assert_eq!(ret.0, vec![1, 2, 3]);
76 /// assert_eq!(ret.1, ResponseMethod::ChannelData);
77 /// assert_eq!(ret.2, addr);
78 /// }
79 /// ```
80 pub fn send(&self, interface: &SocketAddr, method: ResponseMethod, addr: &SocketAddr, data: &[u8]) {
81 let mut is_destroy = false;
82
83 {
84 if let Some(sender) = self.0.read().get(interface) {
85 if sender.send((data.to_vec(), method, *addr)).is_err() {
86 is_destroy = true;
87 }
88 }
89 }
90
91 if is_destroy {
92 self.remove(interface);
93 }
94 }
95
96 /// delete socket.
97 ///
98 /// # Example
99 ///
100 /// ```
101 /// use std::net::SocketAddr;
102 /// use turn_server::router::*;
103 /// use turn_server::turn::ResponseMethod;
104 ///
105 /// #[tokio::main]
106 /// async fn main() {
107 /// let addr = "127.0.0.1:8080".parse::<SocketAddr>().unwrap();
108 /// let router = Router::default();
109 /// let mut receiver = router.get_receiver(addr);
110 ///
111 /// router.send(&addr, ResponseMethod::ChannelData, &addr, &[1, 2, 3]);
112 /// let ret = receiver.recv().await.unwrap();
113 /// assert_eq!(ret.0, vec![1, 2, 3]);
114 /// assert_eq!(ret.1, ResponseMethod::ChannelData);
115 /// assert_eq!(ret.2, addr);
116 ///
117 /// router.remove(&addr);
118 /// assert!(receiver.recv().await.is_none());
119 /// }
120 /// ```
121 pub fn remove(&self, interface: &SocketAddr) {
122 drop(self.0.write().remove(interface))
123 }
124}