agpm_cli/resolver/
sha_conflict_detector.rs1use anyhow::Result;
8use std::collections::HashMap;
9
10use super::types::ResolutionMode;
11
12#[derive(Debug, Clone)]
14pub struct ResolvedRequirement {
15 pub source: String,
17 pub path: String,
19 pub resolved_sha: String,
21 pub requested_version: String,
23 pub required_by: String,
25 pub resolution_mode: ResolutionMode,
27}
28
29#[derive(Debug, Clone)]
31pub struct ShaConflict {
32 pub source: String,
34 pub path: String,
36 pub sha_groups: HashMap<String, Vec<ResolvedRequirement>>,
38}
39
40impl ShaConflict {
41 pub fn format_error(&self) -> String {
43 format!(
44 "SHA conflict for {}/{}:\n{}",
45 self.source,
46 self.path,
47 self.sha_groups
48 .iter()
49 .map(|(sha, reqs)| {
50 format!(
51 " SHA {} required by:\n{}",
52 &sha[..8.min(sha.len())],
53 reqs.iter()
54 .map(|r| format!(
55 " - {} (via {})",
56 r.required_by,
57 match r.resolution_mode {
58 ResolutionMode::Version =>
59 format!("version={}", r.requested_version),
60 ResolutionMode::GitRef =>
61 format!("git ref={}", r.requested_version),
62 }
63 ))
64 .collect::<Vec<_>>()
65 .join("\n")
66 )
67 })
68 .collect::<Vec<_>>()
69 .join("\n")
70 )
71 }
72}
73
74pub struct ShaConflictDetector {
80 requirements: HashMap<(String, String), Vec<ResolvedRequirement>>,
82}
83
84impl Default for ShaConflictDetector {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90impl ShaConflictDetector {
91 pub fn new() -> Self {
93 Self {
94 requirements: HashMap::new(),
95 }
96 }
97
98 pub fn add_requirement(&mut self, requirement: ResolvedRequirement) {
100 let key = (requirement.source.clone(), requirement.path.clone());
101 self.requirements.entry(key).or_default().push(requirement);
102 }
103
104 pub fn detect_conflicts(&self) -> Result<Vec<ShaConflict>> {
109 let mut conflicts = Vec::new();
110
111 for ((source, path), requirements) in &self.requirements {
112 let mut sha_groups: HashMap<String, Vec<ResolvedRequirement>> = HashMap::new();
114 for req in requirements {
115 sha_groups.entry(req.resolved_sha.clone()).or_default().push(req.clone());
116 }
117
118 if sha_groups.len() > 1 {
120 conflicts.push(ShaConflict {
121 source: source.clone(),
122 path: path.clone(),
123 sha_groups,
124 });
125 }
126 }
127
128 Ok(conflicts)
129 }
130
131 pub fn get_requirements(&self, source: &str, path: &str) -> Option<&[ResolvedRequirement]> {
133 self.requirements.get(&(source.to_string(), path.to_string())).map(|reqs| reqs.as_slice())
134 }
135
136 pub fn clear(&mut self) {
138 self.requirements.clear();
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145
146 #[test]
147 fn test_no_conflict_same_sha() {
148 let mut detector = ShaConflictDetector::new();
149
150 detector.add_requirement(ResolvedRequirement {
152 source: "test".to_string(),
153 path: "agents/helper.md".to_string(),
154 resolved_sha: "abc123def456".to_string(),
155 requested_version: "v1.0.0".to_string(),
156 required_by: "agent-a".to_string(),
157 resolution_mode: ResolutionMode::Version,
158 });
159
160 detector.add_requirement(ResolvedRequirement {
161 source: "test".to_string(),
162 path: "agents/helper.md".to_string(),
163 resolved_sha: "abc123def456".to_string(),
164 requested_version: "main".to_string(),
165 required_by: "agent-b".to_string(),
166 resolution_mode: ResolutionMode::GitRef,
167 });
168
169 let conflicts = detector.detect_conflicts().unwrap();
170 assert_eq!(conflicts.len(), 0);
171 }
172
173 #[test]
174 fn test_conflict_different_shas() {
175 let mut detector = ShaConflictDetector::new();
176
177 detector.add_requirement(ResolvedRequirement {
179 source: "test".to_string(),
180 path: "agents/helper.md".to_string(),
181 resolved_sha: "abc123def456".to_string(),
182 requested_version: "v1.0.0".to_string(),
183 required_by: "agent-a".to_string(),
184 resolution_mode: ResolutionMode::Version,
185 });
186
187 detector.add_requirement(ResolvedRequirement {
188 source: "test".to_string(),
189 path: "agents/helper.md".to_string(),
190 resolved_sha: "def456abc123".to_string(),
191 requested_version: "main".to_string(),
192 required_by: "agent-b".to_string(),
193 resolution_mode: ResolutionMode::GitRef,
194 });
195
196 let conflicts = detector.detect_conflicts().unwrap();
197 assert_eq!(conflicts.len(), 1);
198
199 let conflict = &conflicts[0];
200 assert_eq!(conflict.source, "test");
201 assert_eq!(conflict.path, "agents/helper.md");
202 assert_eq!(conflict.sha_groups.len(), 2);
203 }
204
205 #[test]
206 fn test_conflict_formatting() {
207 let mut detector = ShaConflictDetector::new();
208
209 detector.add_requirement(ResolvedRequirement {
210 source: "test".to_string(),
211 path: "agents/helper.md".to_string(),
212 resolved_sha: "abc123def456".to_string(),
213 requested_version: "v1.0.0".to_string(),
214 required_by: "agent-a".to_string(),
215 resolution_mode: ResolutionMode::Version,
216 });
217
218 detector.add_requirement(ResolvedRequirement {
219 source: "test".to_string(),
220 path: "agents/helper.md".to_string(),
221 resolved_sha: "def456abc123".to_string(),
222 requested_version: "main".to_string(),
223 required_by: "agent-b".to_string(),
224 resolution_mode: ResolutionMode::GitRef,
225 });
226
227 let conflicts = detector.detect_conflicts().unwrap();
228 let error_msg = conflicts[0].format_error();
229
230 assert!(error_msg.contains("SHA conflict for test/agents/helper.md"));
231 assert!(error_msg.contains("abc123de"));
232 assert!(error_msg.contains("def456ab"));
233 assert!(error_msg.contains("agent-a"));
234 assert!(error_msg.contains("agent-b"));
235 assert!(error_msg.contains("version=v1.0.0"));
236 assert!(error_msg.contains("git ref=main"));
237 }
238}