libtw2_net/
net.rs

1use crate::collections::peer_map;
2use crate::collections::PeerMap;
3use crate::connection;
4use crate::connection::ReceiveChunk;
5use crate::protocol;
6use crate::protocol::ConnectedPacket;
7use crate::protocol::ConnectedPacketType;
8use crate::protocol::ControlPacket;
9use crate::protocol::Packet;
10use crate::Connection;
11use crate::Timeout;
12use crate::Timestamp;
13use arrayvec::ArrayVec;
14use buffer::with_buffer;
15use buffer::Buffer;
16use buffer::BufferRef;
17use std::fmt;
18use std::hash::Hash;
19use std::iter;
20use std::ops;
21use warn::Panic;
22use warn::Warn;
23
24pub use crate::connection::Error;
25
26pub trait Callback<A: Address> {
27    type Error;
28    fn secure_random(&mut self, buffer: &mut [u8]);
29    fn send(&mut self, addr: A, data: &[u8]) -> Result<(), Self::Error>;
30    fn time(&mut self) -> Timestamp;
31}
32
33#[derive(Debug)]
34pub enum Warning<A: Address> {
35    Peer(A, PeerId, connection::Warning),
36    Connless(A, connection::Warning),
37}
38
39impl<A: Address> Warning<A> {
40    pub fn addr(&self) -> A {
41        match *self {
42            Warning::Peer(addr, _, _) => addr,
43            Warning::Connless(addr, _) => addr,
44        }
45    }
46}
47
48pub trait Address: Copy + Eq + Hash + Ord {}
49impl<A: Copy + Eq + Hash + Ord> Address for A {}
50
51#[derive(Clone, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)]
52pub struct PeerId(pub u32);
53
54impl PeerId {
55    fn get_and_increment(&mut self) -> PeerId {
56        let old = *self;
57        self.0 = self.0.wrapping_add(1);
58        old
59    }
60}
61
62impl fmt::Debug for PeerId {
63    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
64        write!(f, "p{}", self.0)
65    }
66}
67
68impl fmt::Display for PeerId {
69    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
70        fmt::Debug::fmt(self, f)
71    }
72}
73
74const CONNECT_PACKET: &'static [u8; 12] = b"\x10\x00\x00\x01TKEN\xff\xff\xff\xff";
75const CONNECT_PACKET_NO_TOKEN: &'static [u8; 4] = b"\x10\x00\x00\x01";
76
77struct Peer<A: Address> {
78    conn: Connection,
79    addr: A,
80    /// This flag says whether the incoming connection had indicated token
81    /// support.
82    ///
83    /// DDNet token support is unfortunately kind of hacked in.
84    token: bool,
85}
86
87impl<A: Address> Peer<A> {
88    fn new(addr: A, token: bool) -> Peer<A> {
89        Peer {
90            conn: Connection::new(),
91            addr: addr,
92            token: token,
93        }
94    }
95}
96
97struct Peers<A: Address> {
98    peers: PeerMap<Peer<A>>,
99    next_peer_id: PeerId,
100}
101
102impl<A: Address> Peers<A> {
103    fn new() -> Peers<A> {
104        Peers {
105            peers: PeerMap::new(),
106            next_peer_id: PeerId(0),
107        }
108    }
109    fn new_peer(&mut self, addr: A, token: bool) -> (PeerId, &mut Peer<A>) {
110        // FIXME(rust-lang/rfcs#811): Work around missing non-lexical borrows.
111        let raw_self: *mut Peers<A> = self;
112        unsafe {
113            loop {
114                let peer_id = self.next_peer_id.get_and_increment();
115                if let peer_map::Entry::Vacant(v) = (*raw_self).peers.entry(peer_id) {
116                    return (peer_id, v.insert(Peer::new(addr, token)));
117                }
118            }
119        }
120    }
121    fn iter(&self) -> peer_map::Iter<Peer<A>> {
122        self.peers.iter()
123    }
124    fn iter_mut(&mut self) -> peer_map::IterMut<Peer<A>> {
125        self.peers.iter_mut()
126    }
127    fn remove_peer(&mut self, pid: PeerId) {
128        self.peers.remove(pid)
129    }
130    fn pid_from_addr(&mut self, addr: A) -> Option<PeerId> {
131        for (pid, p) in self.peers.iter() {
132            if p.addr == addr {
133                return Some(pid);
134            }
135        }
136        None
137    }
138    fn get(&self, pid: PeerId) -> Option<&Peer<A>> {
139        self.peers.get(pid)
140    }
141    fn get_mut(&mut self, pid: PeerId) -> Option<&mut Peer<A>> {
142        self.peers.get_mut(pid)
143    }
144}
145
146impl<A: Address> ops::Index<PeerId> for Peers<A> {
147    type Output = Peer<A>;
148    fn index(&self, pid: PeerId) -> &Peer<A> {
149        self.get(pid).unwrap_or_else(|| panic!("invalid pid"))
150    }
151}
152
153impl<A: Address> ops::IndexMut<PeerId> for Peers<A> {
154    fn index_mut(&mut self, pid: PeerId) -> &mut Peer<A> {
155        self.get_mut(pid).unwrap_or_else(|| panic!("invalid pid"))
156    }
157}
158
159// TODO: Simplify these enums. A lot.
160
161#[derive(Clone, Copy, Debug, Eq, PartialEq)]
162pub enum ChunkOrEvent<'a, A: Address> {
163    Chunk(Chunk<'a>),
164    Connless(ConnlessChunk<'a, A>),
165    Connect(PeerId),
166    Ready(PeerId),
167    Disconnect(PeerId, &'a [u8]),
168}
169
170#[derive(Clone, Copy, Debug, Eq, PartialEq)]
171pub struct Chunk<'a> {
172    pub pid: PeerId,
173    pub vital: bool,
174    pub data: &'a [u8],
175}
176
177#[derive(Clone, Copy, Debug, Eq, PartialEq)]
178pub struct ConnlessChunk<'a, A: Address> {
179    pub addr: A,
180    pub pid: Option<PeerId>,
181    pub data: &'a [u8],
182}
183
184struct ConnlessBuilder {
185    buffer: [u8; protocol::MAX_PACKETSIZE],
186}
187
188impl ConnlessBuilder {
189    fn new() -> ConnlessBuilder {
190        ConnlessBuilder {
191            buffer: [0; protocol::MAX_PACKETSIZE],
192        }
193    }
194    fn send<A: Address, CB: Callback<A>>(
195        &mut self,
196        cb: &mut CB,
197        addr: A,
198        packet: Packet,
199    ) -> Result<(), Error<CB::Error>> {
200        let send_data = match packet.write(&mut self.buffer[..]) {
201            Ok(d) => d,
202            Err(protocol::Error::Capacity(_)) => unreachable!("too short buffer provided"),
203            Err(protocol::Error::TooLongData) => return Err(Error::TooLongData),
204        };
205        cb.send(addr, send_data)?;
206        Ok(())
207    }
208}
209
210#[derive(Clone)]
211pub struct ReceivePacket<'a, A: Address> {
212    type_: ReceivePacketType<'a, A>,
213}
214
215impl<'a, A: Address> Iterator for ReceivePacket<'a, A> {
216    type Item = ChunkOrEvent<'a, A>;
217    fn next(&mut self) -> Option<ChunkOrEvent<'a, A>> {
218        use self::ReceivePacketType::Connect;
219        use self::ReceivePacketType::Connected;
220        use self::ReceivePacketType::Connless;
221        match self.type_ {
222            ReceivePacketType::None => None,
223            Connect(ref mut once) => once.next().map(|pid| ChunkOrEvent::Connect(pid)),
224            Connected(addr, pid, ref mut receive_packet) => {
225                receive_packet.next().map(|chunk| match chunk {
226                    ReceiveChunk::Connless(d) => ChunkOrEvent::Connless(ConnlessChunk {
227                        addr: addr,
228                        pid: Some(pid),
229                        data: d,
230                    }),
231                    ReceiveChunk::Connected(d, vital) => ChunkOrEvent::Chunk(Chunk {
232                        pid: pid,
233                        vital: vital,
234                        data: d,
235                    }),
236                    ReceiveChunk::Ready => ChunkOrEvent::Ready(pid),
237                    ReceiveChunk::Disconnect(r) => ChunkOrEvent::Disconnect(pid, r),
238                })
239            }
240            Connless(addr, ref mut once) => once.next().map(|data| {
241                ChunkOrEvent::Connless(ConnlessChunk {
242                    addr: addr,
243                    pid: None,
244                    data: data,
245                })
246            }),
247        }
248    }
249    fn size_hint(&self) -> (usize, Option<usize>) {
250        let len = self.clone().count();
251        (len, Some(len))
252    }
253}
254
255impl<'a, A: Address> ExactSizeIterator for ReceivePacket<'a, A> {}
256
257impl<'a, A: Address> ReceivePacket<'a, A> {
258    fn none() -> ReceivePacket<'a, A> {
259        ReceivePacket {
260            type_: ReceivePacketType::None,
261        }
262    }
263    fn connect(pid: PeerId) -> ReceivePacket<'a, A> {
264        ReceivePacket {
265            type_: ReceivePacketType::Connect(iter::once(pid)),
266        }
267    }
268    fn connected(
269        addr: A,
270        pid: PeerId,
271        receive_packet: connection::ReceivePacket<'a>,
272        net: &mut Net<A>,
273    ) -> ReceivePacket<'a, A> {
274        for chunk in receive_packet.clone() {
275            if let ReceiveChunk::Disconnect(..) = chunk {
276                net.peers.remove_peer(pid);
277            }
278        }
279        ReceivePacket {
280            type_: ReceivePacketType::Connected(addr, pid, receive_packet),
281        }
282    }
283
284    fn connless(addr: A, data: &'a [u8]) -> ReceivePacket<'a, A> {
285        ReceivePacket {
286            type_: ReceivePacketType::Connless(addr, iter::once(data)),
287        }
288    }
289}
290
291#[derive(Clone)]
292enum ReceivePacketType<'a, A: Address> {
293    None,
294    Connect(iter::Once<PeerId>),
295    Connected(A, PeerId, connection::ReceivePacket<'a>),
296    Connless(A, iter::Once<&'a [u8]>),
297}
298
299pub struct Net<A: Address> {
300    peers: Peers<A>,
301    builder: ConnlessBuilder,
302    accept_connections: bool,
303}
304
305struct ConnectionCallback<'a, A: Address, CB: Callback<A> + 'a> {
306    cb: &'a mut CB,
307    addr: A,
308}
309
310// Create `ConnectionCallback`.
311fn cc<A: Address, CB: Callback<A>>(cb: &mut CB, addr: A) -> ConnectionCallback<A, CB> {
312    ConnectionCallback { cb: cb, addr: addr }
313}
314
315impl<'a, A: Address, W: Warn<Warning<A>>> Warn<connection::Warning> for WarnCallback<'a, A, W> {
316    fn warn(&mut self, warning: connection::Warning) {
317        self.warn.warn(Warning::Connless(self.addr, warning))
318    }
319}
320
321impl<'a, A: Address, W: Warn<Warning<A>>> Warn<protocol::Warning> for WarnCallback<'a, A, W> {
322    fn warn(&mut self, warning: protocol::Warning) {
323        self.warn.warn(Warning::Connless(
324            self.addr,
325            connection::Warning::Packet(warning),
326        ))
327    }
328}
329
330struct WarnCallback<'a, A: Address, W: Warn<Warning<A>> + 'a> {
331    warn: &'a mut W,
332    addr: A,
333}
334
335fn w<A: Address, W: Warn<Warning<A>>>(warn: &mut W, addr: A) -> WarnCallback<A, W> {
336    WarnCallback {
337        warn: warn,
338        addr: addr,
339    }
340}
341
342impl<'a, A: Address, W: Warn<Warning<A>>> Warn<connection::Warning> for WarnPeerCallback<'a, A, W> {
343    fn warn(&mut self, warning: connection::Warning) {
344        self.warn.warn(Warning::Peer(self.addr, self.pid, warning))
345    }
346}
347
348struct WarnPeerCallback<'a, A: Address, W: Warn<Warning<A>> + 'a> {
349    warn: &'a mut W,
350    addr: A,
351    pid: PeerId,
352}
353
354fn wp<A: Address, W: Warn<Warning<A>>>(
355    warn: &mut W,
356    addr: A,
357    pid: PeerId,
358) -> WarnPeerCallback<A, W> {
359    WarnPeerCallback {
360        warn: warn,
361        addr: addr,
362        pid: pid,
363    }
364}
365
366impl<'a, A: Address, CB: Callback<A>> connection::Callback for ConnectionCallback<'a, A, CB> {
367    type Error = CB::Error;
368    fn secure_random(&mut self, buffer: &mut [u8]) {
369        self.cb.secure_random(buffer)
370    }
371    fn send(&mut self, data: &[u8]) -> Result<(), CB::Error> {
372        self.cb.send(self.addr, data)
373    }
374    fn time(&mut self) -> Timestamp {
375        self.cb.time()
376    }
377}
378
379impl<A: Address> Net<A> {
380    fn new(accept_connections: bool) -> Net<A> {
381        Net {
382            peers: Peers::new(),
383            builder: ConnlessBuilder::new(),
384            accept_connections: accept_connections,
385        }
386    }
387    pub fn server() -> Net<A> {
388        Net::new(true)
389    }
390    pub fn client() -> Net<A> {
391        Net::new(false)
392    }
393    pub fn needs_tick(&self) -> Timeout {
394        self.peers
395            .iter()
396            .map(|(_, p)| p.conn.needs_tick())
397            .min()
398            .unwrap_or_default()
399    }
400    pub fn is_receive_chunk_still_valid(&self, chunk: &mut ChunkOrEvent<A>) -> bool {
401        if let ChunkOrEvent::Chunk(Chunk { pid, .. }) = *chunk {
402            self.peers.get(pid).is_some()
403        } else {
404            true
405        }
406    }
407    pub fn connect<CB: Callback<A>>(
408        &mut self,
409        cb: &mut CB,
410        addr: A,
411    ) -> (PeerId, Result<(), CB::Error>) {
412        let (pid, peer) = self.peers.new_peer(addr, false);
413        (pid, peer.conn.connect(&mut cc(cb, peer.addr)))
414    }
415    pub fn disconnect<CB: Callback<A>>(
416        &mut self,
417        cb: &mut CB,
418        pid: PeerId,
419        reason: &[u8],
420    ) -> Result<(), CB::Error> {
421        let result;
422        {
423            let peer = &mut self.peers[pid];
424            assert!(!peer.conn.is_unconnected());
425            result = peer.conn.disconnect(&mut cc(cb, peer.addr), reason);
426        }
427        self.peers.remove_peer(pid);
428        result
429    }
430    pub fn send_connless<CB: Callback<A>>(
431        &mut self,
432        cb: &mut CB,
433        addr: A,
434        data: &[u8],
435    ) -> Result<(), Error<CB::Error>> {
436        self.builder.send(cb, addr, Packet::Connless(data))
437    }
438    pub fn send<CB: Callback<A>>(
439        &mut self,
440        cb: &mut CB,
441        chunk: Chunk,
442    ) -> Result<(), Error<CB::Error>> {
443        let peer = &mut self.peers[chunk.pid];
444        peer.conn
445            .send(&mut cc(cb, peer.addr), chunk.data, chunk.vital)
446    }
447    pub fn flush<CB: Callback<A>>(&mut self, cb: &mut CB, pid: PeerId) -> Result<(), CB::Error> {
448        let peer = &mut self.peers[pid];
449        peer.conn.flush(&mut cc(cb, peer.addr))
450    }
451    pub fn ignore(&mut self, pid: PeerId) {
452        self.peers.remove_peer(pid);
453    }
454    pub fn accept<CB: Callback<A>>(&mut self, cb: &mut CB, pid: PeerId) -> Result<(), CB::Error> {
455        let peer = &mut self.peers[pid];
456        assert!(peer.conn.is_unconnected());
457        let mut buf: ArrayVec<[u8; 2048]> = ArrayVec::new();
458        let connect_packet: &[u8] = if peer.token {
459            CONNECT_PACKET
460        } else {
461            CONNECT_PACKET_NO_TOKEN
462        };
463        let (mut none, res) =
464            peer.conn
465                .feed(&mut cc(cb, peer.addr), &mut Panic, connect_packet, &mut buf);
466        assert!(none.next().is_none());
467        res
468    }
469    pub fn reject<CB: Callback<A>>(
470        &mut self,
471        cb: &mut CB,
472        pid: PeerId,
473        reason: &[u8],
474    ) -> Result<(), CB::Error> {
475        let result;
476        {
477            let peer = &mut self.peers[pid];
478            assert!(peer.conn.is_unconnected());
479            result = peer.conn.disconnect(&mut cc(cb, peer.addr), reason);
480        }
481        self.peers.remove_peer(pid);
482        result
483    }
484    pub fn tick<'a, CB: Callback<A>>(&'a mut self, cb: &'a mut CB) -> Tick<A, CB> {
485        Tick {
486            iter_mut: self.peers.iter_mut(),
487            cb: cb,
488        }
489    }
490    pub fn feed<'a, CB, B, W>(
491        &mut self,
492        cb: &mut CB,
493        warn: &mut W,
494        addr: A,
495        data: &'a [u8],
496        buf: B,
497    ) -> (ReceivePacket<'a, A>, Result<(), CB::Error>)
498    where
499        CB: Callback<A>,
500        B: Buffer<'a>,
501        W: Warn<Warning<A>>,
502    {
503        with_buffer(buf, |b| self.feed_impl(cb, warn, addr, data, b))
504    }
505    fn feed_impl<'d, 's, CB, W>(
506        &mut self,
507        cb: &mut CB,
508        warn: &mut W,
509        addr: A,
510        data: &'d [u8],
511        mut buf: BufferRef<'d, 's>,
512    ) -> (ReceivePacket<'d, A>, Result<(), CB::Error>)
513    where
514        CB: Callback<A>,
515        W: Warn<Warning<A>>,
516    {
517        if let Some(pid) = self.peers.pid_from_addr(addr) {
518            let (packet, e) = self.peers[pid].conn.feed(
519                &mut cc(cb, addr),
520                &mut wp(warn, addr, pid),
521                data,
522                &mut buf,
523            );
524            (ReceivePacket::connected(addr, pid, packet, self), e)
525        } else {
526            let packet = match Packet::read(&mut w(warn, addr), data, None, &mut buf) {
527                Ok(p) => p,
528                Err(e) => {
529                    w(warn, addr).warn(connection::Warning::Read(e));
530                    return (ReceivePacket::none(), Ok(()));
531                }
532            };
533            if let Packet::Connless(d) = packet {
534                (ReceivePacket::connless(addr, d), Ok(()))
535            } else if let Packet::Connected(ConnectedPacket {
536                token,
537                type_: ConnectedPacketType::Control(ControlPacket::Connect),
538                ..
539            }) = packet
540            {
541                if self.accept_connections {
542                    // TODO: This is vulnerable to IP spoofing.
543                    let (pid, _) = self.peers.new_peer(addr, token.is_some());
544                    (ReceivePacket::connect(pid), Ok(()))
545                } else {
546                    w(warn, addr).warn(connection::Warning::Unexpected);
547                    (ReceivePacket::none(), Ok(()))
548                }
549            } else {
550                w(warn, addr).warn(connection::Warning::Unexpected);
551                (ReceivePacket::none(), Ok(()))
552            }
553        }
554    }
555}
556
557pub struct Tick<'a, A: Address + 'a, CB: Callback<A> + 'a> {
558    iter_mut: peer_map::IterMut<'a, Peer<A>>,
559    cb: &'a mut CB,
560}
561
562impl<'a, A: Address + 'a, CB: Callback<A> + 'a> Iterator for Tick<'a, A, CB> {
563    type Item = CB::Error;
564    fn next(&mut self) -> Option<CB::Error> {
565        while let Some((_, p)) = self.iter_mut.next() {
566            match p.conn.tick(&mut cc(self.cb, p.addr)) {
567                Ok(()) => {}
568                Err(e) => return Some(e),
569            }
570        }
571        None
572    }
573}
574
575#[cfg(test)]
576mod test {
577    use super::Callback;
578    use super::ChunkOrEvent;
579    use super::Net;
580    use crate::protocol;
581    use crate::Timestamp;
582    use itertools::Itertools;
583    use std::collections::VecDeque;
584    use void::ResultVoidExt;
585    use void::Void;
586    use warn::Panic;
587
588    #[test]
589    fn establish_connection() {
590        #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
591        enum Address {
592            Client,
593            Server,
594        }
595        struct Cb {
596            packets: VecDeque<Vec<u8>>,
597            recipient: Address,
598        }
599        impl Cb {
600            fn new() -> Cb {
601                Cb {
602                    packets: VecDeque::new(),
603                    recipient: Address::Server,
604                }
605            }
606        }
607        impl Callback<Address> for Cb {
608            type Error = Void;
609            fn secure_random(&mut self, buffer: &mut [u8]) {
610                if buffer.len() != 4 {
611                    unimplemented!();
612                }
613                buffer[0] = 0x12;
614                buffer[1] = 0x34;
615                buffer[2] = 0x56;
616                buffer[3] = 0x78;
617            }
618            fn send(&mut self, addr: Address, data: &[u8]) -> Result<(), Void> {
619                assert!(self.recipient == addr);
620                self.packets.push_back(data.to_owned());
621                Ok(())
622            }
623            fn time(&mut self) -> Timestamp {
624                Timestamp::from_secs_since_epoch(0)
625            }
626        }
627        let mut cb = Cb::new();
628        let cb = &mut cb;
629        let mut buffer = [0; protocol::MAX_PACKETSIZE];
630
631        let mut net = Net::server();
632
633        // Connect
634        cb.recipient = Address::Server;
635        let (c_pid, res) = net.connect(cb, Address::Server);
636        res.void_unwrap();
637        let packet = cb.packets.pop_front().unwrap();
638        assert!(cb.packets.is_empty());
639
640        // ConnectAccept
641        cb.recipient = Address::Client;
642        let s_pid;
643        {
644            let p = net
645                .feed(cb, &mut Panic, Address::Client, &packet, &mut buffer[..])
646                .0
647                .collect_vec();
648            assert!(p.len() == 1);
649            if let ChunkOrEvent::Connect(s) = p[0] {
650                s_pid = s;
651            } else {
652                panic!();
653            }
654        }
655        // No packets sent out until we accept the client.
656        assert!(cb.packets.is_empty());
657
658        net.accept(cb, s_pid).void_unwrap();
659        let packet = cb.packets.pop_front().unwrap();
660        assert!(cb.packets.is_empty());
661
662        // Accept
663        cb.recipient = Address::Server;
664        assert!(
665            net.feed(cb, &mut Panic, Address::Server, &packet, &mut buffer[..])
666                .0
667                .collect_vec()
668                == &[ChunkOrEvent::Ready(c_pid)]
669        );
670        let packet = cb.packets.pop_front().unwrap();
671        assert!(cb.packets.is_empty());
672
673        cb.recipient = Address::Client;
674        assert!(net
675            .feed(cb, &mut Panic, Address::Client, &packet, &mut buffer[..])
676            .0
677            .next()
678            .is_none());
679        assert!(cb.packets.is_empty());
680
681        // Disconnect
682        cb.recipient = Address::Server;
683        net.disconnect(cb, c_pid, b"foobar").void_unwrap();
684        let packet = cb.packets.pop_front().unwrap();
685        assert!(cb.packets.is_empty());
686
687        cb.recipient = Address::Client;
688        assert!(
689            net.feed(cb, &mut Panic, Address::Client, &packet, &mut buffer[..])
690                .0
691                .collect_vec()
692                == &[ChunkOrEvent::Disconnect(s_pid, b"foobar")]
693        );
694        assert!(cb.packets.is_empty());
695    }
696}