1use super::embedding::EmbeddingProvider;
27use super::vector_store::{VectorEntry, VectorMetadata, VectorStore};
28use super::{ContextItem, ContextProvider, ContextQuery, ContextResult, ContextType};
29use anyhow::Result;
30use async_trait::async_trait;
31use ignore::WalkBuilder;
32use std::path::{Path, PathBuf};
33use std::sync::Arc;
34use tokio::sync::RwLock;
35
36#[derive(Debug, Clone)]
38pub struct VectorContextConfig {
39 pub root_path: PathBuf,
41 pub include_patterns: Vec<String>,
43 pub exclude_patterns: Vec<String>,
45 pub max_file_size: usize,
47 pub max_chunk_chars: usize,
49 pub min_relevance: f32,
51 pub auto_index: bool,
53}
54
55impl VectorContextConfig {
56 pub fn new(root_path: impl Into<PathBuf>) -> Self {
58 Self {
59 root_path: root_path.into(),
60 include_patterns: vec![
61 "**/*.rs".to_string(),
62 "**/*.py".to_string(),
63 "**/*.ts".to_string(),
64 "**/*.js".to_string(),
65 "**/*.go".to_string(),
66 "**/*.md".to_string(),
67 "**/*.toml".to_string(),
68 "**/*.yaml".to_string(),
69 "**/*.yml".to_string(),
70 ],
71 exclude_patterns: vec![
72 "**/target/**".to_string(),
73 "**/node_modules/**".to_string(),
74 "**/.git/**".to_string(),
75 "**/dist/**".to_string(),
76 "**/build/**".to_string(),
77 "**/*.lock".to_string(),
78 ],
79 max_file_size: 512 * 1024,
80 max_chunk_chars: 2000,
81 min_relevance: 0.3,
82 auto_index: true,
83 }
84 }
85
86 pub fn with_include_patterns(mut self, patterns: Vec<String>) -> Self {
88 self.include_patterns = patterns;
89 self
90 }
91
92 pub fn with_exclude_patterns(mut self, patterns: Vec<String>) -> Self {
94 self.exclude_patterns = patterns;
95 self
96 }
97
98 pub fn with_max_file_size(mut self, size: usize) -> Self {
100 self.max_file_size = size;
101 self
102 }
103
104 pub fn with_max_chunk_chars(mut self, chars: usize) -> Self {
106 self.max_chunk_chars = chars;
107 self
108 }
109
110 pub fn with_min_relevance(mut self, score: f32) -> Self {
112 self.min_relevance = score.clamp(0.0, 1.0);
113 self
114 }
115
116 pub fn with_auto_index(mut self, enabled: bool) -> Self {
118 self.auto_index = enabled;
119 self
120 }
121}
122
123#[derive(Debug, Clone)]
125struct CodeChunk {
126 path: PathBuf,
128 content: String,
130 chunk_index: usize,
132}
133
134pub struct VectorContextProvider<E: EmbeddingProvider, S: VectorStore> {
139 config: VectorContextConfig,
140 embedder: Arc<E>,
141 store: Arc<S>,
142 indexed: RwLock<bool>,
143}
144
145impl<E: EmbeddingProvider, S: VectorStore> VectorContextProvider<E, S> {
146 pub fn new(config: VectorContextConfig, embedder: E, store: S) -> Self {
148 Self {
149 config,
150 embedder: Arc::new(embedder),
151 store: Arc::new(store),
152 indexed: RwLock::new(false),
153 }
154 }
155
156 pub fn with_shared(config: VectorContextConfig, embedder: Arc<E>, store: Arc<S>) -> Self {
158 Self {
159 config,
160 embedder,
161 store,
162 indexed: RwLock::new(false),
163 }
164 }
165
166 pub async fn index(&self) -> Result<usize> {
168 let chunks = self.collect_chunks().await?;
169 let chunk_count = chunks.len();
170
171 if chunks.is_empty() {
172 tracing::debug!("No files to index for vector context");
173 *self.indexed.write().await = true;
174 return Ok(0);
175 }
176
177 tracing::info!(
178 chunks = chunk_count,
179 root = %self.config.root_path.display(),
180 "Indexing workspace for vector context"
181 );
182
183 let batch_size = 32;
185 let mut entries = Vec::with_capacity(chunk_count);
186
187 for batch_start in (0..chunks.len()).step_by(batch_size) {
188 let batch_end = (batch_start + batch_size).min(chunks.len());
189 let batch_texts: Vec<&str> = chunks[batch_start..batch_end]
190 .iter()
191 .map(|c| c.content.as_str())
192 .collect();
193
194 let embeddings = self.embedder.embed_batch(&batch_texts).await?;
195
196 for (i, embedding) in embeddings.into_iter().enumerate() {
197 let chunk = &chunks[batch_start + i];
198 let id = format!("{}#{}", chunk.path.display(), chunk.chunk_index);
199
200 entries.push(VectorEntry {
201 id,
202 embedding,
203 metadata: VectorMetadata {
204 source: format!("file:{}", chunk.path.display()),
205 content_type: detect_content_type(&chunk.path),
206 content: chunk.content.clone(),
207 token_count: chunk.content.split_whitespace().count(),
208 extra: {
209 let mut m = std::collections::HashMap::new();
210 m.insert(
211 "path".to_string(),
212 serde_json::Value::String(chunk.path.to_string_lossy().to_string()),
213 );
214 m.insert(
215 "chunk_index".to_string(),
216 serde_json::Value::Number(chunk.chunk_index.into()),
217 );
218 m
219 },
220 },
221 });
222 }
223 }
224
225 self.store.insert_batch(entries).await?;
226 *self.indexed.write().await = true;
227
228 tracing::info!(chunks = chunk_count, "Vector context indexing complete");
229
230 Ok(chunk_count)
231 }
232
233 async fn collect_chunks(&self) -> Result<Vec<CodeChunk>> {
235 let root = self.config.root_path.clone();
236 let max_file_size = self.config.max_file_size;
237 let max_chunk_chars = self.config.max_chunk_chars;
238 let include = self.config.include_patterns.clone();
239 let exclude = self.config.exclude_patterns.clone();
240
241 tokio::task::spawn_blocking(move || {
243 let mut chunks = Vec::new();
244
245 let walker = WalkBuilder::new(&root)
246 .hidden(false)
247 .git_ignore(true)
248 .build();
249
250 for entry in walker {
251 let entry = entry.map_err(|e| anyhow::anyhow!("Walk error: {}", e))?;
252 let path = entry.path();
253
254 if !path.is_file() {
255 continue;
256 }
257
258 let metadata = std::fs::metadata(path)
259 .map_err(|e| anyhow::anyhow!("Metadata error for {}: {}", path.display(), e))?;
260
261 if metadata.len() > max_file_size as u64 {
262 continue;
263 }
264
265 if !matches_patterns(path, &include, true) {
266 continue;
267 }
268
269 if matches_patterns(path, &exclude, false) {
270 continue;
271 }
272
273 let content = match std::fs::read_to_string(path) {
274 Ok(c) => c,
275 Err(_) => continue, };
277
278 if content.trim().is_empty() {
279 continue;
280 }
281
282 let file_chunks = chunk_text(&content, max_chunk_chars);
284 for (i, chunk_content) in file_chunks.into_iter().enumerate() {
285 chunks.push(CodeChunk {
286 path: path.to_path_buf(),
287 content: chunk_content,
288 chunk_index: i,
289 });
290 }
291 }
292
293 Ok::<_, anyhow::Error>(chunks)
294 })
295 .await
296 .map_err(|e| anyhow::anyhow!("Spawn blocking failed: {}", e))?
297 }
298
299 async fn ensure_indexed(&self) -> Result<()> {
301 if *self.indexed.read().await {
302 return Ok(());
303 }
304 if self.config.auto_index {
305 self.index().await?;
306 }
307 Ok(())
308 }
309}
310
311#[async_trait]
312impl<E: EmbeddingProvider + 'static, S: VectorStore + 'static> ContextProvider
313 for VectorContextProvider<E, S>
314{
315 fn name(&self) -> &str {
316 "vector-rag"
317 }
318
319 async fn query(&self, query: &ContextQuery) -> Result<ContextResult> {
320 self.ensure_indexed().await?;
321
322 let query_embedding = self.embedder.embed(&query.query).await?;
324
325 let search_results = self
327 .store
328 .search(&query_embedding, query.max_results)
329 .await?;
330
331 let mut result = ContextResult::new("vector-rag");
333 let mut total_tokens = 0usize;
334
335 for sr in search_results {
336 if sr.score < self.config.min_relevance {
337 continue;
338 }
339
340 if total_tokens >= query.max_tokens {
341 result.truncated = true;
342 break;
343 }
344
345 let content = match query.depth {
346 super::ContextDepth::Abstract => {
347 sr.metadata.content.chars().take(500).collect::<String>()
348 }
349 super::ContextDepth::Overview => {
350 sr.metadata.content.chars().take(2000).collect::<String>()
351 }
352 super::ContextDepth::Full => sr.metadata.content.clone(),
353 };
354
355 let token_count = content.split_whitespace().count();
356 total_tokens += token_count;
357
358 result.add_item(
359 ContextItem::new(sr.id, ContextType::Resource, content)
360 .with_token_count(token_count)
361 .with_relevance(sr.score)
362 .with_source(&sr.metadata.source)
363 .with_metadata("content_type", serde_json::json!(sr.metadata.content_type)),
364 );
365 }
366
367 Ok(result)
368 }
369}
370
371fn matches_patterns(path: &Path, patterns: &[String], default_if_empty: bool) -> bool {
377 if patterns.is_empty() {
378 return default_if_empty;
379 }
380 let path_str = path.to_string_lossy();
381 patterns.iter().any(|pattern| {
382 glob::Pattern::new(pattern)
383 .map(|p| p.matches(&path_str))
384 .unwrap_or(false)
385 })
386}
387
388fn chunk_text(text: &str, max_chars: usize) -> Vec<String> {
393 if text.len() <= max_chars {
394 return vec![text.to_string()];
395 }
396
397 let mut chunks = Vec::new();
398 let mut current = String::new();
399
400 for paragraph in text.split("\n\n") {
402 if current.len() + paragraph.len() + 2 > max_chars && !current.is_empty() {
403 chunks.push(current.trim().to_string());
404 current = String::new();
405 }
406
407 if paragraph.len() > max_chars {
408 if !current.is_empty() {
410 chunks.push(current.trim().to_string());
411 current = String::new();
412 }
413 for line in paragraph.split('\n') {
414 if current.len() + line.len() + 1 > max_chars && !current.is_empty() {
415 chunks.push(current.trim().to_string());
416 current = String::new();
417 }
418 if !current.is_empty() {
419 current.push('\n');
420 }
421 current.push_str(line);
422 }
423 } else {
424 if !current.is_empty() {
425 current.push_str("\n\n");
426 }
427 current.push_str(paragraph);
428 }
429 }
430
431 if !current.trim().is_empty() {
432 chunks.push(current.trim().to_string());
433 }
434
435 chunks
436}
437
438fn detect_content_type(path: &Path) -> String {
440 match path.extension().and_then(|e| e.to_str()) {
441 Some("rs") => "rust",
442 Some("py") => "python",
443 Some("ts") | Some("tsx") => "typescript",
444 Some("js") | Some("jsx") => "javascript",
445 Some("go") => "go",
446 Some("md") | Some("mdx") => "markdown",
447 Some("toml") | Some("yaml") | Some("yml") | Some("json") => "config",
448 _ => "text",
449 }
450 .to_string()
451}
452
453#[cfg(test)]
458mod tests {
459 use super::*;
460 use crate::context::vector_store::InMemoryVectorStore;
461 use anyhow::Result;
462 use std::fs::{self, File};
463 use std::io::Write;
464 use tempfile::TempDir;
465
466 struct MockEmbeddingProvider {
468 dim: usize,
469 }
470
471 impl MockEmbeddingProvider {
472 fn new(dim: usize) -> Self {
473 Self { dim }
474 }
475 }
476
477 #[async_trait]
478 impl EmbeddingProvider for MockEmbeddingProvider {
479 fn name(&self) -> &str {
480 "mock-embedding"
481 }
482
483 fn dimension(&self) -> usize {
484 self.dim
485 }
486
487 async fn embed(&self, text: &str) -> Result<super::super::embedding::Embedding> {
488 let mut embedding = vec![0.0f32; self.dim];
489 for (i, byte) in text.bytes().enumerate() {
490 embedding[i % self.dim] += (byte as f32) / 255.0;
491 }
492 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
493 if norm > 0.0 {
494 for v in &mut embedding {
495 *v /= norm;
496 }
497 }
498 Ok(embedding)
499 }
500 }
501
502 fn setup_test_workspace() -> TempDir {
503 let dir = TempDir::new().unwrap();
504 let root = dir.path();
505
506 let mut f1 = File::create(root.join("main.rs")).unwrap();
508 writeln!(f1, "fn main() {{\n println!(\"Hello, world!\");\n}}").unwrap();
509
510 let mut f2 = File::create(root.join("lib.rs")).unwrap();
511 writeln!(
512 f2,
513 "pub mod auth;\npub mod database;\n\npub fn init() -> Result<()> {{\n Ok(())\n}}"
514 )
515 .unwrap();
516
517 let mut f3 = File::create(root.join("README.md")).unwrap();
518 writeln!(
519 f3,
520 "# My Project\n\nA Rust project for testing vector RAG context."
521 )
522 .unwrap();
523
524 fs::create_dir(root.join("src")).unwrap();
525 let mut f4 = File::create(root.join("src/auth.rs")).unwrap();
526 writeln!(
527 f4,
528 "use jwt::Token;\n\npub fn verify_token(token: &str) -> Result<Claims> {{\n // JWT verification logic\n todo!()\n}}"
529 )
530 .unwrap();
531
532 dir
533 }
534
535 #[test]
538 fn test_config_defaults() {
539 let config = VectorContextConfig::new("/tmp/test");
540 assert_eq!(config.root_path, PathBuf::from("/tmp/test"));
541 assert!(!config.include_patterns.is_empty());
542 assert!(!config.exclude_patterns.is_empty());
543 assert_eq!(config.max_file_size, 512 * 1024);
544 assert_eq!(config.max_chunk_chars, 2000);
545 assert!((config.min_relevance - 0.3).abs() < f32::EPSILON);
546 assert!(config.auto_index);
547 }
548
549 #[test]
550 fn test_config_builders() {
551 let config = VectorContextConfig::new("/tmp")
552 .with_include_patterns(vec!["**/*.rs".to_string()])
553 .with_exclude_patterns(vec!["**/test/**".to_string()])
554 .with_max_file_size(1024)
555 .with_max_chunk_chars(500)
556 .with_min_relevance(0.5)
557 .with_auto_index(false);
558
559 assert_eq!(config.include_patterns, vec!["**/*.rs"]);
560 assert_eq!(config.exclude_patterns, vec!["**/test/**"]);
561 assert_eq!(config.max_file_size, 1024);
562 assert_eq!(config.max_chunk_chars, 500);
563 assert!((config.min_relevance - 0.5).abs() < f32::EPSILON);
564 assert!(!config.auto_index);
565 }
566
567 #[test]
568 fn test_config_min_relevance_clamping() {
569 let c1 = VectorContextConfig::new("/tmp").with_min_relevance(1.5);
570 assert!((c1.min_relevance - 1.0).abs() < f32::EPSILON);
571
572 let c2 = VectorContextConfig::new("/tmp").with_min_relevance(-0.5);
573 assert!(c2.min_relevance.abs() < f32::EPSILON);
574 }
575
576 #[test]
579 fn test_chunk_text_small() {
580 let chunks = chunk_text("hello world", 100);
581 assert_eq!(chunks.len(), 1);
582 assert_eq!(chunks[0], "hello world");
583 }
584
585 #[test]
586 fn test_chunk_text_splits_on_double_newline() {
587 let text = "paragraph one\n\nparagraph two\n\nparagraph three";
588 let chunks = chunk_text(text, 20);
589 assert!(chunks.len() >= 2);
590 assert!(chunks[0].contains("paragraph one"));
591 }
592
593 #[test]
594 fn test_chunk_text_large_paragraph() {
595 let text = (0..100)
596 .map(|i| format!("line {}", i))
597 .collect::<Vec<_>>()
598 .join("\n");
599 let chunks = chunk_text(&text, 50);
600 assert!(chunks.len() > 1);
601 for chunk in &chunks {
602 assert!(chunk.len() < 200, "Chunk too large: {} chars", chunk.len());
604 }
605 }
606
607 #[test]
608 fn test_chunk_text_empty() {
609 let chunks = chunk_text("", 100);
610 assert_eq!(chunks.len(), 1);
611 }
612
613 #[test]
616 fn test_detect_content_type() {
617 assert_eq!(detect_content_type(Path::new("main.rs")), "rust");
618 assert_eq!(detect_content_type(Path::new("app.py")), "python");
619 assert_eq!(detect_content_type(Path::new("index.ts")), "typescript");
620 assert_eq!(detect_content_type(Path::new("app.tsx")), "typescript");
621 assert_eq!(detect_content_type(Path::new("main.go")), "go");
622 assert_eq!(detect_content_type(Path::new("README.md")), "markdown");
623 assert_eq!(detect_content_type(Path::new("Cargo.toml")), "config");
624 assert_eq!(detect_content_type(Path::new("config.yaml")), "config");
625 assert_eq!(detect_content_type(Path::new("unknown.xyz")), "text");
626 }
627
628 #[test]
631 fn test_matches_patterns_empty_default_true() {
632 assert!(matches_patterns(Path::new("test.rs"), &[], true));
633 }
634
635 #[test]
636 fn test_matches_patterns_empty_default_false() {
637 assert!(!matches_patterns(Path::new("test.rs"), &[], false));
638 }
639
640 #[tokio::test]
643 async fn test_provider_index() {
644 let dir = setup_test_workspace();
645 let config = VectorContextConfig::new(dir.path()).with_auto_index(false);
646 let embedder = MockEmbeddingProvider::new(8);
647 let store = InMemoryVectorStore::new();
648
649 let provider = VectorContextProvider::new(config, embedder, store);
650 let count = provider.index().await.unwrap();
651 assert!(count > 0, "Should have indexed some chunks");
652 }
653
654 #[tokio::test]
655 async fn test_provider_query() {
656 let dir = setup_test_workspace();
657 let config = VectorContextConfig::new(dir.path())
658 .with_min_relevance(0.0) .with_auto_index(false);
660 let provider = VectorContextProvider::new(
661 config,
662 MockEmbeddingProvider::new(8),
663 InMemoryVectorStore::new(),
664 );
665 provider.index().await.unwrap();
666
667 let query = ContextQuery::new("authentication JWT token");
668 let result = ContextProvider::query(&provider, &query).await.unwrap();
669
670 assert_eq!(result.provider, "vector-rag");
671 assert!(!result.items.is_empty());
672 for item in &result.items {
674 assert_eq!(item.context_type, ContextType::Resource);
675 assert!(item.source.is_some());
676 }
677 }
678
679 #[tokio::test]
680 async fn test_provider_auto_index() {
681 let dir = setup_test_workspace();
682 let config = VectorContextConfig::new(dir.path())
683 .with_min_relevance(0.0)
684 .with_auto_index(true);
685 let provider = VectorContextProvider::new(
686 config,
687 MockEmbeddingProvider::new(8),
688 InMemoryVectorStore::new(),
689 );
690
691 let query = ContextQuery::new("hello");
693 let result = ContextProvider::query(&provider, &query).await.unwrap();
694 assert!(!result.items.is_empty());
695 }
696
697 #[tokio::test]
698 async fn test_provider_empty_workspace() {
699 let dir = TempDir::new().unwrap();
700 let config = VectorContextConfig::new(dir.path()).with_auto_index(false);
701 let provider = VectorContextProvider::new(
702 config,
703 MockEmbeddingProvider::new(8),
704 InMemoryVectorStore::new(),
705 );
706 let count = provider.index().await.unwrap();
707 assert_eq!(count, 0);
708
709 let query = ContextQuery::new("anything");
710 let result = ContextProvider::query(&provider, &query).await.unwrap();
711 assert!(result.items.is_empty());
712 }
713
714 #[tokio::test]
715 async fn test_provider_respects_max_results() {
716 let dir = setup_test_workspace();
717 let config = VectorContextConfig::new(dir.path())
718 .with_min_relevance(0.0)
719 .with_auto_index(false);
720 let provider = VectorContextProvider::new(
721 config,
722 MockEmbeddingProvider::new(8),
723 InMemoryVectorStore::new(),
724 );
725 provider.index().await.unwrap();
726
727 let query = ContextQuery::new("test").with_max_results(1);
728 let result = ContextProvider::query(&provider, &query).await.unwrap();
729 assert!(result.items.len() <= 1);
730 }
731
732 #[tokio::test]
733 async fn test_provider_respects_max_tokens() {
734 let dir = setup_test_workspace();
735 let config = VectorContextConfig::new(dir.path())
736 .with_min_relevance(0.0)
737 .with_auto_index(false);
738 let provider = VectorContextProvider::new(
739 config,
740 MockEmbeddingProvider::new(8),
741 InMemoryVectorStore::new(),
742 );
743 provider.index().await.unwrap();
744
745 let query = ContextQuery::new("test").with_max_tokens(5);
746 let result = ContextProvider::query(&provider, &query).await.unwrap();
747 assert!(result.total_tokens <= 50); }
750
751 #[tokio::test]
752 async fn test_provider_with_shared() {
753 let dir = setup_test_workspace();
754 let config = VectorContextConfig::new(dir.path())
755 .with_min_relevance(0.0)
756 .with_auto_index(false);
757 let embedder = Arc::new(MockEmbeddingProvider::new(8));
758 let store = Arc::new(InMemoryVectorStore::new());
759
760 let provider =
761 VectorContextProvider::with_shared(config, Arc::clone(&embedder), Arc::clone(&store));
762 provider.index().await.unwrap();
763
764 assert!(!store.is_empty().await);
766 }
767
768 #[tokio::test]
769 async fn test_provider_name() {
770 let dir = TempDir::new().unwrap();
771 let config = VectorContextConfig::new(dir.path());
772 let provider = VectorContextProvider::new(
773 config,
774 MockEmbeddingProvider::new(4),
775 InMemoryVectorStore::new(),
776 );
777 assert_eq!(ContextProvider::name(&provider), "vector-rag");
778 }
779
780 #[tokio::test]
781 async fn test_provider_context_depth_abstract() {
782 let dir = setup_test_workspace();
783 let config = VectorContextConfig::new(dir.path())
784 .with_min_relevance(0.0)
785 .with_auto_index(false);
786 let provider = VectorContextProvider::new(
787 config,
788 MockEmbeddingProvider::new(8),
789 InMemoryVectorStore::new(),
790 );
791 provider.index().await.unwrap();
792
793 let query = ContextQuery::new("test").with_depth(crate::context::ContextDepth::Abstract);
794 let result = ContextProvider::query(&provider, &query).await.unwrap();
795 for item in &result.items {
796 assert!(item.content.len() <= 500);
797 }
798 }
799}