omnia_wasi_sql/host/
default_impl.rs1#![allow(clippy::significant_drop_tightening)]
6#![allow(clippy::cast_possible_truncation)]
7#![allow(clippy::cast_possible_wrap)]
8#![allow(clippy::cast_lossless)]
9#![allow(missing_docs)]
10
11use std::sync::Arc;
12
13use anyhow::{Context, Result};
14use fromenv::FromEnv;
15use futures::FutureExt;
16use omnia::Backend;
17use rusqlite::types::ValueRef;
18use rusqlite::{Connection as SqliteConnection, params_from_iter};
19use tracing::instrument;
20
21use crate::host::resource::{Connection, FutureResult};
22use crate::host::{DataType, Field, Row, WasiSqlCtx};
23
24#[derive(Debug, Clone, FromEnv)]
28pub struct ConnectOptions {
29 #[env(from = "SQL_DATABASE", default = "file::memory:?cache=shared")]
30 pub database: String,
31}
32
33#[allow(missing_docs)]
34impl omnia::FromEnv for ConnectOptions {
35 fn from_env() -> Result<Self> {
36 Self::from_env().finalize().context("issue loading connection options")
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct SqlDefault {
43 conn: Arc<parking_lot::Mutex<SqliteConnection>>,
46}
47
48impl Backend for SqlDefault {
49 type ConnectOptions = ConnectOptions;
50
51 #[instrument]
52 async fn connect_with(options: Self::ConnectOptions) -> Result<Self> {
53 tracing::debug!("initializing SQLite connection to: {}", options.database);
54
55 let conn = Arc::new(parking_lot::Mutex::new(
57 SqliteConnection::open(&options.database).context("failed to open SQLite database")?,
58 ));
59
60 Ok(Self { conn })
61 }
62}
63
64impl WasiSqlCtx for SqlDefault {
65 fn open(&self, _name: String) -> FutureResult<Arc<dyn Connection>> {
66 tracing::debug!("opening SQL connection");
67 let conn = Arc::clone(&self.conn);
68
69 async move {
70 let connection = SqliteConnectionImpl { conn };
71 Ok(Arc::new(connection) as Arc<dyn Connection>)
72 }
73 .boxed()
74 }
75}
76
77#[derive(Debug, Clone)]
78struct SqliteConnectionImpl {
79 conn: Arc<parking_lot::Mutex<SqliteConnection>>,
80}
81
82impl Connection for SqliteConnectionImpl {
83 fn query(&self, query: String, params: Vec<DataType>) -> FutureResult<Vec<Row>> {
84 tracing::debug!("executing query: {}", query);
85 let conn = Arc::clone(&self.conn);
86
87 async move {
88 let conn = conn.lock();
89 let mut stmt = conn.prepare(&query).context("failed to prepare statement")?;
90
91 let rusqlite_params: Vec<_> = params.iter().map(datatype_to_rusqlite_value).collect();
93
94 let column_names: Vec<String> =
96 stmt.column_names().iter().map(ToString::to_string).collect();
97
98 let mut rows = stmt
100 .query(params_from_iter(rusqlite_params.iter()))
101 .context("failed to execute query")?;
102
103 let mut result_rows = Vec::new();
104 let mut index = 0;
105 while let Some(row) = rows.next().context("failed to fetch row")? {
106 let mut fields = Vec::new();
107
108 for (i, name) in column_names.iter().enumerate() {
109 let value = row.get_ref(i).context("failed to get column value")?;
110 let data_type = rusqlite_value_to_datatype(value)?;
111
112 fields.push(Field {
113 name: name.clone(),
114 value: data_type,
115 });
116 }
117
118 result_rows.push(Row {
119 index: index.to_string(),
120 fields,
121 });
122 index += 1;
123 }
124
125 Ok(result_rows)
126 }
127 .boxed()
128 }
129
130 fn exec(&self, query: String, params: Vec<DataType>) -> FutureResult<u32> {
131 tracing::debug!("executing statement: {}", query);
132 let conn = Arc::clone(&self.conn);
133
134 async move {
135 let conn = conn.lock();
136 let mut stmt = conn.prepare(&query).context("failed to prepare statement")?;
137
138 let rusqlite_params: Vec<_> = params.iter().map(datatype_to_rusqlite_value).collect();
140
141 let rows_affected = stmt
142 .execute(params_from_iter(rusqlite_params.iter()))
143 .context("failed to execute statement")?;
144
145 #[allow(clippy::cast_possible_truncation)]
146 Ok(rows_affected as u32)
147 }
148 .boxed()
149 }
150}
151
152fn datatype_to_rusqlite_value(dt: &DataType) -> rusqlite::types::Value {
153 match dt {
154 DataType::Boolean(Some(b)) => rusqlite::types::Value::Integer(i64::from(*b)),
155 DataType::Int32(Some(i)) => rusqlite::types::Value::Integer(i64::from(*i)),
156 DataType::Int64(Some(i)) => rusqlite::types::Value::Integer(*i),
157 DataType::Uint32(Some(u)) => rusqlite::types::Value::Integer(i64::from(*u)),
158 DataType::Uint64(Some(u)) => rusqlite::types::Value::Integer(*u as i64),
159 DataType::Float(Some(f)) => rusqlite::types::Value::Real(f64::from(*f)),
160 DataType::Double(Some(f)) => rusqlite::types::Value::Real(*f),
161 DataType::Str(Some(s)) => rusqlite::types::Value::Text(s.clone()),
162 DataType::Binary(Some(b)) => rusqlite::types::Value::Blob(b.clone()),
163 DataType::Timestamp(Some(ts)) => rusqlite::types::Value::Text(ts.clone()),
164 _ => rusqlite::types::Value::Null,
166 }
167}
168
169fn rusqlite_value_to_datatype(value: ValueRef) -> Result<DataType> {
170 match value {
171 ValueRef::Null => Ok(DataType::Str(None)),
172 ValueRef::Integer(i) => Ok(DataType::Int64(Some(i))),
173 ValueRef::Real(f) => Ok(DataType::Double(Some(f))),
174 ValueRef::Text(t) => {
175 let s = std::str::from_utf8(t).context("invalid UTF-8 in text value")?;
176 Ok(DataType::Str(Some(s.to_string())))
177 }
178 ValueRef::Blob(b) => Ok(DataType::Binary(Some(b.to_vec()))),
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185
186 #[tokio::test]
187 async fn sqlite_operations() {
188 let ctx = SqlDefault::connect_with(ConnectOptions {
189 database: ":memory:".to_string(),
190 })
191 .await
192 .expect("connect");
193
194 let conn = ctx.open("test".to_string()).await.expect("open connection");
195
196 let rows_affected = conn
198 .exec(
199 "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)".to_string(),
200 vec![],
201 )
202 .await
203 .expect("create table");
204 assert_eq!(rows_affected, 0);
205
206 let rows_affected = conn
208 .exec(
209 "INSERT INTO users (name, age) VALUES (?, ?)".to_string(),
210 vec![DataType::Str(Some("Alice".to_string())), DataType::Int32(Some(30))],
211 )
212 .await
213 .expect("insert");
214 assert_eq!(rows_affected, 1);
215
216 let rows_affected = conn
217 .exec(
218 "INSERT INTO users (name, age) VALUES (?, ?)".to_string(),
219 vec![DataType::Str(Some("Bob".to_string())), DataType::Int32(Some(25))],
220 )
221 .await
222 .expect("insert");
223 assert_eq!(rows_affected, 1);
224
225 let rows = conn
227 .query("SELECT id, name, age FROM users ORDER BY name".to_string(), vec![])
228 .await
229 .expect("query");
230
231 assert_eq!(rows.len(), 2);
232 assert_eq!(rows[0].fields[1].name, "name");
233 if let DataType::Str(Some(ref name)) = rows[0].fields[1].value {
234 assert_eq!(name, "Alice");
235 } else {
236 panic!("Expected string value");
237 }
238 }
239}