xrpc/
server.rs

1use async_trait::async_trait;
2use bytes::Bytes;
3use futures::Stream;
4use parking_lot::RwLock;
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::future::Future;
8use std::sync::Arc;
9use tokio::sync::mpsc;
10
11use crate::channel::message::MessageChannel;
12use crate::codec::{BincodeCodec, Codec};
13use crate::error::{Result, RpcError};
14use crate::message::Message;
15use crate::message::metadata::MessageMetadata;
16use crate::message::types::{MessageId, MessageType};
17use crate::streaming::{StreamId, next_stream_id};
18
19#[async_trait]
20pub trait Handler<C: Codec>: Send + Sync {
21    async fn handle(&self, request: Message<C>, codec: &C) -> Result<Message<C>>;
22    fn method_name(&self) -> &str;
23}
24
25#[async_trait]
26pub trait StreamHandler<C: Codec>: Send + Sync {
27    async fn handle(
28        &self,
29        request: Message<C>,
30        sender: ServerStreamSender<C>,
31        codec: &C,
32    ) -> Result<()>;
33    fn method_name(&self) -> &str;
34}
35
36pub struct FnHandler<F, C> {
37    method: String,
38    func: Arc<F>,
39    _codec: std::marker::PhantomData<C>,
40}
41
42impl<F, Fut, C> FnHandler<F, C>
43where
44    F: Fn(Message<C>) -> Fut + Send + Sync + 'static,
45    Fut: Future<Output = Result<Message<C>>> + Send + 'static,
46    C: Codec,
47{
48    pub fn new(method: impl Into<String>, func: F) -> Self {
49        Self {
50            method: method.into(),
51            func: Arc::new(func),
52            _codec: std::marker::PhantomData,
53        }
54    }
55}
56
57#[async_trait]
58impl<F, Fut, C: Codec + Default> Handler<C> for FnHandler<F, C>
59where
60    F: Fn(Message<C>) -> Fut + Send + Sync + 'static,
61    Fut: Future<Output = Result<Message<C>>> + Send + 'static,
62{
63    async fn handle(&self, request: Message<C>, _codec: &C) -> Result<Message<C>> {
64        (self.func)(request).await
65    }
66
67    fn method_name(&self) -> &str {
68        &self.method
69    }
70}
71
72pub struct TypedHandler<Req, Resp, F, C> {
73    method: String,
74    func: Arc<F>,
75    _phantom: std::marker::PhantomData<(Req, Resp, C)>,
76}
77
78impl<Req, Resp, F, Fut, C> TypedHandler<Req, Resp, F, C>
79where
80    Req: for<'de> Deserialize<'de> + Send + 'static,
81    Resp: Serialize + Send + 'static,
82    F: Fn(Req) -> Fut + Send + Sync + 'static,
83    Fut: Future<Output = Result<Resp>> + Send + 'static,
84    C: Codec,
85{
86    pub fn new(method: impl Into<String>, func: F) -> Self {
87        Self {
88            method: method.into(),
89            func: Arc::new(func),
90            _phantom: std::marker::PhantomData,
91        }
92    }
93}
94
95#[async_trait]
96impl<Req, Resp, F, Fut, C> Handler<C> for TypedHandler<Req, Resp, F, C>
97where
98    Req: for<'de> Deserialize<'de> + Send + Sync + 'static,
99    Resp: Serialize + Send + Sync + 'static,
100    F: Fn(Req) -> Fut + Send + Sync + 'static,
101    Fut: Future<Output = Result<Resp>> + Send + 'static,
102    C: Codec + Default,
103{
104    async fn handle(&self, request: Message<C>, codec: &C) -> Result<Message<C>> {
105        let req: Req = codec.decode(&request.payload)?;
106        let resp = (self.func)(req).await?;
107        let payload = codec.encode(&resp)?;
108        Ok(Message::new(
109            request.id,
110            MessageType::Reply,
111            "",
112            Bytes::from(payload),
113            MessageMetadata::new(),
114        ))
115    }
116
117    fn method_name(&self) -> &str {
118        &self.method
119    }
120}
121
122pub struct ServerStreamSender<C: Codec> {
123    stream_id: StreamId,
124    tx: mpsc::UnboundedSender<Bytes>,
125    sequence: std::sync::atomic::AtomicU64,
126    codec: C,
127}
128
129impl<C: Codec> ServerStreamSender<C> {
130    fn new(stream_id: StreamId, tx: mpsc::UnboundedSender<Bytes>, codec: C) -> Self {
131        Self {
132            stream_id,
133            tx,
134            sequence: std::sync::atomic::AtomicU64::new(0),
135            codec,
136        }
137    }
138
139    pub fn stream_id(&self) -> StreamId {
140        self.stream_id
141    }
142
143    pub fn send<T: Serialize>(&self, data: T) -> Result<()> {
144        let seq = self
145            .sequence
146            .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
147        let payload = self.codec.encode(&data)?;
148        let chunk: Message = Message::new(
149            MessageId::new(),
150            MessageType::StreamChunk,
151            "",
152            Bytes::from(payload),
153            MessageMetadata::new().with_stream(self.stream_id, seq),
154        );
155        let encoded = chunk.encode().map_err(RpcError::Transport)?;
156
157        self.tx
158            .send(encoded.freeze())
159            .map_err(|_| RpcError::StreamError("Stream closed".to_string()))
160    }
161
162    pub fn end(&self) -> Result<()> {
163        let end_msg: Message = Message::stream_end(self.stream_id);
164        let encoded = end_msg.encode().map_err(RpcError::Transport)?;
165
166        self.tx
167            .send(encoded.freeze())
168            .map_err(|_| RpcError::StreamError("Stream closed".to_string()))
169    }
170}
171
172pub struct TypedStreamHandler<Req, Item, F, C> {
173    method: String,
174    func: Arc<F>,
175    _phantom: std::marker::PhantomData<(Req, Item, C)>,
176}
177
178impl<Req, Item, F, S, C> TypedStreamHandler<Req, Item, F, C>
179where
180    Req: for<'de> Deserialize<'de> + Send + 'static,
181    Item: Serialize + Send + 'static,
182    S: Stream<Item = Result<Item>> + Send + 'static,
183    F: Fn(Req) -> S + Send + Sync + 'static,
184    C: Codec,
185{
186    pub fn new(method: impl Into<String>, func: F) -> Self {
187        Self {
188            method: method.into(),
189            func: Arc::new(func),
190            _phantom: std::marker::PhantomData,
191        }
192    }
193}
194
195#[async_trait]
196impl<Req, Item, F, S, C> StreamHandler<C> for TypedStreamHandler<Req, Item, F, C>
197where
198    Req: for<'de> Deserialize<'de> + Send + Sync + 'static,
199    Item: Serialize + Send + Sync + 'static,
200    S: Stream<Item = Result<Item>> + Send + 'static,
201    F: Fn(Req) -> S + Send + Sync + 'static,
202    C: Codec + Default,
203{
204    async fn handle(
205        &self,
206        request: Message<C>,
207        sender: ServerStreamSender<C>,
208        codec: &C,
209    ) -> Result<()> {
210        use futures::StreamExt;
211
212        let req: Req = codec.decode(&request.payload)?;
213        let mut stream = Box::pin((self.func)(req));
214
215        while let Some(result) = stream.next().await {
216            match result {
217                Ok(item) => sender.send(item)?,
218                Err(e) => return Err(e),
219            }
220        }
221
222        sender.end()?;
223        Ok(())
224    }
225
226    fn method_name(&self) -> &str {
227        &self.method
228    }
229}
230
231pub struct FnStreamHandler<F, C> {
232    method: String,
233    func: Arc<F>,
234    _codec: std::marker::PhantomData<C>,
235}
236
237impl<F, Fut, C> FnStreamHandler<F, C>
238where
239    F: Fn(Message<C>, ServerStreamSender<C>) -> Fut + Send + Sync + 'static,
240    Fut: Future<Output = Result<()>> + Send + 'static,
241    C: Codec,
242{
243    pub fn new(method: impl Into<String>, func: F) -> Self {
244        Self {
245            method: method.into(),
246            func: Arc::new(func),
247            _codec: std::marker::PhantomData,
248        }
249    }
250}
251
252#[async_trait]
253impl<F, Fut, C> StreamHandler<C> for FnStreamHandler<F, C>
254where
255    F: Fn(Message<C>, ServerStreamSender<C>) -> Fut + Send + Sync + 'static,
256    Fut: Future<Output = Result<()>> + Send + 'static,
257    C: Codec + Default,
258{
259    async fn handle(
260        &self,
261        request: Message<C>,
262        sender: ServerStreamSender<C>,
263        _codec: &C,
264    ) -> Result<()> {
265        (self.func)(request, sender).await
266    }
267
268    fn method_name(&self) -> &str {
269        &self.method
270    }
271}
272
273pub struct RpcServer<C: Codec = BincodeCodec> {
274    handlers: Arc<RwLock<HashMap<String, Arc<dyn Handler<C>>>>>,
275    stream_handlers: Arc<RwLock<HashMap<String, Arc<dyn StreamHandler<C>>>>>,
276    codec: C,
277}
278
279impl RpcServer<BincodeCodec> {
280    pub fn new() -> Self {
281        Self {
282            handlers: Arc::new(RwLock::new(HashMap::new())),
283            stream_handlers: Arc::new(RwLock::new(HashMap::new())),
284            codec: BincodeCodec,
285        }
286    }
287}
288
289impl<C: Codec + Clone + Default + 'static> RpcServer<C> {
290    pub fn with_codec(codec: C) -> Self {
291        Self {
292            handlers: Arc::new(RwLock::new(HashMap::new())),
293            stream_handlers: Arc::new(RwLock::new(HashMap::new())),
294            codec,
295        }
296    }
297
298    pub fn register(&self, handler: Arc<dyn Handler<C>>) {
299        let method = handler.method_name().to_string();
300        self.handlers.write().insert(method, handler);
301    }
302
303    pub fn register_fn<F, Fut>(&self, method: impl Into<String>, func: F)
304    where
305        F: Fn(Message<C>) -> Fut + Send + Sync + 'static,
306        Fut: Future<Output = Result<Message<C>>> + Send + 'static,
307    {
308        let handler: Arc<FnHandler<F, C>> = Arc::new(FnHandler::new(method, func));
309        self.register(handler);
310    }
311
312    pub fn register_typed<Req, Resp, F, Fut>(&self, method: impl Into<String>, func: F)
313    where
314        Req: for<'de> Deserialize<'de> + Send + Sync + 'static,
315        Resp: Serialize + Send + Sync + 'static,
316        F: Fn(Req) -> Fut + Send + Sync + 'static,
317        Fut: Future<Output = Result<Resp>> + Send + 'static,
318    {
319        let handler: Arc<TypedHandler<Req, Resp, F, C>> = Arc::new(TypedHandler::new(method, func));
320        self.register(handler);
321    }
322
323    pub fn register_stream<Req, Item, F, S>(&self, method: impl Into<String>, func: F)
324    where
325        Req: for<'de> Deserialize<'de> + Send + Sync + 'static,
326        Item: Serialize + Send + Sync + 'static,
327        S: Stream<Item = Result<Item>> + Send + 'static,
328        F: Fn(Req) -> S + Send + Sync + 'static,
329    {
330        let method = method.into();
331        let handler: Arc<TypedStreamHandler<Req, Item, F, C>> =
332            Arc::new(TypedStreamHandler::new(method.clone(), func));
333        self.stream_handlers.write().insert(method, handler);
334    }
335
336    pub fn register_stream_fn<F, Fut>(&self, method: impl Into<String>, func: F)
337    where
338        F: Fn(Message<C>, ServerStreamSender<C>) -> Fut + Send + Sync + 'static,
339        Fut: Future<Output = Result<()>> + Send + 'static,
340    {
341        let method = method.into();
342        let handler: Arc<FnStreamHandler<F, C>> =
343            Arc::new(FnStreamHandler::new(method.clone(), func));
344        self.stream_handlers.write().insert(method, handler);
345    }
346
347    pub async fn handle_message<T: MessageChannel<C>>(
348        &self,
349        message: Message<C>,
350        transport: &T,
351    ) -> Option<Message<C>> {
352        match message.msg_type {
353            MessageType::Call => {
354                if message.metadata.stream_id.is_some() {
355                    self.handle_stream_call(message, transport).await;
356                    return None;
357                }
358
359                let handler = self.handlers.read().get(&message.method).cloned();
360                match handler {
361                    Some(h) => match h.handle(message.clone(), &self.codec).await {
362                        Ok(response) => Some(response),
363                        Err(e) => Some(Message::error(message.id, e.to_string())),
364                    },
365                    None => Some(Message::error(
366                        message.id,
367                        format!("Method not found: {}", message.method),
368                    )),
369                }
370            }
371            MessageType::Notification => {
372                let handler = self.handlers.read().get(&message.method).cloned();
373                if let Some(h) = handler {
374                    let _ = h.handle(message, &self.codec).await;
375                }
376                None
377            }
378            _ => None,
379        }
380    }
381
382    async fn handle_stream_call<T: MessageChannel<C>>(&self, message: Message<C>, transport: &T) {
383        let stream_id = message.metadata.stream_id.unwrap_or_else(next_stream_id);
384        let handler = self.stream_handlers.read().get(&message.method).cloned();
385
386        let Some(h) = handler else {
387            let error = Message::stream_error(
388                message.id,
389                stream_id,
390                format!("Stream method not found: {}", message.method),
391            );
392            let _ = transport.send(&error).await;
393            return;
394        };
395
396        let (tx, mut rx) = mpsc::unbounded_channel::<Bytes>();
397        let sender = ServerStreamSender::new(stream_id, tx, self.codec.clone());
398
399        let transport_send = async {
400            while let Some(data) = rx.recv().await {
401                if let Ok(msg) = Message::<C>::decode(&data[..]) {
402                    let _ = transport.send(&msg).await;
403                }
404            }
405        };
406
407        let codec = self.codec.clone();
408        let handler_task = async {
409            if let Err(e) = h.handle(message.clone(), sender, &codec).await {
410                let error = Message::stream_error(message.id, stream_id, e.to_string());
411                let _ = transport.send(&error).await;
412            }
413        };
414
415        tokio::join!(handler_task, transport_send);
416    }
417
418    pub async fn serve<T: MessageChannel<C>>(&self, transport: Arc<T>) -> Result<()> {
419        loop {
420            let message = transport.recv().await.map_err(RpcError::Transport)?;
421
422            if let Some(response) = self.handle_message(message, transport.as_ref()).await {
423                transport
424                    .send(&response)
425                    .await
426                    .map_err(RpcError::Transport)?;
427            }
428        }
429    }
430
431    pub fn spawn_handler<T: MessageChannel<C> + 'static>(&self, transport: T) -> ServerHandle {
432        let handlers = self.handlers.clone();
433        let stream_handlers = self.stream_handlers.clone();
434        let codec = self.codec.clone();
435        let transport = Arc::new(transport);
436
437        let handle = tokio::spawn(async move {
438            let server = RpcServer {
439                handlers,
440                stream_handlers,
441                codec,
442            };
443            let _ = server.serve(transport).await;
444        });
445
446        ServerHandle { handle }
447    }
448
449    pub fn handler_count(&self) -> usize {
450        self.handlers.read().len() + self.stream_handlers.read().len()
451    }
452}
453
454impl Default for RpcServer<BincodeCodec> {
455    fn default() -> Self {
456        Self::new()
457    }
458}
459
460pub struct ServerHandle {
461    handle: tokio::task::JoinHandle<()>,
462}
463
464impl ServerHandle {
465    pub async fn shutdown(self) {
466        self.handle.abort();
467        let _ = self.handle.await;
468    }
469
470    pub fn is_finished(&self) -> bool {
471        self.handle.is_finished()
472    }
473}
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use crate::channel::message::MessageChannelAdapter;
479    use crate::streaming::StreamReceiver;
480    use crate::transport::channel::{ChannelConfig, ChannelFrameTransport};
481
482    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
483    struct AddRequest {
484        a: i32,
485        b: i32,
486    }
487
488    #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
489    struct AddResponse {
490        result: i32,
491    }
492
493    #[tokio::test]
494    async fn test_server_typed_handler() {
495        let config = ChannelConfig::default();
496        let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
497
498        let client_channel = MessageChannelAdapter::new(t1);
499        let server_channel = MessageChannelAdapter::new(t2);
500
501        let server = RpcServer::new();
502        server.register_typed("add", |req: AddRequest| async move {
503            Ok(AddResponse {
504                result: req.a + req.b,
505            })
506        });
507
508        let _handle = server.spawn_handler(server_channel);
509
510        let request: Message = Message::call("add", AddRequest { a: 10, b: 32 }).unwrap();
511        client_channel.send(&request).await.unwrap();
512
513        let response = client_channel.recv().await.unwrap();
514        assert_eq!(response.msg_type, MessageType::Reply);
515
516        let resp: AddResponse = response.deserialize_payload().unwrap();
517        assert_eq!(resp.result, 42);
518    }
519
520    #[tokio::test]
521    async fn test_server_stream_handler() {
522        let config = ChannelConfig::default();
523        let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
524
525        let client_channel = Arc::new(MessageChannelAdapter::new(t1));
526        let server_channel = MessageChannelAdapter::new(t2);
527
528        let server = RpcServer::new();
529        server.register_stream("range", |count: i32| {
530            futures::stream::iter((1..=count).map(|i| Ok(i)))
531        });
532
533        let _handle = server.spawn_handler(server_channel);
534
535        let stream_id = next_stream_id();
536        let mut request: Message = Message::call("range", 5i32).unwrap();
537        request.metadata = request.metadata.with_stream(stream_id, 0);
538
539        let manager = crate::streaming::StreamManager::new();
540        let mut receiver: StreamReceiver<i32> = manager.create_receiver(stream_id);
541
542        client_channel.send(&request).await.unwrap();
543
544        let client_channel_clone = client_channel.clone();
545        let recv_task = tokio::spawn(async move {
546            loop {
547                match client_channel_clone.recv().await {
548                    Ok(msg) => {
549                        if msg.msg_type == MessageType::StreamEnd {
550                            manager.handle_message(&msg);
551                            break;
552                        }
553                        manager.handle_message(&msg);
554                    }
555                    Err(_) => break,
556                }
557            }
558        });
559
560        let mut items = Vec::new();
561        while let Some(result) = receiver.recv().await {
562            items.push(result.unwrap());
563        }
564
565        recv_task.await.unwrap();
566        assert_eq!(items, vec![1, 2, 3, 4, 5]);
567    }
568
569    #[tokio::test]
570    async fn test_server_notification() {
571        use std::sync::atomic::{AtomicBool, Ordering};
572
573        let config = ChannelConfig::default();
574        let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
575
576        let client_channel = MessageChannelAdapter::new(t1);
577        let server_channel = MessageChannelAdapter::new(t2);
578
579        let called = Arc::new(AtomicBool::new(false));
580        let called_clone = called.clone();
581
582        let server = RpcServer::new();
583        server.register_fn("log", move |_msg: Message| {
584            let called = called_clone.clone();
585            async move {
586                called.store(true, Ordering::Release);
587                Ok(Message::reply(MessageId::new(), ())?)
588            }
589        });
590
591        let _handle = server.spawn_handler(server_channel);
592
593        let notification: Message = Message::notification("log", "test").unwrap();
594        client_channel.send(&notification).await.unwrap();
595
596        tokio::time::sleep(std::time::Duration::from_millis(50)).await;
597        assert!(called.load(Ordering::Acquire));
598    }
599}