Skip to main content

roboticus_agent/
knowledge.rs

1use async_trait::async_trait;
2use roboticus_core::{Result, RoboticusError};
3use serde::{Deserialize, Serialize};
4use std::path::{Path, PathBuf};
5
6/// A chunk of knowledge retrieved from a source.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct KnowledgeChunk {
9    pub content: String,
10    pub source: String,
11    pub relevance: f64,
12    pub metadata: Option<serde_json::Value>,
13}
14
15/// Configuration for a knowledge source.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct KnowledgeSourceConfig {
18    pub name: String,
19    pub source_type: String,
20    pub path: Option<PathBuf>,
21    pub url: Option<String>,
22    pub max_chunks: usize,
23}
24
25/// Trait for external knowledge sources the agent can query.
26#[async_trait]
27pub trait KnowledgeSource: Send + Sync {
28    fn name(&self) -> &str;
29    fn source_type(&self) -> &str;
30    async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>>;
31    async fn ingest(&self, content: &str, source: &str) -> Result<()>;
32    fn is_available(&self) -> bool;
33}
34
35/// A knowledge source backed by a local directory of files.
36pub struct DirectorySource {
37    name: String,
38    root: PathBuf,
39    extensions: Vec<String>,
40}
41
42impl DirectorySource {
43    pub fn new(name: &str, root: PathBuf) -> Self {
44        Self {
45            name: name.to_string(),
46            root,
47            extensions: vec![
48                "md".into(),
49                "txt".into(),
50                "rs".into(),
51                "py".into(),
52                "js".into(),
53                "ts".into(),
54                "toml".into(),
55                "yaml".into(),
56                "json".into(),
57            ],
58        }
59    }
60
61    #[must_use]
62    pub fn with_extensions(mut self, exts: Vec<String>) -> Self {
63        self.extensions = exts;
64        self
65    }
66
67    fn is_supported_extension(&self, path: &Path) -> bool {
68        path.extension()
69            .and_then(|e| e.to_str())
70            .map(|e| self.extensions.iter().any(|ext| ext == e))
71            .unwrap_or(false)
72    }
73
74    /// Scan directory for files matching supported extensions.
75    pub fn scan_files(&self) -> Vec<PathBuf> {
76        let mut files = Vec::new();
77        if let Ok(entries) = std::fs::read_dir(&self.root) {
78            for entry in entries.flatten() {
79                let path = entry.path();
80                if path.is_file() && self.is_supported_extension(&path) {
81                    files.push(path);
82                } else if path.is_dir()
83                    && let Ok(sub) = std::fs::read_dir(&path)
84                {
85                    for sub_entry in sub.flatten() {
86                        let sub_path = sub_entry.path();
87                        if sub_path.is_file() && self.is_supported_extension(&sub_path) {
88                            files.push(sub_path);
89                        }
90                    }
91                }
92            }
93        }
94        files
95    }
96}
97
98#[async_trait]
99impl KnowledgeSource for DirectorySource {
100    fn name(&self) -> &str {
101        &self.name
102    }
103
104    fn source_type(&self) -> &str {
105        "directory"
106    }
107
108    async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>> {
109        let query_lower = query.to_lowercase();
110        let files = self.scan_files();
111
112        let chunks = tokio::task::spawn_blocking(move || {
113            let mut chunks = Vec::new();
114            for path in files {
115                // Cap file reads at 10 MB to prevent OOM on huge files.
116                const MAX_FILE_BYTES: u64 = 10 * 1024 * 1024;
117                if let Ok(content) = (|| -> std::io::Result<String> {
118                    use std::io::Read;
119                    let file = std::fs::File::open(&path)?;
120                    let meta = file.metadata()?;
121                    if meta.len() > MAX_FILE_BYTES {
122                        return Err(std::io::Error::other("file too large for knowledge query"));
123                    }
124                    let mut buf = String::new();
125                    file.take(MAX_FILE_BYTES).read_to_string(&mut buf)?;
126                    Ok(buf)
127                })() {
128                    let content_lower = content.to_lowercase();
129                    if content_lower.contains(&query_lower) {
130                        let relevance = content_lower.matches(&query_lower).count() as f64
131                            / content.len().max(1) as f64;
132                        chunks.push(KnowledgeChunk {
133                            content: truncate(&content, 2000),
134                            source: path.display().to_string(),
135                            relevance,
136                            metadata: Some(serde_json::json!({
137                                "file_size": content.len(),
138                                "path": path.display().to_string(),
139                            })),
140                        });
141                    }
142                }
143            }
144            chunks.sort_by(|a, b| {
145                b.relevance
146                    .partial_cmp(&a.relevance)
147                    .unwrap_or(std::cmp::Ordering::Equal)
148            });
149            chunks.truncate(max_results);
150            chunks
151        })
152        .await
153        .map_err(|e| RoboticusError::Config(format!("blocking task failed: {e}")))?;
154
155        Ok(chunks)
156    }
157
158    async fn ingest(&self, _content: &str, _source: &str) -> Result<()> {
159        Ok(())
160    }
161
162    fn is_available(&self) -> bool {
163        self.root.exists() && self.root.is_dir()
164    }
165}
166
167/// A knowledge source backed by a Git repository.
168pub struct GitSource {
169    name: String,
170    repo_path: PathBuf,
171    inner: DirectorySource,
172}
173
174impl GitSource {
175    pub fn new(name: &str, repo_path: PathBuf) -> Self {
176        let inner = DirectorySource::new(name, repo_path.clone());
177        Self {
178            name: name.to_string(),
179            repo_path,
180            inner,
181        }
182    }
183
184    /// Check if the path is a Git repository.
185    pub fn is_git_repo(&self) -> bool {
186        self.repo_path.join(".git").exists()
187    }
188}
189
190#[async_trait]
191impl KnowledgeSource for GitSource {
192    fn name(&self) -> &str {
193        &self.name
194    }
195
196    fn source_type(&self) -> &str {
197        "git"
198    }
199
200    async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>> {
201        self.inner.query(query, max_results).await
202    }
203
204    async fn ingest(&self, _content: &str, _source: &str) -> Result<()> {
205        Ok(())
206    }
207
208    fn is_available(&self) -> bool {
209        self.is_git_repo()
210    }
211}
212
213/// A knowledge source backed by a vector database HTTP API.
214pub struct VectorDbSource {
215    name: String,
216    url: String,
217    http: reqwest::Client,
218    api_key: Option<String>,
219}
220
221impl VectorDbSource {
222    pub fn new(name: &str, url: &str) -> Result<Self> {
223        Ok(Self {
224            name: name.to_string(),
225            url: url.to_string(),
226            http: reqwest::Client::builder()
227                .timeout(std::time::Duration::from_secs(30))
228                .build()
229                .map_err(|e| RoboticusError::Config(format!("HTTP client build failed: {e}")))?,
230            api_key: None,
231        })
232    }
233
234    #[must_use]
235    pub fn with_api_key(mut self, key: String) -> Self {
236        self.api_key = Some(key);
237        self
238    }
239}
240
241#[derive(Deserialize)]
242struct VectorQueryResult {
243    #[serde(default)]
244    content: String,
245    #[serde(default)]
246    source: String,
247    #[serde(default)]
248    relevance: f64,
249}
250
251#[async_trait]
252impl KnowledgeSource for VectorDbSource {
253    fn name(&self) -> &str {
254        &self.name
255    }
256
257    fn source_type(&self) -> &str {
258        "vector_db"
259    }
260
261    async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>> {
262        let url = format!("{}/query", self.url);
263        let body = serde_json::json!({
264            "query": query,
265            "top_k": max_results,
266        });
267
268        let mut req = self.http.post(&url).json(&body);
269        if let Some(key) = &self.api_key {
270            req = req.bearer_auth(key);
271        }
272
273        let resp = req
274            .send()
275            .await
276            .map_err(|e| RoboticusError::Network(format!("vector DB query failed: {e}")))?;
277
278        if !resp.status().is_success() {
279            let status = resp.status();
280            let body = resp.text().await.unwrap_or_default();
281            return Err(RoboticusError::Network(format!(
282                "vector DB returned {status}: {body}"
283            )));
284        }
285
286        let results: Vec<VectorQueryResult> = resp
287            .json()
288            .await
289            .map_err(|e| RoboticusError::Network(format!("vector DB response parse error: {e}")))?;
290
291        Ok(results
292            .into_iter()
293            .map(|r| KnowledgeChunk {
294                content: r.content,
295                source: r.source,
296                relevance: r.relevance,
297                metadata: None,
298            })
299            .collect())
300    }
301
302    async fn ingest(&self, content: &str, source: &str) -> Result<()> {
303        let url = format!("{}/upsert", self.url);
304        let body = serde_json::json!({
305            "documents": [{
306                "content": content,
307                "source": source,
308            }],
309        });
310
311        let mut req = self.http.post(&url).json(&body);
312        if let Some(key) = &self.api_key {
313            req = req.bearer_auth(key);
314        }
315
316        let resp = req
317            .send()
318            .await
319            .map_err(|e| RoboticusError::Network(format!("vector DB ingest failed: {e}")))?;
320
321        if !resp.status().is_success() {
322            let status = resp.status();
323            let body = resp.text().await.unwrap_or_default();
324            return Err(RoboticusError::Network(format!(
325                "vector DB ingest returned {status}: {body}"
326            )));
327        }
328
329        Ok(())
330    }
331
332    fn is_available(&self) -> bool {
333        !self.url.is_empty()
334    }
335}
336
337/// A knowledge source backed by a Neo4j graph database.
338pub struct GraphSource {
339    name: String,
340    url: String,
341    http: reqwest::Client,
342    api_key: Option<String>,
343}
344
345impl GraphSource {
346    pub fn new(name: &str, url: &str) -> Result<Self> {
347        Ok(Self {
348            name: name.to_string(),
349            url: url.to_string(),
350            http: reqwest::Client::builder()
351                .timeout(std::time::Duration::from_secs(30))
352                .build()
353                .map_err(|e| RoboticusError::Config(format!("HTTP client build failed: {e}")))?,
354            api_key: None,
355        })
356    }
357
358    #[must_use]
359    pub fn with_api_key(mut self, key: String) -> Self {
360        self.api_key = Some(key);
361        self
362    }
363}
364
365#[async_trait]
366impl KnowledgeSource for GraphSource {
367    fn name(&self) -> &str {
368        &self.name
369    }
370
371    fn source_type(&self) -> &str {
372        "graph"
373    }
374
375    async fn query(&self, query: &str, max_results: usize) -> Result<Vec<KnowledgeChunk>> {
376        let url = format!("{}/db/neo4j/tx/commit", self.url);
377        let cypher = "MATCH (n) WHERE n.content CONTAINS $query RETURN n.content AS content, \
378             n.source AS source, 1.0 AS relevance LIMIT $limit"
379            .to_string();
380        let body = serde_json::json!({
381            "statements": [{
382                "statement": cypher,
383                "parameters": {
384                    "query": query,
385                    "limit": max_results,
386                },
387            }],
388        });
389
390        let mut req = self.http.post(&url).json(&body);
391        if let Some(key) = &self.api_key {
392            req = req.bearer_auth(key);
393        }
394
395        let resp = req
396            .send()
397            .await
398            .map_err(|e| RoboticusError::Network(format!("graph DB query failed: {e}")))?;
399
400        if !resp.status().is_success() {
401            let status = resp.status();
402            let body = resp.text().await.unwrap_or_default();
403            return Err(RoboticusError::Network(format!(
404                "graph DB returned {status}: {body}"
405            )));
406        }
407
408        let json: serde_json::Value = resp
409            .json()
410            .await
411            .map_err(|e| RoboticusError::Network(format!("graph DB response parse error: {e}")))?;
412
413        let mut chunks = Vec::new();
414        if let Some(results) = json.get("results").and_then(|r| r.as_array()) {
415            for result in results {
416                if let Some(data) = result.get("data").and_then(|d| d.as_array()) {
417                    for row in data {
418                        if let Some(row_vals) = row.get("row").and_then(|r| r.as_array()) {
419                            let content = row_vals
420                                .first()
421                                .and_then(|v| v.as_str())
422                                .unwrap_or_default()
423                                .to_string();
424                            let source = row_vals
425                                .get(1)
426                                .and_then(|v| v.as_str())
427                                .unwrap_or_default()
428                                .to_string();
429                            let relevance = row_vals.get(2).and_then(|v| v.as_f64()).unwrap_or(0.0);
430
431                            chunks.push(KnowledgeChunk {
432                                content,
433                                source,
434                                relevance,
435                                metadata: None,
436                            });
437                        }
438                    }
439                }
440            }
441        }
442
443        Ok(chunks)
444    }
445
446    async fn ingest(&self, content: &str, source: &str) -> Result<()> {
447        let url = format!("{}/db/neo4j/tx/commit", self.url);
448        let body = serde_json::json!({
449            "statements": [{
450                "statement": "MERGE (n:Knowledge {source: $source}) SET n.content = $content",
451                "parameters": {
452                    "content": content,
453                    "source": source,
454                },
455            }],
456        });
457
458        let mut req = self.http.post(&url).json(&body);
459        if let Some(key) = &self.api_key {
460            req = req.bearer_auth(key);
461        }
462
463        let resp = req
464            .send()
465            .await
466            .map_err(|e| RoboticusError::Network(format!("graph DB ingest failed: {e}")))?;
467
468        if !resp.status().is_success() {
469            let status = resp.status();
470            let body = resp.text().await.unwrap_or_default();
471            return Err(RoboticusError::Network(format!(
472                "graph DB ingest returned {status}: {body}"
473            )));
474        }
475
476        Ok(())
477    }
478
479    fn is_available(&self) -> bool {
480        !self.url.is_empty()
481    }
482}
483
484/// Registry of all knowledge sources.
485pub struct KnowledgeRegistry {
486    sources: Vec<Box<dyn KnowledgeSource>>,
487}
488
489impl KnowledgeRegistry {
490    pub fn new() -> Self {
491        Self {
492            sources: Vec::new(),
493        }
494    }
495
496    pub fn add(&mut self, source: Box<dyn KnowledgeSource>) {
497        self.sources.push(source);
498    }
499
500    pub fn list(&self) -> Vec<(&str, &str, bool)> {
501        self.sources
502            .iter()
503            .map(|s| (s.name(), s.source_type(), s.is_available()))
504            .collect()
505    }
506
507    pub async fn query_all(&self, query: &str, max_per_source: usize) -> Vec<KnowledgeChunk> {
508        let mut all_chunks = Vec::new();
509        for source in &self.sources {
510            if source.is_available() {
511                match source.query(query, max_per_source).await {
512                    Ok(chunks) => all_chunks.extend(chunks),
513                    Err(e) => tracing::warn!(
514                        source = %source.name(),
515                        error = %e,
516                        "knowledge query failed"
517                    ),
518                }
519            }
520        }
521        all_chunks.sort_by(|a, b| {
522            b.relevance
523                .partial_cmp(&a.relevance)
524                .unwrap_or(std::cmp::Ordering::Equal)
525        });
526        all_chunks
527    }
528
529    pub fn available_count(&self) -> usize {
530        self.sources.iter().filter(|s| s.is_available()).count()
531    }
532}
533
534impl Default for KnowledgeRegistry {
535    fn default() -> Self {
536        Self::new()
537    }
538}
539
540fn truncate(s: &str, max: usize) -> String {
541    if s.len() <= max {
542        s.to_string()
543    } else {
544        let boundary = s.floor_char_boundary(max);
545        format!("{}...", &s[..boundary])
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use std::fs;
553    use tempfile::TempDir;
554
555    #[test]
556    fn directory_source_scan_finds_files() {
557        let dir = TempDir::new().unwrap();
558        fs::write(dir.path().join("readme.md"), "# Hello").unwrap();
559        fs::write(dir.path().join("code.rs"), "fn main() {}").unwrap();
560        fs::write(dir.path().join("image.png"), "binary").unwrap();
561
562        let source = DirectorySource::new("test", dir.path().to_path_buf());
563        let files = source.scan_files();
564        assert_eq!(files.len(), 2);
565    }
566
567    #[test]
568    fn directory_source_not_available_for_missing_dir() {
569        let source = DirectorySource::new("test", PathBuf::from("/nonexistent/path"));
570        assert!(!source.is_available());
571    }
572
573    #[tokio::test]
574    async fn directory_source_query_finds_matching_content() {
575        let dir = TempDir::new().unwrap();
576        fs::write(
577            dir.path().join("notes.md"),
578            "Rust is a systems programming language",
579        )
580        .unwrap();
581        fs::write(dir.path().join("other.txt"), "Python is interpreted").unwrap();
582
583        let source = DirectorySource::new("test", dir.path().to_path_buf());
584        let results = source.query("Rust", 10).await.unwrap();
585        assert_eq!(results.len(), 1);
586        assert!(results[0].content.contains("Rust"));
587    }
588
589    #[tokio::test]
590    async fn directory_source_query_empty_for_no_match() {
591        let dir = TempDir::new().unwrap();
592        fs::write(dir.path().join("notes.md"), "Hello world").unwrap();
593
594        let source = DirectorySource::new("test", dir.path().to_path_buf());
595        let results = source.query("nonexistent_query_term", 10).await.unwrap();
596        assert!(results.is_empty());
597    }
598
599    #[test]
600    fn git_source_detects_repo() {
601        let dir = TempDir::new().unwrap();
602        fs::create_dir(dir.path().join(".git")).unwrap();
603
604        let source = GitSource::new("test", dir.path().to_path_buf());
605        assert!(source.is_git_repo());
606        assert!(source.is_available());
607    }
608
609    #[test]
610    fn git_source_not_repo() {
611        let dir = TempDir::new().unwrap();
612        let source = GitSource::new("test", dir.path().to_path_buf());
613        assert!(!source.is_git_repo());
614        assert!(!source.is_available());
615    }
616
617    #[test]
618    fn vector_db_source_available_with_url() {
619        let source = VectorDbSource::new("pinecone", "https://pinecone.io").unwrap();
620        assert!(source.is_available());
621        assert_eq!(source.source_type(), "vector_db");
622    }
623
624    #[test]
625    fn vector_db_source_not_available_empty_url() {
626        let source = VectorDbSource::new("empty", "").unwrap();
627        assert!(!source.is_available());
628    }
629
630    #[test]
631    fn vector_db_source_with_api_key() {
632        let source = VectorDbSource::new("pinecone", "https://pinecone.io")
633            .unwrap()
634            .with_api_key("sk-test".to_string());
635        assert!(source.api_key.is_some());
636    }
637
638    #[test]
639    fn graph_source_available_with_url() {
640        let source = GraphSource::new("neo4j", "http://localhost:7474").unwrap();
641        assert!(source.is_available());
642        assert_eq!(source.source_type(), "graph");
643    }
644
645    #[test]
646    fn graph_source_with_api_key() {
647        let source = GraphSource::new("neo4j", "http://localhost:7474")
648            .unwrap()
649            .with_api_key("token".to_string());
650        assert!(source.api_key.is_some());
651    }
652
653    #[test]
654    fn registry_empty() {
655        let reg = KnowledgeRegistry::new();
656        assert_eq!(reg.available_count(), 0);
657        assert!(reg.list().is_empty());
658    }
659
660    #[test]
661    fn registry_lists_sources() {
662        let dir = TempDir::new().unwrap();
663        let mut reg = KnowledgeRegistry::new();
664        reg.add(Box::new(DirectorySource::new(
665            "docs",
666            dir.path().to_path_buf(),
667        )));
668        reg.add(Box::new(
669            VectorDbSource::new("pinecone", "https://api.pinecone.io").unwrap(),
670        ));
671
672        let list = reg.list();
673        assert_eq!(list.len(), 2);
674        assert_eq!(list[0].0, "docs");
675        assert_eq!(list[1].0, "pinecone");
676    }
677
678    #[tokio::test]
679    async fn registry_query_all_aggregates() {
680        let dir = TempDir::new().unwrap();
681        fs::write(dir.path().join("file.md"), "knowledge about Rust").unwrap();
682
683        let mut reg = KnowledgeRegistry::new();
684        reg.add(Box::new(DirectorySource::new(
685            "docs",
686            dir.path().to_path_buf(),
687        )));
688
689        let results = reg.query_all("Rust", 5).await;
690        assert_eq!(results.len(), 1);
691    }
692
693    #[test]
694    fn chunk_serialization() {
695        let chunk = KnowledgeChunk {
696            content: "test content".into(),
697            source: "test.md".into(),
698            relevance: 0.95,
699            metadata: None,
700        };
701        let json = serde_json::to_string(&chunk).unwrap();
702        let decoded: KnowledgeChunk = serde_json::from_str(&json).unwrap();
703        assert_eq!(decoded.content, "test content");
704        assert_eq!(decoded.relevance, 0.95);
705    }
706}