1use eyre::{bail, eyre, Result};
8use reqwest::Client;
9use serde::de::DeserializeOwned;
10use serde::Serialize;
11use url::Url;
12
13use crate::storage::JwtToken;
15use crate::traits::{ClientAuthenticator, ClientStorage};
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AuthMode {
20 Required,
22 None,
24}
25
26#[derive(Debug, Clone, Copy)]
28enum RequestType {
29 Get,
30 Post,
31 Delete,
32}
33
34#[derive(Debug)]
35enum RefreshError {
36 NoRefreshToken,
37 RefreshFailed,
38}
39
40#[derive(Clone, Debug)]
42pub struct ConnectionInfo<A, S>
43where
44 A: ClientAuthenticator + Clone + Send + Sync,
45 S: ClientStorage + Clone + Send + Sync,
46{
47 pub api_url: Url,
48 pub client: Client,
49 pub node_name: Option<String>,
50 pub authenticator: A,
51 pub client_storage: S,
52}
53
54impl<A, S> ConnectionInfo<A, S>
55where
56 A: ClientAuthenticator + Clone + Send + Sync,
57 S: ClientStorage + Clone + Send + Sync,
58{
59 pub fn new(
60 api_url: Url,
61 node_name: Option<String>,
62 authenticator: A,
63 client_storage: S,
64 ) -> Self {
65 Self {
66 api_url,
67 client: Client::new(),
68 node_name,
69 authenticator,
70 client_storage,
71 }
72 }
73
74 pub async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
75 self.request(RequestType::Get, path, None::<()>).await
76 }
77
78 fn path_requires_auth(&self, path: &str) -> bool {
80 !path.starts_with("admin-api/health")
82 }
83
84 pub async fn post<I, O>(&self, path: &str, body: I) -> Result<O>
85 where
86 I: Serialize,
87 O: DeserializeOwned,
88 {
89 self.request(RequestType::Post, path, Some(body)).await
90 }
91
92 pub async fn post_no_body<O: DeserializeOwned>(&self, path: &str) -> Result<O> {
93 self.request(RequestType::Post, path, None::<()>).await
94 }
95
96 pub async fn delete<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
97 self.request(RequestType::Delete, path, None::<()>).await
98 }
99
100 pub async fn put_binary(&self, path: &str, data: Vec<u8>) -> Result<reqwest::Response> {
101 let mut url = self.api_url.clone();
102
103 if let Some((path_part, query_part)) = path.split_once('?') {
104 url.set_path(path_part);
105 url.set_query(Some(query_part));
106 } else {
107 url.set_path(path);
108 }
109
110 let requires_auth = self.path_requires_auth(path);
111
112 let auth_header = if requires_auth && self.node_name.is_some() {
113 if let Ok(Some(tokens)) = self
114 .client_storage
115 .load_tokens(&self.node_name.as_ref().unwrap())
116 .await
117 {
118 Some(format!("Bearer {}", tokens.access_token))
119 } else {
120 match self.authenticator.authenticate(&self.api_url).await {
121 Ok(new_tokens) => {
122 self.client_storage
123 .update_tokens(&self.node_name.as_ref().unwrap(), &new_tokens)
124 .await?;
125 Some(format!("Bearer {}", new_tokens.access_token))
126 }
127 Err(auth_err) => {
128 bail!("Authentication failed: {}", auth_err);
129 }
130 }
131 }
132 } else {
133 None
134 };
135
136 let response = self
137 .execute_request_with_auth_retry(|| {
138 let mut builder = self.client.put(url.clone()).body(data.clone());
139
140 if let Some(ref auth_header) = auth_header {
141 builder = builder.header("Authorization", auth_header);
142 }
143
144 builder.send()
145 })
146 .await?;
147
148 Ok(response)
149 }
150
151 pub async fn get_binary(&self, path: &str) -> Result<Vec<u8>> {
152 let mut url = self.api_url.clone();
153
154 if let Some((path_part, query_part)) = path.split_once('?') {
155 url.set_path(path_part);
156 url.set_query(Some(query_part));
157 } else {
158 url.set_path(path);
159 }
160
161 let requires_auth = self.path_requires_auth(path);
162
163 let auth_header = if requires_auth && self.node_name.is_some() {
164 if let Ok(Some(tokens)) = self
165 .client_storage
166 .load_tokens(&self.node_name.as_ref().unwrap())
167 .await
168 {
169 Some(format!("Bearer {}", tokens.access_token))
170 } else {
171 match self.authenticator.authenticate(&self.api_url).await {
172 Ok(new_tokens) => {
173 self.client_storage
174 .update_tokens(&self.node_name.as_ref().unwrap(), &new_tokens)
175 .await?;
176 Some(format!("Bearer {}", new_tokens.access_token))
177 }
178 Err(auth_err) => {
179 bail!("Authentication failed: {}", auth_err);
180 }
181 }
182 }
183 } else {
184 None
185 };
186
187 let response = self
188 .execute_request_with_auth_retry(|| {
189 let mut builder = self.client.get(url.clone());
190
191 if let Some(ref auth_header) = auth_header {
192 builder = builder.header("Authorization", auth_header);
193 }
194
195 builder.send()
196 })
197 .await?;
198
199 response
200 .bytes()
201 .await
202 .map(|b| b.to_vec())
203 .map_err(Into::into)
204 }
205
206 pub async fn head(&self, path: &str) -> Result<reqwest::header::HeaderMap> {
207 let mut url = self.api_url.clone();
208 url.set_path(path);
209
210 let requires_auth = self.path_requires_auth(path);
212
213 let auth_header = if requires_auth && self.node_name.is_some() {
215 if let Ok(Some(tokens)) = self
216 .client_storage
217 .load_tokens(&self.node_name.as_ref().unwrap())
218 .await
219 {
220 Some(format!("Bearer {}", tokens.access_token))
221 } else {
222 match self.authenticator.authenticate(&self.api_url).await {
224 Ok(new_tokens) => {
225 self.client_storage
227 .update_tokens(&self.node_name.as_ref().unwrap(), &new_tokens)
228 .await?;
229 Some(format!("Bearer {}", new_tokens.access_token))
230 }
231 Err(auth_err) => {
232 bail!("Authentication failed: {}", auth_err);
233 }
234 }
235 }
236 } else {
237 None
238 };
239
240 let response = self
241 .execute_request_with_auth_retry(|| {
242 let mut builder = self.client.head(url.clone());
243
244 if let Some(ref auth_header) = auth_header {
245 builder = builder.header("Authorization", auth_header);
246 }
247
248 builder.send()
249 })
250 .await?;
251
252 Ok(response.headers().clone())
253 }
254
255 async fn request<I, O>(&self, req_type: RequestType, path: &str, body: Option<I>) -> Result<O>
256 where
257 I: Serialize,
258 O: DeserializeOwned,
259 {
260 let mut url = self.api_url.clone();
261 url.set_path(path);
262
263 let requires_auth = self.path_requires_auth(path);
265
266 let auth_header = if requires_auth && self.node_name.is_some() {
268 if let Ok(Some(tokens)) = self
269 .client_storage
270 .load_tokens(&self.node_name.as_ref().unwrap())
271 .await
272 {
273 Some(format!("Bearer {}", tokens.access_token))
274 } else {
275 match self.authenticator.authenticate(&self.api_url).await {
277 Ok(new_tokens) => {
278 self.client_storage
280 .update_tokens(&self.node_name.as_ref().unwrap(), &new_tokens)
281 .await?;
282 Some(format!("Bearer {}", new_tokens.access_token))
283 }
284 Err(auth_err) => {
285 bail!("Authentication failed: {}", auth_err);
286 }
287 }
288 }
289 } else {
290 None
291 };
292
293 let response = self
294 .execute_request_with_auth_retry(|| {
295 let mut builder = match req_type {
296 RequestType::Get => self.client.get(url.clone()),
297 RequestType::Post => self.client.post(url.clone()).json(&body),
298 RequestType::Delete => self.client.delete(url.clone()),
299 };
300
301 if let Some(ref auth_header) = auth_header {
302 builder = builder.header("Authorization", auth_header);
303 }
304
305 builder.send()
306 })
307 .await?;
308
309 response.json::<O>().await.map_err(Into::into)
310 }
311
312 async fn execute_request_with_auth_retry<F, Fut>(
313 &self,
314 request_builder: F,
315 ) -> Result<reqwest::Response>
316 where
317 F: Fn() -> Fut,
318 Fut: std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
319 {
320 let mut retry_count = 0;
321 const MAX_RETRIES: u32 = 2;
322
323 loop {
324 let response = request_builder().await?;
325
326 if response.status() == 401 && retry_count < MAX_RETRIES {
327 retry_count += 1;
328
329 match self.refresh_token().await {
331 Ok(new_token) => {
332 if let Some(ref node_name) = self.node_name {
334 self.client_storage
335 .update_tokens(node_name, &new_token)
336 .await?;
337 }
338 continue;
339 }
340 Err(RefreshError::RefreshFailed) => {
341 match self.authenticator.authenticate(&self.api_url).await {
343 Ok(new_tokens) => {
344 if let Some(ref node_name) = self.node_name {
346 self.client_storage
347 .update_tokens(node_name, &new_tokens)
348 .await?;
349 }
350 continue;
351 }
352 Err(auth_err) => {
353 bail!("Authentication failed: {}", auth_err);
354 }
355 }
356 }
357 Err(RefreshError::NoRefreshToken) => {
358 bail!("No refresh token available for authentication");
360 }
361 }
362 }
363
364 if response.status() == 403 {
365 bail!("Access denied. Your authentication may not have sufficient permissions.");
366 }
367
368 if !response.status().is_success() {
369 bail!("Request failed with status: {}", response.status());
370 }
371
372 return Ok(response);
373 }
374 }
375
376 async fn refresh_token(&self) -> Result<JwtToken, RefreshError> {
377 if let Some(ref node_name) = self.node_name {
378 if let Ok(Some(tokens)) = self.client_storage.load_tokens(node_name).await {
379 let refresh_token = tokens
380 .refresh_token
381 .clone()
382 .ok_or(RefreshError::NoRefreshToken)?;
383
384 match self
385 .try_refresh_token(&tokens.access_token, &refresh_token)
386 .await
387 {
388 Ok(new_token) => {
389 return Ok(new_token);
390 }
391 Err(_) => {
392 return Err(RefreshError::RefreshFailed);
393 }
394 }
395 }
396 }
397
398 Err(RefreshError::NoRefreshToken)
399 }
400
401 async fn try_refresh_token(&self, access_token: &str, refresh_token: &str) -> Result<JwtToken> {
402 let refresh_url = self.api_url.join("/auth/refresh")?;
403
404 #[derive(serde::Serialize)]
405 struct RefreshRequest {
406 access_token: String,
407 refresh_token: String,
408 }
409
410 #[derive(serde::Deserialize, Debug)]
411 struct RefreshResponse {
412 access_token: String,
413 refresh_token: String,
414 }
415
416 #[derive(serde::Deserialize, Debug)]
417 struct WrappedResponse {
418 data: RefreshResponse,
419 }
420
421 let request_body = RefreshRequest {
422 access_token: access_token.to_owned(),
423 refresh_token: refresh_token.to_owned(),
424 };
425
426 let response = self
427 .client
428 .post(refresh_url)
429 .json(&request_body)
430 .send()
431 .await?;
432
433 if !response.status().is_success() {
434 return Err(eyre!(
435 "Token refresh failed with status: {}",
436 response.status()
437 ));
438 }
439
440 let wrapped_response: WrappedResponse = response.json().await?;
441
442 Ok(JwtToken::with_refresh(
443 wrapped_response.data.access_token,
444 wrapped_response.data.refresh_token,
445 ))
446 }
447
448 pub async fn update_tokens(&self, new_tokens: &JwtToken) -> Result<()> {
450 if let Some(node_name) = &self.node_name {
451 self.client_storage
452 .update_tokens(node_name, new_tokens)
453 .await
454 } else {
455 Ok(())
458 }
459 }
460
461 pub async fn detect_auth_mode(&self) -> Result<AuthMode> {
463 if self.api_url.host_str() == Some("localhost")
465 || self.api_url.host_str() == Some("127.0.0.1")
466 {
467 return Ok(AuthMode::None);
468 }
469
470 let health_url = self.api_url.join("admin-api/health")?;
473
474 match self.client.get(health_url).send().await {
475 Ok(response) => {
476 if response.status() == 401 {
477 Ok(AuthMode::Required)
479 } else if response.status().is_success() {
480 Ok(AuthMode::None)
482 } else {
483 Ok(AuthMode::Required)
485 }
486 }
487 Err(_) => {
488 Ok(AuthMode::None)
491 }
492 }
493 }
494}