use nats_wasi::client::{Client, Duration, secs};
use serde::{Deserialize, Serialize, de::DeserializeOwned};
#[derive(Debug)]
pub enum Error {
Nats(nats_wasi::Error),
Db(String),
Json(String),
WrongResultType(String),
}
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Error::Nats(e) => write!(f, "nats: {e}"),
Error::Db(e) => write!(f, "lattice-sql: {e}"),
Error::Json(e) => write!(f, "json: {e}"),
Error::WrongResultType(e) => write!(f, "wrong result type: {e}"),
}
}
}
impl std::error::Error for Error {}
impl From<nats_wasi::Error> for Error {
fn from(e: nats_wasi::Error) -> Self {
Error::Nats(e)
}
}
#[derive(Debug, Clone)]
pub struct QueryResult {
pub columns: Vec<String>,
pub rows: Vec<Vec<serde_json::Value>>,
}
impl QueryResult {
pub fn row_count(&self) -> usize {
self.rows.len()
}
pub fn col_index(&self, name: &str) -> Option<usize> {
self.columns.iter().position(|c| c == name)
}
pub fn cell(&self, row: usize, col: &str) -> Option<&serde_json::Value> {
let col_idx = self.col_index(col)?;
self.rows.get(row)?.get(col_idx)
}
pub fn deserialize_rows<T: DeserializeOwned>(&self) -> Result<Vec<T>, String> {
self.rows
.iter()
.map(|row| {
let mut map = serde_json::Map::new();
for (i, col) in self.columns.iter().enumerate() {
map.insert(
col.clone(),
row.get(i).cloned().unwrap_or(serde_json::Value::Null),
);
}
serde_json::from_value(serde_json::Value::Object(map))
.map_err(|e| e.to_string())
})
.collect()
}
}
#[derive(Debug, Clone)]
pub enum SqlResult {
Query(QueryResult),
Exec {
affected_rows: u64,
},
}
impl SqlResult {
pub fn into_query(self) -> Result<QueryResult, Error> {
match self {
SqlResult::Query(r) => Ok(r),
SqlResult::Exec { .. } => Err(Error::WrongResultType(
"expected SELECT result but got exec/DDL result".into(),
)),
}
}
pub fn into_affected_rows(self) -> Result<u64, Error> {
match self {
SqlResult::Exec { affected_rows } => Ok(affected_rows),
SqlResult::Query(_) => Err(Error::WrongResultType(
"expected exec/DDL result but got SELECT result".into(),
)),
}
}
pub fn is_query(&self) -> bool {
matches!(self, SqlResult::Query(_))
}
pub fn is_exec(&self) -> bool {
matches!(self, SqlResult::Exec { .. })
}
}
#[derive(Serialize)]
struct SqlReq<'a> {
sql: &'a str,
#[serde(rename = "_auth", skip_serializing_if = "Option::is_none")]
auth: Option<&'a str>,
}
#[derive(Deserialize)]
struct AnyResp {
columns: Option<Vec<String>>,
rows: Option<Vec<Vec<serde_json::Value>>>,
affected_rows: Option<u64>,
error: Option<String>,
}
pub struct LatticeSql {
client: Client,
timeout: Duration,
auth_token: Option<String>,
}
impl LatticeSql {
pub const SUBJECT: &'static str = "ldb.sql.query";
pub fn new(client: Client) -> Self {
Self { client, timeout: secs(10), auth_token: None }
}
pub fn with_timeout(client: Client, timeout: Duration) -> Self {
Self { client, timeout, auth_token: None }
}
pub fn with_auth(mut self, token: impl Into<String>) -> Self {
self.auth_token = Some(token.into());
self
}
pub async fn query(&self, sql: &str) -> Result<QueryResult, Error> {
self.sql(sql).await?.into_query()
}
pub async fn exec(&self, sql: &str) -> Result<u64, Error> {
self.sql(sql).await?.into_affected_rows()
}
pub async fn ddl(&self, sql: &str) -> Result<(), Error> {
self.exec(sql).await?;
Ok(())
}
pub async fn sql(&self, sql: &str) -> Result<SqlResult, Error> {
let resp: AnyResp = self.send(sql).await?;
if let Some(msg) = resp.error {
return Err(Error::Db(msg));
}
if let Some(columns) = resp.columns {
return Ok(SqlResult::Query(QueryResult {
columns,
rows: resp.rows.unwrap_or_default(),
}));
}
Ok(SqlResult::Exec {
affected_rows: resp.affected_rows.unwrap_or(0),
})
}
async fn send<R: DeserializeOwned>(&self, sql: &str) -> Result<R, Error> {
let body =
serde_json::to_vec(&SqlReq { sql, auth: self.auth_token.as_deref() })
.map_err(|e| Error::Json(e.to_string()))?;
let reply = self
.client
.request(Self::SUBJECT, &body, self.timeout)
.await?;
serde_json::from_slice(&reply.payload).map_err(|e| Error::Json(e.to_string()))
}
}