use libsql::params;
use crate::db::Database;
use crate::errors::Result;
use crate::types::Node;
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
return 0.0;
}
dot / (mag_a * mag_b)
}
pub async fn store_vector(
db: &Database,
node_id: &str,
embedding: &[f32],
model: &str,
) -> Result<()> {
let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
db.conn()
.execute(
"INSERT OR REPLACE INTO vectors (node_id, embedding, model, created_at) VALUES (?1, ?2, ?3, ?4)",
params![node_id, bytes, model, now],
)
.await?;
Ok(())
}
pub async fn get_vector(db: &Database, node_id: &str) -> Result<Option<Vec<f32>>> {
let mut rows = db
.conn()
.query(
"SELECT embedding FROM vectors WHERE node_id = ?1",
params![node_id],
)
.await?;
match rows.next().await? {
Some(row) => {
let bytes: Vec<u8> = row.get(0)?;
let floats = bytes_to_f32s(&bytes);
Ok(Some(floats))
}
None => Ok(None),
}
}
pub async fn brute_force_search(
db: &Database,
query: &[f32],
limit: usize,
) -> Result<Vec<(String, f32)>> {
let mut rows = db
.conn()
.query("SELECT node_id, embedding FROM vectors", ())
.await?;
let mut scored: Vec<(String, f32)> = Vec::new();
while let Some(row) = rows.next().await? {
let node_id: String = row.get(0)?;
let bytes: Vec<u8> = row.get(1)?;
let embedding = bytes_to_f32s(&bytes);
let score = cosine_similarity(query, &embedding);
scored.push((node_id, score));
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
Ok(scored)
}
pub async fn vector_count(db: &Database) -> Result<usize> {
let mut rows = db
.conn()
.query("SELECT COUNT(*) FROM vectors", ())
.await?;
let row = rows
.next()
.await?
.ok_or_else(|| crate::errors::TokenSaveError::Vector {
message: "COUNT query returned no rows".to_string(),
})?;
let count: i64 = row.get(0)?;
Ok(count as usize)
}
pub async fn delete_vector(db: &Database, node_id: &str) -> Result<()> {
db.conn()
.execute(
"DELETE FROM vectors WHERE node_id = ?1",
params![node_id],
)
.await?;
Ok(())
}
pub async fn clear_vectors(db: &Database) -> Result<()> {
db.conn().execute("DELETE FROM vectors", ()).await?;
Ok(())
}
pub fn create_node_text(node: &Node) -> String {
let mut parts = Vec::new();
parts.push(format!("kind: {}", node.kind.as_str()));
parts.push(format!("name: {}", node.name));
parts.push(format!("qualified_name: {}", node.qualified_name));
parts.push(format!("file: {}", node.file_path));
if let Some(ref sig) = node.signature {
parts.push(format!("signature: {sig}"));
}
if let Some(ref doc) = node.docstring {
parts.push(format!("docstring: {doc}"));
}
parts.join("\n")
}
fn bytes_to_f32s(bytes: &[u8]) -> Vec<f32> {
bytes
.chunks_exact(4)
.map(|chunk| {
let arr: [u8; 4] = [chunk[0], chunk[1], chunk[2], chunk[3]];
f32::from_le_bytes(arr)
})
.collect()
}