lsh_rs2/table/
sqlite.rs

1#![cfg(feature = "sqlite")]
2use super::general::Bucket;
3use crate::constants::DESCRIBE_MAX;
4use crate::data::{Integer, Numeric};
5use crate::prelude::*;
6use fnv::FnvHashSet;
7use rusqlite::{params, Connection};
8use serde::de::DeserializeOwned;
9use std::marker::PhantomData;
10use serde::Serialize;
11use std::cell::Cell;
12
13fn vec_to_blob<T>(hash: &[T]) -> &[u8] {
14    let data = hash.as_ptr() as *const u8;
15    unsafe { std::slice::from_raw_parts(data, hash.len() * std::mem::size_of::<T>()) }
16}
17
18fn blob_to_vec<T>(blob: &[u8]) -> &[T] {
19    let data = blob.as_ptr() as *const T;
20    unsafe { std::slice::from_raw_parts(data, blob.len() / std::mem::size_of::<T>()) }
21}
22
23fn query_bucket(blob: &[u8], table_name: &str, connection: &Connection) -> Result<Bucket> {
24    let mut stmt = connection.prepare_cached(&format!(
25        "
26SELECT (id) FROM {}
27WHERE hash = ?
28        ",
29        table_name
30    ))?;
31    let mut rows = stmt.query(params![blob])?;
32
33    let mut bucket = FnvHashSet::default();
34    while let Some(row) = rows.next()? {
35        bucket.insert(row.get(0)?);
36    }
37    Ok(bucket)
38}
39
40fn make_table(table_name: &str, connection: &Connection) -> Result<()> {
41    connection.execute_batch(&format!(
42        "CREATE TABLE IF NOT EXISTS {} (
43             hash       BLOB,
44             id         INTEGER
45            )
46                ",
47        table_name
48    ))?;
49    Ok(())
50}
51
52fn insert_table<K>(
53    table_name: &str,
54    hash: &Vec<K>,
55    idx: u32,
56    connection: &Connection,
57) -> Result<usize> {
58    let blob = vec_to_blob(hash);
59    let mut stmt = connection.prepare_cached(&format!(
60        "
61INSERT INTO {} (hash, id)
62VALUES (?1, ?2)
63        ",
64        table_name
65    ))?;
66    let idx = stmt.execute(params![blob, idx])?;
67    Ok(idx)
68}
69
70fn hash_table_stats(
71    table_name: &str,
72    limit: u32,
73    conn: &Connection,
74) -> Result<(f64, f64, u32, u32)> {
75    let mut stmt = conn.prepare_cached(&format!(
76        "
77SELECT
78	avg(c) as mean,
79	avg(c * c) - avg(c) * avg(c) as variance,
80	min(c) as minimum,
81	max(c) as maximum
82
83FROM (
84	SELECT count(id) as c
85	FROM {}
86	GROUP BY hash
87	LIMIT ?
88);
89    ",
90        table_name
91    ))?;
92    let out = stmt.query_row(params![limit], |row| {
93        let mean: f64 = row.get(0)?;
94        let variance: f64 = row.get(1)?;
95        let stdev = variance.powf(0.5);
96        let minimum: u32 = row.get(2)?;
97        let maximum: u32 = row.get(3)?;
98        Ok((mean, stdev, minimum, maximum))
99    })?;
100    Ok(out)
101}
102
103/// Sqlite backend for [LSH](struct.LSH.html).
104///
105/// State will be save during sessions. The database is automatically
106/// loaded if [LSH](struct.LSH.html) can find the database file (defaults to `./lsh.db3`.
107pub struct SqlTable<N, K>
108where
109    N: Numeric,
110    K: Integer,
111{
112    n_hash_tables: usize,
113    only_index_storage: bool, // for now only supported
114    counter: u32,
115    pub conn: Connection,
116    table_names: Vec<String>,
117    pub committed: Cell<bool>,
118    phantom: PhantomData<(N, K)>,
119}
120
121fn fmt_table_name(hash_table: usize) -> String {
122    format!("hash_table_{}", hash_table)
123}
124
125fn get_table_names(n_hash_tables: usize) -> Vec<String> {
126    let mut table_names = Vec::with_capacity(n_hash_tables);
127    for idx in 0..n_hash_tables {
128        let table_name = fmt_table_name(idx);
129        table_names.push(table_name);
130    }
131    table_names
132}
133
134fn get_unique_hash_int(n_hash_tables: usize, conn: &Connection) -> Result<FnvHashSet<i32>> {
135    let mut hash_numbers = FnvHashSet::default();
136    for table_name in get_table_names(n_hash_tables) {
137        let mut stmt = conn.prepare(&format!["SELECT hash FROM {} LIMIT 100;", table_name])?;
138        let mut rows = stmt.query([])?;
139
140        while let Some(r) = rows.next()? {
141            let blob: Vec<u8> = r.get(0)?;
142            let hash = blob_to_vec(&blob);
143            hash.iter().for_each(|&v| {
144                hash_numbers.insert(v);
145            })
146        }
147    }
148    Ok(hash_numbers)
149}
150
151fn init_table(conn: &Connection, table_names: &[String]) -> Result<()> {
152    for table_name in table_names {
153        make_table(&table_name, &conn)?;
154    }
155    Ok(())
156}
157
158fn init_db_setttings(conn: &Connection) -> Result<()> {
159    conn.execute_batch(
160        "PRAGMA journal_mode = OFF;
161    PRAGMA synchronous = OFF;
162    PRAGMA cache_size = 100000;
163    PRAGMA main.locking_mode=EXCLUSIVE;",
164    )?;
165    Ok(())
166}
167
168impl<N, K> SqlTable<N, K>
169where
170    N: Numeric,
171    K: Integer,
172{
173    fn get_table_name_put(&self, hash_table: usize) -> Result<&str> {
174        let opt = self.table_names.get(hash_table);
175        match opt {
176            Some(tbl_name) => Ok(&tbl_name[..]),
177            None => Err(Error::TableNotExist),
178        }
179    }
180
181    pub fn init_from_conn(
182        n_hash_tables: usize,
183        only_index_storage: bool,
184        conn: Connection,
185    ) -> Result<Self> {
186        let table_names = get_table_names(n_hash_tables);
187        init_db_setttings(&conn)?;
188        init_table(&conn, &table_names)?;
189        let sql = SqlTable {
190            n_hash_tables,
191            only_index_storage,
192            counter: 0,
193            conn,
194            table_names,
195            committed: Cell::new(false),
196            phantom: PhantomData,
197        };
198        sql.init_transaction()?;
199        Ok(sql)
200    }
201
202    pub fn commit(&self) -> Result<()> {
203        if !self.committed.replace(true) {
204            self.conn.execute_batch("COMMIT TRANSACTION;")?;
205        }
206        Ok(())
207    }
208
209    pub fn init_transaction(&self) -> Result<()> {
210        self.committed.set(false);
211        self.conn.execute_batch("BEGIN TRANSACTION;")?;
212        Ok(())
213    }
214
215    pub fn to_mem(&mut self) -> Result<()> {
216        let mut new_con = rusqlite::Connection::open_in_memory()?;
217        {
218            let backup = rusqlite::backup::Backup::new(&self.conn, &mut new_con)?;
219            backup.step(-1)?;
220        }
221        self.conn = new_con;
222        self.committed.set(true);
223        Ok(())
224    }
225
226    pub fn index_hash(&self) -> Result<()> {
227        self.commit()?;
228        for tbl_name in get_table_names(self.n_hash_tables) {
229            self.conn.execute_batch(&format!(
230                "
231                CREATE INDEX hash_index_{}
232                ON {} (hash);",
233                tbl_name, tbl_name
234            ))?;
235        }
236        Ok(())
237    }
238}
239
240impl<N, K> HashTables<N, K> for SqlTable<N, K>
241where
242    N: Numeric,
243    K: Integer,
244{
245    fn new(n_hash_tables: usize, only_index_storage: bool, db_path: &str) -> Result<Box<Self>> {
246        let path = std::path::Path::new(db_path);
247        let conn = Connection::open(path)?;
248        SqlTable::init_from_conn(n_hash_tables, only_index_storage, conn).map(|tbl| Box::new(tbl))
249    }
250
251    fn put(&mut self, hash: Vec<K>, _d: &[N], hash_table: usize) -> Result<u32> {
252        // the unique id of the unique vector
253        let idx = self.counter;
254
255        // Get the table name to store this id
256        let table_name = self.get_table_name_put(hash_table)?;
257        let r = insert_table(&table_name, &hash, idx, &self.conn);
258
259        // Once we've traversed the last table we increment the id counter.
260        if hash_table == self.n_hash_tables - 1 {
261            self.counter += 1
262        };
263
264        match r {
265            Ok(_) => Ok(idx),
266            Err(Error::SqlFailure(_)) => Ok(idx), // duplicates
267            Err(e) => Err(Error::Failed(format!("{:?}", e))),
268        }
269    }
270
271    /// Query the whole bucket
272    fn query_bucket(&self, hash: &[K], hash_table: usize) -> Result<Bucket> {
273        self.commit()?;
274        let table_name = fmt_table_name(hash_table);
275        let blob = vec_to_blob(hash);
276        let res = query_bucket(blob, &table_name, &self.conn);
277
278        match res {
279            Ok(bucket) => Ok(bucket),
280            Err(e) => Err(Error::Failed(format!("{:?}", e))),
281        }
282    }
283
284    fn describe(&self) -> Result<String> {
285        let mut stmt = self.conn.prepare(
286            r#"SELECT count(*) FROM sqlite_master
287WHERE type='table' AND type LIKE '%hash%';"#,
288        )?;
289
290        let row: String = stmt.query_row([], |row| {
291            let i: i64 = row.get_unwrap(0);
292            Ok(i.to_string())
293        })?;
294        let mut out = String::from(format!("No. of tables: {}\n", row));
295
296        out.push_str("Unique hash values:\n");
297        let hv = get_unique_hash_int(self.n_hash_tables, &self.conn).unwrap();
298        out.push_str(&format!("{:?}", hv));
299
300        let tables = get_table_names(self.n_hash_tables);
301        let mut avg = Vec::with_capacity(self.n_hash_tables);
302        let mut std_dev = Vec::with_capacity(self.n_hash_tables);
303        let mut min = Vec::with_capacity(self.n_hash_tables);
304        let mut max = Vec::with_capacity(self.n_hash_tables);
305
306        // maximum 3 tables will be used in stats
307        let i = std::cmp::min(3, self.n_hash_tables);
308        for table_name in &tables[..i] {
309            let stats = hash_table_stats(&table_name, DESCRIBE_MAX, &self.conn)?;
310            avg.push(stats.0);
311            std_dev.push(stats.1);
312            min.push(stats.2);
313            max.push(stats.3);
314        }
315        out.push_str("\nHash collisions:\n");
316        out.push_str(&format!("avg:\t{:?}\n", avg));
317        out.push_str(&format!("std-dev:\t{:?}\n", std_dev));
318        out.push_str(&format!("min:\t{:?}\n", min));
319        out.push_str(&format!("max:\t{:?}\n", max));
320        Ok(out)
321    }
322
323    fn store_hashers<H: VecHash<N, K> + Serialize>(&mut self, hashers: &[H]) -> Result<()> {
324        let buf: Vec<u8> = bincode::serialize(hashers)?;
325
326        // fails if already exists
327        self.conn.execute_batch(
328            "CREATE TABLE state (
329            hashers     BLOB
330        )",
331        )?;
332        let mut stmt = self
333            .conn
334            .prepare("INSERT INTO state (hashers) VALUES (?1)")?;
335
336        // unlock database by committing any running transaction.
337        self.commit()?;
338        stmt.execute(params![buf])?;
339        self.init_transaction()?;
340        Ok(())
341    }
342
343    fn load_hashers<H: VecHash<N, K> + DeserializeOwned>(&self) -> Result<Vec<H>> {
344        let mut stmt = self.conn.prepare("SELECT * FROM state;")?;
345        let buf: Vec<u8> = stmt.query_row([], |row| {
346            let v: Vec<u8> = row.get_unwrap(0);
347            Ok(v)
348        })?;
349        let hashers: Vec<H> = bincode::deserialize(&buf)?;
350        Ok(hashers)
351    }
352
353    fn get_unique_hash_int(&self) -> FnvHashSet<i32> {
354        get_unique_hash_int(self.n_hash_tables, &self.conn).unwrap()
355    }
356}
357
358#[cfg(test)]
359mod test {
360    use super::*;
361    use crate::table::sqlite_mem::SqlTableMem;
362
363    #[test]
364    fn test_sql_table_init() {
365        let sql = SqlTableMem::<f32, i8>::new(1, true, ".").unwrap();
366        let mut stmt = sql
367            .conn
368            .prepare(&format!("SELECT * FROM {}", sql.table_names[0]))
369            .expect("query failed");
370        stmt.query([]).expect("query failed");
371    }
372
373    #[test]
374    fn test_sql_crud() {
375        let mut sql = *SqlTableMem::new(1, true, ".").unwrap();
376        let v = vec![1., 2.];
377        for hash in &[vec![1, 2], vec![2, 3]] {
378            sql.put(hash.clone(), &v, 0).unwrap();
379        }
380        // make one hash collision by repeating one hash
381        let hash = vec![1, 2];
382        sql.put(hash.clone(), &v, 0).unwrap();
383        let bucket = sql.query_bucket(&hash, 0);
384        println!("{:?}", &bucket);
385        match bucket {
386            Ok(b) => assert!(b.contains(&0)),
387            _ => assert!(false),
388        }
389    }
390
391    #[test]
392    fn test_blob_hash_casting() {
393        for hash in vec![
394            &vec![2, 3, 4],
395            &vec![-124, 32, 89],
396            &vec![1, 2, 3, 4, 5, 6],
397            &vec![-12, -2, -3, 1, 2, 3, 4, 5, 6],
398        ] {
399            let hash = &hash[..];
400            let blob = vec_to_blob(hash);
401            let hash_back: &[i32] = blob_to_vec(blob);
402            assert_eq!(hash, hash_back)
403        }
404    }
405
406    #[test]
407    fn test_in_mem_to_disk() {
408        let mut sql = *SqlTableMem::<f32, i8>::new(1, true, ".").unwrap();
409        let v = vec![1., 2.];
410        for hash in &[vec![1, 2], vec![2, 3]] {
411            sql.put(hash.clone(), &v, 0).unwrap();
412        }
413        sql.commit().unwrap();
414        let p = "./delete.db3";
415        sql.to_db(p).unwrap();
416
417        let mut sql = SqlTable::<f32, i8>::new(1, true, p).unwrap();
418        sql.to_mem().unwrap();
419        assert_eq!(sql.query_bucket(&vec![1, 2], 0).unwrap().take(&0), Some(0));
420        std::fs::remove_file(p).unwrap();
421    }
422}