Skip to main content

fetsig/browser/
common.rs

1use std::time::Duration;
2
3use artwrap::TimeoutFutureExt;
4use base64::{Engine, engine::general_purpose};
5use js_sys::{JsString, Uint8Array};
6use smol_str::{SmolStr, ToSmolStr, format_smolstr};
7use wasm_bindgen::{JsCast, JsValue};
8use wasm_bindgen_futures::JsFuture;
9use web_sys::{AbortController, AbortSignal, Response, ResponseType};
10
11use crate::{HEADER_SIGNATURE, MacVerify, MediaType, StatusCode, uformat_smolstr};
12
13#[cfg(feature = "json")]
14use crate::JSONDeserialize;
15
16#[cfg(feature = "postcard")]
17use crate::PostcardDeserialize;
18
19use super::js_error;
20pub fn none(_: StatusCode) {}
21
22#[cfg(all(feature = "json", feature = "postcard"))]
23pub trait FetchDeserializable: JSONDeserialize + PostcardDeserialize {}
24#[cfg(all(feature = "json", feature = "postcard"))]
25impl<F> FetchDeserializable for F where F: JSONDeserialize + PostcardDeserialize {}
26
27#[cfg(all(feature = "json", not(feature = "postcard")))]
28pub trait FetchDeserializable: JSONDeserialize {}
29#[cfg(all(feature = "json", not(feature = "postcard")))]
30impl<F> FetchDeserializable for F where F: JSONDeserialize {}
31
32#[cfg(all(not(feature = "json"), feature = "postcard"))]
33pub trait FetchDeserializable: PostcardDeserialize {}
34#[cfg(all(not(feature = "json"), feature = "postcard"))]
35impl<F> FetchDeserializable for F where F: PostcardDeserialize {}
36
37#[cfg(all(not(feature = "json"), not(feature = "postcard")))]
38pub trait FetchDeserializable {}
39
40pub struct Abort {
41    controller: AbortController,
42}
43
44impl Abort {
45    pub fn new() -> Result<Self, SmolStr> {
46        Ok(Self {
47            controller: AbortController::new().map_err(js_error)?,
48        })
49    }
50
51    pub fn signal(&self) -> AbortSignal {
52        self.controller.signal()
53    }
54
55    pub fn abort(&self) {
56        self.controller.abort()
57    }
58}
59
60pub(crate) struct PendingFetch {
61    url: SmolStr,
62    #[allow(dead_code)]
63    abort: Abort,
64    timeout: Option<Duration>,
65    request_future: JsFuture,
66}
67
68impl PendingFetch {
69    pub fn new(
70        url: impl ToSmolStr,
71        abort: Abort,
72        timeout: Option<Duration>,
73        request_future: JsFuture,
74    ) -> Self {
75        Self {
76            url: url.to_smolstr(),
77            abort,
78            timeout,
79            request_future,
80        }
81    }
82
83    pub async fn wait_completion(self) -> DecodedResponse<Response> {
84        match self
85            .request_future
86            .timeout(self.timeout.unwrap_or_else(|| Duration::from_secs(900)))
87            .await
88        {
89            Ok(Ok(response)) => {
90                let response = response.unchecked_into::<Response>();
91                if !response.ok() && matches!(response.type_(), ResponseType::Error) {
92                    DecodedResponse::new(StatusCode::FetchFailed).with_hint("Fetch network error")
93                } else {
94                    DecodedResponse::new(response.status()).with_response(response)
95                }
96            }
97            Ok(Err(error)) => DecodedResponse::new(StatusCode::FetchFailed).with_hint(
98                uformat_smolstr!("Fetch start failed ({})", js_error(error).as_str()),
99            ),
100            Err(_) => {
101                self.abort.abort();
102                DecodedResponse::new(StatusCode::FetchTimeout).with_hint(self.url)
103            }
104        }
105    }
106}
107
108pub(crate) struct DecodedResponse<R> {
109    status: StatusCode,
110    hint: Option<SmolStr>,
111    response: Option<R>,
112}
113
114impl<R> DecodedResponse<R> {
115    pub fn new(status: impl Into<StatusCode>) -> Self {
116        Self {
117            status: status.into(),
118            hint: None,
119            response: None,
120        }
121    }
122
123    pub fn with_response(mut self, response: R) -> Self {
124        self.response = Some(response);
125        self
126    }
127
128    pub fn with_hint(mut self, hint: impl ToSmolStr) -> Self {
129        self.hint = Some(hint.to_smolstr());
130        self
131    }
132
133    pub fn status(&self) -> StatusCode {
134        self.status
135    }
136
137    pub fn take_response(&mut self) -> Option<R> {
138        self.response.take()
139    }
140
141    pub fn hint(&self) -> Option<&str> {
142        self.hint.as_deref()
143    }
144
145    fn as_empty<U>(self) -> DecodedResponse<U> {
146        DecodedResponse {
147            status: self.status,
148            hint: self.hint,
149            response: None,
150        }
151    }
152}
153
154pub(crate) async fn execute_fetch<R, MV>(fetch: PendingFetch) -> DecodedResponse<R>
155where
156    R: FetchDeserializable,
157    MV: MacVerify,
158{
159    let mut fetched = fetch.wait_completion().await;
160    let Some(response) = fetched.take_response() else {
161        return fetched.as_empty();
162    };
163
164    let status = fetched.status();
165    match status {
166        StatusCode::Ok
167        | StatusCode::Created
168        | StatusCode::BadRequest
169        | StatusCode::Forbidden
170        | StatusCode::InternalServerError
171        | StatusCode::NotFound
172        | StatusCode::Conflict
173        | StatusCode::PayloadTooBig
174        | StatusCode::RateLimited
175        | StatusCode::Unauthorized => match decode_response::<R, MV>(status, response).await {
176            Ok(result) => result,
177            Err(result) => result,
178        },
179        _ => fetched.as_empty(),
180    }
181}
182
183async fn decode_response<R, MV>(
184    status: StatusCode,
185    response: Response,
186) -> Result<DecodedResponse<R>, DecodedResponse<R>>
187where
188    R: FetchDeserializable,
189    MV: MacVerify,
190{
191    let headers = response.headers();
192    let content_type = headers.get("Content-Type").map_err(|error| {
193        DecodedResponse::new(StatusCode::FetchFailed).with_hint(uformat_smolstr!(
194            "Cannot decode Content-Type header: {}.",
195            js_error(error).as_str()
196        ))
197    })?;
198    let media_type = match content_type {
199        Some(content_type) => MediaType::from(content_type.as_str()),
200        None => MediaType::Plain,
201    };
202
203    let signature = headers.get(HEADER_SIGNATURE).map_err(|error| {
204        DecodedResponse::new(StatusCode::FetchFailed).with_hint(uformat_smolstr!(
205            "Cannot decode {} header: {}.",
206            HEADER_SIGNATURE,
207            js_error(error).as_str()
208        ))
209    })?;
210
211    let array_promise = response
212        .array_buffer()
213        .map_err(|_| DecodedResponse::new(StatusCode::DecodeFailed).with_hint("Decode 1"))?;
214    let content_array_buffer = JsFuture::from(array_promise)
215        .await
216        .map_err(|_| DecodedResponse::new(StatusCode::DecodeFailed).with_hint("Decode 2"))?;
217
218    match deserialize_content::<_, MV>(
219        media_type,
220        DeserializeMode::Deserialize,
221        content_array_buffer,
222        signature.as_deref(),
223    ) {
224        Ok(None) => Ok(DecodedResponse::new(status)),
225        Ok(Some(response)) => Ok(DecodedResponse::new(status).with_response(response)),
226        Err((status, hint)) => Err(DecodedResponse::new(status).with_hint(hint)),
227    }
228}
229
230#[derive(Clone, Copy, PartialEq, Eq)]
231pub enum DecodeMode {
232    Base64,
233    Plain,
234}
235
236pub fn decode_content(
237    mode: DecodeMode,
238    content: JsValue,
239) -> Result<Option<Vec<u8>>, (StatusCode, SmolStr)> {
240    let data = if content.is_string() {
241        if let Some(string) = content.dyn_ref::<JsString>().and_then(|s| s.as_string()) {
242            if string.is_empty() {
243                None
244            } else {
245                Some(string.as_bytes().to_vec())
246            }
247        } else {
248            None
249        }
250    } else {
251        // otherwise content is an array buffer
252        let array = Uint8Array::new(&content);
253        if array.length() == 0 {
254            None
255        } else {
256            Some(array.to_vec())
257        }
258    };
259
260    data.map(|data| {
261        if mode == DecodeMode::Base64 {
262            general_purpose::STANDARD_NO_PAD
263                .decode(data)
264                .map_err(|error| (StatusCode::DecodeFailed, format_smolstr!("{error}")))
265        } else {
266            Ok(data)
267        }
268    })
269    .transpose()
270}
271
272#[derive(Clone, Copy, PartialEq, Eq)]
273pub enum DeserializeMode {
274    Base64AndDeserialize,
275    Deserialize,
276}
277
278impl From<DeserializeMode> for DecodeMode {
279    fn from(mode: DeserializeMode) -> Self {
280        match mode {
281            DeserializeMode::Base64AndDeserialize => DecodeMode::Base64,
282            DeserializeMode::Deserialize => DecodeMode::Plain,
283        }
284    }
285}
286
287pub fn deserialize_content<R, MV>(
288    media_type: MediaType,
289    mode: DeserializeMode,
290    content: JsValue,
291    signature: Option<&str>,
292) -> Result<Option<R>, (StatusCode, SmolStr)>
293where
294    R: FetchDeserializable,
295    MV: MacVerify,
296{
297    match media_type {
298        #[cfg(feature = "json")]
299        MediaType::Json => (),
300        #[cfg(feature = "postcard")]
301        MediaType::Postcard => (),
302        _ => Err((StatusCode::UnsupportedMediaType, SmolStr::default()))?,
303    }
304
305    let data = decode_content(mode.into(), content)?;
306    let Some(data) = data else {
307        return Ok(None);
308    };
309
310    match MV::verify(&data, signature) {
311        Ok(true) => (),
312        Ok(false) => Err((
313            StatusCode::DecodeFailed,
314            "Response signature is invalid.".into(),
315        ))?,
316        Err(error) => Err((
317            StatusCode::DecodeFailed,
318            SmolStr::from_iter([
319                "Response signature verification failed: {}.",
320                error.as_str(),
321            ]),
322        ))?,
323    }
324
325    match media_type {
326        #[cfg(feature = "json")]
327        MediaType::Json => R::try_from_json(&data),
328        #[cfg(feature = "postcard")]
329        MediaType::Postcard => R::try_from_postcard(&data),
330        _ => {
331            return Err((
332                StatusCode::UnsupportedMediaType,
333                "Decode/deserialize error, unexpected data flow for unsupported media type.".into(),
334            ));
335        }
336    }
337    .map_err(|error| {
338        (
339            StatusCode::DecodeFailed,
340            SmolStr::from_iter(["Deserialization failed: ", error.as_str()]),
341        )
342    })
343    .map(|response| Some(response))
344}