libtw2_net/
connection.rs

1use crate::protocol;
2use crate::protocol::ChunksIter;
3use crate::protocol::ConnectedPacket;
4use crate::protocol::ConnectedPacketType;
5use crate::protocol::ControlPacket;
6use crate::protocol::Packet;
7use crate::protocol::Token;
8use crate::protocol::MAX_PACKETSIZE;
9use crate::protocol::MAX_PAYLOAD;
10use crate::protocol::TOKEN_NONE;
11use crate::Timeout;
12use crate::Timestamp;
13use arrayvec::ArrayVec;
14use buffer::with_buffer;
15use buffer::Buffer;
16use buffer::BufferRef;
17use std::cmp;
18use std::collections::VecDeque;
19use std::iter;
20use std::time::Duration;
21use warn::Warn;
22
23// TODO: Implement receive timeout.
24// TODO: Don't allow for unbounded backlog of vital messages.
25
26pub trait Callback {
27    type Error;
28    fn secure_random(&mut self, buffer: &mut [u8]);
29    fn send(&mut self, buffer: &[u8]) -> Result<(), Self::Error>;
30    fn time(&mut self) -> Timestamp;
31}
32
33#[derive(Debug)]
34pub enum Error<CE> {
35    TooLongData,
36    Callback(CE),
37}
38
39impl<CE> From<CE> for Error<CE> {
40    fn from(e: CE) -> Error<CE> {
41        Error::Callback(e)
42    }
43}
44
45impl<CE> Error<CE> {
46    pub fn unwrap_callback(self) -> CE {
47        match self {
48            Error::TooLongData => panic!("too long data"),
49            Error::Callback(e) => e,
50        }
51    }
52}
53
54#[derive(Debug)]
55pub enum Warning {
56    Packet(protocol::Warning),
57    Read(protocol::PacketReadError),
58    TokenMismatch,
59    Unexpected,
60}
61
62trait TimeoutExt {
63    fn set<CB: Callback>(&mut self, cb: &mut CB, value: Duration);
64    fn has_triggered_level<CB: Callback>(&self, cb: &mut CB) -> bool;
65    fn has_triggered_edge<CB: Callback>(&mut self, cb: &mut CB) -> bool;
66}
67
68impl TimeoutExt for Timeout {
69    fn set<CB: Callback>(&mut self, cb: &mut CB, value: Duration) {
70        *self = Timeout::active(cb.time() + value);
71    }
72    fn has_triggered_level<CB: Callback>(&self, cb: &mut CB) -> bool {
73        self.to_opt().map(|time| time <= cb.time()).unwrap_or(false)
74    }
75    fn has_triggered_edge<CB: Callback>(&mut self, cb: &mut CB) -> bool {
76        let triggered = self.has_triggered_level(cb);
77        if triggered {
78            *self = Timeout::inactive();
79        }
80        triggered
81    }
82}
83
84pub struct Connection {
85    state: State,
86    send: Timeout,
87    builder: PacketBuilder,
88}
89
90#[derive(Clone, Debug)]
91enum State {
92    Unconnected,
93    Connecting,
94    Pending(PendingState),
95    Online(OnlineState),
96    Disconnected,
97}
98
99impl State {
100    pub fn assert_online(&mut self) -> &mut OnlineState {
101        match *self {
102            State::Online(ref mut s) => s,
103            _ => panic!("state not online"),
104        }
105    }
106    pub fn token(&self) -> Option<&Option<Token>> {
107        match *self {
108            State::Unconnected => None,
109            State::Connecting => None,
110            State::Pending(ref pending) => Some(&pending.token),
111            State::Online(ref online) => Some(&online.token),
112            State::Disconnected => None,
113        }
114    }
115}
116
117#[derive(Clone, Debug)]
118struct ResendChunk {
119    next_send: Timeout,
120    sequence: Sequence,
121    data: ArrayVec<[u8; 2048]>,
122}
123
124impl ResendChunk {
125    fn new<CB: Callback>(cb: &mut CB, sequence: Sequence, data: &[u8]) -> ResendChunk {
126        let mut result = ResendChunk {
127            next_send: Timeout::inactive(),
128            sequence: sequence,
129            data: data.iter().cloned().collect(),
130        };
131        assert!(
132            result.data.len() == data.len(),
133            "overlong resend packet {}",
134            data.len()
135        );
136        result.start_timeout(cb);
137        result
138    }
139    fn start_timeout<CB: Callback>(&mut self, cb: &mut CB) {
140        self.next_send.set(cb, Duration::from_millis(1_000));
141    }
142}
143
144pub struct ReceivePacket<'a> {
145    type_: ReceivePacketType<'a>,
146}
147
148impl<'a> Clone for ReceivePacket<'a> {
149    fn clone(&self) -> ReceivePacket<'a> {
150        ReceivePacket {
151            type_: self.type_.clone(),
152        }
153    }
154}
155
156impl<'a> ReceivePacket<'a> {
157    fn none() -> ReceivePacket<'a> {
158        ReceivePacket {
159            type_: ReceivePacketType::None,
160        }
161    }
162    fn ready() -> ReceivePacket<'a> {
163        ReceivePacket {
164            type_: ReceivePacketType::Ready(iter::once(())),
165        }
166    }
167    fn connless(data: &[u8]) -> ReceivePacket {
168        ReceivePacket {
169            type_: ReceivePacketType::Connless(iter::once(data)),
170        }
171    }
172    fn connected<W>(
173        warn: &mut W,
174        online: &mut OnlineState,
175        num_chunks: u8,
176        data: &'a [u8],
177    ) -> ReceivePacket<'a>
178    where
179        W: Warn<Warning>,
180    {
181        let chunks_iter = ChunksIter::new(data, num_chunks);
182        let ack = online.ack.clone();
183        let mut iter = chunks_iter.clone();
184        while let Some(c) = iter.next_warn(&mut w(warn)) {
185            if let Some((sequence, resend)) = c.vital {
186                let _ = resend;
187                if online.ack.update(Sequence::from_u16(sequence)) != SequenceOrdering::Current {
188                    online.request_resend = true;
189                }
190            }
191        }
192        ReceivePacket {
193            type_: ReceivePacketType::Connected(ReceiveChunks {
194                ack: ack,
195                chunks: chunks_iter,
196            }),
197        }
198    }
199    fn disconnect(reason: &[u8]) -> ReceivePacket {
200        ReceivePacket {
201            type_: ReceivePacketType::Close(iter::once(reason)),
202        }
203    }
204}
205
206#[derive(Clone)]
207enum ReceivePacketType<'a> {
208    None,
209    Connless(iter::Once<&'a [u8]>),
210    Connected(ReceiveChunks<'a>),
211    Ready(iter::Once<()>),
212    Close(iter::Once<&'a [u8]>),
213}
214
215impl<'a> Iterator for ReceivePacket<'a> {
216    type Item = ReceiveChunk<'a>;
217    fn next(&mut self) -> Option<ReceiveChunk<'a>> {
218        match self.type_ {
219            ReceivePacketType::None => None,
220            ReceivePacketType::Ready(ref mut once) => once.next().map(|()| ReceiveChunk::Ready),
221            ReceivePacketType::Connless(ref mut once) => once.next().map(ReceiveChunk::Connless),
222            ReceivePacketType::Connected(ref mut chunks) => chunks.next(),
223            ReceivePacketType::Close(ref mut once) => once.next().map(ReceiveChunk::Disconnect),
224        }
225    }
226    fn size_hint(&self) -> (usize, Option<usize>) {
227        let len = self.clone().count();
228        (len, Some(len))
229    }
230}
231
232impl<'a> ExactSizeIterator for ReceivePacket<'a> {}
233
234#[derive(Clone)]
235struct ReceiveChunks<'a> {
236    ack: Sequence,
237    chunks: ChunksIter<'a>,
238}
239
240impl<'a> Iterator for ReceiveChunks<'a> {
241    type Item = ReceiveChunk<'a>;
242    fn next(&mut self) -> Option<ReceiveChunk<'a>> {
243        self.chunks.next().and_then(|c| {
244            if let Some((sequence, resend)) = c.vital {
245                let _ = resend;
246                if self.ack.update(Sequence::from_u16(sequence)) != SequenceOrdering::Current {
247                    return self.next();
248                }
249            }
250            Some(ReceiveChunk::Connected(c.data, c.vital.is_some()))
251        })
252    }
253    fn size_hint(&self) -> (usize, Option<usize>) {
254        let len = self.clone().count();
255        (len, Some(len))
256    }
257}
258
259impl<'a> ExactSizeIterator for ReceiveChunks<'a> {}
260
261#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
262pub enum ReceiveChunk<'a> {
263    Connless(&'a [u8]),
264    // Connected(data, vital)
265    Connected(&'a [u8], bool),
266    Ready,
267    Disconnect(&'a [u8]),
268}
269
270#[derive(Clone, Debug)]
271struct PendingState {
272    // `token`, if present, is included in every message from and to the peer
273    // in order to protect against IP spoofing.
274    token: Option<Token>,
275}
276
277impl PendingState {
278    fn new(token: Option<Token>) -> PendingState {
279        PendingState { token: token }
280    }
281}
282
283#[derive(Clone, Debug)]
284struct OnlineState {
285    // `token`, if present, is included in every message from and to the peer
286    // in order to protect against IP spoofing.
287    token: Option<Token>,
288    // `ack` is the vital chunk from the peer we want to acknowledge.
289    ack: Sequence,
290    // `sequence` is the vital chunk from us that the peer acknowledged.
291    sequence: Sequence,
292    request_resend: bool,
293    // `packet` contains all the queued chunks, `packet_nonvital` only the
294    // non-vital ones. This is important for resending.
295    packet: PacketContents,
296    packet_nonvital: PacketContents,
297    // This contains the unacked chunks that we sent, starting from the most
298    // recently sent chunk.
299    resend_queue: VecDeque<ResendChunk>,
300}
301
302impl OnlineState {
303    fn new(token: Option<Token>) -> OnlineState {
304        OnlineState {
305            token: token,
306            ack: Sequence::new(),
307            sequence: Sequence::new(),
308            request_resend: false,
309            packet: PacketContents::new(),
310            packet_nonvital: PacketContents::new(),
311            resend_queue: VecDeque::new(),
312        }
313    }
314    fn can_send(&self) -> bool {
315        self.packet.num_chunks != 0 || self.request_resend
316    }
317    fn ack_chunks(&mut self, ack: Sequence) {
318        let index = self
319            .resend_queue
320            .iter()
321            .position(|chunk| chunk.sequence == ack);
322        index.map(|i| self.resend_queue.truncate(i));
323    }
324    fn flush<CB: Callback>(
325        &mut self,
326        cb: &mut CB,
327        builder: &mut PacketBuilder,
328    ) -> Result<(), CB::Error> {
329        if !self.can_send() {
330            return Ok(());
331        }
332        let result = builder
333            .send(
334                cb,
335                Packet::Connected(ConnectedPacket {
336                    token: self.token,
337                    ack: self.ack.to_u16(),
338                    type_: ConnectedPacketType::Chunks(
339                        self.request_resend,
340                        self.packet.num_chunks,
341                        &self.packet.data,
342                    ),
343                }),
344            )
345            .map_err(|e| e.unwrap_callback());
346        self.request_resend = false;
347        self.packet.clear();
348        self.packet_nonvital.clear();
349        result
350    }
351}
352
353#[derive(Clone, Debug, Eq, PartialEq)]
354struct PacketContents {
355    num_chunks: u8,
356    data: ArrayVec<[u8; 2048]>,
357}
358
359impl PacketContents {
360    fn new() -> PacketContents {
361        PacketContents {
362            num_chunks: 0,
363            data: ArrayVec::new(),
364        }
365    }
366    fn write_chunk(&mut self, data: &[u8], vital: Option<(u16, bool)>) {
367        protocol::write_chunk(data, vital, &mut self.data).unwrap();
368        self.num_chunks += 1;
369    }
370    fn can_fit_chunk(&self, data: &[u8], vital: bool) -> bool {
371        // current size + chunk header + chunk length
372        self.data.len() + protocol::chunk_header_size(vital) + data.len() <= MAX_PAYLOAD
373    }
374    fn clear(&mut self) {
375        *self = PacketContents::new();
376    }
377}
378
379#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
380struct Sequence {
381    seq: u16, // u10
382}
383
384#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
385enum SequenceOrdering {
386    Past,
387    Current,
388    Future,
389}
390
391impl Sequence {
392    fn new() -> Sequence {
393        Default::default()
394    }
395    fn from_u16(seq: u16) -> Sequence {
396        assert!(seq < protocol::SEQUENCE_MODULUS);
397        Sequence { seq: seq }
398    }
399    fn to_u16(self) -> u16 {
400        self.seq
401    }
402    fn next(&mut self) -> Sequence {
403        self.seq = (self.seq + 1) % protocol::SEQUENCE_MODULUS;
404        *self
405    }
406    fn update(&mut self, other: Sequence) -> SequenceOrdering {
407        let mut next_self = *self;
408        next_self.next();
409        let result = next_self.compare(other);
410        if result == SequenceOrdering::Current {
411            *self = next_self;
412        }
413        result
414    }
415    /// Returns what `other` is in relation to `self`.
416    fn compare(self, other: Sequence) -> SequenceOrdering {
417        let half = protocol::SEQUENCE_MODULUS / 2;
418        let less;
419        match self.seq.cmp(&other.seq) {
420            cmp::Ordering::Less => less = other.seq - self.seq < half,
421            cmp::Ordering::Greater => less = self.seq - other.seq > half,
422            cmp::Ordering::Equal => return SequenceOrdering::Current,
423        }
424        if less {
425            SequenceOrdering::Future
426        } else {
427            SequenceOrdering::Past
428        }
429    }
430}
431
432struct PacketBuilder {
433    buffer: [u8; MAX_PACKETSIZE],
434}
435
436impl PacketBuilder {
437    fn new() -> PacketBuilder {
438        PacketBuilder {
439            buffer: [0; MAX_PACKETSIZE],
440        }
441    }
442    fn send<CB: Callback>(&mut self, cb: &mut CB, packet: Packet) -> Result<(), Error<CB::Error>> {
443        let data = match packet.write(&mut self.buffer[..]) {
444            Ok(d) => d,
445            Err(protocol::Error::Capacity(_)) => unreachable!("too short buffer provided"),
446            Err(protocol::Error::TooLongData) => return Err(Error::TooLongData),
447        };
448        cb.send(data)?;
449        Ok(())
450    }
451}
452
453struct WarnCallback<'a, W: Warn<Warning> + 'a> {
454    warn: &'a mut W,
455}
456
457fn w<W: Warn<Warning>>(warn: &mut W) -> WarnCallback<W> {
458    WarnCallback { warn: warn }
459}
460
461impl<'a, W: Warn<Warning>> Warn<protocol::Warning> for WarnCallback<'a, W> {
462    fn warn(&mut self, warning: protocol::Warning) {
463        self.warn.warn(Warning::Packet(warning))
464    }
465}
466
467impl Connection {
468    pub fn new() -> Connection {
469        Connection {
470            state: State::Unconnected,
471            send: Timeout::inactive(),
472            builder: PacketBuilder::new(),
473        }
474    }
475    pub fn new_accept_token<CB: Callback>(cb: &mut CB, token: Token) -> Connection {
476        let mut result = Connection {
477            state: State::Online(OnlineState::new(Some(token))),
478            send: Timeout::inactive(),
479            builder: PacketBuilder::new(),
480        };
481        result.send.set(cb, Duration::from_millis(500));
482        result
483    }
484    pub fn reset(&mut self) {
485        assert!(matches!(self.state, State::Disconnected));
486        *self = Connection::new();
487    }
488    pub fn is_unconnected(&self) -> bool {
489        matches!(self.state, State::Unconnected)
490    }
491    pub fn needs_tick(&self) -> Timeout {
492        match self.state {
493            State::Unconnected | State::Disconnected => return Timeout::inactive(),
494            _ => {}
495        }
496        let resends = match self.state {
497            State::Online(ref online) => online
498                .resend_queue
499                .back()
500                .map(|r| r.next_send)
501                .unwrap_or_default(),
502            _ => Timeout::inactive(),
503        };
504        cmp::min(self.send, resends)
505    }
506    pub fn connect<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
507        assert!(matches!(self.state, State::Unconnected));
508        self.state = State::Connecting;
509        self.tick_action(cb)?;
510        Ok(())
511    }
512    pub fn disconnect<CB: Callback>(
513        &mut self,
514        cb: &mut CB,
515        reason: &[u8],
516    ) -> Result<(), CB::Error> {
517        if let State::Disconnected = self.state {
518            assert!(
519                false,
520                "Can't call disconnect on an already disconnected connection"
521            );
522        }
523        assert!(
524            reason.iter().all(|&b| b != 0),
525            "reason must not contain NULs"
526        );
527        let result = self.send_control(cb, ControlPacket::Close(reason));
528        self.state = State::Disconnected;
529        result
530    }
531    fn resend<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
532        let online = self.state.assert_online();
533        if online.resend_queue.is_empty() {
534            return Ok(());
535        }
536        online.packet = online.packet_nonvital.clone();
537        let mut i = 0;
538        for chunk in &mut online.resend_queue {
539            chunk.start_timeout(cb);
540        }
541        while i < online.resend_queue.len() {
542            let can_fit;
543            {
544                let chunk = &online.resend_queue[online.resend_queue.len() - i - 1];
545                can_fit = online.packet.can_fit_chunk(&chunk.data, true);
546                if can_fit {
547                    let vital = (chunk.sequence.to_u16(), true);
548                    online.packet.write_chunk(&chunk.data, Some(vital));
549                    i += 1;
550                }
551            }
552            if !can_fit {
553                self.send.set(cb, Duration::from_millis(500));
554                online.flush(cb, &mut self.builder)?;
555            }
556        }
557        Ok(())
558    }
559    pub fn flush<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
560        self.send.set(cb, Duration::from_millis(500));
561        self.state.assert_online().flush(cb, &mut self.builder)
562    }
563    fn queue<CB: Callback>(&mut self, cb: &mut CB, buffer: &[u8], vital: bool) {
564        let online = self.state.assert_online();
565        let vital = if vital {
566            let sequence = online.sequence.next();
567            online
568                .resend_queue
569                .push_front(ResendChunk::new(cb, sequence, buffer));
570            Some((sequence.to_u16(), false))
571        } else {
572            None
573        };
574        if vital.is_none() {
575            online.packet_nonvital.write_chunk(buffer, vital);
576        }
577        online.packet.write_chunk(buffer, vital)
578    }
579    pub fn send<CB: Callback>(
580        &mut self,
581        cb: &mut CB,
582        buffer: &[u8],
583        vital: bool,
584    ) -> Result<(), Error<CB::Error>> {
585        let result;
586        {
587            let online = self.state.assert_online();
588            if buffer.len() > MAX_PAYLOAD {
589                return Err(Error::TooLongData);
590            }
591            if !online.packet.can_fit_chunk(buffer, vital) {
592                result = online.flush(cb, &mut self.builder).map_err(Error::from);
593            } else {
594                result = Ok(());
595            }
596        }
597        self.queue(cb, buffer, vital);
598        result
599    }
600    pub fn send_connless<CB: Callback>(
601        &mut self,
602        cb: &mut CB,
603        data: &[u8],
604    ) -> Result<(), Error<CB::Error>> {
605        self.state.assert_online();
606        self.send.set(cb, Duration::from_millis(500));
607        self.builder.send(cb, Packet::Connless(data))
608    }
609    fn send_control<CB: Callback>(
610        &mut self,
611        cb: &mut CB,
612        control: ControlPacket,
613    ) -> Result<(), CB::Error> {
614        let ack = match self.state {
615            State::Online(ref mut online) => online.ack.to_u16(),
616            _ => 0,
617        };
618        let token = match self.state {
619            State::Unconnected => unreachable!(),
620            // Signal support for the token protocol.
621            State::Connecting => Some(TOKEN_NONE),
622            State::Pending(ref pending) => pending.token,
623            State::Online(ref online) => online.token,
624            State::Disconnected => unreachable!(),
625        };
626        self.builder
627            .send(
628                cb,
629                Packet::Connected(ConnectedPacket {
630                    token: token,
631                    ack: ack,
632                    type_: ConnectedPacketType::Control(control),
633                }),
634            )
635            .map_err(|e| e.unwrap_callback())
636    }
637    pub fn tick<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
638        let do_resend = match self.state {
639            State::Online(ref online) => {
640                // WARN?
641                online
642                    .resend_queue
643                    .back()
644                    .map(|c| c.next_send.has_triggered_level(cb))
645                    .unwrap_or(false)
646            }
647            _ => false,
648        };
649        if do_resend {
650            self.resend(cb)
651        } else if self.send.has_triggered_edge(cb) {
652            self.tick_action(cb)
653        } else {
654            Ok(())
655        }
656    }
657    fn tick_action<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
658        let control = match self.state {
659            State::Connecting => ControlPacket::Connect,
660            State::Pending(_) => ControlPacket::ConnectAccept,
661            State::Online(ref mut online) => {
662                if online.can_send() {
663                    // TODO: Warn if this happens on reliable networks.
664                    self.send.set(cb, Duration::from_millis(500));
665                    return online.flush(cb, &mut self.builder);
666                }
667                ControlPacket::KeepAlive
668            }
669            _ => return Ok(()),
670        };
671        self.send.set(cb, Duration::from_millis(500));
672        self.send_control(cb, control)
673    }
674    /// Notifies the connection of incoming data.
675    ///
676    /// `buffer` must have at least size `MAX_PAYLOAD`.
677    pub fn feed<'a, B, CB, W>(
678        &mut self,
679        cb: &mut CB,
680        warn: &mut W,
681        data: &'a [u8],
682        buf: B,
683    ) -> (ReceivePacket<'a>, Result<(), CB::Error>)
684    where
685        B: Buffer<'a>,
686        CB: Callback,
687        W: Warn<Warning>,
688    {
689        with_buffer(buf, |b| self.feed_impl(cb, warn, data, b))
690    }
691
692    pub fn feed_impl<'d, 's, CB, W>(
693        &mut self,
694        cb: &mut CB,
695        warn: &mut W,
696        data: &'d [u8],
697        mut buffer: BufferRef<'d, 's>,
698    ) -> (ReceivePacket<'d>, Result<(), CB::Error>)
699    where
700        CB: Callback,
701        W: Warn<Warning>,
702    {
703        let none = (ReceivePacket::none(), Ok(()));
704        {
705            use protocol::ConnectedPacketType::*;
706            use protocol::ControlPacket::*;
707
708            let token_hint = self.state.token().map(|t| t.is_some());
709            let packet = match Packet::read(&mut w(warn), data, token_hint, &mut buffer) {
710                Ok(p) => p,
711                Err(e) => {
712                    warn.warn(Warning::Read(e));
713                    return none;
714                }
715            };
716
717            let connected = match packet {
718                Packet::Connless(data) => return (ReceivePacket::connless(data), Ok(())),
719                Packet::Connected(c) => c,
720            };
721            let ConnectedPacket { token, ack, type_ } = connected;
722
723            // If we're in a state where we know whether the other side uses
724            // tokens and which token they use, make sure it matches.
725            if let Some(&expected_token) = self.state.token() {
726                if token != expected_token {
727                    warn.warn(Warning::TokenMismatch);
728                    return none;
729                }
730            }
731
732            // TODO: Check ack for sanity.
733            if let State::Online(ref mut online) = self.state {
734                online.ack_chunks(Sequence::from_u16(ack));
735            }
736
737            match type_ {
738                Chunks(request_resend, num_chunks, chunks) => {
739                    let _ = num_chunks;
740                    if let State::Pending(ref pending) = self.state {
741                        self.state = State::Online(OnlineState::new(pending.token));
742                    }
743                    let result;
744                    if request_resend {
745                        if let State::Online(_) = self.state {
746                            result = self.resend(cb);
747                        } else {
748                            result = Ok(());
749                        }
750                    } else {
751                        result = Ok(())
752                    }
753                    match self.state {
754                        State::Online(ref mut online) => {
755                            return (
756                                ReceivePacket::connected(warn, online, num_chunks, chunks),
757                                result,
758                            );
759                        }
760                        State::Pending(_) => unreachable!(),
761                        // WARN: packet received while not online.
762                        _ => return none,
763                    }
764                }
765                Control(KeepAlive) => return none,
766                Control(Connect) => {
767                    if let State::Unconnected = self.state {
768                        let new_token = match token {
769                            None => None,
770                            Some(TOKEN_NONE) => Some(Token::random(|b| cb.secure_random(b))),
771                            // Ignore invalid tokens.
772                            Some(_) => return none,
773                        };
774                        self.state = State::Pending(PendingState::new(new_token));
775                        // Fall through to tick.
776                    } else {
777                        return none;
778                    }
779                }
780                Control(ConnectAccept) => {
781                    if let State::Connecting = self.state {
782                        self.state = State::Online(OnlineState::new(token));
783                        return (
784                            ReceivePacket::ready(),
785                            self.send_control(cb, ControlPacket::Accept),
786                        );
787                    } else {
788                        return none;
789                    }
790                }
791                Control(Accept) => return none,
792                Control(Close(reason)) => {
793                    self.state = State::Disconnected;
794                    return (ReceivePacket::disconnect(reason), Ok(()));
795                }
796            }
797        }
798        // Fall-through from `Control(Connect)`
799        (ReceivePacket::none(), self.tick_action(cb))
800    }
801}
802
803#[cfg(test)]
804mod test {
805    use super::Callback;
806    use super::Connection;
807    use super::ReceiveChunk;
808    use super::Sequence;
809    use super::SequenceOrdering;
810    use crate::protocol;
811    use crate::Timestamp;
812    use hexdump::hexdump;
813    use itertools::Itertools;
814    use std::collections::VecDeque;
815    use void::ResultVoidExt;
816    use void::Void;
817    use warn::Panic;
818
819    #[test]
820    fn sequence_compare() {
821        use super::SequenceOrdering::*;
822
823        fn cmp(a: Sequence, b: Sequence) -> SequenceOrdering {
824            Sequence::compare(a, b)
825        }
826        let default = Sequence::new();
827        let first = Sequence::from_u16(0);
828        let mid = Sequence::from_u16(protocol::SEQUENCE_MODULUS / 2);
829        let end = Sequence::from_u16(protocol::SEQUENCE_MODULUS - 1);
830        assert_eq!(cmp(default, first), Current);
831        assert_eq!(cmp(first, mid), Past);
832        assert_eq!(cmp(first, end), Past);
833        assert_eq!(cmp(mid, first), Past);
834        assert_eq!(cmp(mid, end), Future);
835        assert_eq!(cmp(end, first), Future);
836        assert_eq!(cmp(end, mid), Past);
837    }
838
839    #[test]
840    fn establish_connection_no_token() {
841        struct Cb(VecDeque<Vec<u8>>);
842        impl Cb {
843            fn new() -> Cb {
844                Cb(VecDeque::new())
845            }
846        }
847        impl Callback for Cb {
848            type Error = Void;
849            fn secure_random(&mut self, buffer: &mut [u8]) {
850                let _ = buffer;
851                unimplemented!();
852            }
853            fn send(&mut self, data: &[u8]) -> Result<(), Void> {
854                self.0.push_back(data.to_owned());
855                Ok(())
856            }
857            fn time(&mut self) -> Timestamp {
858                Timestamp::from_secs_since_epoch(0)
859            }
860        }
861        let mut buffer = [0; protocol::MAX_PACKETSIZE];
862        let mut cb = Cb::new();
863        let cb = &mut cb;
864        println!("");
865
866        let mut client = Connection::new();
867        let mut server = Connection::new();
868
869        // Connect
870        client.connect(cb).void_unwrap();
871        let packet = cb.0.pop_front().unwrap();
872        assert!(cb.0.is_empty());
873        hexdump(&packet);
874        assert!(&packet == b"\x10\x00\x00\x01TKEN\xff\xff\xff\xff");
875        // Simulate that the token protocol is not supported.
876        let packet = *b"\x10\x00\x00\x01";
877
878        // ConnectAccept
879        assert!(server
880            .feed(cb, &mut Panic, &packet, &mut buffer[..])
881            .0
882            .next()
883            .is_none());
884        let packet = cb.0.pop_front().unwrap();
885        assert!(cb.0.is_empty());
886        hexdump(&packet);
887        assert!(&packet == b"\x10\x00\x00\x02");
888
889        // Accept
890        assert!(
891            client
892                .feed(cb, &mut Panic, &packet, &mut buffer[..])
893                .0
894                .collect_vec()
895                == &[ReceiveChunk::Ready]
896        );
897        let packet = cb.0.pop_front().unwrap();
898        assert!(cb.0.is_empty());
899        hexdump(&packet);
900        assert!(&packet == b"\x10\x00\x00\x03");
901
902        assert!(server
903            .feed(cb, &mut Panic, &packet, &mut buffer[..])
904            .0
905            .next()
906            .is_none());
907        assert!(cb.0.is_empty());
908
909        // Send
910        client.send(cb, b"\x42", true).unwrap();
911        assert!(cb.0.is_empty());
912
913        // Flush
914        client.flush(cb).void_unwrap();
915        let packet = cb.0.pop_front().unwrap();
916        assert!(cb.0.is_empty());
917        hexdump(&packet);
918        assert!(&packet == b"\x00\x00\x01\x40\x01\x01\x42");
919
920        // Receive
921        assert!(
922            server
923                .feed(cb, &mut Panic, &packet, &mut buffer[..])
924                .0
925                .collect_vec()
926                == &[ReceiveChunk::Connected(b"\x42", true)]
927        );
928        assert!(cb.0.is_empty());
929
930        // Disconnect
931        server.disconnect(cb, b"42").void_unwrap();
932        let packet = cb.0.pop_front().unwrap();
933        hexdump(&packet);
934        assert!(&packet == b"\x10\x01\x00\x0442\0");
935
936        assert!(
937            client
938                .feed(cb, &mut Panic, &packet, &mut buffer[..])
939                .0
940                .collect_vec()
941                == &[ReceiveChunk::Disconnect(b"42")]
942        );
943
944        client.reset();
945        server.reset();
946    }
947
948    #[test]
949    fn establish_connection() {
950        struct Cb(VecDeque<Vec<u8>>);
951        impl Cb {
952            fn new() -> Cb {
953                Cb(VecDeque::new())
954            }
955        }
956        impl Callback for Cb {
957            type Error = Void;
958            fn secure_random(&mut self, buffer: &mut [u8]) {
959                if buffer.len() != 4 {
960                    unimplemented!();
961                }
962                buffer[0] = 0x12;
963                buffer[1] = 0x34;
964                buffer[2] = 0x56;
965                buffer[3] = 0x78;
966            }
967            fn send(&mut self, data: &[u8]) -> Result<(), Void> {
968                self.0.push_back(data.to_owned());
969                Ok(())
970            }
971            fn time(&mut self) -> Timestamp {
972                Timestamp::from_secs_since_epoch(0)
973            }
974        }
975        let mut buffer = [0; protocol::MAX_PACKETSIZE];
976        let mut cb = Cb::new();
977        let cb = &mut cb;
978        println!("");
979
980        let mut client = Connection::new();
981        let mut server = Connection::new();
982
983        // Connect
984        client.connect(cb).void_unwrap();
985        let packet = cb.0.pop_front().unwrap();
986        assert!(cb.0.is_empty());
987        hexdump(&packet);
988        assert!(&packet == b"\x10\x00\x00\x01TKEN\xff\xff\xff\xff");
989
990        // ConnectAccept
991        assert!(server
992            .feed(cb, &mut Panic, &packet, &mut buffer[..])
993            .0
994            .next()
995            .is_none());
996        let packet = cb.0.pop_front().unwrap();
997        assert!(cb.0.is_empty());
998        hexdump(&packet);
999        assert!(&packet == b"\x10\x00\x00\x02TKEN\x12\x34\x56\x78");
1000
1001        // Accept
1002        assert!(
1003            client
1004                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1005                .0
1006                .collect_vec()
1007                == &[ReceiveChunk::Ready]
1008        );
1009        let packet = cb.0.pop_front().unwrap();
1010        assert!(cb.0.is_empty());
1011        hexdump(&packet);
1012        assert!(&packet == b"\x10\x00\x00\x03\x12\x34\x56\x78");
1013
1014        assert!(server
1015            .feed(cb, &mut Panic, &packet, &mut buffer[..])
1016            .0
1017            .next()
1018            .is_none());
1019        assert!(cb.0.is_empty());
1020
1021        // Send
1022        client.send(cb, b"\x42", true).unwrap();
1023        assert!(cb.0.is_empty());
1024
1025        // Flush
1026        client.flush(cb).void_unwrap();
1027        let packet = cb.0.pop_front().unwrap();
1028        assert!(cb.0.is_empty());
1029        hexdump(&packet);
1030        assert!(&packet == b"\x00\x00\x01\x40\x01\x01\x42\x12\x34\x56\x78");
1031
1032        // Receive
1033        assert!(
1034            server
1035                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1036                .0
1037                .collect_vec()
1038                == &[ReceiveChunk::Connected(b"\x42", true)]
1039        );
1040        assert!(cb.0.is_empty());
1041
1042        // Disconnect
1043        server.disconnect(cb, b"42").void_unwrap();
1044        let packet = cb.0.pop_front().unwrap();
1045        hexdump(&packet);
1046        assert!(&packet == b"\x10\x01\x00\x0442\0\x12\x34\x56\x78");
1047
1048        assert!(
1049            client
1050                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1051                .0
1052                .collect_vec()
1053                == &[ReceiveChunk::Disconnect(b"42")]
1054        );
1055
1056        client.reset();
1057        server.reset();
1058    }
1059
1060    #[test]
1061    fn establish_connection_anti_ip_addr_spoofing() {
1062        struct Cb(VecDeque<Vec<u8>>);
1063        impl Cb {
1064            fn new() -> Cb {
1065                Cb(VecDeque::new())
1066            }
1067        }
1068        impl Callback for Cb {
1069            type Error = Void;
1070            fn secure_random(&mut self, buffer: &mut [u8]) {
1071                if buffer.len() != 4 {
1072                    unimplemented!();
1073                }
1074                buffer[0] = 0x12;
1075                buffer[1] = 0x34;
1076                buffer[2] = 0x56;
1077                buffer[3] = 0x78;
1078            }
1079            fn send(&mut self, data: &[u8]) -> Result<(), Void> {
1080                self.0.push_back(data.to_owned());
1081                Ok(())
1082            }
1083            fn time(&mut self) -> Timestamp {
1084                Timestamp::from_secs_since_epoch(0)
1085            }
1086        }
1087        let mut buffer = [0; protocol::MAX_PACKETSIZE];
1088        let mut cb = Cb::new();
1089        let cb = &mut cb;
1090        println!("");
1091
1092        let mut client = Connection::new();
1093        let mut server = Connection::new();
1094
1095        // Connect
1096        client.connect(cb).void_unwrap();
1097        let packet = cb.0.pop_front().unwrap();
1098        assert!(cb.0.is_empty());
1099        hexdump(&packet);
1100        assert!(&packet == b"\x10\x00\x00\x01TKEN\xff\xff\xff\xff");
1101
1102        // ConnectAccept
1103        assert!(server
1104            .feed(cb, &mut Panic, &packet, &mut buffer[..])
1105            .0
1106            .next()
1107            .is_none());
1108        let packet = cb.0.pop_front().unwrap();
1109        assert!(cb.0.is_empty());
1110        hexdump(&packet);
1111        assert!(&packet == b"\x10\x00\x00\x02TKEN\x12\x34\x56\x78");
1112
1113        // Accept
1114        assert!(
1115            client
1116                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1117                .0
1118                .collect_vec()
1119                == &[ReceiveChunk::Ready]
1120        );
1121        let packet = cb.0.pop_front().unwrap();
1122        assert!(cb.0.is_empty());
1123        hexdump(&packet);
1124        assert!(&packet == b"\x10\x00\x00\x03\x12\x34\x56\x78");
1125
1126        assert!(server
1127            .feed(cb, &mut Panic, &packet, &mut buffer[..])
1128            .0
1129            .next()
1130            .is_none());
1131        assert!(cb.0.is_empty());
1132
1133        let token = protocol::Token([0x12, 0x34, 0x56, 0x78]);
1134        let mut server = Connection::new_accept_token(cb, token);
1135
1136        // Send
1137        client.send(cb, b"\x42", true).unwrap();
1138        assert!(cb.0.is_empty());
1139
1140        // Flush
1141        client.flush(cb).void_unwrap();
1142        let packet = cb.0.pop_front().unwrap();
1143        assert!(cb.0.is_empty());
1144        hexdump(&packet);
1145        assert!(&packet == b"\x00\x00\x01\x40\x01\x01\x42\x12\x34\x56\x78");
1146
1147        // Receive
1148        assert!(
1149            server
1150                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1151                .0
1152                .collect_vec()
1153                == &[ReceiveChunk::Connected(b"\x42", true)]
1154        );
1155        assert!(cb.0.is_empty());
1156
1157        // Disconnect
1158        server.disconnect(cb, b"42").void_unwrap();
1159        let packet = cb.0.pop_front().unwrap();
1160        hexdump(&packet);
1161        assert!(&packet == b"\x10\x01\x00\x0442\0\x12\x34\x56\x78");
1162
1163        assert!(
1164            client
1165                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1166                .0
1167                .collect_vec()
1168                == &[ReceiveChunk::Disconnect(b"42")]
1169        );
1170
1171        client.reset();
1172        server.reset();
1173    }
1174}