1use crate::strongbox;
2use crate::volga;
3use crate::{Error, Result};
4use bytes::Bytes;
5use serde_json::json;
6
7#[derive(Clone, serde::Deserialize)]
8#[serde(rename_all = "kebab-case")]
9struct LoginToken {
10 token: String,
11 expires_in: i64,
12 expires: chrono::DateTime<chrono::FixedOffset>,
13 creation_time: chrono::DateTime<chrono::FixedOffset>,
14}
15
16impl LoginToken {
17 fn renew_at(&self) -> chrono::DateTime<chrono::FixedOffset> {
18 self.expires - chrono::Duration::seconds(self.expires_in / 4)
19 }
20}
21
22impl std::fmt::Debug for LoginToken {
23 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 f.debug_struct("LoginToken")
25 .field("expires_in", &self.expires_in)
26 .field("creation_time", &self.creation_time)
27 .finish_non_exhaustive()
28 }
29}
30
31#[derive(Debug)]
32struct ClientState {
33 login_token: LoginToken,
34}
35
36#[derive(Clone)]
38#[allow(clippy::struct_excessive_bools)]
39pub struct ClientBuilder {
40 reqwest_ca: Vec<reqwest::Certificate>,
41 tls_ca: tokio_rustls::rustls::RootCertStore,
42 disable_cert_verification: bool,
43 connection_verbose: bool,
44 auto_renew_token: bool,
45 timeout: Option<core::time::Duration>,
46 connect_timeout: Option<core::time::Duration>,
47}
48
49impl ClientBuilder {
50 #[must_use]
52 pub(crate) fn new() -> Self {
53 let tls_ca = webpki_roots::TLS_SERVER_ROOTS.iter().cloned().collect();
54 Self {
55 reqwest_ca: Vec::new(),
56 tls_ca,
57 disable_cert_verification: false,
58 connection_verbose: false,
59 auto_renew_token: true,
60 timeout: None,
61 connect_timeout: None,
62 }
63 }
64
65 #[must_use]
67 pub fn timeout(self, timeout: core::time::Duration) -> Self {
68 Self {
69 timeout: Some(timeout),
70 ..self
71 }
72 }
73
74 #[must_use]
76 pub fn connection_timeout(self, timeout: core::time::Duration) -> Self {
77 Self {
78 connect_timeout: Some(timeout),
79 ..self
80 }
81 }
82
83 pub fn add_root_certificate(mut self, cert: &[u8]) -> Result<Self> {
85 use std::iter;
86 let r_ca = reqwest::Certificate::from_pem(cert)?;
87 let mut ca_reader = std::io::BufReader::new(cert);
88 for item in iter::from_fn(|| rustls_pemfile::read_one(&mut ca_reader).transpose()) {
89 if let rustls_pemfile::Item::X509Certificate(cert) = item? {
90 self.tls_ca.add(cert)?;
91 }
92 }
93 self.reqwest_ca.push(r_ca);
94 Ok(self)
95 }
96
97 #[must_use]
99 pub fn danger_disable_cert_verification(self) -> Self {
100 Self {
101 disable_cert_verification: true,
102 ..self
103 }
104 }
105
106 #[must_use]
109 pub fn enable_verbose_connection(self) -> Self {
110 Self {
111 connection_verbose: true,
112 ..self
113 }
114 }
115
116 #[must_use]
118 pub fn disable_token_auto_renewal(self) -> Self {
119 Self {
120 auto_renew_token: false,
121 ..self
122 }
123 }
124
125 pub async fn application_login(&self, host: &str, approle_id: Option<&str>) -> Result<Client> {
129 let secret_id = std::env::var("APPROLE_SECRET_ID")
130 .map_err(|_| Error::LoginFailureMissingEnv(String::from("APPROLE_SECRET_ID")))?;
131
132 let role_id = approle_id.unwrap_or(&secret_id);
134
135 let base_url = url::Url::parse(host)?;
136 let url = base_url.join("v1/approle-login")?;
137 let data = json!({
138 "role-id": role_id,
139 "secret-id": secret_id,
140 });
141 Client::do_login(self, base_url, url, data).await
142 }
143
144 #[tracing::instrument(skip(self, password))]
147 pub async fn login(&self, host: &str, username: &str, password: &str) -> Result<Client> {
148 let base_url = url::Url::parse(host)?;
149 let url = base_url.join("v1/login")?;
150
151 let data = json!({
153 "username":username,
154 "password":password
155 });
156 Client::do_login(self, base_url, url, data).await
157 }
158
159 #[tracing::instrument(skip(self, token))]
161 pub fn token_login(&self, host: &str, token: &str) -> Result<Client> {
162 let base_url = url::Url::parse(host)?;
163 Client::new_from_token(self, base_url, token)
164 }
165}
166
167impl Default for ClientBuilder {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173#[derive(Clone)]
176pub struct Client {
177 base_url: url::Url,
178 pub(crate) websocket_url: url::Url,
179 state: std::sync::Arc<tokio::sync::Mutex<ClientState>>,
180 client: reqwest::Client,
181 tls_ca: tokio_rustls::rustls::RootCertStore,
182 disable_cert_verification: bool,
183}
184
185impl std::fmt::Debug for Client {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 f.debug_struct("Client")
188 .field("base_url", &self.base_url)
189 .field("websocket_url", &self.websocket_url)
190 .field("state", &self.state)
191 .field("client", &self.client)
192 .field("disable_cert_verification", &self.disable_cert_verification)
193 .finish_non_exhaustive()
194 }
195}
196
197impl Client {
198 #[must_use]
200 pub fn builder() -> ClientBuilder {
201 ClientBuilder::new()
202 }
203
204 async fn do_login(
205 builder: &ClientBuilder,
206 base_url: url::Url,
207 url: url::Url,
208 payload: serde_json::Value,
209 ) -> Result<Self> {
210 let json = serde_json::to_string(&payload)?;
211 let client = Self::reqwest_client(builder)?;
212 let result = client
213 .post(url)
214 .header("content-type", "application/json")
215 .body(json)
216 .send()
217 .await?;
218
219 if result.status().is_success() {
220 let login_token = result.json().await?;
221
222 Self::new(builder, client, base_url, login_token)
223 } else {
224 let text = result.text().await?;
225 tracing::debug!("login returned {}", text);
226 Err(Error::LoginFailure(text))
227 }
228 }
229
230 fn reqwest_client(builder: &ClientBuilder) -> Result<reqwest::Client> {
231 let reqwest_client_builder = reqwest::Client::builder().use_rustls_tls();
232
233 let reqwest_client_builder = builder
235 .reqwest_ca
236 .iter()
237 .fold(reqwest_client_builder, |reqwest_client_builder, ca| {
238 reqwest_client_builder.add_root_certificate(ca.clone())
239 });
240
241 tracing::debug!("Added {} CA certs", builder.reqwest_ca.len());
242
243 let reqwest_client_builder =
244 reqwest_client_builder.danger_accept_invalid_certs(builder.disable_cert_verification);
245
246 let reqwest_client_builder =
247 reqwest_client_builder.connection_verbose(builder.connection_verbose);
248
249 let reqwest_client_builder = if let Some(duration) = builder.timeout {
250 reqwest_client_builder.timeout(duration)
251 } else {
252 reqwest_client_builder
253 };
254
255 let reqwest_client_builder = if let Some(duration) = builder.connect_timeout {
256 reqwest_client_builder.connect_timeout(duration)
257 } else {
258 reqwest_client_builder
259 };
260
261 let client = reqwest_client_builder.build()?;
262 Ok(client)
263 }
264
265 fn new_from_token(builder: &ClientBuilder, base_url: url::Url, token: &str) -> Result<Self> {
266 let client = Self::reqwest_client(builder)?;
267 let creation_time = chrono::Local::now().into();
268 let expires = creation_time + chrono::Duration::seconds(1);
269
270 let login_token = LoginToken {
271 token: token.to_string(),
272 expires_in: 1,
273 creation_time,
274 expires,
275 };
276
277 Self::new(builder, client, base_url, login_token)
278 }
279
280 fn new(
281 builder: &ClientBuilder,
282 client: reqwest::Client,
283 base_url: url::Url,
284 login_token: LoginToken,
285 ) -> Result<Self> {
286 let websocket_url = url::Url::parse(&format!("wss://{}/v1/ws/", base_url.host_port()?))?;
287
288 let renew_at = login_token.renew_at();
289
290 let state = std::sync::Arc::new(tokio::sync::Mutex::new(ClientState { login_token }));
291
292 let weak_state = std::sync::Arc::downgrade(&state);
293 let refresh_url = base_url.join("/v1/state/strongbox/token/refresh")?;
294
295 if builder.auto_renew_token {
296 tokio::spawn(renew_token_task(
297 weak_state,
298 renew_at,
299 client.clone(),
300 refresh_url,
301 ));
302 }
303
304 Ok(Self {
305 client,
306 tls_ca: builder.tls_ca.clone(),
307 disable_cert_verification: builder.disable_cert_verification,
308 base_url,
309 websocket_url,
310 state,
311 })
312 }
313
314 pub async fn bearer_token(&self) -> String {
316 let state = self.state.lock().await;
317 state.login_token.token.clone()
318 }
319
320 pub async fn get_json<T: serde::de::DeserializeOwned>(
322 &self,
323 path: &str,
324 query_params: Option<&[(&str, &str)]>,
325 ) -> Result<T> {
326 let url = self.base_url.join(path)?;
327
328 let token = self.bearer_token().await;
329
330 let mut builder = self
331 .client
332 .get(url)
333 .bearer_auth(&token)
334 .header("Accept", "application/json");
335 if let Some(qp) = query_params {
336 builder = builder.query(qp);
337 }
338
339 let result = builder.send().await?;
340
341 if result.status().is_success() {
342 let res = result.json().await?;
343 Ok(res)
344 } else {
345 let status = result.status();
346 let error_payload = result
347 .text()
348 .await
349 .unwrap_or_else(|_| "No error payload".to_string());
350 Err(Error::WebServer(
351 status.as_u16(),
352 status.to_string(),
353 error_payload,
354 ))
355 }
356 }
357
358 pub async fn get_bytes(
360 &self,
361 path: &str,
362 query_params: Option<&[(&str, &str)]>,
363 ) -> Result<Bytes> {
364 let url = self.base_url.join(path)?;
365
366 let token = self.bearer_token().await;
367
368 let mut builder = self.client.get(url).bearer_auth(&token);
369
370 if let Some(qp) = query_params {
371 builder = builder.query(qp);
372 }
373
374 let result = builder.send().await?;
375
376 if result.status().is_success() {
377 let res = result.bytes().await?;
378 Ok(res)
379 } else {
380 let status = result.status();
381 let error_payload = result
382 .text()
383 .await
384 .unwrap_or_else(|_| "No error payload".to_string());
385 Err(Error::WebServer(
386 status.as_u16(),
387 status.to_string(),
388 error_payload,
389 ))
390 }
391 }
392
393 pub async fn post_json(
397 &self,
398 path: &str,
399 data: &serde_json::Value,
400 ) -> Result<serde_json::Value> {
401 let url = self.base_url.join(path)?;
402 let token = self.bearer_token().await;
403
404 tracing::debug!("POST {} {:?}", url, data);
405
406 let result = self
407 .client
408 .post(url)
409 .json(&data)
410 .bearer_auth(&token)
411 .send()
412 .await?;
413
414 if result.status().is_success() {
415 let resp = result.bytes().await?;
416
417 let mut responses: Vec<serde_json::Value> = Vec::new();
418 let decoder = serde_json::Deserializer::from_slice(&resp);
419
420 for v in decoder.into_iter() {
421 responses.push(v?);
422 }
423
424 match responses.len() {
425 0 => Ok(serde_json::Value::Object(serde_json::Map::default())),
426 1 => Ok(responses.into_iter().next().unwrap()),
427 _ => {
428 Ok(serde_json::Value::Array(responses))
430 }
431 }
432 } else {
433 tracing::error!("POST call failed");
434 let status = result.status();
435 let resp = result.json().await;
436 match resp {
437 Ok(resp) => Err(Error::REST(resp)),
438 Err(_) => Err(Error::WebServer(
439 status.as_u16(),
440 status.to_string(),
441 "Failed to get JSON responses".to_string(),
442 )),
443 }
444 }
445 }
446
447 pub async fn put_json(
449 &self,
450 path: &str,
451 data: &serde_json::Value,
452 ) -> Result<serde_json::Value> {
453 let url = self.base_url.join(path)?;
454 let token = self.state.lock().await.login_token.token.clone();
455
456 tracing::debug!("PUT {} {:?}", url, data);
457
458 let result = self
459 .client
460 .put(url)
461 .json(&data)
462 .bearer_auth(&token)
463 .send()
464 .await?;
465
466 #[allow(clippy::redundant_closure_for_method_calls)]
467 if result.status().is_success() {
468 use std::error::Error;
469 let resp = result.json().await.or_else(|e| match e {
470 e if e.is_decode() => {
471 match e
472 .source()
473 .and_then(|e| e.downcast_ref::<serde_json::Error>())
474 {
475 Some(e) if e.is_eof() => {
476 Ok(serde_json::Value::Object(serde_json::Map::new()))
477 }
478 _ => Err(e),
479 }
480 }
481 e => Err(e),
482 })?;
483 Ok(resp)
484 } else {
485 tracing::error!("PUT call failed");
486 let status = result.status();
487 let resp = result.json().await;
488 match resp {
489 Ok(resp) => Err(Error::REST(resp)),
490 Err(_) => Err(Error::WebServer(
491 status.as_u16(),
492 status.to_string(),
493 "Failed to get JSON reply".to_string(),
494 )),
495 }
496 }
497 }
498
499 pub async fn volga_open_producer(
501 &self,
502 producer_name: &str,
503 topic: &str,
504 on_no_exists: volga::OnNoExists,
505 ) -> Result<volga::producer::Producer> {
506 crate::volga::producer::Builder::new(self, producer_name, topic, on_no_exists)?
507 .connect()
508 .await
509 }
510
511 pub async fn volga_open_child_site_producer(
513 &self,
514 producer_name: &str,
515 topic: &str,
516 site: &str,
517 on_no_exists: volga::OnNoExists,
518 ) -> Result<volga::producer::Producer> {
519 crate::volga::producer::Builder::new_child(self, producer_name, topic, site, on_no_exists)?
520 .connect()
521 .await
522 }
523
524 #[tracing::instrument]
526 pub async fn volga_open_consumer(
527 &self,
528 consumer_name: &str,
529 topic: &str,
530 options: crate::volga::consumer::Options<'_>,
531 ) -> Result<volga::consumer::Consumer> {
532 crate::volga::consumer::Builder::new(self, consumer_name, topic)?
533 .set_options(options)
534 .connect()
535 .await
536 }
537
538 pub async fn volga_open_child_site_consumer(
540 &self,
541 consumer_name: &str,
542 topic: &str,
543 site: &str,
544 options: crate::volga::consumer::Options<'_>,
545 ) -> Result<volga::consumer::Consumer> {
546 crate::volga::consumer::Builder::new_child(self, consumer_name, topic, site)?
547 .set_options(options)
548 .connect()
549 .await
550 }
551
552 #[tracing::instrument(skip(self))]
553 pub(crate) async fn open_tls_stream(
554 &self,
555 ) -> Result<tokio_rustls::client::TlsStream<tokio::net::TcpStream>> {
556 let mut client_config = tokio_rustls::rustls::ClientConfig::builder()
557 .with_root_certificates(self.tls_ca.clone())
558 .with_no_client_auth();
559
560 if self.disable_cert_verification {
561 let mut danger = client_config.dangerous();
562
563 danger.set_certificate_verifier(std::sync::Arc::new(CertificateVerifier));
564 }
565
566 let client_config = std::sync::Arc::new(client_config);
567
568 let connector: tokio_rustls::TlsConnector = client_config.into();
569 let addrs = self.websocket_url.socket_addrs(|| None)?;
570 let stream = tokio::net::TcpStream::connect(&*addrs).await?;
571
572 let server_name = tokio_rustls::rustls::pki_types::ServerName::try_from(
573 self.websocket_url.host_str().unwrap().to_owned(),
574 )?;
575 let stream = connector.connect(server_name, stream).await?;
576 Ok(stream)
577 }
578
579 pub async fn volga_open_log_query(
581 &self,
582 query: &volga::log_query::Query,
583 ) -> Result<volga::log_query::QueryStream> {
584 volga::log_query::QueryStream::new(self, query).await
585 }
586
587 pub async fn open_strongbox_vault(&self, vault: &str) -> Result<strongbox::Vault> {
589 strongbox::Vault::open(self, vault).await
590 }
591}
592
593#[derive(Debug)]
594struct CertificateVerifier;
595
596impl tokio_rustls::rustls::client::danger::ServerCertVerifier for CertificateVerifier {
597 fn verify_server_cert(
598 &self,
599 _end_entity: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
600 _intermediates: &[tokio_rustls::rustls::pki_types::CertificateDer<'_>],
601 _server_name: &tokio_rustls::rustls::pki_types::ServerName<'_>,
602 _ocsp_response: &[u8],
603 _now: tokio_rustls::rustls::pki_types::UnixTime,
604 ) -> std::result::Result<
605 tokio_rustls::rustls::client::danger::ServerCertVerified,
606 tokio_rustls::rustls::Error,
607 > {
608 Ok(tokio_rustls::rustls::client::danger::ServerCertVerified::assertion())
609 }
610
611 fn verify_tls12_signature(
612 &self,
613 _message: &[u8],
614 _cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
615 _dss: &tokio_rustls::rustls::DigitallySignedStruct,
616 ) -> std::result::Result<
617 tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
618 tokio_rustls::rustls::Error,
619 > {
620 Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
621 }
622
623 fn verify_tls13_signature(
624 &self,
625 _message: &[u8],
626 _cert: &tokio_rustls::rustls::pki_types::CertificateDer<'_>,
627 _dss: &tokio_rustls::rustls::DigitallySignedStruct,
628 ) -> std::result::Result<
629 tokio_rustls::rustls::client::danger::HandshakeSignatureValid,
630 tokio_rustls::rustls::Error,
631 > {
632 Ok(tokio_rustls::rustls::client::danger::HandshakeSignatureValid::assertion())
633 }
634
635 fn supported_verify_schemes(&self) -> Vec<tokio_rustls::rustls::SignatureScheme> {
636 vec![
637 tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA1,
638 tokio_rustls::rustls::SignatureScheme::ECDSA_SHA1_Legacy,
639 tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA256,
640 tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
641 tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA384,
642 tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
643 tokio_rustls::rustls::SignatureScheme::RSA_PKCS1_SHA512,
644 tokio_rustls::rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
645 tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA256,
646 tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA384,
647 tokio_rustls::rustls::SignatureScheme::RSA_PSS_SHA512,
648 tokio_rustls::rustls::SignatureScheme::ED25519,
649 tokio_rustls::rustls::SignatureScheme::ED448,
650 ]
651 }
652}
653
654pub(crate) trait URLExt {
655 fn host_port(&self) -> std::result::Result<String, url::ParseError>;
656}
657
658impl URLExt for url::Url {
659 fn host_port(&self) -> std::result::Result<String, url::ParseError> {
660 let host = self.host_str().ok_or(url::ParseError::EmptyHost)?;
661 Ok(match (host, self.port()) {
662 (host, Some(port)) => format!("{host}:{port}"),
663 (host, _) => host.to_string(),
664 })
665 }
666}
667
668#[tracing::instrument(skip(next_renew_at, weak_state, client, refresh_url))]
669async fn renew_token_task(
670 weak_state: std::sync::Weak<tokio::sync::Mutex<ClientState>>,
671 mut next_renew_at: chrono::DateTime<chrono::FixedOffset>,
672 client: reqwest::Client,
673 refresh_url: url::Url,
674) {
675 loop {
676 let now: chrono::DateTime<_> = chrono::Local::now().into();
677 let sleep_time = next_renew_at - now;
678
679 tracing::debug!("renew token in {sleep_time}");
680
681 tokio::time::sleep(
682 sleep_time
683 .to_std()
684 .unwrap_or_else(|_| std::time::Duration::from_secs(0)),
685 )
686 .await;
687
688 if let Some(state) = weak_state.upgrade() {
689 let mut state = state.lock().await;
690 let response = client
691 .post(refresh_url.clone())
692 .bearer_auth(&state.login_token.token)
693 .send()
694 .await;
695
696 let response = match response {
697 Ok(r) => r,
698 Err(e) => {
699 tracing::error!("Failed to renew token: {e}");
700 let now: chrono::DateTime<chrono::FixedOffset> = chrono::Local::now().into();
701 next_renew_at = now + chrono::Duration::seconds(1);
702 continue;
703 }
704 };
705
706 let text = response.text().await.unwrap();
707 let new_login_token = serde_json::from_str::<LoginToken>(&text);
708
709 match new_login_token {
710 Ok(new_login_token) => {
711 next_renew_at = new_login_token.renew_at();
712 state.login_token = new_login_token;
713 tracing::debug!("Successfully renewed token");
714 }
715 Err(e) => {
716 tracing::error!("Failed to parse or get token: {e}");
717 let now: chrono::DateTime<chrono::FixedOffset> = chrono::Local::now().into();
719 next_renew_at = now + chrono::Duration::seconds(1);
720 }
721 }
722 } else {
723 tracing::info!("renew_token: State lost");
724 break;
726 }
727 }
728}
729
730#[cfg(test)]
731mod test {
732 #[test]
733 fn url_ext() {
734 use super::URLExt;
735 let url = url::Url::parse("https://1.2.3.4:5000/a/b/c").unwrap();
736 let host_port = url.host_port().unwrap();
737 assert_eq!(&host_port, "1.2.3.4:5000");
738
739 let url = url::Url::parse("https://1.2.3.4/a/b/c").unwrap();
740 let host_port = url.host_port().unwrap();
741 assert_eq!(&host_port, "1.2.3.4");
742
743 let url = url::Url::parse("https://www.avassa.com/a/b/c").unwrap();
744 let host_port = url.host_port().unwrap();
745 assert_eq!(&host_port, "www.avassa.com");
746
747 let url = url::Url::parse("https://www.avassa.com:1234/a/b/c").unwrap();
748 let host_port = url.host_port().unwrap();
749 assert_eq!(&host_port, "www.avassa.com:1234");
750 }
751}
752