rtc_stun/
client.rs

1use bytes::BytesMut;
2use shared::error::*;
3use std::collections::{HashMap, VecDeque};
4use std::io::BufReader;
5use std::net::SocketAddr;
6use std::ops::Add;
7use std::time::{Duration, Instant};
8
9use crate::agent::*;
10use crate::message::*;
11use shared::{TransportContext, TransportMessage, TransportProtocol};
12
13const DEFAULT_TIMEOUT_RATE: Duration = Duration::from_millis(5);
14const DEFAULT_RTO: Duration = Duration::from_millis(300);
15const DEFAULT_MAX_ATTEMPTS: u32 = 7;
16const DEFAULT_MAX_BUFFER_SIZE: usize = 8;
17
18/// ClientTransaction represents transaction in progress.
19/// If transaction is succeed or failed, f will be called
20/// provided by event.
21/// Concurrent access is invalid.
22#[derive(Debug, Clone)]
23pub struct ClientTransaction {
24    id: TransactionId,
25    attempt: u32,
26    start: Instant,
27    rto: Duration,
28    raw: Vec<u8>,
29}
30
31impl ClientTransaction {
32    pub(crate) fn next_timeout(&self, now: Instant) -> Instant {
33        now.add((self.attempt + 1) * self.rto)
34    }
35}
36
37struct ClientSettings {
38    buffer_size: usize,
39    rto: Duration,
40    rto_rate: Duration,
41    max_attempts: u32,
42    closed: bool,
43}
44
45impl Default for ClientSettings {
46    fn default() -> Self {
47        ClientSettings {
48            buffer_size: DEFAULT_MAX_BUFFER_SIZE,
49            rto: DEFAULT_RTO,
50            rto_rate: DEFAULT_TIMEOUT_RATE,
51            max_attempts: DEFAULT_MAX_ATTEMPTS,
52            closed: false,
53        }
54    }
55}
56
57#[derive(Default)]
58pub struct ClientBuilder {
59    settings: ClientSettings,
60}
61
62impl ClientBuilder {
63    /// with_rto sets client RTO as defined in STUN RFC.
64    pub fn with_rto(mut self, rto: Duration) -> Self {
65        self.settings.rto = rto;
66        self
67    }
68
69    /// with_timeout_rate sets RTO timer minimum resolution.
70    pub fn with_timeout_rate(mut self, d: Duration) -> Self {
71        self.settings.rto_rate = d;
72        self
73    }
74
75    /// with_buffer_size sets buffer size.
76    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
77        self.settings.buffer_size = buffer_size;
78        self
79    }
80
81    /// with_no_retransmit disables retransmissions and sets RTO to
82    /// DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO which will be effectively time out
83    /// if not set.
84    /// Useful for TCP connections where transport handles RTO.
85    pub fn with_no_retransmit(mut self) -> Self {
86        self.settings.max_attempts = 0;
87        if self.settings.rto == Duration::from_secs(0) {
88            self.settings.rto = DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO;
89        }
90        self
91    }
92
93    pub fn new() -> Self {
94        ClientBuilder {
95            settings: ClientSettings::default(),
96        }
97    }
98
99    pub fn build(
100        self,
101        local: SocketAddr,
102        remote: SocketAddr,
103        protocol: TransportProtocol,
104    ) -> Result<Client> {
105        Ok(Client::new(local, remote, protocol, self.settings))
106    }
107}
108
109/// Client simulates "connection" to STUN server.
110pub struct Client {
111    local: SocketAddr,
112    remote: SocketAddr,
113    transport_protocol: TransportProtocol,
114    agent: Agent,
115    settings: ClientSettings,
116    transactions: HashMap<TransactionId, ClientTransaction>,
117    transmits: VecDeque<TransportMessage<BytesMut>>,
118}
119
120impl Client {
121    fn new(
122        local: SocketAddr,
123        remote: SocketAddr,
124        transport_protocol: TransportProtocol,
125        settings: ClientSettings,
126    ) -> Self {
127        Self {
128            local,
129            remote,
130            transport_protocol,
131            agent: Agent::new(),
132            settings,
133            transactions: HashMap::new(),
134            transmits: VecDeque::new(),
135        }
136    }
137
138    /// Returns packets to transmit
139    ///
140    /// It should be polled for transmit after:
141    /// - the application performed some I/O
142    /// - a call was made to `handle_read`
143    /// - a call was made to `handle_write`
144    /// - a call was made to `handle_timeout`
145    #[must_use]
146    pub fn poll_transmit(&mut self) -> Option<TransportMessage<BytesMut>> {
147        self.transmits.pop_front()
148    }
149
150    pub fn poll_event(&mut self) -> Option<Event> {
151        while let Some(event) = self.agent.poll_event() {
152            let mut ct = if self.transactions.contains_key(&event.id) {
153                self.transactions.remove(&event.id).unwrap()
154            } else {
155                continue;
156            };
157
158            if ct.attempt >= self.settings.max_attempts || event.result.is_ok() {
159                return Some(event);
160            }
161
162            // Doing re-transmission.
163            ct.attempt += 1;
164
165            let payload = BytesMut::from(&ct.raw[..]);
166            let timeout = ct.next_timeout(Instant::now());
167            let id = ct.id;
168
169            // Starting client transaction.
170            self.transactions.entry(ct.id).or_insert(ct);
171
172            // Starting agent transaction.
173            if self
174                .agent
175                .handle_event(ClientAgent::Start(id, timeout))
176                .is_err()
177            {
178                self.transactions.remove(&id);
179                return Some(event);
180            }
181
182            // Writing message to connection again.
183            self.transmits.push_back(TransportMessage {
184                now: Instant::now(),
185                transport: TransportContext {
186                    local_addr: self.local,
187                    peer_addr: self.remote,
188                    ecn: None,
189                    transport_protocol: self.transport_protocol,
190                },
191                message: payload,
192            });
193        }
194
195        None
196    }
197
198    pub fn handle_read(&mut self, buf: &[u8]) -> Result<()> {
199        let mut msg = Message::new();
200        let mut reader = BufReader::new(buf);
201        msg.read_from(&mut reader)?;
202        self.agent.handle_event(ClientAgent::Process(msg))
203    }
204
205    pub fn handle_write(&mut self, m: Message) -> Result<()> {
206        if self.settings.closed {
207            return Err(Error::ErrClientClosed);
208        }
209
210        let payload = BytesMut::from(&m.raw[..]);
211
212        let ct = ClientTransaction {
213            id: m.transaction_id,
214            attempt: 0,
215            start: Instant::now(),
216            rto: self.settings.rto,
217            raw: m.raw,
218        };
219        let deadline = ct.next_timeout(ct.start);
220        self.transactions.entry(ct.id).or_insert(ct);
221        self.agent
222            .handle_event(ClientAgent::Start(m.transaction_id, deadline))?;
223
224        self.transmits.push_back(TransportMessage {
225            now: Instant::now(),
226            transport: TransportContext {
227                local_addr: self.local,
228                peer_addr: self.remote,
229                ecn: None,
230                transport_protocol: self.transport_protocol,
231            },
232            message: payload,
233        });
234
235        Ok(())
236    }
237
238    pub fn poll_timeout(&mut self) -> Option<Instant> {
239        self.agent.poll_timeout()
240    }
241
242    pub fn handle_timeout(&mut self, now: Instant) -> Result<()> {
243        self.agent.handle_event(ClientAgent::Collect(now))
244    }
245
246    pub fn handle_close(&mut self) -> Result<()> {
247        if self.settings.closed {
248            return Err(Error::ErrClientClosed);
249        }
250        self.settings.closed = true;
251        self.agent.handle_event(ClientAgent::Close)
252    }
253}