1use std::collections::HashMap;
2use std::future::Future;
3use std::sync::Arc;
4
5use crate::ToBusinessId;
6use crate::ToBytes;
7use crate::err::SendError;
8use crate::protocol::Multiplexer2ServerReceiver;
9use crate::protocol::Response;
10use crate::protocol::ResponsePacket;
11use crate::protocol::Server2MultiplexerSender;
12
13pub struct RpcServerBuilder {
14 recv_req: Option<Multiplexer2ServerReceiver>,
15 send_res: Option<Server2MultiplexerSender>,
16 handlers: HashMap<u64, Arc<dyn Handler + Send + Sync>>,
17}
18
19impl Default for RpcServerBuilder {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25impl RpcServerBuilder {
26 pub fn new() -> Self {
27 Self {
28 recv_req: None,
29 send_res: None,
30 handlers: HashMap::new(),
31 }
32 }
33
34 pub fn add_receiver(mut self, receiver: Multiplexer2ServerReceiver) -> Self {
35 self.recv_req = Some(receiver);
36 self
37 }
38
39 pub fn add_sender(mut self, sender: Server2MultiplexerSender) -> Self {
40 self.send_res = Some(sender);
41 self
42 }
43
44 pub fn add_handler<F, Args, Res>(mut self, business_id: &impl ToBusinessId, func: F) -> Self
45 where
46 F: IntoHandlerWrapper<Args, Res> + 'static,
47 HandlerWrapper<F, Args, Res>: Handler + 'static,
48 Res: ToBytes + 'static,
49 {
50 let business_id = business_id.to_business_id();
51 let wrapper = func.into_handler_wrapper();
52 self.handlers.insert(business_id, Arc::new(wrapper));
53 self
54 }
55
56 pub fn build(self) -> Result<RpcServer, String> {
57 let receiver = self.recv_req.ok_or("receiver is required")?;
58 let sender = self.send_res.ok_or("Sender is required")?;
59 if self.handlers.is_empty() {
60 return Err("at least one handler is required".to_string());
61 }
62 Ok(RpcServer::new(Arc::new(self.handlers), receiver, sender))
63 }
64}
65
66pub struct RpcServer {
67 handlers: Arc<HashMap<u64, Arc<dyn Handler + Send + Sync>>>,
68 receiver: Option<Multiplexer2ServerReceiver>,
69 sender: Option<Server2MultiplexerSender>,
70}
71
72impl RpcServer {
73 pub fn new(
74 handlers: Arc<HashMap<u64, Arc<dyn Handler + Send + Sync>>>,
75 receiver: Multiplexer2ServerReceiver,
76 sender: Server2MultiplexerSender,
77 ) -> Self {
78 Self {
79 handlers,
80 receiver: Some(receiver),
81 sender: Some(sender),
82 }
83 }
84}
85
86impl RpcServer {
87 pub async fn run(&mut self) -> Result<(), String> {
88 let receiver = self.receiver.take();
89 let handlers = self.handlers.clone();
90 let sender = self.sender.take().expect("sender is required");
91
92 tokio::spawn(async move {
93 let mut rx = receiver.expect("receiver is required");
94
95 while let Some(req) = rx.recv().await {
96 if let Some(handler) = handlers.get(&req.business_id).cloned() {
97 let sender = sender.clone();
98 tokio::spawn(async move {
99 let (tx, rc) = tokio::sync::mpsc::channel(1 << 10);
100 let mut ctx = Context {
101 sender: Some(tx),
102 request_id: req.request_id,
103 session_id: 0,
104 data: req.data.clone(),
105 };
106
107 let t = handler.handle(&mut ctx).await;
108
109 if ctx.sender.is_none() {
110 let mut rx = rc;
111
112 let mut last_data = None;
113
114 while let Some(data) = rx.recv().await {
115 if let Some(last_data) = last_data {
116 sender
117 .send(ResponsePacket {
118 response: Response::new(
119 req.request_id,
120 false,
121 last_data,
122 ),
123 })
124 .await
125 .expect("Multiplexer closed unexpectedly");
126 }
127 last_data = Some(data);
128 }
129
130 if let Some(last_data) = last_data {
131 sender
132 .send(ResponsePacket {
133 response: Response::new(req.request_id, true, last_data),
134 })
135 .await
136 .expect("Multiplexer closed unexpectedly");
137 }
138 } else {
139 match t {
140 Ok(data) => {
141 sender
142 .send(ResponsePacket {
143 response: Response::new(req.request_id, true, data),
144 })
145 .await
146 .expect("Multiplexer closed unexpectedly");
147 }
148 Err(t) => {
149 log::error!("handler error: {}", t);
150 panic!("Handler Error: {}", t);
151 }
152 }
153 }
154 });
155 }
156 }
157
158 log::info!("server thread exit");
159 })
160 .await
161 .map_err(|e| e.to_string())?;
162 Ok(())
163 }
164}
165
166pub struct HandlerWrapper<F, Req, Res> {
167 func: F,
168 _phantom: std::marker::PhantomData<(Req, Res)>,
169}
170
171impl<F, Req, Res> HandlerWrapper<F, Req, Res> {
172 pub fn new(func: F) -> Self {
173 HandlerWrapper {
174 func,
175 _phantom: std::marker::PhantomData,
176 }
177 }
178}
179
180pub struct Context {
181 sender: Option<tokio::sync::mpsc::Sender<Vec<u8>>>,
182 pub request_id: u64,
183 pub session_id: u64,
184 pub data: Vec<u8>,
185}
186
187pub enum FromContextError {
188 ParseError(String),
189 Custom(String),
190}
191
192impl std::fmt::Display for FromContextError {
193 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194 match self {
195 FromContextError::ParseError(e) => write!(f, "Parse error: {}", e),
196 FromContextError::Custom(e) => write!(f, "Custom error: {}", e),
197 }
198 }
199}
200
201pub trait FromContext {
202 fn from_context(ctx: &mut Context) -> Result<Self, FromContextError>
203 where
204 Self: Sized;
205}
206
207#[async_trait::async_trait]
208pub trait Handler: Send + Sync {
209 async fn handle(&self, ctx: &mut Context) -> Result<Vec<u8>, FromContextError>;
210}
211
212#[async_trait::async_trait]
213impl<F, Fut, F1, Res> Handler for HandlerWrapper<F, F1, Res>
214where
215 F: Fn(F1) -> Fut + Send + Sync + 'static,
216 Fut: Future<Output = Res> + Send + 'static,
217 F1: FromContext + Send + Sync + 'static,
218 Res: ToBytes + Send + Sync + 'static,
219{
220 async fn handle(&self, ctx: &mut Context) -> Result<Vec<u8>, FromContextError> {
221 let req1 = F1::from_context(ctx)?;
222 let res = (self.func)(req1).await;
223 Ok(res.to_bytes())
224 }
225}
226
227#[async_trait::async_trait]
228impl<F, Fut, F1, F2, Res> Handler for HandlerWrapper<F, (F1, F2), Res>
229where
230 F: Fn(F1, F2) -> Fut + Send + Sync + 'static,
231 Fut: Future<Output = Res> + Send + 'static,
232 F1: FromContext + Send + Sync + 'static,
233 F2: FromContext + Send + Sync + 'static,
234 Res: ToBytes + Send + Sync + 'static,
235{
236 async fn handle(&self, ctx: &mut Context) -> Result<Vec<u8>, FromContextError> {
237 let req1 = F1::from_context(ctx)?;
238 let req2 = F2::from_context(ctx)?;
239 let res = (self.func)(req1, req2).await;
240 Ok(res.to_bytes())
241 }
242}
243
244#[async_trait::async_trait]
245impl<F, Fut, F1, F2, F3, Res> Handler for HandlerWrapper<F, (F1, F2, F3), Res>
246where
247 F: Fn(F1, F2, F3) -> Fut + Send + Sync + 'static,
248 Fut: Future<Output = Res> + Send + 'static,
249 F1: FromContext + Send + Sync + 'static,
250 F2: FromContext + Send + Sync + 'static,
251 F3: FromContext + Send + Sync + 'static,
252 Res: ToBytes + Send + Sync + 'static,
253{
254 async fn handle(&self, ctx: &mut Context) -> Result<Vec<u8>, FromContextError> {
255 let req1 = F1::from_context(ctx)?;
256 let req2 = F2::from_context(ctx)?;
257 let req3 = F3::from_context(ctx)?;
258 let res = (self.func)(req1, req2, req3).await;
259 Ok(res.to_bytes())
260 }
261}
262
263pub trait IntoHandlerWrapper<F1, Res>: Sized {
265 fn into_handler_wrapper(self) -> HandlerWrapper<Self, F1, Res>;
266}
267
268impl<F, Fut, F1, F2, Res> IntoHandlerWrapper<(F1, F2), Res> for F
269where
270 F: Fn(F1, F2) -> Fut + Send + Sync + 'static,
271 Fut: Future<Output = Res> + Send + 'static,
272 F1: FromContext + Send + Sync + 'static,
273 F2: FromContext + Send + Sync + 'static,
274 Res: ToBytes + Send + Sync + 'static,
275{
276 fn into_handler_wrapper(self) -> HandlerWrapper<Self, (F1, F2), Res> {
277 HandlerWrapper::new(self)
278 }
279}
280
281impl<F, Fut, F1, Res> IntoHandlerWrapper<F1, Res> for F
282where
283 F: Fn(F1) -> Fut + Send + Sync + 'static,
284 Fut: Future<Output = Res> + Send + 'static,
285 F1: FromContext + Send + Sync + 'static,
286 Res: ToBytes + Send + Sync + 'static,
287{
288 fn into_handler_wrapper(self) -> HandlerWrapper<Self, F1, Res> {
289 HandlerWrapper::new(self)
290 }
291}
292
293pub struct RespSender(tokio::sync::mpsc::Sender<Vec<u8>>);
296
297impl RespSender {
298 pub fn blocking_send(&self, data: impl ToBytes) -> Result<(), SendError> {
299 self.0
300 .blocking_send(data.to_bytes())
301 .map_err(|_| SendError::SendFailed("failed to send".to_string()))
302 }
303
304 pub async fn send(&self, data: impl ToBytes) -> Result<(), SendError> {
305 self.0
306 .send(data.to_bytes())
307 .await
308 .map_err(|_| SendError::SendFailed("failed to send".to_string()))
309 }
310}
311
312impl FromContext for RespSender {
313 fn from_context(ctx: &mut Context) -> Result<Self, FromContextError> {
314 let sender = ctx
315 .sender
316 .take()
317 .ok_or_else(|| FromContextError::Custom("sender is not available".to_string()))?;
318 Ok(Self(sender))
319 }
320}
321
322impl FromContext for String {
323 fn from_context(ctx: &mut Context) -> Result<Self, FromContextError> {
324 String::from_utf8(ctx.data.clone()).map_err(|e| FromContextError::ParseError(e.to_string()))
325 }
326}
327
328impl FromContext for Vec<u8> {
329 fn from_context(ctx: &mut Context) -> Result<Self, FromContextError> {
330 Ok(ctx.data.clone())
331 }
332}