Skip to main content

libtw2_net/
connection7.rs

1use crate::protocol7 as protocol;
2use crate::protocol7::ChunksIter;
3use crate::protocol7::ConnectedPacket;
4use crate::protocol7::ConnectedPacketType;
5use crate::protocol7::ConnlessPacket;
6use crate::protocol7::ControlPacket;
7use crate::protocol7::Packet;
8use crate::protocol7::Token;
9use crate::protocol7::MAX_PACKETSIZE;
10use crate::protocol7::MAX_PAYLOAD;
11use crate::protocol7::TOKEN_NONE;
12use crate::Timeout;
13use crate::Timestamp;
14use arrayvec::ArrayVec;
15use libtw2_buffer::with_buffer;
16use libtw2_buffer::Buffer;
17use libtw2_buffer::BufferRef;
18use libtw2_warn::Warn;
19use std::cmp;
20use std::collections::VecDeque;
21use std::iter;
22use std::time::Duration;
23
24// TODO: Implement receive timeout.
25// TODO: Don't allow for unbounded backlog of vital messages.
26
27pub trait Callback {
28    type Error;
29    fn secure_random(&mut self, buffer: &mut [u8]);
30    fn send(&mut self, buffer: &[u8]) -> Result<(), Self::Error>;
31    fn time(&mut self) -> Timestamp;
32}
33
34#[derive(Debug)]
35pub enum Error<CE> {
36    TooLongData,
37    Callback(CE),
38}
39
40impl<CE> From<CE> for Error<CE> {
41    fn from(e: CE) -> Error<CE> {
42        Error::Callback(e)
43    }
44}
45
46impl<CE> Error<CE> {
47    pub fn unwrap_callback(self) -> CE {
48        match self {
49            Error::TooLongData => panic!("too long data"),
50            Error::Callback(e) => e,
51        }
52    }
53}
54
55#[derive(Debug)]
56pub enum Warning {
57    ConnlessResponseTokenMismatch,
58    ConnlessTokenMismatch,
59    Packet(protocol::Warning),
60    Read(protocol::PacketReadError),
61    TokenMismatch,
62    Unexpected,
63}
64
65trait TimeoutExt {
66    fn set<CB: Callback>(&mut self, cb: &mut CB, value: Duration);
67    fn has_triggered_level<CB: Callback>(&self, cb: &mut CB) -> bool;
68    fn has_triggered_edge<CB: Callback>(&mut self, cb: &mut CB) -> bool;
69}
70
71impl TimeoutExt for Timeout {
72    fn set<CB: Callback>(&mut self, cb: &mut CB, value: Duration) {
73        *self = Timeout::active(cb.time() + value);
74    }
75    fn has_triggered_level<CB: Callback>(&self, cb: &mut CB) -> bool {
76        self.to_opt().map(|time| time <= cb.time()).unwrap_or(false)
77    }
78    fn has_triggered_edge<CB: Callback>(&mut self, cb: &mut CB) -> bool {
79        let triggered = self.has_triggered_level(cb);
80        if triggered {
81            *self = Timeout::inactive();
82        }
83        triggered
84    }
85}
86
87pub struct Connection {
88    state: State,
89    send: Timeout,
90    builder: PacketBuilder,
91}
92
93#[derive(Clone, Debug)]
94enum State {
95    Unconnected,
96    Token(TokenState),
97    PendingConnect(PendingConnectState),
98    Connecting(ConnectingState),
99    Pending(PendingState),
100    Online(OnlineState),
101    Disconnected,
102}
103
104impl State {
105    pub fn assert_online(&mut self) -> &mut OnlineState {
106        match *self {
107            State::Online(ref mut s) => s,
108            _ => panic!("state not online"),
109        }
110    }
111    pub fn own_token(&self) -> Option<Token> {
112        use self::State::*;
113        match *self {
114            Unconnected => None,
115            Token(ref token) => Some(token.own_token),
116            PendingConnect(ref pending_connect) => Some(pending_connect.own_token),
117            Connecting(ref connecting) => Some(connecting.own_token),
118            Pending(ref pending) => Some(pending.own_token),
119            Online(ref online) => Some(online.own_token),
120            Disconnected => None,
121        }
122    }
123    pub fn their_token(&self) -> Option<Token> {
124        use self::State::*;
125        match *self {
126            Unconnected => None,
127            Token(_) => None,
128            PendingConnect(_) => None,
129            Connecting(ref connecting) => Some(connecting.their_token),
130            Pending(ref pending) => Some(pending.their_token),
131            Online(ref online) => Some(online.their_token),
132            Disconnected => None,
133        }
134    }
135}
136
137#[derive(Clone, Debug)]
138struct ResendChunk {
139    next_send: Timeout,
140    sequence: Sequence,
141    data: ArrayVec<[u8; 2048]>,
142}
143
144impl ResendChunk {
145    fn new<CB: Callback>(cb: &mut CB, sequence: Sequence, data: &[u8]) -> ResendChunk {
146        let mut result = ResendChunk {
147            next_send: Timeout::inactive(),
148            sequence: sequence,
149            data: data.iter().cloned().collect(),
150        };
151        assert!(
152            result.data.len() == data.len(),
153            "overlong resend packet {}",
154            data.len()
155        );
156        result.start_timeout(cb);
157        result
158    }
159    fn start_timeout<CB: Callback>(&mut self, cb: &mut CB) {
160        self.next_send.set(cb, Duration::from_millis(1_000));
161    }
162}
163
164pub struct ReceivePacket<'a> {
165    type_: ReceivePacketType<'a>,
166}
167
168impl<'a> Clone for ReceivePacket<'a> {
169    fn clone(&self) -> ReceivePacket<'a> {
170        ReceivePacket {
171            type_: self.type_.clone(),
172        }
173    }
174}
175
176impl<'a> ReceivePacket<'a> {
177    fn none() -> ReceivePacket<'a> {
178        ReceivePacket {
179            type_: ReceivePacketType::None,
180        }
181    }
182    fn ready() -> ReceivePacket<'a> {
183        ReceivePacket {
184            type_: ReceivePacketType::Ready(iter::once(())),
185        }
186    }
187    fn connless(data: &'a [u8]) -> ReceivePacket<'a> {
188        ReceivePacket {
189            type_: ReceivePacketType::Connless(iter::once(data)),
190        }
191    }
192    fn connected<W>(
193        warn: &mut W,
194        online: &mut OnlineState,
195        num_chunks: u8,
196        data: &'a [u8],
197    ) -> ReceivePacket<'a>
198    where
199        W: Warn<Warning>,
200    {
201        let chunks_iter = ChunksIter::new(data, num_chunks);
202        let ack = online.ack.clone();
203        let mut iter = chunks_iter.clone();
204        while let Some(c) = iter.next_warn(&mut w(warn)) {
205            if let Some((sequence, resend)) = c.vital {
206                let _ = resend;
207                if online.ack.update(Sequence::from_u16(sequence)) != SequenceOrdering::Current {
208                    online.request_resend = true;
209                }
210            }
211        }
212        ReceivePacket {
213            type_: ReceivePacketType::Connected(ReceiveChunks {
214                ack: ack,
215                chunks: chunks_iter,
216            }),
217        }
218    }
219    fn disconnect(reason: &'a [u8]) -> ReceivePacket<'a> {
220        ReceivePacket {
221            type_: ReceivePacketType::Close(iter::once(reason)),
222        }
223    }
224}
225
226#[derive(Clone)]
227enum ReceivePacketType<'a> {
228    None,
229    Connless(iter::Once<&'a [u8]>),
230    Connected(ReceiveChunks<'a>),
231    Ready(iter::Once<()>),
232    Close(iter::Once<&'a [u8]>),
233}
234
235impl<'a> Iterator for ReceivePacket<'a> {
236    type Item = ReceiveChunk<'a>;
237    fn next(&mut self) -> Option<ReceiveChunk<'a>> {
238        match self.type_ {
239            ReceivePacketType::None => None,
240            ReceivePacketType::Ready(ref mut once) => once.next().map(|()| ReceiveChunk::Ready),
241            ReceivePacketType::Connless(ref mut once) => once.next().map(ReceiveChunk::Connless),
242            ReceivePacketType::Connected(ref mut chunks) => chunks.next(),
243            ReceivePacketType::Close(ref mut once) => once.next().map(ReceiveChunk::Disconnect),
244        }
245    }
246    fn size_hint(&self) -> (usize, Option<usize>) {
247        let len = self.clone().count();
248        (len, Some(len))
249    }
250}
251
252impl<'a> ExactSizeIterator for ReceivePacket<'a> {}
253
254#[derive(Clone)]
255struct ReceiveChunks<'a> {
256    ack: Sequence,
257    chunks: ChunksIter<'a>,
258}
259
260impl<'a> Iterator for ReceiveChunks<'a> {
261    type Item = ReceiveChunk<'a>;
262    fn next(&mut self) -> Option<ReceiveChunk<'a>> {
263        self.chunks.next().and_then(|c| {
264            if let Some((sequence, resend)) = c.vital {
265                let _ = resend;
266                if self.ack.update(Sequence::from_u16(sequence)) != SequenceOrdering::Current {
267                    return self.next();
268                }
269            }
270            Some(ReceiveChunk::Connected(c.data, c.vital.is_some()))
271        })
272    }
273    fn size_hint(&self) -> (usize, Option<usize>) {
274        let len = self.clone().count();
275        (len, Some(len))
276    }
277}
278
279impl<'a> ExactSizeIterator for ReceiveChunks<'a> {}
280
281#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
282pub enum ReceiveChunk<'a> {
283    Connless(&'a [u8]),
284    // Connected(data, vital)
285    Connected(&'a [u8], bool),
286    Ready,
287    Disconnect(&'a [u8]),
288}
289
290#[derive(Clone, Debug)]
291struct TokenState {
292    own_token: Token,
293}
294
295impl TokenState {
296    fn new(own_token: Token) -> TokenState {
297        TokenState {
298            own_token: own_token,
299        }
300    }
301}
302
303#[derive(Clone, Debug)]
304struct PendingConnectState {
305    own_token: Token,
306}
307
308impl PendingConnectState {
309    fn new(own_token: Token) -> PendingConnectState {
310        PendingConnectState {
311            own_token: own_token,
312        }
313    }
314}
315
316#[derive(Clone, Debug)]
317struct ConnectingState {
318    own_token: Token,
319    their_token: Token,
320}
321
322impl ConnectingState {
323    fn new(own_token: Token, their_token: Token) -> ConnectingState {
324        ConnectingState {
325            own_token: own_token,
326            their_token: their_token,
327        }
328    }
329}
330
331#[derive(Clone, Debug)]
332struct PendingState {
333    own_token: Token,
334    their_token: Token,
335}
336
337impl PendingState {
338    fn new(own_token: Token, their_token: Token) -> PendingState {
339        PendingState {
340            own_token: own_token,
341            their_token: their_token,
342        }
343    }
344}
345
346#[derive(Clone, Debug)]
347struct OnlineState {
348    // `own_token` is included in every message from the other peer in order to
349    // protect against IP spoofing.
350    own_token: Token,
351    // `their_token` is included in every message to the other peer in order to
352    // protect against IP spoofing.
353    their_token: Token,
354    // `ack` is the vital chunk from the peer we want to acknowledge.
355    ack: Sequence,
356    // `sequence` is the vital chunk from us that the peer acknowledged.
357    sequence: Sequence,
358    request_resend: bool,
359    // `packet` contains all the queued chunks, `packet_nonvital` only the
360    // non-vital ones. This is important for resending.
361    packet: PacketContents,
362    packet_nonvital: PacketContents,
363    // This contains the unacked chunks that we sent, starting from the most
364    // recently sent chunk.
365    resend_queue: VecDeque<ResendChunk>,
366}
367
368impl OnlineState {
369    fn new(own_token: Token, their_token: Token) -> OnlineState {
370        OnlineState {
371            own_token: own_token,
372            their_token: their_token,
373            ack: Sequence::new(),
374            sequence: Sequence::new(),
375            request_resend: false,
376            packet: PacketContents::new(),
377            packet_nonvital: PacketContents::new(),
378            resend_queue: VecDeque::new(),
379        }
380    }
381    fn can_send(&self) -> bool {
382        self.packet.num_chunks != 0 || self.request_resend
383    }
384    fn ack_chunks(&mut self, ack: Sequence) {
385        let index = self
386            .resend_queue
387            .iter()
388            .position(|chunk| chunk.sequence == ack);
389        index.map(|i| self.resend_queue.truncate(i));
390    }
391    fn flush<CB: Callback>(
392        &mut self,
393        cb: &mut CB,
394        builder: &mut PacketBuilder,
395    ) -> Result<(), CB::Error> {
396        if !self.can_send() {
397            return Ok(());
398        }
399        let result = builder
400            .send(
401                cb,
402                Packet::Connected(ConnectedPacket {
403                    token: self.their_token,
404                    ack: self.ack.to_u16(),
405                    type_: ConnectedPacketType::Chunks(
406                        self.request_resend,
407                        self.packet.num_chunks,
408                        &self.packet.data,
409                    ),
410                }),
411            )
412            .map_err(|e| e.unwrap_callback());
413        self.request_resend = false;
414        self.packet.clear();
415        self.packet_nonvital.clear();
416        result
417    }
418}
419
420#[derive(Clone, Debug, Eq, PartialEq)]
421struct PacketContents {
422    num_chunks: u8,
423    data: ArrayVec<[u8; 2048]>,
424}
425
426impl PacketContents {
427    fn new() -> PacketContents {
428        PacketContents {
429            num_chunks: 0,
430            data: ArrayVec::new(),
431        }
432    }
433    fn write_chunk(&mut self, data: &[u8], vital: Option<(u16, bool)>) {
434        protocol::write_chunk(data, vital, &mut self.data).unwrap();
435        self.num_chunks += 1;
436    }
437    fn can_fit_chunk(&self, data: &[u8], vital: bool) -> bool {
438        // current size + chunk header + chunk length
439        self.data.len() + protocol::chunk_header_size(vital) + data.len() <= MAX_PAYLOAD
440    }
441    fn clear(&mut self) {
442        *self = PacketContents::new();
443    }
444}
445
446#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
447struct Sequence {
448    seq: u16, // u10
449}
450
451#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
452enum SequenceOrdering {
453    Past,
454    Current,
455    Future,
456}
457
458impl Sequence {
459    fn new() -> Sequence {
460        Default::default()
461    }
462    fn from_u16(seq: u16) -> Sequence {
463        assert!(seq < protocol::SEQUENCE_MODULUS);
464        Sequence { seq: seq }
465    }
466    fn to_u16(self) -> u16 {
467        self.seq
468    }
469    fn next(&mut self) -> Sequence {
470        self.seq = (self.seq + 1) % protocol::SEQUENCE_MODULUS;
471        *self
472    }
473    fn update(&mut self, other: Sequence) -> SequenceOrdering {
474        let mut next_self = *self;
475        next_self.next();
476        let result = next_self.compare(other);
477        if result == SequenceOrdering::Current {
478            *self = next_self;
479        }
480        result
481    }
482    /// Returns what `other` is in relation to `self`.
483    fn compare(self, other: Sequence) -> SequenceOrdering {
484        let half = protocol::SEQUENCE_MODULUS / 2;
485        let less;
486        match self.seq.cmp(&other.seq) {
487            cmp::Ordering::Less => less = other.seq - self.seq < half,
488            cmp::Ordering::Greater => less = self.seq - other.seq > half,
489            cmp::Ordering::Equal => return SequenceOrdering::Current,
490        }
491        if less {
492            SequenceOrdering::Future
493        } else {
494            SequenceOrdering::Past
495        }
496    }
497}
498
499struct PacketBuilder {
500    buffer: [u8; MAX_PACKETSIZE],
501}
502
503impl PacketBuilder {
504    fn new() -> PacketBuilder {
505        PacketBuilder {
506            buffer: [0; MAX_PACKETSIZE],
507        }
508    }
509    fn send<CB: Callback>(&mut self, cb: &mut CB, packet: Packet) -> Result<(), Error<CB::Error>> {
510        let data = match packet.write(&mut self.buffer[..]) {
511            Ok(d) => d,
512            Err(protocol::Error::Capacity(_)) => unreachable!("too short buffer provided"),
513            Err(protocol::Error::TooLongData) => return Err(Error::TooLongData),
514        };
515        cb.send(data)?;
516        Ok(())
517    }
518}
519
520struct WarnCallback<'a, W: Warn<Warning> + 'a> {
521    warn: &'a mut W,
522}
523
524fn w<W: Warn<Warning>>(warn: &mut W) -> WarnCallback<'_, W> {
525    WarnCallback { warn: warn }
526}
527
528impl<'a, W: Warn<Warning>> Warn<protocol::Warning> for WarnCallback<'a, W> {
529    fn warn(&mut self, warning: protocol::Warning) {
530        self.warn.warn(Warning::Packet(warning))
531    }
532}
533
534impl Connection {
535    pub fn new() -> Connection {
536        Connection {
537            state: State::Unconnected,
538            send: Timeout::inactive(),
539            builder: PacketBuilder::new(),
540        }
541    }
542    pub fn reset(&mut self) {
543        assert!(matches!(self.state, State::Disconnected));
544        *self = Connection::new();
545    }
546    pub fn is_unconnected(&self) -> bool {
547        matches!(self.state, State::Unconnected)
548    }
549    pub fn needs_tick(&self) -> Timeout {
550        match self.state {
551            State::Unconnected | State::Disconnected => return Timeout::inactive(),
552            _ => {}
553        }
554        let resends = match self.state {
555            State::Online(ref online) => online
556                .resend_queue
557                .back()
558                .map(|r| r.next_send)
559                .unwrap_or_default(),
560            _ => Timeout::inactive(),
561        };
562        cmp::min(self.send, resends)
563    }
564    pub fn connect<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
565        assert!(matches!(self.state, State::Unconnected));
566        self.state = State::Token(TokenState::new(Token::random(|b| cb.secure_random(b))));
567        self.tick_action(cb)?;
568        Ok(())
569    }
570    pub fn disconnect<CB: Callback>(
571        &mut self,
572        cb: &mut CB,
573        reason: &[u8],
574    ) -> Result<(), CB::Error> {
575        if let State::Disconnected = self.state {
576            assert!(
577                false,
578                "Can't call disconnect on an already disconnected connection"
579            );
580        }
581        assert!(
582            reason.iter().all(|&b| b != 0),
583            "reason must not contain NULs"
584        );
585        let result = self.send_control(cb, ControlPacket::Close(reason));
586        self.state = State::Disconnected;
587        result
588    }
589    fn resend<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
590        let online = self.state.assert_online();
591        if online.resend_queue.is_empty() {
592            return Ok(());
593        }
594        online.packet = online.packet_nonvital.clone();
595        let mut i = 0;
596        for chunk in &mut online.resend_queue {
597            chunk.start_timeout(cb);
598        }
599        while i < online.resend_queue.len() {
600            let can_fit;
601            {
602                let chunk = &online.resend_queue[online.resend_queue.len() - i - 1];
603                can_fit = online.packet.can_fit_chunk(&chunk.data, true);
604                if can_fit {
605                    let vital = (chunk.sequence.to_u16(), true);
606                    online.packet.write_chunk(&chunk.data, Some(vital));
607                    i += 1;
608                }
609            }
610            if !can_fit {
611                self.send.set(cb, Duration::from_millis(500));
612                online.flush(cb, &mut self.builder)?;
613            }
614        }
615        Ok(())
616    }
617    pub fn flush<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
618        self.send.set(cb, Duration::from_millis(500));
619        self.state.assert_online().flush(cb, &mut self.builder)
620    }
621    fn queue<CB: Callback>(&mut self, cb: &mut CB, buffer: &[u8], vital: bool) {
622        let online = self.state.assert_online();
623        let vital = if vital {
624            let sequence = online.sequence.next();
625            online
626                .resend_queue
627                .push_front(ResendChunk::new(cb, sequence, buffer));
628            Some((sequence.to_u16(), false))
629        } else {
630            None
631        };
632        if vital.is_none() {
633            online.packet_nonvital.write_chunk(buffer, vital);
634        }
635        online.packet.write_chunk(buffer, vital)
636    }
637    pub fn send<CB: Callback>(
638        &mut self,
639        cb: &mut CB,
640        buffer: &[u8],
641        vital: bool,
642    ) -> Result<(), Error<CB::Error>> {
643        let result;
644        {
645            let online = self.state.assert_online();
646            if buffer.len() > MAX_PAYLOAD {
647                return Err(Error::TooLongData);
648            }
649            if !online.packet.can_fit_chunk(buffer, vital) {
650                result = online.flush(cb, &mut self.builder).map_err(Error::from);
651            } else {
652                result = Ok(());
653            }
654        }
655        self.queue(cb, buffer, vital);
656        result
657    }
658    pub fn send_connless<CB: Callback>(
659        &mut self,
660        cb: &mut CB,
661        data: &[u8],
662    ) -> Result<(), Error<CB::Error>> {
663        let online = self.state.assert_online();
664        self.send.set(cb, Duration::from_millis(500));
665        self.builder.send(
666            cb,
667            Packet::Connless(ConnlessPacket {
668                token: online.their_token,
669                response_token: online.own_token,
670                payload: data,
671            }),
672        )
673    }
674    fn send_control<CB: Callback>(
675        &mut self,
676        cb: &mut CB,
677        control: ControlPacket,
678    ) -> Result<(), CB::Error> {
679        self.send_control_with_token(cb, control, self.state.their_token().unwrap_or(TOKEN_NONE))
680    }
681    fn send_control_with_token<CB: Callback>(
682        &mut self,
683        cb: &mut CB,
684        control: ControlPacket,
685        token: Token,
686    ) -> Result<(), CB::Error> {
687        let ack = match self.state {
688            State::Online(ref mut online) => online.ack.to_u16(),
689            _ => 0,
690        };
691        self.builder
692            .send(
693                cb,
694                Packet::Connected(ConnectedPacket {
695                    token: token,
696                    ack: ack,
697                    type_: ConnectedPacketType::Control(control),
698                }),
699            )
700            .map_err(|e| e.unwrap_callback())
701    }
702    pub fn tick<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
703        let do_resend = match self.state {
704            State::Online(ref online) => {
705                // WARN?
706                online
707                    .resend_queue
708                    .back()
709                    .map(|c| c.next_send.has_triggered_level(cb))
710                    .unwrap_or(false)
711            }
712            _ => false,
713        };
714        if do_resend {
715            self.resend(cb)
716        } else if self.send.has_triggered_edge(cb) {
717            self.tick_action(cb)
718        } else {
719            Ok(())
720        }
721    }
722    fn tick_action<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
723        let control = match self.state {
724            State::Token(ref token) => ControlPacket::Token(token.own_token),
725            State::Connecting(ref connecting) => ControlPacket::Connect(connecting.own_token),
726            State::Pending(_) => ControlPacket::Accept,
727            State::Online(ref mut online) => {
728                if online.can_send() {
729                    // TODO: Warn if this happens on reliable networks.
730                    self.send.set(cb, Duration::from_millis(500));
731                    return online.flush(cb, &mut self.builder);
732                }
733                ControlPacket::KeepAlive
734            }
735            _ => return Ok(()),
736        };
737        self.send.set(cb, Duration::from_millis(500));
738        self.send_control(cb, control)
739    }
740    /// Notifies the connection of incoming data.
741    ///
742    /// `buffer` must have at least size `MAX_PAYLOAD`.
743    pub fn feed<'a, B, CB, W>(
744        &mut self,
745        cb: &mut CB,
746        warn: &mut W,
747        data: &'a [u8],
748        buf: B,
749    ) -> (ReceivePacket<'a>, Result<(), CB::Error>)
750    where
751        B: Buffer<'a>,
752        CB: Callback,
753        W: Warn<Warning>,
754    {
755        with_buffer(buf, |b| self.feed_impl(cb, warn, data, b))
756    }
757
758    pub fn feed_impl<'d, 's, CB, W>(
759        &mut self,
760        cb: &mut CB,
761        warn: &mut W,
762        data: &'d [u8],
763        mut buffer: BufferRef<'d, 's>,
764    ) -> (ReceivePacket<'d>, Result<(), CB::Error>)
765    where
766        CB: Callback,
767        W: Warn<Warning>,
768    {
769        let none = (ReceivePacket::none(), Ok(()));
770        {
771            use protocol::ConnectedPacketType::*;
772            use protocol::ControlPacket::*;
773
774            let packet = match Packet::read(&mut w(warn), data, &mut buffer) {
775                Ok(p) => p,
776                Err(e) => {
777                    warn.warn(Warning::Read(e));
778                    return none;
779                }
780            };
781
782            let connected = match packet {
783                Packet::Connless(ConnlessPacket {
784                    token,
785                    response_token,
786                    payload,
787                }) => {
788                    if Some(token) != self.state.own_token() {
789                        warn.warn(Warning::ConnlessTokenMismatch);
790                        return none;
791                    }
792                    if Some(response_token) != self.state.their_token() {
793                        warn.warn(Warning::ConnlessResponseTokenMismatch);
794                        return none;
795                    }
796                    return (ReceivePacket::connless(payload), Ok(()));
797                }
798                Packet::Connected(c) => c,
799            };
800            let ConnectedPacket { token, ack, type_ } = connected;
801
802            let mut expected_token = self.state.own_token().unwrap_or(TOKEN_NONE);
803            // Allow unauthenticated TOKEN requests before the first
804            // authenticated message, even if we've replied to a TOKEN request
805            // before.
806            //
807            // TODO: test this
808            if matches!(type_, Control(Token(_)))
809                && matches!(self.state, State::PendingConnect(_))
810                && token == TOKEN_NONE
811            {
812                expected_token = TOKEN_NONE;
813            }
814            if token != expected_token {
815                warn.warn(Warning::TokenMismatch);
816                return none;
817            }
818
819            // TODO: Check ack for sanity.
820            if let State::Online(ref mut online) = self.state {
821                online.ack_chunks(Sequence::from_u16(ack));
822            }
823
824            match type_ {
825                Chunks(request_resend, num_chunks, chunks) => {
826                    let _ = num_chunks;
827                    if let State::Pending(ref pending) = self.state {
828                        self.state =
829                            State::Online(OnlineState::new(pending.own_token, pending.their_token));
830                    }
831                    let result;
832                    if request_resend {
833                        if let State::Online(_) = self.state {
834                            result = self.resend(cb);
835                        } else {
836                            result = Ok(());
837                        }
838                    } else {
839                        result = Ok(())
840                    }
841                    match self.state {
842                        State::Online(ref mut online) => {
843                            return (
844                                ReceivePacket::connected(warn, online, num_chunks, chunks),
845                                result,
846                            );
847                        }
848                        State::Pending(_) => unreachable!(),
849                        // WARN: packet received while not online.
850                        _ => return none,
851                    }
852                }
853                Control(Token(their_token)) => {
854                    if let State::Unconnected = self.state {
855                        use self::Token;
856                        self.state =
857                            State::PendingConnect(PendingConnectState::new(Token::random(|b| {
858                                cb.secure_random(b)
859                            })));
860                    }
861
862                    match self.state {
863                        State::Unconnected => unreachable!(),
864                        State::PendingConnect(ref pending_connect) => {
865                            return (
866                                ReceivePacket::none(),
867                                self.send_control_with_token(
868                                    cb,
869                                    ControlPacket::Token(pending_connect.own_token),
870                                    their_token,
871                                ),
872                            );
873                        }
874                        State::Token(ref token) => {
875                            self.state = State::Connecting(ConnectingState::new(
876                                token.own_token,
877                                their_token,
878                            ));
879                            // Fall through to tick.
880                        }
881                        _ => return none,
882                    }
883                }
884                Control(KeepAlive) => return none,
885                Control(Connect(their_token)) => {
886                    if let State::PendingConnect(ref pending_connect) = self.state {
887                        self.state = State::Pending(PendingState::new(
888                            pending_connect.own_token,
889                            their_token,
890                        ));
891                        // Fall through to tick.
892                    } else {
893                        return none;
894                    }
895                }
896                Control(Accept) => {
897                    if let State::Connecting(ref connecting) = self.state {
898                        self.state = State::Online(OnlineState::new(
899                            connecting.own_token,
900                            connecting.their_token,
901                        ));
902                        return (ReceivePacket::ready(), Ok(()));
903                    } else {
904                        return none;
905                    }
906                }
907                Control(Close(reason)) => {
908                    self.state = State::Disconnected;
909                    return (ReceivePacket::disconnect(reason), Ok(()));
910                }
911            }
912        }
913        // Fall-through from `Control(Connect)`
914        (ReceivePacket::none(), self.tick_action(cb))
915    }
916}
917
918#[cfg(test)]
919mod test {
920    use super::Callback;
921    use super::Connection;
922    use super::ReceiveChunk;
923    use super::Sequence;
924    use super::SequenceOrdering;
925    use crate::protocol7 as protocol;
926    use crate::Timestamp;
927    use hexdump::hexdump;
928    use itertools::Itertools;
929    use libtw2_warn::Panic;
930    use std::collections::VecDeque;
931    use void::ResultVoidExt;
932    use void::Void;
933
934    #[test]
935    fn sequence_compare() {
936        use super::SequenceOrdering::*;
937
938        fn cmp(a: Sequence, b: Sequence) -> SequenceOrdering {
939            Sequence::compare(a, b)
940        }
941        let default = Sequence::new();
942        let first = Sequence::from_u16(0);
943        let mid = Sequence::from_u16(protocol::SEQUENCE_MODULUS / 2);
944        let end = Sequence::from_u16(protocol::SEQUENCE_MODULUS - 1);
945        assert_eq!(cmp(default, first), Current);
946        assert_eq!(cmp(first, mid), Past);
947        assert_eq!(cmp(first, end), Past);
948        assert_eq!(cmp(mid, first), Past);
949        assert_eq!(cmp(mid, end), Future);
950        assert_eq!(cmp(end, first), Future);
951        assert_eq!(cmp(end, mid), Past);
952    }
953
954    #[test]
955    fn establish_connection() {
956        #[derive(Default)]
957        struct Cb {
958            sent: VecDeque<Vec<u8>>,
959            random_value: Option<[u8; 4]>,
960        }
961        impl Callback for Cb {
962            type Error = Void;
963            fn secure_random(&mut self, buffer: &mut [u8]) {
964                if buffer.len() != 4 {
965                    unimplemented!();
966                }
967                buffer.copy_from_slice(&self.random_value.take().unwrap());
968            }
969            fn send(&mut self, data: &[u8]) -> Result<(), Void> {
970                self.sent.push_back(data.to_owned());
971                Ok(())
972            }
973            fn time(&mut self) -> Timestamp {
974                Timestamp::from_secs_since_epoch(0)
975            }
976        }
977        let mut buffer = [0; protocol::MAX_PACKETSIZE];
978        let mut cb = Cb::default();
979        let cb = &mut cb;
980        println!("");
981
982        let mut client = Connection::new();
983        let mut server = Connection::new();
984
985        // Token request
986        assert!(cb.random_value.replace([0x12, 0x34, 0x56, 0x78]).is_none());
987        client.connect(cb).void_unwrap();
988        let packet = cb.sent.pop_front().unwrap();
989        assert!(cb.sent.is_empty());
990        hexdump(&packet);
991        let expected = {
992            let mut expected = Vec::new();
993            expected.extend_from_slice(b"\x04\x00\x00\xff\xff\xff\xff\x05\x12\x34\x56\x78");
994            expected.extend_from_slice(&[0; protocol::TOKEN_REQUEST_PACKET_SIZE - 12]);
995            expected
996        };
997        // FIXME(rust-lang/rust#87555): Use concat_bytes!
998        assert!(&packet == &expected);
999
1000        // Token response
1001        assert!(cb.random_value.replace([0x9a, 0xbc, 0xde, 0xf0]).is_none());
1002        assert!(server
1003            .feed(cb, &mut Panic, &packet, &mut buffer[..])
1004            .0
1005            .next()
1006            .is_none());
1007        let packet = cb.sent.pop_front().unwrap();
1008        assert!(cb.sent.is_empty());
1009        hexdump(&packet);
1010        assert!(&packet == b"\x04\x00\x00\x12\x34\x56\x78\x05\x9a\xbc\xde\xf0");
1011
1012        // Connect
1013        assert!(client
1014            .feed(cb, &mut Panic, &packet, &mut buffer[..])
1015            .0
1016            .next()
1017            .is_none());
1018        let packet = cb.sent.pop_front().unwrap();
1019        assert!(cb.sent.is_empty());
1020        hexdump(&packet);
1021        assert!(&packet == b"\x04\x00\x00\x9a\xbc\xde\xf0\x01\x12\x34\x56\x78");
1022
1023        // Accept
1024        assert!(server
1025            .feed(cb, &mut Panic, &packet, &mut buffer[..])
1026            .0
1027            .next()
1028            .is_none());
1029        let packet = cb.sent.pop_front().unwrap();
1030        assert!(cb.sent.is_empty());
1031        hexdump(&packet);
1032        assert!(&packet == b"\x04\x00\x00\x12\x34\x56\x78\x02");
1033
1034        assert!(
1035            client
1036                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1037                .0
1038                .collect_vec()
1039                == &[ReceiveChunk::Ready]
1040        );
1041        assert!(cb.sent.is_empty());
1042
1043        // Send
1044        client.send(cb, b"\x42", true).unwrap();
1045        assert!(cb.sent.is_empty());
1046
1047        // Flush
1048        client.flush(cb).void_unwrap();
1049        let packet = cb.sent.pop_front().unwrap();
1050        assert!(cb.sent.is_empty());
1051        hexdump(&packet);
1052        assert!(&packet == b"\x00\x00\x01\x9a\xbc\xde\xf0\x40\x01\x01\x42");
1053
1054        // Receive
1055        assert!(
1056            server
1057                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1058                .0
1059                .collect_vec()
1060                == &[ReceiveChunk::Connected(b"\x42", true)]
1061        );
1062        assert!(cb.sent.is_empty());
1063
1064        // Disconnect
1065        server.disconnect(cb, b"42").void_unwrap();
1066        let packet = cb.sent.pop_front().unwrap();
1067        hexdump(&packet);
1068        assert!(&packet == b"\x04\x01\x00\x12\x34\x56\x78\x0442\x00");
1069
1070        assert!(
1071            client
1072                .feed(cb, &mut Panic, &packet, &mut buffer[..])
1073                .0
1074                .collect_vec()
1075                == &[ReceiveChunk::Disconnect(b"42")]
1076        );
1077
1078        client.reset();
1079        server.reset();
1080    }
1081}