rest_api/
lib.rs

1#![allow(static_mut_refs)]
2use std::{collections::HashMap, error::Error, sync::RwLock};
3
4use reqwest::RequestBuilder;
5use serde::Serialize;
6
7pub mod signature;
8pub use signature::Signature;
9
10pub trait Signer {
11    fn sign(&self, msg: &str) -> Result<Signature, Box<dyn Error>>;
12    fn signer(&self) -> String;
13}
14
15pub trait RequestHooker {
16    fn before_request(&self, req: RequestBuilder) -> RequestBuilder;
17}
18
19static mut SIGNER: Option<RwLock<Box<dyn Signer>>> = None;
20static mut MESSAGE: Option<String> = None;
21// FIXME: It causes dropping Signal of dioxus
22// static mut HOOKS: RwLock<Vec<Box<dyn RequestHooker>>> = RwLock::new(Vec::new());
23static mut HEADERS: RwLock<Option<HashMap<String, String>>> = RwLock::new(None);
24
25// pub fn add_hook<T: RequestHooker + 'static>(hook: T) {
26//     unsafe {
27//         HOOKS.write().unwrap().push(Box::new(hook));
28//     }
29// }
30
31// pub fn run_hooks(req: RequestBuilder) -> RequestBuilder {
32//     unsafe {
33//         HOOKS
34//             .read()
35//             .unwrap()
36//             .iter()
37//             .fold(req, |req, hook| hook.before_request(req))
38//     }
39// }
40
41pub fn get_authz_token() -> Option<String> {
42    unsafe {
43        let headers = HEADERS.read().unwrap();
44        match headers.as_ref() {
45            Some(headers) => headers
46                .get("Authorization")
47                .cloned()
48                .unwrap_or_default()
49                .split(" ")
50                .last()
51                .map(|s| s.to_string()),
52            None => None,
53        }
54    }
55}
56
57pub fn get_header(key: &str) -> Option<String> {
58    unsafe {
59        let headers = HEADERS.read().unwrap();
60        match headers.as_ref() {
61            Some(headers) => headers.get(key).cloned(),
62            None => None,
63        }
64    }
65}
66
67pub fn add_header(key: String, value: String) {
68    unsafe {
69        let mut headers = HEADERS.write().unwrap();
70        match headers.as_mut() {
71            Some(headers) => {
72                headers.insert(key, value);
73            }
74            None => {
75                let mut new_headers = HashMap::new();
76                new_headers.insert(key, value);
77                *headers = Some(new_headers);
78            }
79        }
80    }
81}
82
83pub fn remove_header(key: &str) {
84    unsafe {
85        let mut headers = HEADERS.write().unwrap();
86        match headers.as_mut() {
87            Some(headers) => {
88                headers.remove(key);
89            }
90            None => {}
91        }
92    }
93}
94
95pub fn set_signer(signer: Box<dyn Signer>) {
96    unsafe {
97        SIGNER = Some(RwLock::new(signer));
98    }
99}
100
101pub fn remove_signer() {
102    unsafe {
103        SIGNER = None;
104    }
105}
106
107pub fn set_message(msg: String) {
108    unsafe {
109        MESSAGE = Some(msg);
110    }
111}
112
113pub fn sign_request(req: RequestBuilder) -> RequestBuilder {
114    if let (Some(signer), Some(msg)) = unsafe { (&SIGNER, &MESSAGE) } {
115        let signer = signer.read().unwrap();
116        let address = signer.signer();
117        tracing::debug!("Signer address: {}", address);
118        if address.is_empty() {
119            return req;
120        }
121
122        let timestamp = chrono::Utc::now().timestamp();
123        let msg = format!("{}-{}", msg, timestamp);
124        tracing::debug!("Signing message: {}", msg);
125        let signature = signer.sign(&msg);
126        if signature.is_err() {
127            return req;
128        }
129
130        let signature = signature.unwrap();
131        req.header(
132            reqwest::header::AUTHORIZATION,
133            format!("UserSig {timestamp}:{signature}"),
134        )
135    } else {
136        tracing::debug!("No signer found");
137        req
138    }
139}
140
141pub fn add_authorization(token: &str) {
142    unsafe {
143        let mut headers = HEADERS.write().unwrap();
144        match headers.as_mut() {
145            Some(headers) => {
146                headers.insert("Authorization".to_string(), token.to_string());
147            }
148            None => {
149                let mut new_headers = HashMap::new();
150                new_headers.insert("Authorization".to_string(), token.to_string());
151                *headers = Some(new_headers);
152            }
153        }
154    }
155}
156
157pub fn extract_for_next_request(res: &reqwest::Response) {
158    let headers = res.headers();
159    if let Some(authz) = headers.get(reqwest::header::AUTHORIZATION) {
160        let authz = authz.to_str().unwrap();
161        add_authorization(authz);
162    } else if let Some(authz) = headers.get("x-amzn-remapped-authorization") {
163        let authz = authz.to_str().unwrap();
164        add_authorization(authz);
165    }
166}
167
168pub fn load_headers(mut req: RequestBuilder) -> RequestBuilder {
169    unsafe {
170        match HEADERS.read().unwrap().as_ref() {
171            Some(ref headers) => {
172                for (k, v) in headers.iter() {
173                    req = req.header(k, v);
174                }
175
176                req
177            }
178            None => req,
179        }
180    }
181}
182
183pub async fn send(req: RequestBuilder) -> reqwest::Result<reqwest::Response> {
184    // let req = run_hooks(req);
185    let req = sign_request(req);
186    let req = load_headers(req);
187    let res = req.send().await;
188    if let Ok(res) = &res {
189        extract_for_next_request(res);
190    }
191
192    res
193}
194
195pub async fn get<T, E>(url: &str) -> Result<T, E>
196where
197    T: serde::de::DeserializeOwned,
198    E: serde::de::DeserializeOwned + From<reqwest::Error>,
199{
200    let client = reqwest::Client::builder().build()?;
201
202    let req = client.get(url);
203    let res = send(req).await?;
204
205    if res.status().is_success() {
206        Ok(res.json().await?)
207    } else {
208        Err(res.json().await?)
209    }
210}
211
212/// Performs an HTTP GET request.
213///
214/// # Arguments
215///
216/// * `url` - The URL to send the request to
217/// * `query_params` - Query parameters for the URL. Pass `&None::<()>` to send request without query parameters
218///
219///
220pub async fn get_with_query<T, E, P>(url: &str, query_params: &P) -> Result<T, E>
221where
222    T: serde::de::DeserializeOwned,
223    E: serde::de::DeserializeOwned + From<reqwest::Error>,
224    P: serde::Serialize + ?Sized,
225{
226    let client = reqwest::Client::builder().build()?;
227
228    let req = client.get(url).query(query_params);
229    let res = send(req).await?;
230
231    if res.status().is_success() {
232        Ok(res.json().await?)
233    } else {
234        Err(res.json().await?)
235    }
236}
237
238pub async fn post<R, T, E>(url: &str, body: R) -> Result<T, E>
239where
240    R: Serialize,
241    T: serde::de::DeserializeOwned,
242    E: serde::de::DeserializeOwned + From<reqwest::Error>,
243{
244    let client = reqwest::Client::builder().build()?;
245
246    let req = client.post(url).json(&body);
247    let res = send(req).await?;
248
249    if res.status().is_success() {
250        Ok(res.json().await?)
251    } else {
252        Err(res.json().await?)
253    }
254}