1use crate::protocol7 as protocol;
2use crate::protocol7::ChunksIter;
3use crate::protocol7::ConnectedPacket;
4use crate::protocol7::ConnectedPacketType;
5use crate::protocol7::ConnlessPacket;
6use crate::protocol7::ControlPacket;
7use crate::protocol7::Packet;
8use crate::protocol7::Token;
9use crate::protocol7::MAX_PACKETSIZE;
10use crate::protocol7::MAX_PAYLOAD;
11use crate::protocol7::TOKEN_NONE;
12use crate::Timeout;
13use crate::Timestamp;
14use arrayvec::ArrayVec;
15use libtw2_buffer::with_buffer;
16use libtw2_buffer::Buffer;
17use libtw2_buffer::BufferRef;
18use libtw2_warn::Warn;
19use std::cmp;
20use std::collections::VecDeque;
21use std::iter;
22use std::time::Duration;
23
24pub trait Callback {
28 type Error;
29 fn secure_random(&mut self, buffer: &mut [u8]);
30 fn send(&mut self, buffer: &[u8]) -> Result<(), Self::Error>;
31 fn time(&mut self) -> Timestamp;
32}
33
34#[derive(Debug)]
35pub enum Error<CE> {
36 TooLongData,
37 Callback(CE),
38}
39
40impl<CE> From<CE> for Error<CE> {
41 fn from(e: CE) -> Error<CE> {
42 Error::Callback(e)
43 }
44}
45
46impl<CE> Error<CE> {
47 pub fn unwrap_callback(self) -> CE {
48 match self {
49 Error::TooLongData => panic!("too long data"),
50 Error::Callback(e) => e,
51 }
52 }
53}
54
55#[derive(Debug)]
56pub enum Warning {
57 ConnlessResponseTokenMismatch,
58 ConnlessTokenMismatch,
59 Packet(protocol::Warning),
60 Read(protocol::PacketReadError),
61 TokenMismatch,
62 Unexpected,
63}
64
65trait TimeoutExt {
66 fn set<CB: Callback>(&mut self, cb: &mut CB, value: Duration);
67 fn has_triggered_level<CB: Callback>(&self, cb: &mut CB) -> bool;
68 fn has_triggered_edge<CB: Callback>(&mut self, cb: &mut CB) -> bool;
69}
70
71impl TimeoutExt for Timeout {
72 fn set<CB: Callback>(&mut self, cb: &mut CB, value: Duration) {
73 *self = Timeout::active(cb.time() + value);
74 }
75 fn has_triggered_level<CB: Callback>(&self, cb: &mut CB) -> bool {
76 self.to_opt().map(|time| time <= cb.time()).unwrap_or(false)
77 }
78 fn has_triggered_edge<CB: Callback>(&mut self, cb: &mut CB) -> bool {
79 let triggered = self.has_triggered_level(cb);
80 if triggered {
81 *self = Timeout::inactive();
82 }
83 triggered
84 }
85}
86
87pub struct Connection {
88 state: State,
89 send: Timeout,
90 builder: PacketBuilder,
91}
92
93#[derive(Clone, Debug)]
94enum State {
95 Unconnected,
96 Token(TokenState),
97 PendingConnect(PendingConnectState),
98 Connecting(ConnectingState),
99 Pending(PendingState),
100 Online(OnlineState),
101 Disconnected,
102}
103
104impl State {
105 pub fn assert_online(&mut self) -> &mut OnlineState {
106 match *self {
107 State::Online(ref mut s) => s,
108 _ => panic!("state not online"),
109 }
110 }
111 pub fn own_token(&self) -> Option<Token> {
112 use self::State::*;
113 match *self {
114 Unconnected => None,
115 Token(ref token) => Some(token.own_token),
116 PendingConnect(ref pending_connect) => Some(pending_connect.own_token),
117 Connecting(ref connecting) => Some(connecting.own_token),
118 Pending(ref pending) => Some(pending.own_token),
119 Online(ref online) => Some(online.own_token),
120 Disconnected => None,
121 }
122 }
123 pub fn their_token(&self) -> Option<Token> {
124 use self::State::*;
125 match *self {
126 Unconnected => None,
127 Token(_) => None,
128 PendingConnect(_) => None,
129 Connecting(ref connecting) => Some(connecting.their_token),
130 Pending(ref pending) => Some(pending.their_token),
131 Online(ref online) => Some(online.their_token),
132 Disconnected => None,
133 }
134 }
135}
136
137#[derive(Clone, Debug)]
138struct ResendChunk {
139 next_send: Timeout,
140 sequence: Sequence,
141 data: ArrayVec<[u8; 2048]>,
142}
143
144impl ResendChunk {
145 fn new<CB: Callback>(cb: &mut CB, sequence: Sequence, data: &[u8]) -> ResendChunk {
146 let mut result = ResendChunk {
147 next_send: Timeout::inactive(),
148 sequence: sequence,
149 data: data.iter().cloned().collect(),
150 };
151 assert!(
152 result.data.len() == data.len(),
153 "overlong resend packet {}",
154 data.len()
155 );
156 result.start_timeout(cb);
157 result
158 }
159 fn start_timeout<CB: Callback>(&mut self, cb: &mut CB) {
160 self.next_send.set(cb, Duration::from_millis(1_000));
161 }
162}
163
164pub struct ReceivePacket<'a> {
165 type_: ReceivePacketType<'a>,
166}
167
168impl<'a> Clone for ReceivePacket<'a> {
169 fn clone(&self) -> ReceivePacket<'a> {
170 ReceivePacket {
171 type_: self.type_.clone(),
172 }
173 }
174}
175
176impl<'a> ReceivePacket<'a> {
177 fn none() -> ReceivePacket<'a> {
178 ReceivePacket {
179 type_: ReceivePacketType::None,
180 }
181 }
182 fn ready() -> ReceivePacket<'a> {
183 ReceivePacket {
184 type_: ReceivePacketType::Ready(iter::once(())),
185 }
186 }
187 fn connless(data: &'a [u8]) -> ReceivePacket<'a> {
188 ReceivePacket {
189 type_: ReceivePacketType::Connless(iter::once(data)),
190 }
191 }
192 fn connected<W>(
193 warn: &mut W,
194 online: &mut OnlineState,
195 num_chunks: u8,
196 data: &'a [u8],
197 ) -> ReceivePacket<'a>
198 where
199 W: Warn<Warning>,
200 {
201 let chunks_iter = ChunksIter::new(data, num_chunks);
202 let ack = online.ack.clone();
203 let mut iter = chunks_iter.clone();
204 while let Some(c) = iter.next_warn(&mut w(warn)) {
205 if let Some((sequence, resend)) = c.vital {
206 let _ = resend;
207 if online.ack.update(Sequence::from_u16(sequence)) != SequenceOrdering::Current {
208 online.request_resend = true;
209 }
210 }
211 }
212 ReceivePacket {
213 type_: ReceivePacketType::Connected(ReceiveChunks {
214 ack: ack,
215 chunks: chunks_iter,
216 }),
217 }
218 }
219 fn disconnect(reason: &'a [u8]) -> ReceivePacket<'a> {
220 ReceivePacket {
221 type_: ReceivePacketType::Close(iter::once(reason)),
222 }
223 }
224}
225
226#[derive(Clone)]
227enum ReceivePacketType<'a> {
228 None,
229 Connless(iter::Once<&'a [u8]>),
230 Connected(ReceiveChunks<'a>),
231 Ready(iter::Once<()>),
232 Close(iter::Once<&'a [u8]>),
233}
234
235impl<'a> Iterator for ReceivePacket<'a> {
236 type Item = ReceiveChunk<'a>;
237 fn next(&mut self) -> Option<ReceiveChunk<'a>> {
238 match self.type_ {
239 ReceivePacketType::None => None,
240 ReceivePacketType::Ready(ref mut once) => once.next().map(|()| ReceiveChunk::Ready),
241 ReceivePacketType::Connless(ref mut once) => once.next().map(ReceiveChunk::Connless),
242 ReceivePacketType::Connected(ref mut chunks) => chunks.next(),
243 ReceivePacketType::Close(ref mut once) => once.next().map(ReceiveChunk::Disconnect),
244 }
245 }
246 fn size_hint(&self) -> (usize, Option<usize>) {
247 let len = self.clone().count();
248 (len, Some(len))
249 }
250}
251
252impl<'a> ExactSizeIterator for ReceivePacket<'a> {}
253
254#[derive(Clone)]
255struct ReceiveChunks<'a> {
256 ack: Sequence,
257 chunks: ChunksIter<'a>,
258}
259
260impl<'a> Iterator for ReceiveChunks<'a> {
261 type Item = ReceiveChunk<'a>;
262 fn next(&mut self) -> Option<ReceiveChunk<'a>> {
263 self.chunks.next().and_then(|c| {
264 if let Some((sequence, resend)) = c.vital {
265 let _ = resend;
266 if self.ack.update(Sequence::from_u16(sequence)) != SequenceOrdering::Current {
267 return self.next();
268 }
269 }
270 Some(ReceiveChunk::Connected(c.data, c.vital.is_some()))
271 })
272 }
273 fn size_hint(&self) -> (usize, Option<usize>) {
274 let len = self.clone().count();
275 (len, Some(len))
276 }
277}
278
279impl<'a> ExactSizeIterator for ReceiveChunks<'a> {}
280
281#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
282pub enum ReceiveChunk<'a> {
283 Connless(&'a [u8]),
284 Connected(&'a [u8], bool),
286 Ready,
287 Disconnect(&'a [u8]),
288}
289
290#[derive(Clone, Debug)]
291struct TokenState {
292 own_token: Token,
293}
294
295impl TokenState {
296 fn new(own_token: Token) -> TokenState {
297 TokenState {
298 own_token: own_token,
299 }
300 }
301}
302
303#[derive(Clone, Debug)]
304struct PendingConnectState {
305 own_token: Token,
306}
307
308impl PendingConnectState {
309 fn new(own_token: Token) -> PendingConnectState {
310 PendingConnectState {
311 own_token: own_token,
312 }
313 }
314}
315
316#[derive(Clone, Debug)]
317struct ConnectingState {
318 own_token: Token,
319 their_token: Token,
320}
321
322impl ConnectingState {
323 fn new(own_token: Token, their_token: Token) -> ConnectingState {
324 ConnectingState {
325 own_token: own_token,
326 their_token: their_token,
327 }
328 }
329}
330
331#[derive(Clone, Debug)]
332struct PendingState {
333 own_token: Token,
334 their_token: Token,
335}
336
337impl PendingState {
338 fn new(own_token: Token, their_token: Token) -> PendingState {
339 PendingState {
340 own_token: own_token,
341 their_token: their_token,
342 }
343 }
344}
345
346#[derive(Clone, Debug)]
347struct OnlineState {
348 own_token: Token,
351 their_token: Token,
354 ack: Sequence,
356 sequence: Sequence,
358 request_resend: bool,
359 packet: PacketContents,
362 packet_nonvital: PacketContents,
363 resend_queue: VecDeque<ResendChunk>,
366}
367
368impl OnlineState {
369 fn new(own_token: Token, their_token: Token) -> OnlineState {
370 OnlineState {
371 own_token: own_token,
372 their_token: their_token,
373 ack: Sequence::new(),
374 sequence: Sequence::new(),
375 request_resend: false,
376 packet: PacketContents::new(),
377 packet_nonvital: PacketContents::new(),
378 resend_queue: VecDeque::new(),
379 }
380 }
381 fn can_send(&self) -> bool {
382 self.packet.num_chunks != 0 || self.request_resend
383 }
384 fn ack_chunks(&mut self, ack: Sequence) {
385 let index = self
386 .resend_queue
387 .iter()
388 .position(|chunk| chunk.sequence == ack);
389 index.map(|i| self.resend_queue.truncate(i));
390 }
391 fn flush<CB: Callback>(
392 &mut self,
393 cb: &mut CB,
394 builder: &mut PacketBuilder,
395 ) -> Result<(), CB::Error> {
396 if !self.can_send() {
397 return Ok(());
398 }
399 let result = builder
400 .send(
401 cb,
402 Packet::Connected(ConnectedPacket {
403 token: self.their_token,
404 ack: self.ack.to_u16(),
405 type_: ConnectedPacketType::Chunks(
406 self.request_resend,
407 self.packet.num_chunks,
408 &self.packet.data,
409 ),
410 }),
411 )
412 .map_err(|e| e.unwrap_callback());
413 self.request_resend = false;
414 self.packet.clear();
415 self.packet_nonvital.clear();
416 result
417 }
418}
419
420#[derive(Clone, Debug, Eq, PartialEq)]
421struct PacketContents {
422 num_chunks: u8,
423 data: ArrayVec<[u8; 2048]>,
424}
425
426impl PacketContents {
427 fn new() -> PacketContents {
428 PacketContents {
429 num_chunks: 0,
430 data: ArrayVec::new(),
431 }
432 }
433 fn write_chunk(&mut self, data: &[u8], vital: Option<(u16, bool)>) {
434 protocol::write_chunk(data, vital, &mut self.data).unwrap();
435 self.num_chunks += 1;
436 }
437 fn can_fit_chunk(&self, data: &[u8], vital: bool) -> bool {
438 self.data.len() + protocol::chunk_header_size(vital) + data.len() <= MAX_PAYLOAD
440 }
441 fn clear(&mut self) {
442 *self = PacketContents::new();
443 }
444}
445
446#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
447struct Sequence {
448 seq: u16, }
450
451#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
452enum SequenceOrdering {
453 Past,
454 Current,
455 Future,
456}
457
458impl Sequence {
459 fn new() -> Sequence {
460 Default::default()
461 }
462 fn from_u16(seq: u16) -> Sequence {
463 assert!(seq < protocol::SEQUENCE_MODULUS);
464 Sequence { seq: seq }
465 }
466 fn to_u16(self) -> u16 {
467 self.seq
468 }
469 fn next(&mut self) -> Sequence {
470 self.seq = (self.seq + 1) % protocol::SEQUENCE_MODULUS;
471 *self
472 }
473 fn update(&mut self, other: Sequence) -> SequenceOrdering {
474 let mut next_self = *self;
475 next_self.next();
476 let result = next_self.compare(other);
477 if result == SequenceOrdering::Current {
478 *self = next_self;
479 }
480 result
481 }
482 fn compare(self, other: Sequence) -> SequenceOrdering {
484 let half = protocol::SEQUENCE_MODULUS / 2;
485 let less;
486 match self.seq.cmp(&other.seq) {
487 cmp::Ordering::Less => less = other.seq - self.seq < half,
488 cmp::Ordering::Greater => less = self.seq - other.seq > half,
489 cmp::Ordering::Equal => return SequenceOrdering::Current,
490 }
491 if less {
492 SequenceOrdering::Future
493 } else {
494 SequenceOrdering::Past
495 }
496 }
497}
498
499struct PacketBuilder {
500 buffer: [u8; MAX_PACKETSIZE],
501}
502
503impl PacketBuilder {
504 fn new() -> PacketBuilder {
505 PacketBuilder {
506 buffer: [0; MAX_PACKETSIZE],
507 }
508 }
509 fn send<CB: Callback>(&mut self, cb: &mut CB, packet: Packet) -> Result<(), Error<CB::Error>> {
510 let data = match packet.write(&mut self.buffer[..]) {
511 Ok(d) => d,
512 Err(protocol::Error::Capacity(_)) => unreachable!("too short buffer provided"),
513 Err(protocol::Error::TooLongData) => return Err(Error::TooLongData),
514 };
515 cb.send(data)?;
516 Ok(())
517 }
518}
519
520struct WarnCallback<'a, W: Warn<Warning> + 'a> {
521 warn: &'a mut W,
522}
523
524fn w<W: Warn<Warning>>(warn: &mut W) -> WarnCallback<'_, W> {
525 WarnCallback { warn: warn }
526}
527
528impl<'a, W: Warn<Warning>> Warn<protocol::Warning> for WarnCallback<'a, W> {
529 fn warn(&mut self, warning: protocol::Warning) {
530 self.warn.warn(Warning::Packet(warning))
531 }
532}
533
534impl Connection {
535 pub fn new() -> Connection {
536 Connection {
537 state: State::Unconnected,
538 send: Timeout::inactive(),
539 builder: PacketBuilder::new(),
540 }
541 }
542 pub fn reset(&mut self) {
543 assert!(matches!(self.state, State::Disconnected));
544 *self = Connection::new();
545 }
546 pub fn is_unconnected(&self) -> bool {
547 matches!(self.state, State::Unconnected)
548 }
549 pub fn needs_tick(&self) -> Timeout {
550 match self.state {
551 State::Unconnected | State::Disconnected => return Timeout::inactive(),
552 _ => {}
553 }
554 let resends = match self.state {
555 State::Online(ref online) => online
556 .resend_queue
557 .back()
558 .map(|r| r.next_send)
559 .unwrap_or_default(),
560 _ => Timeout::inactive(),
561 };
562 cmp::min(self.send, resends)
563 }
564 pub fn connect<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
565 assert!(matches!(self.state, State::Unconnected));
566 self.state = State::Token(TokenState::new(Token::random(|b| cb.secure_random(b))));
567 self.tick_action(cb)?;
568 Ok(())
569 }
570 pub fn disconnect<CB: Callback>(
571 &mut self,
572 cb: &mut CB,
573 reason: &[u8],
574 ) -> Result<(), CB::Error> {
575 if let State::Disconnected = self.state {
576 assert!(
577 false,
578 "Can't call disconnect on an already disconnected connection"
579 );
580 }
581 assert!(
582 reason.iter().all(|&b| b != 0),
583 "reason must not contain NULs"
584 );
585 let result = self.send_control(cb, ControlPacket::Close(reason));
586 self.state = State::Disconnected;
587 result
588 }
589 fn resend<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
590 let online = self.state.assert_online();
591 if online.resend_queue.is_empty() {
592 return Ok(());
593 }
594 online.packet = online.packet_nonvital.clone();
595 let mut i = 0;
596 for chunk in &mut online.resend_queue {
597 chunk.start_timeout(cb);
598 }
599 while i < online.resend_queue.len() {
600 let can_fit;
601 {
602 let chunk = &online.resend_queue[online.resend_queue.len() - i - 1];
603 can_fit = online.packet.can_fit_chunk(&chunk.data, true);
604 if can_fit {
605 let vital = (chunk.sequence.to_u16(), true);
606 online.packet.write_chunk(&chunk.data, Some(vital));
607 i += 1;
608 }
609 }
610 if !can_fit {
611 self.send.set(cb, Duration::from_millis(500));
612 online.flush(cb, &mut self.builder)?;
613 }
614 }
615 Ok(())
616 }
617 pub fn flush<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
618 self.send.set(cb, Duration::from_millis(500));
619 self.state.assert_online().flush(cb, &mut self.builder)
620 }
621 fn queue<CB: Callback>(&mut self, cb: &mut CB, buffer: &[u8], vital: bool) {
622 let online = self.state.assert_online();
623 let vital = if vital {
624 let sequence = online.sequence.next();
625 online
626 .resend_queue
627 .push_front(ResendChunk::new(cb, sequence, buffer));
628 Some((sequence.to_u16(), false))
629 } else {
630 None
631 };
632 if vital.is_none() {
633 online.packet_nonvital.write_chunk(buffer, vital);
634 }
635 online.packet.write_chunk(buffer, vital)
636 }
637 pub fn send<CB: Callback>(
638 &mut self,
639 cb: &mut CB,
640 buffer: &[u8],
641 vital: bool,
642 ) -> Result<(), Error<CB::Error>> {
643 let result;
644 {
645 let online = self.state.assert_online();
646 if buffer.len() > MAX_PAYLOAD {
647 return Err(Error::TooLongData);
648 }
649 if !online.packet.can_fit_chunk(buffer, vital) {
650 result = online.flush(cb, &mut self.builder).map_err(Error::from);
651 } else {
652 result = Ok(());
653 }
654 }
655 self.queue(cb, buffer, vital);
656 result
657 }
658 pub fn send_connless<CB: Callback>(
659 &mut self,
660 cb: &mut CB,
661 data: &[u8],
662 ) -> Result<(), Error<CB::Error>> {
663 let online = self.state.assert_online();
664 self.send.set(cb, Duration::from_millis(500));
665 self.builder.send(
666 cb,
667 Packet::Connless(ConnlessPacket {
668 token: online.their_token,
669 response_token: online.own_token,
670 payload: data,
671 }),
672 )
673 }
674 fn send_control<CB: Callback>(
675 &mut self,
676 cb: &mut CB,
677 control: ControlPacket,
678 ) -> Result<(), CB::Error> {
679 self.send_control_with_token(cb, control, self.state.their_token().unwrap_or(TOKEN_NONE))
680 }
681 fn send_control_with_token<CB: Callback>(
682 &mut self,
683 cb: &mut CB,
684 control: ControlPacket,
685 token: Token,
686 ) -> Result<(), CB::Error> {
687 let ack = match self.state {
688 State::Online(ref mut online) => online.ack.to_u16(),
689 _ => 0,
690 };
691 self.builder
692 .send(
693 cb,
694 Packet::Connected(ConnectedPacket {
695 token: token,
696 ack: ack,
697 type_: ConnectedPacketType::Control(control),
698 }),
699 )
700 .map_err(|e| e.unwrap_callback())
701 }
702 pub fn tick<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
703 let do_resend = match self.state {
704 State::Online(ref online) => {
705 online
707 .resend_queue
708 .back()
709 .map(|c| c.next_send.has_triggered_level(cb))
710 .unwrap_or(false)
711 }
712 _ => false,
713 };
714 if do_resend {
715 self.resend(cb)
716 } else if self.send.has_triggered_edge(cb) {
717 self.tick_action(cb)
718 } else {
719 Ok(())
720 }
721 }
722 fn tick_action<CB: Callback>(&mut self, cb: &mut CB) -> Result<(), CB::Error> {
723 let control = match self.state {
724 State::Token(ref token) => ControlPacket::Token(token.own_token),
725 State::Connecting(ref connecting) => ControlPacket::Connect(connecting.own_token),
726 State::Pending(_) => ControlPacket::Accept,
727 State::Online(ref mut online) => {
728 if online.can_send() {
729 self.send.set(cb, Duration::from_millis(500));
731 return online.flush(cb, &mut self.builder);
732 }
733 ControlPacket::KeepAlive
734 }
735 _ => return Ok(()),
736 };
737 self.send.set(cb, Duration::from_millis(500));
738 self.send_control(cb, control)
739 }
740 pub fn feed<'a, B, CB, W>(
744 &mut self,
745 cb: &mut CB,
746 warn: &mut W,
747 data: &'a [u8],
748 buf: B,
749 ) -> (ReceivePacket<'a>, Result<(), CB::Error>)
750 where
751 B: Buffer<'a>,
752 CB: Callback,
753 W: Warn<Warning>,
754 {
755 with_buffer(buf, |b| self.feed_impl(cb, warn, data, b))
756 }
757
758 pub fn feed_impl<'d, 's, CB, W>(
759 &mut self,
760 cb: &mut CB,
761 warn: &mut W,
762 data: &'d [u8],
763 mut buffer: BufferRef<'d, 's>,
764 ) -> (ReceivePacket<'d>, Result<(), CB::Error>)
765 where
766 CB: Callback,
767 W: Warn<Warning>,
768 {
769 let none = (ReceivePacket::none(), Ok(()));
770 {
771 use protocol::ConnectedPacketType::*;
772 use protocol::ControlPacket::*;
773
774 let packet = match Packet::read(&mut w(warn), data, &mut buffer) {
775 Ok(p) => p,
776 Err(e) => {
777 warn.warn(Warning::Read(e));
778 return none;
779 }
780 };
781
782 let connected = match packet {
783 Packet::Connless(ConnlessPacket {
784 token,
785 response_token,
786 payload,
787 }) => {
788 if Some(token) != self.state.own_token() {
789 warn.warn(Warning::ConnlessTokenMismatch);
790 return none;
791 }
792 if Some(response_token) != self.state.their_token() {
793 warn.warn(Warning::ConnlessResponseTokenMismatch);
794 return none;
795 }
796 return (ReceivePacket::connless(payload), Ok(()));
797 }
798 Packet::Connected(c) => c,
799 };
800 let ConnectedPacket { token, ack, type_ } = connected;
801
802 let mut expected_token = self.state.own_token().unwrap_or(TOKEN_NONE);
803 if matches!(type_, Control(Token(_)))
809 && matches!(self.state, State::PendingConnect(_))
810 && token == TOKEN_NONE
811 {
812 expected_token = TOKEN_NONE;
813 }
814 if token != expected_token {
815 warn.warn(Warning::TokenMismatch);
816 return none;
817 }
818
819 if let State::Online(ref mut online) = self.state {
821 online.ack_chunks(Sequence::from_u16(ack));
822 }
823
824 match type_ {
825 Chunks(request_resend, num_chunks, chunks) => {
826 let _ = num_chunks;
827 if let State::Pending(ref pending) = self.state {
828 self.state =
829 State::Online(OnlineState::new(pending.own_token, pending.their_token));
830 }
831 let result;
832 if request_resend {
833 if let State::Online(_) = self.state {
834 result = self.resend(cb);
835 } else {
836 result = Ok(());
837 }
838 } else {
839 result = Ok(())
840 }
841 match self.state {
842 State::Online(ref mut online) => {
843 return (
844 ReceivePacket::connected(warn, online, num_chunks, chunks),
845 result,
846 );
847 }
848 State::Pending(_) => unreachable!(),
849 _ => return none,
851 }
852 }
853 Control(Token(their_token)) => {
854 if let State::Unconnected = self.state {
855 use self::Token;
856 self.state =
857 State::PendingConnect(PendingConnectState::new(Token::random(|b| {
858 cb.secure_random(b)
859 })));
860 }
861
862 match self.state {
863 State::Unconnected => unreachable!(),
864 State::PendingConnect(ref pending_connect) => {
865 return (
866 ReceivePacket::none(),
867 self.send_control_with_token(
868 cb,
869 ControlPacket::Token(pending_connect.own_token),
870 their_token,
871 ),
872 );
873 }
874 State::Token(ref token) => {
875 self.state = State::Connecting(ConnectingState::new(
876 token.own_token,
877 their_token,
878 ));
879 }
881 _ => return none,
882 }
883 }
884 Control(KeepAlive) => return none,
885 Control(Connect(their_token)) => {
886 if let State::PendingConnect(ref pending_connect) = self.state {
887 self.state = State::Pending(PendingState::new(
888 pending_connect.own_token,
889 their_token,
890 ));
891 } else {
893 return none;
894 }
895 }
896 Control(Accept) => {
897 if let State::Connecting(ref connecting) = self.state {
898 self.state = State::Online(OnlineState::new(
899 connecting.own_token,
900 connecting.their_token,
901 ));
902 return (ReceivePacket::ready(), Ok(()));
903 } else {
904 return none;
905 }
906 }
907 Control(Close(reason)) => {
908 self.state = State::Disconnected;
909 return (ReceivePacket::disconnect(reason), Ok(()));
910 }
911 }
912 }
913 (ReceivePacket::none(), self.tick_action(cb))
915 }
916}
917
918#[cfg(test)]
919mod test {
920 use super::Callback;
921 use super::Connection;
922 use super::ReceiveChunk;
923 use super::Sequence;
924 use super::SequenceOrdering;
925 use crate::protocol7 as protocol;
926 use crate::Timestamp;
927 use hexdump::hexdump;
928 use itertools::Itertools;
929 use libtw2_warn::Panic;
930 use std::collections::VecDeque;
931 use void::ResultVoidExt;
932 use void::Void;
933
934 #[test]
935 fn sequence_compare() {
936 use super::SequenceOrdering::*;
937
938 fn cmp(a: Sequence, b: Sequence) -> SequenceOrdering {
939 Sequence::compare(a, b)
940 }
941 let default = Sequence::new();
942 let first = Sequence::from_u16(0);
943 let mid = Sequence::from_u16(protocol::SEQUENCE_MODULUS / 2);
944 let end = Sequence::from_u16(protocol::SEQUENCE_MODULUS - 1);
945 assert_eq!(cmp(default, first), Current);
946 assert_eq!(cmp(first, mid), Past);
947 assert_eq!(cmp(first, end), Past);
948 assert_eq!(cmp(mid, first), Past);
949 assert_eq!(cmp(mid, end), Future);
950 assert_eq!(cmp(end, first), Future);
951 assert_eq!(cmp(end, mid), Past);
952 }
953
954 #[test]
955 fn establish_connection() {
956 #[derive(Default)]
957 struct Cb {
958 sent: VecDeque<Vec<u8>>,
959 random_value: Option<[u8; 4]>,
960 }
961 impl Callback for Cb {
962 type Error = Void;
963 fn secure_random(&mut self, buffer: &mut [u8]) {
964 if buffer.len() != 4 {
965 unimplemented!();
966 }
967 buffer.copy_from_slice(&self.random_value.take().unwrap());
968 }
969 fn send(&mut self, data: &[u8]) -> Result<(), Void> {
970 self.sent.push_back(data.to_owned());
971 Ok(())
972 }
973 fn time(&mut self) -> Timestamp {
974 Timestamp::from_secs_since_epoch(0)
975 }
976 }
977 let mut buffer = [0; protocol::MAX_PACKETSIZE];
978 let mut cb = Cb::default();
979 let cb = &mut cb;
980 println!("");
981
982 let mut client = Connection::new();
983 let mut server = Connection::new();
984
985 assert!(cb.random_value.replace([0x12, 0x34, 0x56, 0x78]).is_none());
987 client.connect(cb).void_unwrap();
988 let packet = cb.sent.pop_front().unwrap();
989 assert!(cb.sent.is_empty());
990 hexdump(&packet);
991 let expected = {
992 let mut expected = Vec::new();
993 expected.extend_from_slice(b"\x04\x00\x00\xff\xff\xff\xff\x05\x12\x34\x56\x78");
994 expected.extend_from_slice(&[0; protocol::TOKEN_REQUEST_PACKET_SIZE - 12]);
995 expected
996 };
997 assert!(&packet == &expected);
999
1000 assert!(cb.random_value.replace([0x9a, 0xbc, 0xde, 0xf0]).is_none());
1002 assert!(server
1003 .feed(cb, &mut Panic, &packet, &mut buffer[..])
1004 .0
1005 .next()
1006 .is_none());
1007 let packet = cb.sent.pop_front().unwrap();
1008 assert!(cb.sent.is_empty());
1009 hexdump(&packet);
1010 assert!(&packet == b"\x04\x00\x00\x12\x34\x56\x78\x05\x9a\xbc\xde\xf0");
1011
1012 assert!(client
1014 .feed(cb, &mut Panic, &packet, &mut buffer[..])
1015 .0
1016 .next()
1017 .is_none());
1018 let packet = cb.sent.pop_front().unwrap();
1019 assert!(cb.sent.is_empty());
1020 hexdump(&packet);
1021 assert!(&packet == b"\x04\x00\x00\x9a\xbc\xde\xf0\x01\x12\x34\x56\x78");
1022
1023 assert!(server
1025 .feed(cb, &mut Panic, &packet, &mut buffer[..])
1026 .0
1027 .next()
1028 .is_none());
1029 let packet = cb.sent.pop_front().unwrap();
1030 assert!(cb.sent.is_empty());
1031 hexdump(&packet);
1032 assert!(&packet == b"\x04\x00\x00\x12\x34\x56\x78\x02");
1033
1034 assert!(
1035 client
1036 .feed(cb, &mut Panic, &packet, &mut buffer[..])
1037 .0
1038 .collect_vec()
1039 == &[ReceiveChunk::Ready]
1040 );
1041 assert!(cb.sent.is_empty());
1042
1043 client.send(cb, b"\x42", true).unwrap();
1045 assert!(cb.sent.is_empty());
1046
1047 client.flush(cb).void_unwrap();
1049 let packet = cb.sent.pop_front().unwrap();
1050 assert!(cb.sent.is_empty());
1051 hexdump(&packet);
1052 assert!(&packet == b"\x00\x00\x01\x9a\xbc\xde\xf0\x40\x01\x01\x42");
1053
1054 assert!(
1056 server
1057 .feed(cb, &mut Panic, &packet, &mut buffer[..])
1058 .0
1059 .collect_vec()
1060 == &[ReceiveChunk::Connected(b"\x42", true)]
1061 );
1062 assert!(cb.sent.is_empty());
1063
1064 server.disconnect(cb, b"42").void_unwrap();
1066 let packet = cb.sent.pop_front().unwrap();
1067 hexdump(&packet);
1068 assert!(&packet == b"\x04\x01\x00\x12\x34\x56\x78\x0442\x00");
1069
1070 assert!(
1071 client
1072 .feed(cb, &mut Panic, &packet, &mut buffer[..])
1073 .0
1074 .collect_vec()
1075 == &[ReceiveChunk::Disconnect(b"42")]
1076 );
1077
1078 client.reset();
1079 server.reset();
1080 }
1081}