matrixcode_core/compress/
dependency.rs1use std::collections::HashMap;
7
8use crate::providers::{ContentBlock, Message, MessageContent};
9
10use super::types::{DependencyGraph, MessageDependency};
11
12pub struct DependencyBuilder;
14
15impl DependencyBuilder {
16 pub fn build(messages: &[Message]) -> DependencyGraph {
18 let mut dependencies: Vec<MessageDependency> = Vec::new();
19 let mut pending_tool_use: HashMap<String, usize> = HashMap::new();
20 let mut tool_names: HashMap<String, String> = HashMap::new();
21
22 for (idx, msg) in messages.iter().enumerate() {
24 let blocks = get_content_blocks(&msg.content);
25
26 for block in blocks {
27 match block {
28 ContentBlock::ToolUse { id, name, .. } => {
29 pending_tool_use.insert(id.clone(), idx);
31 tool_names.insert(id.clone(), name.clone());
32 }
33 ContentBlock::ToolResult { tool_use_id, .. } => {
34 if let Some(tool_use_idx) = pending_tool_use.remove(tool_use_id.as_str()) {
36 let tool_name = tool_names
37 .get(tool_use_id.as_str())
38 .cloned()
39 .unwrap_or_else(|| "unknown".to_string());
40
41 dependencies.push(MessageDependency {
42 tool_use_idx,
43 tool_result_idx: idx,
44 tool_name: tool_name.clone(),
45 is_critical: is_critical_tool(&tool_name),
46 });
47 }
48 }
49 _ => {}
50 }
51 }
52 }
53
54 let message_to_deps = build_reverse_index(&dependencies);
56
57 DependencyGraph {
58 dependencies,
59 message_to_deps,
60 }
61 }
62
63 pub fn build_with_custom_critical(
65 messages: &[Message],
66 critical_tools: &[&str],
67 ) -> DependencyGraph {
68 let mut graph = Self::build(messages);
69
70 for dep in &mut graph.dependencies {
72 dep.is_critical = critical_tools.contains(&dep.tool_name.as_str());
73 }
74
75 graph
76 }
77}
78
79fn get_content_blocks(content: &MessageContent) -> Vec<&ContentBlock> {
81 match content {
82 MessageContent::Blocks(blocks) => blocks.iter().collect(),
83 _ => Vec::new(),
84 }
85}
86
87fn is_critical_tool(name: &str) -> bool {
89 let critical_tools = ["write", "edit", "multi_edit", "bash"];
90 critical_tools.contains(&name)
91}
92
93fn build_reverse_index(dependencies: &[MessageDependency]) -> HashMap<usize, Vec<usize>> {
95 let mut index: HashMap<usize, Vec<usize>> = HashMap::new();
96
97 for (dep_idx, dep) in dependencies.iter().enumerate() {
98 index.entry(dep.tool_use_idx).or_default().push(dep_idx);
100
101 index.entry(dep.tool_result_idx).or_default().push(dep_idx);
103 }
104
105 index
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::providers::Role;
112
113 #[test]
114 fn test_build_empty() {
115 let messages: Vec<Message> = Vec::new();
116 let graph = DependencyBuilder::build(&messages);
117 assert_eq!(graph.dependencies.len(), 0);
118 }
119
120 #[test]
121 fn test_build_single_pair() {
122 let messages = vec![
123 Message {
124 role: Role::User,
125 content: MessageContent::Text("Read file".to_string()),
126 },
127 Message {
128 role: Role::Assistant,
129 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
130 id: "t1".to_string(),
131 name: "read".to_string(),
132 input: serde_json::json!({"path": "test.rs"}),
133 }]),
134 },
135 Message {
136 role: Role::Tool,
137 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
138 tool_use_id: "t1".to_string(),
139 content: "file content".to_string(),
140 }]),
141 },
142 ];
143
144 let graph = DependencyBuilder::build(&messages);
145 assert_eq!(graph.dependencies.len(), 1);
146
147 let dep = &graph.dependencies[0];
148 assert_eq!(dep.tool_use_idx, 1);
149 assert_eq!(dep.tool_result_idx, 2);
150 assert_eq!(dep.tool_name, "read");
151 assert!(!dep.is_critical);
152 }
153
154 #[test]
155 fn test_build_critical_tool() {
156 let messages = vec![
157 Message {
158 role: Role::Assistant,
159 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
160 id: "t1".to_string(),
161 name: "write".to_string(),
162 input: serde_json::json!({"path": "test.rs"}),
163 }]),
164 },
165 Message {
166 role: Role::Tool,
167 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
168 tool_use_id: "t1".to_string(),
169 content: "success".to_string(),
170 }]),
171 },
172 ];
173
174 let graph = DependencyBuilder::build(&messages);
175 assert_eq!(graph.dependencies.len(), 1);
176 assert!(graph.dependencies[0].is_critical);
177 }
178
179 #[test]
180 fn test_build_multiple_tools_same_message() {
181 let messages = vec![
182 Message {
183 role: Role::Assistant,
184 content: MessageContent::Blocks(vec![
185 ContentBlock::ToolUse {
186 id: "t1".to_string(),
187 name: "read".to_string(),
188 input: serde_json::json!({"path": "a.rs"}),
189 },
190 ContentBlock::ToolUse {
191 id: "t2".to_string(),
192 name: "read".to_string(),
193 input: serde_json::json!({"path": "b.rs"}),
194 },
195 ]),
196 },
197 Message {
198 role: Role::Tool,
199 content: MessageContent::Blocks(vec![
200 ContentBlock::ToolResult {
201 tool_use_id: "t1".to_string(),
202 content: "content a".to_string(),
203 },
204 ContentBlock::ToolResult {
205 tool_use_id: "t2".to_string(),
206 content: "content b".to_string(),
207 },
208 ]),
209 },
210 ];
211
212 let graph = DependencyBuilder::build(&messages);
213 assert_eq!(graph.dependencies.len(), 2);
214
215 assert!(graph.message_to_deps.contains_key(&0));
217 assert_eq!(graph.message_to_deps.get(&0).unwrap().len(), 2);
218 }
219
220 #[test]
221 fn test_missing_tool_result() {
222 let messages = vec![Message {
223 role: Role::Assistant,
224 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
225 id: "t1".to_string(),
226 name: "read".to_string(),
227 input: serde_json::json!({"path": "test.rs"}),
228 }]),
229 }];
230
231 let graph = DependencyBuilder::build(&messages);
232 assert_eq!(graph.dependencies.len(), 0);
233 }
234
235 #[test]
236 fn test_reverse_index() {
237 let messages = vec![
238 Message {
239 role: Role::Assistant,
240 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
241 id: "t1".to_string(),
242 name: "read".to_string(),
243 input: serde_json::json!({}),
244 }]),
245 },
246 Message {
247 role: Role::Tool,
248 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
249 tool_use_id: "t1".to_string(),
250 content: "result".to_string(),
251 }]),
252 },
253 ];
254
255 let graph = DependencyBuilder::build(&messages);
256
257 assert!(graph.message_to_deps.contains_key(&0));
259 assert!(graph.message_to_deps.contains_key(&1));
260 }
261
262 #[test]
263 fn test_get_pair_indices() {
264 let messages = vec![
265 Message {
266 role: Role::Assistant,
267 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
268 id: "t1".to_string(),
269 name: "read".to_string(),
270 input: serde_json::json!({}),
271 }]),
272 },
273 Message {
274 role: Role::Tool,
275 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
276 tool_use_id: "t1".to_string(),
277 content: "result".to_string(),
278 }]),
279 },
280 ];
281
282 let graph = DependencyBuilder::build(&messages);
283
284 let pairs = graph.get_pair_indices(0);
286 assert_eq!(pairs, vec![1]);
287
288 let pairs = graph.get_pair_indices(1);
290 assert_eq!(pairs, vec![0]);
291 }
292}