1use async_trait::async_trait;
2use roboticus_core::{Result, RoboticusError};
3use serde::{Deserialize, Serialize};
4use std::path::{Path, PathBuf};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct KnowledgeChunk {
9 pub content: String,
10 pub source: String,
11 pub relevance: f64,
12 pub metadata: Option<serde_json::Value>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct KnowledgeSourceConfig {
18 pub name: String,
19 pub source_type: String,
20 pub path: Option<PathBuf>,
21 pub url: Option<String>,
22 pub max_chunks: usize,
23}
24
25#[async_trait]
27pub trait KnowledgeSource: Send + Sync {
28 fn name(&self) -> &str;
29 fn source_type(&self) -> &str;
30 async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>>;
31 async fn ingest(&self, content: &str, source: &str) -> Result<()>;
32 fn is_available(&self) -> bool;
33}
34
35pub struct DirectorySource {
37 name: String,
38 root: PathBuf,
39 extensions: Vec<String>,
40}
41
42impl DirectorySource {
43 pub fn new(name: &str, root: PathBuf) -> Self {
44 Self {
45 name: name.to_string(),
46 root,
47 extensions: vec![
48 "md".into(),
49 "txt".into(),
50 "rs".into(),
51 "py".into(),
52 "js".into(),
53 "ts".into(),
54 "toml".into(),
55 "yaml".into(),
56 "json".into(),
57 ],
58 }
59 }
60
61 #[must_use]
62 pub fn with_extensions(mut self, exts: Vec<String>) -> Self {
63 self.extensions = exts;
64 self
65 }
66
67 fn is_supported_extension(&self, path: &Path) -> bool {
68 path.extension()
69 .and_then(|e| e.to_str())
70 .map(|e| self.extensions.iter().any(|ext| ext == e))
71 .unwrap_or(false)
72 }
73
74 pub fn scan_files(&self) -> Vec<PathBuf> {
76 let mut files = Vec::new();
77 if let Ok(entries) = std::fs::read_dir(&self.root) {
78 for entry in entries.flatten() {
79 let path = entry.path();
80 if path.is_file() && self.is_supported_extension(&path) {
81 files.push(path);
82 } else if path.is_dir()
83 && let Ok(sub) = std::fs::read_dir(&path)
84 {
85 for sub_entry in sub.flatten() {
86 let sub_path = sub_entry.path();
87 if sub_path.is_file() && self.is_supported_extension(&sub_path) {
88 files.push(sub_path);
89 }
90 }
91 }
92 }
93 }
94 files
95 }
96}
97
98#[async_trait]
99impl KnowledgeSource for DirectorySource {
100 fn name(&self) -> &str {
101 &self.name
102 }
103
104 fn source_type(&self) -> &str {
105 "directory"
106 }
107
108 async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>> {
109 let query_lower = query.to_lowercase();
110 let files = self.scan_files();
111
112 let chunks = tokio::task::spawn_blocking(move || {
113 let mut chunks = Vec::new();
114 for path in files {
115 const MAX_FILE_BYTES: u64 = 10 * 1024 * 1024;
117 if let Ok(content) = (|| -> std::io::Result<String> {
118 use std::io::Read;
119 let file = std::fs::File::open(&path)?;
120 let meta = file.metadata()?;
121 if meta.len() > MAX_FILE_BYTES {
122 return Err(std::io::Error::other("file too large for knowledge query"));
123 }
124 let mut buf = String::new();
125 file.take(MAX_FILE_BYTES).read_to_string(&mut buf)?;
126 Ok(buf)
127 })() {
128 let content_lower = content.to_lowercase();
129 if content_lower.contains(&query_lower) {
130 let relevance = content_lower.matches(&query_lower).count() as f64
131 / content.len().max(1) as f64;
132 chunks.push(KnowledgeChunk {
133 content: truncate(&content, 2000),
134 source: path.display().to_string(),
135 relevance,
136 metadata: Some(serde_json::json!({
137 "file_size": content.len(),
138 "path": path.display().to_string(),
139 })),
140 });
141 }
142 }
143 }
144 chunks.sort_by(|a, b| {
145 b.relevance
146 .partial_cmp(&a.relevance)
147 .unwrap_or(std::cmp::Ordering::Equal)
148 });
149 chunks.truncate(max_results);
150 chunks
151 })
152 .await
153 .map_err(|e| RoboticusError::Config(format!("blocking task failed: {e}")))?;
154
155 Ok(chunks)
156 }
157
158 async fn ingest(&self, _content: &str, _source: &str) -> Result<()> {
159 Ok(())
160 }
161
162 fn is_available(&self) -> bool {
163 self.root.exists() && self.root.is_dir()
164 }
165}
166
167pub struct GitSource {
169 name: String,
170 repo_path: PathBuf,
171 inner: DirectorySource,
172}
173
174impl GitSource {
175 pub fn new(name: &str, repo_path: PathBuf) -> Self {
176 let inner = DirectorySource::new(name, repo_path.clone());
177 Self {
178 name: name.to_string(),
179 repo_path,
180 inner,
181 }
182 }
183
184 pub fn is_git_repo(&self) -> bool {
186 self.repo_path.join(".git").exists()
187 }
188}
189
190#[async_trait]
191impl KnowledgeSource for GitSource {
192 fn name(&self) -> &str {
193 &self.name
194 }
195
196 fn source_type(&self) -> &str {
197 "git"
198 }
199
200 async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>> {
201 self.inner.query(query, max_results).await
202 }
203
204 async fn ingest(&self, _content: &str, _source: &str) -> Result<()> {
205 Ok(())
206 }
207
208 fn is_available(&self) -> bool {
209 self.is_git_repo()
210 }
211}
212
213pub struct VectorDbSource {
215 name: String,
216 url: String,
217 http: reqwest::Client,
218 api_key: Option<String>,
219}
220
221impl VectorDbSource {
222 pub fn new(name: &str, url: &str) -> Result<Self> {
223 Ok(Self {
224 name: name.to_string(),
225 url: url.to_string(),
226 http: reqwest::Client::builder()
227 .timeout(std::time::Duration::from_secs(30))
228 .build()
229 .map_err(|e| RoboticusError::Config(format!("HTTP client build failed: {e}")))?,
230 api_key: None,
231 })
232 }
233
234 #[must_use]
235 pub fn with_api_key(mut self, key: String) -> Self {
236 self.api_key = Some(key);
237 self
238 }
239}
240
241#[derive(Deserialize)]
242struct VectorQueryResult {
243 #[serde(default)]
244 content: String,
245 #[serde(default)]
246 source: String,
247 #[serde(default)]
248 relevance: f64,
249}
250
251#[async_trait]
252impl KnowledgeSource for VectorDbSource {
253 fn name(&self) -> &str {
254 &self.name
255 }
256
257 fn source_type(&self) -> &str {
258 "vector_db"
259 }
260
261 async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>> {
262 let url = format!("{}/query", self.url);
263 let body = serde_json::json!({
264 "query": query,
265 "top_k": max_results,
266 });
267
268 let mut req = self.http.post(&url).json(&body);
269 if let Some(key) = &self.api_key {
270 req = req.bearer_auth(key);
271 }
272
273 let resp = req
274 .send()
275 .await
276 .map_err(|e| RoboticusError::Network(format!("vector DB query failed: {e}")))?;
277
278 if !resp.status().is_success() {
279 let status = resp.status();
280 let body = resp.text().await.unwrap_or_default();
281 return Err(RoboticusError::Network(format!(
282 "vector DB returned {status}: {body}"
283 )));
284 }
285
286 let results: Vec<VectorQueryResult> = resp
287 .json()
288 .await
289 .map_err(|e| RoboticusError::Network(format!("vector DB response parse error: {e}")))?;
290
291 Ok(results
292 .into_iter()
293 .map(|r| KnowledgeChunk {
294 content: r.content,
295 source: r.source,
296 relevance: r.relevance,
297 metadata: None,
298 })
299 .collect())
300 }
301
302 async fn ingest(&self, content: &str, source: &str) -> Result<()> {
303 let url = format!("{}/upsert", self.url);
304 let body = serde_json::json!({
305 "documents": [{
306 "content": content,
307 "source": source,
308 }],
309 });
310
311 let mut req = self.http.post(&url).json(&body);
312 if let Some(key) = &self.api_key {
313 req = req.bearer_auth(key);
314 }
315
316 let resp = req
317 .send()
318 .await
319 .map_err(|e| RoboticusError::Network(format!("vector DB ingest failed: {e}")))?;
320
321 if !resp.status().is_success() {
322 let status = resp.status();
323 let body = resp.text().await.unwrap_or_default();
324 return Err(RoboticusError::Network(format!(
325 "vector DB ingest returned {status}: {body}"
326 )));
327 }
328
329 Ok(())
330 }
331
332 fn is_available(&self) -> bool {
333 !self.url.is_empty()
334 }
335}
336
337pub struct GraphSource {
339 name: String,
340 url: String,
341 http: reqwest::Client,
342 api_key: Option<String>,
343}
344
345impl GraphSource {
346 pub fn new(name: &str, url: &str) -> Result<Self> {
347 Ok(Self {
348 name: name.to_string(),
349 url: url.to_string(),
350 http: reqwest::Client::builder()
351 .timeout(std::time::Duration::from_secs(30))
352 .build()
353 .map_err(|e| RoboticusError::Config(format!("HTTP client build failed: {e}")))?,
354 api_key: None,
355 })
356 }
357
358 #[must_use]
359 pub fn with_api_key(mut self, key: String) -> Self {
360 self.api_key = Some(key);
361 self
362 }
363}
364
365#[async_trait]
366impl KnowledgeSource for GraphSource {
367 fn name(&self) -> &str {
368 &self.name
369 }
370
371 fn source_type(&self) -> &str {
372 "graph"
373 }
374
375 async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>> {
376 let url = format!("{}/db/neo4j/tx/commit", self.url);
377 let cypher = "MATCH (n) WHERE n.content CONTAINS $query RETURN n.content AS content, \
378 n.source AS source, 1.0 AS relevance LIMIT $limit"
379 .to_string();
380 let body = serde_json::json!({
381 "statements": [{
382 "statement": cypher,
383 "parameters": {
384 "query": query,
385 "limit": max_results,
386 },
387 }],
388 });
389
390 let mut req = self.http.post(&url).json(&body);
391 if let Some(key) = &self.api_key {
392 req = req.bearer_auth(key);
393 }
394
395 let resp = req
396 .send()
397 .await
398 .map_err(|e| RoboticusError::Network(format!("graph DB query failed: {e}")))?;
399
400 if !resp.status().is_success() {
401 let status = resp.status();
402 let body = resp.text().await.unwrap_or_default();
403 return Err(RoboticusError::Network(format!(
404 "graph DB returned {status}: {body}"
405 )));
406 }
407
408 let json: serde_json::Value = resp
409 .json()
410 .await
411 .map_err(|e| RoboticusError::Network(format!("graph DB response parse error: {e}")))?;
412
413 let mut chunks = Vec::new();
414 if let Some(results) = json.get("results").and_then(|r| r.as_array()) {
415 for result in results {
416 if let Some(data) = result.get("data").and_then(|d| d.as_array()) {
417 for row in data {
418 if let Some(row_vals) = row.get("row").and_then(|r| r.as_array()) {
419 let content = row_vals
420 .first()
421 .and_then(|v| v.as_str())
422 .unwrap_or_default()
423 .to_string();
424 let source = row_vals
425 .get(1)
426 .and_then(|v| v.as_str())
427 .unwrap_or_default()
428 .to_string();
429 let relevance = row_vals.get(2).and_then(|v| v.as_f64()).unwrap_or(0.0);
430
431 chunks.push(KnowledgeChunk {
432 content,
433 source,
434 relevance,
435 metadata: None,
436 });
437 }
438 }
439 }
440 }
441 }
442
443 Ok(chunks)
444 }
445
446 async fn ingest(&self, content: &str, source: &str) -> Result<()> {
447 let url = format!("{}/db/neo4j/tx/commit", self.url);
448 let body = serde_json::json!({
449 "statements": [{
450 "statement": "MERGE (n:Knowledge {source: $source}) SET n.content = $content",
451 "parameters": {
452 "content": content,
453 "source": source,
454 },
455 }],
456 });
457
458 let mut req = self.http.post(&url).json(&body);
459 if let Some(key) = &self.api_key {
460 req = req.bearer_auth(key);
461 }
462
463 let resp = req
464 .send()
465 .await
466 .map_err(|e| RoboticusError::Network(format!("graph DB ingest failed: {e}")))?;
467
468 if !resp.status().is_success() {
469 let status = resp.status();
470 let body = resp.text().await.unwrap_or_default();
471 return Err(RoboticusError::Network(format!(
472 "graph DB ingest returned {status}: {body}"
473 )));
474 }
475
476 Ok(())
477 }
478
479 fn is_available(&self) -> bool {
480 !self.url.is_empty()
481 }
482}
483
484pub struct KnowledgeRegistry {
486 sources: Vec<Box<dyn KnowledgeSource>>,
487}
488
489impl KnowledgeRegistry {
490 pub fn new() -> Self {
491 Self {
492 sources: Vec::new(),
493 }
494 }
495
496 pub fn add(&mut self, source: Box<dyn KnowledgeSource>) {
497 self.sources.push(source);
498 }
499
500 pub fn list(&self) -> Vec<(&str, &str, bool)> {
501 self.sources
502 .iter()
503 .map(|s| (s.name(), s.source_type(), s.is_available()))
504 .collect()
505 }
506
507 pub async fn query_all(&self, query: &str, max_per_source: usize) -> Vec<KnowledgeChunk> {
508 let mut all_chunks = Vec::new();
509 for source in &self.sources {
510 if source.is_available() {
511 match source.query(query, max_per_source).await {
512 Ok(chunks) => all_chunks.extend(chunks),
513 Err(e) => tracing::warn!(
514 source = %source.name(),
515 error = %e,
516 "knowledge query failed"
517 ),
518 }
519 }
520 }
521 all_chunks.sort_by(|a, b| {
522 b.relevance
523 .partial_cmp(&a.relevance)
524 .unwrap_or(std::cmp::Ordering::Equal)
525 });
526 all_chunks
527 }
528
529 pub fn available_count(&self) -> usize {
530 self.sources.iter().filter(|s| s.is_available()).count()
531 }
532}
533
534impl Default for KnowledgeRegistry {
535 fn default() -> Self {
536 Self::new()
537 }
538}
539
540fn truncate(s: &str, max: usize) -> String {
541 if s.len() <= max {
542 s.to_string()
543 } else {
544 let boundary = s.floor_char_boundary(max);
545 format!("{}...", &s[..boundary])
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use std::fs;
553 use tempfile::TempDir;
554
555 #[test]
556 fn directory_source_scan_finds_files() {
557 let dir = TempDir::new().unwrap();
558 fs::write(dir.path().join("readme.md"), "# Hello").unwrap();
559 fs::write(dir.path().join("code.rs"), "fn main() {}").unwrap();
560 fs::write(dir.path().join("image.png"), "binary").unwrap();
561
562 let source = DirectorySource::new("test", dir.path().to_path_buf());
563 let files = source.scan_files();
564 assert_eq!(files.len(), 2);
565 }
566
567 #[test]
568 fn directory_source_not_available_for_missing_dir() {
569 let source = DirectorySource::new("test", PathBuf::from("/nonexistent/path"));
570 assert!(!source.is_available());
571 }
572
573 #[tokio::test]
574 async fn directory_source_query_finds_matching_content() {
575 let dir = TempDir::new().unwrap();
576 fs::write(
577 dir.path().join("notes.md"),
578 "Rust is a systems programming language",
579 )
580 .unwrap();
581 fs::write(dir.path().join("other.txt"), "Python is interpreted").unwrap();
582
583 let source = DirectorySource::new("test", dir.path().to_path_buf());
584 let results = source.query("Rust", 10).await.unwrap();
585 assert_eq!(results.len(), 1);
586 assert!(results[0].content.contains("Rust"));
587 }
588
589 #[tokio::test]
590 async fn directory_source_query_empty_for_no_match() {
591 let dir = TempDir::new().unwrap();
592 fs::write(dir.path().join("notes.md"), "Hello world").unwrap();
593
594 let source = DirectorySource::new("test", dir.path().to_path_buf());
595 let results = source.query("nonexistent_query_term", 10).await.unwrap();
596 assert!(results.is_empty());
597 }
598
599 #[test]
600 fn git_source_detects_repo() {
601 let dir = TempDir::new().unwrap();
602 fs::create_dir(dir.path().join(".git")).unwrap();
603
604 let source = GitSource::new("test", dir.path().to_path_buf());
605 assert!(source.is_git_repo());
606 assert!(source.is_available());
607 }
608
609 #[test]
610 fn git_source_not_repo() {
611 let dir = TempDir::new().unwrap();
612 let source = GitSource::new("test", dir.path().to_path_buf());
613 assert!(!source.is_git_repo());
614 assert!(!source.is_available());
615 }
616
617 #[test]
618 fn vector_db_source_available_with_url() {
619 let source = VectorDbSource::new("pinecone", "https://pinecone.io").unwrap();
620 assert!(source.is_available());
621 assert_eq!(source.source_type(), "vector_db");
622 }
623
624 #[test]
625 fn vector_db_source_not_available_empty_url() {
626 let source = VectorDbSource::new("empty", "").unwrap();
627 assert!(!source.is_available());
628 }
629
630 #[test]
631 fn vector_db_source_with_api_key() {
632 let source = VectorDbSource::new("pinecone", "https://pinecone.io")
633 .unwrap()
634 .with_api_key("sk-test".to_string());
635 assert!(source.api_key.is_some());
636 }
637
638 #[test]
639 fn graph_source_available_with_url() {
640 let source = GraphSource::new("neo4j", "http://localhost:7474").unwrap();
641 assert!(source.is_available());
642 assert_eq!(source.source_type(), "graph");
643 }
644
645 #[test]
646 fn graph_source_with_api_key() {
647 let source = GraphSource::new("neo4j", "http://localhost:7474")
648 .unwrap()
649 .with_api_key("token".to_string());
650 assert!(source.api_key.is_some());
651 }
652
653 #[test]
654 fn registry_empty() {
655 let reg = KnowledgeRegistry::new();
656 assert_eq!(reg.available_count(), 0);
657 assert!(reg.list().is_empty());
658 }
659
660 #[test]
661 fn registry_lists_sources() {
662 let dir = TempDir::new().unwrap();
663 let mut reg = KnowledgeRegistry::new();
664 reg.add(Box::new(DirectorySource::new(
665 "docs",
666 dir.path().to_path_buf(),
667 )));
668 reg.add(Box::new(
669 VectorDbSource::new("pinecone", "https://api.pinecone.io").unwrap(),
670 ));
671
672 let list = reg.list();
673 assert_eq!(list.len(), 2);
674 assert_eq!(list[0].0, "docs");
675 assert_eq!(list[1].0, "pinecone");
676 }
677
678 #[tokio::test]
679 async fn registry_query_all_aggregates() {
680 let dir = TempDir::new().unwrap();
681 fs::write(dir.path().join("file.md"), "knowledge about Rust").unwrap();
682
683 let mut reg = KnowledgeRegistry::new();
684 reg.add(Box::new(DirectorySource::new(
685 "docs",
686 dir.path().to_path_buf(),
687 )));
688
689 let results = reg.query_all("Rust", 5).await;
690 assert_eq!(results.len(), 1);
691 }
692
693 #[test]
694 fn chunk_serialization() {
695 let chunk = KnowledgeChunk {
696 content: "test content".into(),
697 source: "test.md".into(),
698 relevance: 0.95,
699 metadata: None,
700 };
701 let json = serde_json::to_string(&chunk).unwrap();
702 let decoded: KnowledgeChunk = serde_json::from_str(&json).unwrap();
703 assert_eq!(decoded.content, "test content");
704 assert_eq!(decoded.relevance, 0.95);
705 }
706}