turn_server/turn/
sessions.rs

1use super::Observer;
2use crate::stun::util::long_term_credential_digest;
3
4use std::{
5    hash::Hash,
6    net::SocketAddr,
7    ops::{Deref, DerefMut, Range},
8    sync::{
9        Arc,
10        atomic::{AtomicU64, Ordering},
11    },
12    thread::{self, sleep},
13    time::Duration,
14};
15
16use ahash::{HashMap, HashMapExt};
17use parking_lot::{Mutex, RwLock, RwLockReadGuard};
18use rand::{Rng, distributions::Alphanumeric, thread_rng};
19
20/// Authentication information for the session.
21///
22/// Digest data is data that summarises usernames and passwords by means of
23/// long-term authentication.
24#[derive(Debug, Clone)]
25pub struct Auth {
26    pub username: String,
27    pub integrity: [u8; 16],
28}
29
30/// Assignment information for the session.
31///
32/// Sessions are all bound to only one port and one channel.
33#[derive(Debug, Clone)]
34pub struct Allocate {
35    pub port: Option<u16>,
36    pub channels: Vec<u16>,
37}
38
39/// turn session information.
40///
41/// A user can have many sessions.
42///
43/// The default survival time for a session is 600 seconds.
44#[derive(Debug, Clone)]
45pub struct Session {
46    pub auth: Auth,
47    pub allocate: Allocate,
48    pub permissions: Vec<u16>,
49    pub expires: u64,
50}
51
52/// The identifier of the session or addr.
53///
54/// Each session needs to be identified by a combination of three pieces of
55/// information: the addr address, and the transport protocol.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub struct SessionAddr {
58    pub address: SocketAddr,
59    pub interface: SocketAddr,
60}
61
62/// The addr used to record the current session.
63///
64/// This is used when forwarding data.
65#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
66pub struct Endpoint {
67    pub address: SocketAddr,
68    pub endpoint: SocketAddr,
69}
70
71/// A specially optimised timer.
72///
73/// This timer does not stack automatically and needs to be stacked externally
74/// and manually.
75///
76/// ```
77/// use turn_server::turn::sessions::Timer;
78///
79/// let timer = Timer::default();
80///
81/// assert_eq!(timer.get(), 0);
82/// assert_eq!(timer.add(), 1);
83/// assert_eq!(timer.get(), 1);
84/// ```
85#[derive(Default)]
86pub struct Timer(AtomicU64);
87
88impl Timer {
89    pub fn get(&self) -> u64 {
90        self.0.load(Ordering::Relaxed)
91    }
92
93    pub fn add(&self) -> u64 {
94        self.0.fetch_add(1, Ordering::Relaxed) + 1
95    }
96}
97
98#[derive(Default)]
99pub struct State {
100    sessions: RwLock<Table<SessionAddr, Session>>,
101    port_allocate_pool: Mutex<PortAllocatePools>,
102    // Records the sessions corresponding to each assigned port, which will be needed when looking
103    // up sessions assigned to this port based on the port number.
104    port_mapping_table: RwLock<Table</* port */ u16, SessionAddr>>,
105    // Records the nonce value for each network connection, which is independent of the session
106    // because it can exist before it is authenticated.
107    address_nonce_tanle: RwLock<Table<SessionAddr, (String, /* expires */ u64)>>,
108    // Stores the address to which the session should be forwarded when it sends indication to a
109    // port. This is written when permissions are created to allow a certain address to be
110    // forwarded to the current session.
111    port_relay_table: RwLock<Table<SessionAddr, HashMap</* port */ u16, Endpoint>>>,
112    // Indicates to which session the data sent by a session to a channel should be forwarded.
113    channel_relay_table: RwLock<Table<SessionAddr, HashMap</* channel */ u16, Endpoint>>>,
114}
115
116pub struct Sessions<T> {
117    timer: Timer,
118    state: State,
119    observer: T,
120}
121
122impl<T: Observer + 'static> Sessions<T> {
123    pub fn new(observer: T) -> Arc<Self> {
124        let this = Arc::new(Self {
125            state: State::default(),
126            timer: Timer::default(),
127            observer,
128        });
129
130        // This is a background thread that silently handles expiring sessions and
131        // cleans up session information when it expires.
132        let this_ = Arc::downgrade(&this);
133        thread::spawn(move || {
134            let mut address = Vec::with_capacity(255);
135
136            while let Some(this) = this_.upgrade() {
137                // The timer advances one second and gets the current time offset.
138                let now = this.timer.add();
139
140                // This is the part that deletes the session information.
141                {
142                    // Finds sessions that have expired.
143                    {
144                        this.state
145                            .sessions
146                            .read()
147                            .iter()
148                            .filter(|(_, v)| v.expires <= now)
149                            .for_each(|(k, _)| address.push(*k));
150                    }
151
152                    // Delete the expired sessions.
153                    if !address.is_empty() {
154                        this.remove_session(&address);
155                        address.clear();
156                    }
157                }
158
159                // Because nonce does not follow session creation, nonce is created for each
160                // addr, so nonce deletion is handled independently.
161                {
162                    this.state
163                        .address_nonce_tanle
164                        .read()
165                        .iter()
166                        .filter(|(_, v)| v.1 <= now)
167                        .for_each(|(k, _)| address.push(*k));
168
169                    if !address.is_empty() {
170                        this.remove_nonce(&address);
171                        address.clear();
172                    }
173                }
174
175                // Fixing a second tick.
176                sleep(Duration::from_secs(1));
177            }
178        });
179
180        this
181    }
182
183    fn remove_session(&self, addrs: &[SessionAddr]) {
184        let mut sessions = self.state.sessions.write();
185        let mut port_allocate_pool = self.state.port_allocate_pool.lock();
186        let mut port_mapping_table = self.state.port_mapping_table.write();
187        let mut port_relay_table = self.state.port_relay_table.write();
188        let mut channel_relay_table = self.state.channel_relay_table.write();
189
190        addrs.iter().for_each(|k| {
191            port_relay_table.remove(k);
192            channel_relay_table.remove(k);
193
194            if let Some(session) = sessions.remove(k) {
195                // Removes the session-bound port from the port binding table and
196                // releases the port back into the allocation pool.
197                if let Some(port) = session.allocate.port {
198                    port_mapping_table.remove(&port);
199                    port_allocate_pool.restore(port);
200                }
201
202                // Notifies that the external session has been closed.
203                self.observer.closed(k, &session.auth.username);
204            }
205        });
206    }
207
208    fn remove_nonce(&self, addrs: &[SessionAddr]) {
209        let mut address_nonce_tanle = self.state.address_nonce_tanle.write();
210
211        addrs.iter().for_each(|k| {
212            address_nonce_tanle.remove(k);
213        });
214    }
215
216    /// Get session for addr.
217    ///
218    /// # Test
219    ///
220    /// ```
221    /// use turn_server::turn::*;
222    ///
223    /// #[derive(Clone)]
224    /// struct ObserverTest;
225    ///
226    /// impl Observer for ObserverTest {
227    ///     fn get_password(&self, username: &str) -> Option<String> {
228    ///         if username == "test" {
229    ///             Some("test".to_string())
230    ///         } else {
231    ///             None
232    ///         }
233    ///     }
234    /// }
235    ///
236    /// let addr = SessionAddr {
237    ///     address: "127.0.0.1:8080".parse().unwrap(),
238    ///     interface: "127.0.0.1:3478".parse().unwrap(),
239    /// };
240    ///
241    /// let digest = [
242    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
243    ///     239,
244    /// ];
245    ///
246    /// let sessions = Sessions::new(ObserverTest);
247    ///
248    /// assert!(sessions.get_session(&addr).get_ref().is_none());
249    ///
250    /// sessions.get_integrity(&addr, "test", "test");
251    ///
252    /// let lock = sessions.get_session(&addr);
253    /// let session = lock.get_ref().unwrap();
254    /// assert_eq!(session.auth.username, "test");
255    /// assert_eq!(session.allocate.port, None);
256    /// assert_eq!(session.allocate.channels.len(), 0);
257    /// ```
258    pub fn get_session<'a, 'b>(
259        &'a self,
260        key: &'b SessionAddr,
261    ) -> ReadLock<'b, 'a, SessionAddr, Table<SessionAddr, Session>> {
262        ReadLock {
263            lock: self.state.sessions.read(),
264            key,
265        }
266    }
267
268    /// Get nonce for addr.
269    ///
270    /// # Test
271    ///
272    /// ```
273    /// use turn_server::turn::*;
274    ///
275    /// #[derive(Clone)]
276    /// struct ObserverTest;
277    ///
278    /// impl Observer for ObserverTest {}
279    ///
280    /// let addr = SessionAddr {
281    ///     address: "127.0.0.1:8080".parse().unwrap(),
282    ///     interface: "127.0.0.1:3478".parse().unwrap(),
283    /// };
284    ///
285    /// let sessions = Sessions::new(ObserverTest);
286    ///
287    /// let a = sessions.get_nonce(&addr).get_ref().unwrap().clone();
288    /// assert!(a.0.len() == 16);
289    /// assert!(a.1 == 600 || a.1 == 601 || a.1 == 602);
290    ///
291    /// let b = sessions.get_nonce(&addr).get_ref().unwrap().clone();
292    /// assert_eq!(a.0, b.0);
293    /// assert!(b.1 == 600 || b.1 == 601 || b.1 == 602);
294    /// ```
295    pub fn get_nonce<'a, 'b>(
296        &'a self,
297        key: &'b SessionAddr,
298    ) -> ReadLock<'b, 'a, SessionAddr, Table<SessionAddr, (String, u64)>> {
299        // If no nonce is created, create a new one.
300        {
301            if !self.state.address_nonce_tanle.read().contains_key(key) {
302                self.state.address_nonce_tanle.write().insert(
303                    *key,
304                    (
305                        // A random string of length 16.
306                        {
307                            let mut rng = thread_rng();
308                            std::iter::repeat(())
309                                .map(|_| rng.sample(Alphanumeric) as char)
310                                .take(16)
311                                .collect::<String>()
312                                .to_lowercase()
313                        },
314                        // Current time stacks for 600 seconds.
315                        self.timer.get() + 600,
316                    ),
317                );
318            }
319        }
320
321        ReadLock {
322            lock: self.state.address_nonce_tanle.read(),
323            key,
324        }
325    }
326
327    /// Get digest for addr.
328    ///
329    /// # Test
330    ///
331    /// ```
332    /// use turn_server::turn::*;
333    ///
334    /// #[derive(Clone)]
335    /// struct ObserverTest;
336    ///
337    /// impl Observer for ObserverTest {
338    ///     fn get_password(&self, username: &str) -> Option<String> {
339    ///         if username == "test" {
340    ///             Some("test".to_string())
341    ///         } else {
342    ///             None
343    ///         }
344    ///     }
345    /// }
346    ///
347    /// let addr = SessionAddr {
348    ///     address: "127.0.0.1:8080".parse().unwrap(),
349    ///     interface: "127.0.0.1:3478".parse().unwrap(),
350    /// };
351    ///
352    /// let digest = [
353    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
354    ///     239,
355    /// ];
356    ///
357    /// let sessions = Sessions::new(ObserverTest);
358    ///
359    /// assert_eq!(sessions.get_integrity(&addr, "test1", "test"), None);
360    ///
361    /// assert_eq!(sessions.get_integrity(&addr, "test", "test"), Some(digest));
362    ///
363    /// assert_eq!(sessions.get_integrity(&addr, "test", "test"), Some(digest));
364    /// ```
365    pub fn get_integrity(&self, addr: &SessionAddr, username: &str, realm: &str) -> Option<[u8; 16]> {
366        // Already authenticated, get the cached digest directly.
367        {
368            if let Some(it) = self.state.sessions.read().get(addr) {
369                return Some(it.auth.integrity);
370            }
371        }
372
373        // Get the current user's password from an external observer and create a
374        // digest.
375        let password = self.observer.get_password(username)?;
376        let integrity = long_term_credential_digest(&username, &password, realm);
377
378        // Record a new session.
379        {
380            self.state.sessions.write().insert(
381                *addr,
382                Session {
383                    permissions: Vec::with_capacity(10),
384                    expires: self.timer.get() + 600,
385                    auth: Auth {
386                        username: username.to_string(),
387                        integrity,
388                    },
389                    allocate: Allocate {
390                        channels: Vec::with_capacity(10),
391                        port: None,
392                    },
393                },
394            );
395        }
396
397        Some(integrity)
398    }
399
400    pub fn allocated(&self) -> usize {
401        self.state.port_allocate_pool.lock().len()
402    }
403
404    /// Assign a port number to the session.
405    ///
406    /// # Test
407    ///
408    /// ```
409    /// use turn_server::turn::*;
410    ///
411    /// #[derive(Clone)]
412    /// struct ObserverTest;
413    ///
414    /// impl Observer for ObserverTest {
415    ///     fn get_password(&self, username: &str) -> Option<String> {
416    ///         if username == "test" {
417    ///             Some("test".to_string())
418    ///         } else {
419    ///             None
420    ///         }
421    ///     }
422    /// }
423    ///
424    /// let addr = SessionAddr {
425    ///     address: "127.0.0.1:8080".parse().unwrap(),
426    ///     interface: "127.0.0.1:3478".parse().unwrap(),
427    /// };
428    ///
429    /// let digest = [
430    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
431    ///     239,
432    /// ];
433    ///
434    /// let sessions = Sessions::new(ObserverTest);
435    ///
436    /// sessions.get_integrity(&addr, "test", "test");
437    ///
438    /// {
439    ///     let lock = sessions.get_session(&addr);
440    ///     let session = lock.get_ref().unwrap();
441    ///     assert_eq!(session.auth.username, "test");
442    ///     assert_eq!(session.allocate.port, None);
443    ///     assert_eq!(session.allocate.channels.len(), 0);
444    /// }
445    ///
446    /// let port = sessions.allocate(&addr).unwrap();
447    /// {
448    ///     let lock = sessions.get_session(&addr);
449    ///     let session = lock.get_ref().unwrap();
450    ///     assert_eq!(session.auth.username, "test");
451    ///     assert_eq!(session.allocate.port, Some(port));
452    ///     assert_eq!(session.allocate.channels.len(), 0);
453    /// }
454    ///
455    /// assert_eq!(sessions.allocate(&addr), Some(port));
456    /// ```
457    pub fn allocate(&self, addr: &SessionAddr) -> Option<u16> {
458        let mut lock = self.state.sessions.write();
459        let session = lock.get_mut(addr)?;
460
461        // If the port has already been allocated, re-allocation is not allowed.
462        if let Some(port) = session.allocate.port {
463            return Some(port);
464        }
465
466        // Records the port assigned to the current session and resets the alive time.
467        let port = self.state.port_allocate_pool.lock().alloc(None)?;
468        session.expires = self.timer.get() + 600;
469        session.allocate.port = Some(port);
470
471        // Write the allocation port binding table.
472        self.state.port_mapping_table.write().insert(port, *addr);
473        Some(port)
474    }
475
476    /// Create permission for session.
477    ///
478    /// # Test
479    ///
480    /// ```
481    /// use turn_server::turn::*;
482    ///
483    /// #[derive(Clone)]
484    /// struct ObserverTest;
485    ///
486    /// impl Observer for ObserverTest {
487    ///     fn get_password(&self, username: &str) -> Option<String> {
488    ///         if username == "test" {
489    ///             Some("test".to_string())
490    ///         } else {
491    ///             None
492    ///         }
493    ///     }
494    /// }
495    ///
496    /// let endpoint = "127.0.0.1:3478".parse().unwrap();
497    /// let addr = SessionAddr {
498    ///     address: "127.0.0.1:8080".parse().unwrap(),
499    ///     interface: "127.0.0.1:3478".parse().unwrap(),
500    /// };
501    ///
502    /// let peer_addr = SessionAddr {
503    ///     address: "127.0.0.1:8081".parse().unwrap(),
504    ///     interface: "127.0.0.1:3478".parse().unwrap(),
505    /// };
506    ///
507    /// let digest = [
508    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
509    ///     239,
510    /// ];
511    ///
512    /// let sessions = Sessions::new(ObserverTest);
513    ///
514    /// sessions.get_integrity(&addr, "test", "test");
515    /// sessions.get_integrity(&peer_addr, "test", "test");
516    ///
517    /// let port = sessions.allocate(&addr).unwrap();
518    /// let peer_port = sessions.allocate(&peer_addr).unwrap();
519    ///
520    /// assert!(!sessions.create_permission(&addr, &endpoint, &[port]));
521    /// assert!(sessions.create_permission(&addr, &endpoint, &[peer_port]));
522    ///
523    /// assert!(!sessions.create_permission(&peer_addr, &endpoint, &[peer_port]));
524    /// assert!(sessions.create_permission(&peer_addr, &endpoint, &[port]));
525    /// ```
526    pub fn create_permission(&self, addr: &SessionAddr, endpoint: &SocketAddr, ports: &[u16]) -> bool {
527        let mut sessions = self.state.sessions.write();
528        let mut port_relay_table = self.state.port_relay_table.write();
529        let port_mapping_table = self.state.port_mapping_table.read();
530
531        // Finds information about the current session.
532        let session = if let Some(it) = sessions.get_mut(addr) {
533            it
534        } else {
535            return false;
536        };
537
538        // The port number assigned to the current session.
539        let local_port = if let Some(it) = session.allocate.port {
540            it
541        } else {
542            return false;
543        };
544
545        // You cannot create permissions for yourself.
546        if ports.contains(&local_port) {
547            return false;
548        }
549
550        // Each peer port must be present.
551        let mut peers = Vec::with_capacity(15);
552        for port in ports {
553            if let Some(it) = port_mapping_table.get(&port) {
554                peers.push((it, *port));
555            } else {
556                return false;
557            }
558        }
559
560        // Create a port forwarding mapping relationship for each peer session.
561        for (peer, port) in peers {
562            port_relay_table
563                .entry(*peer)
564                .or_insert_with(|| HashMap::with_capacity(20))
565                .insert(
566                    local_port,
567                    Endpoint {
568                        address: addr.address,
569                        endpoint: *endpoint,
570                    },
571                );
572
573            // Do not store the same peer ports to the permission list over and over again.
574            if !session.permissions.contains(&port) {
575                session.permissions.push(port);
576            }
577        }
578
579        true
580    }
581
582    /// Binding a channel to the session.
583    ///
584    /// # Test
585    ///
586    /// ```
587    /// use turn_server::turn::*;
588    ///
589    /// #[derive(Clone)]
590    /// struct ObserverTest;
591    ///
592    /// impl Observer for ObserverTest {
593    ///     fn get_password(&self, username: &str) -> Option<String> {
594    ///         if username == "test" {
595    ///             Some("test".to_string())
596    ///         } else {
597    ///             None
598    ///         }
599    ///     }
600    /// }
601    ///
602    /// let endpoint = "127.0.0.1:3478".parse().unwrap();
603    /// let addr = SessionAddr {
604    ///     address: "127.0.0.1:8080".parse().unwrap(),
605    ///     interface: "127.0.0.1:3478".parse().unwrap(),
606    /// };
607    ///
608    /// let peer_addr = SessionAddr {
609    ///     address: "127.0.0.1:8081".parse().unwrap(),
610    ///     interface: "127.0.0.1:3478".parse().unwrap(),
611    /// };
612    ///
613    /// let digest = [
614    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
615    ///     239,
616    /// ];
617    ///
618    /// let sessions = Sessions::new(ObserverTest);
619    ///
620    /// sessions.get_integrity(&addr, "test", "test");
621    /// sessions.get_integrity(&peer_addr, "test", "test");
622    ///
623    /// let port = sessions.allocate(&addr).unwrap();
624    /// let peer_port = sessions.allocate(&peer_addr).unwrap();
625    /// assert_eq!(
626    ///     sessions
627    ///         .get_session(&addr)
628    ///         .get_ref()
629    ///         .unwrap()
630    ///         .allocate
631    ///         .channels
632    ///         .len(),
633    ///     0
634    /// );
635    ///
636    /// assert_eq!(
637    ///     sessions
638    ///         .get_session(&peer_addr)
639    ///         .get_ref()
640    ///         .unwrap()
641    ///         .allocate
642    ///         .channels
643    ///         .len(),
644    ///     0
645    /// );
646    ///
647    /// assert!(sessions.bind_channel(&addr, &endpoint, peer_port, 0x4000));
648    /// assert!(sessions.bind_channel(&peer_addr, &endpoint, port, 0x4000));
649    /// assert_eq!(
650    ///     sessions
651    ///         .get_session(&addr)
652    ///         .get_ref()
653    ///         .unwrap()
654    ///         .allocate
655    ///         .channels,
656    ///     vec![0x4000]
657    /// );
658    ///
659    /// assert_eq!(
660    ///     sessions
661    ///         .get_session(&peer_addr)
662    ///         .get_ref()
663    ///         .unwrap()
664    ///         .allocate
665    ///         .channels,
666    ///     vec![0x4000]
667    /// );
668    /// ```
669    pub fn bind_channel(&self, addr: &SessionAddr, endpoint: &SocketAddr, port: u16, channel: u16) -> bool {
670        // Finds the address of the bound opposing port.
671        let peer = if let Some(it) = self.state.port_mapping_table.read().get(&port) {
672            *it
673        } else {
674            return false;
675        };
676
677        // Records the channel used for the current session.
678        {
679            let mut lock = self.state.sessions.write();
680            let session = if let Some(it) = lock.get_mut(addr) {
681                it
682            } else {
683                return false;
684            };
685
686            if !session.allocate.channels.contains(&channel) {
687                session.allocate.channels.push(channel);
688            }
689        }
690
691        // Binding ports also creates permissions.
692        if !self.create_permission(addr, endpoint, &[port]) {
693            return false;
694        }
695
696        // Create channel forwarding mapping relationships for peers.
697        self.state
698            .channel_relay_table
699            .write()
700            .entry(peer)
701            .or_insert_with(|| HashMap::with_capacity(10))
702            .insert(
703                channel,
704                Endpoint {
705                    address: addr.address,
706                    endpoint: *endpoint,
707                },
708            );
709
710        true
711    }
712
713    /// Gets the peer of the current session bound channel.
714    ///
715    /// # Test
716    ///
717    /// ```
718    /// use turn_server::turn::*;
719    ///
720    /// #[derive(Clone)]
721    /// struct ObserverTest;
722    ///
723    /// impl Observer for ObserverTest {
724    ///     fn get_password(&self, username: &str) -> Option<String> {
725    ///         if username == "test" {
726    ///             Some("test".to_string())
727    ///         } else {
728    ///             None
729    ///         }
730    ///     }
731    /// }
732    ///
733    /// let endpoint = "127.0.0.1:3478".parse().unwrap();
734    /// let addr = SessionAddr {
735    ///     address: "127.0.0.1:8080".parse().unwrap(),
736    ///     interface: "127.0.0.1:3478".parse().unwrap(),
737    /// };
738    ///
739    /// let peer_addr = SessionAddr {
740    ///     address: "127.0.0.1:8081".parse().unwrap(),
741    ///     interface: "127.0.0.1:3478".parse().unwrap(),
742    /// };
743    ///
744    /// let digest = [
745    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
746    ///     239,
747    /// ];
748    ///
749    /// let sessions = Sessions::new(ObserverTest);
750    ///
751    /// sessions.get_integrity(&addr, "test", "test");
752    /// sessions.get_integrity(&peer_addr, "test", "test");
753    ///
754    /// let port = sessions.allocate(&addr).unwrap();
755    /// let peer_port = sessions.allocate(&peer_addr).unwrap();
756    ///
757    /// assert!(sessions.bind_channel(&addr, &endpoint, peer_port, 0x4000));
758    /// assert!(sessions.bind_channel(&peer_addr, &endpoint, port, 0x4000));
759    /// assert_eq!(
760    ///     sessions
761    ///         .get_channel_relay_address(&addr, 0x4000)
762    ///         .unwrap()
763    ///         .endpoint,
764    ///     endpoint
765    /// );
766    ///
767    /// assert_eq!(
768    ///     sessions
769    ///         .get_channel_relay_address(&peer_addr, 0x4000)
770    ///         .unwrap()
771    ///         .endpoint,
772    ///     endpoint
773    /// );
774    /// ```
775    pub fn get_channel_relay_address(&self, addr: &SessionAddr, channel: u16) -> Option<Endpoint> {
776        self.state.channel_relay_table.read().get(&addr)?.get(&channel).copied()
777    }
778
779    /// Get the address of the port binding.
780    ///
781    /// # Test
782    ///
783    /// ```
784    /// use turn_server::turn::*;
785    ///
786    /// #[derive(Clone)]
787    /// struct ObserverTest;
788    ///
789    /// impl Observer for ObserverTest {
790    ///     fn get_password(&self, username: &str) -> Option<String> {
791    ///         if username == "test" {
792    ///             Some("test".to_string())
793    ///         } else {
794    ///             None
795    ///         }
796    ///     }
797    /// }
798    ///
799    /// let endpoint = "127.0.0.1:3478".parse().unwrap();
800    /// let addr = SessionAddr {
801    ///     address: "127.0.0.1:8080".parse().unwrap(),
802    ///     interface: "127.0.0.1:3478".parse().unwrap(),
803    /// };
804    ///
805    /// let peer_addr = SessionAddr {
806    ///     address: "127.0.0.1:8081".parse().unwrap(),
807    ///     interface: "127.0.0.1:3478".parse().unwrap(),
808    /// };
809    ///
810    /// let digest = [
811    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
812    ///     239,
813    /// ];
814    ///
815    /// let sessions = Sessions::new(ObserverTest);
816    ///
817    /// sessions.get_integrity(&addr, "test", "test");
818    /// sessions.get_integrity(&peer_addr, "test", "test");
819    ///
820    /// let port = sessions.allocate(&addr).unwrap();
821    /// let peer_port = sessions.allocate(&peer_addr).unwrap();
822    ///
823    /// assert!(sessions.create_permission(&addr, &endpoint, &[peer_port]));
824    /// assert!(sessions.create_permission(&peer_addr, &endpoint, &[port]));
825    ///
826    /// assert_eq!(
827    ///     sessions
828    ///         .get_relay_address(&addr, peer_port)
829    ///         .unwrap()
830    ///         .endpoint,
831    ///     endpoint
832    /// );
833    ///
834    /// assert_eq!(
835    ///     sessions
836    ///         .get_relay_address(&peer_addr, port)
837    ///         .unwrap()
838    ///         .endpoint,
839    ///     endpoint
840    /// );
841    /// ```
842    pub fn get_relay_address(&self, addr: &SessionAddr, port: u16) -> Option<Endpoint> {
843        self.state.port_relay_table.read().get(&addr)?.get(&port).copied()
844    }
845
846    /// Refresh the session for addr.
847    ///
848    /// # Test
849    ///
850    /// ```
851    /// use turn_server::turn::*;
852    ///
853    /// #[derive(Clone)]
854    /// struct ObserverTest;
855    ///
856    /// impl Observer for ObserverTest {
857    ///     fn get_password(&self, username: &str) -> Option<String> {
858    ///         if username == "test" {
859    ///             Some("test".to_string())
860    ///         } else {
861    ///             None
862    ///         }
863    ///     }
864    /// }
865    ///
866    /// let addr = SessionAddr {
867    ///     address: "127.0.0.1:8080".parse().unwrap(),
868    ///     interface: "127.0.0.1:3478".parse().unwrap(),
869    /// };
870    ///
871    /// let digest = [
872    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
873    ///     239,
874    /// ];
875    ///
876    /// let sessions = Sessions::new(ObserverTest);
877    ///
878    /// assert!(sessions.get_session(&addr).get_ref().is_none());
879    ///
880    /// sessions.get_integrity(&addr, "test", "test");
881    ///
882    /// let expires = sessions.get_session(&addr).get_ref().unwrap().expires;
883    /// assert!(expires == 600 || expires == 601 || expires == 602);
884    ///
885    /// assert!(sessions.refresh(&addr, 0));
886    ///
887    /// assert!(sessions.get_session(&addr).get_ref().is_none());
888    /// ```
889    pub fn refresh(&self, addr: &SessionAddr, lifetime: u32) -> bool {
890        if lifetime > 3600 {
891            return false;
892        }
893
894        if lifetime == 0 {
895            self.remove_session(&[*addr]);
896            self.remove_nonce(&[*addr]);
897        } else {
898            if let Some(session) = self.state.sessions.write().get_mut(addr) {
899                session.expires = self.timer.get() + lifetime as u64;
900            } else {
901                return false;
902            }
903
904            if let Some(nonce) = self.state.address_nonce_tanle.write().get_mut(addr) {
905                nonce.1 = self.timer.get() + lifetime as u64;
906            }
907        }
908
909        true
910    }
911}
912
913/// The default HashMap is created without allocating capacity. To improve
914/// performance, the turn server needs to pre-allocate the available capacity.
915///
916/// So here the HashMap is rewrapped to allocate a large capacity (number of
917/// ports that can be allocated) at the default creation time as well.
918pub struct Table<K, V>(HashMap<K, V>);
919
920impl<K, V> Default for Table<K, V> {
921    fn default() -> Self {
922        Self(HashMap::with_capacity(PortAllocatePools::capacity()))
923    }
924}
925
926impl<K, V> AsRef<HashMap<K, V>> for Table<K, V> {
927    fn as_ref(&self) -> &HashMap<K, V> {
928        &self.0
929    }
930}
931
932impl<K, V> Deref for Table<K, V> {
933    type Target = HashMap<K, V>;
934
935    fn deref(&self) -> &Self::Target {
936        &self.0
937    }
938}
939
940impl<K, V> DerefMut for Table<K, V> {
941    fn deref_mut(&mut self) -> &mut Self::Target {
942        &mut self.0
943    }
944}
945
946/// Used to lengthen the timing of the release of a readable lock guard and to
947/// provide a more convenient way for external access to the lock's internal
948/// data.
949pub struct ReadLock<'a, 'b, K, R> {
950    key: &'a K,
951    lock: RwLockReadGuard<'b, R>,
952}
953
954impl<'a, 'b, K, V> ReadLock<'a, 'b, K, Table<K, V>>
955where
956    K: Eq + Hash,
957{
958    pub fn get_ref(&self) -> Option<&V> {
959        self.lock.get(self.key)
960    }
961}
962
963/// Bit Flag
964#[derive(PartialEq, Eq)]
965pub enum Bit {
966    Low,
967    High,
968}
969
970/// Random Port
971///
972/// Recently, awareness has been raised about a number of "blind" attacks
973/// (i.e., attacks that can be performed without the need to sniff the
974/// packets that correspond to the transport protocol instance to be
975/// attacked) that can be performed against the Transmission Control
976/// Protocol (TCP) [RFC0793] and similar protocols.  The consequences of
977/// these attacks range from throughput reduction to broken connections
978/// or data corruption [RFC5927] [RFC4953] [Watson].
979///
980/// All these attacks rely on the attacker's ability to guess or know the
981/// five-tuple (Protocol, Source Address, Source port, Destination
982/// Address, Destination Port) that identifies the transport protocol
983/// instance to be attacked.
984///
985/// Services are usually located at fixed, "well-known" ports [IANA] at
986/// the host supplying the service (the server).  Client applications
987/// connecting to any such service will contact the server by specifying
988/// the server IP address and service port number.  The IP address and
989/// port number of the client are normally left unspecified by the client
990/// application and thus are chosen automatically by the client
991/// networking stack.  Ports chosen automatically by the networking stack
992/// are known as ephemeral ports [Stevens].
993///
994/// While the server IP address, the well-known port, and the client IP
995/// address may be known by an attacker, the ephemeral port of the client
996/// is usually unknown and must be guessed.
997///
998/// # Test
999///
1000/// ```
1001/// use std::collections::HashSet;
1002/// use turn_server::turn::sessions::*;
1003///
1004/// let mut pool = PortAllocatePools::default();
1005/// let mut ports = HashSet::with_capacity(PortAllocatePools::capacity());
1006///
1007/// while let Some(port) = pool.alloc(None) {
1008///     ports.insert(port);
1009/// }
1010///
1011/// assert_eq!(PortAllocatePools::capacity() + 1, ports.len());
1012/// ```
1013pub struct PortAllocatePools {
1014    pub buckets: Vec<u64>,
1015    allocated: usize,
1016    bit_len: u32,
1017    peak: usize,
1018}
1019
1020impl Default for PortAllocatePools {
1021    fn default() -> Self {
1022        Self {
1023            buckets: vec![0; Self::bucket_size()],
1024            peak: Self::bucket_size() - 1,
1025            bit_len: Self::bit_len(),
1026            allocated: 0,
1027        }
1028    }
1029}
1030
1031impl PortAllocatePools {
1032    /// compute bucket size.
1033    ///
1034    /// # Test
1035    ///
1036    /// ```
1037    /// use turn_server::turn::sessions::*;
1038    ///
1039    /// assert_eq!(PortAllocatePools::bucket_size(), 256);
1040    /// ```
1041    pub fn bucket_size() -> usize {
1042        (Self::capacity() as f32 / 64.0).ceil() as usize
1043    }
1044
1045    /// compute bucket last bit max offset.
1046    ///
1047    /// # Test
1048    ///
1049    /// ```
1050    /// use turn_server::turn::sessions::*;
1051    ///
1052    /// assert_eq!(PortAllocatePools::bit_len(), 63);
1053    /// ```
1054    pub fn bit_len() -> u32 {
1055        (Self::capacity() as f32 % 64.0).ceil() as u32
1056    }
1057
1058    /// get pools capacity.
1059    ///
1060    /// # Test
1061    ///
1062    /// ```
1063    /// use turn_server::turn::sessions::Bit;
1064    /// use turn_server::turn::sessions::PortAllocatePools;
1065    ///
1066    /// assert_eq!(PortAllocatePools::capacity(), 65535 - 49152);
1067    /// ```
1068    pub const fn capacity() -> usize {
1069        65535 - 49152
1070    }
1071
1072    /// get port range.
1073    ///
1074    /// # Test
1075    ///
1076    /// ```
1077    /// use turn_server::turn::sessions::*;
1078    ///
1079    /// assert_eq!(PortAllocatePools::port_range(), 49152..65535);
1080    /// ```
1081    pub const fn port_range() -> Range<u16> {
1082        49152..65535
1083    }
1084
1085    /// get pools allocated size.
1086    ///
1087    /// ```
1088    /// use turn_server::turn::sessions::PortAllocatePools;
1089    ///
1090    /// let mut pools = PortAllocatePools::default();
1091    /// assert_eq!(pools.len(), 0);
1092    ///
1093    /// pools.alloc(None).unwrap();
1094    /// assert_eq!(pools.len(), 1);
1095    /// ```
1096    pub fn len(&self) -> usize {
1097        self.allocated
1098    }
1099
1100    /// get pools allocated size is empty.
1101    ///
1102    /// ```
1103    /// use turn_server::turn::sessions::PortAllocatePools;
1104    ///
1105    /// let mut pools = PortAllocatePools::default();
1106    /// assert_eq!(pools.len(), 0);
1107    /// assert_eq!(pools.is_empty(), true);
1108    /// ```
1109    pub fn is_empty(&self) -> bool {
1110        self.allocated == 0
1111    }
1112
1113    /// random assign a port.
1114    ///
1115    /// # Test
1116    ///
1117    /// ```
1118    /// use turn_server::turn::sessions::PortAllocatePools;
1119    ///
1120    /// let mut pool = PortAllocatePools::default();
1121    ///
1122    /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1123    /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1124    ///
1125    /// assert!(pool.alloc(None).is_some());
1126    /// ```
1127    pub fn alloc(&mut self, start_index: Option<usize>) -> Option<u16> {
1128        let mut index = None;
1129        let mut start = start_index.unwrap_or_else(|| thread_rng().gen_range(0..self.peak as u16) as usize);
1130
1131        // When the partition lookup has gone through the entire partition list, the
1132        // lookup should be stopped, and the location where it should be stopped is
1133        // recorded here.
1134        let previous = if start == 0 { self.peak } else { start - 1 };
1135
1136        loop {
1137            // Finds the first high position in the partition.
1138            if let Some(i) = {
1139                let bucket = self.buckets[start];
1140                if bucket < u64::MAX {
1141                    let offset = bucket.leading_ones();
1142
1143                    // Check to see if the jump is beyond the partition list or the lookup exceeds
1144                    // the maximum length of the allocation table.
1145                    if start == self.peak && offset > self.bit_len {
1146                        None
1147                    } else {
1148                        Some(offset)
1149                    }
1150                } else {
1151                    None
1152                }
1153            } {
1154                index = Some(i as usize);
1155                break;
1156            }
1157
1158            // As long as it doesn't find it, it continues to re-find it from the next
1159            // partition.
1160            if start == self.peak {
1161                start = 0;
1162            } else {
1163                start += 1;
1164            }
1165
1166            // Already gone through all partitions, lookup failed.
1167            if start == previous {
1168                break;
1169            }
1170        }
1171
1172        // Writes to the partition, marking the current location as already allocated.
1173        let index = index?;
1174        self.set_bit(start, index, Bit::High);
1175        self.allocated += 1;
1176
1177        // The actual port number is calculated from the partition offset position.
1178        let num = (start * 64 + index) as u16;
1179        let port = Self::port_range().start + num;
1180        Some(port)
1181    }
1182
1183    /// write bit flag in the bucket.
1184    ///
1185    /// # Test
1186    ///
1187    /// ```
1188    /// use turn_server::turn::sessions::Bit;
1189    /// use turn_server::turn::sessions::PortAllocatePools;
1190    ///
1191    /// let mut pool = PortAllocatePools::default();
1192    ///
1193    /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1194    /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1195    ///
1196    /// pool.set_bit(0, 0, Bit::High);
1197    /// pool.set_bit(0, 1, Bit::High);
1198    ///
1199    /// assert_eq!(pool.alloc(Some(0)), Some(49154));
1200    /// assert_eq!(pool.alloc(Some(0)), Some(49155));
1201    /// ```
1202    pub fn set_bit(&mut self, bucket: usize, index: usize, bit: Bit) {
1203        let high_mask = 1 << (63 - index);
1204        let mask = match bit {
1205            Bit::Low => u64::MAX ^ high_mask,
1206            Bit::High => high_mask,
1207        };
1208
1209        let value = self.buckets[bucket];
1210        self.buckets[bucket] = match bit {
1211            Bit::High => value | mask,
1212            Bit::Low => value & mask,
1213        };
1214    }
1215
1216    /// restore port in the buckets.
1217    ///
1218    /// # Test
1219    ///
1220    /// ```
1221    /// use turn_server::turn::sessions::PortAllocatePools;
1222    ///
1223    /// let mut pool = PortAllocatePools::default();
1224    ///
1225    /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1226    /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1227    ///
1228    /// pool.restore(49152);
1229    /// pool.restore(49153);
1230    ///
1231    /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1232    /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1233    /// ```
1234    pub fn restore(&mut self, port: u16) {
1235        assert!(Self::port_range().contains(&port));
1236
1237        // Calculate the location in the partition from the port number.
1238        let offset = (port - Self::port_range().start) as usize;
1239        let bucket = offset / 64;
1240        let index = offset - (bucket * 64);
1241
1242        // Gets the bit value in the port position in the partition, if it is low, no
1243        // processing is required.
1244        if {
1245            match (self.buckets[bucket] & (1 << (63 - index))) >> (63 - index) {
1246                0 => Bit::Low,
1247                1 => Bit::High,
1248                _ => panic!(),
1249            }
1250        } == Bit::Low
1251        {
1252            return;
1253        }
1254
1255        self.set_bit(bucket, index, Bit::Low);
1256        self.allocated -= 1;
1257    }
1258}