1use std::collections::HashMap;
7use serde::{Deserialize, Serialize};
8use chrono::Utc;
9use anyhow::Result;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolCallRecord {
14 pub tool_name: String,
15 pub timestamp: String,
16 pub summary: String,
17}
18
19#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub struct KnowledgeNode {
22 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
24 pub findings: HashMap<String, String>,
25 #[serde(default, skip_serializing_if = "HashMap::is_empty")]
27 pub file_cache: HashMap<String, String>,
28 #[serde(default, skip_serializing_if = "Vec::is_empty")]
30 pub tool_history: Vec<ToolCallRecord>,
31}
32
33impl KnowledgeNode {
34 pub fn is_empty(&self) -> bool {
36 self.findings.is_empty() && self.file_cache.is_empty() && self.tool_history.is_empty()
37 }
38}
39
40pub trait KnowledgeGraph {
42 fn get_knowledge_mut(&mut self, node_id: &str) -> Option<&mut KnowledgeNode>;
44
45 fn get_knowledge(&self, node_id: &str) -> Option<&KnowledgeNode>;
47
48 fn get_incoming_edges(&self, node_id: &str) -> Vec<String>;
50}
51
52pub trait KnowledgeManagement: KnowledgeGraph {
54 fn store_finding(&mut self, node_id: &str, key: &str, value: &str) -> Result<()> {
56 let node = self.get_knowledge_mut(node_id)
57 .ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
58 node.findings.insert(key.to_string(), value.to_string());
59 Ok(())
60 }
61
62 fn get_finding(&self, node_id: &str, key: &str) -> Option<String> {
64 if let Some(node) = self.get_knowledge(node_id) {
66 if let Some(value) = node.findings.get(key) {
67 return Some(value.clone());
68 }
69 }
70
71 self.get_upstream_findings(node_id, key)
73 }
74
75 fn get_upstream_findings(&self, node_id: &str, key: &str) -> Option<String> {
77 for upstream_id in self.get_incoming_edges(node_id) {
78 if let Some(node) = self.get_knowledge(&upstream_id) {
79 if let Some(value) = node.findings.get(key) {
80 return Some(value.clone());
81 }
82 if let Some(value) = self.get_upstream_findings(&upstream_id, key) {
84 return Some(value);
85 }
86 }
87 }
88 None
89 }
90
91 fn cache_file(&mut self, node_id: &str, path: &str, content: &str) -> Result<()> {
93 let node = self.get_knowledge_mut(node_id)
94 .ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
95 node.file_cache.insert(path.to_string(), content.to_string());
96 Ok(())
97 }
98
99 fn get_cached_file(&self, node_id: &str, path: &str) -> Option<String> {
101 if let Some(node) = self.get_knowledge(node_id) {
103 if let Some(content) = node.file_cache.get(path) {
104 return Some(content.clone());
105 }
106 }
107
108 for upstream_id in self.get_incoming_edges(node_id) {
110 if let Some(content) = self.get_cached_file(&upstream_id, path) {
111 return Some(content);
112 }
113 }
114 None
115 }
116
117 fn record_tool_call(&mut self, node_id: &str, tool_name: &str, summary: &str) -> Result<()> {
119 let node = self.get_knowledge_mut(node_id)
120 .ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
121 node.tool_history.push(ToolCallRecord {
122 tool_name: tool_name.to_string(),
123 timestamp: Utc::now().to_rfc3339(),
124 summary: summary.to_string(),
125 });
126 Ok(())
127 }
128
129 fn get_tool_history(&self, node_id: &str) -> Vec<ToolCallRecord> {
131 let mut history = Vec::new();
132
133 for upstream_id in self.get_incoming_edges(node_id) {
135 history.extend(self.get_tool_history(&upstream_id));
136 }
137
138 if let Some(node) = self.get_knowledge(node_id) {
140 history.extend(node.tool_history.clone());
141 }
142
143 history
144 }
145
146 fn get_knowledge_context(&self, node_id: &str) -> String {
148 let mut context = String::new();
149
150 let mut all_findings = HashMap::new();
152 self.collect_upstream_findings_all(node_id, &mut all_findings);
153
154 if !all_findings.is_empty() {
155 context.push_str("**Knowledge from previous tasks:**\n");
156 for (key, value) in &all_findings {
157 context.push_str(&format!("- {}: {}\n", key, value));
158 }
159 context.push('\n');
160 }
161
162 let tool_history = self.get_tool_history(node_id);
164 if !tool_history.is_empty() {
165 context.push_str("**Previously accessed files:**\n");
166 let mut seen = std::collections::HashSet::new();
167 for record in &tool_history {
168 if record.tool_name == "view_file" && seen.insert(&record.summary) {
169 context.push_str(&format!("- {}\n", record.summary));
170 }
171 }
172 }
173
174 context
175 }
176
177 fn collect_upstream_findings_all(&self, node_id: &str, findings: &mut HashMap<String, String>) {
179 for upstream_id in self.get_incoming_edges(node_id) {
181 self.collect_upstream_findings_all(&upstream_id, findings);
182 }
183
184 if let Some(node) = self.get_knowledge(node_id) {
186 findings.extend(node.findings.clone());
187 }
188 }
189}
190
191#[derive(Debug, Clone, Default)]
193pub struct SimpleKnowledgeGraph {
194 pub nodes: HashMap<String, KnowledgeNode>,
195 pub edges: Vec<(String, String)>, }
197
198impl SimpleKnowledgeGraph {
199 pub fn new() -> Self {
200 Self::default()
201 }
202
203 pub fn add_node(&mut self, node_id: &str) {
204 self.nodes.insert(node_id.to_string(), KnowledgeNode::default());
205 }
206
207 pub fn add_edge(&mut self, from: &str, to: &str) {
208 self.edges.push((from.to_string(), to.to_string()));
209 }
210}
211
212impl KnowledgeGraph for SimpleKnowledgeGraph {
213 fn get_knowledge_mut(&mut self, node_id: &str) -> Option<&mut KnowledgeNode> {
214 self.nodes.get_mut(node_id)
215 }
216
217 fn get_knowledge(&self, node_id: &str) -> Option<&KnowledgeNode> {
218 self.nodes.get(node_id)
219 }
220
221 fn get_incoming_edges(&self, node_id: &str) -> Vec<String> {
222 self.edges.iter()
223 .filter(|(_, to)| to == node_id)
224 .map(|(from, _)| from.clone())
225 .collect()
226 }
227}
228
229impl KnowledgeManagement for SimpleKnowledgeGraph {}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 #[test]
236 fn test_store_and_get_finding() {
237 let mut graph = SimpleKnowledgeGraph::new();
238 graph.add_node("node1");
239
240 graph.store_finding("node1", "key1", "value1").unwrap();
241 assert_eq!(graph.get_finding("node1", "key1"), Some("value1".to_string()));
242 }
243
244 #[test]
245 fn test_upstream_finding_lookup() {
246 let mut graph = SimpleKnowledgeGraph::new();
247 graph.add_node("parent");
248 graph.add_node("child");
249 graph.add_edge("parent", "child");
250
251 graph.store_finding("parent", "shared_key", "parent_value").unwrap();
252
253 assert_eq!(graph.get_finding("child", "shared_key"), Some("parent_value".to_string()));
255 }
256
257 #[test]
258 fn test_file_cache() {
259 let mut graph = SimpleKnowledgeGraph::new();
260 graph.add_node("node1");
261
262 graph.cache_file("node1", "path/to/file.py", "file content").unwrap();
263 assert_eq!(graph.get_cached_file("node1", "path/to/file.py"), Some("file content".to_string()));
264 }
265
266 #[test]
267 fn test_tool_history() {
268 let mut graph = SimpleKnowledgeGraph::new();
269 graph.add_node("node1");
270
271 graph.record_tool_call("node1", "view_file", "foo.py").unwrap();
272 graph.record_tool_call("node1", "edit_file", "bar.py").unwrap();
273
274 let history = graph.get_tool_history("node1");
275 assert_eq!(history.len(), 2);
276 assert_eq!(history[0].tool_name, "view_file");
277 assert_eq!(history[1].tool_name, "edit_file");
278 }
279}