posemesh_domain_http/
auth.rs

1use base64::{Engine as _, engine::general_purpose};
2use futures::lock::Mutex;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use posemesh_utils::now_unix_secs;
7use std::sync::Arc;
8
9#[derive(Debug, Clone)]
10pub struct AuthClient {
11    pub api_url: String,
12    client: Client,
13    dds_token_cache: Arc<Mutex<Option<DdsTokenCache>>>,
14    user_token_cache: Arc<Mutex<Option<UserTokenCache>>>,
15    pub client_id: String,
16    app_key: Option<String>,
17    app_secret: Option<String>,
18}
19
20#[derive(Debug, Clone)]
21pub struct UserTokenCache {
22    refresh_token: String,
23    access_token: String,
24    expires_at: u64,
25}
26
27impl TokenCache for UserTokenCache {
28    fn get_access_token(&self) -> String {
29        self.access_token.clone()
30    }
31
32    fn get_expires_at(&self) -> u64 {
33        self.expires_at
34    }
35}
36
37#[derive(Debug, Clone)]
38pub(crate) struct DdsTokenCache {
39    // DDS access token
40    access_token: String,
41    // DDS access token expiration time as UTC timestamp
42    expires_at: u64,
43}
44
45impl TokenCache for DdsTokenCache {
46    fn get_access_token(&self) -> String {
47        self.access_token.clone()
48    }
49
50    fn get_expires_at(&self) -> u64 {
51        self.expires_at
52    }
53}
54
55pub(crate) trait TokenCache {
56    fn get_access_token(&self) -> String;
57    fn get_expires_at(&self) -> u64;
58}
59
60#[derive(Debug, Serialize)]
61pub struct UserCredentials {
62    pub email: String,
63    pub password: String,
64}
65
66#[derive(Debug, Deserialize)]
67pub struct UserTokenResponse {
68    pub access_token: String,
69    pub refresh_token: String,
70}
71
72#[derive(Debug, Deserialize)]
73pub struct DdsTokenResponse {
74    pub access_token: String,
75}
76
77impl AuthClient {
78    pub fn new(api_url: &str, client_id: &str) -> Self {
79        Self {
80            api_url: api_url.to_string(),
81            client: Client::new(),
82            dds_token_cache: Arc::new(Mutex::new(None)),
83            user_token_cache: Arc::new(Mutex::new(None)),
84            client_id: client_id.to_string(),
85            app_key: None,
86            app_secret: None,
87        }
88    }
89
90    /// Get the expiration time of the user refresh token or DDS access token
91    pub async fn get_expires_at(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
92        let token_cache = {
93            let cache = self.user_token_cache.lock().await;
94            cache.clone()
95        };
96        if token_cache.is_none() {
97            let dds_token_cache = {
98                let cache = self.dds_token_cache.lock().await;
99                cache.clone()
100            };
101            if dds_token_cache.is_none() {
102                return Err("No token found".into());
103            }
104            return Ok(dds_token_cache.unwrap().expires_at);
105        }
106        Ok(parse_jwt(&token_cache.unwrap().refresh_token)?.exp)
107    }
108
109    pub async fn sign_in_with_app_credentials(
110        &mut self,
111        app_key: &str,
112        app_secret: &str,
113    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
114        self.app_key = Some(app_key.to_string());
115        self.app_secret = Some(app_secret.to_string());
116        *self.dds_token_cache.lock().await = None;
117        *self.user_token_cache.lock().await = None;
118
119        self.get_dds_app_access_token().await
120    }
121
122    // Get DDS access token with either app credentials or user access token or oidc_access_token, it checks the cache first, if found and not about to expire, return the cached token
123    // if not found or about to expire, it fetches a new token with app credentials or user access token or oidc_access_token and sets the cache.
124    // If user access token is about to expire, it refreshes the user access token with refresh token first and sets the cache.
125    // It clears all caches if there is an error.
126    pub async fn get_dds_access_token(
127        &self,
128        oidc_access_token: Option<&str>,
129    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
130        let result = if let Some(oidc_access_token) = oidc_access_token {
131            self.get_dds_access_token_with_oidc_access_token(oidc_access_token).await
132        } else if self.app_key.is_some() {
133            self.get_dds_app_access_token().await
134        } else {
135            self.get_dds_user_access_token().await
136        };
137
138        if result.is_err() {
139            *self.dds_token_cache.lock().await = None;
140            *self.user_token_cache.lock().await = None;
141        }
142
143        result
144    }
145
146    // Get DDS access token with OIDC access token, doesn't cache
147    async fn get_dds_access_token_with_oidc_access_token(
148        &self,
149        oidc_access_token: &str,
150    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
151        // Clear all caches before proceeding
152        *self.dds_token_cache.lock().await = None;
153        *self.user_token_cache.lock().await = None;
154        
155        let response = self.get_dds_token_by_token(oidc_access_token).await?;
156        {
157            let mut cache = self.dds_token_cache.lock().await;
158            *cache = Some(DdsTokenCache {
159                access_token: response.access_token.clone(),
160                expires_at: parse_jwt(&response.access_token)?.exp,
161            });
162        }
163        Ok(response.access_token)
164    }
165
166    // Get DDS access token with app credentials, it checks the cache first, if found and not about to expire, return the cached token
167    // if not found or about to expire, fetch a new token with app credentials and sets the cache.
168    async fn get_dds_app_access_token(
169        &self,
170    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
171        let token_cache = {
172            let cache = self.dds_token_cache.lock().await;
173            cache.clone()
174        };
175
176        let app_key = self
177            .app_key
178            .clone()
179            .ok_or("App key is not set".to_string())?;
180        let app_secret = self
181            .app_secret
182            .clone()
183            .ok_or("App secret is not set".to_string())?;
184
185        let token_cache = get_cached_or_fresh_token(
186            &token_cache.unwrap_or(DdsTokenCache {
187                access_token: "".to_string(),
188                expires_at: 0,
189            }),
190            || {
191                let app_key = app_key.to_string();
192                let app_secret = app_secret.to_string();
193                let client = self.client.clone();
194                let api_url = self.api_url.clone();
195                let client_id = self.client_id.clone();
196                async move {
197                    let response = client
198                        .post(format!("{}/service/domains-access-token", api_url))
199                        .basic_auth(app_key, Some(app_secret))
200                        .header("Content-Type", "application/json")
201                        .header("posemesh-client-id", client_id)
202                        .send()
203                        .await?;
204
205                    if response.status().is_success() {
206                        let token_response: DdsTokenResponse = response.json().await?;
207                        Ok(DdsTokenCache {
208                            access_token: token_response.access_token.clone(),
209                            expires_at: parse_jwt(&token_response.access_token)?.exp,
210                        })
211                    } else {
212                        let status = response.status();
213                        let text = response
214                            .text()
215                            .await
216                            .unwrap_or_else(|_| "Unknown error".to_string());
217                        Err(format!(
218                            "Failed to get DDS access token. Status: {} - {}",
219                            status, text
220                        )
221                        .into())
222                    }
223                }
224            },
225        )
226        .await?;
227
228        {
229            let mut cache = self.dds_token_cache.lock().await;
230            *cache = Some(token_cache.clone());
231        }
232
233        Ok(token_cache.access_token)
234    }
235
236    // Get DDS access token with user credentials, it checks the cache first, if found and not about to expire, return the cached token
237    // if not found or about to expire, it fetches a new token with user access token and sets the cache.
238    // If user access token is about to expire, it refreshes the user access token with refresh token first and sets the cache.
239    async fn get_dds_user_access_token(
240        &self,
241    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
242        let token_cache = {
243            let cache = self.dds_token_cache.lock().await;
244            cache.clone()
245        };
246
247        if token_cache.is_none() {
248            return Err("No access token found".into());
249        }
250
251        let user_token_cache = {
252            let cache = self.user_token_cache.lock().await;
253            cache.clone()
254        };
255
256        if user_token_cache.is_none() {
257            return Err("Login first".into());
258        }
259
260        let token_cache = get_cached_or_fresh_token(&token_cache.unwrap(), || {
261            let client = self.client.clone();
262            let api_url = self.api_url.clone();
263            let client_id = self.client_id.clone();
264
265            async move {
266                let client_clone = client.clone();
267                let api_url_clone = api_url.clone();
268                let client_id_clone = client_id.clone();
269                let refresh_token = user_token_cache.clone().unwrap().refresh_token;
270                let user_token_cache =
271                    get_cached_or_fresh_token(&user_token_cache.unwrap(), || async move {
272                        let response = client_clone
273                            .post(format!("{}/user/refresh", api_url_clone))
274                            .header("Content-Type", "application/json")
275                            .header("posemesh-client-id", client_id_clone)
276                            .header("Authorization", format!("Bearer {}", refresh_token))
277                            .send()
278                            .await
279                            .expect("Failed to refresh token");
280
281                        if response.status().is_success() {
282                            let token_response: UserTokenResponse = response.json().await?;
283                            Ok(UserTokenCache {
284                                refresh_token: token_response.refresh_token.clone(),
285                                access_token: token_response.access_token.clone(),
286                                expires_at: parse_jwt(&token_response.access_token)?.exp,
287                            })
288                        } else {
289                            Err(
290                                format!("Failed to refresh token. Status: {}", response.status())
291                                    .into(),
292                            )
293                        }
294                    })
295                    .await?;
296
297                {
298                    let mut cache = self.user_token_cache.lock().await;
299                    *cache = Some(user_token_cache.clone());
300                }
301
302                let dds_token_response = self.get_dds_token_by_token(&user_token_cache.access_token).await?;
303
304                let dds_cache = DdsTokenCache {
305                    access_token: dds_token_response.access_token.clone(),
306                    expires_at: parse_jwt(&dds_token_response.access_token)?.exp,
307                };
308                {
309                    let mut cache = self.dds_token_cache.lock().await;
310                    *cache = Some(dds_cache.clone());
311                }
312                Ok(dds_cache)
313            }
314        })
315        .await?;
316    
317        {
318            let mut cache = self.dds_token_cache.lock().await;
319            *cache = Some(token_cache.clone());
320        }
321
322        Ok(token_cache.access_token)
323    }
324
325    // Login with user credentials, return DDS access token. It clears all caches and sets the app credentials to none.
326    pub async fn user_login(
327        &mut self,
328        email: &str,
329        password: &str,
330    ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
331        *self.dds_token_cache.lock().await = None;
332        *self.user_token_cache.lock().await = None;
333        self.app_key = None;
334        self.app_secret = None;
335
336        let credentials = UserCredentials { email: email.to_string(), password: password.to_string() };
337
338        let response = self.client
339            .post(format!("{}/user/login", &self.api_url))
340            .header("Content-Type", "application/json")
341            .header("posemesh-client-id", &self.client_id)
342            .json(&credentials)
343            .send()
344            .await?;
345
346        if response.status().is_success() {
347            let token_response: UserTokenResponse = response.json().await?;
348            {
349                let mut cache = self.user_token_cache.lock().await;
350                *cache = Some(UserTokenCache {
351                    refresh_token: token_response.refresh_token.clone(),
352                    access_token: token_response.access_token.clone(),
353                    expires_at: parse_jwt(&token_response.access_token)?.exp,
354                });
355            }
356
357            let dds_token_response = self.get_dds_token_by_token(&token_response.access_token).await?;
358            let mut cache = self.dds_token_cache.lock().await;
359            let token_cache = DdsTokenCache {
360                access_token: dds_token_response.access_token.clone(),
361                expires_at: parse_jwt(&dds_token_response.access_token)?.exp,
362            };
363            *cache = Some(token_cache.clone());
364            Ok(token_cache.access_token)
365        } else {
366            Err(format!("Failed to login. Status: {}", response.status()).into())
367        }
368    }
369
370    // Get DDS access token with either user access token or oidc_access_token, doesn't cache
371    async fn get_dds_token_by_token(
372        &self,
373        token: &str,
374    ) -> Result<DdsTokenResponse, Box<dyn std::error::Error + Send + Sync>> {
375        let dds_response = self.client.post(format!("{}/service/domains-access-token", &self.api_url))
376            .header(
377                "Authorization",
378                format!("Bearer {}", token),
379            )
380            .header("Content-Type", "application/json")
381            .header("posemesh-client-id", &self.client_id)
382            .send()
383            .await?;
384
385        if dds_response.status().is_success() {
386            dds_response.json::<DdsTokenResponse>().await.map_err(|e| e.into())
387        } else {
388            let status = dds_response.status();
389            let text = dds_response
390                .text()
391                .await
392                .unwrap_or_else(|_| "Unknown error".to_string());
393            Err(format!(
394                "Failed to get DDS access token. Status: {} - {}",
395                status, text
396            )
397            .into())
398        }
399    }
400}
401
402const REFRESH_CACHE_TIME: u64 = 3;
403
404pub(crate) async fn get_cached_or_fresh_token<R, F, Fut>(
405    cache: &R,
406    token_fetcher: F,
407) -> Result<R, Box<dyn std::error::Error + Send + Sync>>
408where
409    F: FnOnce() -> Fut,
410    R: TokenCache + Clone,
411    Fut: std::future::Future<Output = Result<R, Box<dyn std::error::Error + Send + Sync>>>,
412{
413    // Check if we have a valid cached token
414    let expires_at = cache.get_expires_at();
415    let current_time = now_unix_secs();
416    // If token expires in more than REFRESH_CACHE_TIME seconds, return cached token
417    if expires_at > current_time && expires_at - current_time > REFRESH_CACHE_TIME {
418        return Ok(cache.clone());
419    }
420
421    // Fetch new token
422    token_fetcher().await
423}
424
425#[derive(Debug, Deserialize)]
426pub struct JwtClaim {
427    pub exp: u64,
428    #[serde(default)]
429    pub org: Option<String>,
430}
431
432pub fn parse_jwt(token: &str) -> Result<JwtClaim, Box<dyn std::error::Error + Send + Sync>> {
433    let parts = token.split('.').collect::<Vec<&str>>();
434    if parts.len() != 3 {
435        return Err("Invalid JWT token".into());
436    }
437    let payload = parts[1];
438    let decoded = general_purpose::URL_SAFE_NO_PAD.decode(payload)?;
439    let claims: JwtClaim = serde_json::from_slice(&decoded)?;
440    Ok(claims)
441}
442
443#[cfg(test)]
444mod tests {
445    use super::*;
446    use std::sync::Arc;
447    use tokio::sync::Mutex;
448    use std::time::{SystemTime, UNIX_EPOCH};
449
450    #[derive(Clone, Debug)]
451    struct DummyTokenCache {
452        access_token: String,
453        expires_at: u64,
454    }
455
456    impl TokenCache for DummyTokenCache {
457        fn get_access_token(&self) -> String {
458            self.access_token.clone()
459        }
460        fn get_expires_at(&self) -> u64 {
461            self.expires_at
462        }
463    }
464
465    fn now_unix_secs() -> u64 {
466        SystemTime::now()
467            .duration_since(UNIX_EPOCH)
468            .unwrap()
469            .as_secs()
470    }
471
472    fn make_jwt(exp: u64) -> String {
473        // Header: {"alg":"HS256","typ":"JWT"}
474        // Payload: {"exp":exp}
475        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"HS256","typ":"JWT"}"#);
476        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(format!(r#"{{"exp":{}}}"#, exp));
477        format!("{}.{}.sig", header, payload)
478    }
479
480    #[tokio::test]
481    async fn test_ddstoken_about_to_expire_should_refetch() {
482        // Token expires in 2 seconds (less than REFRESH_CACHE_TIME)
483        let now = now_unix_secs();
484        let expiring_soon = now + 2;
485        let cache = DummyTokenCache {
486            access_token: make_jwt(expiring_soon),
487            expires_at: expiring_soon,
488        };
489
490        let fetch_called = Arc::new(Mutex::new(false));
491        let fetch_called_clone = fetch_called.clone();
492
493        let new_exp = now + 1000;
494        let token_fetcher = move || {
495            let fetch_called_clone = fetch_called_clone.clone();
496            async move {
497                *fetch_called_clone.lock().await = true;
498                let token = DummyTokenCache {
499                    access_token: make_jwt(new_exp),
500                    expires_at: new_exp,
501                };
502                // set_expires_at will be called by get_cached_or_fresh_token
503                Ok(token)
504            }
505        };
506
507        let result = get_cached_or_fresh_token(&cache, token_fetcher).await.unwrap();
508        // Should have called fetcher
509        assert!(*fetch_called.lock().await, "Fetcher should have been called");
510        // Should have new expiration
511        assert_eq!(result.expires_at, new_exp);
512    }
513
514    #[tokio::test]
515    async fn test_ddstoken_not_expiring_should_use_cache() {
516        // Token expires in 100 seconds (more than REFRESH_CACHE_TIME)
517        let now = now_unix_secs();
518        let not_expiring = now + 100;
519        let cache = DummyTokenCache {
520            access_token: make_jwt(not_expiring),
521            expires_at: not_expiring,
522        };
523
524        let fetch_called = Arc::new(Mutex::new(false));
525        let fetch_called_clone = fetch_called.clone();
526
527        let cache_clone = cache.clone();
528        let token_fetcher = move || {
529            let fetch_called_clone = fetch_called_clone.clone();
530            async move {
531                *fetch_called_clone.lock().await = true;
532                Ok(cache_clone.clone())
533            }
534        };
535
536        let result = get_cached_or_fresh_token(&cache, token_fetcher).await.unwrap();
537        // Should NOT have called fetcher
538        assert!(!*fetch_called.lock().await, "Fetcher should NOT have been called");
539        // Should have same expiration
540        assert_eq!(result.expires_at, not_expiring);
541    }
542}