1use std::{
5 time::{Duration, Instant},
6 collections::{BTreeMap, BTreeSet},
7 net::{SocketAddr, IpAddr},
8 io, fmt,
9 error::Error,
10};
11use mio::{
12 Poll, Events, Token,
13 net::{TcpListener, TcpStream},
14 Interest,
15};
16
17use super::{
18 request::{Request, ConnectionSource},
19 managed_stream::{ManagedStream, TcpReadOnce, TcpWriteOnce},
20 state::State,
21 proposal::{Proposal, ProposalKind, ConnectionId},
22};
23
24pub struct Proposer {
26 started: bool,
27 request: Request,
28 poll: Poll,
29 events_capacity: usize,
30 events: Events,
31 last: Instant,
32 id: u16,
33 listener: Option<TcpListener>,
34 streams: BTreeMap<SocketAddr, ManagedStream>,
35 in_progress: BTreeMap<Token, SocketAddr>,
36 blacklist: BTreeSet<IpAddr>,
37 last_token: Token,
38}
39
40impl Proposer {
41 const LISTENER: Token = Token(usize::MAX);
42
43 pub fn new(id: u16, events_capacity: usize) -> io::Result<Self> {
45 let poll = Poll::new()?;
46
47 Ok(Proposer {
48 started: false,
49 request: Request::default(),
50 poll,
51 events_capacity,
52 events: Events::with_capacity(events_capacity),
53 last: Instant::now(),
54 id,
55 listener: None,
56 streams: BTreeMap::default(),
57 in_progress: BTreeMap::default(),
58 blacklist: BTreeSet::default(),
59 last_token: Token(0),
60 })
61 }
62
63 fn allocate_token(&mut self) -> Token {
64 let t = self.last_token;
65 self.last_token = Token(self.last_token.0 + 1);
66 t
67 }
68
69 fn send_proposal<S>(
70 &mut self,
71 rng: S::Rng,
72 state: &mut S,
73 kind: ProposalKind<TcpReadOnce, TcpWriteOnce, S::Ext>,
74 ) where
75 S: State<TcpReadOnce, TcpWriteOnce>,
76 {
77 use std::mem;
78
79 let last = mem::replace(&mut self.last, Instant::now());
80 let proposal = Proposal {
81 rng,
82 elapsed: last.elapsed(),
83 kind,
84 };
85
86 self.request += state.accept(proposal);
87 }
88
89 fn set_source(&mut self, source: ConnectionSource) -> io::Result<()> {
90 if let Some(mut listener) = self.listener.take() {
91 self.poll.registry().deregister(&mut listener).expect("bug");
95 }
96
97 match source {
98 ConnectionSource::None => Ok(()),
99 ConnectionSource::Port(port) => {
100 let mut listener = TcpListener::bind(([0, 0, 0, 0], port).into())?;
101 self.poll
102 .registry()
103 .register(&mut listener, Self::LISTENER, Interest::READABLE)
104 .expect("bug");
105 self.listener = Some(listener);
106 Ok(())
107 },
108 }
109 }
110
111 fn disconnect_peer(&mut self, addr: SocketAddr) -> io::Result<()> {
112 if let Some(stream) = self.streams.remove(&addr) {
113 self.poll
114 .registry()
115 .deregister(stream.borrow_mut().as_mut())
116 .expect("bug");
117 stream.discard()?;
118 }
119
120 Ok(())
121 }
122
123 fn register_stream(
124 &mut self,
125 stream: TcpStream,
126 addr: SocketAddr,
127 interests: Interest,
128 ) -> Token {
129 let token = self.allocate_token();
130 let stream = ManagedStream::new(stream, token);
131 self.poll
132 .registry()
133 .register(stream.borrow_mut().as_mut(), token, interests)
134 .expect("bug");
135 self.streams.insert(addr, stream);
136 self.in_progress.insert(token, addr);
137 token
138 }
139
140 fn connect_peer(&mut self, addr: SocketAddr) -> io::Result<Option<Token>> {
141 if !self.streams.contains_key(&addr) {
142 Ok(Some(self.register_stream(
143 TcpStream::connect(addr)?,
144 addr,
145 Interest::WRITABLE,
146 )))
147 } else {
148 Ok(None)
149 }
150 }
151
152 fn reregister(&mut self) {
153 self.streams.retain(|_, stream| !stream.closed());
154 for (addr, stream) in &self.streams {
155 if let Some(i) = stream.interests() {
156 self.poll
157 .registry()
158 .reregister(stream.borrow_mut().as_mut(), stream.token(), i)
159 .expect("bug");
160 self.in_progress.insert(stream.token(), *addr);
161 }
162 }
163 if let Some(listener) = &mut self.listener {
164 self.poll
165 .registry()
166 .reregister(listener, Self::LISTENER, Interest::READABLE)
167 .expect("bug");
168 }
169 }
170
171 fn take_events(&mut self) -> Events {
172 std::mem::replace(
173 &mut self.events,
174 Events::with_capacity(self.events_capacity),
175 )
176 }
177
178 pub fn run<Rngs, S>(
180 &mut self,
181 rngs: &mut Rngs,
182 state: &mut S,
183 timeout: Duration,
184 ) -> Result<(), ProposerError>
185 where
186 Rngs: Iterator<Item = S::Rng>,
187 S: State<TcpReadOnce, TcpWriteOnce>,
188 {
189 if self.started {
190 self.run_inner(rngs, state, timeout)
191 } else {
192 self.started = true;
193 self.send_proposal(rngs.next().unwrap(), state, ProposalKind::Wake);
194 Ok(())
195 }
196 }
197
198 fn run_inner<Rngs, S>(
199 &mut self,
200 rngs: &mut Rngs,
201 state: &mut S,
202 timeout: Duration,
203 ) -> Result<(), ProposerError>
204 where
205 Rngs: Iterator<Item = S::Rng>,
206 S: State<TcpReadOnce, TcpWriteOnce>,
207 {
208 let mut error = ProposerError::default();
209
210 if let Some(source) = self.request.take_new_source() {
211 if let Err(e) = self.set_source(source) {
212 error.listen_error = Some((source, e));
213 }
214 }
215
216 for addr in self.request.take_blacklist() {
217 self.blacklist.insert(addr.ip());
218 if let Err(e) = self.disconnect_peer(addr) {
219 error.disconnect_errors.push((addr, e));
220 }
221 }
222
223 self.reregister();
224
225 for addr in self.request.take_connects() {
226 match self.connect_peer(addr) {
227 Err(e) => error.connect_errors.push((addr, e)),
228 Ok(None) => (),
229 Ok(Some(token)) => {
230 let kind = ProposalKind::Connection {
231 addr,
232 incoming: true,
233 id: ConnectionId {
234 poll_id: self.id,
235 token: token.0 as u16,
236 },
237 };
238 self.send_proposal(rngs.next().unwrap(), state, kind);
239 },
240 }
241 }
242
243 match self.poll.poll(&mut self.events, Some(timeout)) {
244 Ok(()) => (),
245 Err(e) if e.kind() == io::ErrorKind::Interrupted => (),
246 Err(e) => {
247 let _ = self.take_events();
248 error.poll_error = Some(e);
249 return Err(error);
250 },
251 }
252
253 let events = self.take_events();
254 if events.is_empty() {
255 self.send_proposal(rngs.next().unwrap(), state, ProposalKind::Idle);
256 }
257 for event in events.into_iter() {
258 if event.token() == Self::LISTENER {
259 if let Some(listener) = self.listener.as_mut() {
260 match listener.accept() {
261 Ok((stream, addr)) => {
262 if !self.blacklist.contains(&addr.ip()) {
263 let token = self.register_stream(stream, addr, Interest::READABLE);
264 let kind = ProposalKind::Connection {
265 addr,
266 incoming: true,
267 id: ConnectionId {
268 poll_id: self.id,
269 token: token.0 as u16,
270 },
271 };
272 self.send_proposal(rngs.next().unwrap(), state, kind);
273 }
274 },
275 Err(e) => {
276 error.accept_error = Some(e);
277 },
278 }
279 }
280 } else if let Some(addr) = self.in_progress.remove(&event.token()) {
281 let stream = self.streams.get(&addr).unwrap();
282 let id = ConnectionId {
283 poll_id: self.id,
284 token: stream.token().0 as u16,
285 };
286 let mut pr = Vec::with_capacity(2);
287 if event.is_writable() {
288 if let Some(w) = stream.write_once() {
289 pr.push(ProposalKind::OnWritable(id, w));
290 if event.is_write_closed() {
291 stream.set_write_closed();
292 }
293 } else {
294 debug_assert!(false, "mio should not poll for this event");
295 }
296 }
297 if event.is_readable() {
298 if let Some(r) = stream.read_once() {
299 pr.push(ProposalKind::OnReadable(id, r));
300 if event.is_read_closed() {
301 stream.set_read_closed();
302 }
303 } else {
304 debug_assert!(false, "mio should not poll for this event");
305 }
306 }
307 for pr in pr {
308 self.send_proposal(rngs.next().unwrap(), state, pr);
309 }
310 }
311 }
312
313 error.into_result()
314 }
315}
316
317#[derive(Debug, Default)]
318pub struct ProposerError {
319 listen_error: Option<(ConnectionSource, io::Error)>,
320 connect_errors: Vec<(SocketAddr, io::Error)>,
321 disconnect_errors: Vec<(SocketAddr, io::Error)>,
322 accept_error: Option<io::Error>,
323 poll_error: Option<io::Error>,
324}
325
326impl fmt::Display for ProposerError {
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 if let Some((source, error)) = &self.listen_error {
329 write!(f, "failed to listen: {}, error: {}", source, error)?;
330 }
331 for (addr, error) in &self.connect_errors {
332 write!(f, "failed to connect to: {}, error: {}", addr, error)?;
333 }
334 for (addr, error) in &self.disconnect_errors {
335 write!(f, "failed to disconnect from: {}, error: {}", addr, error)?;
336 }
337 if let Some(error) = &self.accept_error {
338 write!(f, "failed to accept a connection, error: {}", error)?;
339 }
340 if let Some(error) = &self.poll_error {
341 write!(f, "failed to poll the events, error: {}", error)?;
342 }
343
344 Ok(())
345 }
346}
347
348impl Error for ProposerError {}
349
350impl ProposerError {
351 pub fn into_result(self) -> Result<(), Self> {
352 if self.is_empty() {
353 Ok(())
354 } else {
355 Err(self)
356 }
357 }
358
359 pub fn is_empty(&self) -> bool {
360 self.listen_error.is_none()
361 && self.connect_errors.is_empty()
362 && self.disconnect_errors.is_empty()
363 && self.accept_error.is_none()
364 && self.poll_error.is_none()
365 }
366}