1#![doc = include_doc::function_body!("tests/doc.rs", router_example, [my_add, MyAdd])]
9use std::sync::Arc;
11
12use arpy::{protocol, FnRemote};
13use arpy_server::{FnRemoteBody, WebSocketRouter};
14use axum::{
15 extract::{ws::WebSocket, WebSocketUpgrade},
16 response::{
17 sse::{Event, KeepAlive},
18 Sse,
19 },
20 routing::{get, post},
21 BoxError, Router,
22};
23use futures::{Stream, StreamExt};
24use http::ArpyRequest;
25use hyper::HeaderMap;
26use serde::Serialize;
27use websocket::WebSocketHandler;
28
29pub mod http;
30mod websocket;
31
32pub trait RpcRoute {
35 fn http_rpc_route<F, T>(self, prefix: &str, f: F) -> Self
40 where
41 F: FnRemoteBody<T> + Send + Sync + 'static,
42 T: FnRemote + Send + 'static;
43
44 fn ws_rpc_route(self, path: &str, router: WebSocketRouter, max_in_flight: usize) -> Self;
46
47 fn sse_route<T, S, Error>(
49 self,
50 path: &str,
51 events: impl FnMut() -> S + Send + Clone + 'static,
52 keep_alive: Option<KeepAlive>,
53 ) -> Self
54 where
55 T: Serialize + protocol::MsgId + 'static,
56 S: Stream<Item = Result<T, Error>> + Send + 'static,
57 Error: Into<BoxError> + 'static;
58}
59
60impl RpcRoute for Router {
61 fn http_rpc_route<F, T>(self, prefix: &str, f: F) -> Self
62 where
63 F: FnRemoteBody<T> + Send + Sync + 'static,
64 T: FnRemote + Send + 'static,
65 {
66 let id = T::ID;
67 let f = Arc::new(f);
68 self.route(
69 &format!("{prefix}/{id}",),
70 post(move |headers: HeaderMap, arpy: ArpyRequest<T>| http::handler(headers, arpy, f)),
71 )
72 }
73
74 fn ws_rpc_route(self, path: &str, router: WebSocketRouter, max_in_flight: usize) -> Self {
75 let handler = WebSocketHandler::new(router, max_in_flight);
76
77 self.route(
78 path,
79 get(|ws: WebSocketUpgrade| async {
80 ws.on_upgrade(
81 |socket: WebSocket| async move { handler.handle_socket(socket).await },
82 )
83 }),
84 )
85 }
86
87 fn sse_route<T, S, Error>(
89 self,
90 path: &str,
91 mut events: impl FnMut() -> S + Send + Clone + 'static,
92 keep_alive: Option<KeepAlive>,
93 ) -> Self
94 where
95 T: Serialize + protocol::MsgId + 'static,
96 S: Stream<Item = Result<T, Error>> + Send + 'static,
97 Error: Into<BoxError> + 'static,
98 {
99 self.route(
100 path,
101 get(|| async move {
102 let sse = sse_handler(events()).await;
103
104 if let Some(keep_alive) = keep_alive {
105 sse.keep_alive(keep_alive)
106 } else {
107 sse
108 }
109 }),
110 )
111 }
112}
113
114pub async fn sse_handler<T: Serialize + protocol::MsgId, Error: Into<BoxError>>(
119 events: impl Stream<Item = Result<T, Error>> + Send + 'static,
120) -> Sse<impl Stream<Item = Result<Event, Error>>> {
121 Sse::new(
122 events.map(|item| item.map(|item| Event::default().event(T::ID).json_data(item).unwrap())),
123 )
124}