trailbase_wasm/
db.rs

1use base64::prelude::*;
2use serde::Serialize;
3use wstd::http::body::IntoBody;
4use wstd::http::{Client, Request};
5
6use crate::wit::trailbase::runtime::host_endpoint::{
7  tx_begin, tx_commit, tx_execute, tx_query, tx_rollback,
8};
9
10pub use crate::wit::trailbase::runtime::host_endpoint::{TxError, Value};
11pub use trailbase_wasm_common::{SqliteRequest, SqliteResponse};
12
13pub struct Transaction {
14  committed: bool,
15}
16
17impl Transaction {
18  pub fn begin() -> Result<Self, TxError> {
19    tx_begin()?;
20    return Ok(Self { committed: false });
21  }
22
23  pub fn query(&mut self, query: &str, params: &[Value]) -> Result<Vec<Vec<Value>>, TxError> {
24    return tx_query(query, params);
25  }
26
27  pub fn execute(&mut self, query: &str, params: &[Value]) -> Result<u64, TxError> {
28    return tx_execute(query, params);
29  }
30
31  pub fn commit(&mut self) -> Result<(), TxError> {
32    if !self.committed {
33      self.committed = true;
34      tx_commit()?;
35    }
36    return Ok(());
37  }
38}
39
40impl Drop for Transaction {
41  fn drop(&mut self) {
42    if !self.committed {
43      if let Err(err) = tx_rollback() {
44        log::warn!("TX rollback failed: {err}");
45      }
46    }
47  }
48}
49
50#[derive(Debug, thiserror::Error)]
51pub enum Error {
52  #[error("Unexpected Type")]
53  UnexpectedType,
54  #[error("Not a Number")]
55  NotANumber,
56  #[error("Decoding")]
57  Decording(#[from] base64::DecodeError),
58  #[error("Other: {0}")]
59  Other(String),
60}
61
62pub async fn query(
63  query: impl std::string::ToString,
64  params: impl Into<Vec<Value>>,
65) -> Result<Vec<Vec<Value>>, Error> {
66  let r = SqliteRequest {
67    query: query.to_string(),
68    params: params.into().into_iter().map(to_json_value).collect(),
69  };
70  let request = Request::builder()
71    .uri("http://__sqlite/query")
72    .method("POST")
73    .body(
74      serde_json::to_vec(&r)
75        .map_err(|_| Error::UnexpectedType)?
76        .into_body(),
77    )
78    .map_err(|err| Error::Other(err.to_string()))?;
79
80  let client = Client::new();
81  let (_parts, mut body) = client
82    .send(request)
83    .await
84    .map_err(|err| Error::Other(err.to_string()))?
85    .into_parts();
86
87  let bytes = body
88    .bytes()
89    .await
90    .map_err(|err| Error::Other(err.to_string()))?;
91
92  return match serde_json::from_slice(&bytes) {
93    Ok(SqliteResponse::Query { rows }) => Ok(
94      rows
95        .into_iter()
96        .map(|row| {
97          row
98            .into_iter()
99            .map(from_json_value)
100            .collect::<Result<Vec<_>, _>>()
101        })
102        .collect::<Result<Vec<_>, _>>()?,
103    ),
104    Ok(_) => Err(Error::UnexpectedType),
105    Err(err) => Err(Error::Other(err.to_string())),
106  };
107}
108
109pub async fn execute(
110  query: impl std::string::ToString,
111  params: impl Into<Vec<Value>>,
112) -> Result<usize, Error> {
113  let r = SqliteRequest {
114    query: query.to_string(),
115    params: params.into().into_iter().map(to_json_value).collect(),
116  };
117  let request = Request::builder()
118    .uri("http://__sqlite/execute")
119    .method("POST")
120    .body(
121      serde_json::to_vec(&r)
122        .map_err(|_| Error::UnexpectedType)?
123        .into_body(),
124    )
125    .map_err(|err| Error::Other(err.to_string()))?;
126
127  let client = Client::new();
128  let (_parts, mut body) = client
129    .send(request)
130    .await
131    .map_err(|err| Error::Other(err.to_string()))?
132    .into_parts();
133
134  let bytes = body
135    .bytes()
136    .await
137    .map_err(|err| Error::Other(err.to_string()))?;
138
139  return match serde_json::from_slice(&bytes) {
140    Ok(SqliteResponse::Execute { rows_affected }) => Ok(rows_affected),
141    Ok(_) => Err(Error::UnexpectedType),
142    Err(err) => Err(Error::Other(err.to_string())),
143  };
144}
145
146fn from_json_value(value: serde_json::Value) -> Result<Value, Error> {
147  return match value {
148    serde_json::Value::Null => Ok(Value::Null),
149    serde_json::Value::String(s) => Ok(Value::Text(s)),
150    serde_json::Value::Object(mut map) => match map.remove("blob") {
151      Some(serde_json::Value::String(str)) => Ok(Value::Blob(BASE64_URL_SAFE.decode(&str)?)),
152      _ => Err(Error::UnexpectedType),
153    },
154    serde_json::Value::Number(n) => {
155      if let Some(n) = n.as_i64() {
156        Ok(Value::Integer(n))
157      } else if let Some(n) = n.as_u64() {
158        Ok(Value::Integer(n as i64))
159      } else if let Some(n) = n.as_f64() {
160        Ok(Value::Real(n))
161      } else {
162        Err(Error::NotANumber)
163      }
164    }
165    _ => Err(Error::UnexpectedType),
166  };
167}
168
169#[derive(Serialize)]
170struct Blob {
171  blob: String,
172}
173
174impl serde::ser::Serialize for Value {
175  fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
176  where
177    S: serde::ser::Serializer,
178  {
179    return match self {
180      Value::Null => serializer.serialize_unit(),
181      Value::Text(s) => serializer.serialize_str(s),
182      Value::Integer(i) => serializer.serialize_i64(*i),
183      Value::Real(f) => serializer.serialize_f64(*f),
184      Value::Blob(blob) => serializer.serialize_some(&Blob {
185        blob: BASE64_URL_SAFE.encode(blob),
186      }),
187    };
188  }
189}
190
191pub fn to_json_value(value: Value) -> serde_json::Value {
192  return match value {
193    Value::Null => serde_json::Value::Null,
194    Value::Text(s) => serde_json::Value::String(s),
195    Value::Integer(i) => serde_json::Value::Number(serde_json::Number::from(i)),
196    Value::Real(f) => match serde_json::Number::from_f64(f) {
197      Some(n) => serde_json::Value::Number(n),
198      None => serde_json::Value::Null,
199    },
200    Value::Blob(blob) => serde_json::json!({
201        "blob": BASE64_URL_SAFE.encode(blob)
202    }),
203  };
204}