arpy_axum/
lib.rs

1//! # Arpy Axum
2//!
3//! [`arpy`] integration for [`axum`].
4//!
5//! ## Example
6//!
7//! ```
8#![doc = include_doc::function_body!("tests/doc.rs", router_example, [my_add, MyAdd])]
9//! ```
10use std::sync::Arc;
11
12use arpy::{protocol, FnRemote};
13use arpy_server::{FnRemoteBody, WebSocketRouter};
14use axum::{
15    extract::{ws::WebSocket, WebSocketUpgrade},
16    response::{
17        sse::{Event, KeepAlive},
18        Sse,
19    },
20    routing::{get, post},
21    BoxError, Router,
22};
23use futures::{Stream, StreamExt};
24use http::ArpyRequest;
25use hyper::HeaderMap;
26use serde::Serialize;
27use websocket::WebSocketHandler;
28
29pub mod http;
30mod websocket;
31
32/// Extension trait to add RPC routes. See [module level documentation][crate]
33/// for an example.
34pub trait RpcRoute {
35    /// Add an HTTP route to handle a single RPC endpoint.
36    ///
37    /// Routes are constructed with `"{prefix}/{msg_id}"` where `msg_id = T::ID`
38    /// from [`MsgId`][arpy::protocol::MsgId].
39    fn http_rpc_route<F, T>(self, prefix: &str, f: F) -> Self
40    where
41        F: FnRemoteBody<T> + Send + Sync + 'static,
42        T: FnRemote + Send + 'static;
43
44    /// Add all the RPC endpoints in `router` to a websocket endpoint at `path`.
45    fn ws_rpc_route(self, path: &str, router: WebSocketRouter, max_in_flight: usize) -> Self;
46
47    /// Add a Server Sent Events route.
48    fn sse_route<T, S, Error>(
49        self,
50        path: &str,
51        events: impl FnMut() -> S + Send + Clone + 'static,
52        keep_alive: Option<KeepAlive>,
53    ) -> Self
54    where
55        T: Serialize + protocol::MsgId + 'static,
56        S: Stream<Item = Result<T, Error>> + Send + 'static,
57        Error: Into<BoxError> + 'static;
58}
59
60impl RpcRoute for Router {
61    fn http_rpc_route<F, T>(self, prefix: &str, f: F) -> Self
62    where
63        F: FnRemoteBody<T> + Send + Sync + 'static,
64        T: FnRemote + Send + 'static,
65    {
66        let id = T::ID;
67        let f = Arc::new(f);
68        self.route(
69            &format!("{prefix}/{id}",),
70            post(move |headers: HeaderMap, arpy: ArpyRequest<T>| http::handler(headers, arpy, f)),
71        )
72    }
73
74    fn ws_rpc_route(self, path: &str, router: WebSocketRouter, max_in_flight: usize) -> Self {
75        let handler = WebSocketHandler::new(router, max_in_flight);
76
77        self.route(
78            path,
79            get(|ws: WebSocketUpgrade| async {
80                ws.on_upgrade(
81                    |socket: WebSocket| async move { handler.handle_socket(socket).await },
82                )
83            }),
84        )
85    }
86
87    /// Add an SSE route using [`sse_handler`].
88    fn sse_route<T, S, Error>(
89        self,
90        path: &str,
91        mut events: impl FnMut() -> S + Send + Clone + 'static,
92        keep_alive: Option<KeepAlive>,
93    ) -> Self
94    where
95        T: Serialize + protocol::MsgId + 'static,
96        S: Stream<Item = Result<T, Error>> + Send + 'static,
97        Error: Into<BoxError> + 'static,
98    {
99        self.route(
100            path,
101            get(|| async move {
102                let sse = sse_handler(events()).await;
103
104                if let Some(keep_alive) = keep_alive {
105                    sse.keep_alive(keep_alive)
106                } else {
107                    sse
108                }
109            }),
110        )
111    }
112}
113
114/// SSE Handler.
115///
116/// This uses `serde_json` to serialize data, and assumes it can't fail. See
117/// [`serde_json::to_writer`] for more details.
118pub async fn sse_handler<T: Serialize + protocol::MsgId, Error: Into<BoxError>>(
119    events: impl Stream<Item = Result<T, Error>> + Send + 'static,
120) -> Sse<impl Stream<Item = Result<Event, Error>>> {
121    Sse::new(
122        events.map(|item| item.map(|item| Event::default().event(T::ID).json_data(item).unwrap())),
123    )
124}