mycrl_turn/sessions.rs
1use crate::Observer;
2
3use std::{
4 hash::Hash,
5 net::SocketAddr,
6 ops::{Deref, DerefMut, Range},
7 sync::{
8 atomic::{AtomicU64, Ordering},
9 Arc,
10 },
11 thread::{self, sleep},
12 time::Duration,
13};
14
15use ahash::{HashMap, HashMapExt};
16use parking_lot::{Mutex, RwLock, RwLockReadGuard};
17use rand::{distributions::Alphanumeric, thread_rng, Rng};
18use stun::util::long_term_credential_digest;
19
20/// Authentication information for the session.
21///
22/// Digest data is data that summarises usernames and passwords by means of
23/// long-term authentication.
24#[derive(Debug, Clone)]
25pub struct Auth {
26 pub username: String,
27 pub password: String,
28 pub digest: [u8; 16],
29}
30
31/// Assignment information for the session.
32///
33/// Sessions are all bound to only one port and one channel.
34#[derive(Debug, Clone)]
35pub struct Allocate {
36 pub port: Option<u16>,
37 pub channels: Vec<u16>,
38}
39
40/// turn session information.
41///
42/// A user can have many sessions.
43///
44/// The default survival time for a session is 600 seconds.
45#[derive(Debug, Clone)]
46pub struct Session {
47 pub auth: Auth,
48 pub allocate: Allocate,
49 pub permissions: Vec<u16>,
50 pub expires: u64,
51}
52
53/// The identifier of the session or addr.
54///
55/// Each session needs to be identified by a combination of three pieces of
56/// information: the addr address, and the transport protocol.
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
58pub struct SessionAddr {
59 pub address: SocketAddr,
60 pub interface: SocketAddr,
61}
62
63/// The addr used to record the current session.
64///
65/// This is used when forwarding data.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
67pub struct Endpoint {
68 pub address: SocketAddr,
69 pub endpoint: SocketAddr,
70}
71
72/// A specially optimised timer.
73///
74/// This timer does not stack automatically and needs to be stacked externally
75/// and manually.
76///
77/// ```
78/// use turn::sessions::Timer;
79///
80/// let timer = Timer::default();
81///
82/// assert_eq!(timer.get(), 0);
83/// assert_eq!(timer.add(), 1);
84/// assert_eq!(timer.get(), 1);
85/// ```
86#[derive(Default)]
87pub struct Timer(AtomicU64);
88
89impl Timer {
90 pub fn get(&self) -> u64 {
91 self.0.load(Ordering::Relaxed)
92 }
93
94 pub fn add(&self) -> u64 {
95 self.0.fetch_add(1, Ordering::Relaxed) + 1
96 }
97}
98
99#[derive(Default)]
100pub struct State {
101 sessions: RwLock<Table<SessionAddr, Session>>,
102 port_allocate_pool: Mutex<PortAllocatePools>,
103 // Records the sessions corresponding to each assigned port, which will be needed when looking
104 // up sessions assigned to this port based on the port number.
105 port_mapping_table: RwLock<Table</* port */ u16, SessionAddr>>,
106 // Records the nonce value for each network connection, which is independent of the session
107 // because it can exist before it is authenticated.
108 address_nonce_tanle: RwLock<Table<SessionAddr, (String, /* expires */ u64)>>,
109 // Stores the address to which the session should be forwarded when it sends indication to a
110 // port. This is written when permissions are created to allow a certain address to be
111 // forwarded to the current session.
112 port_relay_table: RwLock<Table<SessionAddr, HashMap</* port */ u16, Endpoint>>>,
113 // Indicates to which session the data sent by a session to a channel should be forwarded.
114 channel_relay_table: RwLock<Table<SessionAddr, HashMap</* channel */ u16, Endpoint>>>,
115}
116
117pub struct Sessions<T> {
118 timer: Timer,
119 state: State,
120 observer: T,
121}
122
123impl<T: Observer + 'static> Sessions<T> {
124 pub fn new(observer: T) -> Arc<Self> {
125 let this = Arc::new(Self {
126 state: State::default(),
127 timer: Timer::default(),
128 observer,
129 });
130
131 // This is a background thread that silently handles expiring sessions and
132 // cleans up session information when it expires.
133 let this_ = Arc::downgrade(&this);
134 thread::spawn(move || {
135 let mut address = Vec::with_capacity(255);
136
137 while let Some(this) = this_.upgrade() {
138 // The timer advances one second and gets the current time offset.
139 let now = this.timer.add();
140
141 // This is the part that deletes the session information.
142 {
143 // Finds sessions that have expired.
144 {
145 this.state
146 .sessions
147 .read()
148 .iter()
149 .filter(|(_, v)| v.expires <= now)
150 .for_each(|(k, _)| address.push(*k));
151 }
152
153 // Delete the expired sessions.
154 if !address.is_empty() {
155 this.remove_session(&address);
156 address.clear();
157 }
158 }
159
160 // Because nonce does not follow session creation, nonce is created for each
161 // addr, so nonce deletion is handled independently.
162 {
163 this.state
164 .address_nonce_tanle
165 .read()
166 .iter()
167 .filter(|(_, v)| v.1 <= now)
168 .for_each(|(k, _)| address.push(*k));
169
170 if !address.is_empty() {
171 this.remove_nonce(&address);
172 address.clear();
173 }
174 }
175
176 // Fixing a second tick.
177 sleep(Duration::from_secs(1));
178 }
179 });
180
181 this
182 }
183
184 fn remove_session(&self, addrs: &[SessionAddr]) {
185 let mut sessions = self.state.sessions.write();
186 let mut port_allocate_pool = self.state.port_allocate_pool.lock();
187 let mut port_mapping_table = self.state.port_mapping_table.write();
188 let mut port_relay_table = self.state.port_relay_table.write();
189 let mut channel_relay_table = self.state.channel_relay_table.write();
190
191 addrs.iter().for_each(|k| {
192 port_relay_table.remove(k);
193 channel_relay_table.remove(k);
194
195 if let Some(session) = sessions.remove(k) {
196 // Removes the session-bound port from the port binding table and
197 // releases the port back into the allocation pool.
198 if let Some(port) = session.allocate.port {
199 port_mapping_table.remove(&port);
200 port_allocate_pool.restore(port);
201 }
202
203 // Notifies that the external session has been closed.
204 self.observer.closed(k, &session.auth.username);
205 }
206 });
207 }
208
209 fn remove_nonce(&self, addrs: &[SessionAddr]) {
210 let mut address_nonce_tanle = self.state.address_nonce_tanle.write();
211
212 addrs.iter().for_each(|k| {
213 address_nonce_tanle.remove(k);
214 });
215 }
216
217 /// Get session for addr.
218 ///
219 /// # Test
220 ///
221 /// ```
222 /// use turn::*;
223 ///
224 /// #[derive(Clone)]
225 /// struct ObserverTest;
226 ///
227 /// impl Observer for ObserverTest {
228 /// async fn get_password(
229 /// &self,
230 /// addr: &SessionAddr,
231 /// username: &str,
232 /// ) -> Option<String> {
233 /// if username == "test" {
234 /// Some("test".to_string())
235 /// } else {
236 /// None
237 /// }
238 /// }
239 /// }
240 ///
241 /// let addr = SessionAddr {
242 /// address: "127.0.0.1:8080".parse().unwrap(),
243 /// interface: "127.0.0.1:3478".parse().unwrap(),
244 /// };
245 ///
246 /// let digest = [
247 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
248 /// 239,
249 /// ];
250 ///
251 /// let sessions = Sessions::new(ObserverTest);
252 ///
253 /// assert!(sessions.get_session(&addr).get_ref().is_none());
254 ///
255 /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
256 ///
257 /// let lock = sessions.get_session(&addr);
258 /// let session = lock.get_ref().unwrap();
259 /// assert_eq!(session.auth.username, "test");
260 /// assert_eq!(session.auth.password, "test");
261 /// assert_eq!(session.allocate.port, None);
262 /// assert_eq!(session.allocate.channels.len(), 0);
263 /// ```
264 pub fn get_session<'a, 'b>(
265 &'a self,
266 key: &'b SessionAddr,
267 ) -> ReadLock<'b, 'a, SessionAddr, Table<SessionAddr, Session>> {
268 ReadLock {
269 lock: self.state.sessions.read(),
270 key,
271 }
272 }
273
274 /// Get nonce for addr.
275 ///
276 /// # Test
277 ///
278 /// ```
279 /// use turn::*;
280 ///
281 /// #[derive(Clone)]
282 /// struct ObserverTest;
283 ///
284 /// impl Observer for ObserverTest {}
285 ///
286 /// let addr = SessionAddr {
287 /// address: "127.0.0.1:8080".parse().unwrap(),
288 /// interface: "127.0.0.1:3478".parse().unwrap(),
289 /// };
290 ///
291 /// let sessions = Sessions::new(ObserverTest);
292 ///
293 /// let a = sessions.get_nonce(&addr).get_ref().unwrap().clone();
294 /// assert!(a.0.len() == 16);
295 /// assert!(a.1 == 600 || a.1 == 601 || a.1 == 602);
296 ///
297 /// let b = sessions.get_nonce(&addr).get_ref().unwrap().clone();
298 /// assert_eq!(a.0, b.0);
299 /// assert!(b.1 == 600 || b.1 == 601 || b.1 == 602);
300 /// ```
301 pub fn get_nonce<'a, 'b>(
302 &'a self,
303 key: &'b SessionAddr,
304 ) -> ReadLock<'b, 'a, SessionAddr, Table<SessionAddr, (String, u64)>> {
305 // If no nonce is created, create a new one.
306 {
307 if !self.state.address_nonce_tanle.read().contains_key(key) {
308 self.state.address_nonce_tanle.write().insert(
309 *key,
310 (
311 // A random string of length 16.
312 {
313 let mut rng = thread_rng();
314 std::iter::repeat(())
315 .map(|_| rng.sample(Alphanumeric) as char)
316 .take(16)
317 .collect::<String>()
318 .to_lowercase()
319 },
320 // Current time stacks for 600 seconds.
321 self.timer.get() + 600,
322 ),
323 );
324 }
325 }
326
327 ReadLock {
328 lock: self.state.address_nonce_tanle.read(),
329 key,
330 }
331 }
332
333 /// Get digest for addr.
334 ///
335 /// # Test
336 ///
337 /// ```
338 /// use turn::*;
339 ///
340 /// #[derive(Clone)]
341 /// struct ObserverTest;
342 ///
343 /// impl Observer for ObserverTest {
344 /// async fn get_password(
345 /// &self,
346 /// addr: &SessionAddr,
347 /// username: &str,
348 /// ) -> Option<String> {
349 /// if username == "test" {
350 /// Some("test".to_string())
351 /// } else {
352 /// None
353 /// }
354 /// }
355 /// }
356 ///
357 /// let addr = SessionAddr {
358 /// address: "127.0.0.1:8080".parse().unwrap(),
359 /// interface: "127.0.0.1:3478".parse().unwrap(),
360 /// };
361 ///
362 /// let digest = [
363 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
364 /// 239,
365 /// ];
366 ///
367 /// let sessions = Sessions::new(ObserverTest);
368 ///
369 /// assert_eq!(
370 /// pollster::block_on(sessions.get_digest(&addr, "test1", "test")),
371 /// None
372 /// );
373 ///
374 /// assert_eq!(
375 /// pollster::block_on(sessions.get_digest(&addr, "test", "test")),
376 /// Some(digest)
377 /// );
378 ///
379 /// assert_eq!(
380 /// pollster::block_on(sessions.get_digest(&addr, "test", "test")),
381 /// Some(digest)
382 /// );
383 /// ```
384 pub async fn get_digest(
385 &self,
386 addr: &SessionAddr,
387 username: &str,
388 realm: &str,
389 ) -> Option<[u8; 16]> {
390 // Already authenticated, get the cached digest directly.
391 {
392 if let Some(it) = self.state.sessions.read().get(addr) {
393 return Some(it.auth.digest);
394 }
395 }
396
397 // Get the current user's password from an external observer and create a
398 // digest.
399 let password = self.observer.get_password(addr, username).await?;
400 let digest = long_term_credential_digest(&username, &password, realm);
401
402 // Record a new session.
403 {
404 self.state.sessions.write().insert(
405 *addr,
406 Session {
407 permissions: Vec::with_capacity(10),
408 expires: self.timer.get() + 600,
409 auth: Auth {
410 username: username.to_string(),
411 password,
412 digest,
413 },
414 allocate: Allocate {
415 channels: Vec::with_capacity(10),
416 port: None,
417 },
418 },
419 );
420 }
421
422 Some(digest)
423 }
424
425 pub fn allocated(&self) -> usize {
426 self.state.port_allocate_pool.lock().len()
427 }
428
429 /// Assign a port number to the session.
430 ///
431 /// # Test
432 ///
433 /// ```
434 /// use turn::*;
435 ///
436 /// #[derive(Clone)]
437 /// struct ObserverTest;
438 ///
439 /// impl Observer for ObserverTest {
440 /// async fn get_password(
441 /// &self,
442 /// addr: &SessionAddr,
443 /// username: &str,
444 /// ) -> Option<String> {
445 /// if username == "test" {
446 /// Some("test".to_string())
447 /// } else {
448 /// None
449 /// }
450 /// }
451 /// }
452 ///
453 /// let addr = SessionAddr {
454 /// address: "127.0.0.1:8080".parse().unwrap(),
455 /// interface: "127.0.0.1:3478".parse().unwrap(),
456 /// };
457 ///
458 /// let digest = [
459 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
460 /// 239,
461 /// ];
462 ///
463 /// let sessions = Sessions::new(ObserverTest);
464 ///
465 /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
466 ///
467 /// {
468 /// let lock = sessions.get_session(&addr);
469 /// let session = lock.get_ref().unwrap();
470 /// assert_eq!(session.auth.username, "test");
471 /// assert_eq!(session.auth.password, "test");
472 /// assert_eq!(session.allocate.port, None);
473 /// assert_eq!(session.allocate.channels.len(), 0);
474 /// }
475 ///
476 /// let port = sessions.allocate(&addr).unwrap();
477 /// {
478 /// let lock = sessions.get_session(&addr);
479 /// let session = lock.get_ref().unwrap();
480 /// assert_eq!(session.auth.username, "test");
481 /// assert_eq!(session.auth.password, "test");
482 /// assert_eq!(session.allocate.port, Some(port));
483 /// assert_eq!(session.allocate.channels.len(), 0);
484 /// }
485 ///
486 /// assert_eq!(sessions.allocate(&addr), Some(port));
487 /// ```
488 pub fn allocate(&self, addr: &SessionAddr) -> Option<u16> {
489 let mut lock = self.state.sessions.write();
490 let session = lock.get_mut(addr)?;
491
492 // If the port has already been allocated, re-allocation is not allowed.
493 if let Some(port) = session.allocate.port {
494 return Some(port);
495 }
496
497 // Records the port assigned to the current session and resets the alive time.
498 let port = self.state.port_allocate_pool.lock().alloc(None)?;
499 session.expires = self.timer.get() + 600;
500 session.allocate.port = Some(port);
501
502 // Write the allocation port binding table.
503 self.state.port_mapping_table.write().insert(port, *addr);
504 Some(port)
505 }
506
507 /// Create permission for session.
508 ///
509 /// # Test
510 ///
511 /// ```
512 /// use turn::*;
513 ///
514 /// #[derive(Clone)]
515 /// struct ObserverTest;
516 ///
517 /// impl Observer for ObserverTest {
518 /// async fn get_password(
519 /// &self,
520 /// addr: &SessionAddr,
521 /// username: &str,
522 /// ) -> Option<String> {
523 /// if username == "test" {
524 /// Some("test".to_string())
525 /// } else {
526 /// None
527 /// }
528 /// }
529 /// }
530 ///
531 /// let endpoint = "127.0.0.1:3478".parse().unwrap();
532 /// let addr = SessionAddr {
533 /// address: "127.0.0.1:8080".parse().unwrap(),
534 /// interface: "127.0.0.1:3478".parse().unwrap(),
535 /// };
536 ///
537 /// let peer_addr = SessionAddr {
538 /// address: "127.0.0.1:8081".parse().unwrap(),
539 /// interface: "127.0.0.1:3478".parse().unwrap(),
540 /// };
541 ///
542 /// let digest = [
543 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
544 /// 239,
545 /// ];
546 ///
547 /// let sessions = Sessions::new(ObserverTest);
548 ///
549 /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
550 /// pollster::block_on(sessions.get_digest(&peer_addr, "test", "test"));
551 ///
552 /// let port = sessions.allocate(&addr).unwrap();
553 /// let peer_port = sessions.allocate(&peer_addr).unwrap();
554 ///
555 /// assert!(!sessions.create_permission(&addr, &endpoint, &[port]));
556 /// assert!(sessions.create_permission(&addr, &endpoint, &[peer_port]));
557 ///
558 /// assert!(!sessions.create_permission(&peer_addr, &endpoint, &[peer_port]));
559 /// assert!(sessions.create_permission(&peer_addr, &endpoint, &[port]));
560 /// ```
561 pub fn create_permission(
562 &self,
563 addr: &SessionAddr,
564 endpoint: &SocketAddr,
565 ports: &[u16],
566 ) -> bool {
567 let mut sessions = self.state.sessions.write();
568 let mut port_relay_table = self.state.port_relay_table.write();
569 let port_mapping_table = self.state.port_mapping_table.read();
570
571 // Finds information about the current session.
572 let session = if let Some(it) = sessions.get_mut(addr) {
573 it
574 } else {
575 return false;
576 };
577
578 // The port number assigned to the current session.
579 let local_port = if let Some(it) = session.allocate.port {
580 it
581 } else {
582 return false;
583 };
584
585 // You cannot create permissions for yourself.
586 if ports.contains(&local_port) {
587 return false;
588 }
589
590 // Each peer port must be present.
591 let mut peers = Vec::with_capacity(15);
592 for port in ports {
593 if let Some(it) = port_mapping_table.get(&port) {
594 peers.push((it, *port));
595 } else {
596 return false;
597 }
598 }
599
600 // Create a port forwarding mapping relationship for each peer session.
601 for (peer, port) in peers {
602 port_relay_table
603 .entry(*peer)
604 .or_insert_with(|| HashMap::with_capacity(20))
605 .insert(
606 local_port,
607 Endpoint {
608 address: addr.address,
609 endpoint: *endpoint,
610 },
611 );
612
613 // Do not store the same peer ports to the permission list over and over again.
614 if !session.permissions.contains(&port) {
615 session.permissions.push(port);
616 }
617 }
618
619 true
620 }
621
622 /// Binding a channel to the session.
623 ///
624 /// # Test
625 ///
626 /// ```
627 /// use turn::*;
628 ///
629 /// #[derive(Clone)]
630 /// struct ObserverTest;
631 ///
632 /// impl Observer for ObserverTest {
633 /// async fn get_password(
634 /// &self,
635 /// addr: &SessionAddr,
636 /// username: &str,
637 /// ) -> Option<String> {
638 /// if username == "test" {
639 /// Some("test".to_string())
640 /// } else {
641 /// None
642 /// }
643 /// }
644 /// }
645 ///
646 /// let endpoint = "127.0.0.1:3478".parse().unwrap();
647 /// let addr = SessionAddr {
648 /// address: "127.0.0.1:8080".parse().unwrap(),
649 /// interface: "127.0.0.1:3478".parse().unwrap(),
650 /// };
651 ///
652 /// let peer_addr = SessionAddr {
653 /// address: "127.0.0.1:8081".parse().unwrap(),
654 /// interface: "127.0.0.1:3478".parse().unwrap(),
655 /// };
656 ///
657 /// let digest = [
658 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
659 /// 239,
660 /// ];
661 ///
662 /// let sessions = Sessions::new(ObserverTest);
663 ///
664 /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
665 /// pollster::block_on(sessions.get_digest(&peer_addr, "test", "test"));
666 ///
667 /// let port = sessions.allocate(&addr).unwrap();
668 /// let peer_port = sessions.allocate(&peer_addr).unwrap();
669 /// assert_eq!(
670 /// sessions
671 /// .get_session(&addr)
672 /// .get_ref()
673 /// .unwrap()
674 /// .allocate
675 /// .channels
676 /// .len(),
677 /// 0
678 /// );
679 ///
680 /// assert_eq!(
681 /// sessions
682 /// .get_session(&peer_addr)
683 /// .get_ref()
684 /// .unwrap()
685 /// .allocate
686 /// .channels
687 /// .len(),
688 /// 0
689 /// );
690 ///
691 /// assert!(sessions.bind_channel(&addr, &endpoint, peer_port, 0x4000));
692 /// assert!(sessions.bind_channel(&peer_addr, &endpoint, port, 0x4000));
693 /// assert_eq!(
694 /// sessions
695 /// .get_session(&addr)
696 /// .get_ref()
697 /// .unwrap()
698 /// .allocate
699 /// .channels,
700 /// vec![0x4000]
701 /// );
702 ///
703 /// assert_eq!(
704 /// sessions
705 /// .get_session(&peer_addr)
706 /// .get_ref()
707 /// .unwrap()
708 /// .allocate
709 /// .channels,
710 /// vec![0x4000]
711 /// );
712 /// ```
713 pub fn bind_channel(
714 &self,
715 addr: &SessionAddr,
716 endpoint: &SocketAddr,
717 port: u16,
718 channel: u16,
719 ) -> bool {
720 // Finds the address of the bound opposing port.
721 let peer = if let Some(it) = self.state.port_mapping_table.read().get(&port) {
722 *it
723 } else {
724 return false;
725 };
726
727 // Records the channel used for the current session.
728 {
729 let mut lock = self.state.sessions.write();
730 let session = if let Some(it) = lock.get_mut(addr) {
731 it
732 } else {
733 return false;
734 };
735
736 if !session.allocate.channels.contains(&channel) {
737 session.allocate.channels.push(channel);
738 }
739 }
740
741 // Binding ports also creates permissions.
742 if !self.create_permission(addr, endpoint, &[port]) {
743 return false;
744 }
745
746 // Create channel forwarding mapping relationships for peers.
747 self.state
748 .channel_relay_table
749 .write()
750 .entry(peer)
751 .or_insert_with(|| HashMap::with_capacity(10))
752 .insert(
753 channel,
754 Endpoint {
755 address: addr.address,
756 endpoint: *endpoint,
757 },
758 );
759
760 true
761 }
762
763 /// Gets the peer of the current session bound channel.
764 ///
765 /// # Test
766 ///
767 /// ```
768 /// use turn::*;
769 ///
770 /// #[derive(Clone)]
771 /// struct ObserverTest;
772 ///
773 /// impl Observer for ObserverTest {
774 /// async fn get_password(
775 /// &self,
776 /// addr: &SessionAddr,
777 /// username: &str,
778 /// ) -> Option<String> {
779 /// if username == "test" {
780 /// Some("test".to_string())
781 /// } else {
782 /// None
783 /// }
784 /// }
785 /// }
786 ///
787 /// let endpoint = "127.0.0.1:3478".parse().unwrap();
788 /// let addr = SessionAddr {
789 /// address: "127.0.0.1:8080".parse().unwrap(),
790 /// interface: "127.0.0.1:3478".parse().unwrap(),
791 /// };
792 ///
793 /// let peer_addr = SessionAddr {
794 /// address: "127.0.0.1:8081".parse().unwrap(),
795 /// interface: "127.0.0.1:3478".parse().unwrap(),
796 /// };
797 ///
798 /// let digest = [
799 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
800 /// 239,
801 /// ];
802 ///
803 /// let sessions = Sessions::new(ObserverTest);
804 ///
805 /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
806 /// pollster::block_on(sessions.get_digest(&peer_addr, "test", "test"));
807 ///
808 /// let port = sessions.allocate(&addr).unwrap();
809 /// let peer_port = sessions.allocate(&peer_addr).unwrap();
810 ///
811 /// assert!(sessions.bind_channel(&addr, &endpoint, peer_port, 0x4000));
812 /// assert!(sessions.bind_channel(&peer_addr, &endpoint, port, 0x4000));
813 /// assert_eq!(
814 /// sessions
815 /// .get_channel_relay_address(&addr, 0x4000)
816 /// .unwrap()
817 /// .endpoint,
818 /// endpoint
819 /// );
820 ///
821 /// assert_eq!(
822 /// sessions
823 /// .get_channel_relay_address(&peer_addr, 0x4000)
824 /// .unwrap()
825 /// .endpoint,
826 /// endpoint
827 /// );
828 /// ```
829 pub fn get_channel_relay_address(&self, addr: &SessionAddr, channel: u16) -> Option<Endpoint> {
830 self.state
831 .channel_relay_table
832 .read()
833 .get(&addr)?
834 .get(&channel)
835 .copied()
836 }
837
838 /// Get the address of the port binding.
839 ///
840 /// # Test
841 ///
842 /// ```
843 /// use turn::*;
844 ///
845 /// #[derive(Clone)]
846 /// struct ObserverTest;
847 ///
848 /// impl Observer for ObserverTest {
849 /// async fn get_password(
850 /// &self,
851 /// addr: &SessionAddr,
852 /// username: &str,
853 /// ) -> Option<String> {
854 /// if username == "test" {
855 /// Some("test".to_string())
856 /// } else {
857 /// None
858 /// }
859 /// }
860 /// }
861 ///
862 /// let endpoint = "127.0.0.1:3478".parse().unwrap();
863 /// let addr = SessionAddr {
864 /// address: "127.0.0.1:8080".parse().unwrap(),
865 /// interface: "127.0.0.1:3478".parse().unwrap(),
866 /// };
867 ///
868 /// let peer_addr = SessionAddr {
869 /// address: "127.0.0.1:8081".parse().unwrap(),
870 /// interface: "127.0.0.1:3478".parse().unwrap(),
871 /// };
872 ///
873 /// let digest = [
874 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
875 /// 239,
876 /// ];
877 ///
878 /// let sessions = Sessions::new(ObserverTest);
879 ///
880 /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
881 /// pollster::block_on(sessions.get_digest(&peer_addr, "test", "test"));
882 ///
883 /// let port = sessions.allocate(&addr).unwrap();
884 /// let peer_port = sessions.allocate(&peer_addr).unwrap();
885 ///
886 /// assert!(sessions.create_permission(&addr, &endpoint, &[peer_port]));
887 /// assert!(sessions.create_permission(&peer_addr, &endpoint, &[port]));
888 ///
889 /// assert_eq!(
890 /// sessions
891 /// .get_relay_address(&addr, peer_port)
892 /// .unwrap()
893 /// .endpoint,
894 /// endpoint
895 /// );
896 ///
897 /// assert_eq!(
898 /// sessions
899 /// .get_relay_address(&peer_addr, port)
900 /// .unwrap()
901 /// .endpoint,
902 /// endpoint
903 /// );
904 /// ```
905 pub fn get_relay_address(&self, addr: &SessionAddr, port: u16) -> Option<Endpoint> {
906 self.state
907 .port_relay_table
908 .read()
909 .get(&addr)?
910 .get(&port)
911 .copied()
912 }
913
914 /// Refresh the session for addr.
915 ///
916 /// # Test
917 ///
918 /// ```
919 /// use turn::*;
920 ///
921 /// #[derive(Clone)]
922 /// struct ObserverTest;
923 ///
924 /// impl Observer for ObserverTest {
925 /// async fn get_password(
926 /// &self,
927 /// addr: &SessionAddr,
928 /// username: &str,
929 /// ) -> Option<String> {
930 /// if username == "test" {
931 /// Some("test".to_string())
932 /// } else {
933 /// None
934 /// }
935 /// }
936 /// }
937 ///
938 /// let addr = SessionAddr {
939 /// address: "127.0.0.1:8080".parse().unwrap(),
940 /// interface: "127.0.0.1:3478".parse().unwrap(),
941 /// };
942 ///
943 /// let digest = [
944 /// 174, 238, 187, 253, 117, 209, 73, 157, 36, 56, 143, 91, 155, 16, 224,
945 /// 239,
946 /// ];
947 ///
948 /// let sessions = Sessions::new(ObserverTest);
949 ///
950 /// assert!(sessions.get_session(&addr).get_ref().is_none());
951 ///
952 /// pollster::block_on(sessions.get_digest(&addr, "test", "test"));
953 ///
954 /// let expires = sessions.get_session(&addr).get_ref().unwrap().expires;
955 /// assert!(expires == 600 || expires == 601 || expires == 602);
956 ///
957 /// assert!(sessions.refresh(&addr, 0));
958 ///
959 /// assert!(sessions.get_session(&addr).get_ref().is_none());
960 /// ```
961 pub fn refresh(&self, addr: &SessionAddr, lifetime: u32) -> bool {
962 if lifetime > 3600 {
963 return false;
964 }
965
966 if lifetime == 0 {
967 self.remove_session(&[*addr]);
968 self.remove_nonce(&[*addr]);
969 } else {
970 if let Some(session) = self.state.sessions.write().get_mut(addr) {
971 session.expires = self.timer.get() + lifetime as u64;
972 } else {
973 return false;
974 }
975
976 if let Some(nonce) = self.state.address_nonce_tanle.write().get_mut(addr) {
977 nonce.1 = self.timer.get() + lifetime as u64;
978 }
979 }
980
981 true
982 }
983}
984
985/// The default HashMap is created without allocating capacity. To improve
986/// performance, the turn server needs to pre-allocate the available capacity.
987///
988/// So here the HashMap is rewrapped to allocate a large capacity (number of
989/// ports that can be allocated) at the default creation time as well.
990pub struct Table<K, V>(HashMap<K, V>);
991
992impl<K, V> Default for Table<K, V> {
993 fn default() -> Self {
994 Self(HashMap::with_capacity(PortAllocatePools::capacity()))
995 }
996}
997
998impl<K, V> AsRef<HashMap<K, V>> for Table<K, V> {
999 fn as_ref(&self) -> &HashMap<K, V> {
1000 &self.0
1001 }
1002}
1003
1004impl<K, V> Deref for Table<K, V> {
1005 type Target = HashMap<K, V>;
1006
1007 fn deref(&self) -> &Self::Target {
1008 &self.0
1009 }
1010}
1011
1012impl<K, V> DerefMut for Table<K, V> {
1013 fn deref_mut(&mut self) -> &mut Self::Target {
1014 &mut self.0
1015 }
1016}
1017
1018/// Used to lengthen the timing of the release of a readable lock guard and to
1019/// provide a more convenient way for external access to the lock's internal
1020/// data.
1021pub struct ReadLock<'a, 'b, K, R> {
1022 key: &'a K,
1023 lock: RwLockReadGuard<'b, R>,
1024}
1025
1026impl<'a, 'b, K, V> ReadLock<'a, 'b, K, Table<K, V>>
1027where
1028 K: Eq + Hash,
1029{
1030 pub fn get_ref(&self) -> Option<&V> {
1031 self.lock.get(self.key)
1032 }
1033}
1034
1035/// Bit Flag
1036#[derive(PartialEq, Eq)]
1037pub enum Bit {
1038 Low,
1039 High,
1040}
1041
1042/// Random Port
1043///
1044/// Recently, awareness has been raised about a number of "blind" attacks
1045/// (i.e., attacks that can be performed without the need to sniff the
1046/// packets that correspond to the transport protocol instance to be
1047/// attacked) that can be performed against the Transmission Control
1048/// Protocol (TCP) [RFC0793] and similar protocols. The consequences of
1049/// these attacks range from throughput reduction to broken connections
1050/// or data corruption [RFC5927] [RFC4953] [Watson].
1051///
1052/// All these attacks rely on the attacker's ability to guess or know the
1053/// five-tuple (Protocol, Source Address, Source port, Destination
1054/// Address, Destination Port) that identifies the transport protocol
1055/// instance to be attacked.
1056///
1057/// Services are usually located at fixed, "well-known" ports [IANA] at
1058/// the host supplying the service (the server). Client applications
1059/// connecting to any such service will contact the server by specifying
1060/// the server IP address and service port number. The IP address and
1061/// port number of the client are normally left unspecified by the client
1062/// application and thus are chosen automatically by the client
1063/// networking stack. Ports chosen automatically by the networking stack
1064/// are known as ephemeral ports [Stevens].
1065///
1066/// While the server IP address, the well-known port, and the client IP
1067/// address may be known by an attacker, the ephemeral port of the client
1068/// is usually unknown and must be guessed.
1069pub struct PortAllocatePools {
1070 pub buckets: Vec<u64>,
1071 allocated: usize,
1072 bit_len: u32,
1073 peak: usize,
1074}
1075
1076impl Default for PortAllocatePools {
1077 fn default() -> Self {
1078 Self {
1079 buckets: vec![0; Self::bucket_size()],
1080 peak: Self::bucket_size() - 1,
1081 bit_len: Self::bit_len(),
1082 allocated: 0,
1083 }
1084 }
1085}
1086
1087impl PortAllocatePools {
1088 /// compute bucket size.
1089 ///
1090 /// # Test
1091 ///
1092 /// ```
1093 /// use turn::sessions::*;
1094 ///
1095 /// assert_eq!(PortAllocatePools::bucket_size(), 256);
1096 /// ```
1097 pub fn bucket_size() -> usize {
1098 (Self::capacity() as f32 / 64.0).ceil() as usize
1099 }
1100
1101 /// compute bucket last bit max offset.
1102 ///
1103 /// # Test
1104 ///
1105 /// ```
1106 /// use turn::sessions::*;
1107 ///
1108 /// assert_eq!(PortAllocatePools::bit_len(), 63);
1109 /// ```
1110 pub fn bit_len() -> u32 {
1111 (Self::capacity() as f32 % 64.0).ceil() as u32
1112 }
1113
1114 /// get pools capacity.
1115 ///
1116 /// # Test
1117 ///
1118 /// ```
1119 /// use turn::sessions::Bit;
1120 /// use turn::sessions::PortAllocatePools;
1121 ///
1122 /// assert_eq!(PortAllocatePools::capacity(), 65535 - 49152);
1123 /// ```
1124 pub const fn capacity() -> usize {
1125 65535 - 49152
1126 }
1127
1128 /// get port range.
1129 ///
1130 /// # Test
1131 ///
1132 /// ```
1133 /// use turn::sessions::*;
1134 ///
1135 /// assert_eq!(PortAllocatePools::port_range(), 49152..65535);
1136 /// ```
1137 pub const fn port_range() -> Range<u16> {
1138 49152..65535
1139 }
1140
1141 /// get pools allocated size.
1142 ///
1143 /// ```
1144 /// use turn::sessions::PortAllocatePools;
1145 ///
1146 /// let mut pools = PortAllocatePools::default();
1147 /// assert_eq!(pools.len(), 0);
1148 ///
1149 /// pools.alloc(None).unwrap();
1150 /// assert_eq!(pools.len(), 1);
1151 /// ```
1152 pub fn len(&self) -> usize {
1153 self.allocated
1154 }
1155
1156 /// get pools allocated size is empty.
1157 ///
1158 /// ```
1159 /// use turn::sessions::PortAllocatePools;
1160 ///
1161 /// let mut pools = PortAllocatePools::default();
1162 /// assert_eq!(pools.len(), 0);
1163 /// assert_eq!(pools.is_empty(), true);
1164 /// ```
1165 pub fn is_empty(&self) -> bool {
1166 self.allocated == 0
1167 }
1168
1169 /// random assign a port.
1170 ///
1171 /// # Test
1172 ///
1173 /// ```
1174 /// use turn::sessions::PortAllocatePools;
1175 ///
1176 /// let mut pool = PortAllocatePools::default();
1177 ///
1178 /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1179 /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1180 ///
1181 /// assert!(pool.alloc(None).is_some());
1182 /// ```
1183 pub fn alloc(&mut self, start_index: Option<usize>) -> Option<u16> {
1184 let mut index = None;
1185 let mut start =
1186 start_index.unwrap_or_else(|| thread_rng().gen_range(0..self.peak as u16) as usize);
1187
1188 // When the partition lookup has gone through the entire partition list, the
1189 // lookup should be stopped, and the location where it should be stopped is
1190 // recorded here.
1191 let previous = if start == 0 { self.peak } else { start - 1 };
1192
1193 loop {
1194 // Finds the first high position in the partition.
1195 if let Some(i) = {
1196 let bucket = self.buckets[start];
1197 let offset = if bucket < u64::MAX {
1198 bucket.leading_ones()
1199 } else {
1200 return None;
1201 };
1202
1203 // Check to see if the jump is beyond the partition list or the lookup exceeds
1204 // the maximum length of the allocation table.
1205 if start == self.peak && offset > self.bit_len {
1206 return None;
1207 }
1208
1209 Some(offset)
1210 } {
1211 index = Some(i as usize);
1212 break;
1213 }
1214
1215 // As long as it doesn't find it, it continues to re-find it from the next
1216 // partition.
1217 if start == self.peak {
1218 start = 0;
1219 } else {
1220 start += 1;
1221 }
1222
1223 // Already gone through all partitions, lookup failed.
1224 if start == previous {
1225 break;
1226 }
1227 }
1228
1229 // Writes to the partition, marking the current location as already allocated.
1230 let index = index?;
1231 self.set_bit(start, index, Bit::High);
1232 self.allocated += 1;
1233
1234 // The actual port number is calculated from the partition offset position.
1235 let num = (start * 64 + index) as u16;
1236 let port = Self::port_range().start + num;
1237 Some(port)
1238 }
1239
1240 /// write bit flag in the bucket.
1241 ///
1242 /// # Test
1243 ///
1244 /// ```
1245 /// use turn::sessions::Bit;
1246 /// use turn::sessions::PortAllocatePools;
1247 ///
1248 /// let mut pool = PortAllocatePools::default();
1249 ///
1250 /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1251 /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1252 ///
1253 /// pool.set_bit(0, 0, Bit::High);
1254 /// pool.set_bit(0, 1, Bit::High);
1255 ///
1256 /// assert_eq!(pool.alloc(Some(0)), Some(49154));
1257 /// assert_eq!(pool.alloc(Some(0)), Some(49155));
1258 /// ```
1259 pub fn set_bit(&mut self, bucket: usize, index: usize, bit: Bit) {
1260 let high_mask = 1 << (63 - index);
1261 let mask = match bit {
1262 Bit::Low => u64::MAX ^ high_mask,
1263 Bit::High => high_mask,
1264 };
1265
1266 let value = self.buckets[bucket];
1267 self.buckets[bucket] = match bit {
1268 Bit::High => value | mask,
1269 Bit::Low => value & mask,
1270 };
1271 }
1272
1273 /// restore port in the buckets.
1274 ///
1275 /// # Test
1276 ///
1277 /// ```
1278 /// use turn::sessions::PortAllocatePools;
1279 ///
1280 /// let mut pool = PortAllocatePools::default();
1281 ///
1282 /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1283 /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1284 ///
1285 /// pool.restore(49152);
1286 /// pool.restore(49153);
1287 ///
1288 /// assert_eq!(pool.alloc(Some(0)), Some(49152));
1289 /// assert_eq!(pool.alloc(Some(0)), Some(49153));
1290 /// ```
1291 pub fn restore(&mut self, port: u16) {
1292 assert!(Self::port_range().contains(&port));
1293
1294 // Calculate the location in the partition from the port number.
1295 let offset = (port - Self::port_range().start) as usize;
1296 let bucket = offset / 64;
1297 let index = offset - (bucket * 64);
1298
1299 // Gets the bit value in the port position in the partition, if it is low, no
1300 // processing is required.
1301 if {
1302 match (self.buckets[bucket] & (1 << (63 - index))) >> (63 - index) {
1303 0 => Bit::Low,
1304 1 => Bit::High,
1305 _ => panic!(),
1306 }
1307 } == Bit::Low
1308 {
1309 return;
1310 }
1311
1312 self.set_bit(bucket, index, Bit::Low);
1313 self.allocated -= 1;
1314 }
1315}