mycrl_turn/
sessions.rs

1use crate::Observer;
2
3use std::{
4    hash::Hash,
5    net::SocketAddr,
6    ops::{Deref, DerefMut, Range},
7    sync::{
8        atomic::{AtomicU64, Ordering},
9        Arc,
10    },
11    thread::{self, sleep},
12    time::Duration,
13};
14
15use ahash::{HashMap, HashMapExt};
16use parking_lot::{Mutex, RwLock, RwLockReadGuard};
17use rand::{distributions::Alphanumeric, thread_rng, Rng};
18use stun::util::long_term_credential_digest;
19
20/// Authentication information for the session.
21///
22/// Digest data is data that summarises usernames and passwords by means of
23/// long-term authentication.
24#[derive(Debug, Clone)]
25pub struct Auth {
26    pub username: String,
27    pub password: String,
28    pub digest: [u8; 16],
29}
30
31/// Assignment information for the session.
32///
33/// Sessions are all bound to only one port and one channel.
34#[derive(Debug, Clone)]
35pub struct Allocate {
36    pub port: Option<u16>,
37    pub channels: Vec<u16>,
38}
39
40/// turn session information.
41///
42/// A user can have many sessions.
43///
44/// The default survival time for a session is 600 seconds.
45#[derive(Debug, Clone)]
46pub struct Session {
47    pub auth: Auth,
48    pub allocate: Allocate,
49    pub permissions: Vec<u16>,
50    pub expires: u64,
51}
52
53/// The identifier of the session or addr.
54///
55/// Each session needs to be identified by a combination of three pieces of
56/// information: the addr address, and the transport protocol.
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
58pub struct SessionAddr {
59    pub address: SocketAddr,
60    pub interface: SocketAddr,
61}
62
63/// The addr used to record the current session.
64///
65/// This is used when forwarding data.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67pub struct Endpoint {
68    pub address: SocketAddr,
69    pub endpoint: SocketAddr,
70}
71
72/// A specially optimised timer.
73///
74/// This timer does not stack automatically and needs to be stacked externally
75/// and manually.
76///
77/// ```
78/// use turn::sessions::Timer;
79///
80/// let timer = Timer::default();
81///
82/// assert_eq!(timer.get(), 0);
83/// assert_eq!(timer.add(), 1);
84/// assert_eq!(timer.get(), 1);
85/// ```
86#[derive(Default)]
87pub struct Timer(AtomicU64);
88
89impl Timer {
90    pub fn get(&self) -> u64 {
91        self.0.load(Ordering::Relaxed)
92    }
93
94    pub fn add(&self) -> u64 {
95        self.0.fetch_add(1, Ordering::Relaxed) + 1
96    }
97}
98
99#[derive(Default)]
100pub struct State {
101    sessions: RwLock<Table<SessionAddr, Session>>,
102    port_allocate_pool: Mutex<PortAllocatePools>,
103    // Records the sessions corresponding to each assigned port, which will be needed when looking
104    // up sessions assigned to this port based on the port number.
105    port_mapping_table: RwLock<Table</* port */ u16, SessionAddr>>,
106    // Records the nonce value for each network connection, which is independent of the session
107    // because it can exist before it is authenticated.
108    address_nonce_tanle: RwLock<Table<SessionAddr, (String, /* expires */ u64)>>,
109    // Stores the address to which the session should be forwarded when it sends indication to a
110    // port. This is written when permissions are created to allow a certain address to be
111    // forwarded to the current session.
112    port_relay_table: RwLock<Table<SessionAddr, HashMap</* port */ u16, Endpoint>>>,
113    // Indicates to which session the data sent by a session to a channel should be forwarded.
114    channel_relay_table: RwLock<Table<SessionAddr, HashMap</* channel */ u16, Endpoint>>>,
115}
116
117pub struct Sessions<T> {
118    timer: Timer,
119    state: State,
120    observer: T,
121}
122
123impl<T: Observer + 'static> Sessions<T> {
124    pub fn new(observer: T) -> Arc<Self> {
125        let this = Arc::new(Self {
126            state: State::default(),
127            timer: Timer::default(),
128            observer,
129        });
130
131        // This is a background thread that silently handles expiring sessions and
132        // cleans up session information when it expires.
133        let this_ = Arc::downgrade(&this);
134        thread::spawn(move || {
135            let mut address = Vec::with_capacity(255);
136
137            while let Some(this) = this_.upgrade() {
138                // The timer advances one second and gets the current time offset.
139                let now = this.timer.add();
140
141                // This is the part that deletes the session information.
142                {
143                    // Finds sessions that have expired.
144                    {
145                        this.state
146                            .sessions
147                            .read()
148                            .iter()
149                            .filter(|(_, v)| v.expires <= now)
150                            .for_each(|(k, _)| address.push(*k));
151                    }
152
153                    // Delete the expired sessions.
154                    if !address.is_empty() {
155                        this.remove_session(&address);
156                        address.clear();
157                    }
158                }
159
160                // Because nonce does not follow session creation, nonce is created for each
161                // addr, so nonce deletion is handled independently.
162                {
163                    this.state
164                        .address_nonce_tanle
165                        .read()
166                        .iter()
167                        .filter(|(_, v)| v.1 <= now)
168                        .for_each(|(k, _)| address.push(*k));
169
170                    if !address.is_empty() {
171                        this.remove_nonce(&address);
172                        address.clear();
173                    }
174                }
175
176                // Fixing a second tick.
177                sleep(Duration::from_secs(1));
178            }
179        });
180
181        this
182    }
183
184    fn remove_session(&self, addrs: &[SessionAddr]) {
185        let mut sessions = self.state.sessions.write();
186        let mut port_allocate_pool = self.state.port_allocate_pool.lock();
187        let mut port_mapping_table = self.state.port_mapping_table.write();
188        let mut port_relay_table = self.state.port_relay_table.write();
189        let mut channel_relay_table = self.state.channel_relay_table.write();
190
191        addrs.iter().for_each(|k| {
192            port_relay_table.remove(k);
193            channel_relay_table.remove(k);
194
195            if let Some(session) = sessions.remove(k) {
196                // Removes the session-bound port from the port binding table and
197                // releases the port back into the allocation pool.
198                if let Some(port) = session.allocate.port {
199                    port_mapping_table.remove(&port);
200                    port_allocate_pool.restore(port);
201                }
202
203                // Notifies that the external session has been closed.
204                self.observer.closed(k, &session.auth.username);
205            }
206        });
207    }
208
209    fn remove_nonce(&self, addrs: &[SessionAddr]) {
210        let mut address_nonce_tanle = self.state.address_nonce_tanle.write();
211
212        addrs.iter().for_each(|k| {
213            address_nonce_tanle.remove(k);
214        });
215    }
216
217    /// Get session for addr.
218    ///
219    /// # Test
220    ///
221    /// ```
222    /// use turn::*;
223    ///
224    /// #[derive(Clone)]
225    /// struct ObserverTest;
226    ///
227    /// impl Observer for ObserverTest {
228    ///     async fn get_password(
229    ///         &self,
230    ///         addr: &SessionAddr,
231    ///         username: &str,
232    ///     ) -> Option<String> {
233    ///         if username == "test" {
234    ///             Some("test".to_string())
235    ///         } else {
236    ///             None
237    ///         }
238    ///     }
239    /// }
240    ///
241    /// let addr = SessionAddr {
242    ///     address: "127.0.0.1:8080".parse().unwrap(),
243    ///     interface: "127.0.0.1:3478".parse().unwrap(),
244    /// };
245    ///
246    /// let digest = [
247    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
248    ///     239,
249    /// ];
250    ///
251    /// let sessions = Sessions::new(ObserverTest);
252    ///
253    /// assert!(sessions.get_session(&addr).get_ref().is_none());
254    ///
255    /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
256    ///
257    /// let lock = sessions.get_session(&addr);
258    /// let session = lock.get_ref().unwrap();
259    /// assert_eq!(session.auth.username, "test");
260    /// assert_eq!(session.auth.password, "test");
261    /// assert_eq!(session.allocate.port, None);
262    /// assert_eq!(session.allocate.channels.len(), 0);
263    /// ```
264    pub fn get_session<'a, 'b>(
265        &'a self,
266        key: &'b SessionAddr,
267    ) -> ReadLock<'b, 'a, SessionAddr, Table<SessionAddr, Session>> {
268        ReadLock {
269            lock: self.state.sessions.read(),
270            key,
271        }
272    }
273
274    /// Get nonce for addr.
275    ///
276    /// # Test
277    ///
278    /// ```
279    /// use turn::*;
280    ///
281    /// #[derive(Clone)]
282    /// struct ObserverTest;
283    ///
284    /// impl Observer for ObserverTest {}
285    ///
286    /// let addr = SessionAddr {
287    ///     address: "127.0.0.1:8080".parse().unwrap(),
288    ///     interface: "127.0.0.1:3478".parse().unwrap(),
289    /// };
290    ///
291    /// let sessions = Sessions::new(ObserverTest);
292    ///
293    /// let a = sessions.get_nonce(&addr).get_ref().unwrap().clone();
294    /// assert!(a.0.len() == 16);
295    /// assert!(a.1 == 600 || a.1 == 601 || a.1 == 602);
296    ///
297    /// let b = sessions.get_nonce(&addr).get_ref().unwrap().clone();
298    /// assert_eq!(a.0, b.0);
299    /// assert!(b.1 == 600 || b.1 == 601 || b.1 == 602);
300    /// ```
301    pub fn get_nonce<'a, 'b>(
302        &'a self,
303        key: &'b SessionAddr,
304    ) -> ReadLock<'b, 'a, SessionAddr, Table<SessionAddr, (String, u64)>> {
305        // If no nonce is created, create a new one.
306        {
307            if !self.state.address_nonce_tanle.read().contains_key(key) {
308                self.state.address_nonce_tanle.write().insert(
309                    *key,
310                    (
311                        // A random string of length 16.
312                        {
313                            let mut rng = thread_rng();
314                            std::iter::repeat(())
315                                .map(|_| rng.sample(Alphanumeric) as char)
316                                .take(16)
317                                .collect::<String>()
318                                .to_lowercase()
319                        },
320                        // Current time stacks for 600 seconds.
321                        self.timer.get() + 600,
322                    ),
323                );
324            }
325        }
326
327        ReadLock {
328            lock: self.state.address_nonce_tanle.read(),
329            key,
330        }
331    }
332
333    /// Get digest for addr.
334    ///
335    /// # Test
336    ///
337    /// ```
338    /// use turn::*;
339    ///
340    /// #[derive(Clone)]
341    /// struct ObserverTest;
342    ///
343    /// impl Observer for ObserverTest {
344    ///     async fn get_password(
345    ///         &self,
346    ///         addr: &SessionAddr,
347    ///         username: &str,
348    ///     ) -> Option<String> {
349    ///         if username == "test" {
350    ///             Some("test".to_string())
351    ///         } else {
352    ///             None
353    ///         }
354    ///     }
355    /// }
356    ///
357    /// let addr = SessionAddr {
358    ///     address: "127.0.0.1:8080".parse().unwrap(),
359    ///     interface: "127.0.0.1:3478".parse().unwrap(),
360    /// };
361    ///
362    /// let digest = [
363    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
364    ///     239,
365    /// ];
366    ///
367    /// let sessions = Sessions::new(ObserverTest);
368    ///
369    /// assert_eq!(
370    ///     pollster::block_on(sessions.get_digest(&addr, "test1", "test")),
371    ///     None
372    /// );
373    ///
374    /// assert_eq!(
375    ///     pollster::block_on(sessions.get_digest(&addr, "test", "test")),
376    ///     Some(digest)
377    /// );
378    ///
379    /// assert_eq!(
380    ///     pollster::block_on(sessions.get_digest(&addr, "test", "test")),
381    ///     Some(digest)
382    /// );
383    /// ```
384    pub async fn get_digest(
385        &self,
386        addr: &SessionAddr,
387        username: &str,
388        realm: &str,
389    ) -> Option<[u8; 16]> {
390        // Already authenticated, get the cached digest directly.
391        {
392            if let Some(it) = self.state.sessions.read().get(addr) {
393                return Some(it.auth.digest);
394            }
395        }
396
397        // Get the current user's password from an external observer and create a
398        // digest.
399        let password = self.observer.get_password(addr, username).await?;
400        let digest = long_term_credential_digest(&username, &password, realm);
401
402        // Record a new session.
403        {
404            self.state.sessions.write().insert(
405                *addr,
406                Session {
407                    permissions: Vec::with_capacity(10),
408                    expires: self.timer.get() + 600,
409                    auth: Auth {
410                        username: username.to_string(),
411                        password,
412                        digest,
413                    },
414                    allocate: Allocate {
415                        channels: Vec::with_capacity(10),
416                        port: None,
417                    },
418                },
419            );
420        }
421
422        Some(digest)
423    }
424
425    pub fn allocated(&self) -> usize {
426        self.state.port_allocate_pool.lock().len()
427    }
428
429    /// Assign a port number to the session.
430    ///
431    /// # Test
432    ///
433    /// ```
434    /// use turn::*;
435    ///
436    /// #[derive(Clone)]
437    /// struct ObserverTest;
438    ///
439    /// impl Observer for ObserverTest {
440    ///     async fn get_password(
441    ///         &self,
442    ///         addr: &SessionAddr,
443    ///         username: &str,
444    ///     ) -> Option<String> {
445    ///         if username == "test" {
446    ///             Some("test".to_string())
447    ///         } else {
448    ///             None
449    ///         }
450    ///     }
451    /// }
452    ///
453    /// let addr = SessionAddr {
454    ///     address: "127.0.0.1:8080".parse().unwrap(),
455    ///     interface: "127.0.0.1:3478".parse().unwrap(),
456    /// };
457    ///
458    /// let digest = [
459    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
460    ///     239,
461    /// ];
462    ///
463    /// let sessions = Sessions::new(ObserverTest);
464    ///
465    /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
466    ///
467    /// {
468    ///     let lock = sessions.get_session(&addr);
469    ///     let session = lock.get_ref().unwrap();
470    ///     assert_eq!(session.auth.username, "test");
471    ///     assert_eq!(session.auth.password, "test");
472    ///     assert_eq!(session.allocate.port, None);
473    ///     assert_eq!(session.allocate.channels.len(), 0);
474    /// }
475    ///
476    /// let port = sessions.allocate(&addr).unwrap();
477    /// {
478    ///     let lock = sessions.get_session(&addr);
479    ///     let session = lock.get_ref().unwrap();
480    ///     assert_eq!(session.auth.username, "test");
481    ///     assert_eq!(session.auth.password, "test");
482    ///     assert_eq!(session.allocate.port, Some(port));
483    ///     assert_eq!(session.allocate.channels.len(), 0);
484    /// }
485    ///
486    /// assert_eq!(sessions.allocate(&addr), Some(port));
487    /// ```
488    pub fn allocate(&self, addr: &SessionAddr) -> Option<u16> {
489        let mut lock = self.state.sessions.write();
490        let session = lock.get_mut(addr)?;
491
492        // If the port has already been allocated, re-allocation is not allowed.
493        if let Some(port) = session.allocate.port {
494            return Some(port);
495        }
496
497        // Records the port assigned to the current session and resets the alive time.
498        let port = self.state.port_allocate_pool.lock().alloc(None)?;
499        session.expires = self.timer.get() + 600;
500        session.allocate.port = Some(port);
501
502        // Write the allocation port binding table.
503        self.state.port_mapping_table.write().insert(port, *addr);
504        Some(port)
505    }
506
507    /// Create permission for session.
508    ///
509    /// # Test
510    ///
511    /// ```
512    /// use turn::*;
513    ///
514    /// #[derive(Clone)]
515    /// struct ObserverTest;
516    ///
517    /// impl Observer for ObserverTest {
518    ///     async fn get_password(
519    ///         &self,
520    ///         addr: &SessionAddr,
521    ///         username: &str,
522    ///     ) -> Option<String> {
523    ///         if username == "test" {
524    ///             Some("test".to_string())
525    ///         } else {
526    ///             None
527    ///         }
528    ///     }
529    /// }
530    ///
531    /// let endpoint = "127.0.0.1:3478".parse().unwrap();
532    /// let addr = SessionAddr {
533    ///     address: "127.0.0.1:8080".parse().unwrap(),
534    ///     interface: "127.0.0.1:3478".parse().unwrap(),
535    /// };
536    ///
537    /// let peer_addr = SessionAddr {
538    ///     address: "127.0.0.1:8081".parse().unwrap(),
539    ///     interface: "127.0.0.1:3478".parse().unwrap(),
540    /// };
541    ///
542    /// let digest = [
543    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
544    ///     239,
545    /// ];
546    ///
547    /// let sessions = Sessions::new(ObserverTest);
548    ///
549    /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
550    /// pollster::block_on(sessions.get_digest(&peer_addr, "test", "test"));
551    ///
552    /// let port = sessions.allocate(&addr).unwrap();
553    /// let peer_port = sessions.allocate(&peer_addr).unwrap();
554    ///
555    /// assert!(!sessions.create_permission(&addr, &endpoint, &[port]));
556    /// assert!(sessions.create_permission(&addr, &endpoint, &[peer_port]));
557    ///
558    /// assert!(!sessions.create_permission(&peer_addr, &endpoint, &[peer_port]));
559    /// assert!(sessions.create_permission(&peer_addr, &endpoint, &[port]));
560    /// ```
561    pub fn create_permission(
562        &self,
563        addr: &SessionAddr,
564        endpoint: &SocketAddr,
565        ports: &[u16],
566    ) -> bool {
567        let mut sessions = self.state.sessions.write();
568        let mut port_relay_table = self.state.port_relay_table.write();
569        let port_mapping_table = self.state.port_mapping_table.read();
570
571        // Finds information about the current session.
572        let session = if let Some(it) = sessions.get_mut(addr) {
573            it
574        } else {
575            return false;
576        };
577
578        // The port number assigned to the current session.
579        let local_port = if let Some(it) = session.allocate.port {
580            it
581        } else {
582            return false;
583        };
584
585        // You cannot create permissions for yourself.
586        if ports.contains(&local_port) {
587            return false;
588        }
589
590        // Each peer port must be present.
591        let mut peers = Vec::with_capacity(15);
592        for port in ports {
593            if let Some(it) = port_mapping_table.get(&port) {
594                peers.push((it, *port));
595            } else {
596                return false;
597            }
598        }
599
600        // Create a port forwarding mapping relationship for each peer session.
601        for (peer, port) in peers {
602            port_relay_table
603                .entry(*peer)
604                .or_insert_with(|| HashMap::with_capacity(20))
605                .insert(
606                    local_port,
607                    Endpoint {
608                        address: addr.address,
609                        endpoint: *endpoint,
610                    },
611                );
612
613            // Do not store the same peer ports to the permission list over and over again.
614            if !session.permissions.contains(&port) {
615                session.permissions.push(port);
616            }
617        }
618
619        true
620    }
621
622    /// Binding a channel to the session.
623    ///
624    /// # Test
625    ///
626    /// ```
627    /// use turn::*;
628    ///
629    /// #[derive(Clone)]
630    /// struct ObserverTest;
631    ///
632    /// impl Observer for ObserverTest {
633    ///     async fn get_password(
634    ///         &self,
635    ///         addr: &SessionAddr,
636    ///         username: &str,
637    ///     ) -> Option<String> {
638    ///         if username == "test" {
639    ///             Some("test".to_string())
640    ///         } else {
641    ///             None
642    ///         }
643    ///     }
644    /// }
645    ///
646    /// let endpoint = "127.0.0.1:3478".parse().unwrap();
647    /// let addr = SessionAddr {
648    ///     address: "127.0.0.1:8080".parse().unwrap(),
649    ///     interface: "127.0.0.1:3478".parse().unwrap(),
650    /// };
651    ///
652    /// let peer_addr = SessionAddr {
653    ///     address: "127.0.0.1:8081".parse().unwrap(),
654    ///     interface: "127.0.0.1:3478".parse().unwrap(),
655    /// };
656    ///
657    /// let digest = [
658    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
659    ///     239,
660    /// ];
661    ///
662    /// let sessions = Sessions::new(ObserverTest);
663    ///
664    /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
665    /// pollster::block_on(sessions.get_digest(&peer_addr, "test", "test"));
666    ///
667    /// let port = sessions.allocate(&addr).unwrap();
668    /// let peer_port = sessions.allocate(&peer_addr).unwrap();
669    /// assert_eq!(
670    ///     sessions
671    ///         .get_session(&addr)
672    ///         .get_ref()
673    ///         .unwrap()
674    ///         .allocate
675    ///         .channels
676    ///         .len(),
677    ///     0
678    /// );
679    ///
680    /// assert_eq!(
681    ///     sessions
682    ///         .get_session(&peer_addr)
683    ///         .get_ref()
684    ///         .unwrap()
685    ///         .allocate
686    ///         .channels
687    ///         .len(),
688    ///     0
689    /// );
690    ///
691    /// assert!(sessions.bind_channel(&addr, &endpoint, peer_port, 0x4000));
692    /// assert!(sessions.bind_channel(&peer_addr, &endpoint, port, 0x4000));
693    /// assert_eq!(
694    ///     sessions
695    ///         .get_session(&addr)
696    ///         .get_ref()
697    ///         .unwrap()
698    ///         .allocate
699    ///         .channels,
700    ///     vec![0x4000]
701    /// );
702    ///
703    /// assert_eq!(
704    ///     sessions
705    ///         .get_session(&peer_addr)
706    ///         .get_ref()
707    ///         .unwrap()
708    ///         .allocate
709    ///         .channels,
710    ///     vec![0x4000]
711    /// );
712    /// ```
713    pub fn bind_channel(
714        &self,
715        addr: &SessionAddr,
716        endpoint: &SocketAddr,
717        port: u16,
718        channel: u16,
719    ) -> bool {
720        // Finds the address of the bound opposing port.
721        let peer = if let Some(it) = self.state.port_mapping_table.read().get(&port) {
722            *it
723        } else {
724            return false;
725        };
726
727        // Records the channel used for the current session.
728        {
729            let mut lock = self.state.sessions.write();
730            let session = if let Some(it) = lock.get_mut(addr) {
731                it
732            } else {
733                return false;
734            };
735
736            if !session.allocate.channels.contains(&channel) {
737                session.allocate.channels.push(channel);
738            }
739        }
740
741        // Binding ports also creates permissions.
742        if !self.create_permission(addr, endpoint, &[port]) {
743            return false;
744        }
745
746        // Create channel forwarding mapping relationships for peers.
747        self.state
748            .channel_relay_table
749            .write()
750            .entry(peer)
751            .or_insert_with(|| HashMap::with_capacity(10))
752            .insert(
753                channel,
754                Endpoint {
755                    address: addr.address,
756                    endpoint: *endpoint,
757                },
758            );
759
760        true
761    }
762
763    /// Gets the peer of the current session bound channel.
764    ///
765    /// # Test
766    ///
767    /// ```
768    /// use turn::*;
769    ///
770    /// #[derive(Clone)]
771    /// struct ObserverTest;
772    ///
773    /// impl Observer for ObserverTest {
774    ///     async fn get_password(
775    ///         &self,
776    ///         addr: &SessionAddr,
777    ///         username: &str,
778    ///     ) -> Option<String> {
779    ///         if username == "test" {
780    ///             Some("test".to_string())
781    ///         } else {
782    ///             None
783    ///         }
784    ///     }
785    /// }
786    ///
787    /// let endpoint = "127.0.0.1:3478".parse().unwrap();
788    /// let addr = SessionAddr {
789    ///     address: "127.0.0.1:8080".parse().unwrap(),
790    ///     interface: "127.0.0.1:3478".parse().unwrap(),
791    /// };
792    ///
793    /// let peer_addr = SessionAddr {
794    ///     address: "127.0.0.1:8081".parse().unwrap(),
795    ///     interface: "127.0.0.1:3478".parse().unwrap(),
796    /// };
797    ///
798    /// let digest = [
799    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
800    ///     239,
801    /// ];
802    ///
803    /// let sessions = Sessions::new(ObserverTest);
804    ///
805    /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
806    /// pollster::block_on(sessions.get_digest(&peer_addr, "test", "test"));
807    ///
808    /// let port = sessions.allocate(&addr).unwrap();
809    /// let peer_port = sessions.allocate(&peer_addr).unwrap();
810    ///
811    /// assert!(sessions.bind_channel(&addr, &endpoint, peer_port, 0x4000));
812    /// assert!(sessions.bind_channel(&peer_addr, &endpoint, port, 0x4000));
813    /// assert_eq!(
814    ///     sessions
815    ///         .get_channel_relay_address(&addr, 0x4000)
816    ///         .unwrap()
817    ///         .endpoint,
818    ///     endpoint
819    /// );
820    ///
821    /// assert_eq!(
822    ///     sessions
823    ///         .get_channel_relay_address(&peer_addr, 0x4000)
824    ///         .unwrap()
825    ///         .endpoint,
826    ///     endpoint
827    /// );
828    /// ```
829    pub fn get_channel_relay_address(&self, addr: &SessionAddr, channel: u16) -> Option<Endpoint> {
830        self.state
831            .channel_relay_table
832            .read()
833            .get(&addr)?
834            .get(&channel)
835            .copied()
836    }
837
838    /// Get the address of the port binding.
839    ///
840    /// # Test
841    ///
842    /// ```
843    /// use turn::*;
844    ///
845    /// #[derive(Clone)]
846    /// struct ObserverTest;
847    ///
848    /// impl Observer for ObserverTest {
849    ///     async fn get_password(
850    ///         &self,
851    ///         addr: &SessionAddr,
852    ///         username: &str,
853    ///     ) -> Option<String> {
854    ///         if username == "test" {
855    ///             Some("test".to_string())
856    ///         } else {
857    ///             None
858    ///         }
859    ///     }
860    /// }
861    ///
862    /// let endpoint = "127.0.0.1:3478".parse().unwrap();
863    /// let addr = SessionAddr {
864    ///     address: "127.0.0.1:8080".parse().unwrap(),
865    ///     interface: "127.0.0.1:3478".parse().unwrap(),
866    /// };
867    ///
868    /// let peer_addr = SessionAddr {
869    ///     address: "127.0.0.1:8081".parse().unwrap(),
870    ///     interface: "127.0.0.1:3478".parse().unwrap(),
871    /// };
872    ///
873    /// let digest = [
874    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
875    ///     239,
876    /// ];
877    ///
878    /// let sessions = Sessions::new(ObserverTest);
879    ///
880    /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
881    /// pollster::block_on(sessions.get_digest(&peer_addr, "test", "test"));
882    ///
883    /// let port = sessions.allocate(&addr).unwrap();
884    /// let peer_port = sessions.allocate(&peer_addr).unwrap();
885    ///
886    /// assert!(sessions.create_permission(&addr, &endpoint, &[peer_port]));
887    /// assert!(sessions.create_permission(&peer_addr, &endpoint, &[port]));
888    ///
889    /// assert_eq!(
890    ///     sessions
891    ///         .get_relay_address(&addr, peer_port)
892    ///         .unwrap()
893    ///         .endpoint,
894    ///     endpoint
895    /// );
896    ///
897    /// assert_eq!(
898    ///     sessions
899    ///         .get_relay_address(&peer_addr, port)
900    ///         .unwrap()
901    ///         .endpoint,
902    ///     endpoint
903    /// );
904    /// ```
905    pub fn get_relay_address(&self, addr: &SessionAddr, port: u16) -> Option<Endpoint> {
906        self.state
907            .port_relay_table
908            .read()
909            .get(&addr)?
910            .get(&port)
911            .copied()
912    }
913
914    /// Refresh the session for addr.
915    ///
916    /// # Test
917    ///
918    /// ```
919    /// use turn::*;
920    ///
921    /// #[derive(Clone)]
922    /// struct ObserverTest;
923    ///
924    /// impl Observer for ObserverTest {
925    ///     async fn get_password(
926    ///         &self,
927    ///         addr: &SessionAddr,
928    ///         username: &str,
929    ///     ) -> Option<String> {
930    ///         if username == "test" {
931    ///             Some("test".to_string())
932    ///         } else {
933    ///             None
934    ///         }
935    ///     }
936    /// }
937    ///
938    /// let addr = SessionAddr {
939    ///     address: "127.0.0.1:8080".parse().unwrap(),
940    ///     interface: "127.0.0.1:3478".parse().unwrap(),
941    /// };
942    ///
943    /// let digest = [
944    ///     174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
945    ///     239,
946    /// ];
947    ///
948    /// let sessions = Sessions::new(ObserverTest);
949    ///
950    /// assert!(sessions.get_session(&addr).get_ref().is_none());
951    ///
952    /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
953    ///
954    /// let expires = sessions.get_session(&addr).get_ref().unwrap().expires;
955    /// assert!(expires == 600 || expires == 601 || expires == 602);
956    ///
957    /// assert!(sessions.refresh(&addr, 0));
958    ///
959    /// assert!(sessions.get_session(&addr).get_ref().is_none());
960    /// ```
961    pub fn refresh(&self, addr: &SessionAddr, lifetime: u32) -> bool {
962        if lifetime > 3600 {
963            return false;
964        }
965
966        if lifetime == 0 {
967            self.remove_session(&[*addr]);
968            self.remove_nonce(&[*addr]);
969        } else {
970            if let Some(session) = self.state.sessions.write().get_mut(addr) {
971                session.expires = self.timer.get() + lifetime as u64;
972            } else {
973                return false;
974            }
975
976            if let Some(nonce) = self.state.address_nonce_tanle.write().get_mut(addr) {
977                nonce.1 = self.timer.get() + lifetime as u64;
978            }
979        }
980
981        true
982    }
983}
984
985/// The default HashMap is created without allocating capacity. To improve
986/// performance, the turn server needs to pre-allocate the available capacity.
987///
988/// So here the HashMap is rewrapped to allocate a large capacity (number of
989/// ports that can be allocated) at the default creation time as well.
990pub struct Table<K, V>(HashMap<K, V>);
991
992impl<K, V> Default for Table<K, V> {
993    fn default() -> Self {
994        Self(HashMap::with_capacity(PortAllocatePools::capacity()))
995    }
996}
997
998impl<K, V> AsRef<HashMap<K, V>> for Table<K, V> {
999    fn as_ref(&self) -> &HashMap<K, V> {
1000        &self.0
1001    }
1002}
1003
1004impl<K, V> Deref for Table<K, V> {
1005    type Target = HashMap<K, V>;
1006
1007    fn deref(&self) -> &Self::Target {
1008        &self.0
1009    }
1010}
1011
1012impl<K, V> DerefMut for Table<K, V> {
1013    fn deref_mut(&mut self) -> &mut Self::Target {
1014        &mut self.0
1015    }
1016}
1017
1018/// Used to lengthen the timing of the release of a readable lock guard and to
1019/// provide a more convenient way for external access to the lock's internal
1020/// data.
1021pub struct ReadLock<'a, 'b, K, R> {
1022    key: &'a K,
1023    lock: RwLockReadGuard<'b, R>,
1024}
1025
1026impl<'a, 'b, K, V> ReadLock<'a, 'b, K, Table<K, V>>
1027where
1028    K: Eq + Hash,
1029{
1030    pub fn get_ref(&self) -> Option<&V> {
1031        self.lock.get(self.key)
1032    }
1033}
1034
1035/// Bit Flag
1036#[derive(PartialEq, Eq)]
1037pub enum Bit {
1038    Low,
1039    High,
1040}
1041
1042/// Random Port
1043///
1044/// Recently, awareness has been raised about a number of "blind" attacks
1045/// (i.e., attacks that can be performed without the need to sniff the
1046/// packets that correspond to the transport protocol instance to be
1047/// attacked) that can be performed against the Transmission Control
1048/// Protocol (TCP) [RFC0793] and similar protocols.  The consequences of
1049/// these attacks range from throughput reduction to broken connections
1050/// or data corruption [RFC5927] [RFC4953] [Watson].
1051///
1052/// All these attacks rely on the attacker's ability to guess or know the
1053/// five-tuple (Protocol, Source Address, Source port, Destination
1054/// Address, Destination Port) that identifies the transport protocol
1055/// instance to be attacked.
1056///
1057/// Services are usually located at fixed, "well-known" ports [IANA] at
1058/// the host supplying the service (the server).  Client applications
1059/// connecting to any such service will contact the server by specifying
1060/// the server IP address and service port number.  The IP address and
1061/// port number of the client are normally left unspecified by the client
1062/// application and thus are chosen automatically by the client
1063/// networking stack.  Ports chosen automatically by the networking stack
1064/// are known as ephemeral ports [Stevens].
1065///
1066/// While the server IP address, the well-known port, and the client IP
1067/// address may be known by an attacker, the ephemeral port of the client
1068/// is usually unknown and must be guessed.
1069pub struct PortAllocatePools {
1070    pub buckets: Vec<u64>,
1071    allocated: usize,
1072    bit_len: u32,
1073    peak: usize,
1074}
1075
1076impl Default for PortAllocatePools {
1077    fn default() -> Self {
1078        Self {
1079            buckets: vec![0; Self::bucket_size()],
1080            peak: Self::bucket_size() - 1,
1081            bit_len: Self::bit_len(),
1082            allocated: 0,
1083        }
1084    }
1085}
1086
1087impl PortAllocatePools {
1088    /// compute bucket size.
1089    ///
1090    /// # Test
1091    ///
1092    /// ```
1093    /// use turn::sessions::*;
1094    ///
1095    /// assert_eq!(PortAllocatePools::bucket_size(), 256);
1096    /// ```
1097    pub fn bucket_size() -> usize {
1098        (Self::capacity() as f32 / 64.0).ceil() as usize
1099    }
1100
1101    /// compute bucket last bit max offset.
1102    ///
1103    /// # Test
1104    ///
1105    /// ```
1106    /// use turn::sessions::*;
1107    ///
1108    /// assert_eq!(PortAllocatePools::bit_len(), 63);
1109    /// ```
1110    pub fn bit_len() -> u32 {
1111        (Self::capacity() as f32 % 64.0).ceil() as u32
1112    }
1113
1114    /// get pools capacity.
1115    ///
1116    /// # Test
1117    ///
1118    /// ```
1119    /// use turn::sessions::Bit;
1120    /// use turn::sessions::PortAllocatePools;
1121    ///
1122    /// assert_eq!(PortAllocatePools::capacity(), 65535 - 49152);
1123    /// ```
1124    pub const fn capacity() -> usize {
1125        65535 - 49152
1126    }
1127
1128    /// get port range.
1129    ///
1130    /// # Test
1131    ///
1132    /// ```
1133    /// use turn::sessions::*;
1134    ///
1135    /// assert_eq!(PortAllocatePools::port_range(), 49152..65535);
1136    /// ```
1137    pub const fn port_range() -> Range<u16> {
1138        49152..65535
1139    }
1140
1141    /// get pools allocated size.
1142    ///
1143    /// ```
1144    /// use turn::sessions::PortAllocatePools;
1145    ///
1146    /// let mut pools = PortAllocatePools::default();
1147    /// assert_eq!(pools.len(), 0);
1148    ///
1149    /// pools.alloc(None).unwrap();
1150    /// assert_eq!(pools.len(), 1);
1151    /// ```
1152    pub fn len(&self) -> usize {
1153        self.allocated
1154    }
1155
1156    /// get pools allocated size is empty.
1157    ///
1158    /// ```
1159    /// use turn::sessions::PortAllocatePools;
1160    ///
1161    /// let mut pools = PortAllocatePools::default();
1162    /// assert_eq!(pools.len(), 0);
1163    /// assert_eq!(pools.is_empty(), true);
1164    /// ```
1165    pub fn is_empty(&self) -> bool {
1166        self.allocated == 0
1167    }
1168
1169    /// random assign a port.
1170    ///
1171    /// # Test
1172    ///
1173    /// ```
1174    /// use turn::sessions::PortAllocatePools;
1175    ///
1176    /// let mut pool = PortAllocatePools::default();
1177    ///
1178    /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1179    /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1180    ///
1181    /// assert!(pool.alloc(None).is_some());
1182    /// ```
1183    pub fn alloc(&mut self, start_index: Option<usize>) -> Option<u16> {
1184        let mut index = None;
1185        let mut start =
1186            start_index.unwrap_or_else(|| thread_rng().gen_range(0..self.peak as u16) as usize);
1187
1188        // When the partition lookup has gone through the entire partition list, the
1189        // lookup should be stopped, and the location where it should be stopped is
1190        // recorded here.
1191        let previous = if start == 0 { self.peak } else { start - 1 };
1192
1193        loop {
1194            // Finds the first high position in the partition.
1195            if let Some(i) = {
1196                let bucket = self.buckets[start];
1197                let offset = if bucket < u64::MAX {
1198                    bucket.leading_ones()
1199                } else {
1200                    return None;
1201                };
1202
1203                // Check to see if the jump is beyond the partition list or the lookup exceeds
1204                // the maximum length of the allocation table.
1205                if start == self.peak && offset > self.bit_len {
1206                    return None;
1207                }
1208
1209                Some(offset)
1210            } {
1211                index = Some(i as usize);
1212                break;
1213            }
1214
1215            // As long as it doesn't find it, it continues to re-find it from the next
1216            // partition.
1217            if start == self.peak {
1218                start = 0;
1219            } else {
1220                start += 1;
1221            }
1222
1223            // Already gone through all partitions, lookup failed.
1224            if start == previous {
1225                break;
1226            }
1227        }
1228
1229        // Writes to the partition, marking the current location as already allocated.
1230        let index = index?;
1231        self.set_bit(start, index, Bit::High);
1232        self.allocated += 1;
1233
1234        // The actual port number is calculated from the partition offset position.
1235        let num = (start * 64 + index) as u16;
1236        let port = Self::port_range().start + num;
1237        Some(port)
1238    }
1239
1240    /// write bit flag in the bucket.
1241    ///
1242    /// # Test
1243    ///
1244    /// ```
1245    /// use turn::sessions::Bit;
1246    /// use turn::sessions::PortAllocatePools;
1247    ///
1248    /// let mut pool = PortAllocatePools::default();
1249    ///
1250    /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1251    /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1252    ///
1253    /// pool.set_bit(0, 0, Bit::High);
1254    /// pool.set_bit(0, 1, Bit::High);
1255    ///
1256    /// assert_eq!(pool.alloc(Some(0)), Some(49154));
1257    /// assert_eq!(pool.alloc(Some(0)), Some(49155));
1258    /// ```
1259    pub fn set_bit(&mut self, bucket: usize, index: usize, bit: Bit) {
1260        let high_mask = 1 << (63 - index);
1261        let mask = match bit {
1262            Bit::Low => u64::MAX ^ high_mask,
1263            Bit::High => high_mask,
1264        };
1265
1266        let value = self.buckets[bucket];
1267        self.buckets[bucket] = match bit {
1268            Bit::High => value | mask,
1269            Bit::Low => value & mask,
1270        };
1271    }
1272
1273    /// restore port in the buckets.
1274    ///
1275    /// # Test
1276    ///
1277    /// ```
1278    /// use turn::sessions::PortAllocatePools;
1279    ///
1280    /// let mut pool = PortAllocatePools::default();
1281    ///
1282    /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1283    /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1284    ///
1285    /// pool.restore(49152);
1286    /// pool.restore(49153);
1287    ///
1288    /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1289    /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1290    /// ```
1291    pub fn restore(&mut self, port: u16) {
1292        assert!(Self::port_range().contains(&port));
1293
1294        // Calculate the location in the partition from the port number.
1295        let offset = (port - Self::port_range().start) as usize;
1296        let bucket = offset / 64;
1297        let index = offset - (bucket * 64);
1298
1299        // Gets the bit value in the port position in the partition, if it is low, no
1300        // processing is required.
1301        if {
1302            match (self.buckets[bucket] & (1 << (63 - index))) >> (63 - index) {
1303                0 => Bit::Low,
1304                1 => Bit::High,
1305                _ => panic!(),
1306            }
1307        } == Bit::Low
1308        {
1309            return;
1310        }
1311
1312        self.set_bit(bucket, index, Bit::Low);
1313        self.allocated -= 1;
1314    }
1315}