1use std::fmt;
2
3use base64::{engine::general_purpose, Engine as _};
4use chrono::Utc;
5use reqwest::{
6 header::{HeaderMap, HeaderValue},
7 RequestBuilder,
8};
9use thiserror::Error;
10use crate::shared::{ExportableEncryptionKeyData};
11
12#[derive(Error, Debug)]
13pub enum RequestError {
14 #[error("reqwest failed")]
15 ReqwestError(#[from] reqwest::Error),
16 #[error("unable to create authorization")]
17 AuthConstructionError,
18 #[error("bootstrapping encrypted request failed.")]
19 ReKeyError,
20 #[error("handling the response failed")]
21 HandlingResponse(#[from] crate::client::ResponseError),
22 #[error("the argument provided was not one that can be handled")]
23 InvalidArgument,
24 #[error("the request could not be encrypted")]
25 EncryptionError,
26 #[error("the token provided has expired, and could not be renewed")]
27 TokenExpired,
28}
29
30#[derive(Debug, Clone)]
57pub struct Request<UT, RT>
58where
59 UT: UpdateTokenTrait,
60 RT: RequestTrait,
61{
62 pub client: reqwest::Client,
63 pub endpoint: String,
64 pub token: Option<crate::Token>,
65 pub ut: Option<UT>,
66 pub rt: Option<RT>,
67 ek: Option<ExportableEncryptionKeyData>,
68}
69
70#[derive(Debug, Clone)]
71pub enum Method {
72 Get,
73 Post,
74 Put,
75 Patch,
76 Delete,
77}
78
79impl fmt::Display for Method {
80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81 write!(f, "{:?}", self)
82 }
83}
84
85pub trait UpdateTokenTrait: Send + Sync {
86 fn token_update(&self, _token: crate::Token) -> bool {
89 return true;
90 }
91}
92
93pub trait RequestTrait: Send + Sync {
94 fn before(&self, builder: RequestBuilder) -> RequestBuilder {
96 return builder;
97 }
98
99 fn after(&self, _response: crate::client::Response) {
101 return;
102 }
103}
104
105impl<UT: UpdateTokenTrait, RT: RequestTrait> Request<UT, RT> {
106 pub fn new_simple(
108 client: reqwest::Client,
109 endpoint: &str,
110 token: Option<crate::Token>,
111 ) -> Self {
112 return Self::new(client, endpoint, token, None, None);
113 }
114
115 pub fn new(
117 client: reqwest::Client,
118 endpoint: &str,
119 token: Option<crate::Token>,
120 ut: Option<UT>,
121 rt: Option<RT>,
122 ) -> Self {
123 Self {
124 client,
125 endpoint: endpoint.to_string(),
126 token,
127 ut,
128 rt,
129 ek: None,
130 }
131 }
132
133 pub fn update_token(&mut self, token: Option<crate::Token>) {
135 self.token = token.clone();
136
137 match &self.ut {
138 Some(callback) => match token {
139 Some(token) => {
140 callback.token_update(token);
141 }
142 None => {}
143 },
144 None => {}
145 };
146 }
147
148 #[async_recursion::async_recursion]
152 pub async fn rekey(&mut self, hashid: Option<String>) -> Result<bool, RequestError> {
153 let kp = crate::Keypair::new();
154 let mut headers = HeaderMap::new();
155 headers.insert(
156 "Content-Type",
157 HeaderValue::from_str(&"application/json").unwrap(),
158 );
159
160 match hashid.clone() {
161 Some(hashid) => {
162 headers.insert(
163 "Accept",
164 HeaderValue::from_str(&"application/vnd.ncryptf+json").unwrap(),
165 );
166 headers.insert("X-HashId", HeaderValue::from_str(&hashid).unwrap());
167 let pk = general_purpose::STANDARD.encode(kp.get_public_key());
168 headers.insert("X-PubKey", HeaderValue::from_str(&pk).unwrap());
169 }
170 _ => {
171 headers.insert(
172 "Accept",
173 HeaderValue::from_str(&"application/json").unwrap(),
174 );
175 }
176 };
177
178 let furi = format!("{}{}", self.endpoint, "/ncryptf/ek");
179 let builder = self.client.clone().get(furi).headers(headers);
180
181 match self.do_request(builder, kp).await {
182 Ok(response) => match response.status {
183 reqwest::StatusCode::OK => match serde_json::from_str::<
184 ExportableEncryptionKeyData,
185 >(&response.body.unwrap())
186 {
187 Ok(ek) => {
188 self.ek = Some(ek.clone());
189 match hashid.clone() {
190 Some(_) => return Ok(true),
191 _ => return self.rekey(Some(ek.hash_id)).await,
192 }
193 }
194 Err(_error) => return Err(RequestError::ReKeyError),
195 },
196 _ => return Err(RequestError::ReKeyError),
197 },
198 Err(_error) => return Err(RequestError::ReKeyError),
199 };
200 }
201
202 pub async fn get(&mut self, url: &str) -> Result<crate::client::Response, RequestError> {
204 return self.execute(Method::Get, url, None).await;
205 }
206
207 pub async fn delete(
209 &mut self,
210 url: &str,
211 payload: Option<&str>,
212 ) -> Result<crate::client::Response, RequestError> {
213 return self.execute(Method::Delete, url, payload).await;
214 }
215
216 pub async fn patch(
218 &mut self,
219 url: &str,
220 payload: Option<&str>,
221 ) -> Result<crate::client::Response, RequestError> {
222 return self.execute(Method::Patch, url, payload).await;
223 }
224
225 pub async fn post(
227 &mut self,
228 url: &str,
229 payload: Option<&str>,
230 ) -> Result<crate::client::Response, RequestError> {
231 return self.execute(Method::Post, url, payload).await;
232 }
233
234 pub async fn put(
236 &mut self,
237 url: &str,
238 payload: Option<&str>,
239 ) -> Result<crate::client::Response, RequestError> {
240 return self.execute(Method::Put, url, payload).await;
241 }
242
243 #[async_recursion::async_recursion]
251 async fn execute(
252 &mut self,
253 method: Method,
254 url: &str,
255 payload: Option<&'async_recursion str>,
256 ) -> Result<crate::client::Response, RequestError> {
257 let payload_actual = match payload {
258 Some(payload) => payload,
259 None => "",
260 };
261
262 match &self.ek {
263 Some(ek) => {
264 if ek.is_expired() {
265 match self.rekey(None).await {
266 Ok(_) => {}
267 Err(error) => return Err(error),
268 };
269 }
270 }
271 _ => match self.rekey(None).await {
272 Ok(_) => {}
273 Err(error) => return Err(error),
274 },
275 };
276
277 let auth: Option<crate::Authorization> = match self.token.clone() {
278 Some(mut token) => {
279 let expiration_limit = chrono::Utc::now().timestamp() + 120;
281 if token.expires_at <= expiration_limit {
282 let refresh_token = token.refresh_token;
283 self.token = None;
285
286 match self
287 .post(
288 format!("/ncryptf/token/refresh?refresh_token={}", refresh_token)
289 .as_str(),
290 None,
291 )
292 .await
293 {
294 Ok(response) => match response.status {
295 reqwest::StatusCode::OK => match response.into::<crate::Token>() {
296 Ok(tt) => {
297 self.update_token(Some(tt.clone()));
298 token = self.token.clone().unwrap();
299 }
300 Err(_error) => return Err(RequestError::TokenExpired),
301 },
302 _ => return Err(RequestError::TokenExpired),
303 },
304 Err(_error) => return Err(RequestError::TokenExpired),
305 };
306 }
307
308 match crate::Authorization::from(
310 method.to_string().to_uppercase(),
311 url.to_string().clone(),
312 token.clone(),
313 Utc::now(),
314 payload_actual.to_string(),
315 None,
316 None,
317 ) {
318 Ok(auth) => Some(auth),
319 Err(_error) => return Err(RequestError::AuthConstructionError),
320 }
321 }
322 None => None,
323 };
324
325 let kp = crate::Keypair::new();
326
327 let mut headers = HeaderMap::new();
328 headers.insert(
329 "Accept",
330 HeaderValue::from_str(&"application/vnd.ncryptf+json").unwrap(),
331 );
332 headers.insert(
334 "X-PubKey",
335 HeaderValue::from_str(&general_purpose::STANDARD.encode(kp.get_public_key())).unwrap(),
336 );
337 headers.insert(
338 "X-HashId",
339 HeaderValue::from_str(&self.ek.clone().unwrap().hash_id).unwrap(),
340 );
341
342 match auth {
343 Some(auth) => {
344 headers.insert(
345 "Authorization",
346 HeaderValue::from_str(auth.get_header().as_str()).unwrap(),
347 );
348 }
349 _ => {}
350 }
351
352 let furi = format!("{}{}", self.endpoint, url);
353 let mut builder: reqwest::RequestBuilder = match method {
354 Method::Get => self.client.clone().get(furi),
355 Method::Post => self.client.clone().post(furi),
356 Method::Put => self.client.clone().put(furi),
357 Method::Delete => self.client.clone().delete(furi),
358 Method::Patch => self.client.clone().patch(furi),
359 };
360
361 match payload_actual {
362 "" => {
363 headers.insert(
364 "Content-Type",
365 HeaderValue::from_str(&"application/json").unwrap(),
366 );
367 }
368 _ => {
369 headers.insert(
370 "Content-Type",
371 HeaderValue::from_str(&"application/vnd.ncryptf+json").unwrap(),
372 );
373 let sk = match self.token.clone() {
374 Some(token) => token.signature,
375 None => {
376 let sk = crate::Signature::new();
377 sk.get_secret_key()
378 }
379 };
380
381 let mut request = crate::Request::from(kp.get_secret_key(), sk).unwrap();
382 match request.encrypt(
383 payload_actual.to_string(),
384 self.ek.as_ref().unwrap().clone().get_public_key().unwrap(),
385 ) {
386 Ok(body) => {
387 builder = builder.body(general_purpose::STANDARD.encode(body));
388 }
389 Err(_error) => return Err(RequestError::EncryptionError),
390 }
391 }
392 }
393
394 builder = match &self.rt {
396 Some(rt) => rt.before(builder),
397 None => builder,
398 };
399 builder = builder.headers(headers);
400
401 match self.do_request(builder, kp).await {
402 Ok(response) => match &self.rt {
403 Some(rt) => {
404 rt.after(response.clone());
405 return Ok(response);
406 }
407 None => return Ok(response),
408 },
409 Err(error) => return Err(error),
410 };
411 }
412
413 async fn do_request(
415 &mut self,
416 builder: reqwest::RequestBuilder,
417 kp: crate::Keypair,
418 ) -> Result<crate::client::Response, RequestError> {
419 match builder.send().await {
420 Ok(response) => {
421 if self.ek.is_some() {
424 if self.ek.clone().unwrap().ephemeral || self.ek.clone().unwrap().is_expired() {
425 self.ek = None;
426 }
427 }
428
429 let result = match crate::client::Response::new(response, kp.get_secret_key()).await
430 {
431 Ok(response) => response,
432 Err(error) => return Err(RequestError::HandlingResponse(error)),
433 };
434
435 let hash_id = self.get_header_by_name(result.headers.get("x-hashid"));
437 let expires_at =
438 self.get_header_by_name(result.headers.get("x-public-key-expiration"));
439 let public_key = self.get_key_string_by_result_or_header(
440 result.pk.clone(),
441 result.headers.get("x-public-key"),
442 );
443 let signature_key = self.get_key_string_by_result_or_header(
444 result.sk.clone(),
445 result.headers.get("x-signature-key"),
446 );
447 if hash_id.is_some()
448 && expires_at.is_some()
449 && public_key.is_some()
450 && signature_key.is_some()
451 {
452 let xp = expires_at.unwrap().parse::<i64>();
453 if xp.is_ok() {
454 self.ek = Some(ExportableEncryptionKeyData {
455 public: public_key.unwrap(),
456 signature: signature_key.unwrap(),
457 hash_id: hash_id.unwrap(),
458 ephemeral: false,
459 expires_at: xp.unwrap(),
460 });
461 }
462 }
463
464 return Ok(result);
465 }
466 Err(error) => Err(RequestError::ReqwestError(error)),
467 }
468 }
469
470 fn get_key_string_by_result_or_header(
472 &self,
473 key: Option<Vec<u8>>,
474 header: Option<&HeaderValue>,
475 ) -> Option<String> {
476 match key {
477 Some(key) => Some(general_purpose::STANDARD.encode(key)),
479 None => match header {
481 Some(header) => match header.to_str() {
482 Ok(s) => Some(s.to_string()),
484 Err(_) => None,
485 },
486 None => None,
487 },
488 }
489 }
490
491 fn get_header_by_name(&self, header: Option<&HeaderValue>) -> Option<String> {
493 match header {
494 Some(h) => match h.to_str() {
495 Ok(s) => Some(s.to_string()),
496 Err(_) => None,
497 },
498 None => None,
499 }
500 }
501}