cold_io/
proposer.rs

1// Copyright 2021 Vladislav Melnik
2// SPDX-License-Identifier: MIT
3
4use std::{
5    time::{Duration, Instant},
6    collections::{BTreeMap, BTreeSet},
7    net::{SocketAddr, IpAddr},
8    io, fmt,
9    error::Error,
10};
11use mio::{
12    Poll, Events, Token,
13    net::{TcpListener, TcpStream},
14    Interest,
15};
16
17use super::{
18    request::{Request, ConnectionSource},
19    managed_stream::{ManagedStream, TcpReadOnce, TcpWriteOnce},
20    state::State,
21    proposal::{Proposal, ProposalKind, ConnectionId},
22};
23
24/// The proposer serves the state's requests and provides network events to it.
25pub struct Proposer {
26    started: bool,
27    request: Request,
28    poll: Poll,
29    events_capacity: usize,
30    events: Events,
31    last: Instant,
32    id: u16,
33    listener: Option<TcpListener>,
34    streams: BTreeMap<SocketAddr, ManagedStream>,
35    in_progress: BTreeMap<Token, SocketAddr>,
36    blacklist: BTreeSet<IpAddr>,
37    last_token: Token,
38}
39
40impl Proposer {
41    const LISTENER: Token = Token(usize::MAX);
42
43    /// Set the seed for the random number generator.
44    pub fn new(id: u16, events_capacity: usize) -> io::Result<Self> {
45        let poll = Poll::new()?;
46
47        Ok(Proposer {
48            started: false,
49            request: Request::default(),
50            poll,
51            events_capacity,
52            events: Events::with_capacity(events_capacity),
53            last: Instant::now(),
54            id,
55            listener: None,
56            streams: BTreeMap::default(),
57            in_progress: BTreeMap::default(),
58            blacklist: BTreeSet::default(),
59            last_token: Token(0),
60        })
61    }
62
63    fn allocate_token(&mut self) -> Token {
64        let t = self.last_token;
65        self.last_token = Token(self.last_token.0 + 1);
66        t
67    }
68
69    fn send_proposal<S>(
70        &mut self,
71        rng: S::Rng,
72        state: &mut S,
73        kind: ProposalKind<TcpReadOnce, TcpWriteOnce, S::Ext>,
74    ) where
75        S: State<TcpReadOnce, TcpWriteOnce>,
76    {
77        use std::mem;
78
79        let last = mem::replace(&mut self.last, Instant::now());
80        let proposal = Proposal {
81            rng,
82            elapsed: last.elapsed(),
83            kind,
84        };
85
86        self.request += state.accept(proposal);
87    }
88
89    fn set_source(&mut self, source: ConnectionSource) -> io::Result<()> {
90        if let Some(mut listener) = self.listener.take() {
91            // register/reregister/deregister can only fail in case of the bug
92            // here and further we should panic in such situation,
93            // rather then propagate the error
94            self.poll.registry().deregister(&mut listener).expect("bug");
95        }
96
97        match source {
98            ConnectionSource::None => Ok(()),
99            ConnectionSource::Port(port) => {
100                let mut listener = TcpListener::bind(([0, 0, 0, 0], port).into())?;
101                self.poll
102                    .registry()
103                    .register(&mut listener, Self::LISTENER, Interest::READABLE)
104                    .expect("bug");
105                self.listener = Some(listener);
106                Ok(())
107            },
108        }
109    }
110
111    fn disconnect_peer(&mut self, addr: SocketAddr) -> io::Result<()> {
112        if let Some(stream) = self.streams.remove(&addr) {
113            self.poll
114                .registry()
115                .deregister(stream.borrow_mut().as_mut())
116                .expect("bug");
117            stream.discard()?;
118        }
119
120        Ok(())
121    }
122
123    fn register_stream(
124        &mut self,
125        stream: TcpStream,
126        addr: SocketAddr,
127        interests: Interest,
128    ) -> Token {
129        let token = self.allocate_token();
130        let stream = ManagedStream::new(stream, token);
131        self.poll
132            .registry()
133            .register(stream.borrow_mut().as_mut(), token, interests)
134            .expect("bug");
135        self.streams.insert(addr, stream);
136        self.in_progress.insert(token, addr);
137        token
138    }
139
140    fn connect_peer(&mut self, addr: SocketAddr) -> io::Result<Option<Token>> {
141        if !self.streams.contains_key(&addr) {
142            Ok(Some(self.register_stream(
143                TcpStream::connect(addr)?,
144                addr,
145                Interest::WRITABLE,
146            )))
147        } else {
148            Ok(None)
149        }
150    }
151
152    fn reregister(&mut self) {
153        self.streams.retain(|_, stream| !stream.closed());
154        for (addr, stream) in &self.streams {
155            if let Some(i) = stream.interests() {
156                self.poll
157                    .registry()
158                    .reregister(stream.borrow_mut().as_mut(), stream.token(), i)
159                    .expect("bug");
160                self.in_progress.insert(stream.token(), *addr);
161            }
162        }
163        if let Some(listener) = &mut self.listener {
164            self.poll
165                .registry()
166                .reregister(listener, Self::LISTENER, Interest::READABLE)
167                .expect("bug");
168        }
169    }
170
171    fn take_events(&mut self) -> Events {
172        std::mem::replace(
173            &mut self.events,
174            Events::with_capacity(self.events_capacity),
175        )
176    }
177
178    /// Run the single iteration
179    pub fn run<Rngs, S>(
180        &mut self,
181        rngs: &mut Rngs,
182        state: &mut S,
183        timeout: Duration,
184    ) -> Result<(), ProposerError>
185    where
186        Rngs: Iterator<Item = S::Rng>,
187        S: State<TcpReadOnce, TcpWriteOnce>,
188    {
189        if self.started {
190            self.run_inner(rngs, state, timeout)
191        } else {
192            self.started = true;
193            self.send_proposal(rngs.next().unwrap(), state, ProposalKind::Wake);
194            Ok(())
195        }
196    }
197
198    fn run_inner<Rngs, S>(
199        &mut self,
200        rngs: &mut Rngs,
201        state: &mut S,
202        timeout: Duration,
203    ) -> Result<(), ProposerError>
204    where
205        Rngs: Iterator<Item = S::Rng>,
206        S: State<TcpReadOnce, TcpWriteOnce>,
207    {
208        let mut error = ProposerError::default();
209
210        if let Some(source) = self.request.take_new_source() {
211            if let Err(e) = self.set_source(source) {
212                error.listen_error = Some((source, e));
213            }
214        }
215
216        for addr in self.request.take_blacklist() {
217            self.blacklist.insert(addr.ip());
218            if let Err(e) = self.disconnect_peer(addr) {
219                error.disconnect_errors.push((addr, e));
220            }
221        }
222
223        self.reregister();
224
225        for addr in self.request.take_connects() {
226            match self.connect_peer(addr) {
227                Err(e) => error.connect_errors.push((addr, e)),
228                Ok(None) => (),
229                Ok(Some(token)) => {
230                    let kind = ProposalKind::Connection {
231                        addr,
232                        incoming: true,
233                        id: ConnectionId {
234                            poll_id: self.id,
235                            token: token.0 as u16,
236                        },
237                    };
238                    self.send_proposal(rngs.next().unwrap(), state, kind);
239                },
240            }
241        }
242
243        match self.poll.poll(&mut self.events, Some(timeout)) {
244            Ok(()) => (),
245            Err(e) if e.kind() == io::ErrorKind::Interrupted => (),
246            Err(e) => {
247                let _ = self.take_events();
248                error.poll_error = Some(e);
249                return Err(error);
250            },
251        }
252
253        let events = self.take_events();
254        if events.is_empty() {
255            self.send_proposal(rngs.next().unwrap(), state, ProposalKind::Idle);
256        }
257        for event in events.into_iter() {
258            if event.token() == Self::LISTENER {
259                if let Some(listener) = self.listener.as_mut() {
260                    match listener.accept() {
261                        Ok((stream, addr)) => {
262                            if !self.blacklist.contains(&addr.ip()) {
263                                let token = self.register_stream(stream, addr, Interest::READABLE);
264                                let kind = ProposalKind::Connection {
265                                    addr,
266                                    incoming: true,
267                                    id: ConnectionId {
268                                        poll_id: self.id,
269                                        token: token.0 as u16,
270                                    },
271                                };
272                                self.send_proposal(rngs.next().unwrap(), state, kind);
273                            }
274                        },
275                        Err(e) => {
276                            error.accept_error = Some(e);
277                        },
278                    }
279                }
280            } else if let Some(addr) = self.in_progress.remove(&event.token()) {
281                let stream = self.streams.get(&addr).unwrap();
282                let id = ConnectionId {
283                    poll_id: self.id,
284                    token: stream.token().0 as u16,
285                };
286                let mut pr = Vec::with_capacity(2);
287                if event.is_writable() {
288                    if let Some(w) = stream.write_once() {
289                        pr.push(ProposalKind::OnWritable(id, w));
290                        if event.is_write_closed() {
291                            stream.set_write_closed();
292                        }
293                    } else {
294                        debug_assert!(false, "mio should not poll for this event");
295                    }
296                }
297                if event.is_readable() {
298                    if let Some(r) = stream.read_once() {
299                        pr.push(ProposalKind::OnReadable(id, r));
300                        if event.is_read_closed() {
301                            stream.set_read_closed();
302                        }
303                    } else {
304                        debug_assert!(false, "mio should not poll for this event");
305                    }
306                }
307                for pr in pr {
308                    self.send_proposal(rngs.next().unwrap(), state, pr);
309                }
310            }
311        }
312
313        error.into_result()
314    }
315}
316
317#[derive(Debug, Default)]
318pub struct ProposerError {
319    listen_error: Option<(ConnectionSource, io::Error)>,
320    connect_errors: Vec<(SocketAddr, io::Error)>,
321    disconnect_errors: Vec<(SocketAddr, io::Error)>,
322    accept_error: Option<io::Error>,
323    poll_error: Option<io::Error>,
324}
325
326impl fmt::Display for ProposerError {
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        if let Some((source, error)) = &self.listen_error {
329            write!(f, "failed to listen: {}, error: {}", source, error)?;
330        }
331        for (addr, error) in &self.connect_errors {
332            write!(f, "failed to connect to: {}, error: {}", addr, error)?;
333        }
334        for (addr, error) in &self.disconnect_errors {
335            write!(f, "failed to disconnect from: {}, error: {}", addr, error)?;
336        }
337        if let Some(error) = &self.accept_error {
338            write!(f, "failed to accept a connection, error: {}", error)?;
339        }
340        if let Some(error) = &self.poll_error {
341            write!(f, "failed to poll the events, error: {}", error)?;
342        }
343
344        Ok(())
345    }
346}
347
348impl Error for ProposerError {}
349
350impl ProposerError {
351    pub fn into_result(self) -> Result<(), Self> {
352        if self.is_empty() {
353            Ok(())
354        } else {
355            Err(self)
356        }
357    }
358
359    pub fn is_empty(&self) -> bool {
360        self.listen_error.is_none()
361            && self.connect_errors.is_empty()
362            && self.disconnect_errors.is_empty()
363            && self.accept_error.is_none()
364            && self.poll_error.is_none()
365    }
366}