1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use futures::lock::Mutex;
4use reqwest::Client;
5use serde::Deserialize;
6
7#[cfg(not(target_family = "wasm"))]
8use tokio::spawn;
9#[cfg(target_family = "wasm")]
10use wasm_bindgen_futures::spawn_local as spawn;
11
12use posemesh_utils::now_unix_secs;
13#[cfg(target_family = "wasm")]
14use posemesh_utils::sleep;
15#[cfg(not(target_family = "wasm"))]
16use tokio::time::sleep;
17
18use crate::auth::{AuthClient, TokenCache, get_cached_or_fresh_token, parse_jwt};
19pub const ALL_DOMAINS_ORG: &str = "all";
20pub const OWN_DOMAINS_ORG: &str = "own";
21
22#[derive(Debug, Deserialize, Clone)]
23pub struct Domain {
24 pub id: String,
25 pub name: String,
26 pub organization_id: String,
27 pub domain_server_id: String,
28 pub redirect_url: Option<String>,
29}
30
31#[derive(Debug, Deserialize, Clone)]
32pub struct DomainServer {
33 pub id: String,
34 pub organization_id: String,
35 pub name: String,
36 pub url: String,
37}
38
39#[derive(Debug, Deserialize, Clone)]
40pub struct DomainWithToken {
41 #[serde(flatten)]
42 pub domain: DomainWithServer,
43 #[serde(skip)]
44 pub expires_at: u64,
45 access_token: String,
46}
47impl TokenCache for DomainWithToken {
48 fn get_access_token(&self) -> String {
49 self.access_token.clone()
50 }
51
52 fn get_expires_at(&self) -> u64 {
53 self.expires_at
54 }
55}
56
57#[derive(Debug, Deserialize, Clone)]
58pub struct DomainWithServer {
59 #[serde(flatten)]
60 pub domain: Domain,
61 pub domain_server: DomainServer,
62}
63
64#[derive(Debug, Clone)]
65pub struct DiscoveryService {
66 dds_url: String,
67 client: Client,
68 cache: Arc<Mutex<HashMap<String, DomainWithToken>>>,
69 api_client: AuthClient,
70 oidc_access_token: Option<String>,
71}
72
73#[derive(Debug, Deserialize)]
74pub struct ListDomainsResponse {
75 pub domains: Vec<DomainWithServer>,
76}
77
78impl DiscoveryService {
79 pub fn new(api_url: &str, dds_url: &str, client_id: &str) -> Self {
80 let api_client = AuthClient::new(api_url, client_id);
81
82 Self {
83 dds_url: dds_url.to_string(),
84 client: Client::new(),
85 cache: Arc::new(Mutex::new(HashMap::new())),
86 api_client,
87 oidc_access_token: None,
88 }
89 }
90
91 pub async fn list_domains(
93 &self,
94 org: &str,
95 ) -> Result<Vec<DomainWithServer>, Box<dyn std::error::Error + Send + Sync>> {
96 let access_token = self
97 .api_client
98 .get_dds_access_token(self.oidc_access_token.as_deref())
99 .await?;
100 let response = self
101 .client
102 .get(&format!(
103 "{}/api/v1/domains?org={}&with=domain_server",
104 self.dds_url, org
105 ))
106 .bearer_auth(access_token)
107 .header("Content-Type", "application/json")
108 .header("posemesh-client-id", self.api_client.client_id.clone())
109 .send()
110 .await?;
111
112 if response.status().is_success() {
113 let domain_servers: ListDomainsResponse = response.json().await?;
114 Ok(domain_servers.domains)
115 } else {
116 Err(format!("Failed to list domains. Status: {}", response.status()).into())
117 }
118 }
119
120 pub async fn sign_in_with_auki_account(
121 &mut self,
122 email: &str,
123 password: &str,
124 remember_password: bool,
125 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
126 self.cache.lock().await.clear();
127 self.oidc_access_token = None;
128 let _ = self.api_client.user_login(email, password).await?;
129 if remember_password {
130 let mut api_client = self.api_client.clone();
131 let email = email.to_string();
132 let password = password.to_string();
133 spawn(async move {
134 loop {
135 let expires_at = api_client
136 .get_expires_at()
137 .await
138 .inspect_err(|e| tracing::error!("Failed to get expires at: {}", e));
139 if let Ok(expires_at) = expires_at {
140 let expiration = {
141 let now = now_unix_secs();
142 let duration = expires_at - now;
143 if duration > 600 {
144 Some(Duration::from_secs(duration))
145 } else {
146 None
147 }
148 };
149
150 if let Some(expiration) = expiration {
151 tracing::info!("Refreshing token in {} seconds", expiration.as_secs());
152 sleep(expiration).await;
153 }
154
155 let _ = api_client
156 .user_login(&email, &password)
157 .await
158 .inspect_err(|e| tracing::error!("Failed to login: {}", e));
159 }
160 }
161 });
162 }
163 Ok(())
164 }
165
166 pub async fn sign_in_as_auki_app(
167 &mut self,
168 app_key: &str,
169 app_secret: &str,
170 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
171 self.cache.lock().await.clear();
172 self.oidc_access_token = None;
173 let _ = self
174 .api_client
175 .sign_in_with_app_credentials(app_key, app_secret)
176 .await?;
177 Ok(())
178 }
179
180 pub fn with_oidc_access_token(&self, oidc_access_token: &str) -> Self {
181 if let Some(cached_oidc_access_token) = self.oidc_access_token.as_deref() {
182 if cached_oidc_access_token == oidc_access_token {
183 return self.clone();
184 }
185 }
186 Self {
187 dds_url: self.dds_url.clone(),
188 client: self.client.clone(),
189 cache: Arc::new(Mutex::new(HashMap::new())),
190 api_client: AuthClient::new(&self.api_client.api_url, &self.api_client.client_id),
191 oidc_access_token: Some(oidc_access_token.to_string()),
192 }
193 }
194
195 pub async fn auth_domain(
196 &self,
197 domain_id: &str,
198 ) -> Result<DomainWithToken, Box<dyn std::error::Error + Send + Sync>> {
199 let access_token = self
200 .api_client
201 .get_dds_access_token(self.oidc_access_token.as_deref())
202 .await?;
203 let cache = if let Some(cached_domain) = self.cache.lock().await.get(domain_id) {
205 cached_domain.clone()
206 } else {
207 DomainWithToken {
208 domain: DomainWithServer {
209 domain: Domain {
210 id: domain_id.to_string(),
211 name: "".to_string(),
212 organization_id: "".to_string(),
213 domain_server_id: "".to_string(),
214 redirect_url: None,
215 },
216 domain_server: DomainServer {
217 id: "".to_string(),
218 organization_id: "".to_string(),
219 name: "".to_string(),
220 url: "".to_string(),
221 },
222 },
223 expires_at: 0,
224 access_token: "".to_string(),
225 }
226 };
227
228 let cached = get_cached_or_fresh_token(&cache, || {
229 let client = self.client.clone();
230 let dds_url = self.dds_url.clone();
231 let client_id = self.api_client.client_id.clone();
232 async move {
233 let response = client
234 .post(&format!("{}/api/v1/domains/{}/auth", dds_url, domain_id))
235 .bearer_auth(access_token)
236 .header("Content-Type", "application/json")
237 .header("posemesh-client-id", client_id)
238 .send()
239 .await?;
240
241 if response.status().is_success() {
242 let mut domain_with_token: DomainWithToken = response.json().await?;
243 domain_with_token.expires_at =
244 parse_jwt(&domain_with_token.get_access_token())?.exp;
245 Ok(domain_with_token)
246 } else {
247 let status = response.status();
248 let text = response
249 .text()
250 .await
251 .unwrap_or_else(|_| "Unknown error".to_string());
252 Err(format!("Failed to auth domain. Status: {} - {}", status, text).into())
253 }
254 }
255 })
256 .await?;
257
258 let mut cache = self.cache.lock().await;
260 cache.insert(domain_id.to_string(), cached.clone());
261 Ok(cached)
262 }
263}