1use super::Database;
4use crate::error::Result;
5use chrono::Utc;
6use rusqlite::params;
7
8#[derive(Debug, Clone, serde::Serialize)]
10pub struct CollectionInfo {
11 pub name: String,
12 pub path: String,
13 pub pattern: String,
14 pub document_count: usize,
15 pub created_at: String,
16 pub updated_at: String,
17 pub provider_type: String,
18 pub provider_config: Option<String>,
19}
20
21impl Database {
22 pub fn add_collection(
24 &self,
25 name: &str,
26 path: &str,
27 pattern: &str,
28 provider_type: &str,
29 provider_config: Option<&str>,
30 ) -> Result<()> {
31 let now = Utc::now().to_rfc3339();
32 self.conn.execute(
33 "INSERT INTO collections (name, path, pattern, created_at, updated_at, provider_type, provider_config)
34 VALUES (?1, ?2, ?3, ?4, ?4, ?5, ?6)",
35 params![name, path, pattern, now, provider_type, provider_config],
36 )?;
37 Ok(())
38 }
39
40 pub fn remove_collection(&self, name: &str) -> Result<bool> {
42 self.conn.execute(
44 "UPDATE documents SET active = 0 WHERE collection = ?1",
45 params![name],
46 )?;
47
48 let rows = self
50 .conn
51 .execute("DELETE FROM collections WHERE name = ?1", params![name])?;
52
53 Ok(rows > 0)
54 }
55
56 pub fn rename_collection(&self, old_name: &str, new_name: &str) -> Result<bool> {
58 let now = Utc::now().to_rfc3339();
59
60 self.conn.execute(
62 "UPDATE documents SET collection = ?2 WHERE collection = ?1",
63 params![old_name, new_name],
64 )?;
65
66 let rows = self.conn.execute(
68 "UPDATE collections SET name = ?2, updated_at = ?3 WHERE name = ?1",
69 params![old_name, new_name, now],
70 )?;
71
72 Ok(rows > 0)
73 }
74
75 pub fn list_collections(&self) -> Result<Vec<CollectionInfo>> {
77 let mut stmt = self.conn.prepare(
78 "SELECT c.name, c.path, c.pattern, c.created_at, c.updated_at,
79 (SELECT COUNT(*) FROM documents d WHERE d.collection = c.name AND d.active = 1),
80 c.provider_type, c.provider_config
81 FROM collections c
82 ORDER BY c.name",
83 )?;
84
85 let results = stmt
86 .query_map([], |row| {
87 Ok(CollectionInfo {
88 name: row.get(0)?,
89 path: row.get(1)?,
90 pattern: row.get(2)?,
91 created_at: row.get(3)?,
92 updated_at: row.get(4)?,
93 document_count: row.get::<_, i64>(5)? as usize,
94 provider_type: row.get(6)?,
95 provider_config: row.get(7)?,
96 })
97 })?
98 .collect::<std::result::Result<Vec<_>, _>>()?;
99
100 Ok(results)
101 }
102
103 pub fn get_collection(&self, name: &str) -> Result<Option<CollectionInfo>> {
105 let result = self.conn.query_row(
106 "SELECT c.name, c.path, c.pattern, c.created_at, c.updated_at,
107 (SELECT COUNT(*) FROM documents d WHERE d.collection = c.name AND d.active = 1),
108 c.provider_type, c.provider_config
109 FROM collections c WHERE c.name = ?1",
110 params![name],
111 |row| {
112 Ok(CollectionInfo {
113 name: row.get(0)?,
114 path: row.get(1)?,
115 pattern: row.get(2)?,
116 created_at: row.get(3)?,
117 updated_at: row.get(4)?,
118 document_count: row.get::<_, i64>(5)? as usize,
119 provider_type: row.get(6)?,
120 provider_config: row.get(7)?,
121 })
122 },
123 );
124 match result {
125 Ok(info) => Ok(Some(info)),
126 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
127 Err(e) => Err(e.into()),
128 }
129 }
130
131 pub fn touch_collection(&self, name: &str) -> Result<()> {
133 let now = Utc::now().to_rfc3339();
134 self.conn.execute(
135 "UPDATE collections SET updated_at = ?2 WHERE name = ?1",
136 params![name, now],
137 )?;
138 Ok(())
139 }
140
141 pub async fn reindex_collection(&self, name: &str) -> Result<usize> {
143 let coll = self
144 .get_collection(name)?
145 .ok_or_else(|| crate::error::AgentRootError::CollectionNotFound(name.to_string()))?;
146
147 let registry = crate::providers::ProviderRegistry::with_defaults();
148 let provider = registry.get(&coll.provider_type).ok_or_else(|| {
149 crate::error::AgentRootError::InvalidInput(format!(
150 "Unknown provider type: {}",
151 coll.provider_type
152 ))
153 })?;
154
155 let mut config =
156 crate::providers::ProviderConfig::new(coll.path.clone(), coll.pattern.clone());
157
158 if let Some(provider_config) = &coll.provider_config {
159 if let Ok(config_map) =
160 serde_json::from_str::<std::collections::HashMap<String, String>>(provider_config)
161 {
162 for (key, value) in config_map {
163 config = config.with_option(key, value);
164 }
165 }
166 }
167
168 let items = provider.list_items(&config).await?;
169 let mut updated = 0;
170
171 for item in items {
172 let now = Utc::now().to_rfc3339();
173
174 if let Some(existing) = self.find_active_document(name, &item.uri)? {
175 if existing.hash != item.hash {
176 self.insert_content(&item.hash, &item.content)?;
177 self.update_document(existing.id, &item.title, &item.hash, &now)?;
178 updated += 1;
179 }
180 } else {
181 self.insert_content(&item.hash, &item.content)?;
182 self.insert_document(
183 name,
184 &item.uri,
185 &item.title,
186 &item.hash,
187 &now,
188 &now,
189 &item.source_type,
190 item.metadata.get("source_uri").map(|s| s.as_str()),
191 )?;
192 updated += 1;
193 }
194 }
195
196 self.touch_collection(name)?;
197 Ok(updated)
198 }
199
200 pub async fn generate_or_fetch_metadata(
202 &self,
203 content_hash: &str,
204 content: &str,
205 context: crate::llm::MetadataContext,
206 generator: Option<&dyn crate::llm::MetadataGenerator>,
207 ) -> Result<Option<crate::llm::DocumentMetadata>> {
208 if generator.is_none() {
209 return Ok(None);
210 }
211
212 let cache_key = format!("metadata:v1:{}", content_hash);
213
214 if let Some(cached) = self.get_llm_cache(&cache_key)? {
215 if let Ok(metadata) = serde_json::from_str::<crate::llm::DocumentMetadata>(&cached) {
216 return Ok(Some(metadata));
217 }
218 }
219
220 let gen = generator.unwrap();
221 match gen.generate_metadata(content, &context).await {
222 Ok(metadata) => {
223 let cache_value = serde_json::to_string(&metadata)?;
224 self.set_llm_cache(&cache_key, &cache_value, gen.model_name())?;
225 Ok(Some(metadata))
226 }
227 Err(e) => {
228 eprintln!("Metadata generation failed: {}. Skipping metadata.", e);
229 Ok(None)
230 }
231 }
232 }
233
234 pub fn get_llm_cache_public(&self, key: &str) -> Result<Option<String>> {
236 self.get_llm_cache(key)
237 }
238
239 fn get_llm_cache(&self, key: &str) -> Result<Option<String>> {
241 let result = self.conn.query_row(
242 "SELECT value FROM llm_cache WHERE key = ?1",
243 params![key],
244 |row| row.get(0),
245 );
246
247 match result {
248 Ok(value) => Ok(Some(value)),
249 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
250 Err(e) => Err(e.into()),
251 }
252 }
253
254 fn set_llm_cache(&self, key: &str, value: &str, model: &str) -> Result<()> {
256 let now = Utc::now().to_rfc3339();
257 self.conn.execute(
258 "INSERT OR REPLACE INTO llm_cache (key, value, model, created_at) VALUES (?1, ?2, ?3, ?4)",
259 params![key, value, model, now],
260 )?;
261 Ok(())
262 }
263
264 fn build_metadata_context(
266 &self,
267 item: &crate::providers::SourceItem,
268 collection_name: &str,
269 coll: &CollectionInfo,
270 ) -> crate::llm::MetadataContext {
271 let path = std::path::Path::new(&item.uri);
272 let extension = path
273 .extension()
274 .and_then(|e| e.to_str())
275 .map(|s| s.to_string());
276
277 crate::llm::MetadataContext::new(item.source_type.clone(), collection_name.to_string())
278 .with_extension(extension.unwrap_or_default())
279 .with_provider_config(coll.provider_config.clone().unwrap_or_default())
280 }
281
282 fn insert_document_with_metadata(
284 &self,
285 collection: &str,
286 path: &str,
287 title: &str,
288 hash: &str,
289 created_at: &str,
290 modified_at: &str,
291 source_type: &str,
292 source_uri: Option<&str>,
293 metadata: &crate::llm::DocumentMetadata,
294 model_name: &str,
295 ) -> Result<i64> {
296 let keywords_json = serde_json::to_string(&metadata.keywords)?;
297 let concepts_json = serde_json::to_string(&metadata.concepts)?;
298 let queries_json = serde_json::to_string(&metadata.suggested_queries)?;
299 let now = Utc::now().to_rfc3339();
300
301 let doc = super::documents::DocumentInsert::new(
302 collection,
303 path,
304 title,
305 hash,
306 created_at,
307 modified_at,
308 )
309 .with_source_type(source_type)
310 .with_source_uri(source_uri.unwrap_or(""))
311 .with_llm_metadata_strings(
312 &metadata.summary,
313 &metadata.semantic_title,
314 &keywords_json,
315 &metadata.category,
316 &metadata.intent,
317 &concepts_json,
318 &metadata.difficulty,
319 &queries_json,
320 model_name,
321 &now,
322 );
323
324 self.insert_doc(&doc)
325 }
326
327 fn update_document_with_metadata(
329 &self,
330 id: i64,
331 title: &str,
332 hash: &str,
333 modified_at: &str,
334 metadata: &crate::llm::DocumentMetadata,
335 model_name: &str,
336 ) -> Result<()> {
337 let keywords_json = serde_json::to_string(&metadata.keywords)?;
338 let concepts_json = serde_json::to_string(&metadata.concepts)?;
339 let queries_json = serde_json::to_string(&metadata.suggested_queries)?;
340 let now = Utc::now().to_rfc3339();
341
342 self.conn.execute(
343 "UPDATE documents
344 SET title = ?2, hash = ?3, modified_at = ?4,
345 llm_summary = ?5, llm_title = ?6, llm_keywords = ?7, llm_category = ?8,
346 llm_intent = ?9, llm_concepts = ?10, llm_difficulty = ?11, llm_queries = ?12,
347 llm_metadata_generated_at = ?13, llm_model = ?14
348 WHERE id = ?1",
349 params![
350 id,
351 title,
352 hash,
353 modified_at,
354 metadata.summary,
355 metadata.semantic_title,
356 keywords_json,
357 metadata.category,
358 metadata.intent,
359 concepts_json,
360 metadata.difficulty,
361 queries_json,
362 now,
363 model_name
364 ],
365 )?;
366 Ok(())
367 }
368
369 pub async fn reindex_collection_with_metadata(
371 &self,
372 name: &str,
373 generator: Option<&dyn crate::llm::MetadataGenerator>,
374 ) -> Result<usize> {
375 let coll = self
376 .get_collection(name)?
377 .ok_or_else(|| crate::error::AgentRootError::CollectionNotFound(name.to_string()))?;
378
379 let registry = crate::providers::ProviderRegistry::with_defaults();
380 let provider = registry.get(&coll.provider_type).ok_or_else(|| {
381 crate::error::AgentRootError::InvalidInput(format!(
382 "Unknown provider type: {}",
383 coll.provider_type
384 ))
385 })?;
386
387 let mut config =
388 crate::providers::ProviderConfig::new(coll.path.clone(), coll.pattern.clone());
389
390 if let Some(provider_config) = &coll.provider_config {
391 if let Ok(config_map) =
392 serde_json::from_str::<std::collections::HashMap<String, String>>(provider_config)
393 {
394 for (key, value) in config_map {
395 config = config.with_option(key, value);
396 }
397 }
398 }
399
400 let items = provider.list_items(&config).await?;
401 let mut updated = 0;
402
403 for item in items {
404 let now = Utc::now().to_rfc3339();
405
406 if let Some(existing) = self.find_active_document(name, &item.uri)? {
407 if existing.hash != item.hash {
408 self.insert_content(&item.hash, &item.content)?;
409
410 let metadata_opt = if generator.is_some() {
411 let context = self.build_metadata_context(&item, name, &coll);
412 self.generate_or_fetch_metadata(
413 &item.hash,
414 &item.content,
415 context,
416 generator,
417 )
418 .await?
419 } else {
420 None
421 };
422
423 if let Some(metadata) = metadata_opt {
424 self.update_document_with_metadata(
425 existing.id,
426 &item.title,
427 &item.hash,
428 &now,
429 &metadata,
430 generator.unwrap().model_name(),
431 )?;
432 } else {
433 self.update_document(existing.id, &item.title, &item.hash, &now)?;
434 }
435 updated += 1;
436 }
437 } else {
438 self.insert_content(&item.hash, &item.content)?;
439
440 let metadata_opt = if generator.is_some() {
441 let context = self.build_metadata_context(&item, name, &coll);
442 self.generate_or_fetch_metadata(&item.hash, &item.content, context, generator)
443 .await?
444 } else {
445 None
446 };
447
448 if let Some(metadata) = metadata_opt {
449 self.insert_document_with_metadata(
450 name,
451 &item.uri,
452 &item.title,
453 &item.hash,
454 &now,
455 &now,
456 &item.source_type,
457 item.metadata.get("source_uri").map(|s| s.as_str()),
458 &metadata,
459 generator.unwrap().model_name(),
460 )?;
461 } else {
462 self.insert_document(
463 name,
464 &item.uri,
465 &item.title,
466 &item.hash,
467 &now,
468 &now,
469 &item.source_type,
470 item.metadata.get("source_uri").map(|s| s.as_str()),
471 )?;
472 }
473 updated += 1;
474 }
475 }
476
477 self.touch_collection(name)?;
478 Ok(updated)
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_database_stores_provider_info_correctly() {
488 let db = Database::open_in_memory().unwrap();
489 db.initialize().unwrap();
490
491 db.add_collection(
492 "test_file",
493 "/tmp/test",
494 "**/*.md",
495 "file",
496 Some(r#"{"exclude_hidden":"false"}"#),
497 )
498 .unwrap();
499
500 db.add_collection(
501 "test_github",
502 "https://github.com/test/repo",
503 "**/*.md",
504 "github",
505 None,
506 )
507 .unwrap();
508
509 let provider_type_file: String = db
510 .conn
511 .query_row(
512 "SELECT provider_type FROM collections WHERE name = 'test_file'",
513 [],
514 |row| row.get(0),
515 )
516 .unwrap();
517 assert_eq!(provider_type_file, "file");
518
519 let provider_config_file: Option<String> = db
520 .conn
521 .query_row(
522 "SELECT provider_config FROM collections WHERE name = 'test_file'",
523 [],
524 |row| row.get(0),
525 )
526 .unwrap();
527 assert_eq!(
528 provider_config_file,
529 Some(r#"{"exclude_hidden":"false"}"#.to_string())
530 );
531
532 let provider_type_github: String = db
533 .conn
534 .query_row(
535 "SELECT provider_type FROM collections WHERE name = 'test_github'",
536 [],
537 |row| row.get(0),
538 )
539 .unwrap();
540 assert_eq!(provider_type_github, "github");
541
542 let provider_config_github: Option<String> = db
543 .conn
544 .query_row(
545 "SELECT provider_config FROM collections WHERE name = 'test_github'",
546 [],
547 |row| row.get(0),
548 )
549 .unwrap();
550 assert_eq!(provider_config_github, None);
551
552 let collections = db.list_collections().unwrap();
553 assert_eq!(collections.len(), 2);
554
555 let file_coll = collections.iter().find(|c| c.name == "test_file").unwrap();
556 assert_eq!(file_coll.provider_type, "file");
557 assert_eq!(
558 file_coll.provider_config.as_deref(),
559 Some(r#"{"exclude_hidden":"false"}"#)
560 );
561
562 let github_coll = collections
563 .iter()
564 .find(|c| c.name == "test_github")
565 .unwrap();
566 assert_eq!(github_coll.provider_type, "github");
567 assert_eq!(github_coll.provider_config, None);
568 }
569
570 #[test]
571 fn test_documents_store_source_metadata() {
572 use crate::db::hash_content;
573 use chrono::Utc;
574
575 let db = Database::open_in_memory().unwrap();
576 db.initialize().unwrap();
577
578 db.add_collection("test", "/tmp", "**/*.md", "file", None)
579 .unwrap();
580
581 let content = "# Test Document";
582 let hash = hash_content(content);
583 db.insert_content(&hash, content).unwrap();
584
585 let now = Utc::now().to_rfc3339();
586 let doc_id = db
587 .insert_document(
588 "test",
589 "doc1.md",
590 "Test Document",
591 &hash,
592 &now,
593 &now,
594 "file",
595 Some("/tmp/doc1.md"),
596 )
597 .unwrap();
598
599 assert!(doc_id > 0);
600
601 let source_type: String = db
602 .conn
603 .query_row(
604 "SELECT source_type FROM documents WHERE id = ?1",
605 [doc_id],
606 |row| row.get(0),
607 )
608 .unwrap();
609 assert_eq!(source_type, "file");
610
611 let source_uri: Option<String> = db
612 .conn
613 .query_row(
614 "SELECT source_uri FROM documents WHERE id = ?1",
615 [doc_id],
616 |row| row.get(0),
617 )
618 .unwrap();
619 assert_eq!(source_uri, Some("/tmp/doc1.md".to_string()));
620
621 db.insert_content(&hash, content).unwrap();
622 let doc_id2 = db
623 .insert_document(
624 "test",
625 "doc2.md",
626 "Test Document 2",
627 &hash,
628 &now,
629 &now,
630 "github",
631 Some("https://github.com/test/repo/doc2.md"),
632 )
633 .unwrap();
634
635 let source_type2: String = db
636 .conn
637 .query_row(
638 "SELECT source_type FROM documents WHERE id = ?1",
639 [doc_id2],
640 |row| row.get(0),
641 )
642 .unwrap();
643 assert_eq!(source_type2, "github");
644
645 let source_uri2: Option<String> = db
646 .conn
647 .query_row(
648 "SELECT source_uri FROM documents WHERE id = ?1",
649 [doc_id2],
650 |row| row.get(0),
651 )
652 .unwrap();
653 assert_eq!(
654 source_uri2,
655 Some("https://github.com/test/repo/doc2.md".to_string())
656 );
657 }
658
659 #[tokio::test]
660 async fn test_reindex_collection_uses_provider_system() {
661 use std::fs;
662 use tempfile::TempDir;
663
664 let temp = TempDir::new().unwrap();
665 let base = temp.path();
666
667 fs::write(base.join("doc1.md"), "# Document 1\nInitial content").unwrap();
668 fs::write(base.join("doc2.md"), "# Document 2\nInitial content").unwrap();
669
670 let db = Database::open_in_memory().unwrap();
671 db.initialize().unwrap();
672
673 db.add_collection(
674 "test",
675 &base.to_string_lossy(),
676 "**/*.md",
677 "file",
678 Some(r#"{"exclude_hidden":"false"}"#),
679 )
680 .unwrap();
681
682 let updated = db.reindex_collection("test").await.unwrap();
683 assert_eq!(updated, 2, "Should index 2 files on first run");
684
685 let collections = db.list_collections().unwrap();
686 assert_eq!(collections[0].document_count, 2);
687
688 let doc_count: i64 = db
689 .conn
690 .query_row(
691 "SELECT COUNT(*) FROM documents WHERE collection = 'test' AND active = 1",
692 [],
693 |row| row.get(0),
694 )
695 .unwrap();
696 assert_eq!(doc_count, 2);
697
698 let mut stmt = db
699 .conn
700 .prepare(
701 "SELECT path, source_type FROM documents WHERE collection = 'test' ORDER BY path",
702 )
703 .unwrap();
704 let sources: Vec<(String, String)> = stmt
705 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))
706 .unwrap()
707 .collect::<std::result::Result<Vec<_>, _>>()
708 .unwrap();
709
710 assert_eq!(sources.len(), 2);
711 assert_eq!(sources[0].0, "doc1.md");
712 assert_eq!(sources[0].1, "file");
713 assert_eq!(sources[1].0, "doc2.md");
714 assert_eq!(sources[1].1, "file");
715
716 fs::write(base.join("doc1.md"), "# Document 1\nUpdated content").unwrap();
717
718 let updated2 = db.reindex_collection("test").await.unwrap();
719 assert_eq!(updated2, 1, "Should update only changed file");
720
721 let collections2 = db.list_collections().unwrap();
722 assert_eq!(
723 collections2[0].document_count, 2,
724 "Should still have 2 documents"
725 );
726
727 fs::write(base.join("doc3.md"), "# Document 3\nNew content").unwrap();
728
729 let updated3 = db.reindex_collection("test").await.unwrap();
730 assert_eq!(updated3, 1, "Should add new file");
731
732 let collections3 = db.list_collections().unwrap();
733 assert_eq!(
734 collections3[0].document_count, 3,
735 "Should now have 3 documents"
736 );
737 }
738
739 #[tokio::test]
740 async fn test_reindex_invalid_provider_type() {
741 let db = Database::open_in_memory().unwrap();
742 db.initialize().unwrap();
743
744 db.add_collection("test", "/tmp", "**/*.md", "nonexistent_provider", None)
745 .unwrap();
746
747 let result = db.reindex_collection("test").await;
748 assert!(result.is_err(), "Should error on invalid provider type");
749
750 match result {
751 Err(crate::error::AgentRootError::InvalidInput(msg)) => {
752 assert!(msg.contains("Unknown provider type"));
753 assert!(msg.contains("nonexistent_provider"));
754 }
755 _ => panic!("Expected InvalidInput error"),
756 }
757 }
758
759 #[tokio::test]
760 async fn test_reindex_nonexistent_collection() {
761 let db = Database::open_in_memory().unwrap();
762 db.initialize().unwrap();
763
764 let result = db.reindex_collection("nonexistent").await;
765 assert!(result.is_err(), "Should error on nonexistent collection");
766
767 match result {
768 Err(crate::error::AgentRootError::CollectionNotFound(name)) => {
769 assert_eq!(name, "nonexistent");
770 }
771 _ => panic!("Expected CollectionNotFound error"),
772 }
773 }
774
775 #[test]
776 fn test_add_collection_duplicate_name() {
777 let db = Database::open_in_memory().unwrap();
778 db.initialize().unwrap();
779
780 db.add_collection("test", "/tmp1", "**/*.md", "file", None)
781 .unwrap();
782
783 let result = db.add_collection("test", "/tmp2", "**/*.md", "file", None);
784 assert!(result.is_err(), "Should error on duplicate collection name");
785 }
786
787 #[tokio::test]
788 async fn test_reindex_with_malformed_provider_config() {
789 use std::fs;
790 use tempfile::TempDir;
791
792 let temp = TempDir::new().unwrap();
793 let base = temp.path();
794 fs::write(base.join("test.md"), "# Test").unwrap();
795
796 let db = Database::open_in_memory().unwrap();
797 db.initialize().unwrap();
798
799 db.add_collection(
800 "test",
801 &base.to_string_lossy(),
802 "**/*.md",
803 "file",
804 Some("malformed json that won't parse"),
805 )
806 .unwrap();
807
808 let result = db.reindex_collection("test").await;
809 assert!(
810 result.is_ok(),
811 "Should succeed despite malformed JSON config (uses defaults)"
812 );
813 }
814}