use anyhow::{anyhow, Context, Result};
use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine};
use crate::{proto, BatchResult, Col, ResultSet, Statement, Transaction, Value};
static TRANSACTION_IDS: std::sync::atomic::AtomicU64 = std::sync::atomic::AtomicU64::new(0);
pub enum Client {
#[cfg(feature = "local_backend")]
Local(crate::local::Client),
#[cfg(feature = "reqwest_backend")]
Reqwest(crate::reqwest::Client),
#[cfg(feature = "hrana_backend")]
Hrana(crate::hrana::Client),
#[cfg(feature = "workers_backend")]
Workers(crate::workers::Client),
#[cfg(feature = "spin_backend")]
Spin(crate::spin::Client),
}
unsafe impl Send for Client {}
impl Client {
pub async fn raw_batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement> + Send> + Send,
) -> Result<BatchResult> {
match self {
#[cfg(feature = "local_backend")]
Self::Local(l) => l.raw_batch(stmts),
#[cfg(feature = "reqwest_backend")]
Self::Reqwest(r) => r.raw_batch(stmts).await,
#[cfg(feature = "hrana_backend")]
Self::Hrana(h) => h.raw_batch(stmts).await,
#[cfg(feature = "workers_backend")]
Self::Workers(w) => w.raw_batch(stmts).await,
#[cfg(feature = "spin_backend")]
Self::Spin(s) => s.raw_batch(stmts),
}
}
pub async fn batch<I: IntoIterator<Item = impl Into<Statement> + Send> + Send>(
&self,
stmts: I,
) -> Result<Vec<ResultSet>>
where
<I as IntoIterator>::IntoIter: Send,
{
let batch_results = self
.raw_batch(
std::iter::once(Statement::new("BEGIN"))
.chain(stmts.into_iter().map(|s| s.into()))
.chain(std::iter::once(Statement::new("END"))),
)
.await?;
let step_error: Option<proto::Error> = batch_results
.step_errors
.into_iter()
.skip(1)
.find(|e| e.is_some())
.flatten();
if let Some(error) = step_error {
return Err(anyhow::anyhow!(error.message));
}
let mut step_results: Vec<Result<ResultSet>> = batch_results
.step_results
.into_iter()
.skip(1) .map(|maybe_rs| {
maybe_rs
.map(ResultSet::from)
.ok_or_else(|| anyhow!("Unexpected missing result set"))
})
.collect();
step_results.pop(); step_results.into_iter().collect::<Result<Vec<ResultSet>>>()
}
pub async fn execute(&self, stmt: impl Into<Statement> + Send) -> Result<ResultSet> {
match self {
#[cfg(feature = "local_backend")]
Self::Local(l) => l.execute(stmt),
#[cfg(feature = "reqwest_backend")]
Self::Reqwest(r) => r.execute(stmt).await,
#[cfg(feature = "hrana_backend")]
Self::Hrana(h) => h.execute(stmt).await,
#[cfg(feature = "workers_backend")]
Self::Workers(w) => w.execute(stmt).await,
#[cfg(feature = "spin_backend")]
Self::Spin(s) => s.execute(stmt),
}
}
pub async fn transaction(&self) -> Result<Transaction> {
let id = TRANSACTION_IDS.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
match self {
#[cfg(feature = "local_backend")]
Self::Local(_) => Transaction::new(self, id).await,
#[cfg(feature = "hrana_backend")]
Self::Hrana(_) => Transaction::new(self, id).await,
#[cfg(feature = "reqwest_backend")]
Self::Reqwest(_) => {
anyhow::bail!("Interactive transactions are not supported with the reqwest backend. Use batch() instead.")
}
#[cfg(feature = "workers_backend")]
Self::Workers(_) => Transaction::new(self, id).await,
#[cfg(feature = "spin_backend")]
Self::Spin(_) => {
anyhow::bail!("Interactive ransactions are not supported with the spin backend. Use batch() instead.")
}
}
}
pub async fn execute_in_transaction(&self, tx_id: u64, stmt: Statement) -> Result<ResultSet> {
match self {
#[cfg(feature = "local_backend")]
Self::Local(l) => l.execute_in_transaction(tx_id, stmt),
#[cfg(feature = "reqwest_backend")]
Self::Reqwest(_) => unreachable!(),
#[cfg(feature = "hrana_backend")]
Self::Hrana(h) => h.execute_in_transaction(tx_id, stmt).await,
#[cfg(feature = "workers_backend")]
Self::Workers(w) => w.execute_in_transaction(tx_id, stmt).await,
#[cfg(feature = "spin_backend")]
Self::Spin(_) => unreachable!(),
}
}
pub async fn commit_transaction(&self, tx_id: u64) -> Result<()> {
match self {
#[cfg(feature = "local_backend")]
Self::Local(l) => l.commit_transaction(tx_id),
#[cfg(feature = "reqwest_backend")]
Self::Reqwest(_) => unreachable!(),
#[cfg(feature = "hrana_backend")]
Self::Hrana(h) => h.commit_transaction(tx_id).await,
#[cfg(feature = "workers_backend")]
Self::Workers(w) => w.commit_transaction(tx_id).await,
#[cfg(feature = "spin_backend")]
Self::Spin(_) => unreachable!(),
}
}
pub async fn rollback_transaction(&self, tx_id: u64) -> Result<()> {
match self {
#[cfg(feature = "local_backend")]
Self::Local(l) => l.rollback_transaction(tx_id),
#[cfg(feature = "reqwest_backend")]
Self::Reqwest(_) => unreachable!(),
#[cfg(feature = "hrana_backend")]
Self::Hrana(h) => h.rollback_transaction(tx_id).await,
#[cfg(feature = "workers_backend")]
Self::Workers(w) => w.rollback_transaction(tx_id).await,
#[cfg(feature = "spin_backend")]
Self::Spin(_) => unreachable!(),
}
}
}
pub struct Config {
pub url: url::Url,
pub auth_token: Option<String>,
}
pub async fn new_client_from_config<'a>(config: Config) -> anyhow::Result<Client> {
let scheme = config.url.scheme();
Ok(match scheme {
#[cfg(feature = "local_backend")]
"file" => {
Client::Local(crate::local::Client::new(config.url.to_string())?)
},
#[cfg(feature = "hrana_backend")]
"ws" | "wss" => {
Client::Hrana(crate::hrana::Client::from_config(config).await?)
},
#[cfg(feature = "hrana_backend")]
"libsql" => {
let mut config = config;
config.url = if config.url.scheme() == "libsql" {
url::Url::parse(&config.url.as_str().replace("libsql://", "wss://")).unwrap()
} else {
config.url
};
Client::Hrana(crate::hrana::Client::from_config(config).await?)
}
#[cfg(feature = "reqwest_backend")]
"http" | "https" => {
Client::Reqwest(crate::reqwest::Client::from_config(config)?)
},
#[cfg(feature = "workers_backend")]
"workers" => {
Client::Workers(crate::workers::Client::from_config(config).await.map_err(|e| anyhow::anyhow!("{}", e))?)
},
#[cfg(feature = "spin_backend")]
"spin" => {
Client::Spin(crate::spin::Client::from_config(config))
},
_ => anyhow::bail!("Unknown scheme: {scheme}. Make sure your backend exists and is enabled with its feature flag"),
})
}
pub async fn new_client() -> anyhow::Result<Client> {
let url = std::env::var("LIBSQL_CLIENT_URL").map_err(|_| {
anyhow::anyhow!("LIBSQL_CLIENT_URL variable should point to your libSQL/sqld database")
})?;
let auth_token = std::env::var("LIBSQL_CLIENT_TOKEN").ok();
new_client_from_config(Config {
url: url::Url::parse(&url)?,
auth_token,
})
.await
}
pub(crate) fn statements_to_string(
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> (String, usize) {
let mut body = "{\"statements\": [".to_string();
let mut stmts_count = 0;
for stmt in stmts {
body += &format!("{},", stmt.into());
stmts_count += 1;
}
if stmts_count > 0 {
body.pop();
}
body += "]}";
(body, stmts_count)
}
pub(crate) fn parse_columns(
columns: Vec<serde_json::Value>,
result_idx: usize,
) -> Result<Vec<Col>> {
let mut result = Vec::with_capacity(columns.len());
for (idx, column) in columns.into_iter().enumerate() {
match column {
serde_json::Value::String(column) => result.push(Col { name: Some(column) }),
_ => {
return Err(anyhow!(format!(
"Result {result_idx} column name {idx} not a string",
)))
}
}
}
Ok(result)
}
pub(crate) fn parse_value(
cell: serde_json::Value,
result_idx: usize,
row_idx: usize,
cell_idx: usize,
) -> Result<Value> {
match cell {
serde_json::Value::Null => Ok(Value::Null),
serde_json::Value::Number(v) => match v.as_i64() {
Some(v) => Ok(Value::Integer{value: v} ),
None => match v.as_f64() {
Some(v) => Ok(Value::Float{value: v}),
None => Err(anyhow!(
"Result {result_idx} row {row_idx} cell {cell_idx} had unknown number value: {v}",
)),
},
},
serde_json::Value::String(v) => Ok(Value::Text{value: v}),
serde_json::Value::Object(v) => {
let base64_field = v.get("base64").with_context(|| format!("Result {result_idx} row {row_idx} cell {cell_idx} had unknown object, expected base64 field"))?;
let base64_string = base64_field.as_str().with_context(|| format!("Result {result_idx} row {row_idx} cell {cell_idx} had empty base64 field: {base64_field}"))?;
let decoded = BASE64_STANDARD_NO_PAD.decode(base64_string)?;
Ok(Value::Blob{value: decoded})
},
_ => Err(anyhow!(
"Result {result_idx} row {row_idx} cell {cell_idx} had unknown type",
)),
}
}
pub(crate) fn parse_rows(
rows: Vec<serde_json::Value>,
cols_len: usize,
result_idx: usize,
) -> Result<Vec<Vec<Value>>> {
let mut result = Vec::with_capacity(rows.len());
for (idx, row) in rows.into_iter().enumerate() {
match row {
serde_json::Value::Array(row) => {
if row.len() != cols_len {
return Err(anyhow!(
"Result {result_idx} row {idx} had wrong number of cells",
));
}
let mut cells: Vec<Value> = Vec::with_capacity(cols_len);
for (cell_idx, value) in row.into_iter().enumerate() {
cells.push(parse_value(value, result_idx, idx, cell_idx)?);
}
result.push(cells)
}
_ => return Err(anyhow!("Result {result_idx} row {idx} was not an array",)),
}
}
Ok(result)
}
pub(crate) fn parse_query_result(
result: serde_json::Value,
idx: usize,
) -> Result<(Option<proto::StmtResult>, Option<proto::Error>)> {
match result {
serde_json::Value::Object(obj) => {
if let Some(err) = obj.get("error") {
return match err {
serde_json::Value::Object(obj) => match obj.get("message") {
Some(serde_json::Value::String(msg)) => Ok((
None,
Some(proto::Error {
message: msg.clone(),
}),
)),
_ => Err(anyhow!("Result {idx} error message was not a string",)),
},
_ => Err(anyhow!("Result {idx} results was not an object",)),
};
}
let results = obj.get("results");
match results {
Some(serde_json::Value::Object(obj)) => {
let columns = obj
.get("columns")
.ok_or_else(|| anyhow!(format!("Result {idx} had no columns")))?;
let rows = obj
.get("rows")
.ok_or_else(|| anyhow!(format!("Result {idx} had no rows")))?;
match (rows, columns) {
(serde_json::Value::Array(rows), serde_json::Value::Array(columns)) => {
let cols = parse_columns(columns.to_vec(), idx)?;
let rows = parse_rows(rows.to_vec(), columns.len(), idx)?;
let result_set = proto::StmtResult {
cols,
rows,
affected_row_count: 0,
last_insert_rowid: None,
};
Ok((Some(result_set), None))
}
_ => Err(anyhow!(
"Result {idx} had rows or columns that were not an array",
)),
}
}
Some(_) => Err(anyhow!("Result {idx} was not an object",)),
None => Err(anyhow!("Result {idx} did not contain results or error",)),
}
}
_ => Err(anyhow!("Result {idx} was not an object",)),
}
}
pub(crate) fn http_json_to_batch_result(
response_json: serde_json::Value,
stmts_count: usize,
) -> anyhow::Result<BatchResult> {
match response_json {
serde_json::Value::Array(results) => {
if results.len() != stmts_count {
return Err(anyhow::anyhow!(
"Response array did not contain expected {stmts_count} results"
));
}
let mut step_results: Vec<Option<proto::StmtResult>> = Vec::with_capacity(stmts_count);
let mut step_errors: Vec<Option<proto::Error>> = Vec::with_capacity(stmts_count);
for (idx, result) in results.into_iter().enumerate() {
let (step_result, step_error) =
parse_query_result(result, idx).map_err(|e| anyhow::anyhow!("{e}"))?;
step_results.push(step_result);
step_errors.push(step_error);
}
Ok(BatchResult {
step_results,
step_errors,
})
}
e => Err(anyhow::anyhow!("Error: {}", e)),
}
}