matrixcode_core/compress/
dependency.rs1use std::collections::HashMap;
7
8use crate::providers::{ContentBlock, Message, MessageContent};
9
10use super::types::{DependencyGraph, MessageDependency};
11
12pub struct DependencyBuilder;
14
15impl DependencyBuilder {
16 pub fn build(messages: &[Message]) -> DependencyGraph {
18 let mut dependencies: Vec<MessageDependency> = Vec::new();
19 let mut pending_tool_use: HashMap<String, usize> = HashMap::new();
20 let mut tool_names: HashMap<String, String> = HashMap::new();
21
22 for (idx, msg) in messages.iter().enumerate() {
24 let blocks = get_content_blocks(&msg.content);
25
26 for block in blocks {
27 match block {
28 ContentBlock::ToolUse { id, name, .. } => {
29 pending_tool_use.insert(id.clone(), idx);
31 tool_names.insert(id.clone(), name.clone());
32 }
33 ContentBlock::ToolResult { tool_use_id, .. } => {
34 if let Some(tool_use_idx) = pending_tool_use.remove(tool_use_id.as_str()) {
36 let tool_name = tool_names
37 .get(tool_use_id.as_str())
38 .cloned()
39 .unwrap_or_else(|| "unknown".to_string());
40
41 dependencies.push(MessageDependency {
42 tool_use_idx,
43 tool_result_idx: idx,
44 tool_name: tool_name.clone(),
45 is_critical: is_critical_tool(&tool_name),
46 });
47 }
48 }
49 _ => {}
50 }
51 }
52 }
53
54 let message_to_deps = build_reverse_index(&dependencies);
56
57 DependencyGraph {
58 dependencies,
59 message_to_deps,
60 }
61 }
62
63 pub fn build_with_custom_critical(
65 messages: &[Message],
66 critical_tools: &[&str],
67 ) -> DependencyGraph {
68 let mut graph = Self::build(messages);
69
70 for dep in &mut graph.dependencies {
72 dep.is_critical = critical_tools.contains(&dep.tool_name.as_str());
73 }
74
75 graph
76 }
77}
78
79fn get_content_blocks(content: &MessageContent) -> Vec<&ContentBlock> {
81 match content {
82 MessageContent::Blocks(blocks) => blocks.iter().collect(),
83 _ => Vec::new(),
84 }
85}
86
87fn is_critical_tool(name: &str) -> bool {
89 let critical_tools = ["write", "edit", "multi_edit", "bash"];
90 critical_tools.contains(&name)
91}
92
93fn build_reverse_index(dependencies: &[MessageDependency]) -> HashMap<usize, Vec<usize>> {
95 let mut index: HashMap<usize, Vec<usize>> = HashMap::new();
96
97 for (dep_idx, dep) in dependencies.iter().enumerate() {
98 index.entry(dep.tool_use_idx).or_default().push(dep_idx);
100
101 index.entry(dep.tool_result_idx).or_default().push(dep_idx);
103 }
104
105 index
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111 use crate::providers::Role;
112
113 #[test]
114 fn test_build_empty() {
115 let messages: Vec<Message> = Vec::new();
116 let graph = DependencyBuilder::build(&messages);
117 assert_eq!(graph.dependencies.len(), 0);
118 }
119
120 #[test]
121 fn test_build_single_pair() {
122 let messages = vec![
123 Message {
124 role: Role::User,
125 content: MessageContent::Text("Read file".to_string()),
126 },
127 Message {
128 role: Role::Assistant,
129 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
130 id: "t1".to_string(),
131 name: "read".to_string(),
132 input: serde_json::json!({"path": "test.rs"}),
133 }]),
134 },
135 Message {
136 role: Role::Tool,
137 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
138 tool_use_id: "t1".to_string(),
139 content: "file content".to_string(),
140 }]),
141 },
142 ];
143
144 let graph = DependencyBuilder::build(&messages);
145 assert_eq!(graph.dependencies.len(), 1);
146
147 let dep = &graph.dependencies[0];
148 assert_eq!(dep.tool_use_idx, 1);
149 assert_eq!(dep.tool_result_idx, 2);
150 assert_eq!(dep.tool_name, "read");
151 assert!(!dep.is_critical);
152 }
153
154 #[test]
155 fn test_build_critical_tool() {
156 let messages = vec![
157 Message {
158 role: Role::Assistant,
159 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
160 id: "t1".to_string(),
161 name: "write".to_string(),
162 input: serde_json::json!({"path": "test.rs"}),
163 }]),
164 },
165 Message {
166 role: Role::Tool,
167 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
168 tool_use_id: "t1".to_string(),
169 content: "success".to_string(),
170 }]),
171 },
172 ];
173
174 let graph = DependencyBuilder::build(&messages);
175 assert_eq!(graph.dependencies.len(), 1);
176 assert!(graph.dependencies[0].is_critical);
177 }
178
179 #[test]
180 fn test_build_multiple_tools_same_message() {
181 let messages = vec![
182 Message {
183 role: Role::Assistant,
184 content: MessageContent::Blocks(vec![
185 ContentBlock::ToolUse {
186 id: "t1".to_string(),
187 name: "read".to_string(),
188 input: serde_json::json!({"path": "a.rs"}),
189 },
190 ContentBlock::ToolUse {
191 id: "t2".to_string(),
192 name: "read".to_string(),
193 input: serde_json::json!({"path": "b.rs"}),
194 },
195 ]),
196 },
197 Message {
198 role: Role::Tool,
199 content: MessageContent::Blocks(vec![
200 ContentBlock::ToolResult {
201 tool_use_id: "t1".to_string(),
202 content: "content a".to_string(),
203 },
204 ContentBlock::ToolResult {
205 tool_use_id: "t2".to_string(),
206 content: "content b".to_string(),
207 },
208 ]),
209 },
210 ];
211
212 let graph = DependencyBuilder::build(&messages);
213 assert_eq!(graph.dependencies.len(), 2);
214
215 assert!(graph.message_to_deps.contains_key(&0));
217 assert_eq!(graph.message_to_deps.get(&0).unwrap().len(), 2);
218 }
219
220 #[test]
221 fn test_missing_tool_result() {
222 let messages = vec![
223 Message {
224 role: Role::Assistant,
225 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
226 id: "t1".to_string(),
227 name: "read".to_string(),
228 input: serde_json::json!({"path": "test.rs"}),
229 }]),
230 },
231 ];
232
233 let graph = DependencyBuilder::build(&messages);
234 assert_eq!(graph.dependencies.len(), 0);
235 }
236
237 #[test]
238 fn test_reverse_index() {
239 let messages = vec![
240 Message {
241 role: Role::Assistant,
242 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
243 id: "t1".to_string(),
244 name: "read".to_string(),
245 input: serde_json::json!({}),
246 }]),
247 },
248 Message {
249 role: Role::Tool,
250 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
251 tool_use_id: "t1".to_string(),
252 content: "result".to_string(),
253 }]),
254 },
255 ];
256
257 let graph = DependencyBuilder::build(&messages);
258
259 assert!(graph.message_to_deps.contains_key(&0));
261 assert!(graph.message_to_deps.contains_key(&1));
262 }
263
264 #[test]
265 fn test_get_pair_indices() {
266 let messages = vec![
267 Message {
268 role: Role::Assistant,
269 content: MessageContent::Blocks(vec![ContentBlock::ToolUse {
270 id: "t1".to_string(),
271 name: "read".to_string(),
272 input: serde_json::json!({}),
273 }]),
274 },
275 Message {
276 role: Role::Tool,
277 content: MessageContent::Blocks(vec![ContentBlock::ToolResult {
278 tool_use_id: "t1".to_string(),
279 content: "result".to_string(),
280 }]),
281 },
282 ];
283
284 let graph = DependencyBuilder::build(&messages);
285
286 let pairs = graph.get_pair_indices(0);
288 assert_eq!(pairs, vec![1]);
289
290 let pairs = graph.get_pair_indices(1);
292 assert_eq!(pairs, vec![0]);
293 }
294}