grpc_web_client/
lib.rs

1mod call;
2
3use bytes::Bytes;
4use call::{Encoding, GrpcWebCall};
5use core::{
6    fmt,
7    task::{Context, Poll},
8};
9use futures::{Future, Stream, TryStreamExt};
10use http::{header::HeaderName, request::Request, response::Response, HeaderMap, HeaderValue};
11use http_body::Body;
12use js_sys::{Array, Uint8Array};
13use std::{error::Error, pin::Pin};
14use tonic::{body::BoxBody, client::GrpcService, Status};
15use wasm_bindgen::{JsCast, JsValue};
16use wasm_bindgen_futures::JsFuture;
17use wasm_streams::ReadableStream;
18use web_sys::{Headers, RequestInit};
19
20#[derive(Debug, Clone, PartialEq)]
21pub enum ClientError {
22    Err,
23    FetchFailed(JsValue),
24}
25
26impl Error for ClientError {}
27impl fmt::Display for ClientError {
28    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
29        write!(f, "{:?}", self)
30    }
31}
32
33pub type CredentialsMode = web_sys::RequestCredentials;
34
35pub type RequestMode = web_sys::RequestMode;
36
37#[derive(Clone)]
38pub struct Client {
39    base_uri: String,
40    credentials: CredentialsMode,
41    mode: RequestMode,
42    encoding: Encoding,
43}
44
45impl Client {
46    pub fn new(base_uri: String) -> Self {
47        Client {
48            base_uri,
49            credentials: CredentialsMode::SameOrigin,
50            mode: RequestMode::Cors,
51            encoding: Encoding::None,
52        }
53    }
54
55    async fn request(self, rpc: Request<BoxBody>) -> Result<Response<BoxBody>, ClientError> {
56        let mut uri = rpc.uri().to_string();
57        uri.insert_str(0, &self.base_uri);
58
59        let headers = Headers::new().unwrap();
60        for (k, v) in rpc.headers().iter() {
61            headers.set(k.as_str(), v.to_str().unwrap()).unwrap();
62        }
63        headers.set("x-user-agent", "grpc-web-rust/0.1").unwrap();
64        headers.set("x-grpc-web", "1").unwrap();
65        headers
66            .set("content-type", self.encoding.to_content_type())
67            .unwrap();
68
69        let body_bytes = hyper::body::to_bytes(rpc.into_body()).await.unwrap();
70        let body_array: Uint8Array = body_bytes.as_ref().into();
71        let body_js: &JsValue = body_array.as_ref();
72
73        let mut init = RequestInit::new();
74        init.method("POST")
75            .mode(self.mode)
76            .credentials(self.credentials)
77            .body(Some(body_js))
78            .headers(headers.as_ref());
79
80        let request = web_sys::Request::new_with_str_and_init(&uri, &init).unwrap();
81
82        let window = web_sys::window().unwrap();
83        let fetch = JsFuture::from(window.fetch_with_request(&request))
84            .await
85            .map_err(ClientError::FetchFailed)?;
86        let fetch_res: web_sys::Response = fetch.dyn_into().unwrap();
87
88        let mut res = Response::builder().status(fetch_res.status());
89        let headers = res.headers_mut().unwrap();
90
91        for kv in js_sys::try_iter(fetch_res.headers().as_ref())
92            .unwrap()
93            .unwrap()
94        {
95            let pair: Array = kv.unwrap().into();
96            headers.append(
97                HeaderName::from_bytes(pair.get(0).as_string().unwrap().as_bytes()).unwrap(),
98                HeaderValue::from_str(&pair.get(1).as_string().unwrap()).unwrap(),
99            );
100        }
101
102        let body_stream = ReadableStream::from_raw(fetch_res.body().unwrap().unchecked_into());
103        let body = GrpcWebCall::client_response(
104            ReadableStreamBody::new(body_stream),
105            Encoding::from_content_type(headers),
106        );
107
108        Ok(res.body(BoxBody::new(body)).unwrap())
109    }
110}
111
112impl GrpcService<BoxBody> for Client {
113    type ResponseBody = BoxBody;
114    type Error = ClientError;
115    type Future = Pin<Box<dyn Future<Output = Result<Response<BoxBody>, ClientError>>>>;
116
117    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
118        Poll::Ready(Ok(()))
119    }
120
121    fn call(&mut self, rpc: Request<BoxBody>) -> Self::Future {
122        Box::pin(self.clone().request(rpc))
123    }
124}
125
126struct ReadableStreamBody {
127    stream: Pin<Box<dyn Stream<Item = Result<Bytes, Status>>>>,
128}
129
130impl ReadableStreamBody {
131    fn new(inner: ReadableStream) -> Self {
132        ReadableStreamBody {
133            stream: Box::pin(
134                inner
135                    .into_stream()
136                    .map_ok(|buf_js| {
137                        let buffer = Uint8Array::new(&buf_js);
138                        let mut bytes_vec = vec![0; buffer.length() as usize];
139                        buffer.copy_to(&mut bytes_vec);
140                        let bytes: Bytes = bytes_vec.into();
141                        bytes
142                    })
143                    .map_err(|_| Status::unknown("readablestream error")),
144            ),
145        }
146    }
147}
148
149impl Body for ReadableStreamBody {
150    type Data = Bytes;
151    type Error = Status;
152
153    fn poll_data(
154        mut self: Pin<&mut Self>,
155        cx: &mut Context<'_>,
156    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
157        self.stream.as_mut().poll_next(cx)
158    }
159
160    fn poll_trailers(
161        self: Pin<&mut Self>,
162        _: &mut Context<'_>,
163    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
164        Poll::Ready(Ok(None))
165    }
166
167    fn is_end_stream(&self) -> bool {
168        false
169    }
170}
171
172// WARNING: these are required to satisfy the Body and Error traits, but JsValue is not thread-safe.
173// This shouldn't be an issue because wasm doesn't have threads currently.
174
175unsafe impl Sync for ReadableStreamBody {}
176unsafe impl Send for ReadableStreamBody {}
177
178unsafe impl Sync for ClientError {}
179unsafe impl Send for ClientError {}