1#![allow(static_mut_refs)]
2use std::{collections::HashMap, error::Error, sync::RwLock};
3
4use reqwest::RequestBuilder;
5use serde::Serialize;
6
7pub mod signature;
8pub use signature::Signature;
9
10pub trait Signer {
11 fn sign(&self, msg: &str) -> Result<Signature, Box<dyn Error>>;
12 fn signer(&self) -> String;
13}
14
15pub trait RequestHooker {
16 fn before_request(&self, req: RequestBuilder) -> RequestBuilder;
17}
18
19static mut SIGNER: Option<RwLock<Box<dyn Signer>>> = None;
20static mut MESSAGE: Option<String> = None;
21static mut HEADERS: RwLock<Option<HashMap<String, String>>> = RwLock::new(None);
24
25pub fn get_authz_token() -> Option<String> {
42 unsafe {
43 let headers = HEADERS.read().unwrap();
44 match headers.as_ref() {
45 Some(headers) => headers
46 .get("Authorization")
47 .cloned()
48 .unwrap_or_default()
49 .split(" ")
50 .last()
51 .map(|s| s.to_string()),
52 None => None,
53 }
54 }
55}
56
57pub fn get_header(key: &str) -> Option<String> {
58 unsafe {
59 let headers = HEADERS.read().unwrap();
60 match headers.as_ref() {
61 Some(headers) => headers.get(key).cloned(),
62 None => None,
63 }
64 }
65}
66
67pub fn add_header(key: String, value: String) {
68 unsafe {
69 let mut headers = HEADERS.write().unwrap();
70 match headers.as_mut() {
71 Some(headers) => {
72 headers.insert(key, value);
73 }
74 None => {
75 let mut new_headers = HashMap::new();
76 new_headers.insert(key, value);
77 *headers = Some(new_headers);
78 }
79 }
80 }
81}
82
83pub fn remove_header(key: &str) {
84 unsafe {
85 let mut headers = HEADERS.write().unwrap();
86 match headers.as_mut() {
87 Some(headers) => {
88 headers.remove(key);
89 }
90 None => {}
91 }
92 }
93}
94
95pub fn set_signer(signer: Box<dyn Signer>) {
96 unsafe {
97 SIGNER = Some(RwLock::new(signer));
98 }
99}
100
101pub fn remove_signer() {
102 unsafe {
103 SIGNER = None;
104 }
105}
106
107pub fn set_message(msg: String) {
108 unsafe {
109 MESSAGE = Some(msg);
110 }
111}
112
113pub fn sign_request(req: RequestBuilder) -> RequestBuilder {
114 if let (Some(signer), Some(msg)) = unsafe { (&SIGNER, &MESSAGE) } {
115 let signer = signer.read().unwrap();
116 let address = signer.signer();
117 tracing::debug!("Signer address: {}", address);
118 if address.is_empty() {
119 return req;
120 }
121
122 let timestamp = chrono::Utc::now().timestamp();
123 let msg = format!("{}-{}", msg, timestamp);
124 tracing::debug!("Signing message: {}", msg);
125 let signature = signer.sign(&msg);
126 if signature.is_err() {
127 return req;
128 }
129
130 let signature = signature.unwrap();
131 req.header(
132 reqwest::header::AUTHORIZATION,
133 format!("UserSig {timestamp}:{signature}"),
134 )
135 } else {
136 tracing::debug!("No signer found");
137 req
138 }
139}
140
141pub fn add_authorization(token: &str) {
142 unsafe {
143 let mut headers = HEADERS.write().unwrap();
144 match headers.as_mut() {
145 Some(headers) => {
146 headers.insert("Authorization".to_string(), token.to_string());
147 }
148 None => {
149 let mut new_headers = HashMap::new();
150 new_headers.insert("Authorization".to_string(), token.to_string());
151 *headers = Some(new_headers);
152 }
153 }
154 }
155}
156
157pub fn extract_for_next_request(res: &reqwest::Response) {
158 let headers = res.headers();
159 if let Some(authz) = headers.get(reqwest::header::AUTHORIZATION) {
160 let authz = authz.to_str().unwrap();
161 add_authorization(authz);
162 } else if let Some(authz) = headers.get("x-amzn-remapped-authorization") {
163 let authz = authz.to_str().unwrap();
164 add_authorization(authz);
165 }
166}
167
168pub fn load_headers(mut req: RequestBuilder) -> RequestBuilder {
169 unsafe {
170 match HEADERS.read().unwrap().as_ref() {
171 Some(ref headers) => {
172 for (k, v) in headers.iter() {
173 req = req.header(k, v);
174 }
175
176 req
177 }
178 None => req,
179 }
180 }
181}
182
183pub async fn send(req: RequestBuilder) -> reqwest::Result<reqwest::Response> {
184 let req = sign_request(req);
186 let req = load_headers(req);
187 let res = req.send().await;
188 if let Ok(res) = &res {
189 extract_for_next_request(res);
190 }
191
192 res
193}
194
195pub async fn get<T, E>(url: &str) -> Result<T, E>
196where
197 T: serde::de::DeserializeOwned,
198 E: serde::de::DeserializeOwned + From<reqwest::Error>,
199{
200 let client = reqwest::Client::builder().build()?;
201
202 let req = client.get(url);
203 let res = send(req).await?;
204
205 if res.status().is_success() {
206 Ok(res.json().await?)
207 } else {
208 Err(res.json().await?)
209 }
210}
211
212pub async fn get_with_query<T, E, P>(url: &str, query_params: &P) -> Result<T, E>
221where
222 T: serde::de::DeserializeOwned,
223 E: serde::de::DeserializeOwned + From<reqwest::Error>,
224 P: serde::Serialize + ?Sized,
225{
226 let client = reqwest::Client::builder().build()?;
227
228 let req = client.get(url).query(query_params);
229 let res = send(req).await?;
230
231 if res.status().is_success() {
232 Ok(res.json().await?)
233 } else {
234 Err(res.json().await?)
235 }
236}
237
238pub async fn post<R, T, E>(url: &str, body: R) -> Result<T, E>
239where
240 R: Serialize,
241 T: serde::de::DeserializeOwned,
242 E: serde::de::DeserializeOwned + From<reqwest::Error>,
243{
244 let client = reqwest::Client::builder().build()?;
245
246 let req = client.post(url).json(&body);
247 let res = send(req).await?;
248
249 if res.status().is_success() {
250 Ok(res.json().await?)
251 } else {
252 Err(res.json().await?)
253 }
254}