Skip to main content

kompact_net/transport/
network_thread.rs

1use super::*;
2use crate::{
3    NetworkStatus,
4    dispatch::{
5        NetworkConfig,
6        lookup::{ActorLookup, LookupResult},
7    },
8    events::NetworkDispatcherEvent,
9    messaging::{DispatchEnvelope, NetMessage, SerialisedFrame},
10    net::{
11        SessionId,
12        buffers::{BufferChunk, BufferPool, EncodeBuffer},
13    },
14    transport::{
15        network_channel::{ChannelState, TcpChannel},
16        udp_state::UdpState,
17    },
18};
19use crossbeam_channel::Receiver as Recv;
20use ipnet::{IpNet, Ipv4Net, Ipv6Net};
21use iprange::{IpRange, ToNetwork};
22use kompact::prelude::deserialise_chunk_lease;
23use lru::LruCache;
24use mio::{
25    Events,
26    Poll,
27    Token,
28    event::Event,
29    net::{TcpListener, TcpStream, UdpSocket},
30};
31use rustc_hash::{FxHashMap, FxHashSet};
32use snafu::ResultExt;
33use std::{
34    cell::{RefCell, RefMut},
35    collections::VecDeque,
36    io,
37    net::{IpAddr, Shutdown, SocketAddr},
38    num::NonZeroUsize,
39    ops::DerefMut,
40    rc::Rc,
41    sync::Arc,
42    time::Duration,
43};
44
45// Used for identifying connections
46const TCP_SERVER: Token = Token(0);
47const UDP_SOCKET: Token = Token(1);
48// Used for identifying the dispatcher/input queue
49const DISPATCHER: Token = Token(2);
50const START_TOKEN: Token = Token(3);
51const MAX_POLL_EVENTS: usize = 1024;
52/// How many times to retry on interrupt before we give up
53pub const MAX_INTERRUPTS: i32 = 9;
54// We do retries when we fail to bind a socket listener during boot-up:
55const MAX_BIND_RETRIES: usize = 5;
56const BIND_RETRY_INTERVAL: u64 = 1000;
57
58/// Builder struct, can be sent to a thread safely to launch a NetworkThread
59pub struct NetworkThreadBuilder {
60    poll: Poll,
61    waker: Option<Waker>,
62    log: KompactLogger,
63    pub address: SocketAddr,
64    lookup: Arc<ArcSwap<ActorStore>>,
65    input_queue: Recv<DispatchEvent>,
66    shutdown_promise: KPromise<()>,
67    dispatcher_ref: DispatcherRef,
68    network_config: NetworkConfig,
69    tcp_listener: TcpListener,
70}
71
72impl NetworkThreadBuilder {
73    pub(crate) fn new(
74        log: KompactLogger,
75        address: SocketAddr,
76        lookup: Arc<ArcSwap<ActorStore>>,
77        input_queue: Recv<DispatchEvent>,
78        shutdown_promise: KPromise<()>,
79        dispatcher_ref: DispatcherRef,
80        network_config: NetworkConfig,
81    ) -> Result<NetworkThreadBuilder, NetworkBridgeError> {
82        let poll = Poll::new().expect("failed to create Poll instance in NetworkThread");
83        let waker =
84            Waker::new(poll.registry(), DISPATCHER).expect("failed to create Waker for DISPATCHER");
85        let tcp_listener = bind_with_retries(&address, MAX_BIND_RETRIES, &log)
86            .context(network_bridge_error::IoSnafu)?;
87        let actual_address = tcp_listener
88            .local_addr()
89            .context(network_bridge_error::IoSnafu)?;
90        Ok(NetworkThreadBuilder {
91            poll,
92            tcp_listener,
93            waker: Some(waker),
94            log,
95            address: actual_address,
96            lookup,
97            input_queue,
98            shutdown_promise,
99            dispatcher_ref,
100            network_config,
101        })
102    }
103
104    pub fn take_waker(&mut self) -> Option<Waker> {
105        self.waker.take()
106    }
107
108    pub fn build(mut self) -> NetworkThread {
109        let actual_addr = self
110            .tcp_listener
111            .local_addr()
112            .expect("could not get real addr");
113        let logger = self.log.new(o!("addr" => format!("{}", actual_addr)));
114        let mut udp_socket = UdpSocket::bind(actual_addr).expect("could not bind UDP on TCP port");
115
116        // Register Listeners
117        self.poll
118            .registry()
119            .register(&mut self.tcp_listener, TCP_SERVER, Interest::READABLE)
120            .expect("failed to register TCP SERVER");
121        self.poll
122            .registry()
123            .register(
124                &mut udp_socket,
125                UDP_SOCKET,
126                Interest::READABLE | Interest::WRITABLE,
127            )
128            .expect("failed to register UDP SOCKET");
129
130        let mut buffer_pool = BufferPool::with_config(
131            self.network_config.get_buffer_config(),
132            self.network_config.get_custom_allocator(),
133        );
134        let encode_buffer = EncodeBuffer::with_config(
135            self.network_config.get_buffer_config(),
136            self.network_config.get_custom_allocator(),
137        );
138        let udp_buffer = buffer_pool
139            .get_buffer()
140            .expect("Could not get buffer for setting up UDP");
141        let udp_state = UdpState::new(udp_socket, udp_buffer, logger.clone(), &self.network_config);
142
143        NetworkThread {
144            log: logger,
145            addr: actual_addr,
146            lookup: self.lookup,
147            tcp_listener: self.tcp_listener,
148            udp_state: Some(udp_state),
149            poll: self.poll,
150            address_map: FxHashMap::default(),
151            token_map: LruCache::new(
152                NonZeroUsize::new(self.network_config.get_hard_connection_limit() as usize)
153                    .unwrap(),
154            ),
155            token: START_TOKEN,
156            input_queue: self.input_queue,
157            buffer_pool: RefCell::new(buffer_pool),
158            stopped: false,
159            shutdown_promise: Some(self.shutdown_promise),
160            dispatcher_ref: self.dispatcher_ref,
161            network_config: self.network_config,
162            retry_queue: VecDeque::new(),
163            out_of_buffers: false,
164            encode_buffer,
165            block_list: BlockList::default(), // TODO: extend NetworkConfig to build NetworkThread with a blocklist
166        }
167    }
168}
169/// Thread structure responsible for driving the Network IO
170pub struct NetworkThread {
171    log: KompactLogger,
172    pub addr: SocketAddr,
173    lookup: Arc<ArcSwap<ActorStore>>,
174    tcp_listener: TcpListener,
175    udp_state: Option<UdpState>,
176    poll: Poll,
177    address_map: FxHashMap<SocketAddr, Rc<RefCell<TcpChannel>>>,
178    token_map: LruCache<Token, Rc<RefCell<TcpChannel>>>,
179    token: Token,
180    input_queue: Recv<DispatchEvent>,
181    dispatcher_ref: DispatcherRef,
182    buffer_pool: RefCell<BufferPool>,
183    stopped: bool,
184    shutdown_promise: Option<KPromise<()>>,
185    network_config: NetworkConfig,
186    retry_queue: VecDeque<EventWithRetries>,
187    out_of_buffers: bool,
188    encode_buffer: EncodeBuffer,
189    block_list: BlockList,
190}
191
192impl NetworkThread {
193    pub fn run(mut self) {
194        trace!(self.log, "NetworkThread starting");
195        let mut events = Events::with_capacity(MAX_POLL_EVENTS);
196        loop {
197            self.poll
198                .poll(&mut events, self.get_poll_timeout())
199                .expect("Error when calling Poll");
200
201            for event in events
202                .iter()
203                .map(EventWithRetries::from)
204                .chain(self.retry_queue.split_off(0))
205            {
206                self.handle_event(event);
207
208                if self.stopped {
209                    if let Some(Err(e)) = self
210                        .shutdown_promise
211                        .take()
212                        .map(|promise| promise.complete())
213                    {
214                        error!(self.log, "Error, shutting down sender: {:?}", e);
215                    };
216                    trace!(self.log, "Stopped");
217                    return;
218                };
219            }
220        }
221    }
222
223    fn get_poll_timeout(&self) -> Option<Duration> {
224        if self.out_of_buffers {
225            Some(Duration::from_millis(
226                self.network_config.get_connection_retry_interval(),
227            ))
228        } else if self.retry_queue.is_empty() {
229            None
230        } else {
231            Some(Duration::from_secs(0))
232        }
233    }
234
235    fn handle_event(&mut self, event: EventWithRetries) {
236        match event.token {
237            TCP_SERVER => {
238                if let Err(e) = self.receive_stream() {
239                    error!(self.log, "Error while accepting stream {:?}", e);
240                }
241            }
242            UDP_SOCKET => {
243                if let Some(mut udp_state) = self.udp_state.take() {
244                    if event.writeable {
245                        self.write_udp(&mut udp_state);
246                    }
247                    if event.readable {
248                        self.read_udp(&mut udp_state, event);
249                    }
250                    self.udp_state = Some(udp_state);
251                }
252            }
253            DISPATCHER => {
254                self.receive_dispatch();
255            }
256            _ => {
257                if event.writeable {
258                    self.write_tcp(&event.token);
259                }
260                if event.readable {
261                    self.read_tcp(&event);
262                }
263            }
264        }
265    }
266
267    fn retry_event(&mut self, event: &EventWithRetries) {
268        if event.retries <= self.network_config.get_max_connection_retry_attempts() {
269            self.retry_queue.push_back(event.get_retry_event());
270        } else if let Some(channel) = self.get_channel_by_token(&event.token) {
271            self.lost_connection(channel.borrow_mut());
272        }
273    }
274
275    fn enqueue_writeable_event(&mut self, token: &Token) {
276        self.retry_queue
277            .push_back(EventWithRetries::writeable_with_token(token));
278    }
279
280    fn get_buffer(&self) -> Option<BufferChunk> {
281        self.buffer_pool.borrow_mut().get_buffer()
282    }
283
284    fn return_buffer(&self, buffer: BufferChunk) {
285        self.buffer_pool.borrow_mut().return_buffer(buffer)
286    }
287
288    fn receive_dispatch(&mut self) {
289        while let Ok(event) = self.input_queue.try_recv() {
290            self.handle_dispatch_event(event);
291        }
292    }
293
294    fn handle_dispatch_event(&mut self, event: DispatchEvent) {
295        match event {
296            DispatchEvent::SendTcp(address, data) => {
297                self.send_tcp_message(address, data);
298            }
299            DispatchEvent::SendUdp(address, data) => {
300                self.send_udp_message(address, data);
301            }
302            DispatchEvent::Stop => {
303                self.stop();
304            }
305            DispatchEvent::Kill => {
306                self.kill();
307            }
308            DispatchEvent::Connect(addr) => {
309                if self.block_list.socket_addr_is_blocked(&addr) {
310                    return;
311                }
312                self.request_stream(addr);
313            }
314            DispatchEvent::ClosedAck(addr) => {
315                self.handle_closed_ack(addr);
316            }
317            DispatchEvent::Close(addr) => {
318                self.close_connection(addr);
319            }
320            DispatchEvent::BlockSocket(addr) => {
321                self.block_socket_addr(addr);
322            }
323            DispatchEvent::BlockIpAddr(ip_addr) => {
324                self.block_ip_addr(ip_addr);
325            }
326            DispatchEvent::AllowSocket(addr) => {
327                self.allow_socket_addr(addr);
328            }
329            DispatchEvent::AllowIpAddr(ip_addr) => {
330                self.allow_ip_addr(ip_addr);
331            }
332            DispatchEvent::BlockIpNet(net) => {
333                self.block_ip_net(net);
334            }
335            DispatchEvent::AllowIpNet(net) => {
336                self.allow_ip_net(net);
337            }
338        }
339    }
340
341    fn get_channel_by_token(&mut self, token: &Token) -> Option<Rc<RefCell<TcpChannel>>> {
342        self.token_map.get(token).cloned()
343    }
344
345    fn update_lru(&mut self, token: &Token) {
346        let _ = self.token_map.get(token);
347    }
348
349    fn get_channel_by_address(&self, address: &SocketAddr) -> Option<Rc<RefCell<TcpChannel>>> {
350        self.address_map.get(address).cloned()
351    }
352
353    fn reregister_channel_address(&mut self, old_address: SocketAddr, new_address: SocketAddr) {
354        if let Some(channel_rc) = self.address_map.remove(&old_address) {
355            self.address_map.insert(new_address, channel_rc);
356        }
357    }
358
359    fn read_tcp(&mut self, event: &EventWithRetries) {
360        if let Some(channel_rc) = self.get_channel_by_token(&event.token) {
361            let mut channel = channel_rc.borrow_mut();
362            loop {
363                match channel.read_frame(&self.buffer_pool) {
364                    Ok(None) => {
365                        return;
366                    }
367                    Ok(Some(Frame::Data(data))) => {
368                        self.handle_data_frame(
369                            data,
370                            channel
371                                .session_id()
372                                .expect("Connected Channel must have a SessionId"),
373                        );
374                    }
375                    Ok(Some(Frame::Start(start))) => {
376                        self.handle_start(event, &mut channel, &start);
377                        return;
378                    }
379                    Ok(Some(Frame::Hello(hello))) => {
380                        self.handle_hello(channel.deref_mut(), &hello);
381                    }
382                    Ok(Some(Frame::Ack())) => {
383                        self.check_soft_connection_limit();
384                        self.notify_network_status(NetworkStatus::ConnectionEstablished(
385                            SystemPath::with_socket(Transport::Tcp, channel.address()),
386                            channel.session_id().unwrap(),
387                        ))
388                    }
389                    Ok(Some(Frame::Bye())) => {
390                        self.handle_bye(&mut channel);
391                        return;
392                    }
393                    Err(e) if no_buffer_space(&e) => {
394                        self.out_of_buffers = true;
395                        warn!(self.log, "Out of Buffers");
396                        drop(channel);
397                        self.retry_event(event);
398                        return;
399                    }
400                    Err(e) if connection_reset(&e) => {
401                        warn!(
402                            self.log,
403                            "Connection lost, reset by peer {}",
404                            channel.address()
405                        );
406                        self.lost_connection(channel);
407                        return;
408                    }
409                    Err(e) => {
410                        warn!(
411                            self.log,
412                            "Error reading from channel {}: {}",
413                            channel.address(),
414                            &e
415                        );
416                        return;
417                    }
418                }
419            }
420        }
421    }
422
423    fn read_udp(&mut self, udp_state: &mut UdpState, event: EventWithRetries) {
424        match udp_state.try_read(&self.buffer_pool) {
425            Ok(_) => {}
426            Err(e) if no_buffer_space(&e) => {
427                warn!(
428                    self.log,
429                    "Could not get UDP buffer, retries: {}", event.retries
430                );
431                self.out_of_buffers = true;
432                self.retry_event(&event);
433            }
434            Err(e) => {
435                warn!(self.log, "Error during UDP reading: {}", e);
436            }
437        }
438        while let Some(net_message) = udp_state.incoming_messages.pop_front() {
439            self.deliver_net_message(net_message);
440        }
441    }
442
443    fn write_tcp(&mut self, token: &Token) {
444        if let Some(channel_rc) = self.get_channel_by_token(token) {
445            let mut channel = channel_rc.borrow_mut();
446            match channel.try_drain() {
447                Err(ref err) if broken_pipe(err) => {
448                    self.lost_connection(channel);
449                }
450                Ok(_) => {
451                    if let ChannelState::CloseReceived(addr, id) = channel.state {
452                        channel.state = ChannelState::Closed(addr, id);
453                        debug!(self.log, "Connection to {} shutdown gracefully", &addr);
454                        self.deregister_channel(channel.deref_mut());
455                        self.notify_network_status(NetworkStatus::ConnectionClosed(
456                            SystemPath::with_socket(Transport::Tcp, channel.address()),
457                            id,
458                        ));
459                        self.reject_outbound_for_channel(&mut channel);
460                    }
461                }
462                Err(e) => {
463                    warn!(
464                        self.log,
465                        "Unhandled error while writing to {}\n{:?}",
466                        channel.address(),
467                        e
468                    );
469                }
470            }
471        }
472    }
473
474    fn write_udp(&mut self, udp_state: &mut UdpState) {
475        match udp_state.try_write() {
476            Ok(_) => {}
477            Err(e) => {
478                warn!(self.log, "Error during UDP sending: {}", e);
479            }
480        }
481    }
482
483    fn send_tcp_message(&mut self, address: SocketAddr, data: DispatchData) {
484        if let Some(channel_rc) = self.get_channel_by_address(&address) {
485            let mut channel = channel_rc.borrow_mut();
486            self.update_lru(&channel.token);
487            if channel.connected() {
488                match self.serialise_dispatch_data(data) {
489                    Ok(frame) => {
490                        channel.enqueue_serialised(frame);
491                        self.enqueue_writeable_event(&channel.token);
492                    }
493                    Err(e) if out_of_buffers(&e) => {
494                        self.out_of_buffers = true;
495                        warn!(
496                            self.log,
497                            "No network buffers available, dropping outbound message.\
498                        slow down message rate or increase buffer limits."
499                        );
500                    }
501                    Err(e) => {
502                        error!(self.log, "Error serialising message {}", e);
503                    }
504                }
505            } else {
506                trace!(
507                    self.log,
508                    "Dispatch trying to route to non connected channel {:?}, rejecting the message",
509                    channel
510                );
511                self.reject_dispatch_data(address, data);
512            }
513        } else {
514            trace!(
515                self.log,
516                "Dispatch trying to route to unrecognized address {}, rejecting the message",
517                address
518            );
519            self.reject_dispatch_data(address, data);
520        }
521    }
522
523    fn send_udp_message(&mut self, address: SocketAddr, data: DispatchData) {
524        if let Some(mut udp_state) = self.udp_state.take() {
525            match self.serialise_dispatch_data(data) {
526                Ok(frame) => {
527                    udp_state.enqueue_serialised(address, frame);
528                    match udp_state.try_write() {
529                        Ok(_) => {}
530                        Err(e) => {
531                            warn!(self.log, "Error during UDP sending: {}", e);
532                            debug!(self.log, "UDP error debug info: {:?}", e);
533                        }
534                    }
535                }
536                Err(e) if out_of_buffers(&e) => {
537                    self.out_of_buffers = true;
538                    warn!(
539                        self.log,
540                        "No network buffers available, dropping outbound message.\
541                        slow down message rate or increase buffer limits."
542                    );
543                }
544                Err(e) => {
545                    error!(self.log, "Error serialising message {}", e);
546                }
547            }
548            self.udp_state = Some(udp_state);
549        } else {
550            self.reject_dispatch_data(address, data);
551            trace!(
552                self.log,
553                "Rejecting UDP message to {} as socket is already shut down.", address
554            );
555        }
556    }
557
558    fn handle_data_frame(&self, data: Data, session: SessionId) {
559        let buf = data.payload();
560        let mut envelope = deserialise_chunk_lease(buf).expect("s11n errors");
561        envelope.set_session(session);
562        self.deliver_net_message(envelope);
563    }
564
565    fn deliver_net_message(&self, envelope: NetMessage) {
566        let lease_lookup = self.lookup.load();
567        match lease_lookup.get_by_actor_path(&envelope.receiver) {
568            LookupResult::Ref(actor) => {
569                actor.enqueue(envelope);
570            }
571            LookupResult::Group(group) => {
572                group.route(envelope, &self.log);
573            }
574            LookupResult::None => {
575                warn!(
576                    self.log,
577                    "Could not find actor reference for destination: {:?}, dropping message",
578                    envelope.receiver
579                );
580            }
581            LookupResult::Err(e) => {
582                error!(
583                    self.log,
584                    "An error occurred during local actor lookup for destination: {:?}, dropping message. The error was: {}",
585                    envelope.receiver,
586                    e
587                );
588            }
589        }
590    }
591
592    fn handle_hello(&mut self, channel: &mut TcpChannel, hello: &Hello) {
593        if self.block_list.socket_addr_is_blocked(&hello.addr) {
594            self.drop_channel(channel);
595        } else {
596            self.reregister_channel_address(channel.address(), hello.addr);
597            channel.handle_hello(hello);
598        }
599    }
600
601    /// During channel initialization the threeway handshake to establish connections culminates with this function
602    /// The Start(remote_addr, id) is received by the host on the receiving end of the channel initialisation.
603    /// The decision is made here and now.
604    /// If no other connection is registered for the remote host the decision is easy, we start the channel and send the ack.
605    /// If there are other connection attempts underway there are multiple possibilities:
606    ///     The other connection has not started and does not have a known UUID: it will be killed, this channel will start.
607    ///     The connection has already started, in which case this channel must be killed.
608    ///     The connection has a known UUID but is not connected: Use the UUID as a tie breaker for which to kill and which to keep.
609    fn handle_start(&mut self, event: &EventWithRetries, channel: &mut TcpChannel, start: &Start) {
610        if self.block_list.socket_addr_is_blocked(&start.addr) {
611            self.drop_channel(channel);
612            return;
613        }
614        if let Some(other_channel_rc) = self.get_channel_by_address(&start.addr) {
615            let mut other_channel = other_channel_rc.borrow_mut();
616            debug!(
617                self.log,
618                "Merging channels {:?} and {:?}", channel, other_channel
619            );
620            match other_channel.read_state() {
621                ChannelState::Connected(_, _) => {
622                    if other_channel.messages == 0 {
623                        self.drop_channel(channel);
624                        return;
625                    } else {
626                        self.lost_connection(other_channel);
627                    }
628                }
629                ChannelState::Requested(_, other_id) if other_id.as_u128() > start.id.as_u128() => {
630                    self.drop_channel(channel);
631                    return;
632                }
633                ChannelState::Initialised(_, other_id)
634                    if other_id.as_u128() > start.id.as_u128() =>
635                {
636                    self.drop_channel(channel);
637                    return;
638                }
639                _ => {
640                    self.drop_channel(other_channel.deref_mut());
641                }
642            }
643        }
644        self.reregister_channel_address(channel.address(), start.addr);
645        channel.handle_start(start);
646        self.retry_event(event);
647        self.check_soft_connection_limit();
648        self.notify_network_status(NetworkStatus::ConnectionEstablished(
649            SystemPath::with_socket(Transport::Tcp, start.addr),
650            start.id,
651        ));
652    }
653
654    fn handle_bye(&mut self, channel: &mut TcpChannel) {
655        match channel.state {
656            ChannelState::Closed(addr, id) => {
657                debug!(self.log, "Connection shutdown gracefully");
658                self.deregister_channel(channel);
659                self.notify_network_status(NetworkStatus::ConnectionClosed(
660                    SystemPath::with_socket(Transport::Tcp, addr),
661                    id,
662                ));
663                self.reject_outbound_for_channel(channel);
664            }
665            ChannelState::CloseReceived(_, _) => {}
666            _ => {
667                self.drop_channel(channel);
668            }
669        }
670    }
671
672    fn handle_closed_ack(&mut self, address: SocketAddr) {
673        if let Some(channel_rc) = self.get_channel_by_address(&address) {
674            let mut channel = channel_rc.borrow_mut();
675            if let ChannelState::Connected(_, _) = channel.state {
676                error!(self.log, "ClosedAck for connected Channel: {:#?}", &channel);
677            } else {
678                self.drop_channel(&mut channel)
679            }
680        } else {
681            error!(
682                self.log,
683                "ClosedAck for unrecognized address: {:#?}", &address
684            );
685        }
686    }
687
688    fn drop_channel(&mut self, channel: &mut TcpChannel) {
689        self.deregister_channel(channel);
690        self.address_map.remove(&channel.address());
691        channel.shutdown();
692        let mut buffer = BufferChunk::new(0);
693        channel.swap_buffer(&mut buffer);
694        self.return_buffer(buffer);
695    }
696
697    fn deregister_channel(&mut self, channel: &mut TcpChannel) {
698        let _ = self.poll.registry().deregister(channel.stream_mut());
699        self.token_map.pop(&channel.token);
700    }
701
702    fn request_stream(&mut self, address: SocketAddr) {
703        if let Some(channel_rc) = self.get_channel_by_address(&address) {
704            let mut channel = channel_rc.borrow_mut();
705            match channel.state {
706                ChannelState::Connected(_, _) => {
707                    debug!(
708                        self.log,
709                        "Asked to request connection to already connected host {}", &address
710                    );
711                    return;
712                }
713                ChannelState::Closed(_, _) => {
714                    debug!(
715                        self.log,
716                        "Requested connection to host before receiving ClosedAck {}", &address
717                    );
718                    return;
719                }
720                _ => {
721                    self.drop_channel(&mut channel);
722                }
723            }
724        }
725        if self.check_hard_connection_limit() {
726            warn!(
727                self.log,
728                "Hard Connection limit reached, rejecting request to connect to remote \
729                host {}",
730                &address
731            );
732            return;
733        }
734        if let Some(buffer) = self.get_buffer() {
735            trace!(self.log, "Requesting connection to {}", &address);
736            match TcpStream::connect(address) {
737                Ok(stream) => {
738                    self.store_stream(
739                        stream,
740                        address,
741                        ChannelState::Requested(address, SessionId::new_unique()),
742                        buffer,
743                    );
744                }
745                Err(e) => {
746                    //  Connection will be re-requested
747                    trace!(
748                        self.log,
749                        "Failed to connect to remote host {}, error: {:?}", &address, e
750                    );
751                    self.return_buffer(buffer);
752                }
753            }
754        } else {
755            self.out_of_buffers = true;
756            trace!(
757                self.log,
758                "No Buffers available when attempting to connect to remote host {}", &address
759            );
760        }
761    }
762
763    fn receive_stream(&mut self) -> io::Result<()> {
764        while let Ok((stream, address)) = self.tcp_listener.accept() {
765            if self.block_list.ip_addr_is_blocked(&address.ip()) {
766                trace!(
767                    self.log,
768                    "Rejecting connection request from blocked source: {}", &address
769                );
770                stream.shutdown(Shutdown::Both)?;
771            } else if self.check_hard_connection_limit() {
772                warn!(
773                    self.log,
774                    "Hard Connection limit reached, rejecting incoming connection \
775                request from {}",
776                    &address
777                );
778                stream.shutdown(Shutdown::Both)?;
779            } else if let Some(buffer) = self.get_buffer() {
780                trace!(self.log, "Accepting connection from {}", &address);
781                self.store_stream(stream, address, ChannelState::Initialising, buffer);
782            } else {
783                warn!(
784                    self.log,
785                    "Network Thread out of buffers, rejecting incoming connection \
786                request from {}",
787                    &address
788                );
789                stream.shutdown(Shutdown::Both)?;
790            }
791        }
792        Ok(())
793    }
794
795    fn store_stream(
796        &mut self,
797        stream: TcpStream,
798        address: SocketAddr,
799        state: ChannelState,
800        buffer: BufferChunk,
801    ) {
802        let mut channel = TcpChannel::new(
803            stream,
804            self.token,
805            address,
806            buffer,
807            state,
808            self.addr,
809            &self.network_config,
810        );
811        channel.initialise(&self.addr);
812        if let Err(e) = self.poll.registry().register(
813            channel.stream_mut(),
814            self.token,
815            Interest::READABLE | Interest::WRITABLE,
816        ) {
817            error!(
818                self.log,
819                "Failed to register polling for {}\n{:?}", address, e
820            );
821        }
822        let rc = Rc::new(RefCell::new(channel));
823        self.address_map.insert(address, rc.clone());
824        self.token_map.put(self.token, rc);
825        self.next_token();
826    }
827
828    /// Checks the current active channel count and initiates a graceful shutdown of the LRU channel.
829    fn check_soft_connection_limit(&mut self) {
830        let limit = self.network_config.get_soft_connection_limit() as usize;
831        // First condition allows for early returns without doing the count
832        if self.token_map.len() > limit && self.count_active_connections() > limit {
833            // Find the LRU ACTIVE connection
834            for (_, channel) in self.token_map.iter().rev() {
835                if channel.borrow().connected() {
836                    let addr = channel.borrow().address();
837                    warn!(
838                        self.log,
839                        "Soft Connection Limit exceeded! limit = {}. Closing channel {:?}",
840                        self.network_config.get_soft_connection_limit(),
841                        &channel.borrow(),
842                    );
843                    self.notify_network_status(NetworkStatus::SoftConnectionLimitExceeded);
844                    self.close_connection(addr);
845                    return;
846                }
847            }
848        }
849    }
850
851    /// Returns true if the limit is reached, and notifies the NetworkDispatcher that it is reached.
852    fn check_hard_connection_limit(&self) -> bool {
853        if self.token_map.len() >= self.network_config.get_hard_connection_limit() as usize {
854            self.notify_network_status(NetworkStatus::HardConnectionLimitReached);
855            true
856        } else {
857            false
858        }
859    }
860
861    fn count_active_connections(&self) -> usize {
862        self.token_map
863            .iter()
864            .filter(|(_, connection)| {
865                if let Ok(con) = connection.try_borrow() {
866                    con.connected()
867                } else {
868                    true
869                } // If we fail to borrow the connection it's active.
870            })
871            .count()
872    }
873
874    /// Initiates a graceful closing sequence if the channel is connected or
875    fn close_connection(&mut self, addr: SocketAddr) {
876        if let Some(channel) = self.get_channel_by_address(&addr) {
877            let mut channel_mut = channel.borrow_mut();
878            if channel_mut.connected() {
879                channel_mut.initiate_graceful_shutdown();
880                self.update_lru(&channel_mut.token);
881            } else {
882                self.drop_channel(channel_mut.deref_mut());
883            }
884        }
885    }
886
887    /// Handles all logic necessary to shutdown a channel for which the connection has been lost.
888    fn lost_connection(&mut self, mut channel: RefMut<TcpChannel>) {
889        trace!(self.log, "Lost connection to address {}", channel.address());
890        if let Some(id) = channel.session_id() {
891            self.notify_network_status(NetworkStatus::ConnectionLost(
892                SystemPath::with_socket(Transport::Tcp, channel.address()),
893                id,
894            ));
895        }
896        self.reject_outbound_for_channel(&mut channel);
897        // Try to inform the other end that we're closing the channel
898        let _ = channel.send_bye();
899        self.deregister_channel(channel.deref_mut());
900        channel.shutdown();
901    }
902
903    fn reject_outbound_for_channel(&mut self, channel: &mut TcpChannel) {
904        for rejected_frame in channel.take_outbound() {
905            self.reject_dispatch_data(channel.address(), DispatchData::Serialised(rejected_frame));
906        }
907    }
908
909    fn stop(&mut self) {
910        for (_, channel_rc) in self.address_map.drain() {
911            let mut channel = channel_rc.borrow_mut();
912            debug!(
913                self.log,
914                "Stopping channel with message count {}", channel.messages
915            );
916            channel.initiate_graceful_shutdown();
917            self.token_map.pop(&channel.token);
918        }
919        self.poll
920            .registry()
921            .deregister(&mut self.tcp_listener)
922            .expect("Deregistering listener while stopping network should work");
923        if let Some(mut udp_state) = self.udp_state.take() {
924            self.poll.registry().deregister(&mut udp_state.socket).ok();
925            let count = udp_state.pending_messages();
926            drop(udp_state);
927            debug!(
928                self.log,
929                "Dropped its UDP socket with message count {}", count
930            );
931        }
932        self.stopped = true;
933    }
934
935    fn kill(&mut self) {
936        trace!(self.log, "Killing the NetworkThread");
937        for (_, channel_rc) in self.address_map.drain() {
938            channel_rc.borrow_mut().kill();
939        }
940        self.stop();
941    }
942
943    fn notify_network_status(&self, status: NetworkStatus) {
944        self.dispatcher_ref.tell(DispatchEnvelope::Event(Box::new(
945            NetworkDispatcherEvent::Network(status),
946        )))
947    }
948
949    fn reject_dispatch_data(&self, address: SocketAddr, data: DispatchData) {
950        self.dispatcher_ref.tell(DispatchEnvelope::Event(Box::new(
951            NetworkDispatcherEvent::RejectedData((address, Box::new(data))),
952        )));
953    }
954
955    fn next_token(&mut self) {
956        let next = self.token.0 + 1;
957        self.token = Token(next);
958    }
959
960    fn serialise_dispatch_data(&mut self, data: DispatchData) -> Result<SerialisedFrame, SerError> {
961        match data {
962            DispatchData::Serialised(frame) => Ok(frame),
963            _ => data.into_serialised(&mut self.encode_buffer.get_buffer_encoder()?),
964        }
965    }
966
967    fn block_ip_addr(&mut self, ip_addr: IpAddr) {
968        if self.block_list.block_ip_addr(ip_addr) {
969            debug!(self.log, "Blocking ip: {:?}", ip_addr);
970            // Drop all the open channels
971            let blocked_sockets: Vec<SocketAddr> = self
972                .address_map
973                .keys()
974                .filter(|socket_addr| {
975                    socket_addr.ip() == ip_addr
976                        && self.block_list.socket_addr_is_blocked(socket_addr)
977                })
978                .copied()
979                .collect();
980            for socket_addr in blocked_sockets {
981                if let Some(channel_rc) = self.get_channel_by_address(&socket_addr) {
982                    debug!(
983                        self.log,
984                        "Dropping channel to blocked socket: {:?}", socket_addr
985                    );
986                    let mut channel = channel_rc.borrow_mut();
987                    if channel.connected() {
988                        self.notify_network_status(NetworkStatus::ConnectionDropped(
989                            SystemPath::with_socket(Transport::Tcp, socket_addr),
990                        ));
991                    }
992                    self.drop_channel(&mut channel);
993                }
994            }
995        }
996        self.notify_network_status(NetworkStatus::BlockedIp(ip_addr));
997    }
998
999    fn block_socket_addr(&mut self, socket_addr: SocketAddr) {
1000        if self.block_list.block_socket_addr(socket_addr) {
1001            debug!(self.log, "Blocking socket: {:?}", socket_addr);
1002            if let Some(channel_rc) = self.get_channel_by_address(&socket_addr) {
1003                debug!(
1004                    self.log,
1005                    "Dropping channel to blocked socket: {:?}", socket_addr
1006                );
1007                let mut channel = channel_rc.borrow_mut();
1008                self.drop_channel(&mut channel);
1009            }
1010        }
1011        self.notify_network_status(NetworkStatus::BlockedSystem(SystemPath::with_socket(
1012            Transport::Tcp,
1013            socket_addr,
1014        )));
1015    }
1016
1017    fn allow_ip_addr(&mut self, ip_addr: IpAddr) {
1018        debug!(self.log, "Allowing ip: {:?}", ip_addr);
1019        self.block_list.allow_ip_addr(&ip_addr);
1020        self.notify_network_status(NetworkStatus::AllowedIp(ip_addr));
1021    }
1022
1023    fn allow_socket_addr(&mut self, socket_addr: SocketAddr) {
1024        if self.block_list.allow_socket_addr(&socket_addr) {
1025            debug!(self.log, "Allowing socket: {:?}", socket_addr);
1026            self.notify_network_status(NetworkStatus::AllowedSystem(SystemPath::with_socket(
1027                Transport::Tcp,
1028                socket_addr,
1029            )));
1030        }
1031    }
1032
1033    fn block_ip_net(&mut self, ip_net: IpNet) {
1034        self.block_list.block_ip_net(ip_net);
1035        debug!(self.log, "Blocking IpNet: {:?}", &ip_net);
1036        // Drop all the open channels
1037        let blocked_sockets: Vec<SocketAddr> = self
1038            .address_map
1039            .keys()
1040            .filter(|socket_addr| {
1041                ip_net.contains(&socket_addr.ip())
1042                    && self.block_list.socket_addr_is_blocked(socket_addr)
1043            })
1044            .copied()
1045            .collect();
1046        for socket_addr in blocked_sockets {
1047            if let Some(channel_rc) = self.get_channel_by_address(&socket_addr) {
1048                debug!(
1049                    self.log,
1050                    "Dropping channel to blocked socket: {:?}", socket_addr
1051                );
1052                let mut channel = channel_rc.borrow_mut();
1053                if channel.connected() {
1054                    self.notify_network_status(NetworkStatus::ConnectionDropped(
1055                        SystemPath::with_socket(Transport::Tcp, socket_addr),
1056                    ));
1057                }
1058                self.drop_channel(&mut channel);
1059            }
1060        }
1061        self.notify_network_status(NetworkStatus::BlockedIpNet(ip_net));
1062    }
1063
1064    fn allow_ip_net(&mut self, ip_net: IpNet) {
1065        self.block_list.allow_ip_net(ip_net);
1066        debug!(self.log, "Allowing IpNet: {:?}", &ip_net);
1067        self.notify_network_status(NetworkStatus::AllowedIpNet(ip_net));
1068    }
1069}
1070
1071impl std::ops::Drop for NetworkThread {
1072    fn drop(&mut self) {
1073        // Ensure that the channels are shutdown and buffers are deallocated on panic
1074        if !self.stopped {
1075            while let Some((_, channel)) = self.token_map.pop_lru() {
1076                trace!(self.log, "Dropping channel in crashed NetworkThread");
1077                self.drop_channel(channel.borrow_mut().deref_mut());
1078            }
1079        }
1080    }
1081}
1082
1083fn bind_with_retries(
1084    addr: &SocketAddr,
1085    retries: usize,
1086    log: &KompactLogger,
1087) -> io::Result<TcpListener> {
1088    match TcpListener::bind(*addr) {
1089        Ok(listener) => Ok(listener),
1090        Err(e) => {
1091            if retries > 0 {
1092                debug!(
1093                    log,
1094                    "Failed to bind to addr {}, will retry {} more times, error was: {:?}",
1095                    addr,
1096                    retries,
1097                    e
1098                );
1099                // Lets give cleanup some time to do it's thing before we retry
1100                thread::sleep(Duration::from_millis(BIND_RETRY_INTERVAL));
1101                bind_with_retries(addr, retries - 1, log)
1102            } else {
1103                Err(e)
1104            }
1105        }
1106    }
1107}
1108
1109#[derive(Clone)]
1110struct EventWithRetries {
1111    token: Token,
1112    readable: bool,
1113    writeable: bool,
1114    retries: u8,
1115}
1116impl EventWithRetries {
1117    fn from(event: &Event) -> EventWithRetries {
1118        EventWithRetries {
1119            token: event.token(),
1120            readable: event.is_readable(),
1121            writeable: event.is_writable(),
1122            retries: 0,
1123        }
1124    }
1125
1126    fn writeable_with_token(token: &Token) -> EventWithRetries {
1127        EventWithRetries {
1128            token: *token,
1129            readable: false,
1130            writeable: true,
1131            retries: 0,
1132        }
1133    }
1134
1135    fn get_retry_event(&self) -> EventWithRetries {
1136        EventWithRetries {
1137            token: self.token,
1138            readable: self.readable,
1139            writeable: self.writeable,
1140            retries: self.retries + 1,
1141        }
1142    }
1143}
1144
1145#[derive(Default)]
1146pub struct BlockList {
1147    ipv4_set: IpRange<Ipv4Net>,
1148    ipv6_set: IpRange<Ipv6Net>,
1149    blocked_socket_addr: FxHashSet<SocketAddr>,
1150    allowed_socket_addr: FxHashSet<SocketAddr>,
1151}
1152
1153impl BlockList {
1154    /// Returns true if the rule-set has been modified
1155    fn block_ip_addr(&mut self, ip_addr: IpAddr) -> bool {
1156        match ip_addr {
1157            IpAddr::V4(addr) => {
1158                if self.ipv4_set.contains(&addr.to_network()) {
1159                    return false;
1160                }
1161                self.ipv4_set.add(addr.to_network());
1162            }
1163            IpAddr::V6(addr) => {
1164                if self.ipv6_set.contains(&addr.to_network()) {
1165                    return false;
1166                }
1167                self.ipv6_set.add(addr.to_network());
1168            }
1169        }
1170        true
1171    }
1172
1173    fn block_ip_net(&mut self, ip_net: IpNet) {
1174        match ip_net {
1175            IpNet::V4(net) => {
1176                self.ipv4_set.add(net);
1177            }
1178            IpNet::V6(net) => {
1179                self.ipv6_set.add(net);
1180            }
1181        }
1182    }
1183
1184    fn allow_ip_net(&mut self, ip_net: IpNet) {
1185        match ip_net {
1186            IpNet::V4(net) => {
1187                self.ipv4_set.remove(net);
1188            }
1189            IpNet::V6(net) => {
1190                self.ipv6_set.remove(net);
1191            }
1192        }
1193    }
1194
1195    /// Returns true if the rule-set has been modified
1196    fn allow_ip_addr(&mut self, ip_addr: &IpAddr) -> bool {
1197        match ip_addr {
1198            IpAddr::V4(addr) => {
1199                if self.ipv4_set.contains(&addr.to_network()) {
1200                    self.ipv4_set.remove(addr.to_network());
1201                    return true;
1202                }
1203            }
1204            IpAddr::V6(addr) => {
1205                if self.ipv6_set.contains(&addr.to_network()) {
1206                    self.ipv6_set.remove(addr.to_network());
1207                    return true;
1208                }
1209            }
1210        }
1211        false
1212    }
1213
1214    /// Returns true if the rule-set has been modified
1215    fn block_socket_addr(&mut self, socket_addr: SocketAddr) -> bool {
1216        self.allowed_socket_addr.remove(&socket_addr)
1217            || self.blocked_socket_addr.insert(socket_addr)
1218    }
1219
1220    /// Returns true if the rule-set has been modified
1221    fn allow_socket_addr(&mut self, socket_addr: &SocketAddr) -> bool {
1222        self.blocked_socket_addr.remove(socket_addr)
1223            || self.allowed_socket_addr.insert(*socket_addr)
1224    }
1225
1226    /// Returns true if the IpAddr is fully blocked, i.e. it's Blocked and there's no Allowed SocketAddr with the given IP
1227    fn ip_addr_is_blocked(&self, ip_addr: &IpAddr) -> bool {
1228        if self.ip_sets_contains_ip_addr(ip_addr) {
1229            // The IP may be partially blocked
1230            !self
1231                .allowed_socket_addr
1232                .iter()
1233                .any(|socket_addr| socket_addr.ip() == *ip_addr)
1234        } else {
1235            // The IP isn't Blocked at all, no need to check the Socket address list
1236            false
1237        }
1238    }
1239
1240    /// Returns true if the SocketAddr is blocked
1241    fn socket_addr_is_blocked(&self, socket_addr: &SocketAddr) -> bool {
1242        if self.allowed_socket_addr.contains(socket_addr) {
1243            false
1244        } else if self.blocked_socket_addr.contains(socket_addr) {
1245            true
1246        } else {
1247            self.ip_sets_contains_ip_addr(&socket_addr.ip())
1248        }
1249    }
1250
1251    fn ip_sets_contains_ip_addr(&self, ip_addr: &IpAddr) -> bool {
1252        match ip_addr {
1253            IpAddr::V4(addr) => self.ipv4_set.contains(&addr.to_network()),
1254            IpAddr::V6(addr) => self.ipv6_set.contains(&addr.to_network()),
1255        }
1256    }
1257}
1258
1259#[cfg(test)]
1260#[allow(unused_must_use)]
1261mod tests {
1262    use super::*;
1263    use crate::{dispatch::NetworkConfig, net::buffers::BufferConfig};
1264    use std::str::FromStr;
1265
1266    // Cleaner test-cases for manually running the thread
1267    fn poll_and_handle(thread: &mut NetworkThread) {
1268        let mut events = Events::with_capacity(10);
1269        thread
1270            .poll
1271            .poll(&mut events, Some(Duration::from_millis(100)));
1272        for event in events.iter() {
1273            thread.handle_event(EventWithRetries::from(event));
1274        }
1275        while let Some(event) = thread.retry_queue.pop_front() {
1276            thread.handle_event(event);
1277        }
1278    }
1279
1280    #[allow(unused_must_use)]
1281    fn setup_network_thread(
1282        network_config: &NetworkConfig,
1283    ) -> (NetworkThread, Sender<DispatchEvent>) {
1284        let mut cfg = kompact::test_support::test_kompact_config();
1285        cfg.system_components(DeadletterBox::new, network_config.clone().build());
1286        let system = cfg.build().wait().expect("KompactSystem");
1287
1288        // Set-up the the threads arguments
1289        let lookup = Arc::new(ArcSwap::from_pointee(ActorStore::new()));
1290        //network_thread_registration.set_readiness(Interest::empty());
1291        let (input_queue_sender, input_queue_receiver) = channel();
1292        let (dispatch_shutdown_sender, _) = promise();
1293        let logger = system.logger().clone();
1294        let dispatcher_ref = system.dispatcher_ref();
1295
1296        let network_thread = NetworkThreadBuilder::new(
1297            logger,
1298            "127.0.0.1:0".parse().expect("Address should work"),
1299            lookup,
1300            input_queue_receiver,
1301            dispatch_shutdown_sender,
1302            dispatcher_ref,
1303            network_config.clone(),
1304        )
1305        .expect("Should work")
1306        .build();
1307        (network_thread, input_queue_sender)
1308    }
1309
1310    fn run_handshake_sequence(requester: &mut NetworkThread, acceptor: &mut NetworkThread) {
1311        requester.receive_dispatch();
1312        thread::sleep(Duration::from_millis(100));
1313        poll_and_handle(acceptor);
1314        thread::sleep(Duration::from_millis(100));
1315        poll_and_handle(requester);
1316        thread::sleep(Duration::from_millis(100));
1317        poll_and_handle(acceptor);
1318        thread::sleep(Duration::from_millis(100));
1319        poll_and_handle(requester);
1320        thread::sleep(Duration::from_millis(100));
1321        poll_and_handle(acceptor);
1322        thread::sleep(Duration::from_millis(100));
1323    }
1324
1325    const PATH: &str = "local://127.0.0.1:0/test_actor";
1326
1327    fn empty_message() -> DispatchData {
1328        let path = ActorPath::from_str(PATH).expect("a proper path");
1329        DispatchData::Lazy(Box::new(()), path.clone(), path)
1330    }
1331
1332    #[test]
1333    fn merge_connections_basic() {
1334        // Sets up two NetworkThreads and does mutual connection request
1335        let (mut thread1, input_queue_1_sender) = setup_network_thread(&NetworkConfig::default());
1336        let (mut thread2, input_queue_2_sender) = setup_network_thread(&NetworkConfig::default());
1337        let addr1 = thread1.addr;
1338        let addr2 = thread2.addr;
1339        // Tell both to connect to each-other before they start running:
1340        input_queue_1_sender.send(DispatchEvent::Connect(addr2));
1341        input_queue_2_sender.send(DispatchEvent::Connect(addr1));
1342
1343        // Let both handle the connect event:
1344        thread1.receive_dispatch();
1345        thread2.receive_dispatch();
1346
1347        // Wait for the connect requests to reach destination:
1348        thread::sleep(Duration::from_millis(100));
1349
1350        // Accept requested streams
1351        thread1.receive_stream();
1352        thread2.receive_stream();
1353
1354        // Wait for Hello to reach destination:
1355        thread::sleep(Duration::from_millis(100));
1356
1357        // We need to make sure the TCP buffers are actually flushing the messages.
1358        // Handle events on both ends, say hello:
1359        poll_and_handle(&mut thread1);
1360        poll_and_handle(&mut thread2);
1361        thread::sleep(Duration::from_millis(100));
1362        // Cycle two Requested channels
1363        poll_and_handle(&mut thread1);
1364        poll_and_handle(&mut thread2);
1365        thread::sleep(Duration::from_millis(100));
1366        // Cycle three, merge and close
1367        poll_and_handle(&mut thread1);
1368        poll_and_handle(&mut thread2);
1369        thread::sleep(Duration::from_millis(100));
1370        // Cycle four, receive close and close
1371        poll_and_handle(&mut thread1);
1372        poll_and_handle(&mut thread2);
1373        thread::sleep(Duration::from_millis(100));
1374        // Now we can inspect the Network channels, both only have one channel:
1375        assert_eq!(thread1.address_map.len(), 1);
1376        assert_eq!(thread2.address_map.len(), 1);
1377
1378        // Now assert that they've kept the same channel:
1379        assert_eq!(
1380            thread1
1381                .address_map
1382                .drain()
1383                .next()
1384                .unwrap()
1385                .1
1386                .borrow_mut()
1387                .stream()
1388                .local_addr()
1389                .unwrap(),
1390            thread2
1391                .address_map
1392                .drain()
1393                .next()
1394                .unwrap()
1395                .1
1396                .borrow_mut()
1397                .stream()
1398                .peer_addr()
1399                .unwrap()
1400        );
1401    }
1402
1403    #[test]
1404    fn merge_connections_tricky() {
1405        // Sets up two NetworkThreads and does mutual connection request
1406        // This test uses a different order of events than basic
1407        let (mut thread1, input_queue_1_sender) = setup_network_thread(&NetworkConfig::default());
1408        let (mut thread2, input_queue_2_sender) = setup_network_thread(&NetworkConfig::default());
1409        let addr1 = thread1.addr;
1410        let addr2 = thread2.addr;
1411        // 2 Requests connection to 1 and sends Hello
1412        input_queue_2_sender.send(DispatchEvent::Connect(addr1));
1413        thread2.receive_dispatch();
1414        thread::sleep(Duration::from_millis(100));
1415
1416        // 1 accepts the connection and sends hello back
1417        thread1.receive_stream();
1418        thread::sleep(Duration::from_millis(100));
1419        // 2 receives the Hello
1420        poll_and_handle(&mut thread2);
1421        thread::sleep(Duration::from_millis(100));
1422        // 1 Receives Hello
1423        poll_and_handle(&mut thread1);
1424
1425        // 1 Receives Request Connection Event, this is the tricky part
1426        // 1 Requests connection to 2 and sends Hello
1427        input_queue_1_sender.send(DispatchEvent::Connect(addr2));
1428        thread1.receive_dispatch();
1429        thread::sleep(Duration::from_millis(100));
1430
1431        // 2 accepts the connection and replies with hello
1432        thread2.receive_stream();
1433        thread::sleep(Duration::from_millis(100));
1434
1435        // 2 receives the Hello on the new channel and merges
1436        poll_and_handle(&mut thread2);
1437        thread::sleep(Duration::from_millis(100));
1438
1439        // 1 receives the Hello on the new channel and merges
1440        poll_and_handle(&mut thread1);
1441        thread::sleep(Duration::from_millis(100));
1442
1443        // 2 receives the Bye and the Ack.
1444        poll_and_handle(&mut thread2);
1445        thread::sleep(Duration::from_millis(100));
1446
1447        poll_and_handle(&mut thread1);
1448        thread::sleep(Duration::from_millis(100));
1449
1450        // Now we can inspect the Network channels, both only have one channel:
1451        assert_eq!(thread1.address_map.len(), 1);
1452        assert_eq!(thread2.address_map.len(), 1);
1453
1454        // Now assert that they've kept the same channel:
1455        assert_eq!(
1456            thread1
1457                .address_map
1458                .drain()
1459                .next()
1460                .unwrap()
1461                .1
1462                .borrow_mut()
1463                .stream()
1464                .local_addr()
1465                .unwrap(),
1466            thread2
1467                .address_map
1468                .drain()
1469                .next()
1470                .unwrap()
1471                .1
1472                .borrow_mut()
1473                .stream()
1474                .peer_addr()
1475                .unwrap()
1476        );
1477    }
1478
1479    #[test]
1480    fn network_thread_custom_buffer_config() {
1481        let addr = "127.0.0.1:0".parse().expect("Address should work");
1482        let mut buffer_config = BufferConfig::default();
1483        buffer_config.chunk_size(128);
1484        buffer_config.max_chunk_count(14);
1485        buffer_config.initial_chunk_count(13);
1486        buffer_config.encode_buf_min_free_space(10);
1487        let network_config = NetworkConfig::with_buffer_config(addr, buffer_config);
1488        let (mut network_thread, _) = setup_network_thread(&network_config);
1489        // Assert that the buffer_pool is created correctly
1490        let (pool_size, _) = network_thread.buffer_pool.borrow_mut().get_pool_sizes();
1491        assert_eq!(pool_size, 13); // initial_pool_size
1492        assert_eq!(
1493            network_thread
1494                .buffer_pool
1495                .borrow_mut()
1496                .get_buffer()
1497                .unwrap()
1498                .len(),
1499            128
1500        );
1501        network_thread.stop();
1502    }
1503
1504    // Creates 5 different network_threads, connects "1" to 2, 3, and 4 properly, then the 5th,
1505    // asserts that the 2nd is disconnected, sends something on 3rd
1506    // then the 2nd connects to 1, and we asserts that 4 is dropped by 1.
1507    #[test]
1508    fn soft_channel_limit() {
1509        let mut network_config = NetworkConfig::default();
1510        network_config.set_soft_connection_limit(3);
1511        let (mut thread1, input_queue_1_sender) = setup_network_thread(&network_config);
1512        let (mut thread2, input_queue_2_sender) = setup_network_thread(&network_config);
1513        let (mut thread3, _) = setup_network_thread(&network_config);
1514        let (mut thread4, _) = setup_network_thread(&network_config);
1515        let (mut thread5, _) = setup_network_thread(&network_config);
1516        let addr1 = thread1.addr;
1517        let addr2 = thread2.addr;
1518        let addr3 = thread3.addr;
1519        let addr4 = thread4.addr;
1520        let addr5 = thread5.addr;
1521
1522        input_queue_1_sender.send(DispatchEvent::Connect(addr2));
1523        input_queue_1_sender.send(DispatchEvent::Connect(addr3));
1524        input_queue_1_sender.send(DispatchEvent::Connect(addr4));
1525
1526        // Run the handhsake sequences
1527        run_handshake_sequence(&mut thread1, &mut thread2);
1528        run_handshake_sequence(&mut thread1, &mut thread3);
1529        run_handshake_sequence(&mut thread1, &mut thread4);
1530
1531        // Assert that 2, 3, 4 is connected to 1
1532        assert!(
1533            thread1
1534                .address_map
1535                .get(&addr2)
1536                .unwrap()
1537                .borrow_mut()
1538                .connected()
1539        );
1540        assert!(
1541            thread1
1542                .address_map
1543                .get(&addr3)
1544                .unwrap()
1545                .borrow_mut()
1546                .connected()
1547        );
1548        assert!(
1549            thread1
1550                .address_map
1551                .get(&addr4)
1552                .unwrap()
1553                .borrow_mut()
1554                .connected()
1555        );
1556
1557        // Send something to 3 and 4, to ensure that 2 is LRU
1558        input_queue_1_sender.send(DispatchEvent::SendTcp(addr3, empty_message()));
1559        input_queue_1_sender.send(DispatchEvent::SendTcp(addr4, empty_message()));
1560        thread1.receive_dispatch();
1561        poll_and_handle(&mut thread1);
1562        thread::sleep(Duration::from_millis(100));
1563
1564        // Initiate connection to 5, execute handshake
1565        input_queue_1_sender.send(DispatchEvent::Connect(addr5));
1566        run_handshake_sequence(&mut thread1, &mut thread5);
1567
1568        // Assert that 2 is no longer connected to 1
1569        assert!(
1570            !thread1
1571                .address_map
1572                .get(&addr2)
1573                .unwrap()
1574                .borrow_mut()
1575                .connected()
1576        );
1577
1578        // Receive and send bye
1579        poll_and_handle(&mut thread2);
1580        thread::sleep(Duration::from_millis(100));
1581        // Receive bye
1582        poll_and_handle(&mut thread1);
1583        thread::sleep(Duration::from_millis(100));
1584        input_queue_1_sender.send(DispatchEvent::ClosedAck(addr2));
1585        thread1.receive_dispatch();
1586
1587        // Assert that 2 is dropped
1588        assert!(!thread1.address_map.contains_key(&addr2));
1589
1590        // Ack the closed connection on 2
1591        input_queue_2_sender.send(DispatchEvent::ClosedAck(addr1));
1592        thread2.receive_dispatch();
1593
1594        // Initiate new connection to 1 from 2 and execute handshake
1595        input_queue_2_sender.send(DispatchEvent::Connect(addr1));
1596        run_handshake_sequence(&mut thread2, &mut thread1);
1597
1598        // Receive the incoming connection and assert that 3 (LRU) is no longer connected
1599        poll_and_handle(&mut thread1);
1600        assert!(
1601            !thread1
1602                .address_map
1603                .get(&addr3)
1604                .unwrap()
1605                .borrow_mut()
1606                .connected()
1607        );
1608
1609        thread1.stop();
1610        thread2.stop();
1611        thread3.stop();
1612        thread4.stop();
1613        thread5.stop();
1614    }
1615
1616    // Creates 5 different network_threads, connects "1" to 2, 3, and 4 properly, then the 5th,
1617    // asserts that the 5th is refused.
1618    // Then attempts to connect 5 to 1 and asserts that it's refused again
1619    #[test]
1620    fn hard_channel_limit() {
1621        let mut network_config = NetworkConfig::default();
1622        network_config.set_hard_connection_limit(3);
1623        let (mut thread1, input_queue_1_sender) = setup_network_thread(&network_config);
1624        let (mut thread2, _) = setup_network_thread(&network_config);
1625        let (mut thread3, _) = setup_network_thread(&network_config);
1626        let (mut thread4, _) = setup_network_thread(&network_config);
1627        let (mut thread5, input_queue_5_sender) = setup_network_thread(&network_config);
1628        let addr1 = thread1.addr;
1629        let addr2 = thread2.addr;
1630        let addr3 = thread3.addr;
1631        let addr4 = thread4.addr;
1632        let addr5 = thread5.addr;
1633
1634        input_queue_1_sender.send(DispatchEvent::Connect(addr2));
1635        input_queue_1_sender.send(DispatchEvent::Connect(addr3));
1636        input_queue_1_sender.send(DispatchEvent::Connect(addr4));
1637
1638        // Run the handhsake sequences
1639        run_handshake_sequence(&mut thread1, &mut thread2);
1640        run_handshake_sequence(&mut thread1, &mut thread3);
1641        run_handshake_sequence(&mut thread1, &mut thread4);
1642
1643        // Assert that 2, 3, 4 is connected to 1
1644        assert!(
1645            thread1
1646                .address_map
1647                .get(&addr2)
1648                .unwrap()
1649                .borrow_mut()
1650                .connected()
1651        );
1652        assert!(
1653            thread1
1654                .address_map
1655                .get(&addr3)
1656                .unwrap()
1657                .borrow_mut()
1658                .connected()
1659        );
1660        assert!(
1661            thread1
1662                .address_map
1663                .get(&addr4)
1664                .unwrap()
1665                .borrow_mut()
1666                .connected()
1667        );
1668
1669        // Initiate connection to 5, execute handshake
1670        input_queue_1_sender.send(DispatchEvent::Connect(addr5));
1671        thread1.receive_dispatch();
1672        // That it was immediately discarded
1673        assert!(!thread1.address_map.contains_key(&addr5));
1674
1675        // This should do nothing...
1676        run_handshake_sequence(&mut thread1, &mut thread5);
1677
1678        // Assert channels are unchanged
1679        assert!(!thread1.address_map.contains_key(&addr5));
1680        assert!(
1681            thread1
1682                .address_map
1683                .get(&addr2)
1684                .unwrap()
1685                .borrow_mut()
1686                .connected()
1687        );
1688        assert!(
1689            thread1
1690                .address_map
1691                .get(&addr3)
1692                .unwrap()
1693                .borrow_mut()
1694                .connected()
1695        );
1696        assert!(
1697            thread1
1698                .address_map
1699                .get(&addr4)
1700                .unwrap()
1701                .borrow_mut()
1702                .connected()
1703        );
1704
1705        // Initiate new connection to 1 from 2 and execute handshake
1706        input_queue_5_sender.send(DispatchEvent::Connect(addr1));
1707        thread5.receive_dispatch();
1708        thread::sleep(Duration::from_millis(100));
1709        poll_and_handle(&mut thread1);
1710        // Should have been rejected immediately
1711        assert!(!thread1.address_map.contains_key(&addr5));
1712
1713        // This should do nothing
1714        run_handshake_sequence(&mut thread5, &mut thread1);
1715
1716        // Assert channels are unchanged
1717        assert!(!thread1.address_map.contains_key(&addr5));
1718        assert!(
1719            thread1
1720                .address_map
1721                .get(&addr2)
1722                .unwrap()
1723                .borrow_mut()
1724                .connected()
1725        );
1726        assert!(
1727            thread1
1728                .address_map
1729                .get(&addr3)
1730                .unwrap()
1731                .borrow_mut()
1732                .connected()
1733        );
1734        assert!(
1735            thread1
1736                .address_map
1737                .get(&addr4)
1738                .unwrap()
1739                .borrow_mut()
1740                .connected()
1741        );
1742
1743        thread1.stop();
1744        thread2.stop();
1745        thread3.stop();
1746        thread4.stop();
1747        thread5.stop();
1748    }
1749}