Skip to main content

avassa_client/
client.rs

1use crate::strongbox;
2use crate::volga;
3use crate::{Error, Result};
4use bytes::Bytes;
5use serde_json::json;
6
7#[derive(Clone, serde::Deserialize)]
8#[serde(rename_all = "kebab-case")]
9/// Cert and private key in PEM format
10pub struct X509SVID {
11    pub cert: String,
12    pub private_key: String,
13}
14
15#[derive(Clone, serde::Deserialize)]
16#[serde(rename_all = "kebab-case")]
17struct LoginToken {
18    token: String,
19    expires_in: i64,
20    expires: chrono::DateTime<chrono::Utc>,
21    creation_time: chrono::DateTime<chrono::Utc>,
22    jwt_svid: Option<String>,
23    x509_svid: Option<X509SVID>,
24}
25
26impl LoginToken {
27    fn renew_at(&self) -> chrono::DateTime<chrono::Utc> {
28        self.expires - chrono::Duration::seconds(self.expires_in / 4)
29    }
30}
31
32impl std::fmt::Debug for LoginToken {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        f.debug_struct("LoginToken")
35            .field("expires_in", &self.expires_in)
36            .field("creation_time", &self.creation_time)
37            .field("jwt-svid", &self.jwt_svid.is_some())
38            .field("x509-svid", &self.x509_svid.is_some())
39            .finish_non_exhaustive()
40    }
41}
42
43#[derive(Debug)]
44struct ClientState {
45    login_token: LoginToken,
46}
47
48/// Builder for an Avassa [`Client`]
49#[derive(Clone)]
50#[allow(clippy::struct_excessive_bools)]
51pub struct ClientBuilder {
52    reqwest_ca: Vec<reqwest::Certificate>,
53    tls_ca: tokio_rustls::rustls::RootCertStore,
54    disable_cert_verification: bool,
55    connection_verbose: bool,
56    auto_renew_token: bool,
57    timeout: Option<core::time::Duration>,
58    connect_timeout: Option<core::time::Duration>,
59}
60
61impl ClientBuilder {
62    /// Create a new builder instance
63    #[must_use]
64    pub(crate) fn new() -> Self {
65        let tls_ca = webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect();
66        Self {
67            reqwest_ca: Vec::new(),
68            tls_ca,
69            disable_cert_verification: false,
70            connection_verbose: false,
71            auto_renew_token: true,
72            timeout: None,
73            connect_timeout: None,
74        }
75    }
76
77    /// Enables a request timeout
78    #[must_use]
79    pub fn timeout(self, timeout: core::time::Duration) -> Self {
80        Self {
81            timeout: Some(timeout),
82            ..self
83        }
84    }
85
86    /// Set a timeout for only the connect phase of a Client
87    #[must_use]
88    pub fn connection_timeout(self, timeout: core::time::Duration) -> Self {
89        Self {
90            connect_timeout: Some(timeout),
91            ..self
92        }
93    }
94
95    /// Add a root certificate for API certificate verification
96    pub fn add_root_certificate(mut self, cert: &[u8]) -> Result<Self> {
97        use std::iter;
98        let r_ca = reqwest::Certificate::from_pem(cert)?;
99        let mut ca_reader = std::io::BufReader::new(cert);
100        for item in iter::from_fn(|| rustls_pemfile::read_one(&mut ca_reader).transpose()) {
101            if let rustls_pemfile::Item::X509Certificate(cert) = item? {
102                self.tls_ca.add(cert)?;
103            }
104        }
105        self.reqwest_ca.push(r_ca);
106        Ok(self)
107    }
108
109    /// Disable certificate verification
110    #[must_use]
111    pub fn danger_disable_cert_verification(self) -> Self {
112        Self {
113            disable_cert_verification: true,
114            ..self
115        }
116    }
117
118    /// Enabling this option will emit log messages at the TRACE level for read and write operations
119    /// on the https client
120    #[must_use]
121    pub fn enable_verbose_connection(self) -> Self {
122        Self {
123            connection_verbose: true,
124            ..self
125        }
126    }
127
128    /// Disable auto renewal of authentication token
129    #[must_use]
130    pub fn disable_token_auto_renewal(self) -> Self {
131        Self {
132            auto_renew_token: false,
133            ..self
134        }
135    }
136
137    /// Login the application from secret set in the environment
138    /// `approle_id` can optionally be provided
139    /// This assumes the environment variable `APPROLE_SECRET_ID` is set by the system.
140    #[deprecated]
141    pub async fn application_login(&self, host: &str, approle_id: Option<&str>) -> Result<Client> {
142        let secret_id = std::env::var("APPROLE_SECRET_ID")
143            .map_err(|_| Error::LoginFailureMissingEnv(String::from("APPROLE_SECRET_ID")))?;
144
145        // If no app role is provided, we can try to use the secret id as app role.
146        let role_id = approle_id.unwrap_or(&secret_id);
147
148        let base_url = url::Url::parse(host)?;
149        let url = base_url.join("v1/approle-login")?;
150        let data = json!({
151            "role-id": role_id,
152            "secret-id": secret_id,
153        });
154        Client::do_login(self, base_url, url, data).await
155    }
156
157    /// Login using approle
158    pub async fn approle_login(
159        &self,
160        host: &str,
161        secret_id: &str,
162        role_id: Option<&str>,
163    ) -> Result<Client> {
164        // If no role id is provided, use the secret id;
165        let role_id = role_id.unwrap_or(secret_id);
166
167        let base_url = url::Url::parse(host)?;
168        let url = base_url.join("v1/approle-login")?;
169        let data = json!({
170            "role-id": role_id,
171            "secret-id": secret_id,
172        });
173        Client::do_login(self, base_url, url, data).await
174    }
175
176    /// Login to an avassa Control Tower or Edge Enforcer instance. If possible,
177    /// please use the `application_login` as no credentials needs to be distributed.
178    #[tracing::instrument(skip(self, password))]
179    pub async fn login(&self, host: &str, username: &str, password: &str) -> Result<Client> {
180        let base_url = url::Url::parse(host)?;
181        let url = base_url.join("v1/login")?;
182
183        // If we have a tenant, send it.
184        let data = json!({
185            "username":username,
186            "password":password
187        });
188        Client::do_login(self, base_url, url, data).await
189    }
190
191    /// Login using an existing bearer token
192    #[tracing::instrument(skip(self, token))]
193    pub fn token_login(&self, host: &str, token: &str) -> Result<Client> {
194        let base_url = url::Url::parse(host)?;
195        Client::new_from_token(self, base_url, token)
196    }
197
198    pub async fn jwt_login(
199        &self,
200        host: &str,
201        tenant_name: &str,
202        jwt_auth: &str,
203        role: &str,
204        jwt: &str,
205    ) -> Result<Client> {
206        let base_url = url::Url::parse(host)?;
207        let url = base_url.join("v1/jwt-login")?;
208
209        let data = json!({
210            "tenant": tenant_name,
211            "jwt-auth": jwt_auth,
212            "role": role,
213            "jwt": jwt
214        });
215        Client::do_login(self, base_url, url, data).await
216    }
217}
218
219impl Default for ClientBuilder {
220    fn default() -> Self {
221        Self::new()
222    }
223}
224
225/// The `Client` is used for all interaction with Control Tower or Edge Enforcer instances.
226/// Use one of the login functions to create an instance.
227#[derive(Clone)]
228pub struct Client {
229    pub(crate) base_url: url::Url,
230    pub(crate) websocket_url: url::Url,
231    state: std::sync::Arc<tokio::sync::Mutex<ClientState>>,
232    #[allow(clippy::struct_field_names)]
233    http_client: reqwest::Client,
234    tls_ca: tokio_rustls::rustls::RootCertStore,
235    disable_cert_verification: bool,
236}
237
238impl std::fmt::Debug for Client {
239    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
240        f.debug_struct("Client")
241            .field("base_url", &self.base_url)
242            .field("websocket_url", &self.websocket_url)
243            .field("state", &self.state)
244            .field("http_client", &self.http_client)
245            .field("disable_cert_verification", &self.disable_cert_verification)
246            .finish_non_exhaustive()
247    }
248}
249
250impl Client {
251    /// Create a Client builder
252    #[must_use]
253    pub fn builder() -> ClientBuilder {
254        ClientBuilder::new()
255    }
256
257    /// Returns true if the token has expired
258    pub async fn token_is_expired(&self) -> bool {
259        let state = self.state.lock().await;
260        chrono::Local::now() > state.login_token.expires
261    }
262
263    pub async fn token_expires(&self) -> chrono::DateTime<chrono::Utc> {
264        let state = self.state.lock().await;
265        state.login_token.expires.into()
266    }
267
268    /// Returns the JWT SVID if exists
269    pub async fn jwt_svid(&self) -> Option<String> {
270        let state = self.state.lock().await;
271        state.login_token.jwt_svid.clone()
272    }
273
274    pub async fn x509_svid(&self) -> Option<X509SVID> {
275        let state = self.state.lock().await;
276        state.login_token.x509_svid.clone()
277    }
278
279    async fn do_login(
280        builder: &ClientBuilder,
281        base_url: url::Url,
282        url: url::Url,
283        payload: serde_json::Value,
284    ) -> Result<Self> {
285        let json = serde_json::to_string(&payload)?;
286        let client = Self::reqwest_client(builder)?;
287        let result = client
288            .post(url)
289            .header("content-type", "application/json")
290            .body(json)
291            .send()
292            .await?;
293
294        if result.status().is_success() {
295            let login_token = result.json().await?;
296
297            Self::new(builder, client, base_url, login_token)
298        } else {
299            let text = result.text().await?;
300            tracing::debug!("login returned {}", text);
301            Err(Error::LoginFailure(text))
302        }
303    }
304
305    fn reqwest_client(builder: &ClientBuilder) -> Result<reqwest::Client> {
306        let reqwest_client_builder = reqwest::Client::builder().use_rustls_tls();
307
308        // Add CA certificates
309        let reqwest_client_builder = builder
310            .reqwest_ca
311            .iter()
312            .fold(reqwest_client_builder, |reqwest_client_builder, ca| {
313                reqwest_client_builder.add_root_certificate(ca.clone())
314            });
315
316        tracing::debug!("Added {} CA certs", builder.reqwest_ca.len());
317
318        let reqwest_client_builder =
319            reqwest_client_builder.danger_accept_invalid_certs(builder.disable_cert_verification);
320
321        let reqwest_client_builder =
322            reqwest_client_builder.connection_verbose(builder.connection_verbose);
323
324        let reqwest_client_builder = if let Some(duration) = builder.timeout {
325            reqwest_client_builder.timeout(duration)
326        } else {
327            reqwest_client_builder
328        };
329
330        let reqwest_client_builder = if let Some(duration) = builder.connect_timeout {
331            reqwest_client_builder.connect_timeout(duration)
332        } else {
333            reqwest_client_builder
334        };
335
336        let client = reqwest_client_builder.build()?;
337        Ok(client)
338    }
339
340    fn new_from_token(builder: &ClientBuilder, base_url: url::Url, token: &str) -> Result<Self> {
341        let client = Self::reqwest_client(builder)?;
342        let creation_time = chrono::Local::now().into();
343        let expires = creation_time + chrono::Duration::seconds(1);
344
345        let login_token = LoginToken {
346            token: token.to_string(),
347            expires_in: 1,
348            creation_time,
349            expires,
350            jwt_svid: None,
351            x509_svid: None,
352        };
353
354        Self::new(builder, client, base_url, login_token)
355    }
356
357    fn new(
358        builder: &ClientBuilder,
359        client: reqwest::Client,
360        base_url: url::Url,
361        login_token: LoginToken,
362    ) -> Result<Self> {
363        let websocket_url = url::Url::parse(&format!("wss://{}/v1/ws/", base_url.host_port()?))?;
364
365        let renew_at = login_token.renew_at();
366
367        let state = std::sync::Arc::new(tokio::sync::Mutex::new(ClientState { login_token }));
368
369        let weak_state = std::sync::Arc::downgrade(&state);
370        let refresh_url = base_url.join("/v1/state/strongbox/token/refresh")?;
371
372        if builder.auto_renew_token {
373            tokio::spawn(renew_token_task(
374                weak_state,
375                renew_at,
376                client.clone(),
377                refresh_url,
378            ));
379        }
380
381        Ok(Self {
382            http_client: client,
383            tls_ca: builder.tls_ca.clone(),
384            disable_cert_verification: builder.disable_cert_verification,
385            base_url,
386            websocket_url,
387            state,
388        })
389    }
390
391    /// Returns the login bearer token
392    pub async fn bearer_token(&self) -> String {
393        let state = self.state.lock().await;
394        state.login_token.token.clone()
395    }
396
397    /// GET a json payload from the REST API.
398    pub async fn get_json<T: serde::de::DeserializeOwned>(
399        &self,
400        path: &str,
401        query_params: Option<&[(&str, &str)]>,
402    ) -> Result<T> {
403        let url = self.base_url.join(path)?;
404
405        let token = self.bearer_token().await;
406
407        let mut builder = self
408            .http_client
409            .get(url)
410            .bearer_auth(&token)
411            .header("Accept", "application/json");
412        if let Some(qp) = query_params {
413            builder = builder.query(qp);
414        }
415
416        let result = builder.send().await?;
417
418        if result.status().is_success() {
419            let res = result.json().await?;
420            Ok(res)
421        } else {
422            let status = result.status();
423            let error_payload = result
424                .text()
425                .await
426                .unwrap_or_else(|_| "No error payload".to_string());
427            Err(Error::WebServer(
428                status.as_u16(),
429                status.to_string(),
430                error_payload,
431            ))
432        }
433    }
434
435    /// GET a bytes payload from the REST API.
436    pub async fn get_bytes(
437        &self,
438        path: &str,
439        query_params: Option<&[(&str, &str)]>,
440    ) -> Result<Bytes> {
441        let url = self.base_url.join(path)?;
442
443        let token = self.bearer_token().await;
444
445        let mut builder = self.http_client.get(url).bearer_auth(&token);
446
447        if let Some(qp) = query_params {
448            builder = builder.query(qp);
449        }
450
451        let result = builder.send().await?;
452
453        if result.status().is_success() {
454            let res = result.bytes().await?;
455            Ok(res)
456        } else {
457            let status = result.status();
458            let error_payload = result
459                .text()
460                .await
461                .unwrap_or_else(|_| "No error payload".to_string());
462            Err(Error::WebServer(
463                status.as_u16(),
464                status.to_string(),
465                error_payload,
466            ))
467        }
468    }
469
470    /// POST arbitrary JSON to a path
471    /// # Panics
472    /// never
473    pub async fn post_json(
474        &self,
475        path: &str,
476        data: &serde_json::Value,
477    ) -> Result<serde_json::Value> {
478        let url = self.base_url.join(path)?;
479        let token = self.bearer_token().await;
480
481        tracing::debug!("POST {} {:?}", url, data);
482
483        let result = self
484            .http_client
485            .post(url)
486            .json(&data)
487            .bearer_auth(&token)
488            .send()
489            .await?;
490
491        if result.status().is_success() {
492            let resp = result.bytes().await?;
493
494            let mut responses: Vec<serde_json::Value> = Vec::new();
495            let decoder = serde_json::Deserializer::from_slice(&resp);
496
497            for v in decoder.into_iter() {
498                responses.push(v?);
499            }
500
501            match responses.len() {
502                0 => Ok(serde_json::Value::Object(serde_json::Map::default())),
503                1 => Ok(responses.into_iter().next().unwrap()),
504                _ => {
505                    // Convert to a JSON array
506                    Ok(serde_json::Value::Array(responses))
507                }
508            }
509        } else {
510            tracing::error!("POST call failed");
511            let status = result.status();
512            let resp = result.json().await;
513            match resp {
514                Ok(resp) => Err(Error::REST(resp)),
515                Err(_) => Err(Error::WebServer(
516                    status.as_u16(),
517                    status.to_string(),
518                    "Failed to get JSON responses".to_string(),
519                )),
520            }
521        }
522    }
523
524    /// PUT arbitrary data to a path
525    pub async fn put<T: Into<reqwest::Body> + std::fmt::Debug>(
526        &self,
527        path: &str,
528        content_type: &str,
529        data: T,
530    ) -> Result<()> {
531        let url = self.base_url.join(path)?;
532        let token = self.state.lock().await.login_token.token.clone();
533
534        tracing::debug!("PUT {} {:?}", url, data);
535
536        let result = self
537            .http_client
538            .put(url)
539            .header(reqwest::header::CONTENT_TYPE, content_type)
540            .body(data)
541            .bearer_auth(&token)
542            .send()
543            .await?;
544
545        //#[allow(clippy::redundant_closure_for_method_calls)]
546        if result.status().is_success() {
547            Ok(())
548        } else {
549            tracing::error!("PUT call failed");
550            let status = result.status();
551            let resp = result.json().await;
552            match resp {
553                Ok(resp) => Err(Error::REST(resp)),
554                Err(_) => Err(Error::WebServer(
555                    status.as_u16(),
556                    status.to_string(),
557                    "Failed to get JSON reply".to_string(),
558                )),
559            }
560        }
561    }
562
563    /// PUT arbitrary JSON to a path
564    pub async fn put_json(
565        &self,
566        path: &str,
567        data: &serde_json::Value,
568    ) -> Result<serde_json::Value> {
569        let url = self.base_url.join(path)?;
570        let token = self.state.lock().await.login_token.token.clone();
571
572        tracing::debug!("PUT {} {:?}", url, data);
573
574        let result = self
575            .http_client
576            .put(url)
577            .json(&data)
578            .bearer_auth(&token)
579            .send()
580            .await?;
581
582        #[allow(clippy::redundant_closure_for_method_calls)]
583        if result.status().is_success() {
584            use std::error::Error;
585            let resp = result.json().await.or_else(|e| match e {
586                e if e.is_decode() => {
587                    match e
588                        .source()
589                        .and_then(|e| e.downcast_ref::<serde_json::Error>())
590                    {
591                        Some(e) if e.is_eof() => {
592                            Ok(serde_json::Value::Object(serde_json::Map::new()))
593                        }
594                        _ => Err(e),
595                    }
596                }
597                e => Err(e),
598            })?;
599            Ok(resp)
600        } else {
601            tracing::error!("PUT call failed");
602            let status = result.status();
603            let resp = result.json().await;
604            match resp {
605                Ok(resp) => Err(Error::REST(resp)),
606                Err(_) => Err(Error::WebServer(
607                    status.as_u16(),
608                    status.to_string(),
609                    "Failed to get JSON reply".to_string(),
610                )),
611            }
612        }
613    }
614
615    /// Open a volga producer on a topic
616    pub async fn volga_open_producer(
617        &self,
618        producer_name: &str,
619        topic: &str,
620        on_no_exists: volga::OnNoExists,
621    ) -> Result<volga::producer::Producer> {
622        crate::volga::producer::Builder::new(self, producer_name, topic, on_no_exists)?
623            .connect()
624            .await
625    }
626
627    /// Open a volga NAT producer on a topic in a site
628    pub async fn volga_open_child_site_producer(
629        &self,
630        producer_name: &str,
631        topic: &str,
632        site: &str,
633        on_no_exists: volga::OnNoExists,
634    ) -> Result<volga::producer::Producer> {
635        crate::volga::producer::Builder::new_child(self, producer_name, topic, site, on_no_exists)?
636            .connect()
637            .await
638    }
639
640    /// Open a volga producer on a topic in the parent site
641    pub async fn volga_open_parent_site_producer(
642        &self,
643        producer_name: &str,
644        topic: &str,
645        on_no_exists: volga::OnNoExists,
646    ) -> Result<volga::producer::Producer> {
647        crate::volga::producer::Builder::new_parent(self, producer_name, topic, on_no_exists)?
648            .connect()
649            .await
650    }
651
652    /// Creates and opens a Volga consumer
653    #[tracing::instrument(skip(self))]
654    pub async fn volga_open_consumer(
655        &self,
656        consumer_name: &str,
657        topic: &str,
658        options: crate::volga::consumer::Options<'_>,
659    ) -> Result<volga::consumer::Consumer> {
660        crate::volga::consumer::Builder::new(self, consumer_name, topic)?
661            .set_options(options)
662            .connect()
663            .await
664    }
665
666    /// Creates and opens a Volga consumer on a child site
667    pub async fn volga_open_child_site_consumer(
668        &self,
669        consumer_name: &str,
670        topic: &str,
671        site: &str,
672        options: crate::volga::consumer::Options<'_>,
673    ) -> Result<volga::consumer::Consumer> {
674        crate::volga::consumer::Builder::new_child(self, consumer_name, topic, site)?
675            .set_options(options)
676            .connect()
677            .await
678    }
679
680    /// Creates and opens a Volga consumer on the parent site
681    pub async fn volga_open_parent_site_consumer(
682        &self,
683        consumer_name: &str,
684        topic: &str,
685        options: crate::volga::consumer::Options<'_>,
686    ) -> Result<volga::consumer::Consumer> {
687        crate::volga::consumer::Builder::new_parent(self, consumer_name, topic)?
688            .set_options(options)
689            .connect()
690            .await
691    }
692
693    #[tracing::instrument(skip(self))]
694    pub(crate) async fn open_tls_stream(
695        &self,
696    ) -> Result<tokio_rustls::client::TlsStream<tokio::net::TcpStream>> {
697        let mut client_config = tokio_rustls::rustls::ClientConfig::builder()
698            .with_root_certificates(self.tls_ca.clone())
699            .with_no_client_auth();
700
701        if self.disable_cert_verification {
702            let mut danger = client_config.dangerous();
703
704            danger.set_certificate_verifier(std::sync::Arc::new(CertificateVerifier));
705        }
706
707        let client_config = std::sync::Arc::new(client_config);
708
709        let connector: tokio_rustls::TlsConnector = client_config.into();
710        let addrs = self.websocket_url.socket_addrs(|| None)?;
711        let stream = tokio::net::TcpStream::connect(&*addrs).await?;
712
713        let server_name = tokio_rustls::rustls::pki_types::ServerName::try_from(
714            self.websocket_url.host_str().unwrap().to_owned(),
715        )?;
716        let stream = connector.connect(server_name, stream).await?;
717        Ok(stream)
718    }
719
720    /// Opens a query stream
721    pub async fn volga_query_topic(
722        &self,
723        query: volga::query_topic::Query,
724    ) -> Result<volga::query_topic::QueryStream> {
725        volga::query_topic::QueryStream::new(self, query).await
726    }
727
728    /// Try to open a Strongbox Vault
729    pub async fn open_strongbox_vault(&self, vault: &str) -> Result<strongbox::Vault> {
730        strongbox::Vault::open(self, vault).await
731    }
732
733    /// Setup a connection to a remote container
734    pub async fn connect(
735        &self,
736        application: &str,
737        service_instance: &str,
738        site: &str,
739        protocol: crate::app_connect::Protocol,
740        port: u16,
741        ip_address: Option<std::net::IpAddr>,
742    ) -> Result<crate::app_connect::Connection> {
743        crate::app_connect::connect(
744            self,
745            application,
746            service_instance,
747            site,
748            protocol,
749            port,
750            ip_address,
751        )
752        .await
753    }
754}
755
756#[derive(Debug)]
757struct CertificateVerifier;
758
759impl tokio_rustls::rustls::client::danger::ServerCertVerifier for CertificateVerifier {
760    fn verify_server_cert(
761        &self,
762        _end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
763        _intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>],
764        _server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
765        _ocsp_response: &[u8],
766        _now: tokio_rustls::rustls::pki_types::UnixTime,
767    ) -> std::result::Result<
768        tokio_rustls::rustls::client::danger::ServerCertVerified,
769        tokio_rustls::rustls::Error,
770    > {
771        Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
772    }
773
774    fn verify_tls12_signature(
775        &self,
776        _message: &[u8],
777        _cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
778        _dss: &tokio_rustls::rustls::DigitallySignedStruct,
779    ) -> std::result::Result<
780        tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
781        tokio_rustls::rustls::Error,
782    > {
783        Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
784    }
785
786    fn verify_tls13_signature(
787        &self,
788        _message: &[u8],
789        _cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
790        _dss: &tokio_rustls::rustls::DigitallySignedStruct,
791    ) -> std::result::Result<
792        tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
793        tokio_rustls::rustls::Error,
794    > {
795        Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
796    }
797
798    fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
799        vec![
800            tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA1,
801            tokio_rustls::rustls::SignatureScheme::ECDSA_SHA1_Legacy,
802            tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA256,
803            tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
804            tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA384,
805            tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
806            tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA512,
807            tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
808            tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA256,
809            tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA384,
810            tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA512,
811            tokio_rustls::rustls::SignatureScheme::ED25519,
812            tokio_rustls::rustls::SignatureScheme::ED448,
813        ]
814    }
815}
816
817pub(crate) trait URLExt {
818    fn host_port(&self) -> std::result::Result<String, url::ParseError>;
819}
820
821impl URLExt for url::Url {
822    fn host_port(&self) -> std::result::Result<String, url::ParseError> {
823        let host = self.host_str().ok_or(url::ParseError::EmptyHost)?;
824        Ok(match (host, self.port()) {
825            (host, Some(port)) => format!("{host}:{port}"),
826            (host, _) => host.to_string(),
827        })
828    }
829}
830
831#[tracing::instrument(skip(next_renew_at, weak_state, client, refresh_url))]
832async fn renew_token_task(
833    weak_state: std::sync::Weak<tokio::sync::Mutex<ClientState>>,
834    mut next_renew_at: chrono::DateTime<chrono::Utc>,
835    client: reqwest::Client,
836    refresh_url: url::Url,
837) {
838    loop {
839        let now: chrono::DateTime<_> = chrono::Local::now().into();
840
841        let sleep_time = next_renew_at - now;
842
843        tracing::info!("renew token in {sleep_time} ({})", next_renew_at);
844
845        tokio::time::sleep(
846            sleep_time
847                .to_std()
848                .unwrap_or_else(|_| std::time::Duration::from_secs(0)),
849        )
850        .await;
851
852        if let Some(state) = weak_state.upgrade() {
853            let mut state = state.lock().await;
854
855            if now > state.login_token.expires {
856                tracing::error!("Token is expired and we can't renew. Giving up");
857                break;
858            }
859            let response = client
860                .post(refresh_url.clone())
861                .bearer_auth(&state.login_token.token)
862                .send()
863                .await;
864
865            let response = match response {
866                Ok(r) => r,
867                Err(e) => {
868                    tracing::error!("Failed to renew token: {e}");
869                    let now = chrono::Utc::now();
870                    next_renew_at = now + chrono::Duration::seconds(1);
871                    continue;
872                }
873            };
874
875            let text = response.text().await.unwrap();
876            let new_login_token = serde_json::from_str::<LoginToken>(&text);
877
878            match new_login_token {
879                Ok(new_login_token) => {
880                    next_renew_at = new_login_token.renew_at();
881                    state.login_token = new_login_token;
882                    tracing::info!(
883                        "Successfully renewed token, expires: {}",
884                        state.login_token.expires
885                    );
886                }
887                Err(e) => {
888                    tracing::error!("Failed to parse or get token: {e}");
889                    // After failure, we check every second
890                    let now = chrono::Utc::now();
891                    next_renew_at = now + chrono::Duration::seconds(1);
892                }
893            }
894        } else {
895            tracing::debug!("renew_token: State lost");
896            // If we can't get the state, the client is gone and we should go as well
897            break;
898        }
899    }
900}
901
902#[cfg(test)]
903mod test {
904    #[test]
905    fn url_ext() {
906        use super::URLExt;
907        let url = url::Url::parse("https://1.2.3.4:5000/a/b/c").unwrap();
908        let host_port = url.host_port().unwrap();
909        assert_eq!(&host_port, "1.2.3.4:5000");
910
911        let url = url::Url::parse("https://1.2.3.4/a/b/c").unwrap();
912        let host_port = url.host_port().unwrap();
913        assert_eq!(&host_port, "1.2.3.4");
914
915        let url = url::Url::parse("https://www.avassa.com/a/b/c").unwrap();
916        let host_port = url.host_port().unwrap();
917        assert_eq!(&host_port, "www.avassa.com");
918
919        let url = url::Url::parse("https://www.avassa.com:1234/a/b/c").unwrap();
920        let host_port = url.host_port().unwrap();
921        assert_eq!(&host_port, "www.avassa.com:1234");
922    }
923}