Skip to main content

minion_engine/prompts/
detector.rs

1use std::collections::HashMap;
2use std::path::Path;
3
4use crate::error::StepError;
5use crate::prompts::registry::{Registry, StackDef};
6
7#[derive(Debug, Clone)]
8pub struct StackInfo {
9    pub name: String,
10    pub parent_chain: Vec<String>,
11    pub tools: HashMap<String, String>,
12}
13
14pub struct StackDetector;
15
16impl StackDetector {
17    /// Detect the technology stack for the given workspace path.
18    ///
19    /// Follows `detection_order` in the registry (most specific first).
20    /// Returns the first fully matching stack as a [`StackInfo`].
21    pub async fn detect(registry: &Registry, workspace_path: &Path) -> Result<StackInfo, StepError> {
22        let mut checked_markers: Vec<String> = Vec::new();
23
24        for stack_name in &registry.detection_order {
25            let stack_def = match registry.stacks.get(stack_name) {
26                Some(def) => def,
27                None => continue,
28            };
29
30            // Skip stacks with neither file_markers nor content_match (e.g. _default)
31            if stack_def.file_markers.is_empty() && stack_def.content_match.is_empty() {
32                continue;
33            }
34
35            // Check file_markers: any ONE matching is sufficient
36            if !stack_def.file_markers.is_empty() {
37                let mut any_marker_found = false;
38                for marker in &stack_def.file_markers {
39                    checked_markers.push(marker.clone());
40                    if tokio::fs::metadata(workspace_path.join(marker))
41                        .await
42                        .is_ok()
43                    {
44                        any_marker_found = true;
45                        break;
46                    }
47                }
48                if !any_marker_found {
49                    continue;
50                }
51            }
52
53            // Check content_match: ALL entries must match
54            if !Self::content_matches(stack_def, workspace_path).await {
55                continue;
56            }
57
58            // This stack matches — build and return StackInfo
59            return Ok(Self::build_stack_info(stack_name, registry));
60        }
61
62        let markers_list = checked_markers.join(", ");
63        Err(StepError::Fail(format!(
64            "Could not detect project stack in '{}'. Checked markers: [{}]. \
65             Create prompts/registry.yaml with your stack definition.",
66            workspace_path.display(),
67            markers_list
68        )))
69    }
70
71    /// Returns true if ALL content_match patterns satisfy the workspace files.
72    async fn content_matches(stack_def: &StackDef, workspace_path: &Path) -> bool {
73        for (filename, pattern) in &stack_def.content_match {
74            match tokio::fs::read_to_string(workspace_path.join(filename)).await {
75                Ok(content) if content.contains(pattern.as_str()) => {}
76                _ => return false,
77            }
78        }
79        true
80    }
81
82    /// Build a [`StackInfo`] by walking the parent chain and merging tools.
83    fn build_stack_info(name: &str, registry: &Registry) -> StackInfo {
84        // Walk parent chain from child to root
85        let mut parent_chain: Vec<String> = Vec::new();
86        let mut current = registry.stacks.get(name).and_then(|s| s.parent.as_deref());
87        while let Some(parent_name) = current {
88            parent_chain.push(parent_name.to_string());
89            current = registry
90                .stacks
91                .get(parent_name)
92                .and_then(|s| s.parent.as_deref());
93        }
94
95        // Merge tools root-first so child overrides parent
96        let mut full_chain: Vec<&str> = vec![name];
97        full_chain.extend(parent_chain.iter().map(String::as_str));
98        full_chain.reverse(); // root -> child
99
100        let mut tools: HashMap<String, String> = HashMap::new();
101        for stack_name in &full_chain {
102            if let Some(stack_def) = registry.stacks.get(*stack_name) {
103                tools.extend(stack_def.tools.clone());
104            }
105        }
106
107        StackInfo {
108            name: name.to_string(),
109            parent_chain,
110            tools,
111        }
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use crate::prompts::registry::{Registry, StackDef};
119    use std::io::Write as _;
120    use tempfile::tempdir;
121
122    fn make_registry() -> Registry {
123        let mut stacks = HashMap::new();
124
125        stacks.insert(
126            "_default".to_string(),
127            StackDef {
128                parent: None,
129                file_markers: vec![],
130                content_match: HashMap::new(),
131                tools: {
132                    let mut t = HashMap::new();
133                    t.insert("lint".to_string(), "echo 'no linter'".to_string());
134                    t.insert("test".to_string(), "echo 'no test'".to_string());
135                    t.insert("build".to_string(), "echo 'no build'".to_string());
136                    t
137                },
138            },
139        );
140
141        stacks.insert(
142            "rust".to_string(),
143            StackDef {
144                parent: Some("_default".to_string()),
145                file_markers: vec!["Cargo.toml".to_string()],
146                content_match: HashMap::new(),
147                tools: {
148                    let mut t = HashMap::new();
149                    t.insert("lint".to_string(), "cargo clippy -- -D warnings".to_string());
150                    t.insert("test".to_string(), "cargo test".to_string());
151                    t.insert("build".to_string(), "cargo build --release".to_string());
152                    t
153                },
154            },
155        );
156
157        stacks.insert(
158            "java".to_string(),
159            StackDef {
160                parent: Some("_default".to_string()),
161                file_markers: vec!["pom.xml".to_string(), "build.gradle".to_string()],
162                content_match: HashMap::new(),
163                tools: {
164                    let mut t = HashMap::new();
165                    t.insert("test".to_string(), "mvn test".to_string());
166                    t.insert("build".to_string(), "mvn package -DskipTests".to_string());
167                    t
168                },
169            },
170        );
171
172        stacks.insert(
173            "java-spring".to_string(),
174            StackDef {
175                parent: Some("java".to_string()),
176                file_markers: vec!["pom.xml".to_string()],
177                content_match: {
178                    let mut m = HashMap::new();
179                    m.insert("pom.xml".to_string(), "spring-boot".to_string());
180                    m
181                },
182                tools: {
183                    let mut t = HashMap::new();
184                    t.insert(
185                        "test".to_string(),
186                        "mvn test -Dspring.profiles.active=test".to_string(),
187                    );
188                    t
189                },
190            },
191        );
192
193        stacks.insert(
194            "javascript".to_string(),
195            StackDef {
196                parent: Some("_default".to_string()),
197                file_markers: vec!["package.json".to_string()],
198                content_match: HashMap::new(),
199                tools: {
200                    let mut t = HashMap::new();
201                    t.insert("test".to_string(), "npm test".to_string());
202                    t
203                },
204            },
205        );
206
207        Registry {
208            version: 1,
209            detection_order: vec![
210                "java-spring".to_string(),
211                "java".to_string(),
212                "javascript".to_string(),
213                "rust".to_string(),
214            ],
215            stacks,
216        }
217    }
218
219    #[tokio::test]
220    async fn test_detect_rust_project() {
221        let dir = tempdir().unwrap();
222        std::fs::File::create(dir.path().join("Cargo.toml")).unwrap();
223
224        let registry = make_registry();
225        let result = StackDetector::detect(&registry, dir.path()).await.unwrap();
226
227        assert_eq!(result.name, "rust");
228        assert_eq!(result.parent_chain, vec!["_default"]);
229    }
230
231    #[tokio::test]
232    async fn test_detect_java_spring_project() {
233        let dir = tempdir().unwrap();
234        let mut f = std::fs::File::create(dir.path().join("pom.xml")).unwrap();
235        f.write_all(b"<project><parent><artifactId>spring-boot-starter-parent</artifactId></parent></project>")
236            .unwrap();
237
238        let registry = make_registry();
239        let result = StackDetector::detect(&registry, dir.path()).await.unwrap();
240
241        assert_eq!(result.name, "java-spring");
242    }
243
244    #[tokio::test]
245    async fn test_detection_order_java_spring_before_java() {
246        let dir = tempdir().unwrap();
247        let mut f = std::fs::File::create(dir.path().join("pom.xml")).unwrap();
248        f.write_all(b"<project>spring-boot</project>").unwrap();
249
250        let registry = make_registry();
251        let result = StackDetector::detect(&registry, dir.path()).await.unwrap();
252
253        // java-spring comes before java in detection_order, should be detected first
254        assert_eq!(result.name, "java-spring");
255    }
256
257    #[tokio::test]
258    async fn test_content_match_failure_falls_through_to_less_specific_stack() {
259        let dir = tempdir().unwrap();
260        // pom.xml exists but does NOT contain "spring-boot" -> java-spring fails, java matches
261        let mut f = std::fs::File::create(dir.path().join("pom.xml")).unwrap();
262        f.write_all(b"<project><groupId>com.example</groupId></project>")
263            .unwrap();
264
265        let registry = make_registry();
266        let result = StackDetector::detect(&registry, dir.path()).await.unwrap();
267
268        assert_eq!(result.name, "java");
269    }
270
271    #[tokio::test]
272    async fn test_no_stack_detected_returns_step_error_fail() {
273        let dir = tempdir().unwrap();
274        // Empty directory — no markers
275
276        let registry = make_registry();
277        let result = StackDetector::detect(&registry, dir.path()).await;
278
279        assert!(result.is_err());
280        let msg = result.unwrap_err().to_string();
281        assert!(
282            msg.contains("Could not detect project stack"),
283            "Expected error message, got: {msg}"
284        );
285    }
286
287    #[tokio::test]
288    async fn test_parent_chain_and_tool_merging_for_rust() {
289        let dir = tempdir().unwrap();
290        std::fs::File::create(dir.path().join("Cargo.toml")).unwrap();
291
292        let registry = make_registry();
293        let result = StackDetector::detect(&registry, dir.path()).await.unwrap();
294
295        assert_eq!(result.parent_chain, vec!["_default"]);
296        // Rust tools should override _default tools
297        assert_eq!(result.tools.get("test").unwrap(), "cargo test");
298        assert_eq!(
299            result.tools.get("lint").unwrap(),
300            "cargo clippy -- -D warnings"
301        );
302        assert_eq!(result.tools.get("build").unwrap(), "cargo build --release");
303        // "build" key was in _default too but rust overrides it
304    }
305
306    #[tokio::test]
307    async fn test_java_spring_parent_chain() {
308        let dir = tempdir().unwrap();
309        let mut f = std::fs::File::create(dir.path().join("pom.xml")).unwrap();
310        f.write_all(b"spring-boot").unwrap();
311
312        let registry = make_registry();
313        let result = StackDetector::detect(&registry, dir.path()).await.unwrap();
314
315        // java-spring -> java -> _default
316        assert_eq!(result.parent_chain, vec!["java", "_default"]);
317        // java-spring test overrides java's test
318        assert_eq!(
319            result.tools.get("test").unwrap(),
320            "mvn test -Dspring.profiles.active=test"
321        );
322    }
323}