1#![cfg(feature = "sqlite")]
2use super::general::Bucket;
3use crate::constants::DESCRIBE_MAX;
4use crate::data::{Integer, Numeric};
5use crate::prelude::*;
6use fnv::FnvHashSet;
7use rusqlite::{params, Connection};
8use serde::de::DeserializeOwned;
9use std::marker::PhantomData;
10use serde::Serialize;
11use std::cell::Cell;
12
13fn vec_to_blob<T>(hash: &[T]) -> &[u8] {
14 let data = hash.as_ptr() as *const u8;
15 unsafe { std::slice::from_raw_parts(data, hash.len() * std::mem::size_of::<T>()) }
16}
17
18fn blob_to_vec<T>(blob: &[u8]) -> &[T] {
19 let data = blob.as_ptr() as *const T;
20 unsafe { std::slice::from_raw_parts(data, blob.len() / std::mem::size_of::<T>()) }
21}
22
23fn query_bucket(blob: &[u8], table_name: &str, connection: &Connection) -> Result<Bucket> {
24 let mut stmt = connection.prepare_cached(&format!(
25 "
26SELECT (id) FROM {}
27WHERE hash = ?
28 ",
29 table_name
30 ))?;
31 let mut rows = stmt.query(params![blob])?;
32
33 let mut bucket = FnvHashSet::default();
34 while let Some(row) = rows.next()? {
35 bucket.insert(row.get(0)?);
36 }
37 Ok(bucket)
38}
39
40fn make_table(table_name: &str, connection: &Connection) -> Result<()> {
41 connection.execute_batch(&format!(
42 "CREATE TABLE IF NOT EXISTS {} (
43 hash BLOB,
44 id INTEGER
45 )
46 ",
47 table_name
48 ))?;
49 Ok(())
50}
51
52fn insert_table<K>(
53 table_name: &str,
54 hash: &Vec<K>,
55 idx: u32,
56 connection: &Connection,
57) -> Result<usize> {
58 let blob = vec_to_blob(hash);
59 let mut stmt = connection.prepare_cached(&format!(
60 "
61INSERT INTO {} (hash, id)
62VALUES (?1, ?2)
63 ",
64 table_name
65 ))?;
66 let idx = stmt.execute(params![blob, idx])?;
67 Ok(idx)
68}
69
70fn hash_table_stats(
71 table_name: &str,
72 limit: u32,
73 conn: &Connection,
74) -> Result<(f64, f64, u32, u32)> {
75 let mut stmt = conn.prepare_cached(&format!(
76 "
77SELECT
78 avg(c) as mean,
79 avg(c * c) - avg(c) * avg(c) as variance,
80 min(c) as minimum,
81 max(c) as maximum
82
83FROM (
84 SELECT count(id) as c
85 FROM {}
86 GROUP BY hash
87 LIMIT ?
88);
89 ",
90 table_name
91 ))?;
92 let out = stmt.query_row(params![limit], |row| {
93 let mean: f64 = row.get(0)?;
94 let variance: f64 = row.get(1)?;
95 let stdev = variance.powf(0.5);
96 let minimum: u32 = row.get(2)?;
97 let maximum: u32 = row.get(3)?;
98 Ok((mean, stdev, minimum, maximum))
99 })?;
100 Ok(out)
101}
102
103pub struct SqlTable<N, K>
108where
109 N: Numeric,
110 K: Integer,
111{
112 n_hash_tables: usize,
113 only_index_storage: bool, counter: u32,
115 pub conn: Connection,
116 table_names: Vec<String>,
117 pub committed: Cell<bool>,
118 phantom: PhantomData<(N, K)>,
119}
120
121fn fmt_table_name(hash_table: usize) -> String {
122 format!("hash_table_{}", hash_table)
123}
124
125fn get_table_names(n_hash_tables: usize) -> Vec<String> {
126 let mut table_names = Vec::with_capacity(n_hash_tables);
127 for idx in 0..n_hash_tables {
128 let table_name = fmt_table_name(idx);
129 table_names.push(table_name);
130 }
131 table_names
132}
133
134fn get_unique_hash_int(n_hash_tables: usize, conn: &Connection) -> Result<FnvHashSet<i32>> {
135 let mut hash_numbers = FnvHashSet::default();
136 for table_name in get_table_names(n_hash_tables) {
137 let mut stmt = conn.prepare(&format!["SELECT hash FROM {} LIMIT 100;", table_name])?;
138 let mut rows = stmt.query([])?;
139
140 while let Some(r) = rows.next()? {
141 let blob: Vec<u8> = r.get(0)?;
142 let hash = blob_to_vec(&blob);
143 hash.iter().for_each(|&v| {
144 hash_numbers.insert(v);
145 })
146 }
147 }
148 Ok(hash_numbers)
149}
150
151fn init_table(conn: &Connection, table_names: &[String]) -> Result<()> {
152 for table_name in table_names {
153 make_table(&table_name, &conn)?;
154 }
155 Ok(())
156}
157
158fn init_db_setttings(conn: &Connection) -> Result<()> {
159 conn.execute_batch(
160 "PRAGMA journal_mode = OFF;
161 PRAGMA synchronous = OFF;
162 PRAGMA cache_size = 100000;
163 PRAGMA main.locking_mode=EXCLUSIVE;",
164 )?;
165 Ok(())
166}
167
168impl<N, K> SqlTable<N, K>
169where
170 N: Numeric,
171 K: Integer,
172{
173 fn get_table_name_put(&self, hash_table: usize) -> Result<&str> {
174 let opt = self.table_names.get(hash_table);
175 match opt {
176 Some(tbl_name) => Ok(&tbl_name[..]),
177 None => Err(Error::TableNotExist),
178 }
179 }
180
181 pub fn init_from_conn(
182 n_hash_tables: usize,
183 only_index_storage: bool,
184 conn: Connection,
185 ) -> Result<Self> {
186 let table_names = get_table_names(n_hash_tables);
187 init_db_setttings(&conn)?;
188 init_table(&conn, &table_names)?;
189 let sql = SqlTable {
190 n_hash_tables,
191 only_index_storage,
192 counter: 0,
193 conn,
194 table_names,
195 committed: Cell::new(false),
196 phantom: PhantomData,
197 };
198 sql.init_transaction()?;
199 Ok(sql)
200 }
201
202 pub fn commit(&self) -> Result<()> {
203 if !self.committed.replace(true) {
204 self.conn.execute_batch("COMMIT TRANSACTION;")?;
205 }
206 Ok(())
207 }
208
209 pub fn init_transaction(&self) -> Result<()> {
210 self.committed.set(false);
211 self.conn.execute_batch("BEGIN TRANSACTION;")?;
212 Ok(())
213 }
214
215 pub fn to_mem(&mut self) -> Result<()> {
216 let mut new_con = rusqlite::Connection::open_in_memory()?;
217 {
218 let backup = rusqlite::backup::Backup::new(&self.conn, &mut new_con)?;
219 backup.step(-1)?;
220 }
221 self.conn = new_con;
222 self.committed.set(true);
223 Ok(())
224 }
225
226 pub fn index_hash(&self) -> Result<()> {
227 self.commit()?;
228 for tbl_name in get_table_names(self.n_hash_tables) {
229 self.conn.execute_batch(&format!(
230 "
231 CREATE INDEX hash_index_{}
232 ON {} (hash);",
233 tbl_name, tbl_name
234 ))?;
235 }
236 Ok(())
237 }
238}
239
240impl<N, K> HashTables<N, K> for SqlTable<N, K>
241where
242 N: Numeric,
243 K: Integer,
244{
245 fn new(n_hash_tables: usize, only_index_storage: bool, db_path: &str) -> Result<Box<Self>> {
246 let path = std::path::Path::new(db_path);
247 let conn = Connection::open(path)?;
248 SqlTable::init_from_conn(n_hash_tables, only_index_storage, conn).map(|tbl| Box::new(tbl))
249 }
250
251 fn put(&mut self, hash: Vec<K>, _d: &[N], hash_table: usize) -> Result<u32> {
252 let idx = self.counter;
254
255 let table_name = self.get_table_name_put(hash_table)?;
257 let r = insert_table(&table_name, &hash, idx, &self.conn);
258
259 if hash_table == self.n_hash_tables - 1 {
261 self.counter += 1
262 };
263
264 match r {
265 Ok(_) => Ok(idx),
266 Err(Error::SqlFailure(_)) => Ok(idx), Err(e) => Err(Error::Failed(format!("{:?}", e))),
268 }
269 }
270
271 fn query_bucket(&self, hash: &[K], hash_table: usize) -> Result<Bucket> {
273 self.commit()?;
274 let table_name = fmt_table_name(hash_table);
275 let blob = vec_to_blob(hash);
276 let res = query_bucket(blob, &table_name, &self.conn);
277
278 match res {
279 Ok(bucket) => Ok(bucket),
280 Err(e) => Err(Error::Failed(format!("{:?}", e))),
281 }
282 }
283
284 fn describe(&self) -> Result<String> {
285 let mut stmt = self.conn.prepare(
286 r#"SELECT count(*) FROM sqlite_master
287WHERE type='table' AND type LIKE '%hash%';"#,
288 )?;
289
290 let row: String = stmt.query_row([], |row| {
291 let i: i64 = row.get_unwrap(0);
292 Ok(i.to_string())
293 })?;
294 let mut out = String::from(format!("No. of tables: {}\n", row));
295
296 out.push_str("Unique hash values:\n");
297 let hv = get_unique_hash_int(self.n_hash_tables, &self.conn).unwrap();
298 out.push_str(&format!("{:?}", hv));
299
300 let tables = get_table_names(self.n_hash_tables);
301 let mut avg = Vec::with_capacity(self.n_hash_tables);
302 let mut std_dev = Vec::with_capacity(self.n_hash_tables);
303 let mut min = Vec::with_capacity(self.n_hash_tables);
304 let mut max = Vec::with_capacity(self.n_hash_tables);
305
306 let i = std::cmp::min(3, self.n_hash_tables);
308 for table_name in &tables[..i] {
309 let stats = hash_table_stats(&table_name, DESCRIBE_MAX, &self.conn)?;
310 avg.push(stats.0);
311 std_dev.push(stats.1);
312 min.push(stats.2);
313 max.push(stats.3);
314 }
315 out.push_str("\nHash collisions:\n");
316 out.push_str(&format!("avg:\t{:?}\n", avg));
317 out.push_str(&format!("std-dev:\t{:?}\n", std_dev));
318 out.push_str(&format!("min:\t{:?}\n", min));
319 out.push_str(&format!("max:\t{:?}\n", max));
320 Ok(out)
321 }
322
323 fn store_hashers<H: VecHash<N, K> + Serialize>(&mut self, hashers: &[H]) -> Result<()> {
324 let buf: Vec<u8> = bincode::serialize(hashers)?;
325
326 self.conn.execute_batch(
328 "CREATE TABLE state (
329 hashers BLOB
330 )",
331 )?;
332 let mut stmt = self
333 .conn
334 .prepare("INSERT INTO state (hashers) VALUES (?1)")?;
335
336 self.commit()?;
338 stmt.execute(params![buf])?;
339 self.init_transaction()?;
340 Ok(())
341 }
342
343 fn load_hashers<H: VecHash<N, K> + DeserializeOwned>(&self) -> Result<Vec<H>> {
344 let mut stmt = self.conn.prepare("SELECT * FROM state;")?;
345 let buf: Vec<u8> = stmt.query_row([], |row| {
346 let v: Vec<u8> = row.get_unwrap(0);
347 Ok(v)
348 })?;
349 let hashers: Vec<H> = bincode::deserialize(&buf)?;
350 Ok(hashers)
351 }
352
353 fn get_unique_hash_int(&self) -> FnvHashSet<i32> {
354 get_unique_hash_int(self.n_hash_tables, &self.conn).unwrap()
355 }
356}
357
358#[cfg(test)]
359mod test {
360 use super::*;
361 use crate::table::sqlite_mem::SqlTableMem;
362
363 #[test]
364 fn test_sql_table_init() {
365 let sql = SqlTableMem::<f32, i8>::new(1, true, ".").unwrap();
366 let mut stmt = sql
367 .conn
368 .prepare(&format!("SELECT * FROM {}", sql.table_names[0]))
369 .expect("query failed");
370 stmt.query([]).expect("query failed");
371 }
372
373 #[test]
374 fn test_sql_crud() {
375 let mut sql = *SqlTableMem::new(1, true, ".").unwrap();
376 let v = vec![1., 2.];
377 for hash in &[vec![1, 2], vec![2, 3]] {
378 sql.put(hash.clone(), &v, 0).unwrap();
379 }
380 let hash = vec![1, 2];
382 sql.put(hash.clone(), &v, 0).unwrap();
383 let bucket = sql.query_bucket(&hash, 0);
384 println!("{:?}", &bucket);
385 match bucket {
386 Ok(b) => assert!(b.contains(&0)),
387 _ => assert!(false),
388 }
389 }
390
391 #[test]
392 fn test_blob_hash_casting() {
393 for hash in vec![
394 &vec![2, 3, 4],
395 &vec![-124, 32, 89],
396 &vec![1, 2, 3, 4, 5, 6],
397 &vec![-12, -2, -3, 1, 2, 3, 4, 5, 6],
398 ] {
399 let hash = &hash[..];
400 let blob = vec_to_blob(hash);
401 let hash_back: &[i32] = blob_to_vec(blob);
402 assert_eq!(hash, hash_back)
403 }
404 }
405
406 #[test]
407 fn test_in_mem_to_disk() {
408 let mut sql = *SqlTableMem::<f32, i8>::new(1, true, ".").unwrap();
409 let v = vec![1., 2.];
410 for hash in &[vec![1, 2], vec![2, 3]] {
411 sql.put(hash.clone(), &v, 0).unwrap();
412 }
413 sql.commit().unwrap();
414 let p = "./delete.db3";
415 sql.to_db(p).unwrap();
416
417 let mut sql = SqlTable::<f32, i8>::new(1, true, p).unwrap();
418 sql.to_mem().unwrap();
419 assert_eq!(sql.query_bucket(&vec![1, 2], 0).unwrap().take(&0), Some(0));
420 std::fs::remove_file(p).unwrap();
421 }
422}