1use crate::format::Format;
2use crate::OpenError;
3use crate::Savepointable;
4use crate::{db, identifier::Identifier};
5use rusqlite::{params, Connection, OptionalExtension, Savepoint};
6
7use std::marker::PhantomData;
8
9mod error;
10mod iter;
11pub use iter::Iter;
12
13pub use error::Error;
14
15#[derive(Debug, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)]
16pub struct Config<'db, 'tbl> {
17 pub database: Identifier<'db>,
18 pub table: Identifier<'tbl>,
19}
20
21impl Default for Config<'static, 'static> {
22 fn default() -> Self {
23 Config {
24 database: "main".try_into().unwrap(),
25 table: "ds::set".try_into().unwrap(),
26 }
27 }
28}
29
30pub struct Set<'db, 'tbl, S, C>
32where
33 S: Format,
34 C: Savepointable,
35{
36 connection: C,
37 database: Identifier<'db>,
38 table: Identifier<'tbl>,
39 serializer: PhantomData<S>,
40}
41
42impl<S, C> Set<'static, 'static, S, C>
43where
44 S: Format,
45 C: Savepointable,
46{
47 pub fn open(connection: C) -> Result<Self, OpenError> {
48 Set::open_with_config(connection, Config::default())
49 }
50
51 pub fn unchecked_open(connection: C) -> Self {
55 Set::unchecked_open_with_config(connection, Config::default())
56 }
57}
58impl<'db, 'tbl, S, C> Set<'db, 'tbl, S, C>
59where
60 S: Format,
61 C: Savepointable,
62{
63 pub fn open_with_config(
64 mut connection: C,
65 config: Config<'db, 'tbl>,
66 ) -> Result<Self, OpenError> {
67 let database = config.database;
68 let table = config.table;
69
70 {
71 let sp = connection.savepoint()?;
72
73 let mut version = db::setup(&sp, &database, &table, "ds::set")?;
74 if version < 0 {
75 return Err(OpenError::TableVersion(version));
76 }
77 let prev_version = version;
78 if version < 1 {
79 let trailer = db::strict_without_rowid();
80 let sql_type = S::sql_type();
81
82 sp.execute(
83 &format!(
84 "CREATE TABLE {database}.{table} (
85 key {sql_type} UNIQUE PRIMARY KEY NOT NULL
86 ){trailer}"
87 ),
88 [],
89 )?;
90 version = 1;
91 }
92 if version > 1 {
93 return Err(OpenError::TableVersion(version));
94 }
95 if prev_version != version {
96 db::set_version(&sp, &database, &table, version)?;
97 }
98
99 sp.commit()?;
100 }
101 Ok(Self {
102 connection,
103 database,
104 table,
105 serializer: PhantomData,
106 })
107 }
108
109 pub fn unchecked_open_with_config(connection: C, config: Config<'db, 'tbl>) -> Self {
113 let database = config.database;
114 let table = config.table;
115
116 Self {
117 connection,
118 database,
119 table,
120 serializer: PhantomData,
121 }
122 }
123
124 pub fn insert(&mut self, value: &S::In) -> Result<bool, Error<S>> {
125 let database = &self.database;
126 let table = &self.table;
127 let serialized = S::serialize(value).map_err(Error::Serialize)?;
128
129 let sp = self.connection.savepoint()?;
130 let ret = if db::has_upsert() {
131 sp.prepare_cached(&format!(
132 "INSERT INTO {database}.{table} (key) VALUES (?) ON CONFLICT DO NOTHING"
133 ))?
134 .execute(params![serialized])?;
135 sp.changes() > 0
136 } else if Self::contains_serialized(database, table, &sp, &serialized)? {
137 false
138 } else {
139 sp.prepare_cached(&format!("INSERT INTO {database}.{table} (key) VALUES (?)"))?
140 .execute(params![serialized])?;
141 true
142 };
143 sp.commit()?;
144 Ok(ret)
145 }
146
147 pub fn contains(&mut self, value: &S::In) -> Result<bool, Error<S>> {
148 let serialized = S::serialize(value).map_err(|e| Error::Serialize(e))?;
149 Self::contains_serialized(
150 &self.database,
151 &self.table,
152 &*self.connection.savepoint()?,
153 &serialized,
154 )
155 }
156
157 fn contains_serialized(
158 database: &Identifier,
159 table: &Identifier,
160 connection: &Connection,
161 value: &S::Buffer,
162 ) -> Result<bool, Error<S>> {
163 Ok(connection
164 .prepare_cached(&format!("SELECT 1 FROM {database}.{table} WHERE key = ?"))?
165 .query_row(params![value], |_| Ok(()))
166 .optional()?
167 .is_some())
168 }
169
170 pub fn remove(&mut self, value: &S::In) -> Result<bool, Error<S>> {
171 let database = &self.database;
172 let table = &self.table;
173 let serialized = S::serialize(value).map_err(Error::Serialize)?;
174
175 let sp = self.connection.savepoint()?;
176 let changes = sp
177 .prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?
178 .execute(params![serialized])?;
179
180 sp.commit()?;
181
182 Ok(changes > 0)
183 }
184
185 pub fn clear(&mut self) -> Result<(), Error<S>> {
186 let database = &self.database;
187 let table = &self.table;
188 let sp = self.connection.savepoint()?;
189 sp.prepare_cached(&format!("DELETE FROM {database}.{table}"))?
190 .execute([])?;
191 sp.commit()?;
192 Ok(())
193 }
194
195 pub fn first(&mut self) -> Result<Option<S::Out>, Error<S>> {
196 let database = &self.database;
197 let table = &self.table;
198
199 let serialized: Option<S::Buffer> = self
200 .connection
201 .savepoint()?
202 .prepare_cached(&format!(
203 "SELECT key FROM {database}.{table} ORDER BY key ASC"
204 ))?
205 .query_row([], |row| row.get(0))
206 .optional()?;
207
208 serialized
209 .map(|s| S::deserialize(&s))
210 .transpose()
211 .map_err(Error::Deserialize)
212 }
213
214 pub fn last(&mut self) -> Result<Option<S::Out>, Error<S>> {
215 let database = &self.database;
216 let table = &self.table;
217 let serialized: Option<S::Buffer> = self
218 .connection
219 .savepoint()?
220 .prepare_cached(&format!(
221 "SELECT key FROM {database}.{table} ORDER BY key DESC"
222 ))?
223 .query_row([], |row| row.get(0))
224 .optional()?;
225
226 serialized
227 .map(|s| S::deserialize(&s))
228 .transpose()
229 .map_err(Error::Deserialize)
230 }
231
232 pub fn len(&mut self) -> Result<u64, Error<S>> {
233 let database = &self.database;
234 let table = &self.table;
235 Ok(self
236 .connection
237 .savepoint()?
238 .prepare_cached(&format!("SELECT COUNT(*) FROM {database}.{table}"))?
239 .query_row([], |row| row.get(0))?)
240 }
241
242 pub fn is_empty(&mut self) -> Result<bool, Error<S>> {
243 let database = &self.database;
244 let table = &self.table;
245 Ok(self
246 .connection
247 .savepoint()?
248 .prepare_cached(&format!("SELECT 1 FROM {database}.{table} LIMIT 1"))?
249 .query_row([], |_| Ok(()))
250 .optional()?
251 .is_none())
252 }
253
254 pub fn iter(&mut self) -> Result<Iter<'db, 'tbl, S, Savepoint<'_>>, Error<S>> {
255 Ok(Iter::new(
256 self.connection.savepoint()?,
257 self.database.clone(),
258 self.table.clone(),
259 )?)
260 }
261
262 pub fn retain<F>(&mut self, mut f: F) -> Result<(), Error<S>>
269 where
270 F: FnMut(S::Out) -> bool,
271 {
272 let database = &self.database;
273 let table = &self.table;
274
275 let sp = self.connection.savepoint()?;
276 {
277 let mut maybe_serialized = sp
278 .prepare_cached(&format!(
279 "SELECT key FROM {database}.{table} ORDER BY key ASC LIMIT 1"
280 ))?
281 .query_row([], |row| row.get(0))
282 .optional()?;
283 let mut deleter =
284 sp.prepare_cached(&format!("DELETE FROM {database}.{table} WHERE key = ?"))?;
285 let mut select_next = sp.prepare_cached(&format!(
286 "SELECT key FROM {database}.{table} WHERE key > ? ORDER BY key ASC LIMIT 1"
287 ))?;
288 while let Some(serialized) = maybe_serialized {
289 let item = S::deserialize(&serialized).map_err(|e| Error::Deserialize(e))?;
290 if !f(item) {
291 deleter.execute(params![serialized])?;
292 }
293 maybe_serialized = select_next
294 .query_row(params![serialized], |row| row.get(0))
295 .optional()?;
296 }
297 }
298 sp.commit()?;
299 Ok(())
300 }
301}
302
303#[cfg(test)]
304mod test;