use crate::client::Config;
use anyhow::Result;
use base64::Engine;
use crate::{BatchResult, ResultSet, Statement};
#[derive(Clone, Debug)]
pub struct Client {
base_url: String,
url_for_queries: String,
auth: String,
}
impl Client {
pub fn new(url: impl Into<String>, token: impl Into<String>) -> Self {
let token = token.into();
let url = url.into();
let base_url = if !url.contains("://") {
format!("https://{}", &url)
} else {
url
};
let url_for_queries = if cfg!(feature = "separate_url_for_queries") {
format!("{base_url}/queries")
} else {
base_url.clone()
};
Self {
base_url,
url_for_queries,
auth: format!("Bearer {token}"),
}
}
pub fn from_credentials(
url: impl Into<String>,
username: impl Into<String>,
pass: impl Into<String>,
) -> Self {
let username = username.into();
let pass = pass.into();
let url = url.into();
let base_url = if !url.contains("://") {
format!("https://{}", &url)
} else {
url
};
let url_for_queries = if cfg!(feature = "separate_url_for_queries") {
format!("{base_url}/queries")
} else {
base_url.clone()
};
Self {
base_url,
url_for_queries,
auth: format!(
"Basic {}",
base64::engine::general_purpose::STANDARD.encode(format!("{username}:{pass}"))
),
}
}
pub fn from_config(config: Config) -> anyhow::Result<Self> {
Ok(Self::new(config.url, config.auth_token.unwrap_or_default()))
}
pub fn from_url<T: TryInto<url::Url>>(url: T) -> anyhow::Result<Client>
where
<T as TryInto<url::Url>>::Error: std::fmt::Display,
{
let url = url
.try_into()
.map_err(|e| anyhow::anyhow!(format!("{e}")))?;
let mut params = url.query_pairs();
if let Some((_, token)) = params.find(|(param_key, _)| param_key == "token") {
return Ok(Client::new(url.as_str(), token.into_owned()));
}
let username = url.username();
let password = url.password().unwrap_or_default();
let mut url = url.clone();
url.set_username("")
.map_err(|_| anyhow::anyhow!("Could not extract username from URL. Invalid URL?"))?;
url.set_password(None)
.map_err(|_| anyhow::anyhow!("Could not extract password from URL. Invalid URL?"))?;
Ok(Client::from_credentials(url.as_str(), username, password))
}
pub fn from_env() -> anyhow::Result<Client> {
let url = std::env::var("LIBSQL_CLIENT_URL").map_err(|_| {
anyhow::anyhow!("LIBSQL_CLIENT_URL variable should point to your sqld database")
})?;
if let Ok(token) = std::env::var("LIBSQL_CLIENT_TOKEN") {
return Ok(Client::new(url, token));
}
let user = match std::env::var("LIBSQL_CLIENT_USER") {
Ok(user) => user,
Err(_) => {
return Client::from_url(url.as_str());
}
};
let pass = std::env::var("LIBSQL_CLIENT_PASS").map_err(|_| {
anyhow::anyhow!("LIBSQL_CLIENT_PASS variable should be set to your sqld password")
})?;
Ok(Client::from_credentials(url, user, pass))
}
}
impl Client {
pub async fn raw_batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> anyhow::Result<BatchResult> {
let (body, stmts_count) = crate::client::statements_to_string(stmts);
let client = reqwest::Client::new();
let response = match client
.post(&self.url_for_queries)
.body(body.clone())
.header("Authorization", &self.auth)
.send()
.await
{
Ok(resp) if resp.status() == reqwest::StatusCode::OK => resp,
resp => {
if cfg!(feature = "separate_url_for_queries") {
client
.post(&self.base_url)
.body(body)
.header("Authorization", &self.auth)
.send()
.await?
} else {
anyhow::bail!("{}", resp?.status());
}
}
};
if response.status() != reqwest::StatusCode::OK {
anyhow::bail!("{}", response.status());
}
let resp: String = response.text().await?;
let response_json: serde_json::Value = serde_json::from_str(&resp)?;
crate::client::http_json_to_batch_result(response_json, stmts_count)
}
pub async fn execute(&self, stmt: impl Into<Statement> + Send) -> Result<ResultSet> {
let results = self.raw_batch(std::iter::once(stmt)).await?;
match (results.step_results.first(), results.step_errors.first()) {
(Some(Some(result)), Some(None)) => Ok(ResultSet::from(result.clone())),
(Some(None), Some(Some(err))) => Err(anyhow::anyhow!(err.message.clone())),
_ => unreachable!(),
}
}
pub async fn execute_in_transaction(&self, _tx_id: u64, stmt: Statement) -> Result<ResultSet> {
self.execute(stmt).await
}
pub async fn commit_transaction(&self, _tx_id: u64) -> Result<()> {
self.execute("COMMIT").await.map(|_| ())
}
pub async fn rollback_transaction(&self, _tx_id: u64) -> Result<()> {
self.execute("ROLLBACK").await.map(|_| ())
}
}