1use axum::{body::HttpBody, response::IntoResponse};
2use bytes::{buf::Buf, Bytes};
3use futures_lite::stream::{Stream, StreamExt};
4use http::Request;
5use js_sys::Uint8Array;
6use std::{error::Error, pin::Pin, task::Poll, sync::{Mutex, Arc}};
7use tower::{util::BoxCloneService, BoxError, ServiceBuilder, ServiceExt};
8use wasm_bindgen::prelude::*;
9use wasm_bindgen_futures::spawn_local;
10use wasm_streams::ReadableStream;
11
12struct StreamingBody<T> {
13 stream: T,
14}
15
16impl<T> HttpBody for StreamingBody<T>
17where
18 T: Stream<Item = Bytes> + Unpin,
19{
20 type Data = Bytes;
21
22 type Error = BoxError;
23
24 fn poll_data(
25 self: Pin<&mut Self>,
26 cx: &mut std::task::Context<'_>,
27 ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
28 Pin::into_inner(self)
29 .stream
30 .poll_next(cx)
31 .map(|d| d.map(Ok))
32 }
33
34 fn poll_trailers(
35 self: Pin<&mut Self>,
36 _cx: &mut std::task::Context<'_>,
37 ) -> std::task::Poll<Result<Option<http::HeaderMap>, Self::Error>> {
38 std::task::Poll::Ready(Ok(None))
40 }
41}
42
43impl<T> Stream for StreamingBody<T>
44where
45 T: HttpBody + Unpin,
46 T::Data: Buf,
47 T::Error: Into<Box<dyn Error + Send + Sync>>,
48{
49 type Item = Result<JsValue, JsValue>;
50
51 fn poll_next(
52 self: Pin<&mut Self>,
53 cx: &mut std::task::Context<'_>,
54 ) -> Poll<Option<Self::Item>> {
55 let s = Pin::into_inner(self);
56
57 Pin::new(&mut s.stream).poll_data(cx).map(|f| match f {
58 Some(Ok(b)) => {
59 let slice = b.chunk();
60 let a = Uint8Array::new(&JsValue::from(slice.len()));
61 a.copy_from(slice);
62 let out: JsValue = a.into();
63 Some(Ok(out))
64 }
65 Some(Err(e)) => Some(Err(JsValue::from_str(&e.into().to_string()))),
66 None => None,
67 })
68 }
69}
70
71#[derive(Debug)]
72enum Err {
73 ConvertError,
74}
75
76fn to_http_body(req: gloo_net::http::Request) -> axum::body::Body {
77 let (sender, body) = axum::body::Body::channel();
78 let arc_sender = Arc::new(Mutex::new(sender));
79
80 if let Some(b) = req.body() {
81 spawn_local(async move {
82 let sender_clone = Arc::clone(&arc_sender);
83 ReadableStream::from_raw(b.unchecked_into())
84 .into_stream()
85 .try_for_each(|buf_js| -> Result<(), Err> {
86 let buffer =
87 js_sys::Uint8Array::new(&buf_js.map_err(|_| Err::ConvertError)?);
88 let bytes: Bytes = buffer.to_vec().into();
89
90 let sender_clone = Arc::clone(&sender_clone);
91 spawn_local(async move {
93 let binding = Arc::clone(&sender_clone);
94 let mut sender = binding.lock().unwrap();
95 let _ = sender.send_data(bytes).await;
96 });
97 Ok(())
98 })
99 .await.unwrap();
100 });
101 }
102
103 body
104}
105
106fn from_fetch_request(req: web_sys::Request) -> Request<axum::body::Body> {
107 let gloo_req = gloo_net::http::Request::from(req);
108 let headers = gloo_req.headers();
109
110 let mut builder = Request::builder()
111 .uri(gloo_req.url())
112 .method(gloo_req.method());
113
114 for (key, value) in headers.entries() {
115 builder = builder.header(key, value)
116 }
117
118 let body = to_http_body(gloo_req);
119 builder.body(body).unwrap()
120}
121
122fn create_default_error(e: impl Error) -> gloo_net::http::Response {
123 gloo_net::http::Response::builder()
124 .status(500)
125 .body(Some(format!("{e}").as_str()))
126 .unwrap()
127}
128
129fn to_fetch_response(res: impl IntoResponse) -> web_sys::Response {
130 let (parts, body) = res.into_response().into_parts();
131 let headers = gloo_net::http::Headers::new();
132 for (key, value) in parts.headers.iter() {
133 headers.append(key.as_str(), value.to_str().unwrap());
134 }
135 let stream_body: web_sys::ReadableStream =
136 ReadableStream::from_stream(StreamingBody { stream: body })
137 .into_raw()
138 .unchecked_into();
139
140 gloo_net::http::Response::builder()
141 .status(parts.status.as_u16())
142 .headers(headers)
143 .body(Some(&stream_body))
144 .unwrap_or_else(create_default_error)
145 .into()
146}
147
148pub struct App {
149 service: BoxCloneService<web_sys::Request, web_sys::Response, BoxError>,
150}
151
152impl App {
164 pub fn new<S>(service: S) -> Self
165 where
166 S: ServiceExt<Request<axum::body::Body>> + Clone + Send + Sized + 'static,
167 S::Future: Send + 'static,
168 S::Response: IntoResponse,
169 S::Error: Into<Box<dyn Error + Send + Sync>>,
170 {
171 let svc = ServiceBuilder::new()
172 .map_request(from_fetch_request)
173 .map_response(to_fetch_response)
174 .service(service)
175 .map_err(|e| e.into());
176
177 let service = svc.boxed_clone();
178 Self { service }
179 }
180
181 pub async fn serve(&self, req: web_sys::Request) -> web_sys::Response {
182 self.service.clone().oneshot(req).await.unwrap()
183 }
184}
185
186#[cfg(test)]
187pub mod tests {
188 use crate::App;
189 use axum::{
190 body::StreamBody,
191 extract::{Json, Query},
192 routing::get,
193 Router,
194 };
195 use futures_lite::{stream, Stream};
196 use serde::Deserialize;
197 use std::{collections::HashMap, convert::Infallible};
198
199 struct MyApp(App);
200
201 impl Default for MyApp {
202 fn default() -> Self {
203 let app = Router::new()
204 .route("/", get(handler).post(handler2))
205 .route("/stream", get(handler3));
206 Self(App::new(app))
207 }
208 }
209
210 async fn handler(Query(params): Query<HashMap<String, String>>) -> String {
211 format!("received: {:?}", params)
212 }
213
214 #[derive(Debug, Deserialize)]
215 struct TestStruct {
216 hello: String,
217 }
218
219 async fn handler2(
220 Json(payload): Json<TestStruct>,
221 ) -> StreamBody<impl Stream<Item = Result<String, Infallible>>> {
222 let stream = stream::repeat(Ok(payload.hello));
223 StreamBody::new(stream)
224 }
225
226 async fn handler3() -> StreamBody<impl Stream<Item = std::io::Result<&'static str>>> {
227 let chunks = vec![Ok("Hello,"), Ok(" "), Ok("world!")];
228 let stream = stream::iter(chunks);
229 StreamBody::new(stream)
230 }
231}