ezk_stun/
lib.rs

1use parking_lot::Mutex;
2use std::collections::HashMap;
3use std::io;
4use std::net::SocketAddr;
5use std::time::Duration;
6use stun_types::{Message, TransactionId};
7use tokio::sync::oneshot;
8use tokio::time::timeout;
9
10pub mod auth;
11
12pub trait TransportInfo {
13    fn reliable(&self) -> bool;
14}
15
16pub struct Request<'r, T> {
17    pub bytes: &'r [u8],
18    pub tsx_id: TransactionId,
19    pub transport: &'r T,
20}
21
22pub struct IncomingMessage<T> {
23    pub message: Message,
24    pub source: SocketAddr,
25    pub transport: T,
26}
27
28/// Defines the "user" of a [`StunEndpoint`].
29///
30/// It is designed to be somewhat flexible and transport agnostic.
31///
32/// When using a [`StunEndpoint`] for multiple transports `UserData`
33/// can be used to either pass the transport around directly or
34/// have just be an identifying key.
35#[async_trait::async_trait]
36pub trait StunEndpointUser: Send + Sync {
37    type Transport: TransportInfo + Send + Sync;
38
39    /// Send the given `bytes` to `target` with the given `transport`.
40    async fn send_to(
41        &self,
42        bytes: &[u8],
43        target: SocketAddr,
44        transport: &Self::Transport,
45    ) -> io::Result<()>;
46
47    /// Called by [`StunEndpoint::receive`] when it encounters a message
48    /// without a matching transaction id.
49    async fn receive(&self, message: IncomingMessage<Self::Transport>);
50}
51
52/// Transport agnostic endpoint. Uses [`StunEndpointUser`] to define
53/// send/receive behavior.
54pub struct StunEndpoint<U: StunEndpointUser> {
55    user: U,
56    transactions: Mutex<HashMap<TransactionId, Transaction>>,
57}
58
59struct Transaction {
60    sender: oneshot::Sender<Message>,
61}
62
63impl<U: StunEndpointUser> StunEndpoint<U> {
64    pub fn new(user: U) -> Self {
65        Self {
66            user,
67            transactions: Default::default(),
68        }
69    }
70
71    pub fn user(&self) -> &U {
72        &self.user
73    }
74
75    pub fn user_mut(&mut self) -> &mut U {
76        &mut self.user
77    }
78
79    pub async fn send_request(
80        &self,
81        request: Request<'_, U::Transport>,
82        target: SocketAddr,
83    ) -> io::Result<Option<Message>> {
84        struct DropGuard<'s, U>(&'s StunEndpoint<U>, TransactionId)
85        where
86            U: StunEndpointUser;
87
88        impl<U> Drop for DropGuard<'_, U>
89        where
90            U: StunEndpointUser,
91        {
92            fn drop(&mut self) {
93                self.0.transactions.lock().remove(&self.1);
94            }
95        }
96
97        let _guard = DropGuard(self, request.tsx_id);
98
99        let (tx, mut rx) = oneshot::channel();
100        self.transactions
101            .lock()
102            .insert(request.tsx_id, Transaction { sender: tx });
103
104        let mut delta = Duration::from_millis(500);
105
106        if request.transport.reliable() {
107            match timeout(delta, &mut rx).await {
108                Ok(Ok(response)) => Ok(Some(response)),
109                Ok(Err(_)) => unreachable!(),
110                Err(_) => Ok(None),
111            }
112        } else {
113            for _ in 0..7 {
114                self.user
115                    .send_to(request.bytes, target, request.transport)
116                    .await?;
117
118                match timeout(delta, &mut rx).await {
119                    Ok(Ok(response)) => return Ok(Some(response)),
120                    Ok(Err(_)) => unreachable!(),
121                    Err(_) => {
122                        delta *= 2;
123                    }
124                }
125            }
126
127            Ok(None)
128        }
129    }
130
131    /// Pass a received STUN message to the endpoint for further processing
132    pub async fn receive(&self, message: Message, source: SocketAddr, transport: U::Transport) {
133        {
134            let mut transactions = self.transactions.lock();
135            if let Some(Transaction { sender }) = transactions.remove(&message.transaction_id()) {
136                let _ = sender.send(message);
137                return;
138            }
139        }
140
141        self.user
142            .receive(IncomingMessage {
143                source,
144                message,
145                transport,
146            })
147            .await;
148    }
149}