1use std::{collections::HashMap, error::Error, fmt::Debug, pin::Pin, task::Poll};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6 codec::DecodeError,
7 rpc::{ExtractUserData, MessageMethodName, UserData},
8 Message, Notify, RecvMsg, Request, TypedCallError,
9};
10
11pub struct ServiceBuilder<T = ExactMatchRouter>(Service<T>);
12
13pub struct Service<T = ExactMatchRouter> {
14 router: T,
15 methods: Vec<InboundHandler>,
16}
17
18#[cfg(test)]
19static_assertions::assert_impl_all!(Service<ExactMatchRouter>: Send, Sync);
20
21enum InboundHandler {
22 Request(Box<dyn Fn(Request) -> Result<(), RouteMessageError> + Send + Sync + 'static>),
23 Notify(Box<dyn Fn(Notify) -> Result<(), RouteMessageError> + Send + Sync + 'static>),
24}
25
26impl<T: Debug> Debug for Service<T> {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("Service")
29 .field("router", &self.router)
30 .field("methods", &self.methods.len())
31 .finish()
32 }
33}
34
35#[derive(thiserror::Error, Debug)]
36pub enum RegisterError {
37 #[error("Given key is already registered")]
38 AlreadyRegistered,
39 #[error("Invalid routing key: {0}")]
40 InvalidRoutingKey(#[from] Box<dyn Error + Send + Sync + 'static>),
41}
42
43pub trait Router: Send + Sync + 'static {
44 fn register(&mut self, patterns: &[&str], index: usize) -> Result<(), RegisterError>;
46
47 fn finish(&mut self) {}
50
51 fn route(&self, routing_key: &str) -> Option<usize>;
53}
54
55impl<T: Default + Router> Default for ServiceBuilder<T> {
56 fn default() -> Self {
57 Self::new(T::default())
58 }
59}
60
61impl<T> ServiceBuilder<T>
62where
63 T: Router,
64{
65 pub fn new(router: T) -> Self {
66 Self(Service { router, methods: Vec::new() })
67 }
68
69 pub fn register_request_handler<Req, Rep, Err>(
70 &mut self,
71 patterns: &[&str],
72 func: impl Fn(TypedRequest<Rep, Err>, Req) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
73 + 'static
74 + Send
75 + Sync,
76 ) -> Result<(), RegisterError>
77 where
78 Req: for<'de> Deserialize<'de>,
79 Rep: Serialize,
80 Err: Serialize,
81 {
82 let index = self.0.methods.len();
83 self.0.methods.push(InboundHandler::Request(Box::new(move |request| {
84 let request = TypedRequest::<Rep, Err>::new(request);
85 let param = match request.parse::<Req>() {
86 Ok(x) => x,
87 Err(e) => {
88 request.into_request().error_parse_failed_deferred::<Req>().ok();
89 return Err(RouteMessageError::ParseError(e));
90 }
91 };
92
93 func(request, param)?;
94 Ok(())
95 })));
96 self.0.router.register(patterns, index)
97 }
98
99 pub fn register_notify_handler<Noti>(
100 &mut self,
101 patterns: &[&str],
102 func: impl Fn(crate::Notify, Noti) -> Result<(), Box<dyn Error + Send + Sync + 'static>>
103 + 'static
104 + Send
105 + Sync,
106 ) -> Result<(), RegisterError>
107 where
108 Noti: for<'de> Deserialize<'de>,
109 {
110 let index = self.0.methods.len();
111 self.0.methods.push(InboundHandler::Notify(Box::new(move |msg| {
112 let param = match msg.parse::<Noti>() {
113 Ok(x) => x,
114 Err(e) => {
115 return Err(RouteMessageError::ParseError(e));
116 }
117 };
118
119 func(msg, param)?;
120 Ok(())
121 })));
122 self.0.router.register(patterns, index)
123 }
124
125 pub fn build(mut self) -> Service<T> {
126 self.0.router.finish();
127 self.0
128 }
129}
130
131impl<T> Service<T>
132where
133 T: Router,
134{
135 pub fn route_message(&self, msg: RecvMsg) -> Result<(), RouteMessageError> {
136 let method = std::str::from_utf8(msg.method_raw())?;
137 let index = self.router.route(method).ok_or(RouteMessageError::MethodNotFound)?;
138 match (self.methods.get(index).ok_or(RouteMessageError::MethodNotFound)?, msg) {
139 (InboundHandler::Request(func), RecvMsg::Request(req)) => func(req),
140 (InboundHandler::Notify(func), RecvMsg::Notify(noti)) => func(noti),
141 (_, RecvMsg::Notify(noti)) => Err(RouteMessageError::NotifyToRequestHandler(noti)),
142 (_, RecvMsg::Request(req)) => {
143 req.abort_deferred().ok();
144 Err(RouteMessageError::RequestToNotifyHandler)
145 }
146 }
147 }
148}
149
150#[derive(Debug, thiserror::Error)]
151pub enum RouteMessageError {
152 #[error("Non-utf method name: {0}")]
153 NonUtfMethodName(#[from] std::str::Utf8Error),
154
155 #[error("Method couldn't be routed")]
156 MethodNotFound,
157
158 #[error("Notify message to request handler")]
159 NotifyToRequestHandler(Notify),
160
161 #[error("Request message to notify handler")]
162 RequestToNotifyHandler,
163
164 #[error("Failed to parse incoming message: {0}")]
165 ParseError(DecodeError),
166
167 #[error("Internal handler returned error")]
168 HandlerError(#[from] Box<dyn Error + Send + Sync + 'static>),
169}
170
171#[doc(hidden)]
172pub mod macro_utils {
173 pub type RegisterResult = Result<(), super::RegisterError>;
174}
175
176#[derive(Debug)]
178pub struct TypedRequest<T, E>(Request, std::marker::PhantomData<(T, E)>);
179
180impl<T, E> ExtractUserData for TypedRequest<T, E> {
181 fn user_data_raw(&self) -> &dyn UserData {
182 self.0.user_data_raw()
183 }
184
185 fn user_data_owned(&self) -> crate::rpc::OwnedUserData {
186 self.0.user_data_owned()
187 }
188
189 fn extract_sender(&self) -> crate::Sender {
190 self.0.extract_sender()
191 }
192}
193
194impl<T, E> TypedRequest<T, E>
195where
196 T: serde::Serialize,
197 E: serde::Serialize,
198{
199 pub fn new(req: Request) -> Self {
200 Self(req, Default::default())
201 }
202
203 pub fn into_request(self) -> Request {
204 self.0
205 }
206
207 pub fn response(self, res: Result<&T, &E>) -> Result<(), super::SendError> {
208 match res {
209 Ok(x) => self.0.response_deferred(Ok(x)),
210 Err(e) => self.0.response_deferred(Err(e)),
211 }
212 }
213
214 pub async fn response_async(self, res: Result<&T, &E>) -> Result<(), super::SendError> {
215 match res {
216 Ok(x) => self.0.response(Ok(x)).await,
217 Err(e) => self.0.response(Err(e)).await,
218 }
219 }
220 pub async fn ok_async(self, value: &T) -> Result<(), super::SendError> {
221 self.0.response(Ok(value)).await
222 }
223
224 pub async fn err_async(self, value: &E) -> Result<(), super::SendError> {
225 self.0.response(Err(value)).await
226 }
227
228 pub fn ok(self, value: &T) -> Result<(), super::SendError> {
229 self.0.response_deferred(Ok(value))
230 }
231
232 pub fn err(self, value: &E) -> Result<(), super::SendError> {
233 self.0.response_deferred(Err(value))
234 }
235}
236
237impl<T, E> std::ops::Deref for TypedRequest<T, E> {
238 type Target = Request;
239
240 fn deref(&self) -> &Self::Target {
241 &self.0
242 }
243}
244
245#[derive(Debug)]
247pub struct TypedResponse<T, E>(crate::OwnedResponseFuture, std::marker::PhantomData<(T, E)>);
248
249impl<T, E> std::future::Future for TypedResponse<T, E>
250where
251 T: serde::de::DeserializeOwned + Unpin,
252 E: serde::de::DeserializeOwned + Unpin,
253{
254 type Output = Result<T, TypedCallError<E>>;
255
256 fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
257 let this = self.get_mut();
258 let Poll::Ready(msg) = Pin::new(&mut this.0).poll(cx)? else { return Poll::Pending };
259 Poll::Ready(Ok(msg.result::<T, E>()?))
260 }
261}
262
263impl<T, E> TypedResponse<T, E>
264where
265 T: serde::de::DeserializeOwned + Unpin,
266 E: serde::de::DeserializeOwned + Unpin,
267{
268 pub fn new(fut: crate::OwnedResponseFuture) -> Self {
269 Self(fut, Default::default())
270 }
271
272 pub fn try_recv(&mut self) -> Result<Option<T>, TypedCallError<E>> {
273 match self.0.try_recv()? {
274 None => Ok(None),
275 Some(msg) => Ok(Some(msg.result::<T, E>()?)),
276 }
277 }
278}
279
280#[derive(Debug, Default, Clone)]
282pub struct ExactMatchRouter {
283 routes: HashMap<String, usize>,
284}
285
286impl Router for ExactMatchRouter {
287 fn register(&mut self, pattern: &[&str], index: usize) -> Result<(), RegisterError> {
288 for pat in pattern.into_iter().copied() {
289 if self.routes.contains_key(pat) {
290 return Err(RegisterError::AlreadyRegistered);
291 }
292
293 self.routes.insert(pat.to_owned(), index);
294 }
295
296 Ok(())
297 }
298
299 fn route(&self, routing_key: &str) -> Option<usize> {
300 self.routes.get(routing_key).copied()
301 }
302}