1use std::path::Path;
16
17use crate::error::{Result, StoreError};
18
19#[derive(Clone, Debug, PartialEq)]
21pub enum Value {
22 Null,
23 Int(i64),
24 Real(f64),
25 Text(String),
26 Blob(Vec<u8>),
27}
28
29pub type Row = Vec<Value>;
31
32#[derive(Clone, Copy, Debug, PartialEq, Eq)]
35pub enum Dialect {
36 Sqlite,
37 Postgres,
38}
39
40pub trait Backend {
45 fn dialect(&self) -> Dialect;
47
48 fn exec(&self, sql: &str, params: &[Value]) -> Result<u64>;
50
51 fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>>;
53}
54
55pub struct SqliteBackend {
61 conn: rusqlite::Connection,
62}
63
64impl SqliteBackend {
65 pub fn open(path: &Path) -> Result<Self> {
68 let conn = rusqlite::Connection::open(path)?;
69 Self::apply_pragmas(&conn)?;
70 Ok(Self { conn })
71 }
72
73 pub fn in_memory() -> Result<Self> {
75 let conn = rusqlite::Connection::open_in_memory()?;
76 Ok(Self { conn })
77 }
78
79 fn apply_pragmas(conn: &rusqlite::Connection) -> Result<()> {
80 conn.busy_timeout(std::time::Duration::from_secs(5))?;
85 conn.pragma_update(None, "journal_mode", "WAL")?;
86 conn.pragma_update(None, "synchronous", "NORMAL")?;
87 Ok(())
88 }
89}
90
91impl Backend for SqliteBackend {
92 fn dialect(&self) -> Dialect {
93 Dialect::Sqlite
94 }
95
96 fn exec(&self, sql: &str, params: &[Value]) -> Result<u64> {
97 let n = self
98 .conn
99 .execute(sql, rusqlite::params_from_iter(params.iter()))?;
100 Ok(n as u64)
101 }
102
103 fn query(&self, sql: &str, params: &[Value]) -> Result<Vec<Row>> {
104 let mut stmt = self.conn.prepare(sql)?;
105 let ncols = stmt.column_count();
106 let rows = stmt
107 .query_map(rusqlite::params_from_iter(params.iter()), |row| {
108 (0..ncols)
109 .map(|i| row.get_ref(i).map(value_from_ref))
110 .collect::<rusqlite::Result<Row>>()
111 })?
112 .collect::<rusqlite::Result<Vec<Row>>>()?;
113 Ok(rows)
114 }
115}
116
117fn value_from_ref(v: rusqlite::types::ValueRef<'_>) -> Value {
118 use rusqlite::types::ValueRef;
119 match v {
120 ValueRef::Null => Value::Null,
121 ValueRef::Integer(i) => Value::Int(i),
122 ValueRef::Real(f) => Value::Real(f),
123 ValueRef::Text(t) => Value::Text(String::from_utf8_lossy(t).into_owned()),
124 ValueRef::Blob(b) => Value::Blob(b.to_vec()),
125 }
126}
127
128impl rusqlite::ToSql for Value {
129 fn to_sql(&self) -> rusqlite::Result<rusqlite::types::ToSqlOutput<'_>> {
130 use rusqlite::types::{ToSqlOutput, Value as SqlValue, ValueRef};
131 Ok(match self {
132 Value::Null => ToSqlOutput::Owned(SqlValue::Null),
133 Value::Int(i) => ToSqlOutput::Owned(SqlValue::Integer(*i)),
134 Value::Real(f) => ToSqlOutput::Owned(SqlValue::Real(*f)),
135 Value::Text(s) => ToSqlOutput::Borrowed(ValueRef::Text(s.as_bytes())),
136 Value::Blob(b) => ToSqlOutput::Borrowed(ValueRef::Blob(b)),
137 })
138 }
139}
140
141pub(crate) fn blob32(v: &Value) -> Result<[u8; 32]> {
143 match v {
144 Value::Blob(b) if b.len() == 32 => {
145 let mut out = [0u8; 32];
146 out.copy_from_slice(b);
147 Ok(out)
148 }
149 Value::Blob(b) => Err(StoreError::MalformedRow(format!(
150 "expected 32-byte hash, got {} bytes",
151 b.len()
152 ))),
153 other => Err(StoreError::MalformedRow(format!(
154 "expected blob hash, got {other:?}"
155 ))),
156 }
157}
158
159pub(crate) fn as_u64(v: &Value) -> Result<u64> {
161 match v {
162 Value::Int(i) => Ok(*i as u64),
163 other => Err(StoreError::MalformedRow(format!(
164 "expected integer, got {other:?}"
165 ))),
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[test]
174 fn round_trips_values() {
175 let db = SqliteBackend::in_memory().unwrap();
176 db.exec(
177 "CREATE TABLE t (i INTEGER, r REAL, s TEXT, b BLOB, n INTEGER)",
178 &[],
179 )
180 .unwrap();
181 db.exec(
182 "INSERT INTO t (i, r, s, b, n) VALUES (?, ?, ?, ?, ?)",
183 &[
184 Value::Int(42),
185 Value::Real(1.5),
186 Value::Text("hi".into()),
187 Value::Blob(vec![1, 2, 3]),
188 Value::Null,
189 ],
190 )
191 .unwrap();
192 let rows = db.query("SELECT i, r, s, b, n FROM t", &[]).unwrap();
193 assert_eq!(rows.len(), 1);
194 assert_eq!(
195 rows[0],
196 vec![
197 Value::Int(42),
198 Value::Real(1.5),
199 Value::Text("hi".into()),
200 Value::Blob(vec![1, 2, 3]),
201 Value::Null,
202 ]
203 );
204 }
205
206 #[test]
207 fn dialect_is_sqlite() {
208 let db = SqliteBackend::in_memory().unwrap();
209 assert_eq!(db.dialect(), Dialect::Sqlite);
210 }
211}