burn_communication/
base.rs

1use burn_common::future::DynFut;
2use serde::{Deserialize, Serialize};
3use std::fmt::{Debug, Display};
4use std::hash::Hash;
5use std::str::FromStr;
6
7/// Allows nodes to find each other
8#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Debug)]
9pub struct Address {
10    pub(crate) inner: String,
11}
12
13impl FromStr for Address {
14    type Err = String;
15
16    fn from_str(s: &str) -> Result<Self, Self::Err> {
17        Ok(Self {
18            inner: s.to_string(),
19        })
20    }
21}
22
23impl Display for Address {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        write!(f, "{}", self.inner)
26    }
27}
28
29/// The protocol used for the communications.
30pub trait Protocol: Clone + Send + Sync + 'static {
31    /// The client implementation for the current protocol.
32    type Client: ProtocolClient;
33    /// The server implementation for the current protocol.
34    type Server: ProtocolServer;
35}
36
37/// Error that happens during a communication.
38pub trait CommunicationError: Debug + Send + 'static {}
39
40/// The client is only used to create a [channel](CommunicationChannel), which should be use to
41/// transmit information with the [server](ProtocolServer).
42pub trait ProtocolClient: Send + Sync + 'static {
43    /// Channel used by this protocol.
44    type Channel: CommunicationChannel<Error = Self::Error>;
45    /// The error type.
46    type Error: CommunicationError;
47
48    /// Opens a new [channel](CommunicationChannel) with the current protocol at the given
49    /// [address](Address) and route.
50    ///
51    /// * `address` - Address to connect to
52    /// * `route` - The name of the route (no slashes)
53    ///
54    /// Returns None if the connection can't be done.
55    fn connect(address: Address, route: &str) -> DynFut<Option<Self::Channel>>;
56}
57
58/// Data sent and received by the client and server.
59#[derive(new)]
60pub struct Message {
61    /// The data is always encoded as bytes.
62    pub data: bytes::Bytes,
63}
64
65/// Defines how to create a server that respond to a [channel](CommunicationChannel).
66pub trait ProtocolServer: Sized + Send + Sync + 'static {
67    /// Channel used by this protocol.
68    type Channel: CommunicationChannel<Error = Self::Error>;
69    /// The error type.
70    type Error: CommunicationError;
71
72    /// Defines an endpoint with the function that responds.
73    /// TODO Docs: does it need a slash?
74    fn route<C, Fut>(self, path: &str, callback: C) -> Self
75    where
76        C: FnOnce(Self::Channel) -> Fut + Clone + Send + Sync + 'static,
77        Fut: Future<Output = ()> + Send + 'static;
78
79    /// Start the server.
80    fn serve<F>(
81        self,
82        shutdown: F,
83    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static
84    where
85        F: Future<Output = ()> + Send + 'static;
86}
87
88/// Handles communications.
89pub trait CommunicationChannel: Send + 'static {
90    type Error: CommunicationError;
91
92    /// Send a [message](Message) on the channel.
93    fn send(
94        &mut self,
95        message: Message,
96    ) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
97
98    /// Receive a [message](Message) on the channel and returns a new [response message](Message).
99    fn recv(
100        &mut self,
101    ) -> impl std::future::Future<Output = Result<Option<Message>, Self::Error>> + Send;
102
103    fn close(&mut self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
104}