rpc_it/
service.rs

1use std::{collections::HashMap, error::Error, fmt::Debug, pin::Pin, task::Poll};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    codec::DecodeError,
7    rpc::{ExtractUserData, MessageMethodName, UserData},
8    Message, Notify, RecvMsg, Request, TypedCallError,
9};
10
11pub struct ServiceBuilder<T = ExactMatchRouter>(Service<T>);
12
13pub struct Service<T = ExactMatchRouter> {
14    router: T,
15    methods: Vec<InboundHandler>,
16}
17
18#[cfg(test)]
19static_assertions::assert_impl_all!(Service<ExactMatchRouter>: Send, Sync);
20
21enum InboundHandler {
22    Request(Box<dyn Fn(Request) -> Result<(), RouteMessageError> + Send + Sync + 'static>),
23    Notify(Box<dyn Fn(Notify) -> Result<(), RouteMessageError> + Send + Sync + 'static>),
24}
25
26impl<T: Debug> Debug for Service<T> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("Service")
29            .field("router", &self.router)
30            .field("methods", &self.methods.len())
31            .finish()
32    }
33}
34
35#[derive(thiserror::Error, Debug)]
36pub enum RegisterError {
37    #[error("Given key is already registered")]
38    AlreadyRegistered,
39    #[error("Invalid routing key: {0}")]
40    InvalidRoutingKey(#[from] Box<dyn Error + Send + Sync + 'static>),
41}
42
43pub trait Router: Send + Sync + 'static {
44    /// Register a routing key with the given index.
45    fn register(&mut self, patterns: &[&str], index: usize) -> Result<(), RegisterError>;
46
47    /// Finish the registration process. All errors must've been reported through `register`, thus
48    /// this method should never fail.
49    fn finish(&mut self) {}
50
51    /// Route the given routing key to an index.
52    fn route(&self, routing_key: &str) -> Option<usize>;
53}
54
55impl<T: Default + Router> Default for ServiceBuilder<T> {
56    fn default() -> Self {
57        Self::new(T::default())
58    }
59}
60
61impl<T> ServiceBuilder<T>
62where
63    T: Router,
64{
65    pub fn new(router: T) -> Self {
66        Self(Service { router, methods: Vec::new() })
67    }
68
69    pub fn register_request_handler<Req, Rep, Err>(
70        &mut self,
71        patterns: &[&str],
72        func: impl Fn(TypedRequest<Rep, Err>, Req) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
73            + 'static
74            + Send
75            + Sync,
76    ) -> Result<(), RegisterError>
77    where
78        Req: for<'de> Deserialize<'de>,
79        Rep: Serialize,
80        Err: Serialize,
81    {
82        let index = self.0.methods.len();
83        self.0.methods.push(InboundHandler::Request(Box::new(move |request| {
84            let request = TypedRequest::<Rep, Err>::new(request);
85            let param = match request.parse::<Req>() {
86                Ok(x) => x,
87                Err(e) => {
88                    request.into_request().error_parse_failed_deferred::<Req>().ok();
89                    return Err(RouteMessageError::ParseError(e));
90                }
91            };
92
93            func(request, param)?;
94            Ok(())
95        })));
96        self.0.router.register(patterns, index)
97    }
98
99    pub fn register_notify_handler<Noti>(
100        &mut self,
101        patterns: &[&str],
102        func: impl Fn(crate::Notify, Noti) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
103            + 'static
104            + Send
105            + Sync,
106    ) -> Result<(), RegisterError>
107    where
108        Noti: for<'de> Deserialize<'de>,
109    {
110        let index = self.0.methods.len();
111        self.0.methods.push(InboundHandler::Notify(Box::new(move |msg| {
112            let param = match msg.parse::<Noti>() {
113                Ok(x) => x,
114                Err(e) => {
115                    return Err(RouteMessageError::ParseError(e));
116                }
117            };
118
119            func(msg, param)?;
120            Ok(())
121        })));
122        self.0.router.register(patterns, index)
123    }
124
125    pub fn build(mut self) -> Service<T> {
126        self.0.router.finish();
127        self.0
128    }
129}
130
131impl<T> Service<T>
132where
133    T: Router,
134{
135    pub fn route_message(&self, msg: RecvMsg) -> Result<(), RouteMessageError> {
136        let method = std::str::from_utf8(msg.method_raw())?;
137        let index = self.router.route(method).ok_or(RouteMessageError::MethodNotFound)?;
138        match (self.methods.get(index).ok_or(RouteMessageError::MethodNotFound)?, msg) {
139            (InboundHandler::Request(func), RecvMsg::Request(req)) => func(req),
140            (InboundHandler::Notify(func), RecvMsg::Notify(noti)) => func(noti),
141            (_, RecvMsg::Notify(noti)) => Err(RouteMessageError::NotifyToRequestHandler(noti)),
142            (_, RecvMsg::Request(req)) => {
143                req.abort_deferred().ok();
144                Err(RouteMessageError::RequestToNotifyHandler)
145            }
146        }
147    }
148}
149
150#[derive(Debug, thiserror::Error)]
151pub enum RouteMessageError {
152    #[error("Non-utf method name: {0}")]
153    NonUtfMethodName(#[from] std::str::Utf8Error),
154
155    #[error("Method couldn't be routed")]
156    MethodNotFound,
157
158    #[error("Notify message to request handler")]
159    NotifyToRequestHandler(Notify),
160
161    #[error("Request message to notify handler")]
162    RequestToNotifyHandler,
163
164    #[error("Failed to parse incoming message: {0}")]
165    ParseError(DecodeError),
166
167    #[error("Internal handler returned error")]
168    HandlerError(#[from] Box<dyn Error + Send + Sync + 'static>),
169}
170
171#[doc(hidden)]
172pub mod macro_utils {
173    pub type RegisterResult = Result<(), super::RegisterError>;
174}
175
176/* ---------------------------------------- Typed Request --------------------------------------- */
177#[derive(Debug)]
178pub struct TypedRequest<T, E>(Request, std::marker::PhantomData<(T, E)>);
179
180impl<T, E> ExtractUserData for TypedRequest<T, E> {
181    fn user_data_raw(&self) -> &dyn UserData {
182        self.0.user_data_raw()
183    }
184
185    fn user_data_owned(&self) -> crate::rpc::OwnedUserData {
186        self.0.user_data_owned()
187    }
188
189    fn extract_sender(&self) -> crate::Sender {
190        self.0.extract_sender()
191    }
192}
193
194impl<T, E> TypedRequest<T, E>
195where
196    T: serde::Serialize,
197    E: serde::Serialize,
198{
199    pub fn new(req: Request) -> Self {
200        Self(req, Default::default())
201    }
202
203    pub fn into_request(self) -> Request {
204        self.0
205    }
206
207    pub fn response(self, res: Result<&T, &E>) -> Result<(), super::SendError> {
208        match res {
209            Ok(x) => self.0.response_deferred(Ok(x)),
210            Err(e) => self.0.response_deferred(Err(e)),
211        }
212    }
213
214    pub async fn response_async(self, res: Result<&T, &E>) -> Result<(), super::SendError> {
215        match res {
216            Ok(x) => self.0.response(Ok(x)).await,
217            Err(e) => self.0.response(Err(e)).await,
218        }
219    }
220    pub async fn ok_async(self, value: &T) -> Result<(), super::SendError> {
221        self.0.response(Ok(value)).await
222    }
223
224    pub async fn err_async(self, value: &E) -> Result<(), super::SendError> {
225        self.0.response(Err(value)).await
226    }
227
228    pub fn ok(self, value: &T) -> Result<(), super::SendError> {
229        self.0.response_deferred(Ok(value))
230    }
231
232    pub fn err(self, value: &E) -> Result<(), super::SendError> {
233        self.0.response_deferred(Err(value))
234    }
235}
236
237impl<T, E> std::ops::Deref for TypedRequest<T, E> {
238    type Target = Request;
239
240    fn deref(&self) -> &Self::Target {
241        &self.0
242    }
243}
244
245/* ------------------------------------ Typed Response Future ----------------------------------- */
246#[derive(Debug)]
247pub struct TypedResponse<T, E>(crate::OwnedResponseFuture, std::marker::PhantomData<(T, E)>);
248
249impl<T, E> std::future::Future for TypedResponse<T, E>
250where
251    T: serde::de::DeserializeOwned + Unpin,
252    E: serde::de::DeserializeOwned + Unpin,
253{
254    type Output = Result<T, TypedCallError<E>>;
255
256    fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
257        let this = self.get_mut();
258        let Poll::Ready(msg) = Pin::new(&mut this.0).poll(cx)? else { return Poll::Pending };
259        Poll::Ready(Ok(msg.result::<T, E>()?))
260    }
261}
262
263impl<T, E> TypedResponse<T, E>
264where
265    T: serde::de::DeserializeOwned + Unpin,
266    E: serde::de::DeserializeOwned + Unpin,
267{
268    pub fn new(fut: crate::OwnedResponseFuture) -> Self {
269        Self(fut, Default::default())
270    }
271
272    pub fn try_recv(&mut self) -> Result<Option<T>, TypedCallError<E>> {
273        match self.0.try_recv()? {
274            None => Ok(None),
275            Some(msg) => Ok(Some(msg.result::<T, E>()?)),
276        }
277    }
278}
279
280/* --------------------------------- Basic Router Implementation -------------------------------- */
281#[derive(Debug, Default, Clone)]
282pub struct ExactMatchRouter {
283    routes: HashMap<String, usize>,
284}
285
286impl Router for ExactMatchRouter {
287    fn register(&mut self, pattern: &[&str], index: usize) -> Result<(), RegisterError> {
288        for pat in pattern.into_iter().copied() {
289            if self.routes.contains_key(pat) {
290                return Err(RegisterError::AlreadyRegistered);
291            }
292
293            self.routes.insert(pat.to_owned(), index);
294        }
295
296        Ok(())
297    }
298
299    fn route(&self, routing_key: &str) -> Option<usize> {
300        self.routes.get(routing_key).copied()
301    }
302}