cipherstash_grpc_web_client/
lib.rs

1mod call;
2use cfg_if::cfg_if;
3
4cfg_if! {
5    if #[cfg(feature = "web-worker")] {
6        mod worker;
7        use crate::worker as request;
8    } else {
9        mod browser;
10        use crate::browser as request;
11    }
12}
13
14use bytes::Bytes;
15use call::{Encoding, GrpcWebCall};
16use core::{
17    fmt,
18    task::{Context, Poll},
19};
20use futures::{Future, Stream, TryStreamExt};
21use http::{
22    header::{HeaderName, InvalidHeaderName, InvalidHeaderValue, ToStrError},
23    request::Request,
24    response::Response,
25    HeaderMap, HeaderValue,
26};
27use http_body::Body;
28use js_sys::{Array, Uint8Array};
29use std::{error::Error, pin::Pin};
30use tonic::{body::BoxBody, client::GrpcService, Status};
31use wasm_bindgen::{JsCast, JsValue};
32use wasm_streams::ReadableStream;
33use web_sys::Headers;
34
35#[derive(Debug, Clone, PartialEq)]
36pub enum ClientError {
37    Err,
38    FetchFailed(JsValue),
39    Other(String),
40}
41
42impl From<ToStrError> for ClientError {
43    fn from(_: ToStrError) -> Self {
44        Self::Other("Header value contained invalid ASCII".into())
45    }
46}
47
48impl From<InvalidHeaderName> for ClientError {
49    fn from(_: InvalidHeaderName) -> Self {
50        Self::Other("Header name was invalid".into())
51    }
52}
53
54impl From<InvalidHeaderValue> for ClientError {
55    fn from(_: InvalidHeaderValue) -> Self {
56        Self::Other("Header value was invalid".into())
57    }
58}
59
60impl From<JsValue> for ClientError {
61    fn from(val: JsValue) -> Self {
62        Self::FetchFailed(val)
63    }
64}
65
66impl Error for ClientError {}
67impl fmt::Display for ClientError {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        write!(f, "{:?}", self)
70    }
71}
72
73pub type CredentialsMode = web_sys::RequestCredentials;
74
75pub type RequestMode = web_sys::RequestMode;
76
77#[derive(Clone)]
78pub struct Client {
79    base_uri: String,
80    credentials: CredentialsMode,
81    mode: RequestMode,
82    encoding: Encoding,
83}
84
85impl Client {
86    pub fn new(base_uri: String) -> Self {
87        Client {
88            base_uri,
89            credentials: CredentialsMode::SameOrigin,
90            mode: RequestMode::Cors,
91            encoding: Encoding::None,
92        }
93    }
94
95    async fn request(self, rpc: Request<BoxBody>) -> Result<Response<BoxBody>, ClientError> {
96        let mut uri = rpc.uri().to_string();
97        uri.insert_str(0, &self.base_uri);
98
99        let headers =
100            Headers::new().map_err(|_| ClientError::Other("Failed to create headers".into()))?;
101
102        for (k, v) in rpc.headers().iter() {
103            headers.set(k.as_str(), v.to_str()?)?;
104        }
105        headers.set("x-user-agent", "grpc-web-rust/0.1")?;
106        headers.set("content-type", self.encoding.to_content_type())?;
107
108        let body_bytes = hyper::body::to_bytes(rpc.into_body())
109            .await
110            .map_err(|_| ClientError::Other("Failed to convert RPC body to bytes".into()))?;
111
112        let body_array: Uint8Array = body_bytes.as_ref().into();
113        let body_js: &JsValue = body_array.as_ref();
114
115        let mut init = request::post_init(self);
116        init.body(Some(body_js)).headers(headers.as_ref());
117
118        let request = web_sys::Request::new_with_str_and_init(&uri, &init)?;
119        let fetch_res = request::fetch_with_request(request).await?;
120
121        let mut res = Response::builder().status(fetch_res.status());
122        let headers = res
123            .headers_mut()
124            .ok_or_else(|| ClientError::Other("Could not get response headers".into()))?;
125
126        for kv in js_sys::try_iter(fetch_res.headers().as_ref())?
127            .ok_or_else(|| ClientError::Other("Response headers iterator was empty".into()))?
128        {
129            let pair: Array = kv?.into();
130            headers.append(
131                HeaderName::from_bytes(
132                    pair.get(0)
133                        .as_string()
134                        .ok_or_else(|| ClientError::Other("Header pair had no name".into()))?
135                        .as_bytes(),
136                )?,
137                HeaderValue::from_str(
138                    &pair
139                        .get(1)
140                        .as_string()
141                        .ok_or_else(|| ClientError::Other("Header pair had no value".into()))?,
142                )?,
143            );
144        }
145
146        let body_stream = ReadableStream::from_raw(
147            fetch_res
148                .body()
149                .ok_or_else(|| ClientError::Other("Response body was empty".into()))?
150                .unchecked_into(),
151        );
152        let body = GrpcWebCall::client_response(
153            ReadableStreamBody::new(body_stream),
154            Encoding::from_content_type(headers),
155        );
156
157        Ok(res
158            .body(BoxBody::new(body))
159            .map_err(|e| ClientError::Other(format!("An HTTP error ocurred: {}", e)))?)
160    }
161}
162
163impl GrpcService<BoxBody> for Client {
164    type ResponseBody = BoxBody;
165    type Error = ClientError;
166    type Future = Pin<Box<dyn Future<Output = Result<Response<BoxBody>, ClientError>>>>;
167
168    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169        Poll::Ready(Ok(()))
170    }
171
172    fn call(&mut self, rpc: Request<BoxBody>) -> Self::Future {
173        Box::pin(self.clone().request(rpc))
174    }
175}
176
177struct ReadableStreamBody {
178    stream: Pin<Box<dyn Stream<Item = Result<Bytes, Status>>>>,
179}
180
181impl ReadableStreamBody {
182    fn new(inner: ReadableStream) -> Self {
183        ReadableStreamBody {
184            stream: Box::pin(
185                inner
186                    .into_stream()
187                    .map_ok(|buf_js| {
188                        let buffer = Uint8Array::new(&buf_js);
189                        let mut bytes_vec = vec![0; buffer.length() as usize];
190                        buffer.copy_to(&mut bytes_vec);
191                        let bytes: Bytes = bytes_vec.into();
192                        bytes
193                    })
194                    .map_err(|_| Status::unknown("readablestream error")),
195            ),
196        }
197    }
198}
199
200impl Body for ReadableStreamBody {
201    type Data = Bytes;
202    type Error = Status;
203
204    fn poll_data(
205        mut self: Pin<&mut Self>,
206        cx: &mut Context<'_>,
207    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
208        self.stream.as_mut().poll_next(cx)
209    }
210
211    fn poll_trailers(
212        self: Pin<&mut Self>,
213        _: &mut Context<'_>,
214    ) -> Poll<Result<Option<HeaderMap>, Self::Error>> {
215        Poll::Ready(Ok(None))
216    }
217
218    fn is_end_stream(&self) -> bool {
219        false
220    }
221}
222
223// WARNING: these are required to satisfy the Body and Error traits, but JsValue is not thread-safe.
224// This shouldn't be an issue because wasm doesn't have threads currently.
225
226unsafe impl Sync for ReadableStreamBody {}
227unsafe impl Send for ReadableStreamBody {}
228
229unsafe impl Sync for ClientError {}
230unsafe impl Send for ClientError {}