use crate::models::CodeChunk;
use anyhow::{Context, Result};
use rusqlite::{ffi::sqlite3_auto_extension, params, Connection};
use sqlite_vec::sqlite3_vec_init;
use std::path::PathBuf;
pub struct Database {
conn: Connection,
}
impl Database {
pub fn init() -> Result<Self> {
Self::init_with_dimension(None)
}
pub fn init_with_dimension(embedding_dim: Option<usize>) -> Result<Self> {
use crate::embed::EmbeddingConfig;
let config = EmbeddingConfig::load_or_default()?;
let dim = embedding_dim.unwrap_or(match config.provider {
crate::embed::EmbeddingProviderType::OpenAI => 1536,
crate::embed::EmbeddingProviderType::Gemma => config.gemma.embedding_dim,
});
unsafe {
sqlite3_auto_extension(Some(std::mem::transmute::<
*const (),
unsafe extern "C" fn(
*mut rusqlite::ffi::sqlite3,
*mut *mut i8,
*const rusqlite::ffi::sqlite3_api_routines,
) -> i32,
>(sqlite3_vec_init as *const ())));
}
let db_path = PathBuf::from(".git/semantic.db");
let conn = Connection::open(&db_path).context("Failed to open database connection")?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS code_chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
content TEXT NOT NULL,
embedding BLOB
);",
)
.context("Failed to create code_chunks table")?;
let table_exists: bool = conn
.query_row(
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='vec_chunks'",
[],
|row| row.get(0),
)
.unwrap_or(0)
> 0;
if !table_exists {
let create_vec_table = format!(
"CREATE VIRTUAL TABLE vec_chunks USING vec0(embedding FLOAT[{}]);",
dim
);
conn.execute_batch(&create_vec_table)
.context("Failed to create vec_chunks virtual table")?;
}
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS vec_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chunk_id INTEGER NOT NULL,
file_path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
content TEXT NOT NULL
);",
)
.context("Failed to create vec_metadata table")?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS clusters (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
description TEXT NOT NULL,
chunks_json TEXT NOT NULL
);",
)
.context("Failed to create clusters table")?;
let cluster_vec_exists: bool = conn
.query_row(
"SELECT count(*) FROM sqlite_master WHERE type='table' AND name='vec_clusters'",
[],
|row| row.get(0),
)
.unwrap_or(0)
> 0;
if !cluster_vec_exists {
let create_cluster_vec = format!(
"CREATE VIRTUAL TABLE vec_clusters USING vec0(embedding FLOAT[{}]);",
dim
);
conn.execute_batch(&create_cluster_vec)
.context("Failed to create vec_clusters virtual table")?;
}
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS edges (
id INTEGER PRIMARY KEY AUTOINCREMENT,
from_file TEXT NOT NULL,
to_file TEXT NOT NULL,
via_json TEXT NOT NULL
);",
)
.context("Failed to create edges table")?;
conn.execute_batch(
"CREATE VIRTUAL TABLE IF NOT EXISTS fts_chunks
USING fts5(file_path UNINDEXED, start_line UNINDEXED, end_line UNINDEXED, content, content=code_chunks, content_rowid=id);",
)
.context("Failed to create fts_chunks table")?;
Ok(Database { conn })
}
pub fn clear(&self) -> Result<()> {
self.conn
.execute_batch(
"DELETE FROM fts_chunks;
DELETE FROM vec_metadata;
DELETE FROM vec_chunks;
DELETE FROM code_chunks;
DELETE FROM clusters;
DELETE FROM vec_clusters;
DELETE FROM edges;",
)
.context("Failed to clear database")
}
pub fn insert_cluster(&self, cluster: &crate::map::Cluster) -> Result<()> {
use zerocopy::IntoBytes;
let chunks_json =
serde_json::to_string(&cluster.chunks).context("Failed to serialize cluster chunks")?;
self.conn.execute(
"INSERT INTO clusters (name, description, chunks_json) VALUES (?1, ?2, ?3)",
params![&cluster.name, &cluster.description, &chunks_json],
)?;
let cluster_id = self.conn.last_insert_rowid();
self.conn.execute(
"INSERT INTO vec_clusters (rowid, embedding) VALUES (?1, ?2)",
params![cluster_id, cluster.description_embedding.as_bytes()],
)?;
Ok(())
}
pub fn insert_edge(&self, edge: &crate::map::Edge) -> Result<()> {
let via_json = serde_json::to_string(&edge.via).context("Failed to serialize edge via")?;
self.conn.execute(
"INSERT INTO edges (from_file, to_file, via_json) VALUES (?1, ?2, ?3)",
params![&edge.from, &edge.to, &via_json],
)?;
Ok(())
}
pub fn query_map(&self, query_embedding: &[f32]) -> Result<Option<crate::map::Cluster>> {
use zerocopy::IntoBytes;
let mut stmt = self.conn.prepare(
"SELECT s.name, s.description, s.chunks_json, v.distance
FROM vec_clusters v
JOIN clusters s ON v.rowid = s.id
WHERE v.embedding MATCH ?1 AND k = 1
ORDER BY distance",
)?;
let mut rows = stmt.query_map(params![query_embedding.as_bytes()], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
))
})?;
if let Some(row) = rows.next() {
let (name, description, chunks_json) = row?;
let chunks: Vec<crate::map::ChunkRef> = serde_json::from_str(&chunks_json)
.map_err(|e| anyhow::anyhow!("Failed to parse chunks: {}", e))?;
Ok(Some(crate::map::Cluster {
name,
description,
description_embedding: vec![],
chunks,
}))
} else {
Ok(None)
}
}
pub fn all_clusters(&self) -> Result<Vec<crate::map::Cluster>> {
let mut stmt = self
.conn
.prepare("SELECT name, description, chunks_json FROM clusters ORDER BY id")?;
let clusters = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
))
})?
.map(|row| {
let (name, description, chunks_json) = row?;
let chunks: Vec<crate::map::ChunkRef> = serde_json::from_str(&chunks_json)
.map_err(|e| anyhow::anyhow!("Failed to parse chunks: {}", e))?;
Ok(crate::map::Cluster {
name,
description,
description_embedding: vec![],
chunks,
})
})
.collect::<Result<Vec<_>>>()?;
Ok(clusters)
}
pub fn edges_into(&self, cluster_files: &[&str]) -> Result<Vec<crate::map::Edge>> {
if cluster_files.is_empty() {
return Ok(vec![]);
}
let n = cluster_files.len();
let in_placeholders = (1..=n)
.map(|i| format!("?{}", i))
.collect::<Vec<_>>()
.join(", ");
let not_in_placeholders = (n + 1..=2 * n)
.map(|i| format!("?{}", i))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT from_file, to_file, via_json FROM edges
WHERE to_file IN ({}) AND from_file NOT IN ({})",
in_placeholders, not_in_placeholders
);
let mut stmt = self.conn.prepare(&sql)?;
let params: Vec<&dyn rusqlite::ToSql> = cluster_files
.iter()
.chain(cluster_files.iter())
.map(|s| s as &dyn rusqlite::ToSql)
.collect();
let edges = stmt
.query_map(params.as_slice(), |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
))
})?
.map(|row| {
let (from, to, via_json) = row?;
let via: Vec<String> = serde_json::from_str(&via_json)
.map_err(|e| anyhow::anyhow!("Failed to parse via: {}", e))?;
Ok(crate::map::Edge { from, to, via })
})
.collect::<Result<Vec<_>>>()?;
Ok(edges)
}
pub fn insert_chunk(&self, chunk: &CodeChunk) -> Result<()> {
use zerocopy::IntoBytes;
let embedding_blob =
bincode::serialize(&chunk.embedding).context("Failed to serialize embedding")?;
self.conn
.execute(
"INSERT INTO code_chunks (file_path, start_line, end_line, content, embedding)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![
&chunk.file_path,
chunk.start_line,
chunk.end_line,
&chunk.content,
&embedding_blob
],
)
.context("Failed to insert chunk into database")?;
let chunk_id = self.conn.last_insert_rowid();
self.conn
.execute(
"INSERT INTO vec_chunks (rowid, embedding) VALUES (?1, ?2)",
params![chunk_id, chunk.embedding.as_bytes()],
)
.context("Failed to insert into vec_chunks")?;
self.conn
.execute(
"INSERT INTO vec_metadata (chunk_id, file_path, start_line, end_line, content)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![
chunk_id,
&chunk.file_path,
chunk.start_line,
chunk.end_line,
&chunk.content,
],
)
.context("Failed to insert metadata")?;
self.conn
.execute(
"INSERT INTO fts_chunks (rowid, file_path, start_line, end_line, content)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![
chunk_id,
&chunk.file_path,
chunk.start_line,
chunk.end_line,
&chunk.content,
],
)
.context("Failed to insert into fts_chunks")?;
Ok(())
}
pub fn get_chunk_by_location(
&self,
file_path: &str,
start_line: i64,
end_line: i64,
) -> Result<Option<CodeChunk>> {
let mut stmt = self.conn.prepare(
"SELECT file_path, start_line, end_line, content, embedding
FROM code_chunks
WHERE file_path = ?1 AND start_line = ?2 AND end_line = ?3
LIMIT 1",
)?;
let mut rows = stmt.query_map(params![file_path, start_line, end_line], |row| {
let embedding_blob: Vec<u8> = row.get(4)?;
let embedding: Vec<f32> = bincode::deserialize(&embedding_blob)
.map_err(|_e| rusqlite::Error::InvalidQuery)?;
Ok(CodeChunk {
file_path: row.get(0)?,
start_line: row.get(1)?,
end_line: row.get(2)?,
content: row.get(3)?,
embedding,
distance: None,
})
})?;
if let Some(chunk) = rows.next().transpose()? {
return Ok(Some(chunk));
}
self.get_chunks_overlapping(file_path, start_line, end_line)
}
fn get_chunks_overlapping(
&self,
file_path: &str,
start_line: i64,
end_line: i64,
) -> Result<Option<CodeChunk>> {
let mut stmt = self.conn.prepare(
"SELECT file_path, start_line, end_line, content, embedding
FROM code_chunks
WHERE file_path = ?1
AND start_line < ?3
AND end_line > ?2
ORDER BY start_line",
)?;
let chunks: Vec<CodeChunk> = stmt
.query_map(params![file_path, start_line, end_line], |row| {
let embedding_blob: Vec<u8> = row.get(4)?;
let embedding: Vec<f32> = bincode::deserialize(&embedding_blob)
.map_err(|_e| rusqlite::Error::InvalidQuery)?;
Ok(CodeChunk {
file_path: row.get(0)?,
start_line: row.get(1)?,
end_line: row.get(2)?,
content: row.get(3)?,
embedding,
distance: None,
})
})?
.collect::<Result<Vec<_>, _>>()?;
if chunks.is_empty() {
return Ok(None);
}
let merged_start = chunks.first().unwrap().start_line;
let merged_end = chunks.last().unwrap().end_line;
let merged_content = chunks
.iter()
.map(|c| c.content.as_str())
.collect::<Vec<_>>()
.join("\n");
Ok(Some(CodeChunk {
file_path: file_path.to_string(),
start_line: merged_start,
end_line: merged_end,
content: merged_content,
embedding: chunks.into_iter().next().unwrap().embedding,
distance: None,
}))
}
pub fn search_bm25(&self, query: &str, limit: i64) -> Result<Vec<CodeChunk>> {
let escaped = query.replace('"', "\"\"");
let fts_query = format!("\"{}\"", escaped);
let mut stmt = self.conn.prepare(
"SELECT c.file_path, c.start_line, c.end_line, c.content, c.embedding,
bm25(fts_chunks) AS score
FROM fts_chunks
JOIN code_chunks c ON fts_chunks.rowid = c.id
WHERE fts_chunks MATCH ?1
ORDER BY score
LIMIT ?2",
)?;
let chunks = stmt
.query_map(params![fts_query, limit], |row| {
let embedding_blob: Vec<u8> = row.get(4)?;
let embedding: Vec<f32> = bincode::deserialize(&embedding_blob)
.map_err(|_e| rusqlite::Error::InvalidQuery)?;
Ok(CodeChunk {
file_path: row.get(0)?,
start_line: row.get(1)?,
end_line: row.get(2)?,
content: row.get(3)?,
embedding,
distance: row.get::<_, Option<f64>>(5)?.map(|s| s as f32),
})
})?
.collect::<Result<Vec<_>, _>>()?;
Ok(chunks)
}
pub fn search_hybrid(
&self,
query: &str,
query_embedding: &[f32],
limit: i64,
) -> Result<Vec<CodeChunk>> {
let semantic = self.search_similar(query_embedding, limit)?;
let bm25 = self.search_bm25(query, limit).unwrap_or_default();
let key = |c: &CodeChunk| format!("{}:{}-{}", c.file_path, c.start_line, c.end_line);
let file_key = |c: &CodeChunk| c.file_path.clone();
let mut scores: std::collections::HashMap<String, f32> = std::collections::HashMap::new();
for (rank, chunk) in semantic.iter().enumerate() {
*scores.entry(key(chunk)).or_insert(0.0) += 1.0 / (60.0 + rank as f32 + 1.0);
}
for (rank, chunk) in bm25.iter().enumerate() {
*scores.entry(key(chunk)).or_insert(0.0) += 1.0 / (60.0 + rank as f32 + 1.0);
}
let top_files: Vec<String> = semantic
.iter()
.take(5)
.map(file_key)
.collect::<std::collections::HashSet<_>>()
.into_iter()
.collect();
let top_file_refs: Vec<&str> = top_files.iter().map(|s| s.as_str()).collect();
let connected = self.connected_files(&top_file_refs).unwrap_or_default();
let mut all: Vec<CodeChunk> = semantic.into_iter().chain(bm25).collect();
all.dedup_by(|a, b| key(a) == key(b));
let graph_ranked: Vec<usize> = all
.iter()
.enumerate()
.filter(|(_, c)| connected.contains(&c.file_path))
.map(|(i, _)| i)
.collect();
for (rank, idx) in graph_ranked.iter().enumerate() {
let k = key(&all[*idx]);
*scores.entry(k).or_insert(0.0) += 1.0 / (60.0 + rank as f32 + 1.0);
}
all.sort_by(|a, b| {
let sa = scores.get(&key(a)).copied().unwrap_or(0.0);
let sb = scores.get(&key(b)).copied().unwrap_or(0.0);
sb.partial_cmp(&sa).unwrap_or(std::cmp::Ordering::Equal)
});
all.truncate(limit as usize);
for chunk in &mut all {
chunk.distance = scores.get(&key(chunk)).copied();
}
Ok(all)
}
pub fn file_embeddings_for(&self, files: &[&str]) -> Result<Vec<(String, Vec<f32>)>> {
if files.is_empty() {
return Ok(vec![]);
}
let placeholders = (1..=files.len())
.map(|i| format!("?{}", i))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT file_path, embedding FROM code_chunks WHERE file_path IN ({})",
placeholders
);
let mut stmt = self.conn.prepare(&sql)?;
let params: Vec<&dyn rusqlite::ToSql> =
files.iter().map(|s| s as &dyn rusqlite::ToSql).collect();
let mut by_file: std::collections::HashMap<String, Vec<Vec<f32>>> =
std::collections::HashMap::new();
let rows = stmt.query_map(params.as_slice(), |row| {
let file: String = row.get(0)?;
let blob: Vec<u8> = row.get(1)?;
Ok((file, blob))
})?;
for row in rows {
let (file, blob) = row?;
let emb: Vec<f32> = bincode::deserialize(&blob)
.map_err(|_| anyhow::anyhow!("Failed to deserialize embedding"))?;
by_file.entry(file).or_default().push(emb);
}
let result = by_file
.into_iter()
.map(|(file, embs)| {
let dim = embs[0].len();
let n = embs.len() as f32;
let mut avg = vec![0.0f32; dim];
for e in &embs {
for (i, v) in e.iter().enumerate() {
avg[i] += v;
}
}
avg.iter_mut().for_each(|v| *v /= n);
(file, avg)
})
.collect();
Ok(result)
}
pub fn cluster_embeddings(&self) -> Result<Vec<(String, String, Vec<f32>)>> {
let clusters = self.all_clusters()?;
let mut result = Vec::new();
for sub in clusters {
let files: Vec<&str> = sub.chunks.iter().map(|c| c.file.as_str()).collect();
let file_embs = self.file_embeddings_for(&files).unwrap_or_default();
if file_embs.is_empty() {
continue;
}
let dim = file_embs[0].1.len();
let n = file_embs.len() as f32;
let mut avg = vec![0.0f32; dim];
for (_, emb) in &file_embs {
for (i, v) in emb.iter().enumerate() {
avg[i] += v;
}
}
avg.iter_mut().for_each(|v| *v /= n);
result.push((sub.name, sub.description, avg));
}
Ok(result)
}
pub fn edges_for_file(&self, file_path: &str) -> Result<Vec<crate::map::Edge>> {
let mut stmt = self.conn.prepare(
"SELECT from_file, to_file, via_json FROM edges WHERE to_file = ?1 AND from_file != ?1",
)?;
let edges = stmt
.query_map(params![file_path], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
))
})?
.map(|row| {
let (from, to, via_json) = row?;
let via: Vec<String> = serde_json::from_str(&via_json)
.map_err(|e| anyhow::anyhow!("Failed to parse via: {}", e))?;
Ok(crate::map::Edge { from, to, via })
})
.collect::<Result<Vec<_>>>()?;
Ok(edges)
}
pub fn connected_files(&self, files: &[&str]) -> Result<std::collections::HashSet<String>> {
if files.is_empty() {
return Ok(std::collections::HashSet::new());
}
let placeholders = (1..=files.len())
.map(|i| format!("?{}", i))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT DISTINCT from_file FROM edges WHERE to_file IN ({0})
UNION
SELECT DISTINCT to_file FROM edges WHERE from_file IN ({0})",
placeholders
);
let mut stmt = self.conn.prepare(&sql)?;
let params: Vec<&dyn rusqlite::ToSql> = files
.iter()
.chain(files.iter())
.map(|s| s as &dyn rusqlite::ToSql)
.collect();
let connected = stmt
.query_map(params.as_slice(), |row| row.get::<_, String>(0))?
.collect::<Result<std::collections::HashSet<_>, _>>()?;
Ok(connected)
}
pub fn all_edges(&self) -> Result<Vec<crate::map::Edge>> {
let mut stmt = self
.conn
.prepare("SELECT from_file, to_file, via_json FROM edges")?;
let edges = stmt
.query_map([], |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
))
})?
.map(|row| {
let (from, to, via_json) = row?;
let via: Vec<String> = serde_json::from_str(&via_json)
.map_err(|e| anyhow::anyhow!("Failed to parse via: {}", e))?;
Ok(crate::map::Edge { from, to, via })
})
.collect::<Result<Vec<_>>>()?;
Ok(edges)
}
pub fn get_chunks_for_file(&self, file_path: &str) -> Result<Vec<CodeChunk>> {
let mut stmt = self.conn.prepare(
"SELECT file_path, start_line, end_line, content, embedding
FROM code_chunks
WHERE file_path = ?1
ORDER BY start_line",
)?;
let chunks = stmt
.query_map(params![file_path], |row| {
let embedding_blob: Vec<u8> = row.get(4)?;
let embedding: Vec<f32> = bincode::deserialize(&embedding_blob)
.map_err(|_e| rusqlite::Error::InvalidQuery)?;
Ok(CodeChunk {
file_path: row.get(0)?,
start_line: row.get(1)?,
end_line: row.get(2)?,
content: row.get(3)?,
embedding,
distance: None,
})
})?
.collect::<Result<Vec<_>, _>>()?;
Ok(chunks)
}
pub fn search_similar(&self, query_embedding: &[f32], limit: i64) -> Result<Vec<CodeChunk>> {
use zerocopy::IntoBytes;
let mut stmt = self.conn.prepare(
"SELECT m.file_path, m.start_line, m.end_line, m.content, c.embedding, distance
FROM vec_chunks v
JOIN vec_metadata m ON v.rowid = m.chunk_id
JOIN code_chunks c ON c.id = m.chunk_id
WHERE v.embedding MATCH ?1
AND k = ?2
ORDER BY distance",
)?;
let chunks = stmt
.query_map(params![query_embedding.as_bytes(), limit], |row| {
let embedding_blob: Vec<u8> = row.get(4)?;
let embedding: Vec<f32> = bincode::deserialize(&embedding_blob)
.map_err(|_e| rusqlite::Error::InvalidQuery)?;
Ok(CodeChunk {
file_path: row.get(0)?,
start_line: row.get(1)?,
end_line: row.get(2)?,
content: row.get(3)?,
embedding,
distance: row.get(5).ok(),
})
})?
.collect::<Result<Vec<_>, _>>()?;
Ok(chunks)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::CodeChunk;
use std::fs;
fn create_test_db() -> Result<Database> {
use std::time::{SystemTime, UNIX_EPOCH};
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_nanos();
let test_db_path = std::env::temp_dir().join(format!(
"test_semantic_{}_{}.db",
std::process::id(),
timestamp
));
let _ = fs::remove_file(&test_db_path);
unsafe {
sqlite3_auto_extension(Some(std::mem::transmute::<
*const (),
unsafe extern "C" fn(
*mut rusqlite::ffi::sqlite3,
*mut *mut i8,
*const rusqlite::ffi::sqlite3_api_routines,
) -> i32,
>(sqlite3_vec_init as *const ())));
}
let conn = Connection::open(&test_db_path)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS code_chunks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
file_path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
content TEXT NOT NULL,
embedding BLOB
);",
)?;
conn.execute_batch(
"CREATE VIRTUAL TABLE IF NOT EXISTS vec_chunks USING vec0(
embedding FLOAT[1536]
);",
)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS vec_metadata (
id INTEGER PRIMARY KEY AUTOINCREMENT,
chunk_id INTEGER NOT NULL,
file_path TEXT NOT NULL,
start_line INTEGER NOT NULL,
end_line INTEGER NOT NULL,
content TEXT NOT NULL
);",
)?;
conn.execute_batch(
"CREATE VIRTUAL TABLE IF NOT EXISTS fts_chunks
USING fts5(file_path UNINDEXED, start_line UNINDEXED, end_line UNINDEXED, content, content=code_chunks, content_rowid=id);",
)?;
Ok(Database { conn })
}
#[test]
fn test_database_init() {
let db = create_test_db();
assert!(db.is_ok());
}
#[test]
fn test_insert_chunk() {
let db = create_test_db().unwrap();
let chunk = CodeChunk {
file_path: "test.rs".to_string(),
start_line: 1,
end_line: 10,
content: "test content".to_string(),
embedding: vec![0.5; 1536],
distance: None,
};
let result = db.insert_chunk(&chunk);
assert!(result.is_ok());
}
#[test]
fn test_insert_and_search() {
let db = create_test_db().unwrap();
let chunk1 = CodeChunk {
file_path: "file1.rs".to_string(),
start_line: 1,
end_line: 5,
content: "authentication logic".to_string(),
embedding: vec![1.0; 1536],
distance: None,
};
let chunk2 = CodeChunk {
file_path: "file2.rs".to_string(),
start_line: 10,
end_line: 20,
content: "database connection".to_string(),
embedding: vec![0.5; 1536],
distance: None,
};
db.insert_chunk(&chunk1).unwrap();
db.insert_chunk(&chunk2).unwrap();
let query_embedding = vec![0.9; 1536];
let results = db.search_similar(&query_embedding, 2).unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].distance.is_some());
}
#[test]
fn test_search_similar_ordering() {
let db = create_test_db().unwrap();
let chunk = CodeChunk {
file_path: "test.rs".to_string(),
start_line: 1,
end_line: 5,
content: "test".to_string(),
embedding: vec![1.0; 1536],
distance: None,
};
db.insert_chunk(&chunk).unwrap();
let results = db.search_similar(&vec![1.0; 1536], 1).unwrap();
assert_eq!(results.len(), 1);
if let Some(dist) = results[0].distance {
assert!(dist >= 0.0);
}
}
}