chess_vector_engine/
persistence.rs

1#![allow(clippy::type_complexity)]
2use rusqlite::{params, Connection, Result as SqlResult};
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct PositionData {
8    pub fen: String,
9    pub vector: Vec<f64>,
10    pub evaluation: Option<f64>,
11    pub compressed_vector: Option<Vec<f64>>,
12    pub created_at: i64,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct LSHHashFunction {
17    pub random_vector: Vec<f64>,
18    pub threshold: f64,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct LSHTableData {
23    pub hash_functions: Vec<LSHHashFunction>,
24    pub num_tables: usize,
25    pub num_hash_functions: usize,
26    pub vector_dim: usize,
27}
28
29pub struct Database {
30    conn: Connection,
31}
32
33impl Database {
34    pub fn new<P: AsRef<Path>>(db_path: P) -> SqlResult<Self> {
35        let conn = Connection::open(db_path)?;
36
37        // Enable basic optimizations
38        conn.execute("PRAGMA foreign_keys=ON", [])?;
39
40        let db = Database { conn };
41        db.create_tables()?;
42        Ok(db)
43    }
44
45    pub fn in_memory() -> SqlResult<Self> {
46        let conn = Connection::open_in_memory()?;
47        let db = Database { conn };
48        db.create_tables()?;
49        Ok(db)
50    }
51
52    fn create_tables(&self) -> SqlResult<()> {
53        // Positions table - stores chess positions and their vectors
54        self.conn.execute(
55            "CREATE TABLE IF NOT EXISTS positions (
56                id INTEGER PRIMARY KEY AUTOINCREMENT,
57                fen TEXT NOT NULL UNIQUE,
58                vector BLOB NOT NULL,
59                evaluation REAL,
60                compressed_vector BLOB,
61                created_at INTEGER NOT NULL,
62                updated_at INTEGER NOT NULL DEFAULT 0
63            )",
64            [],
65        )?;
66
67        // LSH tables data - stores LSH configuration and hash functions
68        self.conn.execute(
69            "CREATE TABLE IF NOT EXISTS lsh_config (
70                id INTEGER PRIMARY KEY,
71                num_tables INTEGER NOT NULL,
72                num_hash_functions INTEGER NOT NULL,
73                vector_dim INTEGER NOT NULL,
74                hash_functions BLOB NOT NULL,
75                created_at INTEGER NOT NULL,
76                updated_at INTEGER NOT NULL DEFAULT 0
77            )",
78            [],
79        )?;
80
81        // LSH buckets - stores position assignments to hash buckets
82        self.conn.execute(
83            "CREATE TABLE IF NOT EXISTS lsh_buckets (
84                id INTEGER PRIMARY KEY AUTOINCREMENT,
85                table_id INTEGER NOT NULL,
86                bucket_hash TEXT NOT NULL,
87                position_id INTEGER NOT NULL,
88                UNIQUE(table_id, bucket_hash, position_id)
89            )",
90            [],
91        )?;
92
93        // Manifold model data - stores trained autoencoder weights
94        self.conn.execute(
95            "CREATE TABLE IF NOT EXISTS manifold_models (
96                id INTEGER PRIMARY KEY,
97                input_dim INTEGER NOT NULL,
98                compressed_dim INTEGER NOT NULL,
99                model_weights BLOB NOT NULL,
100                training_metadata BLOB,
101                created_at INTEGER NOT NULL,
102                updated_at INTEGER NOT NULL DEFAULT 0
103            )",
104            [],
105        )?;
106
107        // Create indexes for better query performance
108        self.conn.execute(
109            "CREATE INDEX IF NOT EXISTS idx_positions_fen ON positions(fen)",
110            [],
111        )?;
112
113        self.conn.execute(
114            "CREATE INDEX IF NOT EXISTS idx_lsh_buckets_table_bucket ON lsh_buckets(table_id, bucket_hash)",
115            [],
116        )?;
117
118        self.conn.execute(
119            "CREATE INDEX IF NOT EXISTS idx_positions_created_at ON positions(created_at)",
120            [],
121        )?;
122
123        Ok(())
124    }
125
126    pub fn save_position(&self, position_data: &PositionData) -> SqlResult<i64> {
127        let vector_bytes = bincode::serialize(&position_data.vector)
128            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
129
130        let compressed_vector_bytes = position_data
131            .compressed_vector
132            .as_ref()
133            .map(bincode::serialize)
134            .transpose()
135            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
136
137        let current_time = std::time::SystemTime::now()
138            .duration_since(std::time::UNIX_EPOCH)
139            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
140            .as_secs() as i64;
141
142        self.conn.execute(
143            "INSERT OR REPLACE INTO positions (fen, vector, evaluation, compressed_vector, created_at, updated_at)
144             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
145            params![
146                position_data.fen,
147                vector_bytes,
148                position_data.evaluation,
149                compressed_vector_bytes,
150                position_data.created_at,
151                current_time
152            ],
153        )?;
154
155        Ok(self.conn.last_insert_rowid())
156    }
157
158    pub fn load_position(&self, fen: &str) -> SqlResult<Option<PositionData>> {
159        let mut stmt = self.conn.prepare(
160            "SELECT fen, vector, evaluation, compressed_vector, created_at 
161             FROM positions WHERE fen = ?1",
162        )?;
163
164        let mut rows = stmt.query_map([fen], |row| {
165            let vector_bytes: Vec<u8> = row.get(1)?;
166            let vector: Vec<f64> = bincode::deserialize(&vector_bytes).map_err(|e| {
167                rusqlite::Error::FromSqlConversionFailure(
168                    1,
169                    rusqlite::types::Type::Blob,
170                    Box::new(e),
171                )
172            })?;
173
174            let compressed_vector =
175                if let Ok(Some(compressed_bytes)) = row.get::<_, Option<Vec<u8>>>(3) {
176                    Some(bincode::deserialize(&compressed_bytes).map_err(|e| {
177                        rusqlite::Error::FromSqlConversionFailure(
178                            3,
179                            rusqlite::types::Type::Blob,
180                            Box::new(e),
181                        )
182                    })?)
183                } else {
184                    None
185                };
186
187            Ok(PositionData {
188                fen: row.get(0)?,
189                vector,
190                evaluation: row.get(2)?,
191                compressed_vector,
192                created_at: row.get(4)?,
193            })
194        })?;
195
196        match rows.next() {
197            Some(Ok(position)) => Ok(Some(position)),
198            Some(Err(e)) => Err(e),
199            None => Ok(None),
200        }
201    }
202
203    pub fn load_all_positions(&self) -> SqlResult<Vec<PositionData>> {
204        let mut stmt = self.conn.prepare(
205            "SELECT fen, vector, evaluation, compressed_vector, created_at 
206             FROM positions ORDER BY created_at",
207        )?;
208
209        let rows = stmt.query_map([], |row| {
210            let vector_bytes: Vec<u8> = row.get(1)?;
211            let vector: Vec<f64> = bincode::deserialize(&vector_bytes).map_err(|e| {
212                rusqlite::Error::FromSqlConversionFailure(
213                    1,
214                    rusqlite::types::Type::Blob,
215                    Box::new(e),
216                )
217            })?;
218
219            let compressed_vector =
220                if let Ok(Some(compressed_bytes)) = row.get::<_, Option<Vec<u8>>>(3) {
221                    Some(bincode::deserialize(&compressed_bytes).map_err(|e| {
222                        rusqlite::Error::FromSqlConversionFailure(
223                            3,
224                            rusqlite::types::Type::Blob,
225                            Box::new(e),
226                        )
227                    })?)
228                } else {
229                    None
230                };
231
232            Ok(PositionData {
233                fen: row.get(0)?,
234                vector,
235                evaluation: row.get(2)?,
236                compressed_vector,
237                created_at: row.get(4)?,
238            })
239        })?;
240
241        rows.collect()
242    }
243
244    pub fn save_lsh_config(&self, config: &LSHTableData) -> SqlResult<()> {
245        let hash_functions_bytes = bincode::serialize(&config.hash_functions)
246            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
247
248        let current_time = std::time::SystemTime::now()
249            .duration_since(std::time::UNIX_EPOCH)
250            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
251            .as_secs() as i64;
252
253        self.conn.execute(
254            "INSERT OR REPLACE INTO lsh_config (id, num_tables, num_hash_functions, vector_dim, hash_functions, created_at, updated_at)
255             VALUES (1, ?1, ?2, ?3, ?4, ?5, ?6)",
256            params![
257                config.num_tables,
258                config.num_hash_functions,
259                config.vector_dim,
260                hash_functions_bytes,
261                current_time,
262                current_time
263            ],
264        )?;
265
266        Ok(())
267    }
268
269    pub fn load_lsh_config(&self) -> SqlResult<Option<LSHTableData>> {
270        let mut stmt = self.conn.prepare(
271            "SELECT num_tables, num_hash_functions, vector_dim, hash_functions 
272             FROM lsh_config WHERE id = 1",
273        )?;
274
275        let mut rows = stmt.query_map([], |row| {
276            let hash_functions_bytes: Vec<u8> = row.get(3)?;
277            let hash_functions: Vec<LSHHashFunction> = bincode::deserialize(&hash_functions_bytes)
278                .map_err(|e| {
279                    rusqlite::Error::FromSqlConversionFailure(
280                        3,
281                        rusqlite::types::Type::Blob,
282                        Box::new(e),
283                    )
284                })?;
285
286            Ok(LSHTableData {
287                num_tables: row.get(0)?,
288                num_hash_functions: row.get(1)?,
289                vector_dim: row.get(2)?,
290                hash_functions,
291            })
292        })?;
293
294        match rows.next() {
295            Some(Ok(config)) => Ok(Some(config)),
296            Some(Err(e)) => Err(e),
297            None => Ok(None),
298        }
299    }
300
301    pub fn save_lsh_bucket(
302        &self,
303        table_id: usize,
304        bucket_hash: &str,
305        position_id: i64,
306    ) -> SqlResult<()> {
307        self.conn.execute(
308            "INSERT OR IGNORE INTO lsh_buckets (table_id, bucket_hash, position_id)
309             VALUES (?1, ?2, ?3)",
310            params![table_id, bucket_hash, position_id],
311        )?;
312        Ok(())
313    }
314
315    pub fn load_lsh_buckets(&self, table_id: usize, bucket_hash: &str) -> SqlResult<Vec<i64>> {
316        let mut stmt = self.conn.prepare(
317            "SELECT position_id FROM lsh_buckets WHERE table_id = ?1 AND bucket_hash = ?2",
318        )?;
319
320        let rows = stmt.query_map(params![table_id, bucket_hash], |row| row.get(0))?;
321
322        rows.collect()
323    }
324
325    pub fn clear_lsh_buckets(&self) -> SqlResult<()> {
326        self.conn.execute("DELETE FROM lsh_buckets", [])?;
327        Ok(())
328    }
329
330    pub fn get_position_count(&self) -> SqlResult<i64> {
331        let mut stmt = self.conn.prepare("SELECT COUNT(*) FROM positions")?;
332        let count: i64 = stmt.query_row([], |row| row.get(0))?;
333        Ok(count)
334    }
335
336    pub fn vacuum(&self) -> SqlResult<()> {
337        self.conn.execute("VACUUM", [])?;
338        Ok(())
339    }
340
341    pub fn save_manifold_model(
342        &self,
343        input_dim: usize,
344        compressed_dim: usize,
345        model_weights: &[u8],
346        training_metadata: Option<&[u8]>,
347    ) -> SqlResult<()> {
348        let metadata_bytes = training_metadata.unwrap_or(&[]);
349
350        let current_time = std::time::SystemTime::now()
351            .duration_since(std::time::UNIX_EPOCH)
352            .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
353            .as_secs() as i64;
354
355        self.conn.execute(
356            "INSERT OR REPLACE INTO manifold_models (id, input_dim, compressed_dim, model_weights, training_metadata, created_at, updated_at)
357             VALUES (1, ?1, ?2, ?3, ?4, ?5, ?6)",
358            params![
359                input_dim,
360                compressed_dim,
361                model_weights,
362                metadata_bytes,
363                current_time,
364                current_time
365            ],
366        )?;
367        Ok(())
368    }
369
370    pub fn load_manifold_model(&self) -> SqlResult<Option<(usize, usize, Vec<u8>, Vec<u8>)>> {
371        let mut stmt = self.conn.prepare(
372            "SELECT input_dim, compressed_dim, model_weights, training_metadata 
373             FROM manifold_models WHERE id = 1",
374        )?;
375
376        let mut rows = stmt.query_map([], |row| {
377            Ok((
378                row.get::<_, usize>(0)?,
379                row.get::<_, usize>(1)?,
380                row.get::<_, Vec<u8>>(2)?,
381                row.get::<_, Vec<u8>>(3)?,
382            ))
383        })?;
384
385        match rows.next() {
386            Some(Ok(model)) => Ok(Some(model)),
387            Some(Err(e)) => Err(e),
388            None => Ok(None),
389        }
390    }
391
392    pub fn get_position_by_id(&self, id: i64) -> SqlResult<Option<PositionData>> {
393        let mut stmt = self.conn.prepare(
394            "SELECT fen, vector, evaluation, compressed_vector, created_at 
395             FROM positions WHERE id = ?1",
396        )?;
397
398        let mut rows = stmt.query_map([id], |row| {
399            let vector_bytes: Vec<u8> = row.get(1)?;
400            let vector: Vec<f64> = bincode::deserialize(&vector_bytes).map_err(|e| {
401                rusqlite::Error::FromSqlConversionFailure(
402                    1,
403                    rusqlite::types::Type::Blob,
404                    Box::new(e),
405                )
406            })?;
407
408            let compressed_vector =
409                if let Ok(Some(compressed_bytes)) = row.get::<_, Option<Vec<u8>>>(3) {
410                    Some(bincode::deserialize(&compressed_bytes).map_err(|e| {
411                        rusqlite::Error::FromSqlConversionFailure(
412                            3,
413                            rusqlite::types::Type::Blob,
414                            Box::new(e),
415                        )
416                    })?)
417                } else {
418                    None
419                };
420
421            Ok(PositionData {
422                fen: row.get(0)?,
423                vector,
424                evaluation: row.get(2)?,
425                compressed_vector,
426                created_at: row.get(4)?,
427            })
428        })?;
429
430        match rows.next() {
431            Some(Ok(position)) => Ok(Some(position)),
432            Some(Err(e)) => Err(e),
433            None => Ok(None),
434        }
435    }
436
437    /// Save multiple positions in a single transaction for much better performance
438    pub fn save_positions_batch(&self, positions: &[PositionData]) -> SqlResult<usize> {
439        if positions.is_empty() {
440            return Ok(0);
441        }
442
443        let tx = self.conn.unchecked_transaction()?;
444
445        {
446            let mut stmt = tx.prepare(
447                "INSERT OR REPLACE INTO positions (fen, vector, evaluation, compressed_vector, created_at, updated_at)
448                 VALUES (?1, ?2, ?3, ?4, ?5, ?6)"
449            )?;
450
451            let current_time = std::time::SystemTime::now()
452                .duration_since(std::time::UNIX_EPOCH)
453                .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
454                .as_secs() as i64;
455
456            for position_data in positions {
457                let vector_bytes = bincode::serialize(&position_data.vector)
458                    .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
459
460                let compressed_vector_bytes = position_data
461                    .compressed_vector
462                    .as_ref()
463                    .map(bincode::serialize)
464                    .transpose()
465                    .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
466
467                stmt.execute(params![
468                    position_data.fen,
469                    vector_bytes,
470                    position_data.evaluation,
471                    compressed_vector_bytes,
472                    position_data.created_at,
473                    current_time
474                ])?;
475            }
476        }
477
478        tx.commit()?;
479        Ok(positions.len())
480    }
481
482    /// Load positions in batches for better memory efficiency
483    pub fn load_positions_batch(
484        &self,
485        limit: usize,
486        offset: usize,
487    ) -> SqlResult<Vec<PositionData>> {
488        let mut stmt = self.conn.prepare(
489            "SELECT fen, vector, evaluation, compressed_vector, created_at 
490             FROM positions ORDER BY id LIMIT ?1 OFFSET ?2",
491        )?;
492
493        let rows = stmt.query_map([limit, offset], |row| {
494            let vector_bytes: Vec<u8> = row.get(1)?;
495            let vector: Vec<f64> = bincode::deserialize(&vector_bytes).map_err(|e| {
496                rusqlite::Error::FromSqlConversionFailure(
497                    1,
498                    rusqlite::types::Type::Blob,
499                    Box::new(e),
500                )
501            })?;
502
503            let compressed_vector =
504                if let Ok(Some(compressed_bytes)) = row.get::<_, Option<Vec<u8>>>(3) {
505                    Some(bincode::deserialize(&compressed_bytes).map_err(|e| {
506                        rusqlite::Error::FromSqlConversionFailure(
507                            3,
508                            rusqlite::types::Type::Blob,
509                            Box::new(e),
510                        )
511                    })?)
512                } else {
513                    None
514                };
515
516            Ok(PositionData {
517                fen: row.get(0)?,
518                vector,
519                evaluation: row.get(2)?,
520                compressed_vector,
521                created_at: row.get(4)?,
522            })
523        })?;
524
525        rows.collect()
526    }
527
528    /// Get the total count of positions in the database (as usize)
529    pub fn get_total_position_count(&self) -> SqlResult<usize> {
530        let mut stmt = self.conn.prepare("SELECT COUNT(*) FROM positions")?;
531        let count: i64 = stmt.query_row([], |row| row.get(0))?;
532        Ok(count as usize)
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[test]
541    fn test_database_creation() {
542        let db = Database::in_memory().unwrap();
543        assert_eq!(db.get_position_count().unwrap(), 0);
544    }
545
546    #[test]
547    fn test_position_storage() {
548        let db = Database::in_memory().unwrap();
549
550        let position = PositionData {
551            fen: "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1".to_string(),
552            vector: vec![1.0, 2.0, 3.0],
553            evaluation: Some(0.5),
554            compressed_vector: Some(vec![0.1, 0.2]),
555            created_at: 1234567890,
556        };
557
558        let id = db.save_position(&position).unwrap();
559        assert!(id > 0);
560
561        let loaded = db.load_position(&position.fen).unwrap().unwrap();
562        assert_eq!(loaded.fen, position.fen);
563        assert_eq!(loaded.vector, position.vector);
564        assert_eq!(loaded.evaluation, position.evaluation);
565        assert_eq!(loaded.compressed_vector, position.compressed_vector);
566    }
567
568    #[test]
569    fn test_lsh_config_storage() {
570        let db = Database::in_memory().unwrap();
571
572        let config = LSHTableData {
573            num_tables: 10,
574            num_hash_functions: 5,
575            vector_dim: 1024,
576            hash_functions: vec![LSHHashFunction {
577                random_vector: vec![1.0, -1.0, 0.5],
578                threshold: 0.0,
579            }],
580        };
581
582        db.save_lsh_config(&config).unwrap();
583
584        let loaded = db.load_lsh_config().unwrap().unwrap();
585        assert_eq!(loaded.num_tables, config.num_tables);
586        assert_eq!(loaded.num_hash_functions, config.num_hash_functions);
587        assert_eq!(loaded.vector_dim, config.vector_dim);
588        assert_eq!(loaded.hash_functions.len(), config.hash_functions.len());
589    }
590}