Skip to main content

gid_core/
task_graph_knowledge.rs

1//! Knowledge management extension for task graphs
2//!
3//! Provides per-node knowledge storage, file caching, and tool call tracking
4//! for building up context during graph-driven exploration.
5
6use std::collections::HashMap;
7use serde::{Deserialize, Serialize};
8use chrono::Utc;
9use anyhow::Result;
10
11/// Record of a tool call made during exploration
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ToolCallRecord {
14    pub tool_name: String,
15    pub timestamp: String,
16    pub summary: String,
17}
18
19/// Task node with knowledge storage capabilities
20#[derive(Debug, Clone, Default, Serialize, Deserialize)]
21pub struct KnowledgeNode {
22    /// Findings attached to this node (key-value pairs)
23    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
24    pub findings: HashMap<String, String>,
25    /// Cached file contents
26    #[serde(default, skip_serializing_if = "HashMap::is_empty")]
27    pub file_cache: HashMap<String, String>,
28    /// History of tool calls made for this node
29    #[serde(default, skip_serializing_if = "Vec::is_empty")]
30    pub tool_history: Vec<ToolCallRecord>,
31}
32
33impl KnowledgeNode {
34    /// Returns true if this knowledge node has no data stored.
35    pub fn is_empty(&self) -> bool {
36        self.findings.is_empty() && self.file_cache.is_empty() && self.tool_history.is_empty()
37    }
38}
39
40/// A graph that supports knowledge management on nodes
41pub trait KnowledgeGraph {
42    /// Get mutable access to a node's knowledge storage
43    fn get_knowledge_mut(&mut self, node_id: &str) -> Option<&mut KnowledgeNode>;
44    
45    /// Get read access to a node's knowledge storage
46    fn get_knowledge(&self, node_id: &str) -> Option<&KnowledgeNode>;
47    
48    /// Get edges pointing to a node (for upstream lookups)
49    fn get_incoming_edges(&self, node_id: &str) -> Vec<String>;
50}
51
52/// Knowledge management functions
53pub trait KnowledgeManagement: KnowledgeGraph {
54    /// Store a finding in a node
55    fn store_finding(&mut self, node_id: &str, key: &str, value: &str) -> Result<()> {
56        let node = self.get_knowledge_mut(node_id)
57            .ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
58        node.findings.insert(key.to_string(), value.to_string());
59        Ok(())
60    }
61    
62    /// Get a finding from a node or any upstream node
63    fn get_finding(&self, node_id: &str, key: &str) -> Option<String> {
64        // First check current node
65        if let Some(node) = self.get_knowledge(node_id) {
66            if let Some(value) = node.findings.get(key) {
67                return Some(value.clone());
68            }
69        }
70        
71        // Check upstream nodes (dependencies)
72        self.get_upstream_findings(node_id, key)
73    }
74    
75    /// Get finding from upstream nodes recursively
76    fn get_upstream_findings(&self, node_id: &str, key: &str) -> Option<String> {
77        for upstream_id in self.get_incoming_edges(node_id) {
78            if let Some(node) = self.get_knowledge(&upstream_id) {
79                if let Some(value) = node.findings.get(key) {
80                    return Some(value.clone());
81                }
82                // Recursively check further upstream
83                if let Some(value) = self.get_upstream_findings(&upstream_id, key) {
84                    return Some(value);
85                }
86            }
87        }
88        None
89    }
90    
91    /// Cache file content in a node
92    fn cache_file(&mut self, node_id: &str, path: &str, content: &str) -> Result<()> {
93        let node = self.get_knowledge_mut(node_id)
94            .ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
95        node.file_cache.insert(path.to_string(), content.to_string());
96        Ok(())
97    }
98    
99    /// Get cached file from this node or upstream
100    fn get_cached_file(&self, node_id: &str, path: &str) -> Option<String> {
101        // Check current node
102        if let Some(node) = self.get_knowledge(node_id) {
103            if let Some(content) = node.file_cache.get(path) {
104                return Some(content.clone());
105            }
106        }
107        
108        // Check upstream nodes
109        for upstream_id in self.get_incoming_edges(node_id) {
110            if let Some(content) = self.get_cached_file(&upstream_id, path) {
111                return Some(content);
112            }
113        }
114        None
115    }
116    
117    /// Record a tool call
118    fn record_tool_call(&mut self, node_id: &str, tool_name: &str, summary: &str) -> Result<()> {
119        let node = self.get_knowledge_mut(node_id)
120            .ok_or_else(|| anyhow::anyhow!("Node not found: {}", node_id))?;
121        node.tool_history.push(ToolCallRecord {
122            tool_name: tool_name.to_string(),
123            timestamp: Utc::now().to_rfc3339(),
124            summary: summary.to_string(),
125        });
126        Ok(())
127    }
128    
129    /// Get all tool calls from this and upstream nodes
130    fn get_tool_history(&self, node_id: &str) -> Vec<ToolCallRecord> {
131        let mut history = Vec::new();
132        
133        // Get from upstream first (chronological order)
134        for upstream_id in self.get_incoming_edges(node_id) {
135            history.extend(self.get_tool_history(&upstream_id));
136        }
137        
138        // Add current node
139        if let Some(node) = self.get_knowledge(node_id) {
140            history.extend(node.tool_history.clone());
141        }
142        
143        history
144    }
145    
146    /// Get all findings from this and upstream nodes as formatted context
147    fn get_knowledge_context(&self, node_id: &str) -> String {
148        let mut context = String::new();
149        
150        // Collect all findings from upstream
151        let mut all_findings = HashMap::new();
152        self.collect_upstream_findings_all(node_id, &mut all_findings);
153        
154        if !all_findings.is_empty() {
155            context.push_str("**Knowledge from previous tasks:**\n");
156            for (key, value) in &all_findings {
157                context.push_str(&format!("- {}: {}\n", key, value));
158            }
159            context.push('\n');
160        }
161        
162        // Add tool history summary
163        let tool_history = self.get_tool_history(node_id);
164        if !tool_history.is_empty() {
165            context.push_str("**Previously accessed files:**\n");
166            let mut seen = std::collections::HashSet::new();
167            for record in &tool_history {
168                if record.tool_name == "view_file" && seen.insert(&record.summary) {
169                    context.push_str(&format!("- {}\n", record.summary));
170                }
171            }
172        }
173        
174        context
175    }
176    
177    /// Helper to collect all findings recursively
178    fn collect_upstream_findings_all(&self, node_id: &str, findings: &mut HashMap<String, String>) {
179        // Get from upstream first
180        for upstream_id in self.get_incoming_edges(node_id) {
181            self.collect_upstream_findings_all(&upstream_id, findings);
182        }
183        
184        // Add current node findings
185        if let Some(node) = self.get_knowledge(node_id) {
186            findings.extend(node.findings.clone());
187        }
188    }
189}
190
191/// Simple in-memory knowledge graph for testing
192#[derive(Debug, Clone, Default)]
193pub struct SimpleKnowledgeGraph {
194    pub nodes: HashMap<String, KnowledgeNode>,
195    pub edges: Vec<(String, String)>, // (from, to)
196}
197
198impl SimpleKnowledgeGraph {
199    pub fn new() -> Self {
200        Self::default()
201    }
202    
203    pub fn add_node(&mut self, node_id: &str) {
204        self.nodes.insert(node_id.to_string(), KnowledgeNode::default());
205    }
206    
207    pub fn add_edge(&mut self, from: &str, to: &str) {
208        self.edges.push((from.to_string(), to.to_string()));
209    }
210}
211
212impl KnowledgeGraph for SimpleKnowledgeGraph {
213    fn get_knowledge_mut(&mut self, node_id: &str) -> Option<&mut KnowledgeNode> {
214        self.nodes.get_mut(node_id)
215    }
216    
217    fn get_knowledge(&self, node_id: &str) -> Option<&KnowledgeNode> {
218        self.nodes.get(node_id)
219    }
220    
221    fn get_incoming_edges(&self, node_id: &str) -> Vec<String> {
222        self.edges.iter()
223            .filter(|(_, to)| to == node_id)
224            .map(|(from, _)| from.clone())
225            .collect()
226    }
227}
228
229impl KnowledgeManagement for SimpleKnowledgeGraph {}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    
235    #[test]
236    fn test_store_and_get_finding() {
237        let mut graph = SimpleKnowledgeGraph::new();
238        graph.add_node("node1");
239        
240        graph.store_finding("node1", "key1", "value1").unwrap();
241        assert_eq!(graph.get_finding("node1", "key1"), Some("value1".to_string()));
242    }
243    
244    #[test]
245    fn test_upstream_finding_lookup() {
246        let mut graph = SimpleKnowledgeGraph::new();
247        graph.add_node("parent");
248        graph.add_node("child");
249        graph.add_edge("parent", "child");
250        
251        graph.store_finding("parent", "shared_key", "parent_value").unwrap();
252        
253        // Child should find parent's finding
254        assert_eq!(graph.get_finding("child", "shared_key"), Some("parent_value".to_string()));
255    }
256    
257    #[test]
258    fn test_file_cache() {
259        let mut graph = SimpleKnowledgeGraph::new();
260        graph.add_node("node1");
261        
262        graph.cache_file("node1", "path/to/file.py", "file content").unwrap();
263        assert_eq!(graph.get_cached_file("node1", "path/to/file.py"), Some("file content".to_string()));
264    }
265    
266    #[test]
267    fn test_tool_history() {
268        let mut graph = SimpleKnowledgeGraph::new();
269        graph.add_node("node1");
270        
271        graph.record_tool_call("node1", "view_file", "foo.py").unwrap();
272        graph.record_tool_call("node1", "edit_file", "bar.py").unwrap();
273        
274        let history = graph.get_tool_history("node1");
275        assert_eq!(history.len(), 2);
276        assert_eq!(history[0].tool_name, "view_file");
277        assert_eq!(history[1].tool_name, "edit_file");
278    }
279}