1use bytes::BytesMut;
2use shared::error::*;
3use std::collections::{HashMap, VecDeque};
4use std::io::BufReader;
5use std::net::SocketAddr;
6use std::ops::Add;
7use std::time::{Duration, Instant};
8
9use crate::agent::*;
10use crate::message::*;
11use shared::{TransportContext, TransportMessage, TransportProtocol};
12
13const DEFAULT_TIMEOUT_RATE: Duration = Duration::from_millis(5);
14const DEFAULT_RTO: Duration = Duration::from_millis(300);
15const DEFAULT_MAX_ATTEMPTS: u32 = 7;
16const DEFAULT_MAX_BUFFER_SIZE: usize = 8;
17
18#[derive(Debug, Clone)]
23pub struct ClientTransaction {
24 id: TransactionId,
25 attempt: u32,
26 start: Instant,
27 rto: Duration,
28 raw: Vec<u8>,
29}
30
31impl ClientTransaction {
32 pub(crate) fn next_timeout(&self, now: Instant) -> Instant {
33 now.add((self.attempt + 1) * self.rto)
34 }
35}
36
37struct ClientSettings {
38 buffer_size: usize,
39 rto: Duration,
40 rto_rate: Duration,
41 max_attempts: u32,
42 closed: bool,
43}
44
45impl Default for ClientSettings {
46 fn default() -> Self {
47 ClientSettings {
48 buffer_size: DEFAULT_MAX_BUFFER_SIZE,
49 rto: DEFAULT_RTO,
50 rto_rate: DEFAULT_TIMEOUT_RATE,
51 max_attempts: DEFAULT_MAX_ATTEMPTS,
52 closed: false,
53 }
54 }
55}
56
57#[derive(Default)]
58pub struct ClientBuilder {
59 settings: ClientSettings,
60}
61
62impl ClientBuilder {
63 pub fn with_rto(mut self, rto: Duration) -> Self {
65 self.settings.rto = rto;
66 self
67 }
68
69 pub fn with_timeout_rate(mut self, d: Duration) -> Self {
71 self.settings.rto_rate = d;
72 self
73 }
74
75 pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
77 self.settings.buffer_size = buffer_size;
78 self
79 }
80
81 pub fn with_no_retransmit(mut self) -> Self {
86 self.settings.max_attempts = 0;
87 if self.settings.rto == Duration::from_secs(0) {
88 self.settings.rto = DEFAULT_MAX_ATTEMPTS * DEFAULT_RTO;
89 }
90 self
91 }
92
93 pub fn new() -> Self {
94 ClientBuilder {
95 settings: ClientSettings::default(),
96 }
97 }
98
99 pub fn build(
100 self,
101 local: SocketAddr,
102 remote: SocketAddr,
103 protocol: TransportProtocol,
104 ) -> Result<Client> {
105 Ok(Client::new(local, remote, protocol, self.settings))
106 }
107}
108
109pub struct Client {
111 local: SocketAddr,
112 remote: SocketAddr,
113 transport_protocol: TransportProtocol,
114 agent: Agent,
115 settings: ClientSettings,
116 transactions: HashMap<TransactionId, ClientTransaction>,
117 transmits: VecDeque<TransportMessage<BytesMut>>,
118}
119
120impl Client {
121 fn new(
122 local: SocketAddr,
123 remote: SocketAddr,
124 transport_protocol: TransportProtocol,
125 settings: ClientSettings,
126 ) -> Self {
127 Self {
128 local,
129 remote,
130 transport_protocol,
131 agent: Agent::new(),
132 settings,
133 transactions: HashMap::new(),
134 transmits: VecDeque::new(),
135 }
136 }
137
138 #[must_use]
146 pub fn poll_transmit(&mut self) -> Option<TransportMessage<BytesMut>> {
147 self.transmits.pop_front()
148 }
149
150 pub fn poll_event(&mut self) -> Option<Event> {
151 while let Some(event) = self.agent.poll_event() {
152 let mut ct = if self.transactions.contains_key(&event.id) {
153 self.transactions.remove(&event.id).unwrap()
154 } else {
155 continue;
156 };
157
158 if ct.attempt >= self.settings.max_attempts || event.result.is_ok() {
159 return Some(event);
160 }
161
162 ct.attempt += 1;
164
165 let payload = BytesMut::from(&ct.raw[..]);
166 let timeout = ct.next_timeout(Instant::now());
167 let id = ct.id;
168
169 self.transactions.entry(ct.id).or_insert(ct);
171
172 if self
174 .agent
175 .handle_event(ClientAgent::Start(id, timeout))
176 .is_err()
177 {
178 self.transactions.remove(&id);
179 return Some(event);
180 }
181
182 self.transmits.push_back(TransportMessage {
184 now: Instant::now(),
185 transport: TransportContext {
186 local_addr: self.local,
187 peer_addr: self.remote,
188 ecn: None,
189 transport_protocol: self.transport_protocol,
190 },
191 message: payload,
192 });
193 }
194
195 None
196 }
197
198 pub fn handle_read(&mut self, buf: &[u8]) -> Result<()> {
199 let mut msg = Message::new();
200 let mut reader = BufReader::new(buf);
201 msg.read_from(&mut reader)?;
202 self.agent.handle_event(ClientAgent::Process(msg))
203 }
204
205 pub fn handle_write(&mut self, m: Message) -> Result<()> {
206 if self.settings.closed {
207 return Err(Error::ErrClientClosed);
208 }
209
210 let payload = BytesMut::from(&m.raw[..]);
211
212 let ct = ClientTransaction {
213 id: m.transaction_id,
214 attempt: 0,
215 start: Instant::now(),
216 rto: self.settings.rto,
217 raw: m.raw,
218 };
219 let deadline = ct.next_timeout(ct.start);
220 self.transactions.entry(ct.id).or_insert(ct);
221 self.agent
222 .handle_event(ClientAgent::Start(m.transaction_id, deadline))?;
223
224 self.transmits.push_back(TransportMessage {
225 now: Instant::now(),
226 transport: TransportContext {
227 local_addr: self.local,
228 peer_addr: self.remote,
229 ecn: None,
230 transport_protocol: self.transport_protocol,
231 },
232 message: payload,
233 });
234
235 Ok(())
236 }
237
238 pub fn poll_timeout(&mut self) -> Option<Instant> {
239 self.agent.poll_timeout()
240 }
241
242 pub fn handle_timeout(&mut self, now: Instant) -> Result<()> {
243 self.agent.handle_event(ClientAgent::Collect(now))
244 }
245
246 pub fn handle_close(&mut self) -> Result<()> {
247 if self.settings.closed {
248 return Err(Error::ErrClientClosed);
249 }
250 self.settings.closed = true;
251 self.agent.handle_event(ClientAgent::Close)
252 }
253}