#![deny(clippy::all)]
#![warn(clippy::pedantic)]
#![warn(clippy::cargo)]
#![allow(clippy::missing_errors_doc)]
use bytes::Bytes;
use serde::Deserialize;
use serde_json::json;
pub mod strongbox;
pub mod volga;
#[cfg(feature = "utilities")]
pub mod utilities;
#[cfg(feature = "login-helper")]
pub mod login_helper;
#[cfg(feature = "supctl")]
pub mod supctl;
#[derive(Debug, Deserialize)]
pub struct RESTError {
#[serde(rename = "error-message")]
pub error_message: String,
#[serde(rename = "error-info")]
pub error_info: serde_json::Value,
}
#[derive(Debug, Deserialize)]
pub struct RESTErrorList {
pub errors: Vec<RESTError>,
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Login failed: {0}")]
LoginFailure(String),
#[error("Login failure, missing environment variable '{0}'")]
LoginFailureMissingEnv(String),
#[error("HTTP failed {0}, {1} - {2}")]
WebServer(u16, String, String),
#[error("Websocket error: {0}")]
WebSocket(#[from] tokio_tungstenite::tungstenite::Error),
#[error("Serde JSON error: {0}")]
Serde(#[from] serde_json::Error),
#[error("URL: {0}")]
URL(#[from] url::ParseError),
#[error("Reqwest: {0}")]
HTTPClient(#[from] reqwest::Error),
#[error("Error from Volga {0:?}")]
Volga(Option<String>),
#[error("API Error {0:?}")]
API(String),
#[error("REST error {0:?}")]
REST(RESTErrorList),
#[error("TLS error {0}")]
TLS(#[from] tokio_native_tls::native_tls::Error),
#[error("IO error {0}")]
IO(#[from] std::io::Error),
#[error("Error {0}")]
General(String),
}
impl Error {
#[must_use]
pub fn general(err: &str) -> Self {
Self::General(err.to_string())
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Clone, Deserialize)]
#[serde(rename_all = "kebab-case")]
struct LoginToken {
token: String,
expires_in: i64,
expires: chrono::DateTime<chrono::FixedOffset>,
creation_time: chrono::DateTime<chrono::FixedOffset>,
}
impl LoginToken {
fn renew_at(&self) -> chrono::DateTime<chrono::FixedOffset> {
self.expires - chrono::Duration::seconds(self.expires_in / 4)
}
}
impl std::fmt::Debug for LoginToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LoginToken")
.field("expires_in", &self.expires_in)
.field("creation_time", &self.creation_time)
.finish()
}
}
#[derive(Debug)]
struct ClientState {
login_token: LoginToken,
}
#[derive(Clone)]
#[allow(clippy::struct_excessive_bools)]
pub struct ClientBuilder {
reqwest_ca: Vec<reqwest::Certificate>,
tls_ca: Vec<tokio_native_tls::native_tls::Certificate>,
disable_hostname_check: bool,
disable_cert_verification: bool,
connection_verbose: bool,
auto_renew_token: bool,
timeout: Option<core::time::Duration>,
connect_timeout: Option<core::time::Duration>,
}
impl ClientBuilder {
#[must_use]
const fn new() -> Self {
Self {
reqwest_ca: Vec::new(),
tls_ca: Vec::new(),
disable_cert_verification: false,
disable_hostname_check: false,
connection_verbose: false,
auto_renew_token: true,
timeout: None,
connect_timeout: None,
}
}
#[must_use]
pub fn timeout(self, timeout: core::time::Duration) -> Self {
Self {
timeout: Some(timeout),
..self
}
}
#[must_use]
pub fn connection_timeout(self, timeout: core::time::Duration) -> Self {
Self {
connect_timeout: Some(timeout),
..self
}
}
pub fn add_root_certificate(mut self, cert: &[u8]) -> Result<Self> {
let r_ca = reqwest::Certificate::from_pem(cert)?;
let t_ca = tokio_native_tls::native_tls::Certificate::from_pem(cert)?;
self.reqwest_ca.push(r_ca);
self.tls_ca.push(t_ca);
Ok(self)
}
#[must_use]
pub fn danger_accept_invalid_certs(self) -> Self {
Self {
disable_cert_verification: true,
..self
}
}
#[must_use]
pub fn danger_accept_invalid_hostnames(self) -> Self {
Self {
disable_hostname_check: true,
..self
}
}
#[must_use]
pub fn enable_verbose_connection(self) -> Self {
Self {
connection_verbose: true,
..self
}
}
#[must_use]
pub fn disable_token_auto_renewal(self) -> Self {
Self {
auto_renew_token: false,
..self
}
}
pub async fn application_login(&self, host: &str, approle_id: Option<&str>) -> Result<Client> {
let secret_id = std::env::var("APPROLE_SECRET_ID")
.map_err(|_| Error::LoginFailureMissingEnv(String::from("APPROLE_SECRET_ID")))?;
let role_id = approle_id.unwrap_or(&secret_id);
let base_url = url::Url::parse(host)?;
let url = base_url.join("v1/approle-login")?;
let data = json!({
"role-id": role_id,
"secret-id": secret_id,
});
Client::do_login(self, base_url, url, data).await
}
#[tracing::instrument(skip(self, password))]
pub async fn login(&self, host: &str, username: &str, password: &str) -> Result<Client> {
let base_url = url::Url::parse(host)?;
let url = base_url.join("v1/login")?;
let data = json!({
"username":username,
"password":password
});
Client::do_login(self, base_url, url, data).await
}
#[tracing::instrument(skip(self, token))]
pub fn token_login(&self, host: &str, token: &str) -> Result<Client> {
let base_url = url::Url::parse(host)?;
Client::new_from_token(self, base_url, token)
}
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct Client {
base_url: url::Url,
websocket_url: url::Url,
state: std::sync::Arc<tokio::sync::Mutex<ClientState>>,
client: reqwest::Client,
tls_ca: Vec<tokio_native_tls::native_tls::Certificate>,
disable_hostname_check: bool,
disable_cert_verification: bool,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("base_url", &self.base_url)
.field("websocket_url", &self.websocket_url)
.field("state", &self.state)
.field("client", &self.client)
.field("disable_hostname_check", &self.disable_hostname_check)
.field("disable_cert_verification", &self.disable_cert_verification)
.finish()
}
}
impl Client {
#[must_use]
pub const fn builder() -> ClientBuilder {
ClientBuilder::new()
}
async fn do_login(
builder: &ClientBuilder,
base_url: url::Url,
url: url::Url,
payload: serde_json::Value,
) -> Result<Self> {
let json = serde_json::to_string(&payload)?;
let client = Self::reqwest_client(builder)?;
let result = client
.post(url)
.header("content-type", "application/json")
.body(json)
.send()
.await?;
if result.status().is_success() {
let login_token = result.json().await?;
Self::new(builder, client, base_url, login_token)
} else {
let text = result.text().await?;
tracing::debug!("login returned {}", text);
Err(Error::LoginFailure(text))
}
}
fn reqwest_client(builder: &ClientBuilder) -> Result<reqwest::Client> {
let reqwest_client_builder = reqwest::Client::builder();
let reqwest_client_builder = builder
.reqwest_ca
.iter()
.fold(reqwest_client_builder, |reqwest_client_builder, ca| {
reqwest_client_builder.add_root_certificate(ca.clone())
});
let reqwest_client_builder =
reqwest_client_builder.danger_accept_invalid_certs(builder.disable_cert_verification);
let reqwest_client_builder =
reqwest_client_builder.connection_verbose(builder.connection_verbose);
let reqwest_client_builder = if let Some(duration) = builder.timeout {
reqwest_client_builder.timeout(duration)
} else {
reqwest_client_builder
};
let reqwest_client_builder = if let Some(duration) = builder.connect_timeout {
reqwest_client_builder.connect_timeout(duration)
} else {
reqwest_client_builder
};
let client = reqwest_client_builder.build()?;
Ok(client)
}
fn new_from_token(builder: &ClientBuilder, base_url: url::Url, token: &str) -> Result<Self> {
let client = Self::reqwest_client(builder)?;
let creation_time = chrono::Local::now().into();
let expires = creation_time + chrono::Duration::seconds(1);
let login_token = LoginToken {
token: token.to_string(),
expires_in: 1,
creation_time,
expires,
};
Self::new(builder, client, base_url, login_token)
}
fn new(
builder: &ClientBuilder,
client: reqwest::Client,
base_url: url::Url,
login_token: LoginToken,
) -> Result<Self> {
let websocket_url = url::Url::parse(&format!("wss://{}/v1/ws/", base_url.host_port()?))?;
let renew_at = login_token.renew_at();
let state = std::sync::Arc::new(tokio::sync::Mutex::new(ClientState { login_token }));
let weak_state = std::sync::Arc::downgrade(&state);
let refresh_url = base_url.join("/v1/state/strongbox/token/refresh")?;
if builder.auto_renew_token {
tokio::spawn(renew_token_task(
weak_state,
renew_at,
client.clone(),
refresh_url,
));
}
Ok(Self {
client,
tls_ca: builder.tls_ca.clone(),
disable_cert_verification: builder.disable_cert_verification,
disable_hostname_check: builder.disable_hostname_check,
base_url,
websocket_url,
state,
})
}
pub async fn bearer_token(&self) -> String {
let state = self.state.lock().await;
state.login_token.token.clone()
}
pub async fn get_json<T: serde::de::DeserializeOwned>(
&self,
path: &str,
query_params: Option<&[(&str, &str)]>,
) -> Result<T> {
let url = self.base_url.join(path)?;
let token = self.bearer_token().await;
let mut builder = self
.client
.get(url)
.bearer_auth(&token)
.header("Accept", "application/json");
if let Some(qp) = query_params {
builder = builder.query(qp);
}
let result = builder.send().await?;
if result.status().is_success() {
let res = result.json().await?;
Ok(res)
} else {
let status = result.status();
let error_payload = result
.text()
.await
.unwrap_or_else(|_| "No error payload".to_string());
Err(Error::WebServer(
status.as_u16(),
status.to_string(),
error_payload,
))
}
}
pub async fn get_bytes(
&self,
path: &str,
query_params: Option<&[(&str, &str)]>,
) -> Result<Bytes> {
let url = self.base_url.join(path)?;
let token = self.bearer_token().await;
let mut builder = self.client.get(url).bearer_auth(&token);
if let Some(qp) = query_params {
builder = builder.query(qp);
}
let result = builder.send().await?;
if result.status().is_success() {
let res = result.bytes().await?;
Ok(res)
} else {
let status = result.status();
let error_payload = result
.text()
.await
.unwrap_or_else(|_| "No error payload".to_string());
Err(Error::WebServer(
status.as_u16(),
status.to_string(),
error_payload,
))
}
}
pub async fn post_json(
&self,
path: &str,
data: &serde_json::Value,
) -> Result<serde_json::Value> {
let url = self.base_url.join(path)?;
let token = self.bearer_token().await;
tracing::debug!("POST {} {:?}", url, data);
let result = self
.client
.post(url)
.json(&data)
.bearer_auth(&token)
.send()
.await?;
if result.status().is_success() {
let resp = result.bytes().await?;
let mut responses: Vec<serde_json::Value> = Vec::new();
let decoder = serde_json::Deserializer::from_slice(&resp);
for v in decoder.into_iter() {
responses.push(v?);
}
match responses.len() {
0 => Ok(serde_json::Value::Object(serde_json::Map::default())),
1 => Ok(responses.into_iter().next().unwrap()),
_ => {
Ok(serde_json::Value::Array(responses))
}
}
} else {
tracing::error!("POST call failed");
let status = result.status();
let resp = result.json().await;
match resp {
Ok(resp) => Err(Error::REST(resp)),
Err(_) => Err(Error::WebServer(
status.as_u16(),
status.to_string(),
"Failed to get JSON responses".to_string(),
)),
}
}
}
pub async fn put_json(
&self,
path: &str,
data: &serde_json::Value,
) -> Result<serde_json::Value> {
let url = self.base_url.join(path)?;
let token = self.state.lock().await.login_token.token.clone();
tracing::debug!("PUT {} {:?}", url, data);
let result = self
.client
.put(url)
.json(&data)
.bearer_auth(&token)
.send()
.await?;
#[allow(clippy::redundant_closure_for_method_calls)]
if result.status().is_success() {
use std::error::Error;
let resp = result.json().await.or_else(|e| match e {
e if e.is_decode() => {
match e
.source()
.and_then(|e| e.downcast_ref::<serde_json::Error>())
{
Some(e) if e.is_eof() => {
Ok(serde_json::Value::Object(serde_json::Map::new()))
}
_ => Err(e),
}
}
e => Err(e),
})?;
Ok(resp)
} else {
tracing::error!("PUT call failed");
let status = result.status();
let resp = result.json().await;
match resp {
Ok(resp) => Err(Error::REST(resp)),
Err(_) => Err(Error::WebServer(
status.as_u16(),
status.to_string(),
"Failed to get JSON reply".to_string(),
)),
}
}
}
pub async fn volga_open_producer(
&self,
producer_name: &str,
topic: &str,
on_no_exists: volga::OnNoExists,
) -> Result<volga::producer::Producer> {
crate::volga::producer::Builder::new(self, producer_name, topic, on_no_exists)?
.connect()
.await
}
pub async fn volga_open_child_site_producer(
&self,
producer_name: &str,
topic: &str,
site: &str,
on_no_exists: volga::OnNoExists,
) -> Result<volga::producer::Producer> {
crate::volga::producer::Builder::new_child(self, producer_name, topic, site, on_no_exists)?
.connect()
.await
}
#[tracing::instrument]
pub async fn volga_open_consumer(
&self,
consumer_name: &str,
topic: &str,
options: crate::volga::consumer::Options,
) -> Result<volga::consumer::Consumer> {
crate::volga::consumer::Builder::new(self, consumer_name, topic)?
.set_options(options)
.connect()
.await
}
pub async fn volga_open_child_site_consumer(
&self,
consumer_name: &str,
topic: &str,
site: &str,
options: crate::volga::consumer::Options,
) -> Result<volga::consumer::Consumer> {
crate::volga::consumer::Builder::new_child(self, consumer_name, topic, site)?
.set_options(options)
.connect()
.await
}
#[tracing::instrument(skip(self))]
pub(crate) async fn open_tls_stream(
&self,
) -> Result<tokio_native_tls::TlsStream<tokio::net::TcpStream>> {
let mut connector = tokio_native_tls::native_tls::TlsConnector::builder();
self.tls_ca.iter().for_each(|ca| {
connector.add_root_certificate(ca.clone());
});
connector
.danger_accept_invalid_hostnames(self.disable_hostname_check)
.danger_accept_invalid_certs(self.disable_cert_verification);
let connector = connector.build()?;
let connector: tokio_native_tls::TlsConnector = connector.into();
let addrs = self.websocket_url.socket_addrs(|| None)?;
let stream = tokio::net::TcpStream::connect(&*addrs).await?;
let stream = connector
.connect(self.websocket_url.as_str(), stream)
.await?;
Ok(stream)
}
pub async fn volga_open_log_query(
&self,
query: &volga::log_query::Query,
) -> Result<volga::log_query::QueryStream> {
volga::log_query::QueryStream::new(self, query).await
}
pub async fn open_strongbox_vault(&self, vault: &str) -> Result<strongbox::Vault> {
strongbox::Vault::open(self, vault).await
}
}
pub(crate) trait URLExt {
fn host_port(&self) -> std::result::Result<String, url::ParseError>;
}
impl URLExt for url::Url {
fn host_port(&self) -> std::result::Result<String, url::ParseError> {
let host = self.host_str().ok_or(url::ParseError::EmptyHost)?;
Ok(match (host, self.port()) {
(host, Some(port)) => format!("{host}:{port}"),
(host, _) => host.to_string(),
})
}
}
#[tracing::instrument(skip(next_renew_at, weak_state, client, refresh_url))]
async fn renew_token_task(
weak_state: std::sync::Weak<tokio::sync::Mutex<ClientState>>,
mut next_renew_at: chrono::DateTime<chrono::FixedOffset>,
client: reqwest::Client,
refresh_url: url::Url,
) {
loop {
let now: chrono::DateTime<chrono::FixedOffset> = chrono::Local::now().into();
chrono::Utc::now();
let sleep_time = next_renew_at - now;
tracing::debug!("renew token in {sleep_time}");
tokio::time::sleep(
sleep_time
.to_std()
.unwrap_or_else(|_| std::time::Duration::from_secs(0)),
)
.await;
if let Some(state) = weak_state.upgrade() {
let mut state = state.lock().await;
let response = client
.post(refresh_url.clone())
.bearer_auth(&state.login_token.token)
.send()
.await;
let response = match response {
Ok(r) => r,
Err(e) => {
tracing::error!("Failed to renew token: {e}");
let now: chrono::DateTime<chrono::FixedOffset> = chrono::Local::now().into();
next_renew_at = now + chrono::Duration::seconds(1);
continue;
}
};
let text = response.text().await.unwrap();
let new_login_token = serde_json::from_str::<LoginToken>(&text);
match new_login_token {
Ok(new_login_token) => {
next_renew_at = new_login_token.renew_at();
state.login_token = new_login_token;
tracing::debug!("Successfully renewed token");
}
Err(e) => {
tracing::error!("Failed to parse or get token: {e}");
let now: chrono::DateTime<chrono::FixedOffset> = chrono::Local::now().into();
next_renew_at = now + chrono::Duration::seconds(1);
}
}
} else {
tracing::info!("renew_token: State lost");
break;
}
}
}
#[cfg(test)]
mod test {
#[test]
fn url_ext() {
use super::URLExt;
let url = url::Url::parse("https://1.2.3.4:5000/a/b/c").unwrap();
let host_port = url.host_port().unwrap();
assert_eq!(&host_port, "1.2.3.4:5000");
let url = url::Url::parse("https://1.2.3.4/a/b/c").unwrap();
let host_port = url.host_port().unwrap();
assert_eq!(&host_port, "1.2.3.4");
let url = url::Url::parse("https://www.avassa.com/a/b/c").unwrap();
let host_port = url.host_port().unwrap();
assert_eq!(&host_port, "www.avassa.com");
let url = url::Url::parse("https://www.avassa.com:1234/a/b/c").unwrap();
let host_port = url.host_port().unwrap();
assert_eq!(&host_port, "www.avassa.com:1234");
}
}