netty_rs/
lib.rs

1#![warn(missing_docs)]
2//! Netty-rs allows exposes a simple-to-use API used to create stateful application level network
3//! protocols as both a client or server.
4//!
5//! Netty-rs allows requires consumers specify how to handle messages in different
6//! circumstances. Whenever specifying this the same API is used by using [Connection](Connection).
7//! This very simple API allows consumers to specify restful protocols of varying complexity.
8//! Each message-and-reply chain is in their own channel which does not and is not impacted by
9//! messages sent or received in other message-and-reply chains.
10//!
11//! The situations where how to handle messages need to be specified are:
12//! 1. When a consumer sends a message it can choose to wait for a reply and handle it in a
13//! custom way.
14//! 2. When acting as a server consumers need to specify how to handshake with new connections,
15//! which allows custom authentication of clients among any other handshake related action.
16//! 3. When acting as a server consumers need to specify how to handle non-reply messages from
17//! connections that have already been authenticated.
18//!
19//! The main API is accessed through the [Networker](Networker) struct.
20//!
21//! Netty-rs uses the [DirectoryService](DirectoryService) trait in order to allow consumers to either implement
22//! their own directory service for example a DNS or use the `SimpleDirectoryService` struct that
23//! implements this trait.
24//!
25//! Example
26//! ```rust
27//!# use futures::FutureExt;
28//!# use serde::{Serialize, Deserialize};
29//!# use rand::{thread_rng, Rng};
30//!# use tokio::time::Duration;
31//!# use std::sync::Arc;
32//!# use netty_rs::{Networker, SimpleDirectoryService, NetworkMessage, Connection, Action};
33//!# // Custom error struct
34//!# #[derive(Debug, Clone)]
35//!# struct MyError;
36//!# fn main() -> Result<(), MyError> {
37//!#   tokio_test::block_on(async {
38//!fn generate_challenge() -> Vec<u8> {
39//!    // Generate the challenge ...
40//!#            let mut arr = [0u8; 32];
41//!#            thread_rng().fill(&mut arr[..]);
42//!#            arr.to_vec()
43//!}
44//!
45//!fn verify_challenge(a: &Vec<u8>) -> bool {
46//!    // Verify challenge answer ...
47//!#            let mut arr = [0u8; 32];
48//!#            thread_rng().fill(&mut arr[..]);
49//!#            arr.to_vec();
50//!#            true
51//!}
52//!
53//!fn sign(c: &Vec<u8>) -> Vec<u8> {
54//!    // Sign the challenge
55//!#            let mut arr = [0u8; 32];
56//!#            thread_rng().fill(&mut arr[..]);
57//!#            arr.to_vec()
58//!}
59//!
60//!// Enum for the different types of messages we want to send
61//!#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
62//!enum Content {
63//!    Init,
64//!    Challenge(Vec<u8>),
65//!    Answer(Vec<u8>),
66//!    Accept,
67//!    Deny,
68//!    Request,
69//!    Response(i32),
70//!    ProtocolError,
71//!}
72//!
73//!let ds = SimpleDirectoryService::new();
74//!let networker = Networker::new("127.0.0.1:8080".parse().unwrap(), ds,
75//!    |handshake_msg: NetworkMessage<Content>, mut con: Connection<Content, MyError>| async move {
76//!       // Perhaps you authenticate by producing a challenge and then
77//!       // waiting for a response
78//!       let challenge = generate_challenge();
79//!       let message = handshake_msg.reply(Content::Challenge(challenge));
80//!       let timeout = Duration::from_secs(2);
81//!       // On timeout or other errors we just abort this whole process
82//!       let response = con.send_message_await_reply(message, Some(timeout)).await?;
83//!       if let Content::Answer(a) = &response.content {
84//!           if verify_challenge(a) {
85//!               let accept_msg = response.reply(Content::Accept);
86//!               con.send_message(accept_msg).await?;
87//!           } else {
88//!               let deny_msg = response.reply(Content::Deny);
89//!               con.send_message(deny_msg).await?;
90//!           }
91//!       } else {
92//!           let deny_msg = response.reply(Content::Deny);
93//!           con.send_message(deny_msg).await?;
94//!       }
95//!       // Return the id of this client
96//!       let inner = Arc::try_unwrap(handshake_msg.from).unwrap_or_else(|e| (*e).clone());
97//!        Ok(inner)
98//!    },
99//!    |message: NetworkMessage<Content>, mut con: Connection<Content, MyError>| async move {
100//!        if let Content::Request = message.content {
101//!            // Respond with the magical number for the meaning of life
102//!            let response = message.reply(Content::Response(42));
103//!            con.send_message(response).await?;
104//!        } else {
105//!            let response = message.reply(Content::ProtocolError);
106//!            con.send_message(response).await?;
107//!        }
108//!        Ok(())
109//!    }).await.map_err(|_| MyError)?;
110//!networker.listen(true).await.map_err(|_| MyError)?;
111//!// Send a message to ourselves
112//!let first_message = NetworkMessage::new(
113//!    Arc::new("127.0.0.1:8080".to_string()),
114//!    Arc::new("127.0.0.1:8080".to_string()),
115//!    Content::Init,
116//!);
117//!let timeout = Duration::from_secs(2);
118//!let action = Action::new(
119//!    |msg: NetworkMessage<Content>, mut con: Connection<Content, MyError>| {
120//!        async move {
121//!            if let Content::Challenge(c) = &msg.content {
122//!                let answer = sign(c);
123//!                let resp = msg.reply(Content::Answer(answer));
124//!                let timeout = Duration::from_secs(2);
125//!                let accept = con.send_message_await_reply(resp, Some(timeout)).await?;
126//!                if let Content::Accept = accept.content {
127//!                    Ok(())
128//!                } else {
129//!                    Err(MyError.into())
130//!                }
131//!            } else {
132//!                Err(MyError.into())
133//!            }
134//!        }
135//!        .boxed()
136//!    },
137//!);
138//!networker
139//!    .send_message(first_message, Some(timeout), Some(action))
140//!    .await.map_err(|_| MyError)?;
141//!Result::<(), MyError>::Ok(())
142//!#    })?;
143//!#    Ok(())
144//!# }
145//!```
146
147// TODO: Add a retry-loop feature
148// TODO: Consider making handshakes have a different content than regular messages
149// TODO: If a channel dies then it will need to handshake again, this MIGHT be a problem
150// TODO: Create function for handshaking on sending message
151// TODO: We should verify the fields of the network message so that to and from are correct
152// TODO: Make the handshake closure at least into a FnOnce, perhaps also the message closure. They
153// are already cloned so they don't need to be FnMut to be called more than once
154use futures::future::BoxFuture;
155use futures::Future;
156use log::debug;
157use rand::distributions::Alphanumeric;
158use rand::{thread_rng, Rng};
159use serde::de::DeserializeOwned;
160use serde::{Deserialize, Serialize};
161use std::collections::HashMap;
162use std::fmt::Debug;
163use std::net::SocketAddr;
164use std::net::ToSocketAddrs;
165use std::sync::Arc;
166use tokio::io::AsyncReadExt;
167use tokio::io::AsyncWriteExt;
168use tokio::net::{TcpListener, TcpStream};
169use tokio::sync::mpsc::{channel, Receiver, Sender};
170use tokio::sync::oneshot::{channel as os_channel, Sender as OsSender};
171use tokio::time::sleep;
172use tokio::time::Duration;
173
174/// Marker trait for errors that are returned by the handlers
175pub trait HandlerError: Send + Sync + Debug + 'static + Clone {}
176
177impl<T> HandlerError for T where T: Send + Sync + Debug + 'static + Clone {}
178
179/// Errorkind in the error of netty-rs
180#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
181#[doc(hidden)]
182pub enum ErrorKind<E: HandlerError> {
183    Timeout,
184    ProtocolBreak,
185    Unspecified,
186    SerializationError,
187    NotFound,
188    DirectoryService,
189    HandlerError(E),
190}
191
192/// Error returned by netty-rs
193#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
194pub struct Error<E: HandlerError> {
195    /// The kind of error
196    pub kind: ErrorKind<E>,
197    /// The message provided with the error
198    pub msg: String,
199}
200
201impl<T: HandlerError> From<T> for Error<T> {
202    fn from(e: T) -> Self {
203        Self::handler_error(e)
204    }
205}
206
207impl<E: HandlerError> Error<E> {
208    fn handler_error(e: E) -> Self {
209        Self {
210            kind: ErrorKind::HandlerError(e),
211            msg: "Handler returned an error".to_string(),
212        }
213    }
214
215    fn timeout<S: ToString>(msg: S) -> Self {
216        Self {
217            kind: ErrorKind::Timeout,
218            msg: msg.to_string(),
219        }
220    }
221
222    fn custom<S: ToString>(msg: S) -> Self {
223        Self {
224            kind: ErrorKind::Unspecified,
225            msg: msg.to_string(),
226        }
227    }
228
229    fn serialization_error<S: ToString>(msg: S) -> Self {
230        Self {
231            kind: ErrorKind::SerializationError,
232            msg: msg.to_string(),
233        }
234    }
235
236    fn directory_service_error<S: ToString>(msg: S) -> Self {
237        Self {
238            kind: ErrorKind::DirectoryService,
239            msg: msg.to_string(),
240        }
241    }
242
243    fn network_error<S: ToString>(msg: S) -> Self {
244        Self {
245            kind: ErrorKind::DirectoryService,
246            msg: msg.to_string(),
247        }
248    }
249}
250
251/// Typedef for the result in netty-rs
252type Result<T, E> = std::result::Result<T, Error<E>>;
253
254/// Marker trait for the content contained in the [NetworkMessage](NetworkMessage) struct
255pub trait NetworkContent:
256    Serialize + DeserializeOwned + Send + Sync + Eq + PartialEq + 'static + Debug
257{
258}
259
260impl<T> NetworkContent for T where
261    T: Serialize + DeserializeOwned + Send + Sync + Eq + PartialEq + 'static + Debug
262{
263}
264
265const MAX_SOCKET_BUF_SIZE: usize = 1500;
266const DEFAULT_MESSAGE_TIMEOUT_MILLIS: u64 = 5000;
267const CHANNEL_SIZE: usize = 100;
268
269/// Netty-rs uses a pluggable directory service to translate from an id to an IP address and
270/// port number. This is provided via this trait. A [SimpleDirectoryService](SimpleDirectoryService)
271/// struct is provided which translates from any type that implements [ToSocketAddrs](ToSocketAddrs).
272/// This includes strings which are in the format of for example "127.0.0.1:8080", which will allow consumers to
273/// easily use strings or IP addresses as identifiers. If a more complicated lookup is required
274/// then implementing this trait is always avaliable.
275pub trait DirectoryService<N: Send + Sync, E: HandlerError>: Send + Sync {
276    /// Translates names to Socket addresses
277    fn translate(&self, name: &N) -> Result<SocketAddr, E>;
278}
279
280/// Directory service that translates Strings and socket addresses to socket addresses
281/// Strings have to be in the format of "127.0.0.1:8080"
282pub struct SimpleDirectoryService<
283    S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync,
284> {
285    _pd: std::marker::PhantomData<S>,
286}
287
288impl<S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync>
289    SimpleDirectoryService<S>
290{
291    /// Create a new [SimpleDirectoryService](SimpleDirectoryService)
292    pub fn new() -> Self {
293        Self {
294            _pd: std::marker::PhantomData::<S>,
295        }
296    }
297}
298
299impl<S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync, E: HandlerError>
300    DirectoryService<S, E> for SimpleDirectoryService<S>
301{
302    /// Translates a `ToSocketAddrs` to the first `SocketAddr` it yields
303    fn translate(&self, name: &S) -> Result<SocketAddr, E> {
304        let mut sockets = name.to_socket_addrs().map_err(|_| {
305            Error::directory_service_error("Could not get socket address from directory service")
306        })?;
307        let socket = sockets.next().ok_or_else(|| {
308            Error::directory_service_error("Could not get socket address from directory service")
309        })?;
310        Ok(socket)
311    }
312}
313
314type NetPack<T, E> = (
315    NetworkMessage<T>,
316    Option<Action<T, E>>,
317    Option<Duration>,
318    OsSender<Result<(), E>>,
319);
320type ConnectionPackage<T, E> = (
321    NetworkMessage<T>,
322    Option<Duration>,
323    bool,
324    OsSender<Result<Option<NetworkMessage<T>>, E>>,
325);
326
327/// This struct is the main API for using netty. It allows the creation of a server and the ability
328/// to send messages to clients.
329#[derive(Debug, Clone)]
330pub struct Networker<T: NetworkContent + 'static, E: HandlerError> {
331    tx: Sender<NetPack<T, E>>,
332    // NOTE: command_tx currently sends a bool that represents if the networker should act as a
333    // server or not.
334    command_tx: Sender<(bool, OsSender<Result<(), E>>)>,
335}
336
337/// This struct represents the network messages sent and received, it can be created either from
338/// the new `new` constructor if a fresh message is desired.
339/// If a reply is desired then the `reply` method should be used.
340#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
341pub struct NetworkMessage<T: NetworkContent> {
342    /// The receiver id, this is currently a string but this may change to a generic representing
343    /// an id in the future
344    pub to: Arc<String>,
345    /// The sender id, this is currently a string but this may change to a generic representing
346    /// an id in the future
347    pub from: Arc<String>,
348    /// The message id, used to reply to this message
349    pub id: Arc<String>,
350    /// If this is set to `Some` then this message is a reply to the contained ID, else it is a
351    /// fresh message
352    pub reply: Option<Arc<String>>,
353    /// The content of the message
354    #[serde(bound(deserialize = "T: DeserializeOwned"))]
355    pub content: T,
356}
357
358fn new_id() -> String {
359    let rand_string: String = thread_rng()
360        .sample_iter(&Alphanumeric)
361        .take(30)
362        .map(char::from)
363        .collect();
364    rand_string
365}
366
367impl<T: NetworkContent> NetworkMessage<T> {
368    /// Create a new fresh message
369    pub fn new(to: Arc<String>, from: Arc<String>, content: T) -> Self {
370        Self {
371            to,
372            from,
373            id: Arc::new(new_id()),
374            reply: None,
375            content,
376        }
377    }
378
379    /// Used to construct replies to the provided message.
380    pub fn reply(&self, content: T) -> Self {
381        Self {
382            to: self.from.clone(),
383            from: self.to.clone(),
384            id: Arc::new(new_id()),
385            reply: Some(self.id.clone()),
386            content,
387        }
388    }
389}
390
391async fn create_socket_task<T: NetworkContent, M, F, E: HandlerError>(
392    mut socket: TcpStream,
393    handle_message: M,
394) -> Result<Sender<NetPack<T, E>>, E>
395where
396    F: Future<Output = HandlerResult<(), E>> + Send,
397    M: FnMut(NetworkMessage<T>, Connection<T, E>) -> F + Send + Sync + Clone + 'static,
398{
399    let (tx, mut rx): (Sender<NetPack<T, E>>, Receiver<NetPack<T, E>>) = channel(CHANNEL_SIZE);
400    tokio::spawn(async move {
401        debug!("Starting up a new socket task");
402        let (tx, mut reaction_rx): (
403            Sender<ConnectionPackage<T, E>>,
404            Receiver<ConnectionPackage<T, E>>,
405        ) = channel(CHANNEL_SIZE);
406        let mut awaiting_reply: HashMap<Arc<String>, OsSender<NetworkMessage<T>>> = HashMap::new();
407        loop {
408            //let mut buf = [0; MAX_SOCKET_BUF_SIZE];
409            // This task is responsible for:
410            // 1. Listening to the channel from the main manager thread for messages to send on
411            //      this socket, sending the message out on the socket and recording which (if
412            //      any) reaction was interested in responses to that message as well as
413            //      recording when a timeout happens and then returning an error on the channel
414            // 2. Listening to the reaction_rx channel for NetworkMessages to send and
415            //    recording which message (if any) message the reaction is interested in
416            //    listening to.
417            // 3. Listening to the socket and parsing the data to NetworkMessages, figuring out
418            //    if the parsed message is a response to any channels and in that case routing
419            //    it to that channel, else calling a new handle_message instance with the
420            //    message
421            tokio::select! {
422                Some((msg, timeout, want_reply, os_tx)) = reaction_rx.recv() => {
423                    // 2) Send the message and then create a timeout that sends a timeout error
424                    //    on expiry. Also create a one-shot channel chain where we record which
425                    //    message we are waiting for and give it the transmission end of the
426                    //    chain
427                    debug!("Socket task - Received a request to send a message on reaction thread");
428                    let msg_s = match serde_json::to_string(&msg) {
429                        Ok(r) => r,
430                        Err(_) => {
431                            let e = Error::serialization_error(format!("Could not serialize: {:?}", msg));
432                            if let Err(e) = os_tx.send(Err(e)) {
433                                debug!("Oneshot return channel did not stay open: {:?}", e);
434                            }
435                            continue;
436                        },
437                    };
438                    match socket.write_all(msg_s.as_bytes()).await {
439                        Ok(()) => (),
440                        Err(_) => {
441                            let e = Error::network_error(format!("Could not write on socket"));
442                            if let Err(e) = os_tx.send(Err(e)) {
443                                debug!("Oneshot return channel did not stay open: {:?}", e);
444                            }
445                            continue;
446                        },
447                    }
448                    if want_reply {
449                        let timeout_time = match timeout {
450                            Some(t) => t,
451                            // If no timeout is provided then we use a default
452                            None => Duration::from_millis(DEFAULT_MESSAGE_TIMEOUT_MILLIS),
453                        };
454                        let (hm_tx, hm_rx) = os_channel();
455                        awaiting_reply.insert(msg.id, hm_tx);
456                        tokio::spawn(async move {
457                            let timeout = tokio::time::sleep(timeout_time);
458                            tokio::pin!(timeout);
459                            tokio::select! {
460                                Ok(msg) = hm_rx => {
461                                    if let Err(e) = os_tx.send(Ok(Some(msg))) {
462                                        debug!("Oneshot return channel did not stay open: {:?}", e);
463                                    }
464                                },
465                                _ = timeout => {
466                                    if let Err(e) = os_tx.send(Err(Error::timeout("Did not recieve a response in time!"))) {
467                                        debug!("Oneshot return channel did not stay open: {:?}", e);
468                                    }
469                                }
470                            }
471                        });
472                    } else {
473                        if let Err(e) = os_tx.send(Ok(None)) {
474                            debug!("Oneshot return channel did not stay open: {:?}", e);
475                        }
476                    }
477                },
478                Result::<NetworkMessage<T>, E>::Ok(msg) = read_message(&mut socket) => {
479                    // 3) Listens to the socket and parses the data to NetworkMessages then
480                    //    looks in the awaiting_reply hashmap to see if there is any connection
481                    //    that are waiting for reply to a message that the parsed message is
482                    //    replying to, in that case send the message on that channel if it
483                    //    remains open, if the channel is closed discard the message.
484                    //    If there is no connection that awaits reply from this message then
485                    //    start a new connection for this message
486                    match &msg.reply {
487                        Some(id) => {
488                            match awaiting_reply.remove(id) {
489                                Some(tx) => {
490                                    debug!("Sending message to waiter {:?}", msg);
491                                    if let Err(e) = tx.send(msg) {
492                                        //Discard the channel and message
493                                        debug!("Discarded the message due to: {:?}", e);
494                                    }
495                                }
496                                None => {
497                                    //Discard the channel and message
498                                    debug!("Could not find channel to pass message on");
499                                }
500                            };
501                        },
502                        None => {
503                            debug!("Did not find any waiters for {:?}", msg);
504                            let con = Connection {
505                                sender: tx.clone(),
506                            };
507                            let mut handle_message = handle_message.clone();
508                            tokio::spawn(async move {
509                                if let Err(e) = handle_message(msg, con).await {
510                                    debug!("Handle message returned an error {:?}", e);
511                                }
512                            });
513                        },
514                    };
515                },
516                Some((msg, react, timeout, return_tx)) = rx.recv() => {
517                    debug!("On socket task - Received a message to send {:?}", msg);
518                    // 1. Listens to the main-thread communication channel and sends the
519                    //    message on the internal reaction channel to schedule it. If a
520                    //    reaction is passed then listens to the oneshot channel for a reply
521                    //    and provides this to the reaction call
522                    let (os_tx, os_rx) = os_channel();
523                    let want_reply = react.is_some();
524                    tx.send((msg, timeout, want_reply, os_tx)).await.expect("Networker internal error due to reaction channel being closed");
525                    let tx = tx.clone();
526                    tokio::spawn(async move {
527                        let r = os_rx.await.expect("Networker internal error, awaiting os_rx channel but transmitter closed");
528                        match r {
529                            Ok(r) => {
530                                if want_reply {
531                                    let react = react.expect("Networker unreachable state");
532                                    let msg = r.expect("Network unreachable state - received no message while expecting reply");
533                                    let con = Connection {
534                                        sender: tx.clone(),
535                                    };
536                                    let result = react.0(msg, con).await;
537                                    return_tx.send(result).expect("Networker internal error - networker did not listen to return channel");
538                                } else {
539                                    return_tx.send(Ok(())).expect("Networker internal error - networker did not listen to return channel");
540                                }
541                            },
542                            Err(e) => {
543                                return_tx.send(Err(e)).expect("Networker internal error - networker did not listen to return channel");
544                            }
545                        };
546                    });
547                }
548            }
549        }
550    });
551    Ok(tx)
552}
553
554async fn read_message<T: NetworkContent, E: HandlerError>(
555    socket: &mut TcpStream,
556) -> Result<NetworkMessage<T>, E> {
557    let mut buf = [0; MAX_SOCKET_BUF_SIZE];
558    match socket.read(&mut buf).await {
559        Ok(n) => match String::from_utf8(buf[..n].to_vec()) {
560            Ok(s) => match serde_json::from_str(&s) {
561                Ok(s) => return Ok(s),
562                Err(_) => {
563                    return Err(Error::serialization_error(
564                        "Could not deserialize recieved message",
565                    ));
566                }
567            },
568            Err(e) => {
569                return Err(Error::serialization_error(format!(
570                    "Could not convert to utf-8 - {:?}",
571                    e
572                )));
573            }
574        },
575
576        Err(e) => {
577            return Err(Error::network_error(format!(
578                "Could not read from socket - {:?}",
579                e
580            )));
581        }
582    }
583}
584
585async fn process_socket<T: NetworkContent, H, M, FH, FM, E: HandlerError>(
586    mut socket: TcpStream,
587    mut handle_handshake: H,
588    handle_message: M,
589) -> Result<(String, Sender<NetPack<T, E>>), E>
590where
591    FH: Future<Output = HandlerResult<String, E>> + Send,
592    FM: Future<Output = HandlerResult<(), E>> + Send,
593    H: FnMut(NetworkMessage<T>, Connection<T, E>) -> FH + Send + Sync + 'static,
594    M: FnMut(NetworkMessage<T>, Connection<T, E>) -> FM + Send + Sync + Clone + 'static,
595{
596    let msg = read_message(&mut socket).await?;
597    let (handshake_tx, mut handshake_rx) = channel(CHANNEL_SIZE);
598    let con = Connection {
599        sender: handshake_tx.clone(),
600    };
601    let listen_for_messages = async {
602        //This loop never returns Ok, instead it sends successes on the channel to the
603        //Connection thread in order to drive that thread forward
604        loop {
605            if let Some((msg, timeout, want_reply, os_tx)) = handshake_rx.recv().await {
606                let msg = match serde_json::to_string(&msg) {
607                    Ok(s) => s,
608                    Err(_) => {
609                        let e =
610                            Error::serialization_error(format!("Could not serialize {:?}", msg));
611                        if let Err(e) = os_tx.send(Err(e)) {
612                            debug!("Oneshot return channel did not stay open: {:?}", e);
613                        }
614                        continue;
615                    }
616                };
617                match socket.write_all(msg.as_bytes()).await {
618                    Ok(()) => (),
619                    Err(_) => {
620                        let e = Error::network_error("Could not send over network socket");
621                        if let Err(e) = os_tx.send(Err(e)) {
622                            debug!("Oneshot return channel did not stay open: {:?}", e);
623                        }
624                        continue;
625                    }
626                };
627                if !want_reply {
628                    if let Err(e) = os_tx.send(Ok(None)) {
629                        debug!("Oneshot return channel did not stay open: {:?}", e);
630                    }
631                    continue;
632                } else {
633                    let timeout = sleep(
634                        timeout.unwrap_or(Duration::from_millis(DEFAULT_MESSAGE_TIMEOUT_MILLIS)),
635                    );
636                    tokio::pin!(timeout);
637                    tokio::select! {
638                        s = read_message(&mut socket) => {
639                            match s {
640                                Ok(s) => {
641                                    if let Err(e) = os_tx.send(Ok(Some(s))) {
642                                        debug!("Oneshot return channel did not stay open: {:?}", e);
643                                    }
644                                },
645                                Err(e) => {
646                                    if let Err(e) = os_tx.send(Err(e)) {
647                                        debug!("Oneshot return channel did not stay open: {:?}", e);
648                                    }
649                                }
650                            }
651                        },
652                        _ = timeout => {
653                            if let Err(e) = os_tx.send(Err(Error::timeout("Did not recieve a response in time"))) {
654                                debug!("Oneshot return channel did not stay open: {:?}", e);
655                            }
656                        }
657                    }
658                }
659            }
660        }
661    };
662    let r = tokio::select! {
663        Result::<T, E>::Err(e) = listen_for_messages => return Err(e),
664        r = handle_handshake(msg, con) => {
665            match r {
666                Ok(r) => r,
667                Err(e) => {
668                    return Err(e);
669                }
670            }
671        }
672    };
673    let tx = create_socket_task(socket, handle_message).await?;
674    Ok((r, tx))
675}
676
677type HandlerResult<T, E> = Result<T, E>;
678impl<T: NetworkContent, E: HandlerError> Networker<T, E> {
679    /// Creates a new networker using. `address` is the network socket address that the server should
680    /// listen to. `directory_service` is the directory service to use in order to translate IDs to
681    /// addresses. `handle_handshakes` is a closure that is called when a new connection is
682    /// recieved which sends a `NetworkMessage`. This closure should authenticate if appropriate
683    /// and do other handshake and setup related things. `handle_messages` is a closure that is
684    /// called for all messages received from a connection that has already been handshaked.
685    pub async fn new<H, M, FH, FM>(
686        address: SocketAddr,
687        directory_service: impl DirectoryService<String, E> + 'static,
688        handle_handshakes: H,
689        handle_messages: M,
690    ) -> Result<Networker<T, E>, E>
691    where
692        FM: Future<Output = HandlerResult<(), E>> + Send,
693        FH: Future<Output = HandlerResult<String, E>> + Send,
694        M: FnMut(NetworkMessage<T>, Connection<T, E>) -> FM + Send + Sync + Clone + 'static,
695        H: FnMut(NetworkMessage<T>, Connection<T, E>) -> FH + Send + Sync + Clone + 'static,
696    {
697        let (net_tx, mut thread_rx): (Sender<NetPack<T, E>>, Receiver<NetPack<T, E>>) =
698            channel(CHANNEL_SIZE);
699        let (command_tx, mut command_rx): (
700            Sender<(bool, OsSender<Result<(), E>>)>,
701            Receiver<(bool, OsSender<Result<(), E>>)>,
702        ) = channel(CHANNEL_SIZE);
703        let mut name_channel_hm = HashMap::new();
704        let mut listener: Option<TcpListener> = None;
705        tokio::spawn(async move {
706            loop {
707                tokio::select! {
708                    Some((should_be_server, os_tx)) = command_rx.recv() => {
709                        if should_be_server {
710                            listener = match TcpListener::bind(address).await {
711                                Ok(r) => Some(r),
712                                Err(e) => {
713                                    match os_tx.send(Err(Error::network_error(format!(
714                                                    "Could not listen to address: {} due to: {:?}",
715                                                    address, e
716                                                    )))) {
717                                        Ok(()) => (),
718                                        Err(_) => {
719                                            debug!("Internal networker error - could not return result from turning on/off server");
720                                        },
721                                    };
722                                    continue;
723                                }
724                            };
725                        } else {
726                            listener = None;
727                        }
728                        match os_tx.send(Ok(())) {
729                            Ok(()) => {
730                            },
731                            Err(_) => {
732                                debug!("Internal networker error - could not send on return channel from request to start listening");
733                            },
734                        };
735                    }
736                    // NOTE: We only listen for new connections in the case where the networker
737                    // should act like a server
738                    Ok((socket, _)) = async {
739                        if let Some(listener) = &listener {
740                            listener.accept().await
741                        } else {
742                            // If the networker is not a server then we will wait forever, if this
743                            // changes then this future will be cancelled
744                            let forever = futures::future::pending();
745                            let () = forever.await;
746                            unreachable!("Networker unreachable state - tried to listen to a non-existent server");
747                        }
748                    } => {
749                        debug!("Received a TCP connection");
750                        let (name, tx) = match process_socket(socket, handle_handshakes.clone(), handle_messages.clone()).await {
751                            Ok(r) => r,
752                            Err(e) => {
753                                debug!("Could not establish contact {:?}", e);
754                                continue;
755                            },
756                        };
757                        debug!("Handshake finished peer name is: {}", name);
758                        name_channel_hm.insert(name, tx);
759                    },
760                    Some((message, react, timeout, os_tx)) = thread_rx.recv() => {
761                        // Here we need to find the correct socket to send to. And we should
762                        // maybe allow a "ALL" option to send towards
763                        match name_channel_hm.get(&*message.to) {
764                            // There is already a recipiant connected with that name
765                            Some(tx) => {
766                                if let Err(e) = tx.send((message, react, timeout, os_tx)).await {
767                                    debug!("{}", e);
768                                }
769                            }
770                            // No connection to provided recipiant was found
771                            None => {
772                                // TODO: We should allow consumers to provide a do_handshake
773                                // function which is run automatically be ran on sending a new
774                                // message to an uninitiated peer. This function should return
775                                // a Result containing the name of the peer. If no such
776                                // function is provided then we skip that part
777                                match directory_service.translate(&message.to) {
778                                    Ok(address) => {
779                                        let socket = match TcpStream::connect(address).await
780                                            .map_err(|_| Error::network_error(format!("Could not connect to address {:?}", address))) {
781                                                Ok(s) => s,
782                                                Err(e) => {
783                                                    if let Err(_) = os_tx.send(Err(e)) {
784                                                        debug!("Could not return error send on one-shot channel");
785                                                    }
786                                                    continue;
787                                                }
788                                            };
789                                        // Spawn a new thread with a new tcp connection and
790                                        // send the message
791                                        let tx = match create_socket_task(socket, handle_messages.clone()).await {
792                                            Ok(tx) => tx,
793                                            Err(e) => {
794                                                if let Err(_) = os_tx.send(Err(e)) {
795                                                    debug!("Could not return error send on one-shot channel");
796                                                }
797                                                continue;
798                                            }
799                                        };
800                                        let name = message.to.clone();
801                                        if let Err(e) = tx.send((message, react, timeout, os_tx)).await {
802                                            debug!("{}", e);
803                                        }
804                                        name_channel_hm.insert((*name).clone(), tx);
805                                    },
806                                    Err(e) => {
807                                        if let Err(_) = os_tx.send(Err(e)) {
808                                            debug!("Could not return error send on one-shot channel");
809                                        }
810                                        continue;
811                                    }
812                                }
813                            }
814                        }
815                    },
816                }
817            }
818        });
819
820        Ok(Networker {
821            tx: net_tx,
822            command_tx,
823        })
824    }
825
826    /// Sends a message and then reacts to the response with the action and then returns the
827    /// last message returned
828    pub async fn send_message(
829        &self,
830        message: NetworkMessage<T>,
831        timeout: Option<Duration>,
832        react: Option<Action<T, E>>,
833    ) -> Result<(), E> {
834        let (os_tx, os_rx) = os_channel();
835        if let Err(_) = self.tx.send((message, react, timeout, os_tx)).await {
836            debug!("Could not send to networker");
837        }
838        match os_rx.await.expect("Oneshot transmitter dropped in socket") {
839            Ok(()) => Ok(()),
840            Err(e) => Err(e),
841        }
842    }
843
844    /// Starts the server
845    pub async fn listen(&self, should_listen: bool) -> Result<(), E> {
846        let (tx, rx) = os_channel();
847        match self.command_tx.send((should_listen, tx)).await {
848            Ok(()) => (),
849            Err(_) => {
850                return Err(Error::custom(format!(
851                    "Internal error - Could change listening status due to channel being down"
852                )))
853            }
854        };
855        match rx.await {
856            Ok(r) => r,
857            Err(_) => return Err(Error::custom(format!("Internal error - Could not get response from listening call due to return channel closing prematurely"))),
858        }
859    }
860}
861
862/// Action contains a closure that handles the communication on a channel
863pub struct Action<
864    T: Send + Sync + Serialize + DeserializeOwned + Eq + PartialEq + Debug + 'static,
865    E: HandlerError,
866>(
867    Box<
868        dyn FnOnce(NetworkMessage<T>, Connection<T, E>) -> BoxFuture<'static, HandlerResult<(), E>>
869            + Send
870            + Sync,
871    >,
872);
873
874impl<T: NetworkContent, E: HandlerError> Action<T, E> {
875    /// Creates a new Action
876    pub fn new<
877        F: FnOnce(NetworkMessage<T>, Connection<T, E>) -> BoxFuture<'static, HandlerResult<(), E>>
878            + 'static
879            + Send
880            + Sync,
881    >(
882        f: F,
883    ) -> Self {
884        Self(Box::new(f))
885    }
886}
887
888/// A connection serves as the main API to specify how a network conversation should look. It
889/// allows sending messages to the recipiant and awaiting their response using the two methods
890/// `send_message` and `send_message_await_reply`.
891pub struct Connection<T: NetworkContent, E: HandlerError> {
892    sender: Sender<ConnectionPackage<T, E>>,
893}
894
895impl<T: NetworkContent, E: HandlerError> Connection<T, E> {
896    /// Send a message over the connection without waiting for a reply, returning upon successfully
897    /// sending the message out.
898    pub async fn send_message(&mut self, msg: NetworkMessage<T>) -> Result<(), E> {
899        let (tx, rx) = os_channel();
900        match self.sender.send((msg, None, false, tx)).await {
901            Ok(()) => (),
902            Err(_) => {
903                panic!("Internal error - could not send on internal channel",);
904            }
905        };
906        let r = match rx.await {
907            Ok(r) => r,
908            Err(_) => {
909                panic!("Internal error - internal return channel was closed before receiving a message");
910            }
911        };
912        r.map(|r| {
913            if r.is_some() {
914                panic!("Unreachable state - expected None but was provided a network message")
915            }
916        })
917    }
918
919    /// Send a message over the connection while waiting for a reply, a `Result` is returned with
920    /// the replying `NetworkMessage`.
921    pub async fn send_message_await_reply(
922        &mut self,
923        msg: NetworkMessage<T>,
924        timeout: Option<Duration>,
925    ) -> Result<NetworkMessage<T>, E> {
926        let (tx, rx) = os_channel();
927        match self.sender.send((msg, timeout, true, tx)).await {
928            Ok(()) => (),
929            Err(_) => {
930                return Err(Error::custom(
931                    "Internal error - could not send on internal channel",
932                ));
933            }
934        };
935        let r = match rx.await {
936            Ok(r) => r,
937            Err(_) => {
938                return Err(Error::custom("Internal error - internal return channel was closed before receiving a message"));
939            }
940        };
941        r.map(|r| r.expect("Expecting a network message as response but None was provided"))
942    }
943}