Skip to main content

omnia_wasi_sql/host/
default_impl.rs

1//! Default `SQLite` implementation for wasi-sql
2//!
3//! This is a lightweight implementation for development use only.
4
5#![allow(clippy::significant_drop_tightening)]
6#![allow(clippy::cast_possible_truncation)]
7#![allow(clippy::cast_possible_wrap)]
8#![allow(clippy::cast_lossless)]
9#![allow(missing_docs)]
10
11use std::sync::Arc;
12
13use anyhow::{Context, Result};
14use fromenv::FromEnv;
15use futures::FutureExt;
16use omnia::Backend;
17use rusqlite::types::ValueRef;
18use rusqlite::{Connection as SqliteConnection, params_from_iter};
19use tracing::instrument;
20
21use crate::host::resource::{Connection, FutureResult};
22use crate::host::{DataType, Field, Row, WasiSqlCtx};
23
24/// Options used to connect to the SQL database.
25///
26/// This struct is used to load connection options from environment variables.
27#[derive(Debug, Clone, FromEnv)]
28pub struct ConnectOptions {
29    #[env(from = "SQL_DATABASE", default = "file::memory:?cache=shared")]
30    pub database: String,
31}
32
33#[allow(missing_docs)]
34impl omnia::FromEnv for ConnectOptions {
35    fn from_env() -> Result<Self> {
36        Self::from_env().finalize().context("issue loading connection options")
37    }
38}
39
40/// Default implementation for `wasi:sql`.
41#[derive(Debug, Clone)]
42pub struct SqlDefault {
43    // Store the database path to create new connections on demand
44    // Mutex is necessary since rusqlite::Connection isn't `Sync`
45    conn: Arc<parking_lot::Mutex<SqliteConnection>>,
46}
47
48impl Backend for SqlDefault {
49    type ConnectOptions = ConnectOptions;
50
51    #[instrument]
52    async fn connect_with(options: Self::ConnectOptions) -> Result<Self> {
53        tracing::debug!("initializing SQLite connection to: {}", options.database);
54
55        // Create initial connection to validate database path
56        let conn = Arc::new(parking_lot::Mutex::new(
57            SqliteConnection::open(&options.database).context("failed to open SQLite database")?,
58        ));
59
60        Ok(Self { conn })
61    }
62}
63
64impl WasiSqlCtx for SqlDefault {
65    fn open(&self, _name: String) -> FutureResult<Arc<dyn Connection>> {
66        tracing::debug!("opening SQL connection");
67        let conn = Arc::clone(&self.conn);
68
69        async move {
70            let connection = SqliteConnectionImpl { conn };
71            Ok(Arc::new(connection) as Arc<dyn Connection>)
72        }
73        .boxed()
74    }
75}
76
77#[derive(Debug, Clone)]
78struct SqliteConnectionImpl {
79    conn: Arc<parking_lot::Mutex<SqliteConnection>>,
80}
81
82impl Connection for SqliteConnectionImpl {
83    fn query(&self, query: String, params: Vec<DataType>) -> FutureResult<Vec<Row>> {
84        tracing::debug!("executing query: {}", query);
85        let conn = Arc::clone(&self.conn);
86
87        async move {
88            let conn = conn.lock();
89            let mut stmt = conn.prepare(&query).context("failed to prepare statement")?;
90
91            // Convert DataType to rusqlite values
92            let rusqlite_params: Vec<_> = params.iter().map(datatype_to_rusqlite_value).collect();
93
94            // Get column names
95            let column_names: Vec<String> =
96                stmt.column_names().iter().map(ToString::to_string).collect();
97
98            // Execute query and collect rows
99            let mut rows = stmt
100                .query(params_from_iter(rusqlite_params.iter()))
101                .context("failed to execute query")?;
102
103            let mut result_rows = Vec::new();
104            let mut index = 0;
105            while let Some(row) = rows.next().context("failed to fetch row")? {
106                let mut fields = Vec::new();
107
108                for (i, name) in column_names.iter().enumerate() {
109                    let value = row.get_ref(i).context("failed to get column value")?;
110                    let data_type = rusqlite_value_to_datatype(value)?;
111
112                    fields.push(Field {
113                        name: name.clone(),
114                        value: data_type,
115                    });
116                }
117
118                result_rows.push(Row {
119                    index: index.to_string(),
120                    fields,
121                });
122                index += 1;
123            }
124
125            Ok(result_rows)
126        }
127        .boxed()
128    }
129
130    fn exec(&self, query: String, params: Vec<DataType>) -> FutureResult<u32> {
131        tracing::debug!("executing statement: {}", query);
132        let conn = Arc::clone(&self.conn);
133
134        async move {
135            let conn = conn.lock();
136            let mut stmt = conn.prepare(&query).context("failed to prepare statement")?;
137
138            // Convert DataType to rusqlite values
139            let rusqlite_params: Vec<_> = params.iter().map(datatype_to_rusqlite_value).collect();
140
141            let rows_affected = stmt
142                .execute(params_from_iter(rusqlite_params.iter()))
143                .context("failed to execute statement")?;
144
145            #[allow(clippy::cast_possible_truncation)]
146            Ok(rows_affected as u32)
147        }
148        .boxed()
149    }
150}
151
152fn datatype_to_rusqlite_value(dt: &DataType) -> rusqlite::types::Value {
153    match dt {
154        DataType::Boolean(Some(b)) => rusqlite::types::Value::Integer(i64::from(*b)),
155        DataType::Int32(Some(i)) => rusqlite::types::Value::Integer(i64::from(*i)),
156        DataType::Int64(Some(i)) => rusqlite::types::Value::Integer(*i),
157        DataType::Uint32(Some(u)) => rusqlite::types::Value::Integer(i64::from(*u)),
158        DataType::Uint64(Some(u)) => rusqlite::types::Value::Integer(*u as i64),
159        DataType::Float(Some(f)) => rusqlite::types::Value::Real(f64::from(*f)),
160        DataType::Double(Some(f)) => rusqlite::types::Value::Real(*f),
161        DataType::Str(Some(s)) => rusqlite::types::Value::Text(s.clone()),
162        DataType::Binary(Some(b)) => rusqlite::types::Value::Blob(b.clone()),
163        DataType::Timestamp(Some(ts)) => rusqlite::types::Value::Text(ts.clone()),
164        // All None variants map to NULL
165        _ => rusqlite::types::Value::Null,
166    }
167}
168
169fn rusqlite_value_to_datatype(value: ValueRef) -> Result<DataType> {
170    match value {
171        ValueRef::Null => Ok(DataType::Str(None)),
172        ValueRef::Integer(i) => Ok(DataType::Int64(Some(i))),
173        ValueRef::Real(f) => Ok(DataType::Double(Some(f))),
174        ValueRef::Text(t) => {
175            let s = std::str::from_utf8(t).context("invalid UTF-8 in text value")?;
176            Ok(DataType::Str(Some(s.to_string())))
177        }
178        ValueRef::Blob(b) => Ok(DataType::Binary(Some(b.to_vec()))),
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185
186    #[tokio::test]
187    async fn sqlite_operations() {
188        let ctx = SqlDefault::connect_with(ConnectOptions {
189            database: ":memory:".to_string(),
190        })
191        .await
192        .expect("connect");
193
194        let conn = ctx.open("test".to_string()).await.expect("open connection");
195
196        // Create a test table
197        let rows_affected = conn
198            .exec(
199                "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT, age INTEGER)".to_string(),
200                vec![],
201            )
202            .await
203            .expect("create table");
204        assert_eq!(rows_affected, 0);
205
206        // Insert some data
207        let rows_affected = conn
208            .exec(
209                "INSERT INTO users (name, age) VALUES (?, ?)".to_string(),
210                vec![DataType::Str(Some("Alice".to_string())), DataType::Int32(Some(30))],
211            )
212            .await
213            .expect("insert");
214        assert_eq!(rows_affected, 1);
215
216        let rows_affected = conn
217            .exec(
218                "INSERT INTO users (name, age) VALUES (?, ?)".to_string(),
219                vec![DataType::Str(Some("Bob".to_string())), DataType::Int32(Some(25))],
220            )
221            .await
222            .expect("insert");
223        assert_eq!(rows_affected, 1);
224
225        // Query the data
226        let rows = conn
227            .query("SELECT id, name, age FROM users ORDER BY name".to_string(), vec![])
228            .await
229            .expect("query");
230
231        assert_eq!(rows.len(), 2);
232        assert_eq!(rows[0].fields[1].name, "name");
233        if let DataType::Str(Some(ref name)) = rows[0].fields[1].value {
234            assert_eq!(name, "Alice");
235        } else {
236            panic!("Expected string value");
237        }
238    }
239}