use crate::{proto, proto::StmtResult, BatchResult, Col, ResultSet, Statement, Value};
use anyhow::Result;
use std::sync::{Arc, Mutex};
use rusqlite::types::Value as RusqliteValue;
#[derive(Debug)]
pub struct Client {
inner: Arc<Mutex<rusqlite::Connection>>,
}
struct ValueWrapper(Value);
impl From<ValueWrapper> for RusqliteValue {
fn from(v: ValueWrapper) -> Self {
match v.0 {
Value::Null => RusqliteValue::Null,
Value::Integer { value: n } => RusqliteValue::Integer(n),
Value::Text { value: s } => RusqliteValue::Text(s),
Value::Float { value: d } => RusqliteValue::Real(d),
Value::Blob { value: b } => RusqliteValue::Blob(b),
}
}
}
impl From<RusqliteValue> for ValueWrapper {
fn from(v: RusqliteValue) -> Self {
match v {
RusqliteValue::Null => ValueWrapper(Value::Null),
RusqliteValue::Integer(n) => ValueWrapper(Value::Integer { value: n }),
RusqliteValue::Text(s) => ValueWrapper(Value::Text { value: s }),
RusqliteValue::Real(d) => ValueWrapper(Value::Float { value: d }),
RusqliteValue::Blob(b) => ValueWrapper(Value::Blob { value: b }),
}
}
}
impl Client {
pub fn new(path: impl AsRef<std::path::Path>) -> anyhow::Result<Self> {
Ok(Self {
inner: Arc::new(Mutex::new(
rusqlite::Connection::open(path).map_err(|e| anyhow::anyhow!("{e}"))?,
)),
})
}
pub fn in_memory() -> anyhow::Result<Self> {
Ok(Self {
inner: Arc::new(Mutex::new(
rusqlite::Connection::open(":memory:").map_err(|e| anyhow::anyhow!("{e}"))?,
)),
})
}
pub fn from_env() -> anyhow::Result<Self> {
let path = std::env::var("LIBSQL_CLIENT_URL").map_err(|_| {
anyhow::anyhow!("LIBSQL_CLIENT_URL variable should point to your sqld database")
})?;
let path = match path.strip_prefix("file:///") {
Some(path) => path,
None => anyhow::bail!("Local URL needs to start with file:///"),
};
Self::new(path)
}
pub fn raw_batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> anyhow::Result<BatchResult> {
let mut step_results = vec![];
let mut step_errors = vec![];
for stmt in stmts {
let stmt = stmt.into();
let sql_string = &stmt.sql;
let params = rusqlite::params_from_iter(
stmt.args
.into_iter()
.map(ValueWrapper)
.map(RusqliteValue::from),
);
let inner = self.inner.lock().unwrap();
let mut stmt = inner.prepare(sql_string)?;
let cols: Vec<Col> = stmt
.columns()
.into_iter()
.map(|c| Col {
name: Some(c.name().to_string()),
})
.collect();
let mut rows = Vec::new();
let mut input_rows = match stmt.query(params) {
Ok(rows) => rows,
Err(e) => {
step_results.push(None);
step_errors.push(Some(proto::Error {
message: e.to_string(),
}));
break;
}
};
while let Some(row) = input_rows.next()? {
let cells = (0..cols.len())
.map(|i| ValueWrapper::from(row.get::<usize, RusqliteValue>(i).unwrap()).0)
.collect();
rows.push(cells)
}
let stmt_result = StmtResult {
cols,
rows,
affected_row_count: 0,
last_insert_rowid: None,
};
step_results.push(Some(stmt_result));
step_errors.push(None);
}
Ok(BatchResult {
step_results,
step_errors,
})
}
pub fn batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement> + Send> + Send,
) -> Result<Vec<ResultSet>> {
let batch_results = self.raw_batch(
std::iter::once(Statement::new("BEGIN"))
.chain(stmts.into_iter().map(|s| s.into()))
.chain(std::iter::once(Statement::new("END"))),
)?;
let step_error: Option<proto::Error> = batch_results
.step_errors
.into_iter()
.skip(1)
.find(|e| e.is_some())
.flatten();
if let Some(error) = step_error {
return Err(anyhow::anyhow!(error.message));
}
let mut step_results: Vec<Result<ResultSet>> = batch_results
.step_results
.into_iter()
.skip(1) .map(|maybe_rs| {
maybe_rs
.map(ResultSet::from)
.ok_or_else(|| anyhow::anyhow!("Unexpected missing result set"))
})
.collect();
step_results.pop(); step_results.into_iter().collect::<Result<Vec<ResultSet>>>()
}
pub fn execute(&self, stmt: impl Into<Statement> + Send) -> Result<ResultSet> {
let results = self.raw_batch(std::iter::once(stmt))?;
match (results.step_results.first(), results.step_errors.first()) {
(Some(Some(result)), Some(None)) => Ok(ResultSet::from(result.clone())),
(Some(None), Some(Some(err))) => Err(anyhow::anyhow!(err.message.clone())),
_ => unreachable!(),
}
}
pub fn execute_in_transaction(&self, _tx_id: u64, stmt: Statement) -> Result<ResultSet> {
self.execute(stmt)
}
pub fn commit_transaction(&self, _tx_id: u64) -> Result<()> {
self.execute("COMMIT").map(|_| ())
}
pub fn rollback_transaction(&self, _tx_id: u64) -> Result<()> {
self.execute("ROLLBACK").map(|_| ())
}
}