sqlite_collections/ds/
map.rs

1use crate::format::Format;
2use crate::OpenError;
3use crate::Savepointable;
4use crate::{db, identifier::Identifier};
5use rusqlite::Savepoint;
6use rusqlite::{params, Connection, OptionalExtension};
7
8use std::marker::PhantomData;
9
10mod error;
11pub mod iter;
12
13use iter::KeyIter;
14use iter::KeyValueIter;
15use iter::ValueIter;
16
17pub use error::Error;
18
19#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
20pub struct Config<'db, 'tbl> {
21    pub database: Identifier<'db>,
22    pub table: Identifier<'tbl>,
23}
24
25impl Default for Config<'static, 'static> {
26    fn default() -> Self {
27        Config {
28            database: "main".try_into().unwrap(),
29            table: "ds::map".try_into().unwrap(),
30        }
31    }
32}
33
34/// Deterministic store set.
35pub struct Map<'db, 'tbl, K, V, C>
36where
37    K: Format,
38    V: Format,
39    C: Savepointable,
40{
41    connection: C,
42    database: Identifier<'db>,
43    table: Identifier<'tbl>,
44    key_serializer: PhantomData<K>,
45    value_serializer: PhantomData<V>,
46}
47
48impl<K, V, C> Map<'static, 'static, K, V, C>
49where
50    K: Format,
51    V: Format,
52    C: Savepointable,
53{
54    pub fn open(connection: C) -> Result<Self, OpenError> {
55        Map::open_with_config(connection, Config::default())
56    }
57
58    /// Open a set without creating it or checking if it exists.  This is safe
59    /// if you call a safe open in (or under) the same transaction or savepoint
60    /// beforehand.
61    pub fn unchecked_open(connection: C) -> Self {
62        Map::unchecked_open_with_config(connection, Config::default())
63    }
64}
65
66impl<'db, 'tbl, K, V, C> Map<'db, 'tbl, K, V, C>
67where
68    K: Format,
69    V: Format,
70    C: Savepointable,
71{
72    pub fn open_with_config(
73        mut connection: C,
74        config: Config<'db, 'tbl>,
75    ) -> Result<Self, OpenError> {
76        let database = config.database;
77        let table = config.table;
78
79        {
80            let sp = connection.savepoint()?;
81
82            let mut version = db::setup(&sp, &database, &table, "ds::map")?;
83            if version < 0 {
84                return Err(OpenError::TableVersion(version));
85            }
86            let prev_version = version;
87            if version < 1 {
88                let trailer = db::strict_without_rowid();
89                let sql_type = K::sql_type();
90
91                sp.execute(
92                    &format!(
93                        "CREATE TABLE {database}.{table} (
94                            key {sql_type} UNIQUE PRIMARY KEY NOT NULL,
95                            value {sql_type} NOT NULL
96                        ){trailer}"
97                    ),
98                    [],
99                )?;
100                version = 1;
101            }
102            if version > 1 {
103                return Err(OpenError::TableVersion(version));
104            }
105            if prev_version != version {
106                db::set_version(&sp, &database, &table, version)?;
107            }
108
109            sp.commit()?;
110        }
111        Ok(Self {
112            connection,
113            database,
114            table,
115            key_serializer: PhantomData,
116            value_serializer: PhantomData,
117        })
118    }
119
120    /// Open a set without creating it or checking if it exists.  This is safe
121    /// if you call a safe open in (or under) the same transaction or savepoint
122    /// beforehand.
123    pub fn unchecked_open_with_config(connection: C, config: Config<'db, 'tbl>) -> Self {
124        let database = config.database;
125        let table = config.table;
126
127        Self {
128            connection,
129            database,
130            table,
131            key_serializer: PhantomData,
132            value_serializer: PhantomData,
133        }
134    }
135
136    pub fn insert(&mut self, key: &K::In, value: &V::In) -> Result<Option<V::Out>, Error<K, V>> {
137        let database = &self.database;
138        let table = &self.table;
139        let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
140        let value = V::serialize(value).map_err(|e| Error::ValueSerialize(e))?;
141
142        let sp = self.connection.savepoint()?;
143        let prev_value = Self::get_from_serialized(database, table, &sp, &key)?;
144        if db::has_upsert() {
145            sp.prepare_cached(&format!(
146                "INSERT INTO {database}.{table} (key, value) VALUES (?, ?) ON CONFLICT DO UPDATE SET value=excluded.value"
147            ))?
148            .execute(params![key, value])?;
149        } else if prev_value.is_some() {
150            sp.prepare_cached(&format!(
151                "UPDATE {database}.{table} SET value=? WHERE key=?"
152            ))?
153            .execute(params![value, key])?;
154        } else {
155            sp.prepare_cached(&format!(
156                "INSERT INTO {database}.{table} (key, value) VALUES (?, ?)"
157            ))?
158            .execute(params![key, value])?;
159        };
160        sp.commit()?;
161        Ok(prev_value)
162    }
163
164    pub fn len(&mut self) -> Result<u64, Error<K, V>> {
165        let database = &self.database;
166        let table = &self.table;
167        Ok(self
168            .connection
169            .savepoint()?
170            .prepare_cached(&format!("SELECT COUNT(*) FROM {database}.{table}"))?
171            .query_row([], |row| row.get(0))?)
172    }
173
174    pub fn is_empty(&mut self) -> Result<bool, Error<K, V>> {
175        let database = &self.database;
176        let table = &self.table;
177        Ok(self
178            .connection
179            .savepoint()?
180            .prepare_cached(&format!("SELECT 1 FROM {database}.{table} LIMIT 1"))?
181            .query_row([], |_| Ok(()))
182            .optional()?
183            .is_none())
184    }
185
186    fn contains_from_serialized(
187        database: &Identifier,
188        table: &Identifier,
189        connection: &Connection,
190        key: &K::Buffer,
191    ) -> Result<bool, Error<K, V>> {
192        Ok(connection
193            .prepare_cached(&format!("SELECT 1 FROM {database}.{table} WHERE key = ?"))?
194            .query_row(params![key], |_| Ok(()))
195            .optional()?
196            .is_some())
197    }
198
199    fn get_from_serialized(
200        database: &Identifier,
201        table: &Identifier,
202        connection: &Connection,
203        key: &K::Buffer,
204    ) -> Result<Option<V::Out>, Error<K, V>> {
205        Ok(
206            Self::get_serialized_from_serialized(database, table, connection, key)?
207                .map(|b| V::deserialize(&b).map_err(|e| Error::ValueDeserialize(e)))
208                .transpose()?,
209        )
210    }
211
212    fn get_serialized_from_serialized(
213        database: &Identifier,
214        table: &Identifier,
215        connection: &Connection,
216        key: &K::Buffer,
217    ) -> Result<Option<V::Buffer>, Error<K, V>> {
218        Ok(connection
219            .prepare_cached(&format!(
220                "SELECT value FROM {database}.{table} WHERE key = ?"
221            ))?
222            .query_row(params![key], |row| row.get(0))
223            .optional()?)
224    }
225
226    pub fn get(&mut self, key: &K::In) -> Result<Option<V::Out>, Error<K, V>> {
227        let database = &self.database;
228        let table = &self.table;
229        let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
230
231        let sp = self.connection.savepoint()?;
232        let result = Self::get_from_serialized(database, table, &sp, &key)?;
233        sp.commit()?;
234        Ok(result)
235    }
236
237    pub fn remove(&mut self, key: &K::In) -> Result<Option<V::Out>, Error<K, V>> {
238        let database = &self.database;
239        let table = &self.table;
240        let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
241
242        let sp = self.connection.savepoint()?;
243        let result = Self::get_from_serialized(database, table, &sp, &key)?;
244        sp.prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?
245            .execute(params![key])?;
246        sp.commit()?;
247        Ok(result)
248    }
249
250    pub fn contains_key(&mut self, key: &K::In) -> Result<bool, Error<K, V>> {
251        let database = &self.database;
252        let table = &self.table;
253        let key = K::serialize(key).map_err(|e| Error::KeySerialize(e))?;
254
255        let sp = self.connection.savepoint()?;
256        let result = Self::contains_from_serialized(database, table, &sp, &key)?;
257        sp.commit()?;
258        Ok(result)
259    }
260
261    pub fn clear(&mut self) -> Result<(), Error<K, V>> {
262        let database = &self.database;
263        let table = &self.table;
264        let sp = self.connection.savepoint()?;
265        sp.prepare_cached(&format!("DELETE FROM {database}.{table}"))?
266            .execute([])?;
267        sp.commit()?;
268        Ok(())
269    }
270
271    pub fn iter(&mut self) -> Result<KeyValueIter<'db, 'tbl, K, V, Savepoint<'_>>, Error<K, V>> {
272        Ok(KeyValueIter::new(
273            self.connection.savepoint()?,
274            self.database.clone(),
275            self.table.clone(),
276        )?)
277    }
278
279    pub fn keys(&mut self) -> Result<KeyIter<'db, 'tbl, K, V, Savepoint<'_>>, Error<K, V>> {
280        Ok(KeyIter::new(
281            self.connection.savepoint()?,
282            self.database.clone(),
283            self.table.clone(),
284        )?)
285    }
286    pub fn values(&mut self) -> Result<ValueIter<'db, 'tbl, K, V, Savepoint<'_>>, Error<K, V>> {
287        Ok(ValueIter::new(
288            self.connection.savepoint()?,
289            self.database.clone(),
290            self.table.clone(),
291        )?)
292    }
293
294    /// Retains only the elements specified by the predicate.
295    ///
296    /// In other words, remove all pairs (k, v) for which f(k, v) returns false.
297    /// The elements are visited in unsorted (and unspecified) order.
298    ///
299    /// This is all done in a single transaction.
300    pub fn retain<F>(&mut self, mut f: F) -> Result<(), Error<K, V>>
301    where
302        F: FnMut(K::Out, V::Out) -> bool,
303    {
304        let database = &self.database;
305        let table = &self.table;
306
307        let sp = self.connection.savepoint()?;
308        {
309            let mut maybe_serialized = sp
310                .prepare_cached(&format!(
311                    "SELECT key, value FROM {database}.{table} ORDER BY key ASC LIMIT 1"
312                ))?
313                .query_row([], |row| Ok((row.get(0)?, row.get(1)?)))
314                .optional()?;
315            let mut deleter =
316                sp.prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?;
317            let mut select_next = sp.prepare_cached(&format!(
318                "SELECT key, value FROM {database}.{table} WHERE key > ? ORDER BY key ASC LIMIT 1"
319            ))?;
320            while let Some((serialized_key, value)) = maybe_serialized {
321                let key = K::deserialize(&serialized_key).map_err(Error::KeyDeserialize)?;
322                let value = V::deserialize(&value).map_err(Error::ValueDeserialize)?;
323                if !f(key, value) {
324                    deleter.execute(params![serialized_key])?;
325                }
326                maybe_serialized = select_next
327                    .query_row(params![serialized_key], |row| {
328                        Ok((row.get(0)?, row.get(1)?))
329                    })
330                    .optional()?;
331            }
332        }
333        sp.commit()?;
334        Ok(())
335    }
336}
337
338#[cfg(test)]
339mod test;