datacake_rpc/
server.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::io;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use parking_lot::{Mutex, RwLock};
7use tokio::task::JoinHandle;
8
9use crate::handler::{HandlerKey, OpaqueMessageHandler, RpcService, ServiceRegistry};
10
11/// A RPC server instance.
12///
13/// This is used for listening for inbound connections and handling any RPC messages
14/// coming from clients.
15///
16/// ```rust
17/// use rkyv::{Archive, Deserialize, Serialize};
18/// use datacake_rpc::{Server, Handler, Request, RpcService, ServiceRegistry, Status};
19/// use std::net::SocketAddr;
20///
21/// #[repr(C)]
22/// #[derive(Serialize, Deserialize, Archive, PartialEq, Debug)]
23/// #[archive(compare(PartialEq), check_bytes)]
24/// #[archive_attr(derive(PartialEq, Debug))]
25/// pub struct MyMessage {
26///     name: String,
27///     age: u32,
28/// }
29///
30/// pub struct EchoService;
31///
32/// impl RpcService for EchoService {
33///     fn register_handlers(registry: &mut ServiceRegistry<Self>) {
34///         registry.add_handler::<MyMessage>();
35///     }
36/// }
37///
38/// #[datacake_rpc::async_trait]
39/// impl Handler<MyMessage> for EchoService {
40///     type Reply = MyMessage;
41///
42///     async fn on_message(&self, msg: Request<MyMessage>) -> Result<Self::Reply, Status> {
43///         Ok(msg.to_owned().unwrap())
44///     }
45/// }
46///
47/// # #[tokio::main]
48/// # async fn main() -> anyhow::Result<()> {
49/// let bind = "127.0.0.1:8000".parse::<SocketAddr>()?;
50/// // Start the RPC server listening on our bind address.
51/// let server = Server::listen(bind).await?;
52///
53/// // Once our server is running we can add or remove services.
54/// // Once a service is added it is able to begin handling RPC messages.
55/// server.add_service(EchoService);
56///
57/// // Once a service is removed the server will reject messages for the
58/// // service that is no longer registered,
59/// server.remove_service(EchoService::service_name());
60///
61/// // We can add wait() here if we want to listen for messages forever.
62/// // server.wait().await;
63/// # Ok(())
64/// # }
65/// ```
66pub struct Server {
67    state: ServerState,
68    handle: JoinHandle<()>,
69}
70
71impl Server {
72    /// Spawns the RPC server task and returns the server handle.
73    pub async fn listen(addr: SocketAddr) -> io::Result<Self> {
74        let state = ServerState::default();
75        let handle = crate::net::start_rpc_server(addr, state.clone()).await?;
76
77        Ok(Self { state, handle })
78    }
79
80    /// Adds a new service to the live RPC server.
81    pub fn add_service<Svc>(&self, service: Svc)
82    where
83        Svc: RpcService + Send + Sync + 'static,
84    {
85        let mut registry = ServiceRegistry::new(service);
86        Svc::register_handlers(&mut registry);
87        let handlers = registry.into_handlers();
88        self.state.add_handlers(Svc::service_name(), handlers);
89    }
90
91    /// Removes all handlers linked with the given service name.
92    pub fn remove_service(&self, service_name: &str) {
93        self.state.remove_handlers(service_name);
94    }
95
96    /// Signals the server to shutdown.
97    pub fn shutdown(self) {
98        self.handle.abort();
99    }
100
101    /// Waits until the server exits.
102    ///
103    /// This typically is just a future that pends forever as the server
104    /// will not exit unless an external force triggers it.
105    pub async fn wait(self) {
106        self.handle.await.expect("Wait for server handle.");
107    }
108}
109
110#[derive(Clone, Default)]
111/// Represents the shared state of the RPC server.
112pub(crate) struct ServerState {
113    services: Arc<Mutex<BTreeMap<String, BTreeSet<HandlerKey>>>>,
114    handlers: Arc<RwLock<BTreeMap<HandlerKey, Arc<dyn OpaqueMessageHandler>>>>,
115}
116
117impl ServerState {
118    /// Adds a new set of handlers to the server state.
119    ///
120    /// Handlers newly added will then be able to handle messages received by
121    /// the already running RPC system.
122    pub(crate) fn add_handlers(
123        &self,
124        service_name: &str,
125        handlers: BTreeMap<HandlerKey, Arc<dyn OpaqueMessageHandler>>,
126    ) {
127        {
128            let mut lock = self.services.lock();
129            for key in handlers.keys() {
130                lock.entry(service_name.to_string())
131                    .or_default()
132                    .insert(*key);
133            }
134        }
135
136        let mut lock = self.handlers.write();
137        lock.extend(handlers);
138    }
139
140    /// Removes a new set of handlers from the server state.
141    pub(crate) fn remove_handlers(&self, service: &str) {
142        let uris = {
143            match self.services.lock().remove(service) {
144                None => return,
145                Some(uris) => uris,
146            }
147        };
148
149        let mut lock = self.handlers.write();
150        lock.retain(|key, _| uris.contains(key));
151    }
152
153    /// Attempts to get the message handler for a specific service and message.
154    pub(crate) fn get_handler(
155        &self,
156        uri: &str,
157    ) -> Option<Arc<dyn OpaqueMessageHandler>> {
158        let lock = self.handlers.read();
159        lock.get(&crate::hash(uri)).cloned()
160    }
161}