Skip to main content

posemesh_domain_http/
auth.rs

1use base64::{Engine as _, engine::general_purpose};
2use futures::lock::Mutex;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use posemesh_utils::now_unix_secs;
7use std::sync::Arc;
8
9use crate::errors::{AukiErrorResponse, AuthError, DomainError};
10
11#[derive(Debug, Clone)]
12pub struct AuthClient {
13    pub api_url: String,
14    client: Client,
15    dds_token_cache: Arc<Mutex<Option<DdsTokenCache>>>,
16    user_token_cache: Arc<Mutex<Option<UserTokenCache>>>,
17    pub client_id: String,
18    app_key: Option<String>,
19    app_secret: Option<String>,
20}
21
22#[derive(Debug, Clone)]
23pub struct UserTokenCache {
24    refresh_token: String,
25    access_token: String,
26    expires_at: u64,
27}
28
29impl TokenCache for UserTokenCache {
30    fn get_access_token(&self) -> String {
31        self.access_token.clone()
32    }
33
34    fn get_expires_at(&self) -> u64 {
35        self.expires_at
36    }
37}
38
39#[derive(Debug, Clone)]
40pub(crate) struct DdsTokenCache {
41    // DDS access token
42    access_token: String,
43    claim: JwtClaim,
44}
45
46impl TokenCache for DdsTokenCache {
47    fn get_access_token(&self) -> String {
48        self.access_token.clone()
49    }
50
51    fn get_expires_at(&self) -> u64 {
52        self.claim.exp
53    }
54}
55
56impl Default for DdsTokenCache {
57    fn default() -> Self {
58        Self {
59            access_token: "".to_string(),
60            claim: JwtClaim { exp: 0, org: None },
61        }
62    }
63}
64pub(crate) trait TokenCache {
65    fn get_access_token(&self) -> String;
66    fn get_expires_at(&self) -> u64;
67}
68
69#[derive(Debug, Serialize)]
70pub struct UserCredentials {
71    pub email: String,
72    pub password: String,
73}
74
75#[derive(Debug, Deserialize)]
76pub struct UserTokenResponse {
77    pub access_token: String,
78    pub refresh_token: String,
79}
80
81#[derive(Debug, Deserialize)]
82pub struct DdsTokenResponse {
83    pub access_token: String,
84}
85
86impl AuthClient {
87    pub fn new(api_url: &str, client_id: &str) -> Self {
88        Self {
89            api_url: api_url.to_string(),
90            client: Client::new(),
91            dds_token_cache: Arc::new(Mutex::new(None)),
92            user_token_cache: Arc::new(Mutex::new(None)),
93            client_id: client_id.to_string(),
94            app_key: None,
95            app_secret: None,
96        }
97    }
98
99    /// Get the expiration time of the user refresh token or DDS access token
100    pub async fn get_expires_at(&self) -> Result<u64, DomainError> {
101        let token_cache = {
102            let cache = self.user_token_cache.lock().await;
103            cache.clone()
104        };
105        if token_cache.is_none() {
106            let dds_token_cache = {
107                let cache = self.dds_token_cache.lock().await;
108                cache.clone()
109            };
110            if dds_token_cache.is_none() {
111                return Err(DomainError::AuthError(AuthError::Unauthorized(
112                    "No token found",
113                )));
114            }
115            return Ok(dds_token_cache.unwrap().claim.exp);
116        }
117        Ok(parse_jwt(&token_cache.unwrap().refresh_token)?.exp)
118    }
119
120    pub async fn sign_in_with_app_credentials(
121        &mut self,
122        app_key: &str,
123        app_secret: &str,
124    ) -> Result<String, DomainError> {
125        self.app_key = Some(app_key.to_string());
126        self.app_secret = Some(app_secret.to_string());
127        *self.dds_token_cache.lock().await = None;
128        *self.user_token_cache.lock().await = None;
129
130        self.get_dds_app_access_token().await
131    }
132
133    // Get DDS access token with either app credentials or user access token or oidc_access_token, it checks the cache first, if found and not about to expire, return the cached token
134    // if not found or about to expire, it fetches a new token with app credentials or user access token or oidc_access_token and sets the cache.
135    // If user access token is about to expire, it refreshes the user access token with refresh token first and sets the cache.
136    // It clears all caches if there is an error.
137    pub async fn get_dds_access_token(
138        &self,
139        oidc_access_token: Option<&str>,
140    ) -> Result<String, DomainError> {
141        let result = if let Some(oidc_access_token) = oidc_access_token {
142            self.get_dds_access_token_with_oidc_access_token(oidc_access_token)
143                .await
144        } else if self.app_key.is_some() {
145            self.get_dds_app_access_token().await
146        } else {
147            self.get_dds_user_access_token().await
148        };
149
150        if result.is_err() {
151            *self.dds_token_cache.lock().await = None;
152            *self.user_token_cache.lock().await = None;
153        }
154
155        result
156    }
157
158    // Get DDS access token with OIDC access token, doesn't cache
159    async fn get_dds_access_token_with_oidc_access_token(
160        &self,
161        oidc_access_token: &str,
162    ) -> Result<String, DomainError> {
163        // Clear all caches before proceeding
164        *self.dds_token_cache.lock().await = None;
165        *self.user_token_cache.lock().await = None;
166
167        let response = self.get_dds_token_by_token(oidc_access_token).await?;
168        {
169            let mut cache = self.dds_token_cache.lock().await;
170            *cache = Some(DdsTokenCache {
171                access_token: response.access_token.clone(),
172                claim: parse_jwt(&response.access_token)?,
173            });
174        }
175        Ok(response.access_token)
176    }
177
178    // Get DDS access token with app credentials, it checks the cache first, if found and not about to expire, return the cached token
179    // if not found or about to expire, fetch a new token with app credentials and sets the cache.
180    async fn get_dds_app_access_token(&self) -> Result<String, DomainError> {
181        let token_cache = {
182            let cache = self.dds_token_cache.lock().await;
183            cache.clone()
184        };
185
186        let app_key = self
187            .app_key
188            .clone()
189            .ok_or(AuthError::Unauthorized("App key is not set"))?;
190        let app_secret = self
191            .app_secret
192            .clone()
193            .ok_or(AuthError::Unauthorized("App secret is not set"))?;
194
195        let token_cache = get_cached_or_fresh_token(
196            &token_cache.unwrap_or(DdsTokenCache {
197                access_token: "".to_string(),
198                claim: JwtClaim { exp: 0, org: None },
199            }),
200            || {
201                let app_key = app_key.to_string();
202                let app_secret = app_secret.to_string();
203                let client = self.client.clone();
204                let api_url = self.api_url.clone();
205                let client_id = self.client_id.clone();
206                async move {
207                    let response = client
208                        .post(format!("{}/service/domains-access-token", api_url))
209                        .basic_auth(app_key, Some(app_secret))
210                        .header("Content-Type", "application/json")
211                        .header("posemesh-client-id", client_id)
212                        .send()
213                        .await?;
214
215                    if response.status().is_success() {
216                        let token_response: DdsTokenResponse = response.json().await?;
217                        Ok(DdsTokenCache {
218                            access_token: token_response.access_token.clone(),
219                            claim: parse_jwt(&token_response.access_token)?,
220                        })
221                    } else {
222                        let status = response.status();
223                        let text = response
224                            .text()
225                            .await
226                            .unwrap_or_else(|_| "Unknown error".to_string());
227                        Err(AukiErrorResponse {
228                            status,
229                            error: format!("Failed to get DDS access token. {}", text),
230                        }
231                        .into())
232                    }
233                }
234            },
235        )
236        .await?;
237
238        {
239            let mut cache = self.dds_token_cache.lock().await;
240            *cache = Some(token_cache.clone());
241        }
242
243        Ok(token_cache.access_token)
244    }
245
246    // Get DDS access token with user credentials, it checks the cache first, if found and not about to expire, return the cached token
247    // if not found or about to expire, it fetches a new token with user access token and sets the cache.
248    // If user access token is about to expire, it refreshes the user access token with refresh token first and sets the cache.
249    async fn get_dds_user_access_token(&self) -> Result<String, DomainError> {
250        let token_cache = {
251            let cache = self.dds_token_cache.lock().await;
252            cache.clone()
253        };
254
255        if token_cache.is_none() {
256            return Err(AuthError::Unauthorized("No user access token found").into());
257        }
258
259        let user_token_cache = {
260            let cache = self.user_token_cache.lock().await;
261            cache.clone()
262        };
263
264        if user_token_cache.is_none() {
265            return Err(AuthError::Unauthorized("Login first").into());
266        }
267
268        let token_cache = get_cached_or_fresh_token(&token_cache.unwrap(), || {
269            let client = self.client.clone();
270            let api_url = self.api_url.clone();
271            let client_id = self.client_id.clone();
272
273            async move {
274                let client_clone = client.clone();
275                let api_url_clone = api_url.clone();
276                let client_id_clone = client_id.clone();
277                let refresh_token = user_token_cache.clone().unwrap().refresh_token;
278                let user_token_cache =
279                    get_cached_or_fresh_token(&user_token_cache.unwrap(), || async move {
280                        let response = client_clone
281                            .post(format!("{}/user/refresh", api_url_clone))
282                            .header("Content-Type", "application/json")
283                            .header("posemesh-client-id", client_id_clone)
284                            .header("Authorization", format!("Bearer {}", refresh_token))
285                            .send()
286                            .await
287                            .expect("Failed to refresh token");
288
289                        if response.status().is_success() {
290                            let token_response: UserTokenResponse = response.json().await?;
291                            Ok(UserTokenCache {
292                                refresh_token: token_response.refresh_token.clone(),
293                                access_token: token_response.access_token.clone(),
294                                expires_at: parse_jwt(&token_response.access_token)?.exp,
295                            })
296                        } else {
297                            let status = response.status();
298                            let text = response
299                                .text()
300                                .await
301                                .unwrap_or_else(|_| "Unknown error".to_string());
302                            Err(AukiErrorResponse {
303                                status,
304                                error: format!("Failed to refresh token. {}", text),
305                            }
306                            .into())
307                        }
308                    })
309                    .await?;
310
311                {
312                    let mut cache = self.user_token_cache.lock().await;
313                    *cache = Some(user_token_cache.clone());
314                }
315
316                let dds_token_response = self
317                    .get_dds_token_by_token(&user_token_cache.access_token)
318                    .await?;
319
320                let dds_cache = DdsTokenCache {
321                    access_token: dds_token_response.access_token.clone(),
322                    claim: parse_jwt(&dds_token_response.access_token)?,
323                };
324                {
325                    let mut cache = self.dds_token_cache.lock().await;
326                    *cache = Some(dds_cache.clone());
327                }
328                Ok(dds_cache)
329            }
330        })
331        .await?;
332
333        {
334            let mut cache = self.dds_token_cache.lock().await;
335            *cache = Some(token_cache.clone());
336        }
337
338        Ok(token_cache.access_token)
339    }
340
341    // Login with user credentials, return DDS access token. It clears all caches and sets the app credentials to none.
342    pub async fn user_login(&mut self, email: &str, password: &str) -> Result<String, DomainError> {
343        self.app_key = None;
344        self.app_secret = None;
345
346        let credentials = UserCredentials {
347            email: email.to_string(),
348            password: password.to_string(),
349        };
350
351        let response = self
352            .client
353            .post(format!("{}/user/login", &self.api_url))
354            .header("Content-Type", "application/json")
355            .header("posemesh-client-id", &self.client_id)
356            .json(&credentials)
357            .send()
358            .await?;
359
360        if response.status().is_success() {
361            let token_response: UserTokenResponse = response.json().await?;
362            {
363                let mut cache = self.user_token_cache.lock().await;
364                *cache = Some(UserTokenCache {
365                    refresh_token: token_response.refresh_token.clone(),
366                    access_token: token_response.access_token.clone(),
367                    expires_at: parse_jwt(&token_response.access_token)?.exp,
368                });
369            }
370
371            let dds_token_response = self
372                .get_dds_token_by_token(&token_response.access_token)
373                .await?;
374            let mut cache = self.dds_token_cache.lock().await;
375            let token_cache = DdsTokenCache {
376                access_token: dds_token_response.access_token.clone(),
377                claim: parse_jwt(&dds_token_response.access_token)?,
378            };
379            *cache = Some(token_cache.clone());
380            Ok(token_cache.access_token)
381        } else {
382            let status = response.status();
383            let text = response
384                .text()
385                .await
386                .unwrap_or_else(|_| "Unknown error".to_string());
387
388            Err(AukiErrorResponse {
389                status,
390                error: format!("Failed to login. {}", text),
391            }
392            .into())
393        }
394    }
395
396    // Get DDS access token with either user access token or oidc_access_token, doesn't cache
397    async fn get_dds_token_by_token(&self, token: &str) -> Result<DdsTokenResponse, DomainError> {
398        let dds_response = self
399            .client
400            .post(format!("{}/service/domains-access-token", &self.api_url))
401            .header("Authorization", format!("Bearer {}", token))
402            .header("Content-Type", "application/json")
403            .header("posemesh-client-id", &self.client_id)
404            .send()
405            .await?;
406
407        if dds_response.status().is_success() {
408            dds_response
409                .json::<DdsTokenResponse>()
410                .await
411                .map_err(|e| e.into())
412        } else {
413            let status = dds_response.status();
414            let text = dds_response
415                .text()
416                .await
417                .unwrap_or_else(|_| "Unknown error".to_string());
418            Err(AukiErrorResponse {
419                status,
420                error: format!("Failed to get DDS access token. {}", text),
421            }
422            .into())
423        }
424    }
425}
426
427pub const REFRESH_CACHE_TIME: u64 = 60; // 1 minute
428
429pub(crate) async fn get_cached_or_fresh_token<R, F, Fut>(
430    cache: &R,
431    token_fetcher: F,
432) -> Result<R, DomainError>
433where
434    F: FnOnce() -> Fut,
435    R: TokenCache + Clone,
436    Fut: std::future::Future<Output = Result<R, DomainError>>,
437{
438    // Check if we have a valid cached token
439    let expires_at = cache.get_expires_at();
440    let current_time = now_unix_secs();
441    // If token expires in more than REFRESH_CACHE_TIME seconds, return cached token
442    if expires_at > current_time && expires_at - current_time > REFRESH_CACHE_TIME {
443        return Ok(cache.clone());
444    }
445
446    // Fetch new token
447    token_fetcher().await
448}
449
450#[derive(Debug, Deserialize, Clone)]
451pub struct JwtClaim {
452    pub exp: u64,
453    pub org: Option<String>,
454}
455
456pub fn parse_jwt(token: &str) -> Result<JwtClaim, AuthError> {
457    let parts = token.split('.').collect::<Vec<&str>>();
458    if parts.len() != 3 {
459        return Err(AuthError::Unauthorized("Invalid JWT token"));
460    }
461    let payload = parts[1];
462    let decoded = general_purpose::URL_SAFE_NO_PAD.decode(payload)?;
463    let claims: JwtClaim = serde_json::from_slice(&decoded)?;
464    Ok(claims)
465}
466
467#[cfg(test)]
468mod tests {
469    use super::*;
470    use std::sync::Arc;
471    use std::time::{SystemTime, UNIX_EPOCH};
472    use tokio::sync::Mutex;
473
474    #[derive(Clone, Debug)]
475    struct DummyTokenCache {
476        access_token: String,
477        expires_at: u64,
478    }
479
480    impl TokenCache for DummyTokenCache {
481        fn get_access_token(&self) -> String {
482            self.access_token.clone()
483        }
484        fn get_expires_at(&self) -> u64 {
485            self.expires_at
486        }
487    }
488
489    fn now_unix_secs() -> u64 {
490        SystemTime::now()
491            .duration_since(UNIX_EPOCH)
492            .unwrap()
493            .as_secs()
494    }
495
496    fn make_jwt(exp: u64) -> String {
497        // Header: {"alg":"HS256","typ":"JWT"}
498        // Payload: {"exp":exp}
499        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD
500            .encode(r#"{"alg":"HS256","typ":"JWT"}"#);
501        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
502            .encode(format!(r#"{{"exp":{}}}"#, exp));
503        format!("{}.{}.sig", header, payload)
504    }
505
506    #[tokio::test]
507    async fn test_ddstoken_about_to_expire_should_refetch() {
508        // Token expires in 2 seconds (less than REFRESH_CACHE_TIME)
509        let now = now_unix_secs();
510        let expiring_soon = now + 2;
511        let cache = DummyTokenCache {
512            access_token: make_jwt(expiring_soon),
513            expires_at: expiring_soon,
514        };
515
516        let fetch_called = Arc::new(Mutex::new(false));
517        let fetch_called_clone = fetch_called.clone();
518
519        let new_exp = now + 1000;
520        let token_fetcher = move || {
521            let fetch_called_clone = fetch_called_clone.clone();
522            async move {
523                *fetch_called_clone.lock().await = true;
524                let token = DummyTokenCache {
525                    access_token: make_jwt(new_exp),
526                    expires_at: new_exp,
527                };
528                // set_expires_at will be called by get_cached_or_fresh_token
529                Ok(token)
530            }
531        };
532
533        let result = get_cached_or_fresh_token(&cache, token_fetcher)
534            .await
535            .unwrap();
536        // Should have called fetcher
537        assert!(
538            *fetch_called.lock().await,
539            "Fetcher should have been called"
540        );
541        // Should have new expiration
542        assert_eq!(result.expires_at, new_exp);
543    }
544
545    #[tokio::test]
546    async fn test_ddstoken_not_expiring_should_use_cache() {
547        // Token expires in 100 seconds (more than REFRESH_CACHE_TIME)
548        let now = now_unix_secs();
549        let not_expiring = now + 100;
550        let cache = DummyTokenCache {
551            access_token: make_jwt(not_expiring),
552            expires_at: not_expiring,
553        };
554
555        let fetch_called = Arc::new(Mutex::new(false));
556        let fetch_called_clone = fetch_called.clone();
557
558        let cache_clone = cache.clone();
559        let token_fetcher = move || {
560            let fetch_called_clone = fetch_called_clone.clone();
561            async move {
562                *fetch_called_clone.lock().await = true;
563                Ok(cache_clone.clone())
564            }
565        };
566
567        let result = get_cached_or_fresh_token(&cache, token_fetcher)
568            .await
569            .unwrap();
570        // Should NOT have called fetcher
571        assert!(
572            !*fetch_called.lock().await,
573            "Fetcher should NOT have been called"
574        );
575        // Should have same expiration
576        assert_eq!(result.expires_at, not_expiring);
577    }
578}