use async_trait::async_trait;
use anyhow::Result;
use super::{parse_query_result, QueryResult, Statement};
#[async_trait(?Send)]
pub trait DatabaseClient {
async fn execute(&self, stmt: impl Into<Statement>) -> Result<QueryResult> {
let mut results = self.batch(std::iter::once(stmt)).await?;
Ok(results.remove(0))
}
async fn batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> Result<Vec<QueryResult>>;
}
pub enum GenericClient {
#[cfg(feature = "local_backend")]
Local(super::local::Client),
#[cfg(feature = "reqwest_backend")]
Reqwest(super::reqwest::Client),
#[cfg(feature = "hrana_backend")]
Hrana(super::hrana::Client),
#[cfg(feature = "workers_backend")]
Workers(super::workers::Client),
#[cfg(feature = "spin_backend")]
Spin(super::spin::Client),
}
#[async_trait(?Send)]
impl DatabaseClient for GenericClient {
async fn batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> Result<Vec<QueryResult>> {
match self {
#[cfg(feature = "local_backend")]
Self::Local(l) => l.batch(stmts).await,
#[cfg(feature = "reqwest_backend")]
Self::Reqwest(r) => r.batch(stmts).await,
#[cfg(feature = "hrana_backend")]
Self::Hrana(h) => h.batch(stmts).await,
#[cfg(feature = "workers_backend")]
Self::Workers(w) => w.batch(stmts).await,
#[cfg(feature = "spin_backend")]
Self::Spin(s) => s.batch(stmts).await,
}
}
}
pub struct Config {
pub url: url::Url,
pub auth_token: Option<String>,
}
fn maybe_translate_url(url: url::Url) -> url::Url {
if cfg!(feature = "hrana_backend") {
match url.scheme() {
"libsql" => url::Url::parse(&url.as_str().replace("libsql://", "ws://")).unwrap(),
"libsqls" => url::Url::parse(&url.as_str().replace("libsqls://", "wss://")).unwrap(),
_ => url,
}
} else {
match url.scheme() {
"libsql" => url::Url::parse(&url.as_str().replace("libsql://", "http://")).unwrap(),
"libsqls" => url::Url::parse(&url.as_str().replace("libsql://", "https://")).unwrap(),
_ => url,
}
}
}
pub async fn new_client_from_config(mut config: Config) -> anyhow::Result<GenericClient> {
config.url = maybe_translate_url(config.url);
let scheme = config.url.scheme();
Ok(match scheme {
#[cfg(feature = "local_backend")]
"file" => {
GenericClient::Local(super::local::Client::new(config.url.to_string())?)
},
#[cfg(feature = "hrana_backend")]
"ws" | "wss" => {
GenericClient::Hrana(super::hrana::Client::from_config(config).await?)
},
#[cfg(feature = "reqwest_backend")]
"http" | "https" => {
GenericClient::Reqwest(super::reqwest::Client::from_config(config)?)
},
#[cfg(feature = "workers_backend")]
"workers" => {
GenericClient::Workers(super::workers::Client::from_config(config))
},
#[cfg(feature = "spin_backend")]
"spin" => {
GenericClient::Spin(super::spin::Client::from_config(config))
},
_ => anyhow::bail!("Unknown scheme: {scheme}. Make sure your backend exists and is enabled with its feature flag"),
})
}
pub async fn new_client() -> anyhow::Result<GenericClient> {
let url = std::env::var("LIBSQL_CLIENT_URL").map_err(|_| {
anyhow::anyhow!("LIBSQL_CLIENT_URL variable should point to your libSQL/sqld database")
})?;
let url = match url::Url::parse(&url) {
Ok(url) => url,
Err(_) if cfg!(feature = "local") => {
return Ok(GenericClient::Local(super::local::Client::new(url)?))
}
Err(e) => return Err(e.into()),
};
let url = maybe_translate_url(url);
let scheme = url.scheme();
let backend = std::env::var("LIBSQL_CLIENT_BACKEND").unwrap_or_else(|_| {
match scheme {
"ws" | "wss" if cfg!(feature = "hrana_backend") => "hrana",
"http" | "https" => {
if cfg!(feature = "reqwest_backend") {
"reqwest"
} else if cfg!(feature = "workers_backend") {
"workers"
} else if cfg!(feature = "spin_backend") {
"spin"
} else {
"local"
}
}
_ => "local",
}
.to_string()
});
Ok(match backend.as_str() {
#[cfg(feature = "local_backend")]
"local" => {
GenericClient::Local(super::local::Client::new(url.as_str())?)
},
#[cfg(feature = "reqwest_backend")]
"reqwest" => {
GenericClient::Reqwest(super::reqwest::Client::from_url(url.as_str())?)
},
#[cfg(feature = "hrana_backend")]
"hrana" => {
GenericClient::Hrana(super::hrana::Client::new(url.as_str(), "").await?)
},
#[cfg(feature = "workers_backend")]
"workers" => {
anyhow::bail!("Connecting from workers API may need access to worker::RouteContext. Please call libsql_client::workers::Client::connect_from_ctx() directly")
},
#[cfg(feature = "spin_backend")]
"spin" => {
anyhow::bail!("Connecting from spin API may need access to specific Spin SDK secrets. Please call libsql_client::spin::Client::connect_from_url() directly")
},
_ => anyhow::bail!("Unknown backend: {backend}. Make sure your backend exists and is enabled with its feature flag"),
})
}
pub(crate) fn statements_to_string(
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> (String, usize) {
let mut body = "{\"statements\": [".to_string();
let mut stmts_count = 0;
for stmt in stmts {
body += &format!("{},", stmt.into());
stmts_count += 1;
}
if stmts_count > 0 {
body.pop();
}
body += "]}";
(body, stmts_count)
}
pub(crate) fn json_to_query_result(
response_json: serde_json::Value,
stmts_count: usize,
) -> anyhow::Result<Vec<QueryResult>> {
match response_json {
serde_json::Value::Array(results) => {
if results.len() != stmts_count {
Err(anyhow::anyhow!(
"Response array did not contain expected {stmts_count} results"
))
} else {
let mut query_results: Vec<QueryResult> = Vec::with_capacity(stmts_count);
for (idx, result) in results.into_iter().enumerate() {
query_results
.push(parse_query_result(result, idx).map_err(|e| anyhow::anyhow!("{e}"))?);
}
Ok(query_results)
}
}
e => Err(anyhow::anyhow!("Error: {}", e)),
}
}