rtc_stun/
client.rs

1use bytes::BytesMut;
2use shared::error::*;
3use std::collections::{HashMap, VecDeque};
4use std::io::BufReader;
5use std::net::SocketAddr;
6use std::ops::Add;
7use std::time::{Duration, Instant};
8
9use crate::agent::*;
10use crate::message::*;
11use shared::{TaggedBytesMut, TransportContext, TransportMessage, TransportProtocol};
12
13const DEFAULT_TIMEOUT_RATE: Duration = Duration::from_millis(5);
14const DEFAULT_RTO: Duration = Duration::from_millis(300);
15const DEFAULT_MAX_ATTEMPTS: u32 = 7;
16const DEFAULT_MAX_BUFFER_SIZE: usize = 8;
17
18/// ClientTransaction represents transaction in progress.
19/// If transaction is succeed or failed, f will be called
20/// provided by event.
21/// Concurrent access is invalid.
22#[derive(Debug, Clone)]
23pub struct ClientTransaction {
24    id: TransactionId,
25    attempt: u32,
26    start: Instant,
27    rto: Duration,
28    raw: Vec<u8>,
29}
30
31impl ClientTransaction {
32    pub(crate) fn next_timeout(&self, now: Instant) -> Instant {
33        now.add((self.attempt + 1) * self.rto)
34    }
35}
36
37struct ClientSettings {
38    buffer_size: usize,
39    rto: Duration,
40    rto_rate: Duration,
41    max_attempts: u32,
42    closed: bool,
43}
44
45impl Default for ClientSettings {
46    fn default() -> Self {
47        ClientSettings {
48            buffer_size: DEFAULT_MAX_BUFFER_SIZE,
49            rto: DEFAULT_RTO,
50            rto_rate: DEFAULT_TIMEOUT_RATE,
51            max_attempts: DEFAULT_MAX_ATTEMPTS,
52            closed: false,
53        }
54    }
55}
56
57#[derive(Default)]
58pub struct ClientBuilder {
59    settings: ClientSettings,
60}
61
62impl ClientBuilder {
63    /// with_rto sets client RTO as defined in STUN RFC.
64    pub fn with_rto(mut self, rto: Duration) -> Self {
65        self.settings.rto = rto;
66        self
67    }
68
69    /// with_timeout_rate sets RTO timer minimum resolution.
70    pub fn with_timeout_rate(mut self, d: Duration) -> Self {
71        self.settings.rto_rate = d;
72        self
73    }
74
75    /// with_buffer_size sets buffer size.
76    pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
77        self.settings.buffer_size = buffer_size;
78        self
79    }
80
81    /// with_no_retransmit disables retransmissions and sets RTO to
82    /// DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO which will be effectively time out
83    /// if not set.
84    /// Useful for TCP connections where transport handles RTO.
85    pub fn with_no_retransmit(mut self) -> Self {
86        self.settings.max_attempts = 0;
87        if self.settings.rto == Duration::from_secs(0) {
88            self.settings.rto = DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO;
89        }
90        self
91    }
92
93    pub fn new() -> Self {
94        ClientBuilder {
95            settings: ClientSettings::default(),
96        }
97    }
98
99    pub fn build(
100        self,
101        local: SocketAddr,
102        remote: SocketAddr,
103        protocol: TransportProtocol,
104    ) -> Result<Client> {
105        Ok(Client::new(local, remote, protocol, self.settings))
106    }
107}
108
109/// Client simulates "connection" to STUN server.
110pub struct Client {
111    local: SocketAddr,
112    remote: SocketAddr,
113    transport_protocol: TransportProtocol,
114    agent: Agent,
115    settings: ClientSettings,
116    transactions: HashMap<TransactionId, ClientTransaction>,
117    transmits: VecDeque<TransportMessage<BytesMut>>,
118}
119
120impl Client {
121    fn new(
122        local: SocketAddr,
123        remote: SocketAddr,
124        transport_protocol: TransportProtocol,
125        settings: ClientSettings,
126    ) -> Self {
127        Self {
128            local,
129            remote,
130            transport_protocol,
131            agent: Agent::new(),
132            settings,
133            transactions: HashMap::new(),
134            transmits: VecDeque::new(),
135        }
136    }
137}
138
139impl sansio::Protocol<TaggedBytesMut, Message, Event> for Client {
140    type Rout = ();
141    type Wout = TaggedBytesMut;
142    type Eout = Event;
143    type Error = Error;
144    type Time = Instant;
145
146    fn handle_read(&mut self, msg: TaggedBytesMut) -> Result<()> {
147        let mut stun_msg = Message::new();
148        let mut reader = BufReader::new(&msg.message[..]);
149        stun_msg.read_from(&mut reader)?;
150        self.agent.handle_event(ClientAgent::Process(stun_msg))
151    }
152
153    fn poll_read(&mut self) -> Option<Self::Rout> {
154        None
155    }
156
157    fn handle_write(&mut self, m: Message) -> Result<()> {
158        if self.settings.closed {
159            return Err(Error::ErrClientClosed);
160        }
161
162        let payload = BytesMut::from(&m.raw[..]);
163
164        let ct = ClientTransaction {
165            id: m.transaction_id,
166            attempt: 0,
167            start: Instant::now(),
168            rto: self.settings.rto,
169            raw: m.raw,
170        };
171        let deadline = ct.next_timeout(ct.start);
172        self.transactions.entry(ct.id).or_insert(ct);
173        self.agent
174            .handle_event(ClientAgent::Start(m.transaction_id, deadline))?;
175
176        self.transmits.push_back(TransportMessage {
177            now: Instant::now(),
178            transport: TransportContext {
179                local_addr: self.local,
180                peer_addr: self.remote,
181                ecn: None,
182                transport_protocol: self.transport_protocol,
183            },
184            message: payload,
185        });
186
187        Ok(())
188    }
189
190    /// Returns packets to transmit
191    ///
192    /// It should be polled for transmit after:
193    /// - the application performed some I/O
194    /// - a call was made to `handle_read`
195    /// - a call was made to `handle_write`
196    /// - a call was made to `handle_timeout`
197    fn poll_write(&mut self) -> Option<Self::Wout> {
198        self.transmits.pop_front()
199    }
200
201    fn poll_event(&mut self) -> Option<Self::Eout> {
202        while let Some(event) = self.agent.poll_event() {
203            let mut ct = if self.transactions.contains_key(&event.id) {
204                self.transactions.remove(&event.id).unwrap()
205            } else {
206                continue;
207            };
208
209            if ct.attempt >= self.settings.max_attempts || event.result.is_ok() {
210                return Some(event);
211            }
212
213            // Doing re-transmission.
214            ct.attempt += 1;
215
216            let payload = BytesMut::from(&ct.raw[..]);
217            let timeout = ct.next_timeout(Instant::now());
218            let id = ct.id;
219
220            // Starting client transaction.
221            self.transactions.entry(ct.id).or_insert(ct);
222
223            // Starting agent transaction.
224            if self
225                .agent
226                .handle_event(ClientAgent::Start(id, timeout))
227                .is_err()
228            {
229                self.transactions.remove(&id);
230                return Some(event);
231            }
232
233            // Writing message to connection again.
234            self.transmits.push_back(TransportMessage {
235                now: Instant::now(),
236                transport: TransportContext {
237                    local_addr: self.local,
238                    peer_addr: self.remote,
239                    ecn: None,
240                    transport_protocol: self.transport_protocol,
241                },
242                message: payload,
243            });
244        }
245
246        None
247    }
248
249    fn poll_timeout(&mut self) -> Option<Self::Time> {
250        self.agent.poll_timeout()
251    }
252
253    fn handle_timeout(&mut self, now: Instant) -> Result<()> {
254        self.agent.handle_event(ClientAgent::Collect(now))
255    }
256
257    fn close(&mut self) -> Result<()> {
258        if self.settings.closed {
259            return Err(Error::ErrClientClosed);
260        }
261        self.settings.closed = true;
262        self.agent.handle_event(ClientAgent::Close)
263    }
264}