datacake_rpc/server.rs
1use std::collections::{BTreeMap, BTreeSet};
2use std::io;
3use std::net::SocketAddr;
4use std::sync::Arc;
5
6use parking_lot::{Mutex, RwLock};
7use tokio::task::JoinHandle;
8
9use crate::handler::{HandlerKey, OpaqueMessageHandler, RpcService, ServiceRegistry};
10
11/// A RPC server instance.
12///
13/// This is used for listening for inbound connections and handling any RPC messages
14/// coming from clients.
15///
16/// ```rust
17/// use rkyv::{Archive, Deserialize, Serialize};
18/// use datacake_rpc::{Server, Handler, Request, RpcService, ServiceRegistry, Status};
19/// use std::net::SocketAddr;
20///
21/// #[repr(C)]
22/// #[derive(Serialize, Deserialize, Archive, PartialEq, Debug)]
23/// #[archive(compare(PartialEq), check_bytes)]
24/// #[archive_attr(derive(PartialEq, Debug))]
25/// pub struct MyMessage {
26/// name: String,
27/// age: u32,
28/// }
29///
30/// pub struct EchoService;
31///
32/// impl RpcService for EchoService {
33/// fn register_handlers(registry: &mut ServiceRegistry<Self>) {
34/// registry.add_handler::<MyMessage>();
35/// }
36/// }
37///
38/// #[datacake_rpc::async_trait]
39/// impl Handler<MyMessage> for EchoService {
40/// type Reply = MyMessage;
41///
42/// async fn on_message(&self, msg: Request<MyMessage>) -> Result<Self::Reply, Status> {
43/// Ok(msg.to_owned().unwrap())
44/// }
45/// }
46///
47/// # #[tokio::main]
48/// # async fn main() -> anyhow::Result<()> {
49/// let bind = "127.0.0.1:8000".parse::<SocketAddr>()?;
50/// // Start the RPC server listening on our bind address.
51/// let server = Server::listen(bind).await?;
52///
53/// // Once our server is running we can add or remove services.
54/// // Once a service is added it is able to begin handling RPC messages.
55/// server.add_service(EchoService);
56///
57/// // Once a service is removed the server will reject messages for the
58/// // service that is no longer registered,
59/// server.remove_service(EchoService::service_name());
60///
61/// // We can add wait() here if we want to listen for messages forever.
62/// // server.wait().await;
63/// # Ok(())
64/// # }
65/// ```
66pub struct Server {
67 state: ServerState,
68 handle: JoinHandle<()>,
69}
70
71impl Server {
72 /// Spawns the RPC server task and returns the server handle.
73 pub async fn listen(addr: SocketAddr) -> io::Result<Self> {
74 let state = ServerState::default();
75 let handle = crate::net::start_rpc_server(addr, state.clone()).await?;
76
77 Ok(Self { state, handle })
78 }
79
80 /// Adds a new service to the live RPC server.
81 pub fn add_service<Svc>(&self, service: Svc)
82 where
83 Svc: RpcService + Send + Sync + 'static,
84 {
85 let mut registry = ServiceRegistry::new(service);
86 Svc::register_handlers(&mut registry);
87 let handlers = registry.into_handlers();
88 self.state.add_handlers(Svc::service_name(), handlers);
89 }
90
91 /// Removes all handlers linked with the given service name.
92 pub fn remove_service(&self, service_name: &str) {
93 self.state.remove_handlers(service_name);
94 }
95
96 /// Signals the server to shutdown.
97 pub fn shutdown(self) {
98 self.handle.abort();
99 }
100
101 /// Waits until the server exits.
102 ///
103 /// This typically is just a future that pends forever as the server
104 /// will not exit unless an external force triggers it.
105 pub async fn wait(self) {
106 self.handle.await.expect("Wait for server handle.");
107 }
108}
109
110#[derive(Clone, Default)]
111/// Represents the shared state of the RPC server.
112pub(crate) struct ServerState {
113 services: Arc<Mutex<BTreeMap<String, BTreeSet<HandlerKey>>>>,
114 handlers: Arc<RwLock<BTreeMap<HandlerKey, Arc<dyn OpaqueMessageHandler>>>>,
115}
116
117impl ServerState {
118 /// Adds a new set of handlers to the server state.
119 ///
120 /// Handlers newly added will then be able to handle messages received by
121 /// the already running RPC system.
122 pub(crate) fn add_handlers(
123 &self,
124 service_name: &str,
125 handlers: BTreeMap<HandlerKey, Arc<dyn OpaqueMessageHandler>>,
126 ) {
127 {
128 let mut lock = self.services.lock();
129 for key in handlers.keys() {
130 lock.entry(service_name.to_string())
131 .or_default()
132 .insert(*key);
133 }
134 }
135
136 let mut lock = self.handlers.write();
137 lock.extend(handlers);
138 }
139
140 /// Removes a new set of handlers from the server state.
141 pub(crate) fn remove_handlers(&self, service: &str) {
142 let uris = {
143 match self.services.lock().remove(service) {
144 None => return,
145 Some(uris) => uris,
146 }
147 };
148
149 let mut lock = self.handlers.write();
150 lock.retain(|key, _| uris.contains(key));
151 }
152
153 /// Attempts to get the message handler for a specific service and message.
154 pub(crate) fn get_handler(
155 &self,
156 uri: &str,
157 ) -> Option<Arc<dyn OpaqueMessageHandler>> {
158 let lock = self.handlers.read();
159 lock.get(&crate::hash(uri)).cloned()
160 }
161}