axum_js_fetch/
lib.rs

1use axum::{body::HttpBody, response::IntoResponse};
2use bytes::{buf::Buf, Bytes};
3use futures_lite::stream::{Stream, StreamExt};
4use http::Request;
5use js_sys::Uint8Array;
6use std::{error::Error, pin::Pin, task::Poll, sync::{Mutex, Arc}};
7use tower::{util::BoxCloneService, BoxError, ServiceBuilder, ServiceExt};
8use wasm_bindgen::prelude::*;
9use wasm_bindgen_futures::spawn_local;
10use wasm_streams::ReadableStream;
11
12struct StreamingBody<T> {
13    stream: T,
14}
15
16impl<T> HttpBody for StreamingBody<T>
17where
18    T: Stream<Item = Bytes> + Unpin,
19{
20    type Data = Bytes;
21
22    type Error = BoxError;
23
24    fn poll_data(
25        self: Pin<&mut Self>,
26        cx: &mut std::task::Context<'_>,
27    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
28        Pin::into_inner(self)
29            .stream
30            .poll_next(cx)
31            .map(|d| d.map(Ok))
32    }
33
34    fn poll_trailers(
35        self: Pin<&mut Self>,
36        _cx: &mut std::task::Context<'_>,
37    ) -> std::task::Poll<Result<Option<http::HeaderMap>, Self::Error>> {
38        // todo!('add support for trailers')
39        std::task::Poll::Ready(Ok(None))
40    }
41}
42
43impl<T> Stream for StreamingBody<T>
44where
45    T: HttpBody + Unpin,
46    T::Data: Buf,
47    T::Error: Into<Box<dyn Error + Send + Sync>>,
48{
49    type Item = Result<JsValue, JsValue>;
50
51    fn poll_next(
52        self: Pin<&mut Self>,
53        cx: &mut std::task::Context<'_>,
54    ) -> Poll<Option<Self::Item>> {
55        let s = Pin::into_inner(self);
56
57        Pin::new(&mut s.stream).poll_data(cx).map(|f| match f {
58            Some(Ok(b)) => {
59                let slice = b.chunk();
60                let a = Uint8Array::new(&JsValue::from(slice.len()));
61                a.copy_from(slice);
62                let out: JsValue = a.into();
63                Some(Ok(out))
64            }
65            Some(Err(e)) => Some(Err(JsValue::from_str(&e.into().to_string()))),
66            None => None,
67        })
68    }
69}
70
71#[derive(Debug)]
72enum Err {
73    ConvertError,
74}
75
76fn to_http_body(req: gloo_net::http::Request) -> axum::body::Body {
77    let (sender, body) = axum::body::Body::channel();
78    let arc_sender = Arc::new(Mutex::new(sender));
79
80    if let Some(b) = req.body() {
81        spawn_local(async move {
82            let sender_clone = Arc::clone(&arc_sender);
83            ReadableStream::from_raw(b.unchecked_into())
84                .into_stream()
85                .try_for_each(|buf_js| -> Result<(), Err> {
86                    let buffer =
87                        js_sys::Uint8Array::new(&buf_js.map_err(|_| Err::ConvertError)?);
88                    let bytes: Bytes = buffer.to_vec().into();
89                
90                    let sender_clone = Arc::clone(&sender_clone);
91                    // dont block on sending bytes into stream
92                    spawn_local(async move {
93                        let binding = Arc::clone(&sender_clone);
94                        let mut sender = binding.lock().unwrap();
95                        let _ = sender.send_data(bytes).await;
96                    });
97                    Ok(())
98                })
99                .await.unwrap();
100        });
101    }
102
103    body
104}
105
106fn from_fetch_request(req: web_sys::Request) -> Request<axum::body::Body> {
107    let gloo_req = gloo_net::http::Request::from(req);
108    let headers = gloo_req.headers();
109
110    let mut builder = Request::builder()
111        .uri(gloo_req.url())
112        .method(gloo_req.method());
113
114    for (key, value) in headers.entries() {
115        builder = builder.header(key, value)
116    }
117
118    let body = to_http_body(gloo_req);
119    builder.body(body).unwrap()
120}
121
122fn create_default_error(e: impl Error) -> gloo_net::http::Response {
123    gloo_net::http::Response::builder()
124        .status(500)
125        .body(Some(format!("{e}").as_str()))
126        .unwrap()
127}
128
129fn to_fetch_response(res: impl IntoResponse) -> web_sys::Response {
130    let (parts, body) = res.into_response().into_parts();
131    let headers = gloo_net::http::Headers::new();
132    for (key, value) in parts.headers.iter() {
133        headers.append(key.as_str(), value.to_str().unwrap());
134    }
135    let stream_body: web_sys::ReadableStream =
136        ReadableStream::from_stream(StreamingBody { stream: body })
137            .into_raw()
138            .unchecked_into();
139
140    gloo_net::http::Response::builder()
141        .status(parts.status.as_u16())
142        .headers(headers)
143        .body(Some(&stream_body))
144        .unwrap_or_else(create_default_error)
145        .into()
146}
147
148pub struct App {
149    service: BoxCloneService<web_sys::Request, web_sys::Response, BoxError>,
150}
151
152// impl Default for App {
153//     fn default() -> Self {
154//         console_error_panic_hook::set_once();
155
156//         let svc = ServiceBuilder::new()
157//             .map_request(to_ossi_request)
158//             .service(app)
159//             .map_response(to_ossi_response);
160//     }
161// }
162
163impl App {
164    pub fn new<S>(service: S) -> Self
165    where
166        S: ServiceExt<Request<axum::body::Body>> + Clone + Send + Sized + 'static,
167        S::Future: Send + 'static,
168        S::Response: IntoResponse,
169        S::Error: Into<Box<dyn Error + Send + Sync>>,
170    {
171        let svc = ServiceBuilder::new()
172            .map_request(from_fetch_request)
173            .map_response(to_fetch_response)
174            .service(service)
175            .map_err(|e| e.into());
176
177        let service = svc.boxed_clone();
178        Self { service }
179    }
180
181    pub async fn serve(&self, req: web_sys::Request) -> web_sys::Response {
182        self.service.clone().oneshot(req).await.unwrap()
183    }
184}
185
186#[cfg(test)]
187pub mod tests {
188    use crate::App;
189    use axum::{
190        body::StreamBody,
191        extract::{Json, Query},
192        routing::get,
193        Router,
194    };
195    use futures_lite::{stream, Stream};
196    use serde::Deserialize;
197    use std::{collections::HashMap, convert::Infallible};
198
199    struct MyApp(App);
200
201    impl Default for MyApp {
202        fn default() -> Self {
203            let app = Router::new()
204                .route("/", get(handler).post(handler2))
205                .route("/stream", get(handler3));
206            Self(App::new(app))
207        }
208    }
209
210    async fn handler(Query(params): Query<HashMap<String, String>>) -> String {
211        format!("received: {:?}", params)
212    }
213
214    #[derive(Debug, Deserialize)]
215    struct TestStruct {
216        hello: String,
217    }
218
219    async fn handler2(
220        Json(payload): Json<TestStruct>,
221    ) -> StreamBody<impl Stream<Item = Result<String, Infallible>>> {
222        let stream = stream::repeat(Ok(payload.hello));
223        StreamBody::new(stream)
224    }
225
226    async fn handler3() -> StreamBody<impl Stream<Item = std::io::Result<&'static str>>> {
227        let chunks = vec![Ok("Hello,"), Ok(" "), Ok("world!")];
228        let stream = stream::iter(chunks);
229        StreamBody::new(stream)
230    }
231}