1use crate::client_data::ClientData;
2use crate::error::{Error, ErrorKind};
3use crate::token::Token;
4use reqwest::header::{HeaderMap, HeaderValue};
5use serde_json::{Map, Value};
6use std::sync::Arc;
7use std::time::Instant;
8
9const EXPIRATION_ERROR_MARGIN_IN_SECONDS: u64 = 5;
10
11pub struct Identity {
12 identity_provider_url: String,
13 client_data: ClientData,
14}
15
16impl Identity {
17 pub async fn try_new(
18 identity_provider_url: String,
19 client_data: ClientData,
20 ) -> Result<Identity, Error> {
21 Ok(Identity {
22 identity_provider_url,
23 client_data,
24 })
25 }
26
27 pub async fn try_get_token(&self) -> Result<Arc<Token>, Error> {
28 let token =
29 Arc::new(try_get_new_token(&self.identity_provider_url, &self.client_data).await?);
30
31 Ok(token)
32 }
33
34 pub async fn renew_token_if_expiring_after_seconds(
35 &self,
36 token: Arc<Token>,
37 expiring_after_seconds: u64,
38 ) -> Result<Arc<Token>, Error> {
39 if token.does_expire_after(expiring_after_seconds) {
40 return self.try_get_token().await;
41 }
42
43 Ok(token)
44 }
45}
46
47async fn try_get_new_token(
48 identity_provider_url: &String,
49 client_data: &ClientData,
50) -> Result<Token, Error> {
51 let client = reqwest::Client::new();
52 let mut header = HeaderMap::new();
53 let content_type_value = match HeaderValue::from_str("application/json") {
54 Ok(value) => value,
55 Err(error) => {
56 return Err(Error::new(
57 ErrorKind::InternalFailure,
58 format!("failed to set content type: {}", error),
59 ))
60 }
61 };
62 header.insert("content-type", content_type_value);
63
64 let response = match client
65 .post(identity_provider_url)
66 .headers(header)
67 .body(client_data.json())
68 .send()
69 .await
70 {
71 Ok(response) => response,
72 Err(error) => {
73 return Err(Error::new(
74 ErrorKind::IdentityProviderFailure,
75 format!("failed to get new token: {}", error),
76 ))
77 }
78 };
79
80 let response_object = match response.json::<Map<String, Value>>().await {
81 Ok(response_object) => response_object,
82 Err(error) => {
83 return Err(Error::new(
84 ErrorKind::IdentityProviderFailure,
85 format!("failed to parse response: {}", error),
86 ))
87 }
88 };
89
90 let token = extract_token_from_response_object(response_object)?;
91
92 Ok(token)
93}
94
95fn extract_token_from_response_object(
96 mut response_object: Map<String, Value>,
97) -> Result<Token, Error> {
98 let generated_at = Instant::now();
99
100 let token_string = match response_object.remove("access_token") {
101 Some(token_string) => match token_string.as_str() {
102 Some(token_string) => token_string.to_string(),
103 None => {
104 return Err(Error::new(
105 ErrorKind::MalformedResponse,
106 "failed to read token as string",
107 ))
108 }
109 },
110 None => return Err(Error::new(ErrorKind::MalformedResponse, "missing token")),
111 };
112
113 let mut duration = match response_object.remove("expires_in") {
114 Some(duration) => match duration.as_u64() {
115 Some(duration) => duration,
116 None => {
117 return Err(Error::new(
118 ErrorKind::MalformedResponse,
119 "failed to read duration as u64",
120 ))
121 }
122 },
123 None => return Err(Error::new(ErrorKind::MalformedResponse, "missing duration")),
124 };
125
126 if duration <= EXPIRATION_ERROR_MARGIN_IN_SECONDS {
127 return Err(Error::new(
128 ErrorKind::MalformedResponse,
129 "duration is too short",
130 ));
131 } else {
132 duration -= EXPIRATION_ERROR_MARGIN_IN_SECONDS;
133 }
134
135 Ok(Token::new(token_string, generated_at, duration))
136}
137
138#[cfg(test)]
139use serde_json::Number;
140
141#[test]
142fn durations_lower_or_equal_to_error_margin_are_rejected() {
143 let mut response_object = Map::new();
144 response_object.insert(
145 "access_token".to_string(),
146 Value::String("token".to_string()),
147 );
148 response_object.insert("expires_in".to_string(), Value::Number(5.into()));
149
150 assert!(extract_token_from_response_object(response_object).is_err());
151}
152
153#[test]
154fn extract_token_from_response_object_correctly() {
155 const TOKEN: &str = "token";
156 let mut response_object = Map::new();
157 response_object.insert("access_token".to_string(), Value::String(TOKEN.to_string()));
158 response_object.insert("expires_in".to_string(), Value::Number(Number::from(10u64)));
159
160 let token = extract_token_from_response_object(response_object).unwrap();
161
162 assert_eq!(token.value(), TOKEN);
163 assert!(!token.is_expired());
164}