1use bytes::BytesMut;
2use shared::error::*;
3use std::collections::{HashMap, VecDeque};
4use std::io::BufReader;
5use std::net::SocketAddr;
6use std::ops::Add;
7use std::time::{Duration, Instant};
8
9use crate::agent::*;
10use crate::message::*;
11use shared::{TaggedBytesMut, TransportContext, TransportMessage, TransportProtocol};
12
13const DEFAULT_TIMEOUT_RATE: Duration = Duration::from_millis(5);
14const DEFAULT_RTO: Duration = Duration::from_millis(300);
15const DEFAULT_MAX_ATTEMPTS: u32 = 7;
16const DEFAULT_MAX_BUFFER_SIZE: usize = 8;
17
18#[derive(Debug, Clone)]
23pub struct ClientTransaction {
24 id: TransactionId,
25 attempt: u32,
26 start: Instant,
27 rto: Duration,
28 raw: Vec<u8>,
29}
30
31impl ClientTransaction {
32 pub(crate) fn next_timeout(&self, now: Instant) -> Instant {
33 now.add((self.attempt + 1) * self.rto)
34 }
35}
36
37struct ClientSettings {
38 buffer_size: usize,
39 rto: Duration,
40 rto_rate: Duration,
41 max_attempts: u32,
42 closed: bool,
43}
44
45impl Default for ClientSettings {
46 fn default() -> Self {
47 ClientSettings {
48 buffer_size: DEFAULT_MAX_BUFFER_SIZE,
49 rto: DEFAULT_RTO,
50 rto_rate: DEFAULT_TIMEOUT_RATE,
51 max_attempts: DEFAULT_MAX_ATTEMPTS,
52 closed: false,
53 }
54 }
55}
56
57#[derive(Default)]
58pub struct ClientBuilder {
59 settings: ClientSettings,
60}
61
62impl ClientBuilder {
63 pub fn with_rto(mut self, rto: Duration) -> Self {
65 self.settings.rto = rto;
66 self
67 }
68
69 pub fn with_timeout_rate(mut self, d: Duration) -> Self {
71 self.settings.rto_rate = d;
72 self
73 }
74
75 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
77 self.settings.buffer_size = buffer_size;
78 self
79 }
80
81 pub fn with_no_retransmit(mut self) -> Self {
86 self.settings.max_attempts = 0;
87 if self.settings.rto == Duration::from_secs(0) {
88 self.settings.rto = DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO;
89 }
90 self
91 }
92
93 pub fn new() -> Self {
94 ClientBuilder {
95 settings: ClientSettings::default(),
96 }
97 }
98
99 pub fn build(
100 self,
101 local: SocketAddr,
102 remote: SocketAddr,
103 protocol: TransportProtocol,
104 ) -> Result<Client> {
105 Ok(Client::new(local, remote, protocol, self.settings))
106 }
107}
108
109pub struct Client {
111 local: SocketAddr,
112 remote: SocketAddr,
113 transport_protocol: TransportProtocol,
114 agent: Agent,
115 settings: ClientSettings,
116 transactions: HashMap<TransactionId, ClientTransaction>,
117 transmits: VecDeque<TransportMessage<BytesMut>>,
118}
119
120impl Client {
121 fn new(
122 local: SocketAddr,
123 remote: SocketAddr,
124 transport_protocol: TransportProtocol,
125 settings: ClientSettings,
126 ) -> Self {
127 Self {
128 local,
129 remote,
130 transport_protocol,
131 agent: Agent::new(),
132 settings,
133 transactions: HashMap::new(),
134 transmits: VecDeque::new(),
135 }
136 }
137}
138
139impl sansio::Protocol<TaggedBytesMut, Message, Event> for Client {
140 type Rout = ();
141 type Wout = TaggedBytesMut;
142 type Eout = Event;
143 type Error = Error;
144 type Time = Instant;
145
146 fn handle_read(&mut self, msg: TaggedBytesMut) -> Result<()> {
147 let mut stun_msg = Message::new();
148 let mut reader = BufReader::new(&msg.message[..]);
149 stun_msg.read_from(&mut reader)?;
150 self.agent.handle_event(ClientAgent::Process(stun_msg))
151 }
152
153 fn poll_read(&mut self) -> Option<Self::Rout> {
154 None
155 }
156
157 fn handle_write(&mut self, m: Message) -> Result<()> {
158 if self.settings.closed {
159 return Err(Error::ErrClientClosed);
160 }
161
162 let payload = BytesMut::from(&m.raw[..]);
163
164 let ct = ClientTransaction {
165 id: m.transaction_id,
166 attempt: 0,
167 start: Instant::now(),
168 rto: self.settings.rto,
169 raw: m.raw,
170 };
171 let deadline = ct.next_timeout(ct.start);
172 self.transactions.entry(ct.id).or_insert(ct);
173 self.agent
174 .handle_event(ClientAgent::Start(m.transaction_id, deadline))?;
175
176 self.transmits.push_back(TransportMessage {
177 now: Instant::now(),
178 transport: TransportContext {
179 local_addr: self.local,
180 peer_addr: self.remote,
181 ecn: None,
182 transport_protocol: self.transport_protocol,
183 },
184 message: payload,
185 });
186
187 Ok(())
188 }
189
190 fn poll_write(&mut self) -> Option<Self::Wout> {
198 self.transmits.pop_front()
199 }
200
201 fn poll_event(&mut self) -> Option<Self::Eout> {
202 while let Some(event) = self.agent.poll_event() {
203 let mut ct = if self.transactions.contains_key(&event.id) {
204 self.transactions.remove(&event.id).unwrap()
205 } else {
206 continue;
207 };
208
209 if ct.attempt >= self.settings.max_attempts || event.result.is_ok() {
210 return Some(event);
211 }
212
213 ct.attempt += 1;
215
216 let payload = BytesMut::from(&ct.raw[..]);
217 let timeout = ct.next_timeout(Instant::now());
218 let id = ct.id;
219
220 self.transactions.entry(ct.id).or_insert(ct);
222
223 if self
225 .agent
226 .handle_event(ClientAgent::Start(id, timeout))
227 .is_err()
228 {
229 self.transactions.remove(&id);
230 return Some(event);
231 }
232
233 self.transmits.push_back(TransportMessage {
235 now: Instant::now(),
236 transport: TransportContext {
237 local_addr: self.local,
238 peer_addr: self.remote,
239 ecn: None,
240 transport_protocol: self.transport_protocol,
241 },
242 message: payload,
243 });
244 }
245
246 None
247 }
248
249 fn poll_timeout(&mut self) -> Option<Self::Time> {
250 self.agent.poll_timeout()
251 }
252
253 fn handle_timeout(&mut self, now: Instant) -> Result<()> {
254 self.agent.handle_event(ClientAgent::Collect(now))
255 }
256
257 fn close(&mut self) -> Result<()> {
258 if self.settings.closed {
259 return Err(Error::ErrClientClosed);
260 }
261 self.settings.closed = true;
262 self.agent.handle_event(ClientAgent::Close)
263 }
264}