Skip to main content

aeron_rpc/
server.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::sync::Arc;
4
5use crate::ToBusinessId;
6use crate::ToBytes;
7use crate::err::SendError;
8use crate::protocol::Multiplexer2ServerReceiver;
9use crate::protocol::Response;
10use crate::protocol::ResponsePacket;
11use crate::protocol::Server2MultiplexerSender;
12
13pub struct RpcServerBuilder {
14    recv_req: Option<Multiplexer2ServerReceiver>,
15    send_res: Option<Server2MultiplexerSender>,
16    handlers: HashMap<u64, Arc<dyn Handler + Send + Sync>>,
17}
18
19impl Default for RpcServerBuilder {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25impl RpcServerBuilder {
26    pub fn new() -> Self {
27        Self {
28            recv_req: None,
29            send_res: None,
30            handlers: HashMap::new(),
31        }
32    }
33
34    pub fn add_receiver(mut self, receiver: Multiplexer2ServerReceiver) -> Self {
35        self.recv_req = Some(receiver);
36        self
37    }
38
39    pub fn add_sender(mut self, sender: Server2MultiplexerSender) -> Self {
40        self.send_res = Some(sender);
41        self
42    }
43
44    pub fn add_handler<F, Args, Res>(mut self, business_id: &impl ToBusinessId, func: F) -> Self
45    where
46        F: IntoHandlerWrapper<Args, Res> + 'static,
47        HandlerWrapper<F, Args, Res>: Handler + 'static,
48        Res: ToBytes + 'static,
49    {
50        let business_id = business_id.to_business_id();
51        let wrapper = func.into_handler_wrapper();
52        self.handlers.insert(business_id, Arc::new(wrapper));
53        self
54    }
55
56    pub fn build(self) -> Result<RpcServer, String> {
57        let receiver = self.recv_req.ok_or("receiver is required")?;
58        let sender = self.send_res.ok_or("Sender is required")?;
59        if self.handlers.is_empty() {
60            return Err("at least one handler is required".to_string());
61        }
62        Ok(RpcServer::new(Arc::new(self.handlers), receiver, sender))
63    }
64}
65
66pub struct RpcServer {
67    handlers: Arc<HashMap<u64, Arc<dyn Handler + Send + Sync>>>,
68    receiver: Option<Multiplexer2ServerReceiver>,
69    sender: Option<Server2MultiplexerSender>,
70}
71
72impl RpcServer {
73    pub fn new(
74        handlers: Arc<HashMap<u64, Arc<dyn Handler + Send + Sync>>>,
75        receiver: Multiplexer2ServerReceiver,
76        sender: Server2MultiplexerSender,
77    ) -> Self {
78        Self {
79            handlers,
80            receiver: Some(receiver),
81            sender: Some(sender),
82        }
83    }
84}
85
86impl RpcServer {
87    pub async fn run(&mut self) -> Result<(), String> {
88        let receiver = self.receiver.take();
89        let handlers = self.handlers.clone();
90        let sender = self.sender.take().expect("sender is required");
91
92        tokio::spawn(async move {
93            let mut rx = receiver.expect("receiver is required");
94
95            while let Some(req) = rx.recv().await {
96                if let Some(handler) = handlers.get(&req.business_id).cloned() {
97                    let sender = sender.clone();
98                    tokio::spawn(async move {
99                        let (tx, rc) = tokio::sync::mpsc::channel(1 << 10);
100                        let mut ctx = Context {
101                            sender: Some(tx),
102                            request_id: req.request_id,
103                            session_id: 0,
104                            data: req.data.clone(),
105                        };
106
107                        let t = handler.handle(&mut ctx).await;
108
109                        if ctx.sender.is_none() {
110                            let mut rx = rc;
111
112                            let mut last_data = None;
113
114                            while let Some(data) = rx.recv().await {
115                                if let Some(last_data) = last_data {
116                                    sender
117                                        .send(ResponsePacket {
118                                            response: Response::new(
119                                                req.request_id,
120                                                false,
121                                                last_data,
122                                            ),
123                                        })
124                                        .await
125                                        .expect("Multiplexer closed unexpectedly");
126                                }
127                                last_data = Some(data);
128                            }
129
130                            if let Some(last_data) = last_data {
131                                sender
132                                    .send(ResponsePacket {
133                                        response: Response::new(req.request_id, true, last_data),
134                                    })
135                                    .await
136                                    .expect("Multiplexer closed unexpectedly");
137                            }
138                        } else {
139                            match t {
140                                Ok(data) => {
141                                    sender
142                                        .send(ResponsePacket {
143                                            response: Response::new(req.request_id, true, data),
144                                        })
145                                        .await
146                                        .expect("Multiplexer closed unexpectedly");
147                                }
148                                Err(t) => {
149                                    log::error!("handler error: {}", t);
150                                    panic!("Handler Error: {}", t);
151                                }
152                            }
153                        }
154                    });
155                }
156            }
157
158            log::info!("server thread exit");
159        })
160        .await
161        .map_err(|e| e.to_string())?;
162        Ok(())
163    }
164}
165
166pub struct HandlerWrapper<F, Req, Res> {
167    func: F,
168    _phantom: std::marker::PhantomData<(Req, Res)>,
169}
170
171impl<F, Req, Res> HandlerWrapper<F, Req, Res> {
172    pub fn new(func: F) -> Self {
173        HandlerWrapper {
174            func,
175            _phantom: std::marker::PhantomData,
176        }
177    }
178}
179
180pub struct Context {
181    sender: Option<tokio::sync::mpsc::Sender<Vec<u8>>>,
182    pub request_id: u64,
183    pub session_id: u64,
184    pub data: Vec<u8>,
185}
186
187pub enum FromContextError {
188    ParseError(String),
189    Custom(String),
190}
191
192impl std::fmt::Display for FromContextError {
193    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
194        match self {
195            FromContextError::ParseError(e) => write!(f, "Parse error: {}", e),
196            FromContextError::Custom(e) => write!(f, "Custom error: {}", e),
197        }
198    }
199}
200
201pub trait FromContext {
202    fn from_context(ctx: &mut Context) -> Result<Self, FromContextError>
203    where
204        Self: Sized;
205}
206
207#[async_trait::async_trait]
208pub trait Handler: Send + Sync {
209    async fn handle(&self, ctx: &mut Context) -> Result<Vec<u8>, FromContextError>;
210}
211
212#[async_trait::async_trait]
213impl<F, Fut, F1, Res> Handler for HandlerWrapper<F, F1, Res>
214where
215    F: Fn(F1) -> Fut + Send + Sync + 'static,
216    Fut: Future<Output = Res> + Send + 'static,
217    F1: FromContext + Send + Sync + 'static,
218    Res: ToBytes + Send + Sync + 'static,
219{
220    async fn handle(&self, ctx: &mut Context) -> Result<Vec<u8>, FromContextError> {
221        let req1 = F1::from_context(ctx)?;
222        let res = (self.func)(req1).await;
223        Ok(res.to_bytes())
224    }
225}
226
227#[async_trait::async_trait]
228impl<F, Fut, F1, F2, Res> Handler for HandlerWrapper<F, (F1, F2), Res>
229where
230    F: Fn(F1, F2) -> Fut + Send + Sync + 'static,
231    Fut: Future<Output = Res> + Send + 'static,
232    F1: FromContext + Send + Sync + 'static,
233    F2: FromContext + Send + Sync + 'static,
234    Res: ToBytes + Send + Sync + 'static,
235{
236    async fn handle(&self, ctx: &mut Context) -> Result<Vec<u8>, FromContextError> {
237        let req1 = F1::from_context(ctx)?;
238        let req2 = F2::from_context(ctx)?;
239        let res = (self.func)(req1, req2).await;
240        Ok(res.to_bytes())
241    }
242}
243
244#[async_trait::async_trait]
245impl<F, Fut, F1, F2, F3, Res> Handler for HandlerWrapper<F, (F1, F2, F3), Res>
246where
247    F: Fn(F1, F2, F3) -> Fut + Send + Sync + 'static,
248    Fut: Future<Output = Res> + Send + 'static,
249    F1: FromContext + Send + Sync + 'static,
250    F2: FromContext + Send + Sync + 'static,
251    F3: FromContext + Send + Sync + 'static,
252    Res: ToBytes + Send + Sync + 'static,
253{
254    async fn handle(&self, ctx: &mut Context) -> Result<Vec<u8>, FromContextError> {
255        let req1 = F1::from_context(ctx)?;
256        let req2 = F2::from_context(ctx)?;
257        let req3 = F3::from_context(ctx)?;
258        let res = (self.func)(req1, req2, req3).await;
259        Ok(res.to_bytes())
260    }
261}
262
263/// Helper trait for converting a function into a HandlerWrapper.
264pub trait IntoHandlerWrapper<F1, Res>: Sized {
265    fn into_handler_wrapper(self) -> HandlerWrapper<Self, F1, Res>;
266}
267
268impl<F, Fut, F1, F2, Res> IntoHandlerWrapper<(F1, F2), Res> for F
269where
270    F: Fn(F1, F2) -> Fut + Send + Sync + 'static,
271    Fut: Future<Output = Res> + Send + 'static,
272    F1: FromContext + Send + Sync + 'static,
273    F2: FromContext + Send + Sync + 'static,
274    Res: ToBytes + Send + Sync + 'static,
275{
276    fn into_handler_wrapper(self) -> HandlerWrapper<Self, (F1, F2), Res> {
277        HandlerWrapper::new(self)
278    }
279}
280
281impl<F, Fut, F1, Res> IntoHandlerWrapper<F1, Res> for F
282where
283    F: Fn(F1) -> Fut + Send + Sync + 'static,
284    Fut: Future<Output = Res> + Send + 'static,
285    F1: FromContext + Send + Sync + 'static,
286    Res: ToBytes + Send + Sync + 'static,
287{
288    fn into_handler_wrapper(self) -> HandlerWrapper<Self, F1, Res> {
289        HandlerWrapper::new(self)
290    }
291}
292
293/// The result of a handler will be ignored if using RespSender.
294/// Direct return data from the handler if just one data needs to be sent.
295pub struct RespSender(tokio::sync::mpsc::Sender<Vec<u8>>);
296
297impl RespSender {
298    pub fn blocking_send(&self, data: impl ToBytes) -> Result<(), SendError> {
299        self.0
300            .blocking_send(data.to_bytes())
301            .map_err(|_| SendError::SendFailed("failed to send".to_string()))
302    }
303
304    pub async fn send(&self, data: impl ToBytes) -> Result<(), SendError> {
305        self.0
306            .send(data.to_bytes())
307            .await
308            .map_err(|_| SendError::SendFailed("failed to send".to_string()))
309    }
310}
311
312impl FromContext for RespSender {
313    fn from_context(ctx: &mut Context) -> Result<Self, FromContextError> {
314        let sender = ctx
315            .sender
316            .take()
317            .ok_or_else(|| FromContextError::Custom("sender is not available".to_string()))?;
318        Ok(Self(sender))
319    }
320}
321
322impl FromContext for String {
323    fn from_context(ctx: &mut Context) -> Result<Self, FromContextError> {
324        String::from_utf8(ctx.data.clone()).map_err(|e| FromContextError::ParseError(e.to_string()))
325    }
326}
327
328impl FromContext for Vec<u8> {
329    fn from_context(ctx: &mut Context) -> Result<Self, FromContextError> {
330        Ok(ctx.data.clone())
331    }
332}