Skip to main content

matrixcode_core/compress/
dependency.rs

1//! Message dependency tracking for preserving conversation coherence.
2//!
3//! Builds a graph of ToolUse ↔ ToolResult pairs to ensure they are
4//! preserved together during compression.
5
6use std::collections::HashMap;
7
8use crate::providers::{ContentBlock, Message, MessageContent};
9
10use super::types::{DependencyGraph, MessageDependency};
11
12/// Builder for dependency graphs.
13pub struct DependencyBuilder;
14
15impl DependencyBuilder {
16    /// Build a dependency graph from message history.
17    pub fn build(messages: &[Message]) -> DependencyGraph {
18        let mut dependencies: Vec<MessageDependency> = Vec::new();
19        let mut pending_tool_use: HashMap<String, usize> = HashMap::new();
20        let mut tool_names: HashMap<String, String> = HashMap::new();
21
22        // Scan all messages to find ToolUse-ToolResult pairs
23        for (idx, msg) in messages.iter().enumerate() {
24            let blocks = get_content_blocks(&msg.content);
25
26            for block in blocks {
27                match block {
28                    ContentBlock::ToolUse { id, name, .. } => {
29                        // Record pending ToolUse
30                        pending_tool_use.insert(id.clone(), idx);
31                        tool_names.insert(id.clone(), name.clone());
32                    }
33                    ContentBlock::ToolResult { tool_use_id, .. } => {
34                        // Match with pending ToolUse
35                        if let Some(tool_use_idx) = pending_tool_use.remove(tool_use_id.as_str()) {
36                            let tool_name = tool_names
37                                .get(tool_use_id.as_str())
38                                .cloned()
39                                .unwrap_or_else(|| "unknown".to_string());
40
41                            dependencies.push(MessageDependency {
42                                tool_use_idx,
43                                tool_result_idx: idx,
44                                tool_name: tool_name.clone(),
45                                is_critical: is_critical_tool(&tool_name),
46                            });
47                        }
48                    }
49                    _ => {}
50                }
51            }
52        }
53
54        // Build reverse index for quick lookup
55        let message_to_deps = build_reverse_index(&dependencies);
56
57        DependencyGraph {
58            dependencies,
59            message_to_deps,
60        }
61    }
62
63    /// Build dependency graph with custom critical tool detection.
64    pub fn build_with_custom_critical(
65        messages: &[Message],
66        critical_tools: &[&str],
67    ) -> DependencyGraph {
68        let mut graph = Self::build(messages);
69
70        // Update is_critical based on custom list
71        for dep in &mut graph.dependencies {
72            dep.is_critical = critical_tools.contains(&dep.tool_name.as_str());
73        }
74
75        graph
76    }
77}
78
79/// Extract content blocks from a message.
80fn get_content_blocks(content: &MessageContent) -> Vec<&ContentBlock> {
81    match content {
82        MessageContent::Blocks(blocks) => blocks.iter().collect(),
83        _ => Vec::new(),
84    }
85}
86
87/// Check if a tool is considered critical (modifies state).
88fn is_critical_tool(name: &str) -> bool {
89    let critical_tools = ["write", "edit", "multi_edit", "bash"];
90    critical_tools.contains(&name)
91}
92
93/// Build reverse index: message idx -> dependency indices.
94fn build_reverse_index(dependencies: &[MessageDependency]) -> HashMap<usize, Vec<usize>> {
95    let mut index: HashMap<usize, Vec<usize>> = HashMap::new();
96
97    for (dep_idx, dep) in dependencies.iter().enumerate() {
98        // Add ToolUse index
99        index.entry(dep.tool_use_idx).or_default().push(dep_idx);
100
101        // Add ToolResult index
102        index.entry(dep.tool_result_idx).or_default().push(dep_idx);
103    }
104
105    index
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111    use crate::providers::Role;
112
113    #[test]
114    fn test_build_empty() {
115        let messages: Vec<Message> = Vec::new();
116        let graph = DependencyBuilder::build(&messages);
117        assert_eq!(graph.dependencies.len(), 0);
118    }
119
120    #[test]
121    fn test_build_single_pair() {
122        let messages = vec![
123            Message {
124                role: Role::User,
125                content: MessageContent::Text("Read file".to_string()),
126            },
127            Message {
128                role: Role::Assistant,
129                content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
130                    id: "t1".to_string(),
131                    name: "read".to_string(),
132                    input: serde_json::json!({"path": "test.rs"}),
133                }]),
134            },
135            Message {
136                role: Role::Tool,
137                content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
138                    tool_use_id: "t1".to_string(),
139                    content: "file content".to_string(),
140                }]),
141            },
142        ];
143
144        let graph = DependencyBuilder::build(&messages);
145        assert_eq!(graph.dependencies.len(), 1);
146
147        let dep = &graph.dependencies[0];
148        assert_eq!(dep.tool_use_idx, 1);
149        assert_eq!(dep.tool_result_idx, 2);
150        assert_eq!(dep.tool_name, "read");
151        assert!(!dep.is_critical);
152    }
153
154    #[test]
155    fn test_build_critical_tool() {
156        let messages = vec![
157            Message {
158                role: Role::Assistant,
159                content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
160                    id: "t1".to_string(),
161                    name: "write".to_string(),
162                    input: serde_json::json!({"path": "test.rs"}),
163                }]),
164            },
165            Message {
166                role: Role::Tool,
167                content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
168                    tool_use_id: "t1".to_string(),
169                    content: "success".to_string(),
170                }]),
171            },
172        ];
173
174        let graph = DependencyBuilder::build(&messages);
175        assert_eq!(graph.dependencies.len(), 1);
176        assert!(graph.dependencies[0].is_critical);
177    }
178
179    #[test]
180    fn test_build_multiple_tools_same_message() {
181        let messages = vec![
182            Message {
183                role: Role::Assistant,
184                content: MessageContent::Blocks(vec![
185                    ContentBlock::ToolUse {
186                        id: "t1".to_string(),
187                        name: "read".to_string(),
188                        input: serde_json::json!({"path": "a.rs"}),
189                    },
190                    ContentBlock::ToolUse {
191                        id: "t2".to_string(),
192                        name: "read".to_string(),
193                        input: serde_json::json!({"path": "b.rs"}),
194                    },
195                ]),
196            },
197            Message {
198                role: Role::Tool,
199                content: MessageContent::Blocks(vec![
200                    ContentBlock::ToolResult {
201                        tool_use_id: "t1".to_string(),
202                        content: "content a".to_string(),
203                    },
204                    ContentBlock::ToolResult {
205                        tool_use_id: "t2".to_string(),
206                        content: "content b".to_string(),
207                    },
208                ]),
209            },
210        ];
211
212        let graph = DependencyBuilder::build(&messages);
213        assert_eq!(graph.dependencies.len(), 2);
214
215        // Check reverse index
216        assert!(graph.message_to_deps.contains_key(&0));
217        assert_eq!(graph.message_to_deps.get(&0).unwrap().len(), 2);
218    }
219
220    #[test]
221    fn test_missing_tool_result() {
222        let messages = vec![
223            Message {
224                role: Role::Assistant,
225                content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
226                    id: "t1".to_string(),
227                    name: "read".to_string(),
228                    input: serde_json::json!({"path": "test.rs"}),
229                }]),
230            },
231        ];
232
233        let graph = DependencyBuilder::build(&messages);
234        assert_eq!(graph.dependencies.len(), 0);
235    }
236
237    #[test]
238    fn test_reverse_index() {
239        let messages = vec![
240            Message {
241                role: Role::Assistant,
242                content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
243                    id: "t1".to_string(),
244                    name: "read".to_string(),
245                    input: serde_json::json!({}),
246                }]),
247            },
248            Message {
249                role: Role::Tool,
250                content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
251                    tool_use_id: "t1".to_string(),
252                    content: "result".to_string(),
253                }]),
254            },
255        ];
256
257        let graph = DependencyBuilder::build(&messages);
258
259        // Both indices should map to the same dependency
260        assert!(graph.message_to_deps.contains_key(&0));
261        assert!(graph.message_to_deps.contains_key(&1));
262    }
263
264    #[test]
265    fn test_get_pair_indices() {
266        let messages = vec![
267            Message {
268                role: Role::Assistant,
269                content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
270                    id: "t1".to_string(),
271                    name: "read".to_string(),
272                    input: serde_json::json!({}),
273                }]),
274            },
275            Message {
276                role: Role::Tool,
277                content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
278                    tool_use_id: "t1".to_string(),
279                    content: "result".to_string(),
280                }]),
281            },
282        ];
283
284        let graph = DependencyBuilder::build(&messages);
285
286        // Get pair indices for ToolUse
287        let pairs = graph.get_pair_indices(0);
288        assert_eq!(pairs, vec![1]);
289
290        // Get pair indices for ToolResult
291        let pairs = graph.get_pair_indices(1);
292        assert_eq!(pairs, vec![0]);
293    }
294}