use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use chrono::Utc;
use anyhow::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
pub tool_name: String,
pub timestamp: String,
pub summary: String,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct KnowledgeNode {
pub findings: HashMap<String, String>,
pub file_cache: HashMap<String, String>,
pub tool_history: Vec<ToolCallRecord>,
}
pub trait KnowledgeGraph {
fn get_knowledge_mut(&mut self, node_id: &str) -> Option<&mut KnowledgeNode>;
fn get_knowledge(&self, node_id: &str) -> Option<&KnowledgeNode>;
fn get_incoming_edges(&self, node_id: &str) -> Vec<String>;
}
pub trait KnowledgeManagement: KnowledgeGraph {
fn store_finding(&mut self, node_id: &str, key: &str, value: &str) -> Result<()> {
let node = self.get_knowledge_mut(node_id)
.ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
node.findings.insert(key.to_string(), value.to_string());
Ok(())
}
fn get_finding(&self, node_id: &str, key: &str) -> Option<String> {
if let Some(node) = self.get_knowledge(node_id) {
if let Some(value) = node.findings.get(key) {
return Some(value.clone());
}
}
self.get_upstream_findings(node_id, key)
}
fn get_upstream_findings(&self, node_id: &str, key: &str) -> Option<String> {
for upstream_id in self.get_incoming_edges(node_id) {
if let Some(node) = self.get_knowledge(&upstream_id) {
if let Some(value) = node.findings.get(key) {
return Some(value.clone());
}
if let Some(value) = self.get_upstream_findings(&upstream_id, key) {
return Some(value);
}
}
}
None
}
fn cache_file(&mut self, node_id: &str, path: &str, content: &str) -> Result<()> {
let node = self.get_knowledge_mut(node_id)
.ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
node.file_cache.insert(path.to_string(), content.to_string());
Ok(())
}
fn get_cached_file(&self, node_id: &str, path: &str) -> Option<String> {
if let Some(node) = self.get_knowledge(node_id) {
if let Some(content) = node.file_cache.get(path) {
return Some(content.clone());
}
}
for upstream_id in self.get_incoming_edges(node_id) {
if let Some(content) = self.get_cached_file(&upstream_id, path) {
return Some(content);
}
}
None
}
fn record_tool_call(&mut self, node_id: &str, tool_name: &str, summary: &str) -> Result<()> {
let node = self.get_knowledge_mut(node_id)
.ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
node.tool_history.push(ToolCallRecord {
tool_name: tool_name.to_string(),
timestamp: Utc::now().to_rfc3339(),
summary: summary.to_string(),
});
Ok(())
}
fn get_tool_history(&self, node_id: &str) -> Vec<ToolCallRecord> {
let mut history = Vec::new();
for upstream_id in self.get_incoming_edges(node_id) {
history.extend(self.get_tool_history(&upstream_id));
}
if let Some(node) = self.get_knowledge(node_id) {
history.extend(node.tool_history.clone());
}
history
}
fn get_knowledge_context(&self, node_id: &str) -> String {
let mut context = String::new();
let mut all_findings = HashMap::new();
self.collect_upstream_findings_all(node_id, &mut all_findings);
if !all_findings.is_empty() {
context.push_str("**Knowledge from previous tasks:**\n");
for (key, value) in &all_findings {
context.push_str(&format!("- {}: {}\n", key, value));
}
context.push('\n');
}
let tool_history = self.get_tool_history(node_id);
if !tool_history.is_empty() {
context.push_str("**Previously accessed files:**\n");
let mut seen = std::collections::HashSet::new();
for record in &tool_history {
if record.tool_name == "view_file" && seen.insert(&record.summary) {
context.push_str(&format!("- {}\n", record.summary));
}
}
}
context
}
fn collect_upstream_findings_all(&self, node_id: &str, findings: &mut HashMap<String, String>) {
for upstream_id in self.get_incoming_edges(node_id) {
self.collect_upstream_findings_all(&upstream_id, findings);
}
if let Some(node) = self.get_knowledge(node_id) {
findings.extend(node.findings.clone());
}
}
}
#[derive(Debug, Clone, Default)]
pub struct SimpleKnowledgeGraph {
pub nodes: HashMap<String, KnowledgeNode>,
pub edges: Vec<(String, String)>, }
impl SimpleKnowledgeGraph {
pub fn new() -> Self {
Self::default()
}
pub fn add_node(&mut self, node_id: &str) {
self.nodes.insert(node_id.to_string(), KnowledgeNode::default());
}
pub fn add_edge(&mut self, from: &str, to: &str) {
self.edges.push((from.to_string(), to.to_string()));
}
}
impl KnowledgeGraph for SimpleKnowledgeGraph {
fn get_knowledge_mut(&mut self, node_id: &str) -> Option<&mut KnowledgeNode> {
self.nodes.get_mut(node_id)
}
fn get_knowledge(&self, node_id: &str) -> Option<&KnowledgeNode> {
self.nodes.get(node_id)
}
fn get_incoming_edges(&self, node_id: &str) -> Vec<String> {
self.edges.iter()
.filter(|(_, to)| to == node_id)
.map(|(from, _)| from.clone())
.collect()
}
}
impl KnowledgeManagement for SimpleKnowledgeGraph {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_store_and_get_finding() {
let mut graph = SimpleKnowledgeGraph::new();
graph.add_node("node1");
graph.store_finding("node1", "key1", "value1").unwrap();
assert_eq!(graph.get_finding("node1", "key1"), Some("value1".to_string()));
}
#[test]
fn test_upstream_finding_lookup() {
let mut graph = SimpleKnowledgeGraph::new();
graph.add_node("parent");
graph.add_node("child");
graph.add_edge("parent", "child");
graph.store_finding("parent", "shared_key", "parent_value").unwrap();
assert_eq!(graph.get_finding("child", "shared_key"), Some("parent_value".to_string()));
}
#[test]
fn test_file_cache() {
let mut graph = SimpleKnowledgeGraph::new();
graph.add_node("node1");
graph.cache_file("node1", "path/to/file.py", "file content").unwrap();
assert_eq!(graph.get_cached_file("node1", "path/to/file.py"), Some("file content".to_string()));
}
#[test]
fn test_tool_history() {
let mut graph = SimpleKnowledgeGraph::new();
graph.add_node("node1");
graph.record_tool_call("node1", "view_file", "foo.py").unwrap();
graph.record_tool_call("node1", "edit_file", "bar.py").unwrap();
let history = graph.get_tool_history("node1");
assert_eq!(history.len(), 2);
assert_eq!(history[0].tool_name, "view_file");
assert_eq!(history[1].tool_name, "edit_file");
}
}