use crate::client::Config;
use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::RwLock;
use crate::{utils, BatchResult, ResultSet, Statement};
pub struct Client {
url: String,
token: Option<String>,
client: hrana_client::Client,
client_future: hrana_client::ConnFut,
streams_for_transactions: RwLock<HashMap<u64, Arc<hrana_client::Stream>>>,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("url", &self.url)
.field("token", &self.token)
.finish()
}
}
impl Client {
pub async fn new(url: impl Into<String>, token: impl Into<String>) -> Result<Self> {
let token = token.into();
let token = if token.is_empty() { None } else { Some(token) };
let url = url.into();
let (client, client_future) = hrana_client::Client::connect(&url, token.clone()).await?;
Ok(Self {
url,
token,
client,
client_future,
streams_for_transactions: RwLock::new(HashMap::new()),
})
}
pub async fn reconnect(&mut self) -> Result<()> {
let (client, client_future) =
hrana_client::Client::connect(&self.url, self.token.clone()).await?;
self.client = client;
self.client_future = client_future;
Ok(())
}
pub async fn from_url<T: TryInto<url::Url>>(url: T) -> anyhow::Result<Client>
where
<T as TryInto<url::Url>>::Error: std::fmt::Display,
{
let mut url: url::Url = url
.try_into()
.map_err(|e| anyhow::anyhow!(format!("{e}")))?;
let token = utils::pop_query_param(&mut url, "authToken".to_string());
let url_str = if url.scheme() == "libsql" {
let new_url = format!("wss://{}", url.as_str().strip_prefix("libsql://").unwrap());
url::Url::parse(&new_url).unwrap().to_string()
} else {
url.to_string()
};
if let Some(token) = token {
Client::new(url_str, token).await
} else {
Client::new(url_str, "").await
}
}
pub async fn from_config(config: Config) -> Result<Self> {
Self::new(config.url, config.auth_token.unwrap_or_default()).await
}
pub async fn shutdown(self) -> Result<()> {
self.client.shutdown().await?;
self.client_future.await?;
Ok(())
}
async fn stream_for_transaction(&self, tx_id: u64) -> Result<Arc<hrana_client::Stream>> {
{
let streams = self.streams_for_transactions.read().unwrap();
if streams.contains_key(&tx_id) {
tracing::trace!("Found stream for transaction {tx_id}");
return Ok(streams.get(&tx_id).unwrap().clone()); }
}
let stream = Arc::new(self.client.open_stream().await?);
tracing::trace!("Created new stream");
let mut streams = self.streams_for_transactions.write().unwrap();
if let std::collections::hash_map::Entry::Vacant(e) = streams.entry(tx_id) {
e.insert(stream.clone());
}
Ok(stream)
}
fn drop_stream_for_transaction(&self, tx_id: u64) {
let mut streams = self.streams_for_transactions.write().unwrap();
tracing::trace!("Dropping stream for transaction {tx_id}");
streams.remove(&tx_id);
}
fn into_hrana(stmt: Statement) -> hrana_client::proto::Stmt {
let mut hrana_stmt = hrana_client::proto::Stmt::new(stmt.sql, true);
for param in stmt.args {
hrana_stmt.bind(param);
}
hrana_stmt
}
}
impl Client {
pub async fn raw_batch(
&self,
stmts: impl IntoIterator<Item = impl Into<Statement>>,
) -> anyhow::Result<BatchResult> {
let mut batch = hrana_client::proto::Batch::new();
for stmt in stmts.into_iter() {
let stmt: Statement = stmt.into();
let mut hrana_stmt = hrana_client::proto::Stmt::new(stmt.sql, true);
for param in stmt.args {
hrana_stmt.bind(param);
}
batch.step(None, hrana_stmt);
}
let stream = self.client.open_stream().await?;
stream
.execute_batch(batch)
.await
.map_err(|e| anyhow::anyhow!("{}", e))
}
pub async fn execute(&self, stmt: impl Into<Statement>) -> Result<ResultSet> {
let stmt = Self::into_hrana(stmt.into());
let stream = self.client.open_stream().await?;
stream
.execute(stmt)
.await
.map(ResultSet::from)
.map_err(|e| anyhow::anyhow!("{}", e))
}
pub async fn execute_in_transaction(&self, tx_id: u64, stmt: Statement) -> Result<ResultSet> {
let stmt = Self::into_hrana(stmt);
tracing::trace!("Transaction {tx_id} executing {}", stmt.sql);
let stream = self.stream_for_transaction(tx_id).await?;
stream
.execute(stmt)
.await
.map(ResultSet::from)
.map_err(|e| anyhow::anyhow!("{}", e))
}
pub async fn commit_transaction(&self, tx_id: u64) -> Result<()> {
tracing::trace!("Transaction {tx_id} commit");
let stream = self.stream_for_transaction(tx_id).await?;
self.drop_stream_for_transaction(tx_id);
stream
.execute(Self::into_hrana(Statement::from("COMMIT")))
.await
.map(|_| ())
.map_err(|e| anyhow::anyhow!("{}", e))
}
pub async fn rollback_transaction(&self, tx_id: u64) -> Result<()> {
tracing::trace!("Transaction {tx_id} rollback");
let stream = self.stream_for_transaction(tx_id).await?;
self.drop_stream_for_transaction(tx_id);
stream
.execute(Self::into_hrana(Statement::from("ROLLBACK")))
.await
.map(|_| ())
.map_err(|e| anyhow::anyhow!("{}", e))
}
}