1use base64::{Engine as _, engine::general_purpose};
2use futures::lock::Mutex;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5
6use posemesh_utils::now_unix_secs;
7use std::sync::Arc;
8
9#[derive(Debug, Clone)]
10pub struct AuthClient {
11 pub api_url: String,
12 client: Client,
13 dds_token_cache: Arc<Mutex<Option<DdsTokenCache>>>,
14 user_token_cache: Arc<Mutex<Option<UserTokenCache>>>,
15 pub client_id: String,
16 app_key: Option<String>,
17 app_secret: Option<String>,
18}
19
20#[derive(Debug, Clone)]
21pub struct UserTokenCache {
22 refresh_token: String,
23 access_token: String,
24 expires_at: u64,
25}
26
27impl TokenCache for UserTokenCache {
28 fn get_access_token(&self) -> String {
29 self.access_token.clone()
30 }
31
32 fn get_expires_at(&self) -> u64 {
33 self.expires_at
34 }
35}
36
37#[derive(Debug, Clone)]
38pub(crate) struct DdsTokenCache {
39 access_token: String,
41 expires_at: u64,
43}
44
45impl TokenCache for DdsTokenCache {
46 fn get_access_token(&self) -> String {
47 self.access_token.clone()
48 }
49
50 fn get_expires_at(&self) -> u64 {
51 self.expires_at
52 }
53}
54
55pub(crate) trait TokenCache {
56 fn get_access_token(&self) -> String;
57 fn get_expires_at(&self) -> u64;
58}
59
60#[derive(Debug, Serialize)]
61pub struct UserCredentials {
62 pub email: String,
63 pub password: String,
64}
65
66#[derive(Debug, Deserialize)]
67pub struct UserTokenResponse {
68 pub access_token: String,
69 pub refresh_token: String,
70}
71
72#[derive(Debug, Deserialize)]
73pub struct DdsTokenResponse {
74 pub access_token: String,
75}
76
77impl AuthClient {
78 pub fn new(api_url: &str, client_id: &str) -> Self {
79 Self {
80 api_url: api_url.to_string(),
81 client: Client::new(),
82 dds_token_cache: Arc::new(Mutex::new(None)),
83 user_token_cache: Arc::new(Mutex::new(None)),
84 client_id: client_id.to_string(),
85 app_key: None,
86 app_secret: None,
87 }
88 }
89
90 pub async fn get_expires_at(&self) -> Result<u64, Box<dyn std::error::Error + Send + Sync>> {
92 let token_cache = {
93 let cache = self.user_token_cache.lock().await;
94 cache.clone()
95 };
96 if token_cache.is_none() {
97 let dds_token_cache = {
98 let cache = self.dds_token_cache.lock().await;
99 cache.clone()
100 };
101 if dds_token_cache.is_none() {
102 return Err("No token found".into());
103 }
104 return Ok(dds_token_cache.unwrap().expires_at);
105 }
106 Ok(parse_jwt(&token_cache.unwrap().refresh_token)?.exp)
107 }
108
109 pub async fn sign_in_with_app_credentials(
110 &mut self,
111 app_key: &str,
112 app_secret: &str,
113 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
114 self.app_key = Some(app_key.to_string());
115 self.app_secret = Some(app_secret.to_string());
116 *self.dds_token_cache.lock().await = None;
117 *self.user_token_cache.lock().await = None;
118
119 self.get_dds_app_access_token().await
120 }
121
122 pub async fn get_dds_access_token(
127 &self,
128 oidc_access_token: Option<&str>,
129 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
130 let result = if let Some(oidc_access_token) = oidc_access_token {
131 self.get_dds_access_token_with_oidc_access_token(oidc_access_token).await
132 } else if self.app_key.is_some() {
133 self.get_dds_app_access_token().await
134 } else {
135 self.get_dds_user_access_token().await
136 };
137
138 if result.is_err() {
139 *self.dds_token_cache.lock().await = None;
140 *self.user_token_cache.lock().await = None;
141 }
142
143 result
144 }
145
146 async fn get_dds_access_token_with_oidc_access_token(
148 &self,
149 oidc_access_token: &str,
150 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
151 *self.dds_token_cache.lock().await = None;
153 *self.user_token_cache.lock().await = None;
154
155 let response = self.get_dds_token_by_token(oidc_access_token).await?;
156 {
157 let mut cache = self.dds_token_cache.lock().await;
158 *cache = Some(DdsTokenCache {
159 access_token: response.access_token.clone(),
160 expires_at: parse_jwt(&response.access_token)?.exp,
161 });
162 }
163 Ok(response.access_token)
164 }
165
166 async fn get_dds_app_access_token(
169 &self,
170 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
171 let token_cache = {
172 let cache = self.dds_token_cache.lock().await;
173 cache.clone()
174 };
175
176 let app_key = self
177 .app_key
178 .clone()
179 .ok_or("App key is not set".to_string())?;
180 let app_secret = self
181 .app_secret
182 .clone()
183 .ok_or("App secret is not set".to_string())?;
184
185 let token_cache = get_cached_or_fresh_token(
186 &token_cache.unwrap_or(DdsTokenCache {
187 access_token: "".to_string(),
188 expires_at: 0,
189 }),
190 || {
191 let app_key = app_key.to_string();
192 let app_secret = app_secret.to_string();
193 let client = self.client.clone();
194 let api_url = self.api_url.clone();
195 let client_id = self.client_id.clone();
196 async move {
197 let response = client
198 .post(format!("{}/service/domains-access-token", api_url))
199 .basic_auth(app_key, Some(app_secret))
200 .header("Content-Type", "application/json")
201 .header("posemesh-client-id", client_id)
202 .send()
203 .await?;
204
205 if response.status().is_success() {
206 let token_response: DdsTokenResponse = response.json().await?;
207 Ok(DdsTokenCache {
208 access_token: token_response.access_token.clone(),
209 expires_at: parse_jwt(&token_response.access_token)?.exp,
210 })
211 } else {
212 let status = response.status();
213 let text = response
214 .text()
215 .await
216 .unwrap_or_else(|_| "Unknown error".to_string());
217 Err(format!(
218 "Failed to get DDS access token. Status: {} - {}",
219 status, text
220 )
221 .into())
222 }
223 }
224 },
225 )
226 .await?;
227
228 {
229 let mut cache = self.dds_token_cache.lock().await;
230 *cache = Some(token_cache.clone());
231 }
232
233 Ok(token_cache.access_token)
234 }
235
236 async fn get_dds_user_access_token(
240 &self,
241 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
242 let token_cache = {
243 let cache = self.dds_token_cache.lock().await;
244 cache.clone()
245 };
246
247 if token_cache.is_none() {
248 return Err("No access token found".into());
249 }
250
251 let user_token_cache = {
252 let cache = self.user_token_cache.lock().await;
253 cache.clone()
254 };
255
256 if user_token_cache.is_none() {
257 return Err("Login first".into());
258 }
259
260 let token_cache = get_cached_or_fresh_token(&token_cache.unwrap(), || {
261 let client = self.client.clone();
262 let api_url = self.api_url.clone();
263 let client_id = self.client_id.clone();
264
265 async move {
266 let client_clone = client.clone();
267 let api_url_clone = api_url.clone();
268 let client_id_clone = client_id.clone();
269 let refresh_token = user_token_cache.clone().unwrap().refresh_token;
270 let user_token_cache =
271 get_cached_or_fresh_token(&user_token_cache.unwrap(), || async move {
272 let response = client_clone
273 .post(format!("{}/user/refresh", api_url_clone))
274 .header("Content-Type", "application/json")
275 .header("posemesh-client-id", client_id_clone)
276 .header("Authorization", format!("Bearer {}", refresh_token))
277 .send()
278 .await
279 .expect("Failed to refresh token");
280
281 if response.status().is_success() {
282 let token_response: UserTokenResponse = response.json().await?;
283 Ok(UserTokenCache {
284 refresh_token: token_response.refresh_token.clone(),
285 access_token: token_response.access_token.clone(),
286 expires_at: parse_jwt(&token_response.access_token)?.exp,
287 })
288 } else {
289 Err(
290 format!("Failed to refresh token. Status: {}", response.status())
291 .into(),
292 )
293 }
294 })
295 .await?;
296
297 {
298 let mut cache = self.user_token_cache.lock().await;
299 *cache = Some(user_token_cache.clone());
300 }
301
302 let dds_token_response = self.get_dds_token_by_token(&user_token_cache.access_token).await?;
303
304 let dds_cache = DdsTokenCache {
305 access_token: dds_token_response.access_token.clone(),
306 expires_at: parse_jwt(&dds_token_response.access_token)?.exp,
307 };
308 {
309 let mut cache = self.dds_token_cache.lock().await;
310 *cache = Some(dds_cache.clone());
311 }
312 Ok(dds_cache)
313 }
314 })
315 .await?;
316
317 {
318 let mut cache = self.dds_token_cache.lock().await;
319 *cache = Some(token_cache.clone());
320 }
321
322 Ok(token_cache.access_token)
323 }
324
325 pub async fn user_login(
327 &mut self,
328 email: &str,
329 password: &str,
330 ) -> Result<String, Box<dyn std::error::Error + Send + Sync>> {
331 *self.dds_token_cache.lock().await = None;
332 *self.user_token_cache.lock().await = None;
333 self.app_key = None;
334 self.app_secret = None;
335
336 let credentials = UserCredentials { email: email.to_string(), password: password.to_string() };
337
338 let response = self.client
339 .post(format!("{}/user/login", &self.api_url))
340 .header("Content-Type", "application/json")
341 .header("posemesh-client-id", &self.client_id)
342 .json(&credentials)
343 .send()
344 .await?;
345
346 if response.status().is_success() {
347 let token_response: UserTokenResponse = response.json().await?;
348 {
349 let mut cache = self.user_token_cache.lock().await;
350 *cache = Some(UserTokenCache {
351 refresh_token: token_response.refresh_token.clone(),
352 access_token: token_response.access_token.clone(),
353 expires_at: parse_jwt(&token_response.access_token)?.exp,
354 });
355 }
356
357 let dds_token_response = self.get_dds_token_by_token(&token_response.access_token).await?;
358 let mut cache = self.dds_token_cache.lock().await;
359 let token_cache = DdsTokenCache {
360 access_token: dds_token_response.access_token.clone(),
361 expires_at: parse_jwt(&dds_token_response.access_token)?.exp,
362 };
363 *cache = Some(token_cache.clone());
364 Ok(token_cache.access_token)
365 } else {
366 Err(format!("Failed to login. Status: {}", response.status()).into())
367 }
368 }
369
370 async fn get_dds_token_by_token(
372 &self,
373 token: &str,
374 ) -> Result<DdsTokenResponse, Box<dyn std::error::Error + Send + Sync>> {
375 let dds_response = self.client.post(format!("{}/service/domains-access-token", &self.api_url))
376 .header(
377 "Authorization",
378 format!("Bearer {}", token),
379 )
380 .header("Content-Type", "application/json")
381 .header("posemesh-client-id", &self.client_id)
382 .send()
383 .await?;
384
385 if dds_response.status().is_success() {
386 dds_response.json::<DdsTokenResponse>().await.map_err(|e| e.into())
387 } else {
388 let status = dds_response.status();
389 let text = dds_response
390 .text()
391 .await
392 .unwrap_or_else(|_| "Unknown error".to_string());
393 Err(format!(
394 "Failed to get DDS access token. Status: {} - {}",
395 status, text
396 )
397 .into())
398 }
399 }
400}
401
402const REFRESH_CACHE_TIME: u64 = 3;
403
404pub(crate) async fn get_cached_or_fresh_token<R, F, Fut>(
405 cache: &R,
406 token_fetcher: F,
407) -> Result<R, Box<dyn std::error::Error + Send + Sync>>
408where
409 F: FnOnce() -> Fut,
410 R: TokenCache + Clone,
411 Fut: std::future::Future<Output = Result<R, Box<dyn std::error::Error + Send + Sync>>>,
412{
413 let expires_at = cache.get_expires_at();
415 let current_time = now_unix_secs();
416 if expires_at > current_time && expires_at - current_time > REFRESH_CACHE_TIME {
418 return Ok(cache.clone());
419 }
420
421 token_fetcher().await
423}
424
425#[derive(Debug, Deserialize)]
426pub struct JwtClaim {
427 pub exp: u64,
428 #[serde(default)]
429 pub org: Option<String>,
430}
431
432pub fn parse_jwt(token: &str) -> Result<JwtClaim, Box<dyn std::error::Error + Send + Sync>> {
433 let parts = token.split('.').collect::<Vec<&str>>();
434 if parts.len() != 3 {
435 return Err("Invalid JWT token".into());
436 }
437 let payload = parts[1];
438 let decoded = general_purpose::URL_SAFE_NO_PAD.decode(payload)?;
439 let claims: JwtClaim = serde_json::from_slice(&decoded)?;
440 Ok(claims)
441}
442
443#[cfg(test)]
444mod tests {
445 use super::*;
446 use std::sync::Arc;
447 use tokio::sync::Mutex;
448 use std::time::{SystemTime, UNIX_EPOCH};
449
450 #[derive(Clone, Debug)]
451 struct DummyTokenCache {
452 access_token: String,
453 expires_at: u64,
454 }
455
456 impl TokenCache for DummyTokenCache {
457 fn get_access_token(&self) -> String {
458 self.access_token.clone()
459 }
460 fn get_expires_at(&self) -> u64 {
461 self.expires_at
462 }
463 }
464
465 fn now_unix_secs() -> u64 {
466 SystemTime::now()
467 .duration_since(UNIX_EPOCH)
468 .unwrap()
469 .as_secs()
470 }
471
472 fn make_jwt(exp: u64) -> String {
473 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"HS256","typ":"JWT"}"#);
476 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(format!(r#"{{"exp":{}}}"#, exp));
477 format!("{}.{}.sig", header, payload)
478 }
479
480 #[tokio::test]
481 async fn test_ddstoken_about_to_expire_should_refetch() {
482 let now = now_unix_secs();
484 let expiring_soon = now + 2;
485 let cache = DummyTokenCache {
486 access_token: make_jwt(expiring_soon),
487 expires_at: expiring_soon,
488 };
489
490 let fetch_called = Arc::new(Mutex::new(false));
491 let fetch_called_clone = fetch_called.clone();
492
493 let new_exp = now + 1000;
494 let token_fetcher = move || {
495 let fetch_called_clone = fetch_called_clone.clone();
496 async move {
497 *fetch_called_clone.lock().await = true;
498 let token = DummyTokenCache {
499 access_token: make_jwt(new_exp),
500 expires_at: new_exp,
501 };
502 Ok(token)
504 }
505 };
506
507 let result = get_cached_or_fresh_token(&cache, token_fetcher).await.unwrap();
508 assert!(*fetch_called.lock().await, "Fetcher should have been called");
510 assert_eq!(result.expires_at, new_exp);
512 }
513
514 #[tokio::test]
515 async fn test_ddstoken_not_expiring_should_use_cache() {
516 let now = now_unix_secs();
518 let not_expiring = now + 100;
519 let cache = DummyTokenCache {
520 access_token: make_jwt(not_expiring),
521 expires_at: not_expiring,
522 };
523
524 let fetch_called = Arc::new(Mutex::new(false));
525 let fetch_called_clone = fetch_called.clone();
526
527 let cache_clone = cache.clone();
528 let token_fetcher = move || {
529 let fetch_called_clone = fetch_called_clone.clone();
530 async move {
531 *fetch_called_clone.lock().await = true;
532 Ok(cache_clone.clone())
533 }
534 };
535
536 let result = get_cached_or_fresh_token(&cache, token_fetcher).await.unwrap();
537 assert!(!*fetch_called.lock().await, "Fetcher should NOT have been called");
539 assert_eq!(result.expires_at, not_expiring);
541 }
542}