1use crate::debug_log; use anyhow::Result;
3use dashmap::DashMap;
4use hnsw_rs::prelude::*;
5use parking_lot::RwLock;
6use rayon::prelude::*;
7use rusqlite::{params, Connection};
8use serde::{Deserialize, Serialize};
9use std::fs;
10use std::path::PathBuf;
11use std::sync::Arc;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct VectorEntry {
15 pub id: i64,
16 pub text: String,
17 pub vector: Vec<f64>,
18 pub model: String,
19 pub provider: String,
20 pub created_at: chrono::DateTime<chrono::Utc>,
21 pub file_path: Option<String>,
22 pub chunk_index: Option<i32>,
23 pub total_chunks: Option<i32>,
24}
25
26type HnswIndex = Hnsw<'static, f64, DistCosine>;
28
29pub struct VectorDatabase {
30 db_path: PathBuf,
31 hnsw_index: Arc<RwLock<Option<HnswIndex>>>,
33 vector_cache: Arc<DashMap<i64, VectorEntry>>,
35 index_dirty: Arc<RwLock<bool>>,
37}
38
39impl std::fmt::Debug for VectorDatabase {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 f.debug_struct("VectorDatabase")
42 .field("db_path", &self.db_path)
43 .field("vector_cache_len", &self.vector_cache.len())
44 .field("index_dirty", &self.index_dirty)
45 .finish()
46 }
47}
48
49impl VectorDatabase {
50 pub fn new(name: &str) -> Result<Self> {
51 let embeddings_dir = Self::embeddings_dir()?;
52 fs::create_dir_all(&embeddings_dir)?;
53
54 let db_path = embeddings_dir.join(format!("{}.db", name));
55
56 let db = Self {
57 db_path,
58 hnsw_index: Arc::new(RwLock::new(None)),
59 vector_cache: Arc::new(DashMap::new()),
60 index_dirty: Arc::new(RwLock::new(true)),
61 };
62
63 db.initialize()?;
64 Ok(db)
65 }
66
67 pub fn embeddings_dir() -> Result<PathBuf> {
68 let config_dir = crate::config::Config::config_dir()?;
70 Ok(config_dir.join("embeddings"))
71 }
72
73 pub fn list_databases() -> Result<Vec<String>> {
74 let embeddings_dir = Self::embeddings_dir()?;
75 Self::list_databases_in_dir(&embeddings_dir)
76 }
77
78 pub fn list_databases_in_dir(embeddings_dir: &std::path::Path) -> Result<Vec<String>> {
79 if !embeddings_dir.exists() {
80 return Ok(Vec::new());
81 }
82
83 let mut databases = Vec::new();
84
85 for entry in fs::read_dir(embeddings_dir)? {
86 let entry = entry?;
87 let path = entry.path();
88
89 if path.is_file() {
90 if let Some(extension) = path.extension() {
91 if extension == "db" {
92 if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
93 databases.push(name.to_string());
94 }
95 }
96 }
97 }
98 }
99
100 databases.sort();
101 Ok(databases)
102 }
103
104 pub fn delete_database(name: &str) -> Result<()> {
105 let embeddings_dir = Self::embeddings_dir()?;
106 Self::delete_database_in_dir(name, &embeddings_dir)
107 }
108
109 pub fn delete_database_in_dir(name: &str, embeddings_dir: &std::path::Path) -> Result<()> {
110 let db_path = embeddings_dir.join(format!("{}.db", name));
111
112 if db_path.exists() {
113 fs::remove_file(db_path)?;
114 }
115
116 Ok(())
117 }
118
119 fn initialize(&self) -> Result<()> {
120 let conn = Connection::open(&self.db_path)?;
121
122 conn.execute(
124 "CREATE TABLE IF NOT EXISTS vectors (
125 id INTEGER PRIMARY KEY AUTOINCREMENT,
126 text TEXT NOT NULL,
127 vector BLOB NOT NULL,
128 model TEXT NOT NULL,
129 provider TEXT NOT NULL,
130 created_at TEXT NOT NULL
131 )",
132 [],
133 )?;
134
135 let mut has_file_path = false;
137 let mut has_chunk_index = false;
138 let mut has_total_chunks = false;
139
140 let mut stmt = conn.prepare("PRAGMA table_info(vectors)")?;
142 let column_iter = stmt.query_map([], |row| {
143 let column_name: String = row.get(1)?;
144 Ok(column_name)
145 })?;
146
147 for column_result in column_iter {
148 let column_name = column_result?;
149 match column_name.as_str() {
150 "file_path" => has_file_path = true,
151 "chunk_index" => has_chunk_index = true,
152 "total_chunks" => has_total_chunks = true,
153 _ => {}
154 }
155 }
156
157 if !has_file_path {
159 conn.execute("ALTER TABLE vectors ADD COLUMN file_path TEXT", [])?;
160 }
161 if !has_chunk_index {
162 conn.execute("ALTER TABLE vectors ADD COLUMN chunk_index INTEGER", [])?;
163 }
164 if !has_total_chunks {
165 conn.execute("ALTER TABLE vectors ADD COLUMN total_chunks INTEGER", [])?;
166 }
167
168 conn.execute(
170 "CREATE INDEX IF NOT EXISTS idx_model_provider ON vectors(model, provider)",
171 [],
172 )?;
173
174 conn.execute(
176 "CREATE INDEX IF NOT EXISTS idx_file_path ON vectors(file_path)",
177 [],
178 )?;
179
180 Ok(())
181 }
182
183 pub fn add_vector(
184 &self,
185 text: &str,
186 vector: &[f64],
187 model: &str,
188 provider: &str,
189 ) -> Result<i64> {
190 self.add_vector_with_metadata(text, vector, model, provider, None, None, None)
191 }
192
193 pub fn add_vector_with_metadata(
194 &self,
195 text: &str,
196 vector: &[f64],
197 model: &str,
198 provider: &str,
199 file_path: Option<&str>,
200 chunk_index: Option<i32>,
201 total_chunks: Option<i32>,
202 ) -> Result<i64> {
203 let conn = Connection::open(&self.db_path)?;
204
205 let vector_json = serde_json::to_string(vector)?;
207 let created_at = chrono::Utc::now().to_rfc3339();
208
209 conn.execute(
210 "INSERT INTO vectors (text, vector, model, provider, created_at, file_path, chunk_index, total_chunks) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
211 params![text, vector_json, model, provider, created_at, file_path, chunk_index, total_chunks],
212 )?;
213
214 let id = conn.last_insert_rowid();
215
216 let vector_entry = VectorEntry {
218 id,
219 text: text.to_string(),
220 vector: vector.to_vec(),
221 model: model.to_string(),
222 provider: provider.to_string(),
223 created_at: chrono::Utc::now(),
224 file_path: file_path.map(|s| s.to_string()),
225 chunk_index,
226 total_chunks,
227 };
228
229 self.vector_cache.insert(id, vector_entry);
231
232 *self.index_dirty.write() = true;
234
235 Ok(id)
236 }
237
238 pub fn get_all_vectors(&self) -> Result<Vec<VectorEntry>> {
239 let conn = Connection::open(&self.db_path)?;
240
241 let mut stmt = conn.prepare(
242 "SELECT id, text, vector, model, provider, created_at, file_path, chunk_index, total_chunks FROM vectors ORDER BY created_at DESC"
243 )?;
244
245 let vector_iter = stmt.query_map([], |row| {
246 let vector_json: String = row.get(2)?;
247 let vector: Vec<f64> = serde_json::from_str(&vector_json).map_err(|_e| {
248 rusqlite::Error::InvalidColumnType(
249 2,
250 "vector".to_string(),
251 rusqlite::types::Type::Text,
252 )
253 })?;
254
255 let created_at_str: String = row.get(5)?;
256 let created_at = chrono::DateTime::parse_from_rfc3339(&created_at_str)
257 .map_err(|_| {
258 rusqlite::Error::InvalidColumnType(
259 5,
260 "created_at".to_string(),
261 rusqlite::types::Type::Text,
262 )
263 })?
264 .with_timezone(&chrono::Utc);
265
266 Ok(VectorEntry {
267 id: row.get(0)?,
268 text: row.get(1)?,
269 vector,
270 model: row.get(3)?,
271 provider: row.get(4)?,
272 created_at,
273 file_path: row.get(6).ok(),
274 chunk_index: row.get(7).ok(),
275 total_chunks: row.get(8).ok(),
276 })
277 })?;
278
279 let mut vectors = Vec::new();
280 for vector in vector_iter {
281 vectors.push(vector?);
282 }
283
284 Ok(vectors)
285 }
286
287 pub fn get_model_info(&self) -> Result<Option<(String, String)>> {
288 let conn = Connection::open(&self.db_path)?;
289
290 let mut stmt = conn.prepare("SELECT model, provider FROM vectors LIMIT 1")?;
291
292 let mut rows = stmt.query_map([], |row| {
293 Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
294 })?;
295
296 if let Some(row) = rows.next() {
297 Ok(Some(row?))
298 } else {
299 Ok(None)
300 }
301 }
302
303 pub fn find_similar(
304 &self,
305 query_vector: &[f64],
306 limit: usize,
307 ) -> Result<Vec<(VectorEntry, f64)>> {
308 self.ensure_index_built()?;
310
311 if let Some(index) = self.hnsw_index.read().as_ref() {
313 if !self.vector_cache.is_empty() {
315 let first_entry = self.vector_cache.iter().next();
316 if let Some(entry) = first_entry {
317 let stored_dimension = entry.vector.len();
318 if query_vector.len() != stored_dimension {
319 debug_log!("Dimension mismatch: query={}, stored={}, falling back to linear search",
320 query_vector.len(), stored_dimension);
321 return self.find_similar_linear_optimized(query_vector, limit);
322 }
323 }
324 }
325
326 let hnsw_limit = std::cmp::min(limit * 2, self.vector_cache.len());
328 let search_results = index.search(query_vector, hnsw_limit, 50); let mut results = Vec::with_capacity(limit);
331 for neighbor in search_results {
332 if let Some(entry) = self.vector_cache.get(&(neighbor.d_id as i64)) {
333 let similarity = 1.0 - neighbor.distance as f64;
335 results.push((entry.value().clone(), similarity));
336
337 if results.len() >= limit {
339 break;
340 }
341 }
342 }
343
344 if results.len() < limit && results.len() < self.vector_cache.len() {
346 debug_log!(
347 "HNSW returned only {} results, falling back to linear search",
348 results.len()
349 );
350 return self.find_similar_linear_optimized(query_vector, limit);
351 }
352
353 return Ok(results);
354 }
355
356 self.find_similar_linear_optimized(query_vector, limit)
358 }
359
360 fn find_similar_linear_optimized(
362 &self,
363 query_vector: &[f64],
364 limit: usize,
365 ) -> Result<Vec<(VectorEntry, f64)>> {
366 let vectors = if self.vector_cache.is_empty() {
368 self.get_all_vectors()?
369 } else {
370 self.vector_cache
371 .iter()
372 .map(|entry| entry.value().clone())
373 .collect::<Vec<_>>()
374 };
375
376 let mut similarities: Vec<(VectorEntry, f64)> = vectors
378 .into_par_iter()
379 .map(|vector_entry| {
380 let similarity = cosine_similarity_simd(query_vector, &vector_entry.vector);
381 (vector_entry, similarity)
382 })
383 .collect();
384
385 if limit < similarities.len() {
387 similarities.select_nth_unstable_by(limit, |a, b| {
388 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
389 });
390 similarities[..limit]
391 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
392 similarities.truncate(limit);
393 } else {
394 similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
395 }
396
397 Ok(similarities)
398 }
399
400 fn ensure_index_built(&self) -> Result<()> {
402 let index_dirty = *self.index_dirty.read();
403
404 if index_dirty || self.hnsw_index.read().is_none() {
405 self.rebuild_index()?;
406 }
407
408 Ok(())
409 }
410
411 fn rebuild_index(&self) -> Result<()> {
413 debug_log!("Rebuilding HNSW index...");
414
415 if self.vector_cache.is_empty() {
417 let vectors = self.get_all_vectors()?;
418 for vector in vectors {
419 self.vector_cache.insert(vector.id, vector);
420 }
421 }
422
423 if self.vector_cache.is_empty() {
424 return Ok(());
425 }
426
427 let first_entry = self.vector_cache.iter().next();
429 if let Some(entry) = first_entry {
430 let dimension = entry.vector.len();
431
432 let hnsw = Hnsw::new(16, dimension, 200, 200, DistCosine {});
434
435 for entry in self.vector_cache.iter() {
437 let vector_entry = entry.value();
438 hnsw.insert((&vector_entry.vector, vector_entry.id as usize));
439 }
440
441 *self.hnsw_index.write() = Some(hnsw);
443 *self.index_dirty.write() = false;
444
445 debug_log!(
446 "HNSW index rebuilt with {} vectors",
447 self.vector_cache.len()
448 );
449 }
450
451 Ok(())
452 }
453
454 pub fn count(&self) -> Result<usize> {
455 let conn = Connection::open(&self.db_path)?;
456
457 let count: i64 = conn.query_row("SELECT COUNT(*) FROM vectors", [], |row| row.get(0))?;
458
459 Ok(count as usize)
460 }
461}
462
463pub fn cosine_similarity_simd(a: &[f64], b: &[f64]) -> f64 {
465 if a.len() != b.len() {
466 debug_log!(
467 "Vector dimension mismatch: query={}, stored={}",
468 a.len(),
469 b.len()
470 );
471 return 0.0;
472 }
473
474 if a.is_empty() {
475 return 0.0;
476 }
477
478 let mut dot_product = 0.0f64;
480 let mut norm_a_sq = 0.0f64;
481 let mut norm_b_sq = 0.0f64;
482
483 let chunk_size = 4;
485 let chunks = a.len() / chunk_size;
486
487 for i in 0..chunks {
488 let start = i * chunk_size;
489 let end = start + chunk_size;
490
491 for j in start..end {
492 let av = a[j];
493 let bv = b[j];
494 dot_product += av * bv;
495 norm_a_sq += av * av;
496 norm_b_sq += bv * bv;
497 }
498 }
499
500 for i in (chunks * chunk_size)..a.len() {
502 let av = a[i];
503 let bv = b[i];
504 dot_product += av * bv;
505 norm_a_sq += av * av;
506 norm_b_sq += bv * bv;
507 }
508
509 let norm_a = norm_a_sq.sqrt();
510 let norm_b = norm_b_sq.sqrt();
511
512 if norm_a == 0.0 || norm_b == 0.0 {
513 return 0.0;
514 }
515
516 dot_product / (norm_a * norm_b)
517}
518
519pub struct FileProcessor;
521
522impl FileProcessor {
523 pub fn is_text_file(path: &std::path::Path) -> bool {
525 if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
527 let ext = ext.to_lowercase();
528 match ext.as_str() {
529 "txt" | "md" | "markdown" | "rst" | "org" | "tex" | "rtf" => true,
531 "rs" | "py" | "js" | "ts" | "java" | "cpp" | "c" | "h" | "hpp" | "go" | "rb"
533 | "php" | "swift" | "kt" | "scala" | "sh" | "bash" | "zsh" | "fish" | "ps1"
534 | "bat" | "cmd" | "html" | "css" | "scss" | "sass" | "less" | "xml" | "json"
535 | "yaml" | "yml" | "toml" | "ini" | "cfg" | "conf" | "sql" | "r" | "m" | "mm"
536 | "pl" | "pm" | "lua" | "vim" | "dockerfile" | "makefile" | "cmake" | "gradle" => {
537 true
538 }
539 "log" | "out" | "err" => true,
541 "exe" | "dll" | "so" | "dylib" | "bin" | "obj" | "o" | "a" | "lib" | "zip"
543 | "tar" | "gz" | "bz2" | "xz" | "7z" | "rar" | "pdf" | "doc" | "docx" | "xls"
544 | "xlsx" | "ppt" | "pptx" | "jpg" | "jpeg" | "png" | "gif" | "bmp" | "tiff"
545 | "svg" | "ico" | "mp3" | "mp4" | "avi" | "mov" | "wmv" | "flv" | "mkv" | "wav"
546 | "flac" | "ogg" => false,
547 _ => {
548 path.file_name()
550 .and_then(|name| name.to_str())
551 .map(|name| !name.contains('.'))
552 .unwrap_or(false)
553 }
554 }
555 } else {
556 Self::is_text_content(path).unwrap_or(false)
558 }
559 }
560
561 fn is_text_content(path: &std::path::Path) -> Result<bool> {
563 use std::fs::File;
564 use std::io::Read;
565
566 let mut file = File::open(path)?;
567 let mut buffer = [0; 512]; let bytes_read = file.read(&mut buffer)?;
569
570 if bytes_read == 0 {
571 return Ok(true); }
573
574 let null_count = buffer[..bytes_read].iter().filter(|&&b| b == 0).count();
576 if null_count > 0 {
577 return Ok(false);
578 }
579
580 let printable_count = buffer[..bytes_read]
582 .iter()
583 .filter(|&&b| b >= 32 && b <= 126 || b == 9 || b == 10 || b == 13)
584 .count();
585
586 let printable_ratio = printable_count as f64 / bytes_read as f64;
587 Ok(printable_ratio > 0.7) }
589
590 pub fn expand_file_patterns(patterns: &[String]) -> Result<Vec<std::path::PathBuf>> {
592 use glob::glob;
593
594 let mut files = Vec::new();
595
596 for pattern in patterns {
597 debug_log!("Processing file pattern: {}", pattern);
598
599 match glob(pattern) {
600 Ok(paths) => {
601 for path_result in paths {
602 match path_result {
603 Ok(path) => {
604 if path.is_file() && Self::is_text_file(&path) {
605 debug_log!("Adding text file: {}", path.display());
606 files.push(path);
607 } else if path.is_file() {
608 debug_log!("Skipping non-text file: {}", path.display());
609 } else {
610 debug_log!("Skipping non-file: {}", path.display());
611 }
612 }
613 Err(e) => {
614 eprintln!(
615 "Warning: Error processing path in pattern '{}': {}",
616 pattern, e
617 );
618 }
619 }
620 }
621 }
622 Err(e) => {
623 eprintln!("Warning: Invalid glob pattern '{}': {}", pattern, e);
624 }
625 }
626 }
627
628 files.sort();
629 files.dedup();
630 Ok(files)
631 }
632
633 pub fn chunk_text(text: &str, chunk_size: usize, overlap: usize) -> Vec<String> {
635 debug_log!(
636 "Chunking text: {} chars, chunk_size: {}, overlap: {}",
637 text.len(),
638 chunk_size,
639 overlap
640 );
641
642 if text.len() <= chunk_size {
643 debug_log!("Text is smaller than chunk size, returning single chunk");
644 return vec![text.to_string()];
645 }
646
647 let mut chunks = Vec::new();
648 let mut start = 0;
649 let mut iteration = 0;
650
651 while start < text.len() {
652 iteration += 1;
653 debug_log!(
654 "Chunk iteration {}: start={}, text.len()={}",
655 iteration,
656 start,
657 text.len()
658 );
659
660 let end = std::cmp::min(start + chunk_size, text.len());
661 let mut chunk_end = end;
662
663 if end < text.len() {
665 if let Some(sentence_end) = text[start..end].rfind(". ") {
666 chunk_end = start + sentence_end + 1;
667 } else if let Some(para_end) = text[start..end].rfind("\n\n") {
668 chunk_end = start + para_end + 1;
669 } else if let Some(line_end) = text[start..end].rfind('\n') {
670 chunk_end = start + line_end + 1;
671 }
672 }
673
674 let chunk = text[start..chunk_end].trim().to_string();
675 if !chunk.is_empty() {
676 let chunk_len = chunk.len();
677 chunks.push(chunk);
678 debug_log!("Added chunk {}: {} chars", chunks.len(), chunk_len);
679 }
680
681 if chunk_end >= text.len() {
683 debug_log!("Reached end of text, breaking");
684 break;
685 }
686
687 let new_start = if chunk_end > overlap {
688 chunk_end - overlap
689 } else {
690 chunk_end
691 };
692
693 if new_start <= start {
696 start = start + 1;
697 debug_log!(
698 "Preventing infinite loop: moving start from {} to {}",
699 new_start,
700 start
701 );
702 } else {
703 start = new_start;
704 }
705
706 debug_log!("Next start position: {}", start);
707
708 if iteration > 1000 {
710 debug_log!("WARNING: Too many iterations, breaking to prevent infinite loop");
711 break;
712 }
713 }
714
715 debug_log!("Chunking complete: {} chunks created", chunks.len());
716 chunks
717 }
718
719 pub fn process_file(path: &std::path::Path) -> Result<Vec<String>> {
721 if let Ok(handle) = tokio::runtime::Handle::try_current() {
723 handle.block_on(Self::process_file_async(path))
724 } else {
725 debug_log!("Reading file synchronously: {}", path.display());
727 let content = std::fs::read_to_string(path)?;
728 debug_log!("File content length: {} characters", content.len());
729
730 debug_log!("Starting text chunking with 1200 char chunks, 200 char overlap");
732 let chunks = Self::chunk_text(&content, 1200, 200);
733
734 debug_log!(
735 "File '{}' split into {} chunks",
736 path.display(),
737 chunks.len()
738 );
739
740 Ok(chunks)
741 }
742 }
743
744 pub async fn process_file_async(path: &std::path::Path) -> Result<Vec<String>> {
746 debug_log!("Reading file: {}", path.display());
747
748 let content = Self::read_file_optimized(path).await?;
749 debug_log!("File content length: {} characters", content.len());
750
751 debug_log!("Starting text chunking with 1200 char chunks, 200 char overlap");
753 let chunks = Self::chunk_text(&content, 1200, 200);
754
755 debug_log!(
756 "File '{}' split into {} chunks",
757 path.display(),
758 chunks.len()
759 );
760
761 Ok(chunks)
762 }
763
764 async fn read_file_optimized(path: &std::path::Path) -> Result<String> {
766 let metadata = tokio::fs::metadata(path).await?;
767 let file_size = metadata.len();
768
769 if file_size > 1_048_576 {
771 debug_log!("Using memory mapping for large file: {} bytes", file_size);
772
773 let file = tokio::fs::File::open(path).await?;
774 let std_file = file.into_std().await;
775 let mmap = unsafe { memmap2::Mmap::map(&std_file)? };
776
777 let content = tokio::task::spawn_blocking(move || {
779 std::str::from_utf8(&mmap)
780 .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in file: {}", e))
781 .map(|s| s.to_string())
782 })
783 .await??;
784
785 Ok(content)
786 } else {
787 debug_log!(
789 "Using async file reading for small file: {} bytes",
790 file_size
791 );
792 Ok(tokio::fs::read_to_string(path).await?)
793 }
794 }
795}
796
797#[cfg(test)]
798mod tests {
799 use super::*;
800
801 #[test]
802 fn test_cosine_similarity() {
803 let a = vec![1.0, 2.0, 3.0];
804 let b = vec![1.0, 2.0, 3.0];
805 assert!((cosine_similarity_simd(&a, &b) - 1.0).abs() < 1e-10);
806
807 let a = vec![1.0, 0.0];
808 let b = vec![0.0, 1.0];
809 assert!((cosine_similarity_simd(&a, &b) - 0.0).abs() < 1e-10);
810 }
811
812 #[test]
813 fn test_chunk_text() {
814 let text = "This is sentence one. This is sentence two. This is sentence three.";
815 let chunks = FileProcessor::chunk_text(text, 30, 10);
816
817 assert!(chunks.len() > 1);
818 assert!(chunks[0].contains("sentence one"));
819 }
820
821 #[test]
822 fn test_is_text_file() {
823 use std::path::Path;
824
825 assert!(FileProcessor::is_text_file(Path::new("test.txt")));
826 assert!(FileProcessor::is_text_file(Path::new("test.rs")));
827 assert!(FileProcessor::is_text_file(Path::new("test.py")));
828 assert!(!FileProcessor::is_text_file(Path::new("test.exe")));
829 assert!(!FileProcessor::is_text_file(Path::new("test.jpg")));
830 }
831}