use std::collections::BTreeMap;
use std::ffi::{CStr, CString};
use std::os::raw::{c_char, c_int, c_uchar, c_void};
use std::path::Path;
use std::ptr;
use std::slice;
use serde::{Deserialize, Serialize};
use crate::model::Chunk;
use super::{path_string, AsrError, AsrResult};
pub(crate) const INDEX_STATUS_READY: &str = "ready";
pub(crate) const INDEX_STATUS_FAILED: &str = "failed";
const SQLITE_OK: c_int = 0;
const SQLITE_ROW: c_int = 100;
const SQLITE_DONE: c_int = 101;
#[repr(C)]
struct sqlite3 {
_private: [u8; 0],
}
#[repr(C)]
struct sqlite3_stmt {
_private: [u8; 0],
}
#[link(name = "sqlite3")]
extern "C" {
fn sqlite3_open(filename: *const c_char, pp_db: *mut *mut sqlite3) -> c_int;
fn sqlite3_close(db: *mut sqlite3) -> c_int;
fn sqlite3_close_v2(db: *mut sqlite3) -> c_int;
fn sqlite3_errmsg(db: *mut sqlite3) -> *const c_char;
fn sqlite3_exec(
db: *mut sqlite3,
sql: *const c_char,
callback: Option<
unsafe extern "C" fn(*mut c_void, c_int, *mut *mut c_char, *mut *mut c_char) -> c_int,
>,
arg: *mut c_void,
errmsg: *mut *mut c_char,
) -> c_int;
fn sqlite3_prepare_v2(
db: *mut sqlite3,
sql: *const c_char,
n_byte: c_int,
pp_stmt: *mut *mut sqlite3_stmt,
pz_tail: *mut *const c_char,
) -> c_int;
fn sqlite3_finalize(stmt: *mut sqlite3_stmt) -> c_int;
fn sqlite3_step(stmt: *mut sqlite3_stmt) -> c_int;
fn sqlite3_reset(stmt: *mut sqlite3_stmt) -> c_int;
fn sqlite3_clear_bindings(stmt: *mut sqlite3_stmt) -> c_int;
fn sqlite3_bind_text(
stmt: *mut sqlite3_stmt,
index: c_int,
value: *const c_char,
n: c_int,
destructor: Option<unsafe extern "C" fn(*mut c_void)>,
) -> c_int;
fn sqlite3_bind_null(stmt: *mut sqlite3_stmt, index: c_int) -> c_int;
fn sqlite3_column_text(stmt: *mut sqlite3_stmt, index: c_int) -> *const c_uchar;
fn sqlite3_column_bytes(stmt: *mut sqlite3_stmt, index: c_int) -> c_int;
fn sqlite3_column_int64(stmt: *mut sqlite3_stmt, index: c_int) -> i64;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RepoRecord {
pub name: String,
pub source_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub remote_url: Option<String>,
pub local_path: String,
pub git_root: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub default_branch: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub head_commit: Option<String>,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IndexStateRecord {
pub repo_name: String,
pub status: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub head_commit: Option<String>,
pub dirty: bool,
pub untracked: bool,
pub modified: bool,
pub worktree_fingerprint: String,
pub indexed_files: usize,
pub total_chunks: usize,
pub languages: BTreeMap<String, usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content_hash: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub indexed_at: String,
}
const SCHEMA_VERSION: i64 = 2;
pub(crate) struct Store {
db: Database,
}
impl Store {
pub(crate) fn open(db_path: &Path) -> AsrResult<Self> {
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent).map_err(|err| {
AsrError::with_path(
"asr_home_create_failed",
format!("Failed to create database parent directory: {err}"),
path_string(parent),
)
})?;
}
Ok(Self {
db: Database::open(db_path)?,
})
}
pub(crate) fn init_schema(&self) -> AsrResult<()> {
if self.db.user_version()? == SCHEMA_VERSION {
return Ok(());
}
self.db.exec(
r#"
PRAGMA journal_mode = WAL;
CREATE TABLE IF NOT EXISTS repos (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
source_type TEXT NOT NULL,
remote_url TEXT,
local_path TEXT NOT NULL,
git_root TEXT NOT NULL,
default_branch TEXT,
head_commit TEXT,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS repos_name_idx ON repos(name);
CREATE INDEX IF NOT EXISTS repos_source_type_idx ON repos(source_type);
CREATE TABLE IF NOT EXISTS index_state (
repo_name TEXT PRIMARY KEY,
status TEXT NOT NULL,
head_commit TEXT,
dirty TEXT NOT NULL,
untracked TEXT NOT NULL,
modified TEXT NOT NULL,
worktree_fingerprint TEXT NOT NULL,
indexed_files TEXT NOT NULL,
total_chunks TEXT NOT NULL,
languages_json TEXT NOT NULL,
content_hash TEXT,
error TEXT,
indexed_at TEXT NOT NULL,
FOREIGN KEY(repo_name) REFERENCES repos(name) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS index_state_status_idx ON index_state(status);
CREATE INDEX IF NOT EXISTS index_state_head_idx ON index_state(head_commit);
CREATE TABLE IF NOT EXISTS chunks (
repo_name TEXT NOT NULL,
head_commit TEXT,
path TEXT NOT NULL,
start_line TEXT NOT NULL,
end_line TEXT NOT NULL,
language TEXT,
content TEXT NOT NULL,
FOREIGN KEY(repo_name) REFERENCES repos(name) ON DELETE CASCADE
);
CREATE INDEX IF NOT EXISTS chunks_repo_idx ON chunks(repo_name);
CREATE INDEX IF NOT EXISTS chunks_repo_head_idx ON chunks(repo_name, head_commit);
CREATE INDEX IF NOT EXISTS chunks_path_idx ON chunks(repo_name, path);
"#,
)?;
if let Err(err) = self.db.exec(
"ALTER TABLE index_state ADD COLUMN worktree_fingerprint TEXT NOT NULL DEFAULT ''",
) {
log::debug!("Schema migration (worktree_fingerprint): {}", err.message);
}
self.db
.exec(&format!("PRAGMA user_version = {SCHEMA_VERSION}"))?;
Ok(())
}
pub(crate) fn insert_repo(&self, repo: &RepoRecord) -> AsrResult<()> {
let mut stmt = self.db.prepare(
"INSERT INTO repos (name, source_type, remote_url, local_path, git_root, default_branch, head_commit, created_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
)?;
stmt.bind_text(1, &repo.name)?;
stmt.bind_text(2, &repo.source_type)?;
stmt.bind_optional_text(3, repo.remote_url.as_deref())?;
stmt.bind_text(4, &repo.local_path)?;
stmt.bind_text(5, &repo.git_root)?;
stmt.bind_optional_text(6, repo.default_branch.as_deref())?;
stmt.bind_optional_text(7, repo.head_commit.as_deref())?;
stmt.bind_text(8, &repo.created_at)?;
stmt.bind_text(9, &repo.updated_at)?;
expect_done(stmt.step()?, "Unexpected row while inserting repository")
}
pub(crate) fn update_repo_head(
&self,
name: &str,
branch: Option<&str>,
head_commit: Option<&str>,
updated_at: &str,
) -> AsrResult<()> {
let mut stmt = self.db.prepare(
"UPDATE repos SET default_branch = ?, head_commit = ?, updated_at = ? WHERE name = ?",
)?;
stmt.bind_optional_text(1, branch)?;
stmt.bind_optional_text(2, head_commit)?;
stmt.bind_text(3, updated_at)?;
stmt.bind_text(4, name)?;
expect_done(
stmt.step()?,
"Unexpected row while updating repository head",
)
}
pub(crate) fn get_repo(&self, name: &str) -> AsrResult<Option<RepoRecord>> {
let mut stmt = self.db.prepare(
"SELECT name, source_type, remote_url, local_path, git_root, default_branch, head_commit, created_at, updated_at FROM repos WHERE name = ? LIMIT 1",
)?;
stmt.bind_text(1, name)?;
match stmt.step()? {
Step::Row => Ok(Some(repo_from_statement(&stmt))),
Step::Done => Ok(None),
}
}
pub(crate) fn list_repos(&self) -> AsrResult<Vec<RepoRecord>> {
let mut stmt = self.db.prepare(
"SELECT name, source_type, remote_url, local_path, git_root, default_branch, head_commit, created_at, updated_at FROM repos ORDER BY name COLLATE BINARY ASC",
)?;
let mut repos = Vec::new();
while let Step::Row = stmt.step()? {
repos.push(repo_from_statement(&stmt));
}
Ok(repos)
}
pub(crate) fn put_index_state(&self, state: &IndexStateRecord) -> AsrResult<()> {
self.write_index_state(state)
}
pub(crate) fn replace_index(
&self,
repo_name: &str,
state: &IndexStateRecord,
chunks: &[Chunk],
) -> AsrResult<()> {
self.db.exec("BEGIN IMMEDIATE")?;
let result = (|| -> AsrResult<()> {
self.delete_chunks(repo_name)?;
self.write_index_state(state)?;
self.insert_chunks(repo_name, state.head_commit.as_deref(), chunks)?;
Ok(())
})();
match result {
Ok(()) => self.db.exec("COMMIT"),
Err(err) => {
if let Err(rb_err) = self.db.exec("ROLLBACK") {
log::warn!(
"SQLite ROLLBACK failed: {}; connection will rollback on close",
rb_err.message
);
}
Err(err)
}
}
}
pub(crate) fn get_index_state(&self, repo_name: &str) -> AsrResult<Option<IndexStateRecord>> {
let mut stmt = self.db.prepare(
"SELECT repo_name, status, head_commit, dirty, untracked, modified, worktree_fingerprint, indexed_files, total_chunks, languages_json, content_hash, error, indexed_at FROM index_state WHERE repo_name = ? LIMIT 1",
)?;
stmt.bind_text(1, repo_name)?;
match stmt.step()? {
Step::Row => Ok(Some(index_state_from_statement(&stmt))),
Step::Done => Ok(None),
}
}
pub(crate) fn list_chunks(&self, repo_name: &str) -> AsrResult<Vec<Chunk>> {
let mut stmt = self.db.prepare(
"SELECT content, path, start_line, end_line, language FROM chunks WHERE repo_name = ? ORDER BY path COLLATE BINARY ASC, CAST(start_line AS INTEGER) ASC, CAST(end_line AS INTEGER) ASC",
)?;
stmt.bind_text(1, repo_name)?;
let mut chunks = Vec::new();
while let Step::Row = stmt.step()? {
chunks.push(chunk_from_statement(&stmt));
}
Ok(chunks)
}
fn delete_chunks(&self, repo_name: &str) -> AsrResult<()> {
let mut stmt = self.db.prepare("DELETE FROM chunks WHERE repo_name = ?")?;
stmt.bind_text(1, repo_name)?;
expect_done(stmt.step()?, "Unexpected row while deleting chunks")
}
fn insert_chunks(
&self,
repo_name: &str,
head_commit: Option<&str>,
chunks: &[Chunk],
) -> AsrResult<()> {
let mut stmt = self.db.prepare(
"INSERT INTO chunks (repo_name, head_commit, path, start_line, end_line, language, content) VALUES (?, ?, ?, ?, ?, ?, ?)",
)?;
for chunk in chunks {
stmt.bind_text(1, repo_name)?;
stmt.bind_optional_text(2, head_commit)?;
stmt.bind_text(3, &chunk.file_path)?;
stmt.bind_text(4, &chunk.start_line.to_string())?;
stmt.bind_text(5, &chunk.end_line.to_string())?;
stmt.bind_optional_text(6, chunk.language.as_deref())?;
stmt.bind_text(7, &chunk.content)?;
expect_done(stmt.step()?, "Unexpected row while inserting chunk")?;
stmt.reset()?;
}
Ok(())
}
fn write_index_state(&self, state: &IndexStateRecord) -> AsrResult<()> {
let languages_json = serde_json::to_string(&state.languages).map_err(|err| {
AsrError::new(
"store_error",
format!("Failed to serialize index language metadata: {err}"),
)
})?;
let mut stmt = self.db.prepare(
"INSERT OR REPLACE INTO index_state (repo_name, status, head_commit, dirty, untracked, modified, worktree_fingerprint, indexed_files, total_chunks, languages_json, content_hash, error, indexed_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
)?;
stmt.bind_text(1, &state.repo_name)?;
stmt.bind_text(2, &state.status)?;
stmt.bind_optional_text(3, state.head_commit.as_deref())?;
stmt.bind_text(4, bool_text(state.dirty))?;
stmt.bind_text(5, bool_text(state.untracked))?;
stmt.bind_text(6, bool_text(state.modified))?;
stmt.bind_text(7, &state.worktree_fingerprint)?;
stmt.bind_text(8, &state.indexed_files.to_string())?;
stmt.bind_text(9, &state.total_chunks.to_string())?;
stmt.bind_text(10, &languages_json)?;
stmt.bind_optional_text(11, state.content_hash.as_deref())?;
stmt.bind_optional_text(12, state.error.as_deref())?;
stmt.bind_text(13, &state.indexed_at)?;
expect_done(stmt.step()?, "Unexpected row while writing index state")
}
}
fn repo_from_statement(stmt: &Statement<'_>) -> RepoRecord {
RepoRecord {
name: stmt.column_text(0).unwrap_or_default(),
source_type: stmt.column_text(1).unwrap_or_default(),
remote_url: stmt.column_text(2),
local_path: stmt.column_text(3).unwrap_or_default(),
git_root: stmt.column_text(4).unwrap_or_default(),
default_branch: stmt.column_text(5),
head_commit: stmt.column_text(6),
created_at: stmt.column_text(7).unwrap_or_default(),
updated_at: stmt.column_text(8).unwrap_or_default(),
}
}
fn index_state_from_statement(stmt: &Statement<'_>) -> IndexStateRecord {
let languages_json = stmt.column_text(9).unwrap_or_else(|| "{}".to_string());
let languages = serde_json::from_str(&languages_json).unwrap_or_default();
IndexStateRecord {
repo_name: stmt.column_text(0).unwrap_or_default(),
status: stmt.column_text(1).unwrap_or_default(),
head_commit: stmt.column_text(2),
dirty: parse_bool(stmt.column_text(3)),
untracked: parse_bool(stmt.column_text(4)),
modified: parse_bool(stmt.column_text(5)),
worktree_fingerprint: stmt.column_text(6).unwrap_or_default(),
indexed_files: parse_usize(stmt.column_text(7)),
total_chunks: parse_usize(stmt.column_text(8)),
languages,
content_hash: stmt.column_text(10),
error: stmt.column_text(11),
indexed_at: stmt.column_text(12).unwrap_or_default(),
}
}
fn chunk_from_statement(stmt: &Statement<'_>) -> Chunk {
Chunk::new(
stmt.column_text(0).unwrap_or_default(),
stmt.column_text(1).unwrap_or_default(),
parse_usize(stmt.column_text(2)).max(1),
parse_usize(stmt.column_text(3)).max(1),
stmt.column_text(4),
)
}
fn bool_text(value: bool) -> &'static str {
if value {
"1"
} else {
"0"
}
}
fn parse_bool(value: Option<String>) -> bool {
matches!(value.as_deref(), Some("1") | Some("true") | Some("TRUE"))
}
fn parse_usize(value: Option<String>) -> usize {
value
.and_then(|text| text.parse::<usize>().ok())
.unwrap_or_default()
}
fn expect_done(step: Step, message: &'static str) -> AsrResult<()> {
match step {
Step::Done => Ok(()),
Step::Row => Err(AsrError::new("store_error", message)),
}
}
struct Database {
handle: *mut sqlite3,
}
impl Database {
fn open(path: &Path) -> AsrResult<Self> {
let c_path = cstring(path_string(path), "db_path")?;
let mut handle: *mut sqlite3 = ptr::null_mut();
let rc = unsafe { sqlite3_open(c_path.as_ptr(), &mut handle) };
if rc != SQLITE_OK || handle.is_null() {
let message = if handle.is_null() {
"SQLite open failed".to_string()
} else {
unsafe { sqlite_error(handle) }
};
if !handle.is_null() {
unsafe {
sqlite3_close(handle);
}
}
return Err(AsrError::with_path(
"sqlite_open_failed",
message,
path_string(path),
));
}
let db = Self { handle };
db.exec("PRAGMA busy_timeout = 5000")?;
db.exec("PRAGMA foreign_keys = ON")?;
Ok(db)
}
fn user_version(&self) -> AsrResult<i64> {
let mut stmt = self.prepare("PRAGMA user_version")?;
match stmt.step()? {
Step::Row => Ok(stmt.column_i64(0)),
Step::Done => Ok(0),
}
}
fn exec(&self, sql: &str) -> AsrResult<()> {
let c_sql = cstring(sql, "sql")?;
let rc = unsafe {
sqlite3_exec(
self.handle,
c_sql.as_ptr(),
None,
ptr::null_mut(),
ptr::null_mut(),
)
};
if rc != SQLITE_OK {
return Err(AsrError::new("sqlite_exec_failed", unsafe {
sqlite_error(self.handle)
}));
}
Ok(())
}
fn prepare(&self, sql: &str) -> AsrResult<Statement<'_>> {
let c_sql = cstring(sql, "sql")?;
let mut stmt: *mut sqlite3_stmt = ptr::null_mut();
let rc = unsafe {
sqlite3_prepare_v2(self.handle, c_sql.as_ptr(), -1, &mut stmt, ptr::null_mut())
};
if rc != SQLITE_OK || stmt.is_null() {
return Err(AsrError::new("sqlite_prepare_failed", unsafe {
sqlite_error(self.handle)
}));
}
Ok(Statement {
db: self,
stmt,
bindings: Vec::new(),
})
}
}
impl Drop for Database {
fn drop(&mut self) {
if !self.handle.is_null() {
unsafe {
sqlite3_close_v2(self.handle);
}
}
}
}
struct Statement<'db> {
db: &'db Database,
stmt: *mut sqlite3_stmt,
bindings: Vec<CString>,
}
impl Statement<'_> {
fn bind_text(&mut self, index: c_int, value: &str) -> AsrResult<()> {
let value = cstring(value, "bind_value")?;
let ptr = value.as_ptr();
self.bindings.push(value);
let rc = unsafe { sqlite3_bind_text(self.stmt, index, ptr, -1, None) };
if rc != SQLITE_OK {
return Err(AsrError::new("sqlite_bind_failed", unsafe {
sqlite_error(self.db.handle)
}));
}
Ok(())
}
fn bind_optional_text(&mut self, index: c_int, value: Option<&str>) -> AsrResult<()> {
match value {
Some(value) => self.bind_text(index, value),
None => {
let rc = unsafe { sqlite3_bind_null(self.stmt, index) };
if rc != SQLITE_OK {
return Err(AsrError::new("sqlite_bind_failed", unsafe {
sqlite_error(self.db.handle)
}));
}
Ok(())
}
}
}
fn step(&mut self) -> AsrResult<Step> {
match unsafe { sqlite3_step(self.stmt) } {
SQLITE_ROW => Ok(Step::Row),
SQLITE_DONE => Ok(Step::Done),
_ => Err(AsrError::new("sqlite_step_failed", unsafe {
sqlite_error(self.db.handle)
})),
}
}
fn reset(&mut self) -> AsrResult<()> {
let reset_rc = unsafe { sqlite3_reset(self.stmt) };
if reset_rc != SQLITE_OK {
return Err(AsrError::new("sqlite_reset_failed", unsafe {
sqlite_error(self.db.handle)
}));
}
let clear_rc = unsafe { sqlite3_clear_bindings(self.stmt) };
if clear_rc != SQLITE_OK {
return Err(AsrError::new("sqlite_clear_bindings_failed", unsafe {
sqlite_error(self.db.handle)
}));
}
self.bindings.clear();
Ok(())
}
fn column_i64(&self, index: c_int) -> i64 {
unsafe { sqlite3_column_int64(self.stmt, index) }
}
fn column_text(&self, index: c_int) -> Option<String> {
let ptr = unsafe { sqlite3_column_text(self.stmt, index) };
if ptr.is_null() {
return None;
}
let len = unsafe { sqlite3_column_bytes(self.stmt, index) };
if len < 0 {
return None;
}
let bytes = unsafe { slice::from_raw_parts(ptr, len as usize) };
Some(String::from_utf8_lossy(bytes).to_string())
}
}
impl Drop for Statement<'_> {
fn drop(&mut self) {
if !self.stmt.is_null() {
unsafe {
sqlite3_finalize(self.stmt);
}
}
}
}
enum Step {
Row,
Done,
}
fn cstring(value: impl AsRef<str>, field: &'static str) -> AsrResult<CString> {
CString::new(value.as_ref()).map_err(|_| {
AsrError::new(
"invalid_nul_byte",
format!("Value for {field} contains an unsupported NUL byte"),
)
})
}
unsafe fn sqlite_error(handle: *mut sqlite3) -> String {
let message = sqlite3_errmsg(handle);
if message.is_null() {
return "unknown SQLite error".to_string();
}
CStr::from_ptr(message).to_string_lossy().to_string()
}