1use crate::dto::utils::MaybeStringU64;
2use async_trait::async_trait;
3use futures_locks::RwLock;
4use reqwest::{
5 header::{HeaderMap, HeaderValue},
6 StatusCode,
7};
8use reqwest_middleware::ClientWithMiddleware;
9use serde::{Deserialize, Serialize};
10use std::time::{Duration, Instant};
11use std::{fmt::Display, sync::Arc};
12use thiserror::Error;
13
14type CustomAuthCallback =
16 dyn Fn(&mut HeaderMap, &ClientWithMiddleware) -> Result<(), AuthenticatorError> + Send + Sync;
17
18#[async_trait]
19pub trait CustomAuthenticator {
22 async fn set_headers(
31 &self,
32 headers: &mut HeaderMap,
33 client: &ClientWithMiddleware,
34 ) -> Result<(), AuthenticatorError>;
35}
36
37#[derive(Clone)]
39pub enum AuthHeaderManager {
40 OIDCToken(Arc<Authenticator>),
42 FixedToken(String),
44 AuthTicket(String),
46 Custom(Arc<CustomAuthCallback>),
48 CustomAsync(Arc<dyn CustomAuthenticator + Send + Sync>),
50}
51
52impl AuthHeaderManager {
53 pub async fn set_headers(
61 &self,
62 headers: &mut HeaderMap,
63 client: &ClientWithMiddleware,
64 ) -> Result<(), AuthenticatorError> {
65 match self {
66 AuthHeaderManager::OIDCToken(a) => {
67 let token = a.get_token(client).await?;
68 let auth_header_value =
69 HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| {
70 AuthenticatorError::internal_error(
71 "Failed to set authorization bearer token".to_string(),
72 Some(e.to_string()),
73 )
74 })?;
75 headers.insert("Authorization", auth_header_value);
76 }
77 AuthHeaderManager::FixedToken(token) => {
78 let auth_header_value =
79 HeaderValue::from_str(&format!("Bearer {token}")).map_err(|e| {
80 AuthenticatorError::internal_error(
81 "Failed to set authorization bearer token".to_string(),
82 Some(e.to_string()),
83 )
84 })?;
85 headers.insert("Authorization", auth_header_value);
86 }
87 AuthHeaderManager::AuthTicket(t) => {
88 let auth_ticket_header_value = HeaderValue::from_str(t).map_err(|e| {
89 AuthenticatorError::internal_error(
90 "Failed to set auth ticket".to_string(),
91 Some(e.to_string()),
92 )
93 })?;
94 headers.insert("auth-ticket", auth_ticket_header_value);
95 }
96 AuthHeaderManager::Custom(c) => c(headers, client)?,
97 AuthHeaderManager::CustomAsync(c) => c.set_headers(headers, client).await?,
98 }
99 Ok(())
100 }
101}
102
103pub struct AuthenticatorConfig {
105 pub client_id: String,
107 pub token_url: String,
109 pub secret: String,
111 pub resource: Option<String>,
113 pub audience: Option<String>,
115 pub scopes: Option<String>,
117 pub default_expires_in: Option<u64>,
122}
123
124#[derive(Serialize, Deserialize, Debug)]
125struct AuthenticatorRequest {
126 client_id: String,
127 client_secret: String,
128 resource: Option<String>,
129 audience: Option<String>,
130 scope: Option<String>,
131 grant_type: String,
132}
133
134impl AuthenticatorRequest {
135 fn new(config: AuthenticatorConfig) -> AuthenticatorRequest {
136 AuthenticatorRequest {
137 client_id: config.client_id,
138 client_secret: config.secret,
139 grant_type: "client_credentials".to_string(),
140 resource: config.resource,
141 audience: config.audience,
142 scope: config.scopes,
143 }
144 }
145}
146
147#[derive(Serialize, Deserialize, Debug)]
148struct AuthenticatorResponse {
149 access_token: String,
150 expires_in: Option<MaybeStringU64>,
151}
152
153#[derive(Serialize, Deserialize, Debug, Error)]
154pub struct AuthenticatorError {
156 pub error: String,
158 pub error_description: Option<String>,
160 pub error_uri: Option<String>,
162}
163
164impl AuthenticatorError {
165 pub fn internal_error(error: String, error_description: Option<String>) -> Self {
172 Self {
173 error,
174 error_description,
175 error_uri: None,
176 }
177 }
178}
179
180impl Display for AuthenticatorError {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 write!(f, "{}", self.error,)?;
183 if let Some(error_description) = &self.error_description {
184 write!(f, ": {error_description}")?;
185 }
186 if let Some(error_uri) = &self.error_uri {
187 write!(f, " ({error_uri})")?;
188 }
189 Ok(())
190 }
191}
192
193struct AuthenticatorState {
194 last_token: Option<String>,
195 current_token_expiry: Instant,
196}
197
198pub struct AuthenticatorResult {
200 token: String,
202 expiry: Instant,
204}
205
206pub struct Authenticator {
208 req: AuthenticatorRequest,
209 state: RwLock<AuthenticatorState>,
210 token_url: String,
211 default_expires_in: Option<Duration>,
212}
213
214impl AuthenticatorResult {
215 pub fn token(&self) -> &str {
217 &self.token
218 }
219
220 pub fn into_token(self) -> String {
222 self.token
223 }
224
225 pub fn expiry(&self) -> Instant {
227 self.expiry
228 }
229}
230
231impl Authenticator {
232 pub fn new(config: AuthenticatorConfig) -> Authenticator {
238 Authenticator {
239 token_url: config.token_url.clone(),
240 default_expires_in: config.default_expires_in.map(Duration::from_secs),
241 req: AuthenticatorRequest::new(config),
242 state: RwLock::new(AuthenticatorState {
243 last_token: None,
244 current_token_expiry: Instant::now(),
245 }),
246 }
247 }
248
249 async fn request_token(
250 &self,
251 client: &ClientWithMiddleware,
252 ) -> Result<AuthenticatorResult, AuthenticatorError> {
253 let response = client
254 .post(&self.token_url)
255 .form(&self.req)
256 .send()
257 .await
258 .map_err(|e| {
259 AuthenticatorError::internal_error(
260 "Something went wrong when sending the request".to_string(),
261 Some(e.to_string()),
262 )
263 })?;
264
265 let status = response.status();
266
267 let start = Instant::now();
268
269 let response = response.text().await.map_err(|e| {
270 AuthenticatorError::internal_error(
271 "Failed to receive response contents".to_owned(),
272 Some(e.to_string()),
273 )
274 })?;
275
276 if status != StatusCode::OK {
277 return match serde_json::from_str(&response) {
278 Ok(e) => Err(e),
279 Err(e) => Err(AuthenticatorError::internal_error(
280 format!("Something went wrong (status: {status}), but the response error couldn't be deserialized. Raw response: {response}")
281 , Some(e.to_string())))
282 };
283 }
284
285 let response: AuthenticatorResponse = serde_json::from_str(&response).map_err(|e| {
286 AuthenticatorError::internal_error(
287 "Failed to deserialize response from OAuth endpoint".to_string(),
288 Some(e.to_string()),
289 )
290 })?;
291
292 let token = response.access_token;
293 let Some(expires_in) = response
294 .expires_in
295 .map(|m| Duration::from_secs(m.0.saturating_sub(60)))
300 .or(self.default_expires_in)
301 else {
302 return Err(AuthenticatorError::internal_error(
303 "Missing expires_in in response, and no default expiration configured".to_owned(),
304 None,
305 ));
306 };
307
308 Ok(AuthenticatorResult {
309 token,
310 expiry: start + expires_in,
311 })
312 }
313
314 pub async fn get_token_with_expiry(
324 &self,
325 client: &ClientWithMiddleware,
326 ) -> Result<AuthenticatorResult, AuthenticatorError> {
327 let now = Instant::now();
328 {
329 let state = &*self.state.read().await;
330 if let Some(last) = &state.last_token {
331 if state.current_token_expiry > now {
332 return Ok(AuthenticatorResult {
333 token: last.clone(),
334 expiry: state.current_token_expiry,
335 });
336 }
337 }
338 }
339
340 let mut write = self.state.write().await;
342
343 if let Some(last) = &write.last_token {
346 if write.current_token_expiry > now {
347 return Ok(AuthenticatorResult {
348 token: last.clone(),
349 expiry: write.current_token_expiry,
350 });
351 }
352 }
353
354 match self.request_token(client).await {
355 Ok(response) => {
356 write.current_token_expiry = response.expiry;
357 write.last_token = Some(response.token.clone());
358 Ok(response)
359 }
360 Err(e) => Err(e),
361 }
362 }
363
364 pub async fn get_token(
371 &self,
372 client: &ClientWithMiddleware,
373 ) -> Result<String, AuthenticatorError> {
374 Ok(self.get_token_with_expiry(client).await?.token)
375 }
376}