1#![allow(clippy::type_complexity)]
2use rusqlite::{params, Connection, Result as SqlResult};
3use serde::{Deserialize, Serialize};
4use std::path::Path;
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct PositionData {
8 pub fen: String,
9 pub vector: Vec<f64>,
10 pub evaluation: Option<f64>,
11 pub compressed_vector: Option<Vec<f64>>,
12 pub created_at: i64,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct LSHHashFunction {
17 pub random_vector: Vec<f64>,
18 pub threshold: f64,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct LSHTableData {
23 pub hash_functions: Vec<LSHHashFunction>,
24 pub num_tables: usize,
25 pub num_hash_functions: usize,
26 pub vector_dim: usize,
27}
28
29pub struct Database {
30 conn: Connection,
31}
32
33impl Database {
34 pub fn new<P: AsRef<Path>>(db_path: P) -> SqlResult<Self> {
35 let conn = Connection::open(db_path)?;
36
37 conn.execute("PRAGMA foreign_keys=ON", [])?;
39
40 let db = Database { conn };
41 db.create_tables()?;
42 Ok(db)
43 }
44
45 pub fn in_memory() -> SqlResult<Self> {
46 let conn = Connection::open_in_memory()?;
47 let db = Database { conn };
48 db.create_tables()?;
49 Ok(db)
50 }
51
52 fn create_tables(&self) -> SqlResult<()> {
53 self.conn.execute(
55 "CREATE TABLE IF NOT EXISTS positions (
56 id INTEGER PRIMARY KEY AUTOINCREMENT,
57 fen TEXT NOT NULL UNIQUE,
58 vector BLOB NOT NULL,
59 evaluation REAL,
60 compressed_vector BLOB,
61 created_at INTEGER NOT NULL,
62 updated_at INTEGER NOT NULL DEFAULT 0
63 )",
64 [],
65 )?;
66
67 self.conn.execute(
69 "CREATE TABLE IF NOT EXISTS lsh_config (
70 id INTEGER PRIMARY KEY,
71 num_tables INTEGER NOT NULL,
72 num_hash_functions INTEGER NOT NULL,
73 vector_dim INTEGER NOT NULL,
74 hash_functions BLOB NOT NULL,
75 created_at INTEGER NOT NULL,
76 updated_at INTEGER NOT NULL DEFAULT 0
77 )",
78 [],
79 )?;
80
81 self.conn.execute(
83 "CREATE TABLE IF NOT EXISTS lsh_buckets (
84 id INTEGER PRIMARY KEY AUTOINCREMENT,
85 table_id INTEGER NOT NULL,
86 bucket_hash TEXT NOT NULL,
87 position_id INTEGER NOT NULL,
88 UNIQUE(table_id, bucket_hash, position_id)
89 )",
90 [],
91 )?;
92
93 self.conn.execute(
95 "CREATE TABLE IF NOT EXISTS manifold_models (
96 id INTEGER PRIMARY KEY,
97 input_dim INTEGER NOT NULL,
98 compressed_dim INTEGER NOT NULL,
99 model_weights BLOB NOT NULL,
100 training_metadata BLOB,
101 created_at INTEGER NOT NULL,
102 updated_at INTEGER NOT NULL DEFAULT 0
103 )",
104 [],
105 )?;
106
107 self.conn.execute(
109 "CREATE INDEX IF NOT EXISTS idx_positions_fen ON positions(fen)",
110 [],
111 )?;
112
113 self.conn.execute(
114 "CREATE INDEX IF NOT EXISTS idx_lsh_buckets_table_bucket ON lsh_buckets(table_id, bucket_hash)",
115 [],
116 )?;
117
118 self.conn.execute(
119 "CREATE INDEX IF NOT EXISTS idx_positions_created_at ON positions(created_at)",
120 [],
121 )?;
122
123 Ok(())
124 }
125
126 pub fn save_position(&self, position_data: &PositionData) -> SqlResult<i64> {
127 let vector_bytes = bincode::serialize(&position_data.vector)
128 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
129
130 let compressed_vector_bytes = position_data
131 .compressed_vector
132 .as_ref()
133 .map(bincode::serialize)
134 .transpose()
135 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
136
137 let current_time = std::time::SystemTime::now()
138 .duration_since(std::time::UNIX_EPOCH)
139 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
140 .as_secs() as i64;
141
142 self.conn.execute(
143 "INSERT OR REPLACE INTO positions (fen, vector, evaluation, compressed_vector, created_at, updated_at)
144 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
145 params![
146 position_data.fen,
147 vector_bytes,
148 position_data.evaluation,
149 compressed_vector_bytes,
150 position_data.created_at,
151 current_time
152 ],
153 )?;
154
155 Ok(self.conn.last_insert_rowid())
156 }
157
158 pub fn load_position(&self, fen: &str) -> SqlResult<Option<PositionData>> {
159 let mut stmt = self.conn.prepare(
160 "SELECT fen, vector, evaluation, compressed_vector, created_at
161 FROM positions WHERE fen = ?1",
162 )?;
163
164 let mut rows = stmt.query_map([fen], |row| {
165 let vector_bytes: Vec<u8> = row.get(1)?;
166 let vector: Vec<f64> = bincode::deserialize(&vector_bytes).map_err(|e| {
167 rusqlite::Error::FromSqlConversionFailure(
168 1,
169 rusqlite::types::Type::Blob,
170 Box::new(e),
171 )
172 })?;
173
174 let compressed_vector =
175 if let Ok(Some(compressed_bytes)) = row.get::<_, Option<Vec<u8>>>(3) {
176 Some(bincode::deserialize(&compressed_bytes).map_err(|e| {
177 rusqlite::Error::FromSqlConversionFailure(
178 3,
179 rusqlite::types::Type::Blob,
180 Box::new(e),
181 )
182 })?)
183 } else {
184 None
185 };
186
187 Ok(PositionData {
188 fen: row.get(0)?,
189 vector,
190 evaluation: row.get(2)?,
191 compressed_vector,
192 created_at: row.get(4)?,
193 })
194 })?;
195
196 match rows.next() {
197 Some(Ok(position)) => Ok(Some(position)),
198 Some(Err(e)) => Err(e),
199 None => Ok(None),
200 }
201 }
202
203 pub fn load_all_positions(&self) -> SqlResult<Vec<PositionData>> {
204 let mut stmt = self.conn.prepare(
205 "SELECT fen, vector, evaluation, compressed_vector, created_at
206 FROM positions ORDER BY created_at",
207 )?;
208
209 let rows = stmt.query_map([], |row| {
210 let vector_bytes: Vec<u8> = row.get(1)?;
211 let vector: Vec<f64> = bincode::deserialize(&vector_bytes).map_err(|e| {
212 rusqlite::Error::FromSqlConversionFailure(
213 1,
214 rusqlite::types::Type::Blob,
215 Box::new(e),
216 )
217 })?;
218
219 let compressed_vector =
220 if let Ok(Some(compressed_bytes)) = row.get::<_, Option<Vec<u8>>>(3) {
221 Some(bincode::deserialize(&compressed_bytes).map_err(|e| {
222 rusqlite::Error::FromSqlConversionFailure(
223 3,
224 rusqlite::types::Type::Blob,
225 Box::new(e),
226 )
227 })?)
228 } else {
229 None
230 };
231
232 Ok(PositionData {
233 fen: row.get(0)?,
234 vector,
235 evaluation: row.get(2)?,
236 compressed_vector,
237 created_at: row.get(4)?,
238 })
239 })?;
240
241 rows.collect()
242 }
243
244 pub fn save_lsh_config(&self, config: &LSHTableData) -> SqlResult<()> {
245 let hash_functions_bytes = bincode::serialize(&config.hash_functions)
246 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
247
248 let current_time = std::time::SystemTime::now()
249 .duration_since(std::time::UNIX_EPOCH)
250 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
251 .as_secs() as i64;
252
253 self.conn.execute(
254 "INSERT OR REPLACE INTO lsh_config (id, num_tables, num_hash_functions, vector_dim, hash_functions, created_at, updated_at)
255 VALUES (1, ?1, ?2, ?3, ?4, ?5, ?6)",
256 params![
257 config.num_tables,
258 config.num_hash_functions,
259 config.vector_dim,
260 hash_functions_bytes,
261 current_time,
262 current_time
263 ],
264 )?;
265
266 Ok(())
267 }
268
269 pub fn load_lsh_config(&self) -> SqlResult<Option<LSHTableData>> {
270 let mut stmt = self.conn.prepare(
271 "SELECT num_tables, num_hash_functions, vector_dim, hash_functions
272 FROM lsh_config WHERE id = 1",
273 )?;
274
275 let mut rows = stmt.query_map([], |row| {
276 let hash_functions_bytes: Vec<u8> = row.get(3)?;
277 let hash_functions: Vec<LSHHashFunction> = bincode::deserialize(&hash_functions_bytes)
278 .map_err(|e| {
279 rusqlite::Error::FromSqlConversionFailure(
280 3,
281 rusqlite::types::Type::Blob,
282 Box::new(e),
283 )
284 })?;
285
286 Ok(LSHTableData {
287 num_tables: row.get(0)?,
288 num_hash_functions: row.get(1)?,
289 vector_dim: row.get(2)?,
290 hash_functions,
291 })
292 })?;
293
294 match rows.next() {
295 Some(Ok(config)) => Ok(Some(config)),
296 Some(Err(e)) => Err(e),
297 None => Ok(None),
298 }
299 }
300
301 pub fn save_lsh_bucket(
302 &self,
303 table_id: usize,
304 bucket_hash: &str,
305 position_id: i64,
306 ) -> SqlResult<()> {
307 self.conn.execute(
308 "INSERT OR IGNORE INTO lsh_buckets (table_id, bucket_hash, position_id)
309 VALUES (?1, ?2, ?3)",
310 params![table_id, bucket_hash, position_id],
311 )?;
312 Ok(())
313 }
314
315 pub fn load_lsh_buckets(&self, table_id: usize, bucket_hash: &str) -> SqlResult<Vec<i64>> {
316 let mut stmt = self.conn.prepare(
317 "SELECT position_id FROM lsh_buckets WHERE table_id = ?1 AND bucket_hash = ?2",
318 )?;
319
320 let rows = stmt.query_map(params![table_id, bucket_hash], |row| row.get(0))?;
321
322 rows.collect()
323 }
324
325 pub fn clear_lsh_buckets(&self) -> SqlResult<()> {
326 self.conn.execute("DELETE FROM lsh_buckets", [])?;
327 Ok(())
328 }
329
330 pub fn get_position_count(&self) -> SqlResult<i64> {
331 let mut stmt = self.conn.prepare("SELECT COUNT(*) FROM positions")?;
332 let count: i64 = stmt.query_row([], |row| row.get(0))?;
333 Ok(count)
334 }
335
336 pub fn vacuum(&self) -> SqlResult<()> {
337 self.conn.execute("VACUUM", [])?;
338 Ok(())
339 }
340
341 pub fn save_manifold_model(
342 &self,
343 input_dim: usize,
344 compressed_dim: usize,
345 model_weights: &[u8],
346 training_metadata: Option<&[u8]>,
347 ) -> SqlResult<()> {
348 let metadata_bytes = training_metadata.unwrap_or(&[]);
349
350 let current_time = std::time::SystemTime::now()
351 .duration_since(std::time::UNIX_EPOCH)
352 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
353 .as_secs() as i64;
354
355 self.conn.execute(
356 "INSERT OR REPLACE INTO manifold_models (id, input_dim, compressed_dim, model_weights, training_metadata, created_at, updated_at)
357 VALUES (1, ?1, ?2, ?3, ?4, ?5, ?6)",
358 params![
359 input_dim,
360 compressed_dim,
361 model_weights,
362 metadata_bytes,
363 current_time,
364 current_time
365 ],
366 )?;
367 Ok(())
368 }
369
370 pub fn load_manifold_model(&self) -> SqlResult<Option<(usize, usize, Vec<u8>, Vec<u8>)>> {
371 let mut stmt = self.conn.prepare(
372 "SELECT input_dim, compressed_dim, model_weights, training_metadata
373 FROM manifold_models WHERE id = 1",
374 )?;
375
376 let mut rows = stmt.query_map([], |row| {
377 Ok((
378 row.get::<_, usize>(0)?,
379 row.get::<_, usize>(1)?,
380 row.get::<_, Vec<u8>>(2)?,
381 row.get::<_, Vec<u8>>(3)?,
382 ))
383 })?;
384
385 match rows.next() {
386 Some(Ok(model)) => Ok(Some(model)),
387 Some(Err(e)) => Err(e),
388 None => Ok(None),
389 }
390 }
391
392 pub fn get_position_by_id(&self, id: i64) -> SqlResult<Option<PositionData>> {
393 let mut stmt = self.conn.prepare(
394 "SELECT fen, vector, evaluation, compressed_vector, created_at
395 FROM positions WHERE id = ?1",
396 )?;
397
398 let mut rows = stmt.query_map([id], |row| {
399 let vector_bytes: Vec<u8> = row.get(1)?;
400 let vector: Vec<f64> = bincode::deserialize(&vector_bytes).map_err(|e| {
401 rusqlite::Error::FromSqlConversionFailure(
402 1,
403 rusqlite::types::Type::Blob,
404 Box::new(e),
405 )
406 })?;
407
408 let compressed_vector =
409 if let Ok(Some(compressed_bytes)) = row.get::<_, Option<Vec<u8>>>(3) {
410 Some(bincode::deserialize(&compressed_bytes).map_err(|e| {
411 rusqlite::Error::FromSqlConversionFailure(
412 3,
413 rusqlite::types::Type::Blob,
414 Box::new(e),
415 )
416 })?)
417 } else {
418 None
419 };
420
421 Ok(PositionData {
422 fen: row.get(0)?,
423 vector,
424 evaluation: row.get(2)?,
425 compressed_vector,
426 created_at: row.get(4)?,
427 })
428 })?;
429
430 match rows.next() {
431 Some(Ok(position)) => Ok(Some(position)),
432 Some(Err(e)) => Err(e),
433 None => Ok(None),
434 }
435 }
436
437 pub fn save_positions_batch(&self, positions: &[PositionData]) -> SqlResult<usize> {
439 if positions.is_empty() {
440 return Ok(0);
441 }
442
443 let tx = self.conn.unchecked_transaction()?;
444
445 {
446 let mut stmt = tx.prepare(
447 "INSERT OR REPLACE INTO positions (fen, vector, evaluation, compressed_vector, created_at, updated_at)
448 VALUES (?1, ?2, ?3, ?4, ?5, ?6)"
449 )?;
450
451 let current_time = std::time::SystemTime::now()
452 .duration_since(std::time::UNIX_EPOCH)
453 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?
454 .as_secs() as i64;
455
456 for position_data in positions {
457 let vector_bytes = bincode::serialize(&position_data.vector)
458 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
459
460 let compressed_vector_bytes = position_data
461 .compressed_vector
462 .as_ref()
463 .map(bincode::serialize)
464 .transpose()
465 .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e)))?;
466
467 stmt.execute(params![
468 position_data.fen,
469 vector_bytes,
470 position_data.evaluation,
471 compressed_vector_bytes,
472 position_data.created_at,
473 current_time
474 ])?;
475 }
476 }
477
478 tx.commit()?;
479 Ok(positions.len())
480 }
481
482 pub fn load_positions_batch(
484 &self,
485 limit: usize,
486 offset: usize,
487 ) -> SqlResult<Vec<PositionData>> {
488 let mut stmt = self.conn.prepare(
489 "SELECT fen, vector, evaluation, compressed_vector, created_at
490 FROM positions ORDER BY id LIMIT ?1 OFFSET ?2",
491 )?;
492
493 let rows = stmt.query_map([limit, offset], |row| {
494 let vector_bytes: Vec<u8> = row.get(1)?;
495 let vector: Vec<f64> = bincode::deserialize(&vector_bytes).map_err(|e| {
496 rusqlite::Error::FromSqlConversionFailure(
497 1,
498 rusqlite::types::Type::Blob,
499 Box::new(e),
500 )
501 })?;
502
503 let compressed_vector =
504 if let Ok(Some(compressed_bytes)) = row.get::<_, Option<Vec<u8>>>(3) {
505 Some(bincode::deserialize(&compressed_bytes).map_err(|e| {
506 rusqlite::Error::FromSqlConversionFailure(
507 3,
508 rusqlite::types::Type::Blob,
509 Box::new(e),
510 )
511 })?)
512 } else {
513 None
514 };
515
516 Ok(PositionData {
517 fen: row.get(0)?,
518 vector,
519 evaluation: row.get(2)?,
520 compressed_vector,
521 created_at: row.get(4)?,
522 })
523 })?;
524
525 rows.collect()
526 }
527
528 pub fn get_total_position_count(&self) -> SqlResult<usize> {
530 let mut stmt = self.conn.prepare("SELECT COUNT(*) FROM positions")?;
531 let count: i64 = stmt.query_row([], |row| row.get(0))?;
532 Ok(count as usize)
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn test_database_creation() {
542 let db = Database::in_memory().unwrap();
543 assert_eq!(db.get_position_count().unwrap(), 0);
544 }
545
546 #[test]
547 fn test_position_storage() {
548 let db = Database::in_memory().unwrap();
549
550 let position = PositionData {
551 fen: "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1".to_string(),
552 vector: vec![1.0, 2.0, 3.0],
553 evaluation: Some(0.5),
554 compressed_vector: Some(vec![0.1, 0.2]),
555 created_at: 1234567890,
556 };
557
558 let id = db.save_position(&position).unwrap();
559 assert!(id > 0);
560
561 let loaded = db.load_position(&position.fen).unwrap().unwrap();
562 assert_eq!(loaded.fen, position.fen);
563 assert_eq!(loaded.vector, position.vector);
564 assert_eq!(loaded.evaluation, position.evaluation);
565 assert_eq!(loaded.compressed_vector, position.compressed_vector);
566 }
567
568 #[test]
569 fn test_lsh_config_storage() {
570 let db = Database::in_memory().unwrap();
571
572 let config = LSHTableData {
573 num_tables: 10,
574 num_hash_functions: 5,
575 vector_dim: 1024,
576 hash_functions: vec![LSHHashFunction {
577 random_vector: vec![1.0, -1.0, 0.5],
578 threshold: 0.0,
579 }],
580 };
581
582 db.save_lsh_config(&config).unwrap();
583
584 let loaded = db.load_lsh_config().unwrap().unwrap();
585 assert_eq!(loaded.num_tables, config.num_tables);
586 assert_eq!(loaded.num_hash_functions, config.num_hash_functions);
587 assert_eq!(loaded.vector_dim, config.vector_dim);
588 assert_eq!(loaded.hash_functions.len(), config.hash_functions.len());
589 }
590}