#![allow(unsafe_code)]
use std::fs;
use std::path::PathBuf;
use anyhow::{Context, Result};
use rusqlite::ffi::sqlite3_auto_extension;
use rusqlite::Connection;
use tracing::{debug, info};
use zerocopy::AsBytes;
use crate::setup::ManpageEntry;
const DB_FILENAME: &str = "index.db";
pub fn get_database_path() -> Result<PathBuf> {
let dirs = directories::ProjectDirs::from("", "", "ulm")
.context("Could not determine data directory")?;
let data_dir = dirs.data_dir();
fs::create_dir_all(data_dir)
.with_context(|| format!("Failed to create data directory: {}", data_dir.display()))?;
Ok(data_dir.join(DB_FILENAME))
}
fn init_sqlite_vec() {
use std::sync::Once;
static INIT: Once = Once::new();
INIT.call_once(|| {
#[allow(clippy::missing_transmute_annotations)]
unsafe {
sqlite3_auto_extension(Some(std::mem::transmute(
sqlite_vec::sqlite3_vec_init as *const (),
)));
}
});
}
fn open_connection(path: &PathBuf) -> Result<Connection> {
init_sqlite_vec();
let conn = Connection::open(path)
.with_context(|| format!("Failed to open database: {}", path.display()))?;
Ok(conn)
}
#[allow(clippy::unused_async)] pub async fn create_index(entries: Vec<ManpageEntry>) -> Result<()> {
let db_path = get_database_path()?;
info!(path = %db_path.display(), entries = entries.len(), "Creating vector index");
let conn = open_connection(&db_path)?;
let vector_dim = entries.first().map_or(768, |e| e.vector.len());
conn.execute("DROP TABLE IF EXISTS manpages_vec", [])
.context("Failed to drop vector table")?;
conn.execute("DROP TABLE IF EXISTS manpages", [])
.context("Failed to drop manpages table")?;
conn.execute(
"CREATE TABLE manpages (
id INTEGER PRIMARY KEY,
tool_name TEXT NOT NULL,
section TEXT NOT NULL,
description TEXT NOT NULL
)",
[],
)
.context("Failed to create manpages table")?;
conn.execute(
&format!(
"CREATE VIRTUAL TABLE manpages_vec USING vec0(
id INTEGER PRIMARY KEY,
embedding FLOAT[{vector_dim}]
)"
),
[],
)
.context("Failed to create vector table")?;
let mut stmt = conn
.prepare("INSERT INTO manpages (tool_name, section, description) VALUES (?1, ?2, ?3)")
.context("Failed to prepare insert statement")?;
let mut vec_stmt = conn
.prepare("INSERT INTO manpages_vec (id, embedding) VALUES (?1, ?2)")
.context("Failed to prepare vector insert statement")?;
for entry in &entries {
stmt.execute(rusqlite::params![
entry.tool_name,
entry.section,
entry.description
])
.context("Failed to insert manpage")?;
let id = conn.last_insert_rowid();
let vector_blob = entry.vector.as_bytes();
vec_stmt
.execute(rusqlite::params![id, vector_blob])
.context("Failed to insert vector")?;
}
info!(
"Created index with {} entries (dimension: {})",
entries.len(),
vector_dim
);
Ok(())
}
#[allow(clippy::unused_async)] pub async fn index_exists() -> Result<bool> {
let db_path = get_database_path()?;
if !db_path.exists() {
return Ok(false);
}
let conn = open_connection(&db_path)?;
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='manpages'",
[],
|row| row.get(0),
)
.context("Failed to check table existence")?;
Ok(count > 0)
}
#[derive(Debug, Clone)]
pub struct SearchResult {
pub tool_name: String,
pub section: String,
pub description: String,
pub score: f32,
}
#[allow(clippy::unused_async)] pub async fn search(query_vector: &[f32], limit: usize) -> Result<Vec<SearchResult>> {
let db_path = get_database_path()?;
let conn = open_connection(&db_path)?;
let query_blob = query_vector.as_bytes();
let mut stmt = conn
.prepare(
"SELECT
m.tool_name,
m.section,
m.description,
v.distance
FROM manpages_vec v
JOIN manpages m ON m.id = v.id
WHERE v.embedding MATCH ?1 AND k = ?2
ORDER BY v.distance",
)
.context("Failed to prepare search query")?;
let results = stmt
.query_map(rusqlite::params![query_blob, limit], |row| {
Ok(SearchResult {
tool_name: row.get(0)?,
section: row.get(1)?,
description: row.get(2)?,
score: row.get(3)?,
})
})
.context("Failed to execute search")?;
let mut search_results = Vec::new();
for result in results {
search_results.push(result.context("Failed to read search result")?);
}
debug!(count = search_results.len(), "Vector search completed");
Ok(search_results)
}
#[allow(clippy::unused_async)] pub async fn count_entries() -> Result<usize> {
let db_path = get_database_path()?;
let conn = open_connection(&db_path)?;
let count: i64 = conn
.query_row("SELECT COUNT(*) FROM manpages", [], |row| row.get(0))
.context("Failed to count entries")?;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok(count as usize)
}
#[cfg(test)]
mod tests {
use super::*;
#[allow(dead_code)]
fn create_test_entries(count: usize) -> Vec<ManpageEntry> {
(0..count)
.map(|i| ManpageEntry {
tool_name: format!("tool{i}"),
section: "1".to_string(),
description: format!("Description for tool {i}"),
#[allow(clippy::cast_precision_loss)]
vector: vec![0.1 * i as f32; 8], })
.collect()
}
#[test]
fn test_get_database_path() {
let path = get_database_path().unwrap();
assert!(path.to_string_lossy().contains("ulm"));
assert!(path.to_string_lossy().ends_with("index.db"));
}
}