use futures::lock::Mutex;
use futures::FutureExt;
use indexmap::IndexMap;
use libsql::{params::Params, Builder as LibsqlBuilder, Connection, Database, Value};
use serde_json::Value as JsonValue;
use std::collections::HashMap;
use std::panic::AssertUnwindSafe;
use std::path::{Component, Path, PathBuf};
use std::sync::Arc;
use crate::decode;
use crate::error::Error;
use crate::models::{EncryptionConfig, QueryResult};
pub struct DbConnection {
conn: Connection,
db: Database,
}
impl DbConnection {
pub async fn connect(
path: &str,
encryption: Option<EncryptionConfig>,
base_path: PathBuf,
sync_url: Option<String>,
auth_token: Option<String>,
) -> Result<Self, Error> {
let path = path.to_string();
let db = AssertUnwindSafe(async move {
if let Some(url) = sync_url {
let full_path = Self::resolve_local_path(&path, &base_path)?;
Self::open_replica(full_path, url, auth_token.unwrap_or_default(), encryption).await
} else if path.starts_with("libsql://") || path.starts_with("https://") {
Self::open_remote(path, auth_token.unwrap_or_default()).await
} else {
let full_path = Self::resolve_local_path(&path, &base_path)?;
Self::open_local(full_path, encryption).await
}
})
.catch_unwind()
.await
.map_err(|_| {
Error::InvalidDbUrl(
"libsql panicked building the database — check your URL format \
(expected libsql://… or https://…)"
.into(),
)
})??;
let conn = db.connect()?;
Ok(Self { conn, db })
}
fn resolve_local_path(path: &str, base_path: &Path) -> Result<PathBuf, Error> {
let db_path = path.strip_prefix("sqlite:").unwrap_or(path);
if db_path == ":memory:" {
return Ok(PathBuf::from(":memory:"));
}
if PathBuf::from(db_path).is_absolute() {
return Ok(PathBuf::from(db_path));
}
let joined = base_path.join(db_path);
let normalised = joined.components().fold(PathBuf::new(), |mut acc, c| {
match c {
Component::ParentDir => {
acc.pop();
}
Component::CurDir => {}
_ => acc.push(c),
}
acc
});
if !normalised.starts_with(base_path) {
return Err(Error::InvalidDbUrl(format!(
"path '{}' escapes the base directory",
db_path
)));
}
Ok(normalised)
}
async fn open_local(
full_path: PathBuf,
encryption: Option<EncryptionConfig>,
) -> Result<Database, Error> {
#[allow(unused_mut)]
let mut builder = LibsqlBuilder::new_local(&full_path.to_string_lossy().to_string());
#[cfg(feature = "encryption")]
if let Some(config) = encryption {
builder = builder.encryption_config(config.into());
}
#[cfg(not(feature = "encryption"))]
if encryption.is_some() {
return Err(Error::InvalidDbUrl(
"encryption feature is not enabled — rebuild with the `encryption` feature".into(),
));
}
Ok(builder.build().await?)
}
#[cfg(feature = "replication")]
async fn open_replica(
full_path: PathBuf,
sync_url: String,
auth_token: String,
encryption: Option<EncryptionConfig>,
) -> Result<Database, Error> {
#[allow(unused_mut)]
let mut builder = LibsqlBuilder::new_remote_replica(
full_path.to_string_lossy().to_string(),
sync_url,
auth_token,
);
#[cfg(feature = "encryption")]
if let Some(config) = encryption {
builder = builder.encryption_config(config.into());
}
let db = builder.build().await?;
db.sync().await?;
Ok(db)
}
#[cfg(not(feature = "replication"))]
async fn open_replica(
_full_path: PathBuf,
_sync_url: String,
_auth_token: String,
_encryption: Option<EncryptionConfig>,
) -> Result<Database, Error> {
Err(Error::InvalidDbUrl(
"embedded replica requires the `replication` feature — add features = [\"replication\"] to your Cargo.toml".into(),
))
}
#[cfg(feature = "remote")]
async fn open_remote(url: String, auth_token: String) -> Result<Database, Error> {
Ok(LibsqlBuilder::new_remote(url, auth_token).build().await?)
}
#[cfg(not(feature = "remote"))]
async fn open_remote(_url: String, _auth_token: String) -> Result<Database, Error> {
Err(Error::InvalidDbUrl(
"remote connections require the `remote` feature — add features = [\"remote\"] to your Cargo.toml".into(),
))
}
pub async fn sync(&self) -> Result<(), Error> {
Self::do_sync(&self.db).await
}
#[cfg(feature = "replication")]
async fn do_sync(db: &Database) -> Result<(), Error> {
db.sync().await?;
Ok(())
}
#[cfg(not(feature = "replication"))]
async fn do_sync(_db: &Database) -> Result<(), Error> {
Err(Error::OperationNotSupported(
"sync requires the `replication` feature".into(),
))
}
pub async fn execute(&self, query: &str, values: Vec<JsonValue>) -> Result<QueryResult, Error> {
let params = json_to_params(values);
let rows_affected = self.conn.execute(query, params).await?;
Ok(QueryResult {
rows_affected,
last_insert_id: self.conn.last_insert_rowid(),
})
}
pub async fn select(
&self,
query: &str,
values: Vec<JsonValue>,
) -> Result<Vec<IndexMap<String, JsonValue>>, Error> {
let params = json_to_params(values);
let mut rows = self.conn.query(query, params).await?;
let mut results = Vec::new();
while let Some(row) = rows.next().await? {
let mut map = IndexMap::new();
let column_count = row.column_count();
for i in 0..column_count {
if let Some(column_name) = row.column_name(i) {
let value = decode::to_json(&row, i)?;
map.insert(column_name.to_string(), value);
}
}
results.push(map);
}
Ok(results)
}
pub async fn batch(&self, queries: Vec<String>) -> Result<(), Error> {
self.conn.execute("BEGIN", Params::None).await?;
for query in &queries {
if let Err(e) = self.conn.execute(query.as_str(), Params::None).await {
let _ = self.conn.execute("ROLLBACK", Params::None).await;
return Err(Error::Libsql(e));
}
}
if let Err(e) = self.conn.execute("COMMIT", Params::None).await {
let _ = self.conn.execute("ROLLBACK", Params::None).await;
return Err(Error::Libsql(e));
}
Ok(())
}
pub async fn close(&self) {
self.conn.reset().await;
}
}
fn json_to_params(values: Vec<JsonValue>) -> Params {
if values.is_empty() {
return Params::None;
}
let params: Vec<Value> = values.into_iter().map(json_to_libsql_value).collect();
Params::Positional(params)
}
fn json_to_libsql_value(v: JsonValue) -> Value {
match v {
JsonValue::Null => Value::Null,
JsonValue::Bool(b) => Value::Integer(if b { 1 } else { 0 }),
JsonValue::Number(n) => {
if let Some(i) = n.as_i64() {
Value::Integer(i)
} else if let Some(f) = n.as_f64() {
Value::Real(f)
} else {
Value::Null
}
}
JsonValue::String(s) => Value::Text(s),
JsonValue::Array(ref arr) => {
if arr.iter().all(|v| v.is_number()) {
let bytes: Vec<u8> = arr
.iter()
.filter_map(|v| v.as_u64().map(|n| n as u8))
.collect();
Value::Blob(bytes)
} else {
Value::Text(v.to_string())
}
}
JsonValue::Object(_) => Value::Text(v.to_string()),
}
}
pub struct DbInstances(pub Arc<Mutex<HashMap<String, Arc<DbConnection>>>>);
impl Default for DbInstances {
fn default() -> Self {
Self(Arc::new(Mutex::new(HashMap::new())))
}
}