use std::path::Path;
use lean_semantic_search_contract::{DeclarationFeatureRow, OpaqueFeatureKey};
use lean_semantic_search_retrieval::Corpus;
use rusqlite::{Connection, OpenFlags, OptionalExtension, params, params_from_iter};
use crate::StoreError;
use crate::schema::{
META_CORPUS_TOKEN, META_POLICY_VERSION, META_SCHEMA_VERSION, META_TOTAL_DOCUMENTS, STORE_SCHEMA_VERSION,
};
const FANOUT_CHUNK: usize = 900;
pub struct Store {
connection: Connection,
total_documents: usize,
schema_version: String,
policy_version: String,
corpus_token: String,
}
impl Store {
pub fn open(path: impl AsRef<Path>) -> Result<Self, StoreError> {
let flags = OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_NO_MUTEX;
let connection = Connection::open_with_flags(path, flags)?;
connection.execute_batch("PRAGMA query_only = true;")?;
let schema_version = read_metadata(&connection, META_SCHEMA_VERSION)?;
if schema_version != STORE_SCHEMA_VERSION {
return Err(StoreError::SchemaMismatch {
found: schema_version,
expected: STORE_SCHEMA_VERSION,
});
}
let policy_version = read_metadata(&connection, META_POLICY_VERSION)?;
let corpus_token = read_metadata(&connection, META_CORPUS_TOKEN)?;
let total_documents = read_metadata(&connection, META_TOTAL_DOCUMENTS)?
.parse::<usize>()
.map_err(|_| StoreError::Corrupt(META_TOTAL_DOCUMENTS))?;
Ok(Self {
connection,
total_documents,
schema_version,
policy_version,
corpus_token,
})
}
#[must_use]
pub fn corpus_token(&self) -> &str {
&self.corpus_token
}
#[must_use]
pub fn schema_version(&self) -> &str {
&self.schema_version
}
#[must_use]
pub fn policy_version(&self) -> &str {
&self.policy_version
}
#[must_use]
pub fn document_total(&self) -> usize {
self.total_documents
}
}
impl Corpus for Store {
fn document_total(&self) -> usize {
self.total_documents
}
fn fanout(&self, keys: &[OpaqueFeatureKey]) -> Vec<usize> {
if keys.is_empty() {
return Vec::new();
}
let mut distinct: Vec<&str> = keys.iter().map(OpaqueFeatureKey::as_str).collect();
distinct.sort_unstable();
distinct.dedup();
let mut counts: std::collections::HashMap<String, usize> = std::collections::HashMap::new();
for chunk in distinct.chunks(FANOUT_CHUNK) {
let placeholders = vec!["?"; chunk.len()].join(",");
let sql = format!("SELECT key, COUNT(*) FROM postings WHERE key IN ({placeholders}) GROUP BY key");
let Ok(mut statement) = self.connection.prepare_cached(&sql) else {
continue;
};
let Ok(rows) = statement.query_map(params_from_iter(chunk.iter().copied()), |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
}) else {
continue;
};
for (key, count) in rows.flatten() {
counts.insert(key, usize::try_from(count).unwrap_or(0));
}
}
keys.iter()
.map(|key| counts.get(key.as_str()).copied().unwrap_or(0))
.collect()
}
fn postings(&self, key: &OpaqueFeatureKey, limit: usize) -> Vec<String> {
let bound = i64::try_from(limit).unwrap_or(i64::MAX);
let Ok(mut statement) = self
.connection
.prepare_cached("SELECT declaration_id FROM postings WHERE key = ?1 ORDER BY declaration_id LIMIT ?2")
else {
return Vec::new();
};
let Ok(rows) = statement.query_map(params![key.as_str(), bound], |row| row.get::<_, String>(0)) else {
return Vec::new();
};
rows.flatten().collect()
}
fn declaration_row(&self, declaration_id: &str) -> Option<DeclarationFeatureRow> {
let mut statement = self
.connection
.prepare_cached("SELECT row_json FROM feature_rows WHERE declaration_id = ?1")
.ok()?;
let row_json: String = statement.query_row(params![declaration_id], |row| row.get(0)).ok()?;
serde_json::from_str(&row_json).ok()
}
}
fn read_metadata(connection: &Connection, key: &str) -> Result<String, StoreError> {
connection
.prepare_cached("SELECT value FROM metadata WHERE key = ?1")?
.query_row(params![key], |row| row.get::<_, String>(0))
.optional()?
.ok_or_else(|| StoreError::MissingMetadata(key.to_owned()))
}