1use crate::strongbox;
2use crate::volga;
3use crate::{Error, Result};
4use bytes::Bytes;
5use serde_json::json;
6
7#[derive(Clone, serde::Deserialize)]
8#[serde(rename_all = "kebab-case")]
9pub struct X509SVID {
11 pub cert: String,
12 pub private_key: String,
13}
14
15#[derive(Clone, serde::Deserialize)]
16#[serde(rename_all = "kebab-case")]
17struct LoginToken {
18 token: String,
19 expires_in: i64,
20 expires: chrono::DateTime<chrono::Utc>,
21 creation_time: chrono::DateTime<chrono::Utc>,
22 jwt_svid: Option<String>,
23 x509_svid: Option<X509SVID>,
24}
25
26impl LoginToken {
27 fn renew_at(&self) -> chrono::DateTime<chrono::Utc> {
28 self.expires - chrono::Duration::seconds(self.expires_in / 4)
29 }
30}
31
32impl std::fmt::Debug for LoginToken {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 f.debug_struct("LoginToken")
35 .field("expires_in", &self.expires_in)
36 .field("creation_time", &self.creation_time)
37 .field("jwt-svid", &self.jwt_svid.is_some())
38 .field("x509-svid", &self.x509_svid.is_some())
39 .finish_non_exhaustive()
40 }
41}
42
43#[derive(Debug)]
44struct ClientState {
45 login_token: LoginToken,
46}
47
48#[derive(Clone)]
50#[allow(clippy::struct_excessive_bools)]
51pub struct ClientBuilder {
52 reqwest_ca: Vec<reqwest::Certificate>,
53 tls_ca: tokio_rustls::rustls::RootCertStore,
54 disable_cert_verification: bool,
55 connection_verbose: bool,
56 auto_renew_token: bool,
57 timeout: Option<core::time::Duration>,
58 connect_timeout: Option<core::time::Duration>,
59}
60
61impl ClientBuilder {
62 #[must_use]
64 pub(crate) fn new() -> Self {
65 let tls_ca = webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect();
66 Self {
67 reqwest_ca: Vec::new(),
68 tls_ca,
69 disable_cert_verification: false,
70 connection_verbose: false,
71 auto_renew_token: true,
72 timeout: None,
73 connect_timeout: None,
74 }
75 }
76
77 #[must_use]
79 pub fn timeout(self, timeout: core::time::Duration) -> Self {
80 Self {
81 timeout: Some(timeout),
82 ..self
83 }
84 }
85
86 #[must_use]
88 pub fn connection_timeout(self, timeout: core::time::Duration) -> Self {
89 Self {
90 connect_timeout: Some(timeout),
91 ..self
92 }
93 }
94
95 pub fn add_root_certificate(mut self, cert: &[u8]) -> Result<Self> {
97 use std::iter;
98 let r_ca = reqwest::Certificate::from_pem(cert)?;
99 let mut ca_reader = std::io::BufReader::new(cert);
100 for item in iter::from_fn(|| rustls_pemfile::read_one(&mut ca_reader).transpose()) {
101 if let rustls_pemfile::Item::X509Certificate(cert) = item? {
102 self.tls_ca.add(cert)?;
103 }
104 }
105 self.reqwest_ca.push(r_ca);
106 Ok(self)
107 }
108
109 #[must_use]
111 pub fn danger_disable_cert_verification(self) -> Self {
112 Self {
113 disable_cert_verification: true,
114 ..self
115 }
116 }
117
118 #[must_use]
121 pub fn enable_verbose_connection(self) -> Self {
122 Self {
123 connection_verbose: true,
124 ..self
125 }
126 }
127
128 #[must_use]
130 pub fn disable_token_auto_renewal(self) -> Self {
131 Self {
132 auto_renew_token: false,
133 ..self
134 }
135 }
136
137 #[deprecated]
141 pub async fn application_login(&self, host: &str, approle_id: Option<&str>) -> Result<Client> {
142 let secret_id = std::env::var("APPROLE_SECRET_ID")
143 .map_err(|_| Error::LoginFailureMissingEnv(String::from("APPROLE_SECRET_ID")))?;
144
145 let role_id = approle_id.unwrap_or(&secret_id);
147
148 let base_url = url::Url::parse(host)?;
149 let url = base_url.join("v1/approle-login")?;
150 let data = json!({
151 "role-id": role_id,
152 "secret-id": secret_id,
153 });
154 Client::do_login(self, base_url, url, data).await
155 }
156
157 pub async fn approle_login(
159 &self,
160 host: &str,
161 secret_id: &str,
162 role_id: Option<&str>,
163 ) -> Result<Client> {
164 let role_id = role_id.unwrap_or(secret_id);
166
167 let base_url = url::Url::parse(host)?;
168 let url = base_url.join("v1/approle-login")?;
169 let data = json!({
170 "role-id": role_id,
171 "secret-id": secret_id,
172 });
173 Client::do_login(self, base_url, url, data).await
174 }
175
176 #[tracing::instrument(skip(self, password))]
179 pub async fn login(&self, host: &str, username: &str, password: &str) -> Result<Client> {
180 let base_url = url::Url::parse(host)?;
181 let url = base_url.join("v1/login")?;
182
183 let data = json!({
185 "username":username,
186 "password":password
187 });
188 Client::do_login(self, base_url, url, data).await
189 }
190
191 #[tracing::instrument(skip(self, token))]
193 pub fn token_login(&self, host: &str, token: &str) -> Result<Client> {
194 let base_url = url::Url::parse(host)?;
195 Client::new_from_token(self, base_url, token)
196 }
197
198 pub async fn jwt_login(
199 &self,
200 host: &str,
201 tenant_name: &str,
202 jwt_auth: &str,
203 role: &str,
204 jwt: &str,
205 ) -> Result<Client> {
206 let base_url = url::Url::parse(host)?;
207 let url = base_url.join("v1/jwt-login")?;
208
209 let data = json!({
210 "tenant": tenant_name,
211 "jwt-auth": jwt_auth,
212 "role": role,
213 "jwt": jwt
214 });
215 Client::do_login(self, base_url, url, data).await
216 }
217}
218
219impl Default for ClientBuilder {
220 fn default() -> Self {
221 Self::new()
222 }
223}
224
225#[derive(Clone)]
228pub struct Client {
229 pub(crate) base_url: url::Url,
230 pub(crate) websocket_url: url::Url,
231 state: std::sync::Arc<tokio::sync::Mutex<ClientState>>,
232 #[allow(clippy::struct_field_names)]
233 http_client: reqwest::Client,
234 tls_ca: tokio_rustls::rustls::RootCertStore,
235 disable_cert_verification: bool,
236}
237
238impl std::fmt::Debug for Client {
239 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240 f.debug_struct("Client")
241 .field("base_url", &self.base_url)
242 .field("websocket_url", &self.websocket_url)
243 .field("state", &self.state)
244 .field("http_client", &self.http_client)
245 .field("disable_cert_verification", &self.disable_cert_verification)
246 .finish_non_exhaustive()
247 }
248}
249
250impl Client {
251 #[must_use]
253 pub fn builder() -> ClientBuilder {
254 ClientBuilder::new()
255 }
256
257 pub async fn token_is_expired(&self) -> bool {
259 let state = self.state.lock().await;
260 chrono::Local::now() > state.login_token.expires
261 }
262
263 pub async fn token_expires(&self) -> chrono::DateTime<chrono::Utc> {
264 let state = self.state.lock().await;
265 state.login_token.expires.into()
266 }
267
268 pub async fn jwt_svid(&self) -> Option<String> {
270 let state = self.state.lock().await;
271 state.login_token.jwt_svid.clone()
272 }
273
274 pub async fn x509_svid(&self) -> Option<X509SVID> {
275 let state = self.state.lock().await;
276 state.login_token.x509_svid.clone()
277 }
278
279 async fn do_login(
280 builder: &ClientBuilder,
281 base_url: url::Url,
282 url: url::Url,
283 payload: serde_json::Value,
284 ) -> Result<Self> {
285 let json = serde_json::to_string(&payload)?;
286 let client = Self::reqwest_client(builder)?;
287 let result = client
288 .post(url)
289 .header("content-type", "application/json")
290 .body(json)
291 .send()
292 .await?;
293
294 if result.status().is_success() {
295 let login_token = result.json().await?;
296
297 Self::new(builder, client, base_url, login_token)
298 } else {
299 let text = result.text().await?;
300 tracing::debug!("login returned {}", text);
301 Err(Error::LoginFailure(text))
302 }
303 }
304
305 fn reqwest_client(builder: &ClientBuilder) -> Result<reqwest::Client> {
306 let reqwest_client_builder = reqwest::Client::builder().use_rustls_tls();
307
308 let reqwest_client_builder = builder
310 .reqwest_ca
311 .iter()
312 .fold(reqwest_client_builder, |reqwest_client_builder, ca| {
313 reqwest_client_builder.add_root_certificate(ca.clone())
314 });
315
316 tracing::debug!("Added {} CA certs", builder.reqwest_ca.len());
317
318 let reqwest_client_builder =
319 reqwest_client_builder.danger_accept_invalid_certs(builder.disable_cert_verification);
320
321 let reqwest_client_builder =
322 reqwest_client_builder.connection_verbose(builder.connection_verbose);
323
324 let reqwest_client_builder = if let Some(duration) = builder.timeout {
325 reqwest_client_builder.timeout(duration)
326 } else {
327 reqwest_client_builder
328 };
329
330 let reqwest_client_builder = if let Some(duration) = builder.connect_timeout {
331 reqwest_client_builder.connect_timeout(duration)
332 } else {
333 reqwest_client_builder
334 };
335
336 let client = reqwest_client_builder.build()?;
337 Ok(client)
338 }
339
340 fn new_from_token(builder: &ClientBuilder, base_url: url::Url, token: &str) -> Result<Self> {
341 let client = Self::reqwest_client(builder)?;
342 let creation_time = chrono::Local::now().into();
343 let expires = creation_time + chrono::Duration::seconds(1);
344
345 let login_token = LoginToken {
346 token: token.to_string(),
347 expires_in: 1,
348 creation_time,
349 expires,
350 jwt_svid: None,
351 x509_svid: None,
352 };
353
354 Self::new(builder, client, base_url, login_token)
355 }
356
357 fn new(
358 builder: &ClientBuilder,
359 client: reqwest::Client,
360 base_url: url::Url,
361 login_token: LoginToken,
362 ) -> Result<Self> {
363 let websocket_url = url::Url::parse(&format!("wss://{}/v1/ws/", base_url.host_port()?))?;
364
365 let renew_at = login_token.renew_at();
366
367 let state = std::sync::Arc::new(tokio::sync::Mutex::new(ClientState { login_token }));
368
369 let weak_state = std::sync::Arc::downgrade(&state);
370 let refresh_url = base_url.join("/v1/state/strongbox/token/refresh")?;
371
372 if builder.auto_renew_token {
373 tokio::spawn(renew_token_task(
374 weak_state,
375 renew_at,
376 client.clone(),
377 refresh_url,
378 ));
379 }
380
381 Ok(Self {
382 http_client: client,
383 tls_ca: builder.tls_ca.clone(),
384 disable_cert_verification: builder.disable_cert_verification,
385 base_url,
386 websocket_url,
387 state,
388 })
389 }
390
391 pub async fn bearer_token(&self) -> String {
393 let state = self.state.lock().await;
394 state.login_token.token.clone()
395 }
396
397 pub async fn get_json<T: serde::de::DeserializeOwned>(
399 &self,
400 path: &str,
401 query_params: Option<&[(&str, &str)]>,
402 ) -> Result<T> {
403 let url = self.base_url.join(path)?;
404
405 let token = self.bearer_token().await;
406
407 let mut builder = self
408 .http_client
409 .get(url)
410 .bearer_auth(&token)
411 .header("Accept", "application/json");
412 if let Some(qp) = query_params {
413 builder = builder.query(qp);
414 }
415
416 let result = builder.send().await?;
417
418 if result.status().is_success() {
419 let res = result.json().await?;
420 Ok(res)
421 } else {
422 let status = result.status();
423 let error_payload = result
424 .text()
425 .await
426 .unwrap_or_else(|_| "No error payload".to_string());
427 Err(Error::WebServer(
428 status.as_u16(),
429 status.to_string(),
430 error_payload,
431 ))
432 }
433 }
434
435 pub async fn get_bytes(
437 &self,
438 path: &str,
439 query_params: Option<&[(&str, &str)]>,
440 ) -> Result<Bytes> {
441 let url = self.base_url.join(path)?;
442
443 let token = self.bearer_token().await;
444
445 let mut builder = self.http_client.get(url).bearer_auth(&token);
446
447 if let Some(qp) = query_params {
448 builder = builder.query(qp);
449 }
450
451 let result = builder.send().await?;
452
453 if result.status().is_success() {
454 let res = result.bytes().await?;
455 Ok(res)
456 } else {
457 let status = result.status();
458 let error_payload = result
459 .text()
460 .await
461 .unwrap_or_else(|_| "No error payload".to_string());
462 Err(Error::WebServer(
463 status.as_u16(),
464 status.to_string(),
465 error_payload,
466 ))
467 }
468 }
469
470 pub async fn post_json(
474 &self,
475 path: &str,
476 data: &serde_json::Value,
477 ) -> Result<serde_json::Value> {
478 let url = self.base_url.join(path)?;
479 let token = self.bearer_token().await;
480
481 tracing::debug!("POST {} {:?}", url, data);
482
483 let result = self
484 .http_client
485 .post(url)
486 .json(&data)
487 .bearer_auth(&token)
488 .send()
489 .await?;
490
491 if result.status().is_success() {
492 let resp = result.bytes().await?;
493
494 let mut responses: Vec<serde_json::Value> = Vec::new();
495 let decoder = serde_json::Deserializer::from_slice(&resp);
496
497 for v in decoder.into_iter() {
498 responses.push(v?);
499 }
500
501 match responses.len() {
502 0 => Ok(serde_json::Value::Object(serde_json::Map::default())),
503 1 => Ok(responses.into_iter().next().unwrap()),
504 _ => {
505 Ok(serde_json::Value::Array(responses))
507 }
508 }
509 } else {
510 tracing::error!("POST call failed");
511 let status = result.status();
512 let resp = result.json().await;
513 match resp {
514 Ok(resp) => Err(Error::REST(resp)),
515 Err(_) => Err(Error::WebServer(
516 status.as_u16(),
517 status.to_string(),
518 "Failed to get JSON responses".to_string(),
519 )),
520 }
521 }
522 }
523
524 pub async fn put<T: Into<reqwest::Body> + std::fmt::Debug>(
526 &self,
527 path: &str,
528 content_type: &str,
529 data: T,
530 ) -> Result<()> {
531 let url = self.base_url.join(path)?;
532 let token = self.state.lock().await.login_token.token.clone();
533
534 tracing::debug!("PUT {} {:?}", url, data);
535
536 let result = self
537 .http_client
538 .put(url)
539 .header(reqwest::header::CONTENT_TYPE, content_type)
540 .body(data)
541 .bearer_auth(&token)
542 .send()
543 .await?;
544
545 if result.status().is_success() {
547 Ok(())
548 } else {
549 tracing::error!("PUT call failed");
550 let status = result.status();
551 let resp = result.json().await;
552 match resp {
553 Ok(resp) => Err(Error::REST(resp)),
554 Err(_) => Err(Error::WebServer(
555 status.as_u16(),
556 status.to_string(),
557 "Failed to get JSON reply".to_string(),
558 )),
559 }
560 }
561 }
562
563 pub async fn put_json(
565 &self,
566 path: &str,
567 data: &serde_json::Value,
568 ) -> Result<serde_json::Value> {
569 let url = self.base_url.join(path)?;
570 let token = self.state.lock().await.login_token.token.clone();
571
572 tracing::debug!("PUT {} {:?}", url, data);
573
574 let result = self
575 .http_client
576 .put(url)
577 .json(&data)
578 .bearer_auth(&token)
579 .send()
580 .await?;
581
582 #[allow(clippy::redundant_closure_for_method_calls)]
583 if result.status().is_success() {
584 use std::error::Error;
585 let resp = result.json().await.or_else(|e| match e {
586 e if e.is_decode() => {
587 match e
588 .source()
589 .and_then(|e| e.downcast_ref::<serde_json::Error>())
590 {
591 Some(e) if e.is_eof() => {
592 Ok(serde_json::Value::Object(serde_json::Map::new()))
593 }
594 _ => Err(e),
595 }
596 }
597 e => Err(e),
598 })?;
599 Ok(resp)
600 } else {
601 tracing::error!("PUT call failed");
602 let status = result.status();
603 let resp = result.json().await;
604 match resp {
605 Ok(resp) => Err(Error::REST(resp)),
606 Err(_) => Err(Error::WebServer(
607 status.as_u16(),
608 status.to_string(),
609 "Failed to get JSON reply".to_string(),
610 )),
611 }
612 }
613 }
614
615 pub async fn volga_open_producer(
617 &self,
618 producer_name: &str,
619 topic: &str,
620 on_no_exists: volga::OnNoExists,
621 ) -> Result<volga::producer::Producer> {
622 crate::volga::producer::Builder::new(self, producer_name, topic, on_no_exists)?
623 .connect()
624 .await
625 }
626
627 pub async fn volga_open_child_site_producer(
629 &self,
630 producer_name: &str,
631 topic: &str,
632 site: &str,
633 on_no_exists: volga::OnNoExists,
634 ) -> Result<volga::producer::Producer> {
635 crate::volga::producer::Builder::new_child(self, producer_name, topic, site, on_no_exists)?
636 .connect()
637 .await
638 }
639
640 pub async fn volga_open_parent_site_producer(
642 &self,
643 producer_name: &str,
644 topic: &str,
645 on_no_exists: volga::OnNoExists,
646 ) -> Result<volga::producer::Producer> {
647 crate::volga::producer::Builder::new_parent(self, producer_name, topic, on_no_exists)?
648 .connect()
649 .await
650 }
651
652 #[tracing::instrument(skip(self))]
654 pub async fn volga_open_consumer(
655 &self,
656 consumer_name: &str,
657 topic: &str,
658 options: crate::volga::consumer::Options<'_>,
659 ) -> Result<volga::consumer::Consumer> {
660 crate::volga::consumer::Builder::new(self, consumer_name, topic)?
661 .set_options(options)
662 .connect()
663 .await
664 }
665
666 pub async fn volga_open_child_site_consumer(
668 &self,
669 consumer_name: &str,
670 topic: &str,
671 site: &str,
672 options: crate::volga::consumer::Options<'_>,
673 ) -> Result<volga::consumer::Consumer> {
674 crate::volga::consumer::Builder::new_child(self, consumer_name, topic, site)?
675 .set_options(options)
676 .connect()
677 .await
678 }
679
680 pub async fn volga_open_parent_site_consumer(
682 &self,
683 consumer_name: &str,
684 topic: &str,
685 options: crate::volga::consumer::Options<'_>,
686 ) -> Result<volga::consumer::Consumer> {
687 crate::volga::consumer::Builder::new_parent(self, consumer_name, topic)?
688 .set_options(options)
689 .connect()
690 .await
691 }
692
693 #[tracing::instrument(skip(self))]
694 pub(crate) async fn open_tls_stream(
695 &self,
696 ) -> Result<tokio_rustls::client::TlsStream<tokio::net::TcpStream>> {
697 let mut client_config = tokio_rustls::rustls::ClientConfig::builder()
698 .with_root_certificates(self.tls_ca.clone())
699 .with_no_client_auth();
700
701 if self.disable_cert_verification {
702 let mut danger = client_config.dangerous();
703
704 danger.set_certificate_verifier(std::sync::Arc::new(CertificateVerifier));
705 }
706
707 let client_config = std::sync::Arc::new(client_config);
708
709 let connector: tokio_rustls::TlsConnector = client_config.into();
710 let addrs = self.websocket_url.socket_addrs(|| None)?;
711 let stream = tokio::net::TcpStream::connect(&*addrs).await?;
712
713 let server_name = tokio_rustls::rustls::pki_types::ServerName::try_from(
714 self.websocket_url.host_str().unwrap().to_owned(),
715 )?;
716 let stream = connector.connect(server_name, stream).await?;
717 Ok(stream)
718 }
719
720 pub async fn volga_query_topic(
722 &self,
723 query: volga::query_topic::Query,
724 ) -> Result<volga::query_topic::QueryStream> {
725 volga::query_topic::QueryStream::new(self, query).await
726 }
727
728 pub async fn open_strongbox_vault(&self, vault: &str) -> Result<strongbox::Vault> {
730 strongbox::Vault::open(self, vault).await
731 }
732
733 pub async fn connect(
735 &self,
736 application: &str,
737 service_instance: &str,
738 site: &str,
739 protocol: crate::app_connect::Protocol,
740 port: u16,
741 ip_address: Option<std::net::IpAddr>,
742 ) -> Result<crate::app_connect::Connection> {
743 crate::app_connect::connect(
744 self,
745 application,
746 service_instance,
747 site,
748 protocol,
749 port,
750 ip_address,
751 )
752 .await
753 }
754}
755
756#[derive(Debug)]
757struct CertificateVerifier;
758
759impl tokio_rustls::rustls::client::danger::ServerCertVerifier for CertificateVerifier {
760 fn verify_server_cert(
761 &self,
762 _end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
763 _intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>],
764 _server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
765 _ocsp_response: &[u8],
766 _now: tokio_rustls::rustls::pki_types::UnixTime,
767 ) -> std::result::Result<
768 tokio_rustls::rustls::client::danger::ServerCertVerified,
769 tokio_rustls::rustls::Error,
770 > {
771 Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
772 }
773
774 fn verify_tls12_signature(
775 &self,
776 _message: &[u8],
777 _cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
778 _dss: &tokio_rustls::rustls::DigitallySignedStruct,
779 ) -> std::result::Result<
780 tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
781 tokio_rustls::rustls::Error,
782 > {
783 Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
784 }
785
786 fn verify_tls13_signature(
787 &self,
788 _message: &[u8],
789 _cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
790 _dss: &tokio_rustls::rustls::DigitallySignedStruct,
791 ) -> std::result::Result<
792 tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
793 tokio_rustls::rustls::Error,
794 > {
795 Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
796 }
797
798 fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
799 vec![
800 tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA1,
801 tokio_rustls::rustls::SignatureScheme::ECDSA_SHA1_Legacy,
802 tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA256,
803 tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
804 tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA384,
805 tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
806 tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA512,
807 tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
808 tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA256,
809 tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA384,
810 tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA512,
811 tokio_rustls::rustls::SignatureScheme::ED25519,
812 tokio_rustls::rustls::SignatureScheme::ED448,
813 ]
814 }
815}
816
817pub(crate) trait URLExt {
818 fn host_port(&self) -> std::result::Result<String, url::ParseError>;
819}
820
821impl URLExt for url::Url {
822 fn host_port(&self) -> std::result::Result<String, url::ParseError> {
823 let host = self.host_str().ok_or(url::ParseError::EmptyHost)?;
824 Ok(match (host, self.port()) {
825 (host, Some(port)) => format!("{host}:{port}"),
826 (host, _) => host.to_string(),
827 })
828 }
829}
830
831#[tracing::instrument(skip(next_renew_at, weak_state, client, refresh_url))]
832async fn renew_token_task(
833 weak_state: std::sync::Weak<tokio::sync::Mutex<ClientState>>,
834 mut next_renew_at: chrono::DateTime<chrono::Utc>,
835 client: reqwest::Client,
836 refresh_url: url::Url,
837) {
838 loop {
839 let now: chrono::DateTime<_> = chrono::Local::now().into();
840
841 let sleep_time = next_renew_at - now;
842
843 tracing::info!("renew token in {sleep_time} ({})", next_renew_at);
844
845 tokio::time::sleep(
846 sleep_time
847 .to_std()
848 .unwrap_or_else(|_| std::time::Duration::from_secs(0)),
849 )
850 .await;
851
852 if let Some(state) = weak_state.upgrade() {
853 let mut state = state.lock().await;
854
855 if now > state.login_token.expires {
856 tracing::error!("Token is expired and we can't renew. Giving up");
857 break;
858 }
859 let response = client
860 .post(refresh_url.clone())
861 .bearer_auth(&state.login_token.token)
862 .send()
863 .await;
864
865 let response = match response {
866 Ok(r) => r,
867 Err(e) => {
868 tracing::error!("Failed to renew token: {e}");
869 let now = chrono::Utc::now();
870 next_renew_at = now + chrono::Duration::seconds(1);
871 continue;
872 }
873 };
874
875 let text = response.text().await.unwrap();
876 let new_login_token = serde_json::from_str::<LoginToken>(&text);
877
878 match new_login_token {
879 Ok(new_login_token) => {
880 next_renew_at = new_login_token.renew_at();
881 state.login_token = new_login_token;
882 tracing::info!(
883 "Successfully renewed token, expires: {}",
884 state.login_token.expires
885 );
886 }
887 Err(e) => {
888 tracing::error!("Failed to parse or get token: {e}");
889 let now = chrono::Utc::now();
891 next_renew_at = now + chrono::Duration::seconds(1);
892 }
893 }
894 } else {
895 tracing::debug!("renew_token: State lost");
896 break;
898 }
899 }
900}
901
902#[cfg(test)]
903mod test {
904 #[test]
905 fn url_ext() {
906 use super::URLExt;
907 let url = url::Url::parse("https://1.2.3.4:5000/a/b/c").unwrap();
908 let host_port = url.host_port().unwrap();
909 assert_eq!(&host_port, "1.2.3.4:5000");
910
911 let url = url::Url::parse("https://1.2.3.4/a/b/c").unwrap();
912 let host_port = url.host_port().unwrap();
913 assert_eq!(&host_port, "1.2.3.4");
914
915 let url = url::Url::parse("https://www.avassa.com/a/b/c").unwrap();
916 let host_port = url.host_port().unwrap();
917 assert_eq!(&host_port, "www.avassa.com");
918
919 let url = url::Url::parse("https://www.avassa.com:1234/a/b/c").unwrap();
920 let host_port = url.host_port().unwrap();
921 assert_eq!(&host_port, "www.avassa.com:1234");
922 }
923}