1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use futures::lock::Mutex;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6
7#[cfg(not(target_family = "wasm"))]
8use tokio::spawn;
9#[cfg(target_family = "wasm")]
10use wasm_bindgen_futures::spawn_local as spawn;
11
12use posemesh_utils::now_unix_secs;
13#[cfg(target_family = "wasm")]
14use posemesh_utils::sleep;
15#[cfg(not(target_family = "wasm"))]
16use tokio::time::sleep;
17
18use crate::auth::{AuthClient, TokenCache, get_cached_or_fresh_token, parse_jwt};
19pub const ALL_DOMAINS_ORG: &str = "all";
20pub const OWN_DOMAINS_ORG: &str = "own";
21
22#[derive(Debug, Deserialize, Clone, Serialize)]
23pub struct DomainServer {
24 pub id: String,
25 pub organization_id: String,
26 pub name: String,
27 pub url: String,
28}
29
30#[derive(Debug, Deserialize, Clone)]
31pub struct DomainWithToken {
32 #[serde(flatten)]
33 pub domain: DomainWithServer,
34 #[serde(skip)]
35 pub expires_at: u64,
36 access_token: String,
37}
38impl TokenCache for DomainWithToken {
39 fn get_access_token(&self) -> String {
40 self.access_token.clone()
41 }
42
43 fn get_expires_at(&self) -> u64 {
44 self.expires_at
45 }
46}
47
48#[derive(Debug, Deserialize, Clone, Serialize)]
49pub struct DomainWithServer {
50 pub id: String,
51 pub name: String,
52 pub organization_id: String,
53 pub domain_server_id: String,
54 pub redirect_url: Option<String>,
55 pub domain_server: DomainServer,
56}
57
58#[derive(Debug, Clone)]
59pub struct DiscoveryService {
60 dds_url: String,
61 client: Client,
62 cache: Arc<Mutex<HashMap<String, DomainWithToken>>>,
63 api_client: AuthClient,
64 oidc_access_token: Option<String>,
65}
66
67#[derive(Debug, Deserialize)]
68pub struct ListDomainsResponse {
69 pub domains: Vec<DomainWithServer>,
70}
71
72impl DiscoveryService {
73 pub fn new(api_url: &str, dds_url: &str, client_id: &str) -> Self {
74 let api_client = AuthClient::new(api_url, client_id);
75
76 Self {
77 dds_url: dds_url.to_string(),
78 client: Client::new(),
79 cache: Arc::new(Mutex::new(HashMap::new())),
80 api_client,
81 oidc_access_token: None,
82 }
83 }
84
85 pub async fn list_domains(
87 &self,
88 org: &str,
89 ) -> Result<Vec<DomainWithServer>, Box<dyn std::error::Error + Send + Sync>> {
90 let access_token = self
91 .api_client
92 .get_dds_access_token(self.oidc_access_token.as_deref())
93 .await?;
94 let response = self
95 .client
96 .get(&format!(
97 "{}/api/v1/domains?org={}&with=domain_server",
98 self.dds_url, org
99 ))
100 .bearer_auth(access_token)
101 .header("Content-Type", "application/json")
102 .header("posemesh-client-id", self.api_client.client_id.clone())
103 .send()
104 .await?;
105
106 if response.status().is_success() {
107 let domain_servers: ListDomainsResponse = response.json().await?;
108 Ok(domain_servers.domains)
109 } else {
110 let status = response.status();
111 let text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
112 Err(format!("Failed to list domains. Status: {} - {} - {}", status, text, org).into())
113 }
114 }
115
116 pub async fn sign_in_with_auki_account(
117 &mut self,
118 email: &str,
119 password: &str,
120 remember_password: bool,
121 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
122 self.cache.lock().await.clear();
123 self.oidc_access_token = None;
124 let _ = self.api_client.user_login(email, password).await?;
125 if remember_password {
126 let mut api_client = self.api_client.clone();
127 let email = email.to_string();
128 let password = password.to_string();
129 spawn(async move {
130 loop {
131 let expires_at = api_client
132 .get_expires_at()
133 .await
134 .inspect_err(|e| tracing::error!("Failed to get expires at: {}", e));
135 if let Ok(expires_at) = expires_at {
136 let expiration = {
137 let now = now_unix_secs();
138 let duration = expires_at - now;
139 if duration > 600 {
140 Some(Duration::from_secs(duration))
141 } else {
142 None
143 }
144 };
145
146 if let Some(expiration) = expiration {
147 tracing::info!("Refreshing token in {} seconds", expiration.as_secs());
148 sleep(expiration).await;
149 }
150
151 let _ = api_client
152 .user_login(&email, &password)
153 .await
154 .inspect_err(|e| tracing::error!("Failed to login: {}", e));
155 }
156 }
157 });
158 }
159 Ok(())
160 }
161
162 pub async fn sign_in_as_auki_app(
163 &mut self,
164 app_key: &str,
165 app_secret: &str,
166 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
167 self.cache.lock().await.clear();
168 self.oidc_access_token = None;
169 let _ = self
170 .api_client
171 .sign_in_with_app_credentials(app_key, app_secret)
172 .await?;
173 Ok(())
174 }
175
176 pub fn with_oidc_access_token(&self, oidc_access_token: &str) -> Self {
177 if let Some(cached_oidc_access_token) = self.oidc_access_token.as_deref() {
178 if cached_oidc_access_token == oidc_access_token {
179 return self.clone();
180 }
181 }
182 Self {
183 dds_url: self.dds_url.clone(),
184 client: self.client.clone(),
185 cache: Arc::new(Mutex::new(HashMap::new())),
186 api_client: AuthClient::new(&self.api_client.api_url, &self.api_client.client_id),
187 oidc_access_token: Some(oidc_access_token.to_string()),
188 }
189 }
190
191 pub async fn auth_domain(
192 &self,
193 domain_id: &str,
194 ) -> Result<DomainWithToken, Box<dyn std::error::Error + Send + Sync>> {
195 let access_token = self
196 .api_client
197 .get_dds_access_token(self.oidc_access_token.as_deref())
198 .await?;
199 let cache = if let Some(cached_domain) = self.cache.lock().await.get(domain_id) {
201 cached_domain.clone()
202 } else {
203 DomainWithToken {
204 domain: DomainWithServer {
205 id: domain_id.to_string(),
206 name: "".to_string(),
207 organization_id: "".to_string(),
208 domain_server_id: "".to_string(),
209 redirect_url: None,
210 domain_server: DomainServer {
211 id: "".to_string(),
212 organization_id: "".to_string(),
213 name: "".to_string(),
214 url: "".to_string(),
215 },
216 },
217 expires_at: 0,
218 access_token: "".to_string(),
219 }
220 };
221
222 let cached = get_cached_or_fresh_token(&cache, || {
223 let client = self.client.clone();
224 let dds_url = self.dds_url.clone();
225 let client_id = self.api_client.client_id.clone();
226 async move {
227 let response = client
228 .post(&format!("{}/api/v1/domains/{}/auth", dds_url, domain_id))
229 .bearer_auth(access_token)
230 .header("Content-Type", "application/json")
231 .header("posemesh-client-id", client_id)
232 .send()
233 .await?;
234
235 if response.status().is_success() {
236 let mut domain_with_token: DomainWithToken = response.json().await?;
237 domain_with_token.expires_at =
238 parse_jwt(&domain_with_token.get_access_token())?.exp;
239 Ok(domain_with_token)
240 } else {
241 let status = response.status();
242 let text = response
243 .text()
244 .await
245 .unwrap_or_else(|_| "Unknown error".to_string());
246 Err(format!("Failed to auth domain. Status: {} - {}", status, text).into())
247 }
248 }
249 })
250 .await?;
251
252 let mut cache = self.cache.lock().await;
254 cache.insert(domain_id.to_string(), cached.clone());
255 Ok(cached)
256 }
257}