1use std::path::Path;
16
17use crate::error::{Result, StoreError};
18
19#[derive(Clone, Debug, PartialEq)]
21pub enum Value {
22 Null,
23 Int(i64),
24 Real(f64),
25 Text(String),
26 Blob(Vec<u8>),
27}
28
29pub type Row = Vec<Value>;
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum Dialect {
36 Sqlite,
37 Postgres,
38}
39
40pub trait Backend {
45 fn dialect(&self) -> Dialect;
47
48 fn exec(&self, sql: &str, params: &[Value]) -> Result<u64>;
50
51 fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
53}
54
55pub struct SqliteBackend {
61 conn: rusqlite::Connection,
62}
63
64impl SqliteBackend {
65 pub fn open(path: &Path) -> Result<Self> {
68 let conn = rusqlite::Connection::open(path)?;
69 Self::apply_pragmas(&conn)?;
70 Ok(Self { conn })
71 }
72
73 pub fn in_memory() -> Result<Self> {
75 let conn = rusqlite::Connection::open_in_memory()?;
76 Ok(Self { conn })
77 }
78
79 pub fn from_connection(conn: rusqlite::Connection) -> Self {
87 Self { conn }
88 }
89
90 pub fn connection(&self) -> &rusqlite::Connection {
95 &self.conn
96 }
97
98 fn apply_pragmas(conn: &rusqlite::Connection) -> Result<()> {
99 conn.busy_timeout(std::time::Duration::from_secs(5))?;
104 conn.pragma_update(None, "journal_mode", "WAL")?;
105 conn.pragma_update(None, "synchronous", "NORMAL")?;
106 Ok(())
107 }
108}
109
110impl Backend for SqliteBackend {
111 fn dialect(&self) -> Dialect {
112 Dialect::Sqlite
113 }
114
115 fn exec(&self, sql: &str, params: &[Value]) -> Result<u64> {
116 let n = self
117 .conn
118 .execute(sql, rusqlite::params_from_iter(params.iter()))?;
119 Ok(n as u64)
120 }
121
122 fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>> {
123 let mut stmt = self.conn.prepare(sql)?;
124 let ncols = stmt.column_count();
125 let rows = stmt
126 .query_map(rusqlite::params_from_iter(params.iter()), |row| {
127 (0..ncols)
128 .map(|i| row.get_ref(i).map(value_from_ref))
129 .collect::<rusqlite::Result<Row>>()
130 })?
131 .collect::<rusqlite::Result<Vec<Row>>>()?;
132 Ok(rows)
133 }
134}
135
136fn value_from_ref(v: rusqlite::types::ValueRef<'_>) -> Value {
137 use rusqlite::types::ValueRef;
138 match v {
139 ValueRef::Null => Value::Null,
140 ValueRef::Integer(i) => Value::Int(i),
141 ValueRef::Real(f) => Value::Real(f),
142 ValueRef::Text(t) => Value::Text(String::from_utf8_lossy(t).into_owned()),
143 ValueRef::Blob(b) => Value::Blob(b.to_vec()),
144 }
145}
146
147impl rusqlite::ToSql for Value {
148 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
149 use rusqlite::types::{ToSqlOutput, Value as SqlValue, ValueRef};
150 Ok(match self {
151 Value::Null => ToSqlOutput::Owned(SqlValue::Null),
152 Value::Int(i) => ToSqlOutput::Owned(SqlValue::Integer(*i)),
153 Value::Real(f) => ToSqlOutput::Owned(SqlValue::Real(*f)),
154 Value::Text(s) => ToSqlOutput::Borrowed(ValueRef::Text(s.as_bytes())),
155 Value::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
156 })
157 }
158}
159
160pub(crate) fn blob32(v: &Value) -> Result<[u8; 32]> {
162 match v {
163 Value::Blob(b) if b.len() == 32 => {
164 let mut out = [0u8; 32];
165 out.copy_from_slice(b);
166 Ok(out)
167 }
168 Value::Blob(b) => Err(StoreError::MalformedRow(format!(
169 "expected 32-byte hash, got {} bytes",
170 b.len()
171 ))),
172 other => Err(StoreError::MalformedRow(format!(
173 "expected blob hash, got {other:?}"
174 ))),
175 }
176}
177
178pub(crate) fn as_u64(v: &Value) -> Result<u64> {
180 match v {
181 Value::Int(i) => Ok(*i as u64),
182 other => Err(StoreError::MalformedRow(format!(
183 "expected integer, got {other:?}"
184 ))),
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn round_trips_values() {
194 let db = SqliteBackend::in_memory().unwrap();
195 db.exec(
196 "CREATE TABLE t (i INTEGER, r REAL, s TEXT, b BLOB, n INTEGER)",
197 &[],
198 )
199 .unwrap();
200 db.exec(
201 "INSERT INTO t (i, r, s, b, n) VALUES (?, ?, ?, ?, ?)",
202 &[
203 Value::Int(42),
204 Value::Real(1.5),
205 Value::Text("hi".into()),
206 Value::Blob(vec![1, 2, 3]),
207 Value::Null,
208 ],
209 )
210 .unwrap();
211 let rows = db.query("SELECT i, r, s, b, n FROM t", &[]).unwrap();
212 assert_eq!(rows.len(), 1);
213 assert_eq!(
214 rows[0],
215 vec![
216 Value::Int(42),
217 Value::Real(1.5),
218 Value::Text("hi".into()),
219 Value::Blob(vec![1, 2, 3]),
220 Value::Null,
221 ]
222 );
223 }
224
225 #[test]
226 fn dialect_is_sqlite() {
227 let db = SqliteBackend::in_memory().unwrap();
228 assert_eq!(db.dialect(), Dialect::Sqlite);
229 }
230
231 #[test]
232 fn from_connection_wraps_and_shares_the_database() {
233 let conn = rusqlite::Connection::open_in_memory().unwrap();
235 let db = SqliteBackend::from_connection(conn);
236
237 db.exec("CREATE TABLE t (x INTEGER)", &[]).unwrap();
239 db.connection()
242 .execute("INSERT INTO t VALUES (7)", [])
243 .unwrap();
244
245 let rows = db.query("SELECT x FROM t", &[]).unwrap();
246 assert_eq!(rows, vec![vec![Value::Int(7)]]);
247 }
248}