use std::{
collections::{HashSet, VecDeque},
sync::{Arc, RwLock},
};
use crate::{
batch::BatchResult,
config::{self, RqliteClientConfig, RqliteClientConfigBuilder},
error::{ClientBuilderError, RequestError},
node::{Node, NodeResponse, RemoveNodeRequest},
query::{self, QueryArgs, RqliteQuery},
query_result::QueryResult,
request::{RequestOptions, RqliteQueryParam, RqliteQueryParams},
response::{RqliteResponseRaw, RqliteResult},
select::RqliteSelectResults,
};
use base64::{engine::general_purpose, Engine};
use reqwest::header;
use rqlite_rs_core::Row;
pub struct RqliteClient {
client: reqwest::Client,
hosts: Arc<RwLock<VecDeque<String>>>,
config: RqliteClientConfig,
}
#[derive(Default)]
pub struct RqliteClientBuilder {
hosts: HashSet<String>,
config: RqliteClientConfigBuilder,
basic_auth: Option<String>,
}
impl RqliteClientBuilder {
pub fn new() -> Self {
RqliteClientBuilder::default()
}
pub fn auth(mut self, user: &str, password: &str) -> Self {
self.basic_auth = Some(general_purpose::STANDARD.encode(format!("{}:{}", user, password)));
self
}
pub fn known_host(mut self, host: impl ToString) -> Self {
self.hosts.insert(host.to_string());
self
}
pub fn default_query_params(mut self, params: Vec<RqliteQueryParam>) -> Self {
self.config = self.config.default_query_params(params);
self
}
pub fn scheme(mut self, scheme: config::Scheme) -> Self {
self.config = self.config.scheme(scheme);
self
}
pub fn build(self) -> Result<RqliteClient, ClientBuilderError> {
if self.hosts.is_empty() {
return Err(ClientBuilderError::NoHostsProvided);
}
let hosts = VecDeque::from(self.hosts.into_iter().collect::<Vec<String>>());
let mut headers = header::HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
header::HeaderValue::from_static("application/json"),
);
if let Some(credentials) = self.basic_auth {
let basic_auth_fmt = format!("Basic {}", credentials);
headers.insert(
header::AUTHORIZATION,
header::HeaderValue::from_str(basic_auth_fmt.as_str())?,
);
}
let mut client = reqwest::ClientBuilder::new()
.timeout(std::time::Duration::from_secs(5))
.default_headers(headers);
if let Some(config::Scheme::Https) = self.config.scheme {
client = client.https_only(true)
}
Ok(RqliteClient {
client: client.build()?,
hosts: Arc::new(RwLock::new(hosts)),
config: self.config.build(),
})
}
}
impl RqliteClient {
fn shift_host(&self) {
let mut hosts = self.hosts.write().unwrap();
hosts.rotate_left(1);
}
async fn try_request(
&self,
mut options: RequestOptions,
) -> Result<reqwest::Response, RequestError> {
let (mut host, host_count) = {
let hosts = self.hosts.read().unwrap();
(hosts[0].clone(), hosts.len())
};
if let Some(default_params) = &self.config.default_query_params {
options.merge_default_query_params(default_params);
};
for _ in 0..host_count {
let req = options.to_reqwest_request(&self.client, host.as_str(), &self.config.scheme);
match req.send().await {
Ok(res) if res.status().is_success() => return Ok(res),
Ok(res) => match res.status() {
reqwest::StatusCode::UNAUTHORIZED => {
return Err(RequestError::Unauthorized);
}
status => {
return Err(RequestError::ReqwestError {
body: res.text().await?,
status,
});
}
},
Err(e) => self.handle_request_error(e, &mut host)?,
}
}
Err(RequestError::NoAvailableHosts)
}
fn handle_request_error(
&self,
e: reqwest::Error,
host: &mut String,
) -> Result<(), RequestError> {
if e.is_connect() || e.is_timeout() {
let previous_host = host.clone();
self.shift_host();
let hosts = self.hosts.read().unwrap();
*host = hosts[0].clone();
println!("Connection to {} failed, trying {}", previous_host, *host);
Ok(())
} else {
Err(RequestError::SwitchoverWrongError(e.to_string()))
}
}
async fn exec_query<T>(&self, q: query::RqliteQuery) -> Result<RqliteResult<T>, RequestError>
where
T: serde::de::DeserializeOwned + Clone,
{
let res = self
.try_request(RequestOptions {
endpoint: q.endpoint(),
body: Some(
q.into_json()
.map_err(RequestError::FailedParseRequestBody)?,
),
..Default::default()
})
.await?;
let body = res.text().await?;
let response = serde_json::from_str::<RqliteResponseRaw<T>>(&body)
.map_err(RequestError::FailedParseResponseBody)?;
response
.results
.into_iter()
.next()
.ok_or(RequestError::NoRowsReturned)
}
pub async fn fetch<Q>(&self, q: Q) -> Result<Vec<Row>, RequestError>
where
Q: TryInto<RqliteQuery>,
RequestError: From<Q::Error>,
{
let result = self
.exec_query::<RqliteSelectResults>(q.try_into()?)
.await?;
match result {
RqliteResult::Success(qr) => Ok(qr.rows()),
RqliteResult::Error(qe) => Err(RequestError::DatabaseError(qe.error)),
}
}
pub async fn exec<Q>(&self, q: Q) -> Result<QueryResult, RequestError>
where
Q: TryInto<RqliteQuery>,
RequestError: From<Q::Error>,
{
let query_result = self.exec_query::<QueryResult>(q.try_into()?).await?;
match query_result {
RqliteResult::Success(qr) => Ok(qr),
RqliteResult::Error(qe) => Err(RequestError::DatabaseError(qe.error)),
}
}
pub async fn batch<Q>(&self, qs: Vec<Q>) -> Result<Vec<RqliteResult<BatchResult>>, RequestError>
where
Q: TryInto<RqliteQuery>,
RequestError: From<Q::Error>,
{
let queries = qs
.into_iter()
.map(|q| q.try_into())
.collect::<Result<Vec<RqliteQuery>, _>>()?;
let batch = QueryArgs::from(queries);
let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
let res = self
.try_request(RequestOptions {
endpoint: "db/request".to_string(),
body: Some(body),
..Default::default()
})
.await?;
let body = res.text().await?;
let results = serde_json::from_str::<RqliteResponseRaw<BatchResult>>(&body)
.map_err(RequestError::FailedParseResponseBody)?
.results;
Ok(results)
}
pub async fn transaction<Q>(
&self,
qs: Vec<Q>,
) -> Result<Vec<RqliteResult<QueryResult>>, RequestError>
where
Q: TryInto<RqliteQuery>,
RequestError: From<Q::Error>,
{
let queries = qs
.into_iter()
.map(|q| q.try_into())
.collect::<Result<Vec<RqliteQuery>, _>>()?;
let batch = QueryArgs::from(queries);
let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
let res = self
.try_request(RequestOptions {
endpoint: "db/execute".to_string(),
body: Some(body),
params: Some(
RqliteQueryParams::new()
.transaction()
.into_request_query_params(),
),
..Default::default()
})
.await?;
let body = res.text().await?;
let results = serde_json::from_str::<RqliteResponseRaw<QueryResult>>(&body)
.map_err(RequestError::FailedParseResponseBody)?
.results;
Ok(results)
}
pub async fn queue<Q>(&self, qs: Vec<Q>) -> Result<(), RequestError>
where
Q: TryInto<RqliteQuery>,
RequestError: From<Q::Error>,
{
let queries = qs
.into_iter()
.map(|q| q.try_into())
.collect::<Result<Vec<RqliteQuery>, _>>()?;
let batch = QueryArgs::from(queries);
let body = serde_json::to_string(&batch).map_err(RequestError::FailedParseRequestBody)?;
self.try_request(RequestOptions {
endpoint: "db/execute".to_string(),
body: Some(body),
params: Some(RqliteQueryParams::new().queue().into_request_query_params()),
..Default::default()
})
.await?;
Ok(())
}
pub async fn ready(&self) -> bool {
match self
.try_request(RequestOptions {
endpoint: "readyz".to_string(),
method: reqwest::Method::GET,
..Default::default()
})
.await
{
Ok(res) => res.status() == reqwest::StatusCode::OK,
Err(_) => false,
}
}
pub async fn nodes(&self) -> Result<Vec<Node>, RequestError> {
let res = self
.try_request(RequestOptions {
endpoint: "nodes".to_string(),
params: Some(
RqliteQueryParams::new()
.ver("2".to_string())
.into_request_query_params(),
),
method: reqwest::Method::GET,
..Default::default()
})
.await?;
let body = res.text().await?;
let response = serde_json::from_str::<NodeResponse>(&body)
.map_err(RequestError::FailedParseResponseBody)?;
Ok(response.nodes)
}
pub async fn leader(&self) -> Result<Option<Node>, RequestError> {
let nodes = self.nodes().await?;
Ok(nodes.into_iter().find(|n| n.leader))
}
pub async fn remove_node(&self, id: &str) -> Result<(), RequestError> {
let body = serde_json::to_string(&RemoveNodeRequest { id: id.to_string() })
.map_err(RequestError::FailedParseRequestBody)?;
let res = self
.try_request(RequestOptions {
endpoint: "remove".to_string(),
body: Some(body),
method: reqwest::Method::DELETE,
..Default::default()
})
.await?;
if res.status().is_success() {
Ok(())
} else {
Err(RequestError::DatabaseError(format!(
"Failed to remove node: {}",
res.text()
.await
.map_err(RequestError::FailedReadingResponse)?
)))
}
}
}