netty_rs/lib.rs
1#![warn(missing_docs)]
2//! Netty-rs allows exposes a simple-to-use API used to create stateful application level network
3//! protocols as both a client or server.
4//!
5//! Netty-rs allows requires consumers specify how to handle messages in different
6//! circumstances. Whenever specifying this the same API is used by using [Connection](Connection).
7//! This very simple API allows consumers to specify restful protocols of varying complexity.
8//! Each message-and-reply chain is in their own channel which does not and is not impacted by
9//! messages sent or received in other message-and-reply chains.
10//!
11//! The situations where how to handle messages need to be specified are:
12//! 1. When a consumer sends a message it can choose to wait for a reply and handle it in a
13//! custom way.
14//! 2. When acting as a server consumers need to specify how to handshake with new connections,
15//! which allows custom authentication of clients among any other handshake related action.
16//! 3. When acting as a server consumers need to specify how to handle non-reply messages from
17//! connections that have already been authenticated.
18//!
19//! The main API is accessed through the [Networker](Networker) struct.
20//!
21//! Netty-rs uses the [DirectoryService](DirectoryService) trait in order to allow consumers to either implement
22//! their own directory service for example a DNS or use the `SimpleDirectoryService` struct that
23//! implements this trait.
24//!
25//! Example
26//! ```rust
27//!# use futures::FutureExt;
28//!# use serde::{Serialize, Deserialize};
29//!# use rand::{thread_rng, Rng};
30//!# use tokio::time::Duration;
31//!# use std::sync::Arc;
32//!# use netty_rs::{Networker, SimpleDirectoryService, NetworkMessage, Connection, Action};
33//!# // Custom error struct
34//!# #[derive(Debug, Clone)]
35//!# struct MyError;
36//!# fn main() -> Result<(), MyError> {
37//!# tokio_test::block_on(async {
38//!fn generate_challenge() -> Vec<u8> {
39//! // Generate the challenge ...
40//!# let mut arr = [0u8; 32];
41//!# thread_rng().fill(&mut arr[..]);
42//!# arr.to_vec()
43//!}
44//!
45//!fn verify_challenge(a: &Vec<u8>) -> bool {
46//! // Verify challenge answer ...
47//!# let mut arr = [0u8; 32];
48//!# thread_rng().fill(&mut arr[..]);
49//!# arr.to_vec();
50//!# true
51//!}
52//!
53//!fn sign(c: &Vec<u8>) -> Vec<u8> {
54//! // Sign the challenge
55//!# let mut arr = [0u8; 32];
56//!# thread_rng().fill(&mut arr[..]);
57//!# arr.to_vec()
58//!}
59//!
60//!// Enum for the different types of messages we want to send
61//!#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq)]
62//!enum Content {
63//! Init,
64//! Challenge(Vec<u8>),
65//! Answer(Vec<u8>),
66//! Accept,
67//! Deny,
68//! Request,
69//! Response(i32),
70//! ProtocolError,
71//!}
72//!
73//!let ds = SimpleDirectoryService::new();
74//!let networker = Networker::new("127.0.0.1:8080".parse().unwrap(), ds,
75//! |handshake_msg: NetworkMessage<Content>, mut con: Connection<Content, MyError>| async move {
76//! // Perhaps you authenticate by producing a challenge and then
77//! // waiting for a response
78//! let challenge = generate_challenge();
79//! let message = handshake_msg.reply(Content::Challenge(challenge));
80//! let timeout = Duration::from_secs(2);
81//! // On timeout or other errors we just abort this whole process
82//! let response = con.send_message_await_reply(message, Some(timeout)).await?;
83//! if let Content::Answer(a) = &response.content {
84//! if verify_challenge(a) {
85//! let accept_msg = response.reply(Content::Accept);
86//! con.send_message(accept_msg).await?;
87//! } else {
88//! let deny_msg = response.reply(Content::Deny);
89//! con.send_message(deny_msg).await?;
90//! }
91//! } else {
92//! let deny_msg = response.reply(Content::Deny);
93//! con.send_message(deny_msg).await?;
94//! }
95//! // Return the id of this client
96//! let inner = Arc::try_unwrap(handshake_msg.from).unwrap_or_else(|e| (*e).clone());
97//! Ok(inner)
98//! },
99//! |message: NetworkMessage<Content>, mut con: Connection<Content, MyError>| async move {
100//! if let Content::Request = message.content {
101//! // Respond with the magical number for the meaning of life
102//! let response = message.reply(Content::Response(42));
103//! con.send_message(response).await?;
104//! } else {
105//! let response = message.reply(Content::ProtocolError);
106//! con.send_message(response).await?;
107//! }
108//! Ok(())
109//! }).await.map_err(|_| MyError)?;
110//!networker.listen(true).await.map_err(|_| MyError)?;
111//!// Send a message to ourselves
112//!let first_message = NetworkMessage::new(
113//! Arc::new("127.0.0.1:8080".to_string()),
114//! Arc::new("127.0.0.1:8080".to_string()),
115//! Content::Init,
116//!);
117//!let timeout = Duration::from_secs(2);
118//!let action = Action::new(
119//! |msg: NetworkMessage<Content>, mut con: Connection<Content, MyError>| {
120//! async move {
121//! if let Content::Challenge(c) = &msg.content {
122//! let answer = sign(c);
123//! let resp = msg.reply(Content::Answer(answer));
124//! let timeout = Duration::from_secs(2);
125//! let accept = con.send_message_await_reply(resp, Some(timeout)).await?;
126//! if let Content::Accept = accept.content {
127//! Ok(())
128//! } else {
129//! Err(MyError.into())
130//! }
131//! } else {
132//! Err(MyError.into())
133//! }
134//! }
135//! .boxed()
136//! },
137//!);
138//!networker
139//! .send_message(first_message, Some(timeout), Some(action))
140//! .await.map_err(|_| MyError)?;
141//!Result::<(), MyError>::Ok(())
142//!# })?;
143//!# Ok(())
144//!# }
145//!```
146
147// TODO: Add a retry-loop feature
148// TODO: Consider making handshakes have a different content than regular messages
149// TODO: If a channel dies then it will need to handshake again, this MIGHT be a problem
150// TODO: Create function for handshaking on sending message
151// TODO: We should verify the fields of the network message so that to and from are correct
152// TODO: Make the handshake closure at least into a FnOnce, perhaps also the message closure. They
153// are already cloned so they don't need to be FnMut to be called more than once
154use futures::future::BoxFuture;
155use futures::Future;
156use log::debug;
157use rand::distributions::Alphanumeric;
158use rand::{thread_rng, Rng};
159use serde::de::DeserializeOwned;
160use serde::{Deserialize, Serialize};
161use std::collections::HashMap;
162use std::fmt::Debug;
163use std::net::SocketAddr;
164use std::net::ToSocketAddrs;
165use std::sync::Arc;
166use tokio::io::AsyncReadExt;
167use tokio::io::AsyncWriteExt;
168use tokio::net::{TcpListener, TcpStream};
169use tokio::sync::mpsc::{channel, Receiver, Sender};
170use tokio::sync::oneshot::{channel as os_channel, Sender as OsSender};
171use tokio::time::sleep;
172use tokio::time::Duration;
173
174/// Marker trait for errors that are returned by the handlers
175pub trait HandlerError: Send + Sync + Debug + 'static + Clone {}
176
177impl<T> HandlerError for T where T: Send + Sync + Debug + 'static + Clone {}
178
179/// Errorkind in the error of netty-rs
180#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
181#[doc(hidden)]
182pub enum ErrorKind<E: HandlerError> {
183 Timeout,
184 ProtocolBreak,
185 Unspecified,
186 SerializationError,
187 NotFound,
188 DirectoryService,
189 HandlerError(E),
190}
191
192/// Error returned by netty-rs
193#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
194pub struct Error<E: HandlerError> {
195 /// The kind of error
196 pub kind: ErrorKind<E>,
197 /// The message provided with the error
198 pub msg: String,
199}
200
201impl<T: HandlerError> From<T> for Error<T> {
202 fn from(e: T) -> Self {
203 Self::handler_error(e)
204 }
205}
206
207impl<E: HandlerError> Error<E> {
208 fn handler_error(e: E) -> Self {
209 Self {
210 kind: ErrorKind::HandlerError(e),
211 msg: "Handler returned an error".to_string(),
212 }
213 }
214
215 fn timeout<S: ToString>(msg: S) -> Self {
216 Self {
217 kind: ErrorKind::Timeout,
218 msg: msg.to_string(),
219 }
220 }
221
222 fn custom<S: ToString>(msg: S) -> Self {
223 Self {
224 kind: ErrorKind::Unspecified,
225 msg: msg.to_string(),
226 }
227 }
228
229 fn serialization_error<S: ToString>(msg: S) -> Self {
230 Self {
231 kind: ErrorKind::SerializationError,
232 msg: msg.to_string(),
233 }
234 }
235
236 fn directory_service_error<S: ToString>(msg: S) -> Self {
237 Self {
238 kind: ErrorKind::DirectoryService,
239 msg: msg.to_string(),
240 }
241 }
242
243 fn network_error<S: ToString>(msg: S) -> Self {
244 Self {
245 kind: ErrorKind::DirectoryService,
246 msg: msg.to_string(),
247 }
248 }
249}
250
251/// Typedef for the result in netty-rs
252type Result<T, E> = std::result::Result<T, Error<E>>;
253
254/// Marker trait for the content contained in the [NetworkMessage](NetworkMessage) struct
255pub trait NetworkContent:
256 Serialize + DeserializeOwned + Send + Sync + Eq + PartialEq + 'static + Debug
257{
258}
259
260impl<T> NetworkContent for T where
261 T: Serialize + DeserializeOwned + Send + Sync + Eq + PartialEq + 'static + Debug
262{
263}
264
265const MAX_SOCKET_BUF_SIZE: usize = 1500;
266const DEFAULT_MESSAGE_TIMEOUT_MILLIS: u64 = 5000;
267const CHANNEL_SIZE: usize = 100;
268
269/// Netty-rs uses a pluggable directory service to translate from an id to an IP address and
270/// port number. This is provided via this trait. A [SimpleDirectoryService](SimpleDirectoryService)
271/// struct is provided which translates from any type that implements [ToSocketAddrs](ToSocketAddrs).
272/// This includes strings which are in the format of for example "127.0.0.1:8080", which will allow consumers to
273/// easily use strings or IP addresses as identifiers. If a more complicated lookup is required
274/// then implementing this trait is always avaliable.
275pub trait DirectoryService<N: Send + Sync, E: HandlerError>: Send + Sync {
276 /// Translates names to Socket addresses
277 fn translate(&self, name: &N) -> Result<SocketAddr, E>;
278}
279
280/// Directory service that translates Strings and socket addresses to socket addresses
281/// Strings have to be in the format of "127.0.0.1:8080"
282pub struct SimpleDirectoryService<
283 S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync,
284> {
285 _pd: std::marker::PhantomData<S>,
286}
287
288impl<S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync>
289 SimpleDirectoryService<S>
290{
291 /// Create a new [SimpleDirectoryService](SimpleDirectoryService)
292 pub fn new() -> Self {
293 Self {
294 _pd: std::marker::PhantomData::<S>,
295 }
296 }
297}
298
299impl<S: ToSocketAddrs<Iter = std::vec::IntoIter<SocketAddr>> + Send + Sync, E: HandlerError>
300 DirectoryService<S, E> for SimpleDirectoryService<S>
301{
302 /// Translates a `ToSocketAddrs` to the first `SocketAddr` it yields
303 fn translate(&self, name: &S) -> Result<SocketAddr, E> {
304 let mut sockets = name.to_socket_addrs().map_err(|_| {
305 Error::directory_service_error("Could not get socket address from directory service")
306 })?;
307 let socket = sockets.next().ok_or_else(|| {
308 Error::directory_service_error("Could not get socket address from directory service")
309 })?;
310 Ok(socket)
311 }
312}
313
314type NetPack<T, E> = (
315 NetworkMessage<T>,
316 Option<Action<T, E>>,
317 Option<Duration>,
318 OsSender<Result<(), E>>,
319);
320type ConnectionPackage<T, E> = (
321 NetworkMessage<T>,
322 Option<Duration>,
323 bool,
324 OsSender<Result<Option<NetworkMessage<T>>, E>>,
325);
326
327/// This struct is the main API for using netty. It allows the creation of a server and the ability
328/// to send messages to clients.
329#[derive(Debug, Clone)]
330pub struct Networker<T: NetworkContent + 'static, E: HandlerError> {
331 tx: Sender<NetPack<T, E>>,
332 // NOTE: command_tx currently sends a bool that represents if the networker should act as a
333 // server or not.
334 command_tx: Sender<(bool, OsSender<Result<(), E>>)>,
335}
336
337/// This struct represents the network messages sent and received, it can be created either from
338/// the new `new` constructor if a fresh message is desired.
339/// If a reply is desired then the `reply` method should be used.
340#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Serialize, Deserialize)]
341pub struct NetworkMessage<T: NetworkContent> {
342 /// The receiver id, this is currently a string but this may change to a generic representing
343 /// an id in the future
344 pub to: Arc<String>,
345 /// The sender id, this is currently a string but this may change to a generic representing
346 /// an id in the future
347 pub from: Arc<String>,
348 /// The message id, used to reply to this message
349 pub id: Arc<String>,
350 /// If this is set to `Some` then this message is a reply to the contained ID, else it is a
351 /// fresh message
352 pub reply: Option<Arc<String>>,
353 /// The content of the message
354 #[serde(bound(deserialize = "T: DeserializeOwned"))]
355 pub content: T,
356}
357
358fn new_id() -> String {
359 let rand_string: String = thread_rng()
360 .sample_iter(&Alphanumeric)
361 .take(30)
362 .map(char::from)
363 .collect();
364 rand_string
365}
366
367impl<T: NetworkContent> NetworkMessage<T> {
368 /// Create a new fresh message
369 pub fn new(to: Arc<String>, from: Arc<String>, content: T) -> Self {
370 Self {
371 to,
372 from,
373 id: Arc::new(new_id()),
374 reply: None,
375 content,
376 }
377 }
378
379 /// Used to construct replies to the provided message.
380 pub fn reply(&self, content: T) -> Self {
381 Self {
382 to: self.from.clone(),
383 from: self.to.clone(),
384 id: Arc::new(new_id()),
385 reply: Some(self.id.clone()),
386 content,
387 }
388 }
389}
390
391async fn create_socket_task<T: NetworkContent, M, F, E: HandlerError>(
392 mut socket: TcpStream,
393 handle_message: M,
394) -> Result<Sender<NetPack<T, E>>, E>
395where
396 F: Future<Output = HandlerResult<(), E>> + Send,
397 M: FnMut(NetworkMessage<T>, Connection<T, E>) -> F + Send + Sync + Clone + 'static,
398{
399 let (tx, mut rx): (Sender<NetPack<T, E>>, Receiver<NetPack<T, E>>) = channel(CHANNEL_SIZE);
400 tokio::spawn(async move {
401 debug!("Starting up a new socket task");
402 let (tx, mut reaction_rx): (
403 Sender<ConnectionPackage<T, E>>,
404 Receiver<ConnectionPackage<T, E>>,
405 ) = channel(CHANNEL_SIZE);
406 let mut awaiting_reply: HashMap<Arc<String>, OsSender<NetworkMessage<T>>> = HashMap::new();
407 loop {
408 //let mut buf = [0; MAX_SOCKET_BUF_SIZE];
409 // This task is responsible for:
410 // 1. Listening to the channel from the main manager thread for messages to send on
411 // this socket, sending the message out on the socket and recording which (if
412 // any) reaction was interested in responses to that message as well as
413 // recording when a timeout happens and then returning an error on the channel
414 // 2. Listening to the reaction_rx channel for NetworkMessages to send and
415 // recording which message (if any) message the reaction is interested in
416 // listening to.
417 // 3. Listening to the socket and parsing the data to NetworkMessages, figuring out
418 // if the parsed message is a response to any channels and in that case routing
419 // it to that channel, else calling a new handle_message instance with the
420 // message
421 tokio::select! {
422 Some((msg, timeout, want_reply, os_tx)) = reaction_rx.recv() => {
423 // 2) Send the message and then create a timeout that sends a timeout error
424 // on expiry. Also create a one-shot channel chain where we record which
425 // message we are waiting for and give it the transmission end of the
426 // chain
427 debug!("Socket task - Received a request to send a message on reaction thread");
428 let msg_s = match serde_json::to_string(&msg) {
429 Ok(r) => r,
430 Err(_) => {
431 let e = Error::serialization_error(format!("Could not serialize: {:?}", msg));
432 if let Err(e) = os_tx.send(Err(e)) {
433 debug!("Oneshot return channel did not stay open: {:?}", e);
434 }
435 continue;
436 },
437 };
438 match socket.write_all(msg_s.as_bytes()).await {
439 Ok(()) => (),
440 Err(_) => {
441 let e = Error::network_error(format!("Could not write on socket"));
442 if let Err(e) = os_tx.send(Err(e)) {
443 debug!("Oneshot return channel did not stay open: {:?}", e);
444 }
445 continue;
446 },
447 }
448 if want_reply {
449 let timeout_time = match timeout {
450 Some(t) => t,
451 // If no timeout is provided then we use a default
452 None => Duration::from_millis(DEFAULT_MESSAGE_TIMEOUT_MILLIS),
453 };
454 let (hm_tx, hm_rx) = os_channel();
455 awaiting_reply.insert(msg.id, hm_tx);
456 tokio::spawn(async move {
457 let timeout = tokio::time::sleep(timeout_time);
458 tokio::pin!(timeout);
459 tokio::select! {
460 Ok(msg) = hm_rx => {
461 if let Err(e) = os_tx.send(Ok(Some(msg))) {
462 debug!("Oneshot return channel did not stay open: {:?}", e);
463 }
464 },
465 _ = timeout => {
466 if let Err(e) = os_tx.send(Err(Error::timeout("Did not recieve a response in time!"))) {
467 debug!("Oneshot return channel did not stay open: {:?}", e);
468 }
469 }
470 }
471 });
472 } else {
473 if let Err(e) = os_tx.send(Ok(None)) {
474 debug!("Oneshot return channel did not stay open: {:?}", e);
475 }
476 }
477 },
478 Result::<NetworkMessage<T>, E>::Ok(msg) = read_message(&mut socket) => {
479 // 3) Listens to the socket and parses the data to NetworkMessages then
480 // looks in the awaiting_reply hashmap to see if there is any connection
481 // that are waiting for reply to a message that the parsed message is
482 // replying to, in that case send the message on that channel if it
483 // remains open, if the channel is closed discard the message.
484 // If there is no connection that awaits reply from this message then
485 // start a new connection for this message
486 match &msg.reply {
487 Some(id) => {
488 match awaiting_reply.remove(id) {
489 Some(tx) => {
490 debug!("Sending message to waiter {:?}", msg);
491 if let Err(e) = tx.send(msg) {
492 //Discard the channel and message
493 debug!("Discarded the message due to: {:?}", e);
494 }
495 }
496 None => {
497 //Discard the channel and message
498 debug!("Could not find channel to pass message on");
499 }
500 };
501 },
502 None => {
503 debug!("Did not find any waiters for {:?}", msg);
504 let con = Connection {
505 sender: tx.clone(),
506 };
507 let mut handle_message = handle_message.clone();
508 tokio::spawn(async move {
509 if let Err(e) = handle_message(msg, con).await {
510 debug!("Handle message returned an error {:?}", e);
511 }
512 });
513 },
514 };
515 },
516 Some((msg, react, timeout, return_tx)) = rx.recv() => {
517 debug!("On socket task - Received a message to send {:?}", msg);
518 // 1. Listens to the main-thread communication channel and sends the
519 // message on the internal reaction channel to schedule it. If a
520 // reaction is passed then listens to the oneshot channel for a reply
521 // and provides this to the reaction call
522 let (os_tx, os_rx) = os_channel();
523 let want_reply = react.is_some();
524 tx.send((msg, timeout, want_reply, os_tx)).await.expect("Networker internal error due to reaction channel being closed");
525 let tx = tx.clone();
526 tokio::spawn(async move {
527 let r = os_rx.await.expect("Networker internal error, awaiting os_rx channel but transmitter closed");
528 match r {
529 Ok(r) => {
530 if want_reply {
531 let react = react.expect("Networker unreachable state");
532 let msg = r.expect("Network unreachable state - received no message while expecting reply");
533 let con = Connection {
534 sender: tx.clone(),
535 };
536 let result = react.0(msg, con).await;
537 return_tx.send(result).expect("Networker internal error - networker did not listen to return channel");
538 } else {
539 return_tx.send(Ok(())).expect("Networker internal error - networker did not listen to return channel");
540 }
541 },
542 Err(e) => {
543 return_tx.send(Err(e)).expect("Networker internal error - networker did not listen to return channel");
544 }
545 };
546 });
547 }
548 }
549 }
550 });
551 Ok(tx)
552}
553
554async fn read_message<T: NetworkContent, E: HandlerError>(
555 socket: &mut TcpStream,
556) -> Result<NetworkMessage<T>, E> {
557 let mut buf = [0; MAX_SOCKET_BUF_SIZE];
558 match socket.read(&mut buf).await {
559 Ok(n) => match String::from_utf8(buf[..n].to_vec()) {
560 Ok(s) => match serde_json::from_str(&s) {
561 Ok(s) => return Ok(s),
562 Err(_) => {
563 return Err(Error::serialization_error(
564 "Could not deserialize recieved message",
565 ));
566 }
567 },
568 Err(e) => {
569 return Err(Error::serialization_error(format!(
570 "Could not convert to utf-8 - {:?}",
571 e
572 )));
573 }
574 },
575
576 Err(e) => {
577 return Err(Error::network_error(format!(
578 "Could not read from socket - {:?}",
579 e
580 )));
581 }
582 }
583}
584
585async fn process_socket<T: NetworkContent, H, M, FH, FM, E: HandlerError>(
586 mut socket: TcpStream,
587 mut handle_handshake: H,
588 handle_message: M,
589) -> Result<(String, Sender<NetPack<T, E>>), E>
590where
591 FH: Future<Output = HandlerResult<String, E>> + Send,
592 FM: Future<Output = HandlerResult<(), E>> + Send,
593 H: FnMut(NetworkMessage<T>, Connection<T, E>) -> FH + Send + Sync + 'static,
594 M: FnMut(NetworkMessage<T>, Connection<T, E>) -> FM + Send + Sync + Clone + 'static,
595{
596 let msg = read_message(&mut socket).await?;
597 let (handshake_tx, mut handshake_rx) = channel(CHANNEL_SIZE);
598 let con = Connection {
599 sender: handshake_tx.clone(),
600 };
601 let listen_for_messages = async {
602 //This loop never returns Ok, instead it sends successes on the channel to the
603 //Connection thread in order to drive that thread forward
604 loop {
605 if let Some((msg, timeout, want_reply, os_tx)) = handshake_rx.recv().await {
606 let msg = match serde_json::to_string(&msg) {
607 Ok(s) => s,
608 Err(_) => {
609 let e =
610 Error::serialization_error(format!("Could not serialize {:?}", msg));
611 if let Err(e) = os_tx.send(Err(e)) {
612 debug!("Oneshot return channel did not stay open: {:?}", e);
613 }
614 continue;
615 }
616 };
617 match socket.write_all(msg.as_bytes()).await {
618 Ok(()) => (),
619 Err(_) => {
620 let e = Error::network_error("Could not send over network socket");
621 if let Err(e) = os_tx.send(Err(e)) {
622 debug!("Oneshot return channel did not stay open: {:?}", e);
623 }
624 continue;
625 }
626 };
627 if !want_reply {
628 if let Err(e) = os_tx.send(Ok(None)) {
629 debug!("Oneshot return channel did not stay open: {:?}", e);
630 }
631 continue;
632 } else {
633 let timeout = sleep(
634 timeout.unwrap_or(Duration::from_millis(DEFAULT_MESSAGE_TIMEOUT_MILLIS)),
635 );
636 tokio::pin!(timeout);
637 tokio::select! {
638 s = read_message(&mut socket) => {
639 match s {
640 Ok(s) => {
641 if let Err(e) = os_tx.send(Ok(Some(s))) {
642 debug!("Oneshot return channel did not stay open: {:?}", e);
643 }
644 },
645 Err(e) => {
646 if let Err(e) = os_tx.send(Err(e)) {
647 debug!("Oneshot return channel did not stay open: {:?}", e);
648 }
649 }
650 }
651 },
652 _ = timeout => {
653 if let Err(e) = os_tx.send(Err(Error::timeout("Did not recieve a response in time"))) {
654 debug!("Oneshot return channel did not stay open: {:?}", e);
655 }
656 }
657 }
658 }
659 }
660 }
661 };
662 let r = tokio::select! {
663 Result::<T, E>::Err(e) = listen_for_messages => return Err(e),
664 r = handle_handshake(msg, con) => {
665 match r {
666 Ok(r) => r,
667 Err(e) => {
668 return Err(e);
669 }
670 }
671 }
672 };
673 let tx = create_socket_task(socket, handle_message).await?;
674 Ok((r, tx))
675}
676
677type HandlerResult<T, E> = Result<T, E>;
678impl<T: NetworkContent, E: HandlerError> Networker<T, E> {
679 /// Creates a new networker using. `address` is the network socket address that the server should
680 /// listen to. `directory_service` is the directory service to use in order to translate IDs to
681 /// addresses. `handle_handshakes` is a closure that is called when a new connection is
682 /// recieved which sends a `NetworkMessage`. This closure should authenticate if appropriate
683 /// and do other handshake and setup related things. `handle_messages` is a closure that is
684 /// called for all messages received from a connection that has already been handshaked.
685 pub async fn new<H, M, FH, FM>(
686 address: SocketAddr,
687 directory_service: impl DirectoryService<String, E> + 'static,
688 handle_handshakes: H,
689 handle_messages: M,
690 ) -> Result<Networker<T, E>, E>
691 where
692 FM: Future<Output = HandlerResult<(), E>> + Send,
693 FH: Future<Output = HandlerResult<String, E>> + Send,
694 M: FnMut(NetworkMessage<T>, Connection<T, E>) -> FM + Send + Sync + Clone + 'static,
695 H: FnMut(NetworkMessage<T>, Connection<T, E>) -> FH + Send + Sync + Clone + 'static,
696 {
697 let (net_tx, mut thread_rx): (Sender<NetPack<T, E>>, Receiver<NetPack<T, E>>) =
698 channel(CHANNEL_SIZE);
699 let (command_tx, mut command_rx): (
700 Sender<(bool, OsSender<Result<(), E>>)>,
701 Receiver<(bool, OsSender<Result<(), E>>)>,
702 ) = channel(CHANNEL_SIZE);
703 let mut name_channel_hm = HashMap::new();
704 let mut listener: Option<TcpListener> = None;
705 tokio::spawn(async move {
706 loop {
707 tokio::select! {
708 Some((should_be_server, os_tx)) = command_rx.recv() => {
709 if should_be_server {
710 listener = match TcpListener::bind(address).await {
711 Ok(r) => Some(r),
712 Err(e) => {
713 match os_tx.send(Err(Error::network_error(format!(
714 "Could not listen to address: {} due to: {:?}",
715 address, e
716 )))) {
717 Ok(()) => (),
718 Err(_) => {
719 debug!("Internal networker error - could not return result from turning on/off server");
720 },
721 };
722 continue;
723 }
724 };
725 } else {
726 listener = None;
727 }
728 match os_tx.send(Ok(())) {
729 Ok(()) => {
730 },
731 Err(_) => {
732 debug!("Internal networker error - could not send on return channel from request to start listening");
733 },
734 };
735 }
736 // NOTE: We only listen for new connections in the case where the networker
737 // should act like a server
738 Ok((socket, _)) = async {
739 if let Some(listener) = &listener {
740 listener.accept().await
741 } else {
742 // If the networker is not a server then we will wait forever, if this
743 // changes then this future will be cancelled
744 let forever = futures::future::pending();
745 let () = forever.await;
746 unreachable!("Networker unreachable state - tried to listen to a non-existent server");
747 }
748 } => {
749 debug!("Received a TCP connection");
750 let (name, tx) = match process_socket(socket, handle_handshakes.clone(), handle_messages.clone()).await {
751 Ok(r) => r,
752 Err(e) => {
753 debug!("Could not establish contact {:?}", e);
754 continue;
755 },
756 };
757 debug!("Handshake finished peer name is: {}", name);
758 name_channel_hm.insert(name, tx);
759 },
760 Some((message, react, timeout, os_tx)) = thread_rx.recv() => {
761 // Here we need to find the correct socket to send to. And we should
762 // maybe allow a "ALL" option to send towards
763 match name_channel_hm.get(&*message.to) {
764 // There is already a recipiant connected with that name
765 Some(tx) => {
766 if let Err(e) = tx.send((message, react, timeout, os_tx)).await {
767 debug!("{}", e);
768 }
769 }
770 // No connection to provided recipiant was found
771 None => {
772 // TODO: We should allow consumers to provide a do_handshake
773 // function which is run automatically be ran on sending a new
774 // message to an uninitiated peer. This function should return
775 // a Result containing the name of the peer. If no such
776 // function is provided then we skip that part
777 match directory_service.translate(&message.to) {
778 Ok(address) => {
779 let socket = match TcpStream::connect(address).await
780 .map_err(|_| Error::network_error(format!("Could not connect to address {:?}", address))) {
781 Ok(s) => s,
782 Err(e) => {
783 if let Err(_) = os_tx.send(Err(e)) {
784 debug!("Could not return error send on one-shot channel");
785 }
786 continue;
787 }
788 };
789 // Spawn a new thread with a new tcp connection and
790 // send the message
791 let tx = match create_socket_task(socket, handle_messages.clone()).await {
792 Ok(tx) => tx,
793 Err(e) => {
794 if let Err(_) = os_tx.send(Err(e)) {
795 debug!("Could not return error send on one-shot channel");
796 }
797 continue;
798 }
799 };
800 let name = message.to.clone();
801 if let Err(e) = tx.send((message, react, timeout, os_tx)).await {
802 debug!("{}", e);
803 }
804 name_channel_hm.insert((*name).clone(), tx);
805 },
806 Err(e) => {
807 if let Err(_) = os_tx.send(Err(e)) {
808 debug!("Could not return error send on one-shot channel");
809 }
810 continue;
811 }
812 }
813 }
814 }
815 },
816 }
817 }
818 });
819
820 Ok(Networker {
821 tx: net_tx,
822 command_tx,
823 })
824 }
825
826 /// Sends a message and then reacts to the response with the action and then returns the
827 /// last message returned
828 pub async fn send_message(
829 &self,
830 message: NetworkMessage<T>,
831 timeout: Option<Duration>,
832 react: Option<Action<T, E>>,
833 ) -> Result<(), E> {
834 let (os_tx, os_rx) = os_channel();
835 if let Err(_) = self.tx.send((message, react, timeout, os_tx)).await {
836 debug!("Could not send to networker");
837 }
838 match os_rx.await.expect("Oneshot transmitter dropped in socket") {
839 Ok(()) => Ok(()),
840 Err(e) => Err(e),
841 }
842 }
843
844 /// Starts the server
845 pub async fn listen(&self, should_listen: bool) -> Result<(), E> {
846 let (tx, rx) = os_channel();
847 match self.command_tx.send((should_listen, tx)).await {
848 Ok(()) => (),
849 Err(_) => {
850 return Err(Error::custom(format!(
851 "Internal error - Could change listening status due to channel being down"
852 )))
853 }
854 };
855 match rx.await {
856 Ok(r) => r,
857 Err(_) => return Err(Error::custom(format!("Internal error - Could not get response from listening call due to return channel closing prematurely"))),
858 }
859 }
860}
861
862/// Action contains a closure that handles the communication on a channel
863pub struct Action<
864 T: Send + Sync + Serialize + DeserializeOwned + Eq + PartialEq + Debug + 'static,
865 E: HandlerError,
866>(
867 Box<
868 dyn FnOnce(NetworkMessage<T>, Connection<T, E>) -> BoxFuture<'static, HandlerResult<(), E>>
869 + Send
870 + Sync,
871 >,
872);
873
874impl<T: NetworkContent, E: HandlerError> Action<T, E> {
875 /// Creates a new Action
876 pub fn new<
877 F: FnOnce(NetworkMessage<T>, Connection<T, E>) -> BoxFuture<'static, HandlerResult<(), E>>
878 + 'static
879 + Send
880 + Sync,
881 >(
882 f: F,
883 ) -> Self {
884 Self(Box::new(f))
885 }
886}
887
888/// A connection serves as the main API to specify how a network conversation should look. It
889/// allows sending messages to the recipiant and awaiting their response using the two methods
890/// `send_message` and `send_message_await_reply`.
891pub struct Connection<T: NetworkContent, E: HandlerError> {
892 sender: Sender<ConnectionPackage<T, E>>,
893}
894
895impl<T: NetworkContent, E: HandlerError> Connection<T, E> {
896 /// Send a message over the connection without waiting for a reply, returning upon successfully
897 /// sending the message out.
898 pub async fn send_message(&mut self, msg: NetworkMessage<T>) -> Result<(), E> {
899 let (tx, rx) = os_channel();
900 match self.sender.send((msg, None, false, tx)).await {
901 Ok(()) => (),
902 Err(_) => {
903 panic!("Internal error - could not send on internal channel",);
904 }
905 };
906 let r = match rx.await {
907 Ok(r) => r,
908 Err(_) => {
909 panic!("Internal error - internal return channel was closed before receiving a message");
910 }
911 };
912 r.map(|r| {
913 if r.is_some() {
914 panic!("Unreachable state - expected None but was provided a network message")
915 }
916 })
917 }
918
919 /// Send a message over the connection while waiting for a reply, a `Result` is returned with
920 /// the replying `NetworkMessage`.
921 pub async fn send_message_await_reply(
922 &mut self,
923 msg: NetworkMessage<T>,
924 timeout: Option<Duration>,
925 ) -> Result<NetworkMessage<T>, E> {
926 let (tx, rx) = os_channel();
927 match self.sender.send((msg, timeout, true, tx)).await {
928 Ok(()) => (),
929 Err(_) => {
930 return Err(Error::custom(
931 "Internal error - could not send on internal channel",
932 ));
933 }
934 };
935 let r = match rx.await {
936 Ok(r) => r,
937 Err(_) => {
938 return Err(Error::custom("Internal error - internal return channel was closed before receiving a message"));
939 }
940 };
941 r.map(|r| r.expect("Expecting a network message as response but None was provided"))
942 }
943}