1use base64::{Engine as _, engine::general_purpose};
2use futures::lock::Mutex;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use posemesh_utils::now_unix_secs;
7use std::sync::Arc;
8
9use crate::errors::{AukiErrorResponse, AuthError, DomainError};
10
11#[derive(Debug, Clone)]
12pub struct AuthClient {
13 pub api_url: String,
14 client: Client,
15 dds_token_cache: Arc<Mutex<Option<DdsTokenCache>>>,
16 user_token_cache: Arc<Mutex<Option<UserTokenCache>>>,
17 pub client_id: String,
18 app_key: Option<String>,
19 app_secret: Option<String>,
20}
21
22#[derive(Debug, Clone)]
23pub struct UserTokenCache {
24 refresh_token: String,
25 access_token: String,
26 expires_at: u64,
27}
28
29impl TokenCache for UserTokenCache {
30 fn get_access_token(&self) -> String {
31 self.access_token.clone()
32 }
33
34 fn get_expires_at(&self) -> u64 {
35 self.expires_at
36 }
37}
38
39#[derive(Debug, Clone)]
40pub(crate) struct DdsTokenCache {
41 access_token: String,
43 claim: JwtClaim,
44}
45
46impl TokenCache for DdsTokenCache {
47 fn get_access_token(&self) -> String {
48 self.access_token.clone()
49 }
50
51 fn get_expires_at(&self) -> u64 {
52 self.claim.exp
53 }
54}
55
56impl Default for DdsTokenCache {
57 fn default() -> Self {
58 Self {
59 access_token: "".to_string(),
60 claim: JwtClaim { exp: 0, org: None },
61 }
62 }
63}
64pub(crate) trait TokenCache {
65 fn get_access_token(&self) -> String;
66 fn get_expires_at(&self) -> u64;
67}
68
69#[derive(Debug, Serialize)]
70pub struct UserCredentials {
71 pub email: String,
72 pub password: String,
73}
74
75#[derive(Debug, Deserialize)]
76pub struct UserTokenResponse {
77 pub access_token: String,
78 pub refresh_token: String,
79}
80
81#[derive(Debug, Deserialize)]
82pub struct DdsTokenResponse {
83 pub access_token: String,
84}
85
86impl AuthClient {
87 pub fn new(api_url: &str, client_id: &str) -> Self {
88 Self {
89 api_url: api_url.to_string(),
90 client: Client::new(),
91 dds_token_cache: Arc::new(Mutex::new(None)),
92 user_token_cache: Arc::new(Mutex::new(None)),
93 client_id: client_id.to_string(),
94 app_key: None,
95 app_secret: None,
96 }
97 }
98
99 pub async fn get_expires_at(&self) -> Result<u64, DomainError> {
101 let token_cache = {
102 let cache = self.user_token_cache.lock().await;
103 cache.clone()
104 };
105 if token_cache.is_none() {
106 let dds_token_cache = {
107 let cache = self.dds_token_cache.lock().await;
108 cache.clone()
109 };
110 if dds_token_cache.is_none() {
111 return Err(DomainError::AuthError(AuthError::Unauthorized(
112 "No token found",
113 )));
114 }
115 return Ok(dds_token_cache.unwrap().claim.exp);
116 }
117 Ok(parse_jwt(&token_cache.unwrap().refresh_token)?.exp)
118 }
119
120 pub async fn sign_in_with_app_credentials(
121 &mut self,
122 app_key: &str,
123 app_secret: &str,
124 ) -> Result<String, DomainError> {
125 self.app_key = Some(app_key.to_string());
126 self.app_secret = Some(app_secret.to_string());
127 *self.dds_token_cache.lock().await = None;
128 *self.user_token_cache.lock().await = None;
129
130 self.get_dds_app_access_token().await
131 }
132
133 pub async fn get_dds_access_token(
138 &self,
139 oidc_access_token: Option<&str>,
140 ) -> Result<String, DomainError> {
141 let result = if let Some(oidc_access_token) = oidc_access_token {
142 self.get_dds_access_token_with_oidc_access_token(oidc_access_token)
143 .await
144 } else if self.app_key.is_some() {
145 self.get_dds_app_access_token().await
146 } else {
147 self.get_dds_user_access_token().await
148 };
149
150 if result.is_err() {
151 *self.dds_token_cache.lock().await = None;
152 *self.user_token_cache.lock().await = None;
153 }
154
155 result
156 }
157
158 async fn get_dds_access_token_with_oidc_access_token(
160 &self,
161 oidc_access_token: &str,
162 ) -> Result<String, DomainError> {
163 *self.dds_token_cache.lock().await = None;
165 *self.user_token_cache.lock().await = None;
166
167 let response = self.get_dds_token_by_token(oidc_access_token).await?;
168 {
169 let mut cache = self.dds_token_cache.lock().await;
170 *cache = Some(DdsTokenCache {
171 access_token: response.access_token.clone(),
172 claim: parse_jwt(&response.access_token)?,
173 });
174 }
175 Ok(response.access_token)
176 }
177
178 async fn get_dds_app_access_token(&self) -> Result<String, DomainError> {
181 let token_cache = {
182 let cache = self.dds_token_cache.lock().await;
183 cache.clone()
184 };
185
186 let app_key = self
187 .app_key
188 .clone()
189 .ok_or(AuthError::Unauthorized("App key is not set"))?;
190 let app_secret = self
191 .app_secret
192 .clone()
193 .ok_or(AuthError::Unauthorized("App secret is not set"))?;
194
195 let token_cache = get_cached_or_fresh_token(
196 &token_cache.unwrap_or(DdsTokenCache {
197 access_token: "".to_string(),
198 claim: JwtClaim { exp: 0, org: None },
199 }),
200 || {
201 let app_key = app_key.to_string();
202 let app_secret = app_secret.to_string();
203 let client = self.client.clone();
204 let api_url = self.api_url.clone();
205 let client_id = self.client_id.clone();
206 async move {
207 let response = client
208 .post(format!("{}/service/domains-access-token", api_url))
209 .basic_auth(app_key, Some(app_secret))
210 .header("Content-Type", "application/json")
211 .header("posemesh-client-id", client_id)
212 .send()
213 .await?;
214
215 if response.status().is_success() {
216 let token_response: DdsTokenResponse = response.json().await?;
217 Ok(DdsTokenCache {
218 access_token: token_response.access_token.clone(),
219 claim: parse_jwt(&token_response.access_token)?,
220 })
221 } else {
222 let status = response.status();
223 let text = response
224 .text()
225 .await
226 .unwrap_or_else(|_| "Unknown error".to_string());
227 Err(AukiErrorResponse {
228 status,
229 error: format!("Failed to get DDS access token. {}", text),
230 }
231 .into())
232 }
233 }
234 },
235 )
236 .await?;
237
238 {
239 let mut cache = self.dds_token_cache.lock().await;
240 *cache = Some(token_cache.clone());
241 }
242
243 Ok(token_cache.access_token)
244 }
245
246 async fn get_dds_user_access_token(&self) -> Result<String, DomainError> {
250 let token_cache = {
251 let cache = self.dds_token_cache.lock().await;
252 cache.clone()
253 };
254
255 if token_cache.is_none() {
256 return Err(AuthError::Unauthorized("No user access token found").into());
257 }
258
259 let user_token_cache = {
260 let cache = self.user_token_cache.lock().await;
261 cache.clone()
262 };
263
264 if user_token_cache.is_none() {
265 return Err(AuthError::Unauthorized("Login first").into());
266 }
267
268 let token_cache = get_cached_or_fresh_token(&token_cache.unwrap(), || {
269 let client = self.client.clone();
270 let api_url = self.api_url.clone();
271 let client_id = self.client_id.clone();
272
273 async move {
274 let client_clone = client.clone();
275 let api_url_clone = api_url.clone();
276 let client_id_clone = client_id.clone();
277 let refresh_token = user_token_cache.clone().unwrap().refresh_token;
278 let user_token_cache =
279 get_cached_or_fresh_token(&user_token_cache.unwrap(), || async move {
280 let response = client_clone
281 .post(format!("{}/user/refresh", api_url_clone))
282 .header("Content-Type", "application/json")
283 .header("posemesh-client-id", client_id_clone)
284 .header("Authorization", format!("Bearer {}", refresh_token))
285 .send()
286 .await
287 .expect("Failed to refresh token");
288
289 if response.status().is_success() {
290 let token_response: UserTokenResponse = response.json().await?;
291 Ok(UserTokenCache {
292 refresh_token: token_response.refresh_token.clone(),
293 access_token: token_response.access_token.clone(),
294 expires_at: parse_jwt(&token_response.access_token)?.exp,
295 })
296 } else {
297 let status = response.status();
298 let text = response
299 .text()
300 .await
301 .unwrap_or_else(|_| "Unknown error".to_string());
302 Err(AukiErrorResponse {
303 status,
304 error: format!("Failed to refresh token. {}", text),
305 }
306 .into())
307 }
308 })
309 .await?;
310
311 {
312 let mut cache = self.user_token_cache.lock().await;
313 *cache = Some(user_token_cache.clone());
314 }
315
316 let dds_token_response = self
317 .get_dds_token_by_token(&user_token_cache.access_token)
318 .await?;
319
320 let dds_cache = DdsTokenCache {
321 access_token: dds_token_response.access_token.clone(),
322 claim: parse_jwt(&dds_token_response.access_token)?,
323 };
324 {
325 let mut cache = self.dds_token_cache.lock().await;
326 *cache = Some(dds_cache.clone());
327 }
328 Ok(dds_cache)
329 }
330 })
331 .await?;
332
333 {
334 let mut cache = self.dds_token_cache.lock().await;
335 *cache = Some(token_cache.clone());
336 }
337
338 Ok(token_cache.access_token)
339 }
340
341 pub async fn user_login(&mut self, email: &str, password: &str) -> Result<String, DomainError> {
343 self.app_key = None;
344 self.app_secret = None;
345
346 let credentials = UserCredentials {
347 email: email.to_string(),
348 password: password.to_string(),
349 };
350
351 let response = self
352 .client
353 .post(format!("{}/user/login", &self.api_url))
354 .header("Content-Type", "application/json")
355 .header("posemesh-client-id", &self.client_id)
356 .json(&credentials)
357 .send()
358 .await?;
359
360 if response.status().is_success() {
361 let token_response: UserTokenResponse = response.json().await?;
362 {
363 let mut cache = self.user_token_cache.lock().await;
364 *cache = Some(UserTokenCache {
365 refresh_token: token_response.refresh_token.clone(),
366 access_token: token_response.access_token.clone(),
367 expires_at: parse_jwt(&token_response.access_token)?.exp,
368 });
369 }
370
371 let dds_token_response = self
372 .get_dds_token_by_token(&token_response.access_token)
373 .await?;
374 let mut cache = self.dds_token_cache.lock().await;
375 let token_cache = DdsTokenCache {
376 access_token: dds_token_response.access_token.clone(),
377 claim: parse_jwt(&dds_token_response.access_token)?,
378 };
379 *cache = Some(token_cache.clone());
380 Ok(token_cache.access_token)
381 } else {
382 let status = response.status();
383 let text = response
384 .text()
385 .await
386 .unwrap_or_else(|_| "Unknown error".to_string());
387
388 Err(AukiErrorResponse {
389 status,
390 error: format!("Failed to login. {}", text),
391 }
392 .into())
393 }
394 }
395
396 async fn get_dds_token_by_token(&self, token: &str) -> Result<DdsTokenResponse, DomainError> {
398 let dds_response = self
399 .client
400 .post(format!("{}/service/domains-access-token", &self.api_url))
401 .header("Authorization", format!("Bearer {}", token))
402 .header("Content-Type", "application/json")
403 .header("posemesh-client-id", &self.client_id)
404 .send()
405 .await?;
406
407 if dds_response.status().is_success() {
408 dds_response
409 .json::<DdsTokenResponse>()
410 .await
411 .map_err(|e| e.into())
412 } else {
413 let status = dds_response.status();
414 let text = dds_response
415 .text()
416 .await
417 .unwrap_or_else(|_| "Unknown error".to_string());
418 Err(AukiErrorResponse {
419 status,
420 error: format!("Failed to get DDS access token. {}", text),
421 }
422 .into())
423 }
424 }
425}
426
427pub const REFRESH_CACHE_TIME: u64 = 60; pub(crate) async fn get_cached_or_fresh_token<R, F, Fut>(
430 cache: &R,
431 token_fetcher: F,
432) -> Result<R, DomainError>
433where
434 F: FnOnce() -> Fut,
435 R: TokenCache + Clone,
436 Fut: std::future::Future<Output = Result<R, DomainError>>,
437{
438 let expires_at = cache.get_expires_at();
440 let current_time = now_unix_secs();
441 if expires_at > current_time && expires_at - current_time > REFRESH_CACHE_TIME {
443 return Ok(cache.clone());
444 }
445
446 token_fetcher().await
448}
449
450#[derive(Debug, Deserialize, Clone)]
451pub struct JwtClaim {
452 pub exp: u64,
453 pub org: Option<String>,
454}
455
456pub fn parse_jwt(token: &str) -> Result<JwtClaim, AuthError> {
457 let parts = token.split('.').collect::<Vec<&str>>();
458 if parts.len() != 3 {
459 return Err(AuthError::Unauthorized("Invalid JWT token"));
460 }
461 let payload = parts[1];
462 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(payload)?;
463 let claims: JwtClaim = serde_json::from_slice(&decoded)?;
464 Ok(claims)
465}
466
467#[cfg(test)]
468mod tests {
469 use super::*;
470 use std::sync::Arc;
471 use std::time::{SystemTime, UNIX_EPOCH};
472 use tokio::sync::Mutex;
473
474 #[derive(Clone, Debug)]
475 struct DummyTokenCache {
476 access_token: String,
477 expires_at: u64,
478 }
479
480 impl TokenCache for DummyTokenCache {
481 fn get_access_token(&self) -> String {
482 self.access_token.clone()
483 }
484 fn get_expires_at(&self) -> u64 {
485 self.expires_at
486 }
487 }
488
489 fn now_unix_secs() -> u64 {
490 SystemTime::now()
491 .duration_since(UNIX_EPOCH)
492 .unwrap()
493 .as_secs()
494 }
495
496 fn make_jwt(exp: u64) -> String {
497 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
500 .encode(r#"{"alg":"HS256","typ":"JWT"}"#);
501 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
502 .encode(format!(r#"{{"exp":{}}}"#, exp));
503 format!("{}.{}.sig", header, payload)
504 }
505
506 #[tokio::test]
507 async fn test_ddstoken_about_to_expire_should_refetch() {
508 let now = now_unix_secs();
510 let expiring_soon = now + 2;
511 let cache = DummyTokenCache {
512 access_token: make_jwt(expiring_soon),
513 expires_at: expiring_soon,
514 };
515
516 let fetch_called = Arc::new(Mutex::new(false));
517 let fetch_called_clone = fetch_called.clone();
518
519 let new_exp = now + 1000;
520 let token_fetcher = move || {
521 let fetch_called_clone = fetch_called_clone.clone();
522 async move {
523 *fetch_called_clone.lock().await = true;
524 let token = DummyTokenCache {
525 access_token: make_jwt(new_exp),
526 expires_at: new_exp,
527 };
528 Ok(token)
530 }
531 };
532
533 let result = get_cached_or_fresh_token(&cache, token_fetcher)
534 .await
535 .unwrap();
536 assert!(
538 *fetch_called.lock().await,
539 "Fetcher should have been called"
540 );
541 assert_eq!(result.expires_at, new_exp);
543 }
544
545 #[tokio::test]
546 async fn test_ddstoken_not_expiring_should_use_cache() {
547 let now = now_unix_secs();
549 let not_expiring = now + 100;
550 let cache = DummyTokenCache {
551 access_token: make_jwt(not_expiring),
552 expires_at: not_expiring,
553 };
554
555 let fetch_called = Arc::new(Mutex::new(false));
556 let fetch_called_clone = fetch_called.clone();
557
558 let cache_clone = cache.clone();
559 let token_fetcher = move || {
560 let fetch_called_clone = fetch_called_clone.clone();
561 async move {
562 *fetch_called_clone.lock().await = true;
563 Ok(cache_clone.clone())
564 }
565 };
566
567 let result = get_cached_or_fresh_token(&cache, token_fetcher)
568 .await
569 .unwrap();
570 assert!(
572 !*fetch_called.lock().await,
573 "Fetcher should NOT have been called"
574 );
575 assert_eq!(result.expires_at, not_expiring);
577 }
578}