1use parking_lot::Mutex;
2use std::collections::HashMap;
3use std::io;
4use std::net::SocketAddr;
5use std::time::Duration;
6use stun_types::{Message, TransactionId};
7use tokio::sync::oneshot;
8use tokio::time::timeout;
9
10pub mod auth;
11
12pub trait TransportInfo {
13 fn reliable(&self) -> bool;
14}
15
16pub struct Request<'r, T> {
17 pub bytes: &'r [u8],
18 pub tsx_id: TransactionId,
19 pub transport: &'r T,
20}
21
22pub struct IncomingMessage<T> {
23 pub message: Message,
24 pub source: SocketAddr,
25 pub transport: T,
26}
27
28#[async_trait::async_trait]
36pub trait StunEndpointUser: Send + Sync {
37 type Transport: TransportInfo + Send + Sync;
38
39 async fn send_to(
41 &self,
42 bytes: &[u8],
43 target: SocketAddr,
44 transport: &Self::Transport,
45 ) -> io::Result<()>;
46
47 async fn receive(&self, message: IncomingMessage<Self::Transport>);
50}
51
52pub struct StunEndpoint<U: StunEndpointUser> {
55 user: U,
56 transactions: Mutex<HashMap<TransactionId, Transaction>>,
57}
58
59struct Transaction {
60 sender: oneshot::Sender<Message>,
61}
62
63impl<U: StunEndpointUser> StunEndpoint<U> {
64 pub fn new(user: U) -> Self {
65 Self {
66 user,
67 transactions: Default::default(),
68 }
69 }
70
71 pub fn user(&self) -> &U {
72 &self.user
73 }
74
75 pub fn user_mut(&mut self) -> &mut U {
76 &mut self.user
77 }
78
79 pub async fn send_request(
80 &self,
81 request: Request<'_, U::Transport>,
82 target: SocketAddr,
83 ) -> io::Result<Option<Message>> {
84 struct DropGuard<'s, U>(&'s StunEndpoint<U>, TransactionId)
85 where
86 U: StunEndpointUser;
87
88 impl<U> Drop for DropGuard<'_, U>
89 where
90 U: StunEndpointUser,
91 {
92 fn drop(&mut self) {
93 self.0.transactions.lock().remove(&self.1);
94 }
95 }
96
97 let _guard = DropGuard(self, request.tsx_id);
98
99 let (tx, mut rx) = oneshot::channel();
100 self.transactions
101 .lock()
102 .insert(request.tsx_id, Transaction { sender: tx });
103
104 let mut delta = Duration::from_millis(500);
105
106 if request.transport.reliable() {
107 match timeout(delta, &mut rx).await {
108 Ok(Ok(response)) => Ok(Some(response)),
109 Ok(Err(_)) => unreachable!(),
110 Err(_) => Ok(None),
111 }
112 } else {
113 for _ in 0..7 {
114 self.user
115 .send_to(request.bytes, target, request.transport)
116 .await?;
117
118 match timeout(delta, &mut rx).await {
119 Ok(Ok(response)) => return Ok(Some(response)),
120 Ok(Err(_)) => unreachable!(),
121 Err(_) => {
122 delta *= 2;
123 }
124 }
125 }
126
127 Ok(None)
128 }
129 }
130
131 pub async fn receive(&self, message: Message, source: SocketAddr, transport: U::Transport) {
133 {
134 let mut transactions = self.transactions.lock();
135 if let Some(Transaction { sender }) = transactions.remove(&message.transaction_id()) {
136 let _ = sender.send(message);
137 return;
138 }
139 }
140
141 self.user
142 .receive(IncomingMessage {
143 source,
144 message,
145 transport,
146 })
147 .await;
148 }
149}