Skip to main content

ubt_cli/
detect.rs

1use std::path::{Path, PathBuf};
2
3use crate::error::{Result, UbtError};
4use crate::plugin::{Plugin, PluginRegistry, PluginSource};
5
6/// Result of tool detection.
7#[derive(Debug)]
8pub struct DetectionResult {
9    pub plugin_name: String,
10    pub variant_name: String,
11    pub source: PluginSource,
12    pub project_root: PathBuf,
13}
14
15/// Detect the active tool using the SPEC §7.1 priority chain:
16/// 1. CLI override (`--tool`)
17/// 2. Environment variable (`UBT_TOOL`)
18/// 3. Config `[project].tool`
19/// 4. Auto-detection (walk CWD upward, check detect files)
20pub fn detect_tool(
21    cli_tool: Option<&str>,
22    config_tool: Option<&str>,
23    start_dir: &Path,
24    registry: &PluginRegistry,
25) -> Result<DetectionResult> {
26    // 1. CLI override
27    if let Some(tool) = cli_tool {
28        return resolve_explicit_tool(tool, start_dir, registry);
29    }
30
31    // 2. UBT_TOOL env var (already handled by clap's env feature on --tool,
32    //    but also check explicitly for programmatic use)
33    if let Ok(tool) = std::env::var("UBT_TOOL") {
34        if !tool.is_empty() {
35            return resolve_explicit_tool(&tool, start_dir, registry);
36        }
37    }
38
39    // 3. Config [project].tool
40    if let Some(tool) = config_tool {
41        return resolve_explicit_tool(tool, start_dir, registry);
42    }
43
44    // 4. Auto-detection
45    auto_detect(start_dir, registry)
46}
47
48/// Resolve an explicitly named tool (from CLI, env, or config).
49/// The tool name can be either a plugin name or a variant name (e.g., "pnpm").
50fn resolve_explicit_tool(
51    tool: &str,
52    start_dir: &Path,
53    registry: &PluginRegistry,
54) -> Result<DetectionResult> {
55    // First check if it matches a plugin name directly
56    if let Some((plugin, source)) = registry.get(tool) {
57        return Ok(DetectionResult {
58            plugin_name: plugin.name.clone(),
59            variant_name: detect_variant(plugin, start_dir)
60                .unwrap_or_else(|| plugin.default_variant.clone()),
61            source: source.clone(),
62            project_root: start_dir.to_path_buf(),
63        });
64    }
65
66    // Check if it matches a variant name within any plugin
67    for (_name, (plugin, source)) in registry.iter() {
68        if plugin.variants.contains_key(tool) {
69            return Ok(DetectionResult {
70                plugin_name: plugin.name.clone(),
71                variant_name: tool.to_string(),
72                source: source.clone(),
73                project_root: start_dir.to_path_buf(),
74            });
75        }
76    }
77
78    Err(UbtError::PluginLoadError {
79        name: tool.to_string(),
80        detail: "no plugin or variant found with this name".into(),
81    })
82}
83
84/// Auto-detect tool by walking from start_dir upward.
85fn auto_detect(start_dir: &Path, registry: &PluginRegistry) -> Result<DetectionResult> {
86    let mut current = start_dir.to_path_buf();
87
88    loop {
89        let matches = detect_at_dir(&current, registry);
90        if !matches.is_empty() {
91            return resolve_matches(matches, &current);
92        }
93        if !current.pop() {
94            break;
95        }
96    }
97
98    Err(UbtError::NoPluginMatch)
99}
100
101/// A detection match at a specific directory.
102#[derive(Debug)]
103struct DetectMatch {
104    plugin_name: String,
105    variant_name: String,
106    priority: i32,
107    source: PluginSource,
108}
109
110/// Check all plugins for matches in the given directory.
111fn detect_at_dir(dir: &Path, registry: &PluginRegistry) -> Vec<DetectMatch> {
112    let mut matches = Vec::new();
113
114    for (_name, (plugin, source)) in registry.iter() {
115        if plugin_matches_dir(plugin, dir) {
116            let variant =
117                detect_variant(plugin, dir).unwrap_or_else(|| plugin.default_variant.clone());
118            matches.push(DetectMatch {
119                plugin_name: plugin.name.clone(),
120                variant_name: variant,
121                priority: plugin.priority,
122                source: source.clone(),
123            });
124        }
125    }
126
127    matches
128}
129
130/// Check if a plugin's detect files are present in the given directory.
131fn plugin_matches_dir(plugin: &Plugin, dir: &Path) -> bool {
132    plugin.detect.files.iter().any(|pattern| {
133        if pattern.contains('*') {
134            // Glob pattern (e.g., "*.csproj")
135            glob_matches(dir, pattern)
136        } else {
137            dir.join(pattern).exists()
138        }
139    })
140}
141
142/// Check if a glob pattern matches any file in the directory.
143fn glob_matches(dir: &Path, pattern: &str) -> bool {
144    let Ok(matcher) = globset::GlobBuilder::new(pattern)
145        .literal_separator(true)
146        .build()
147        .map(|g| g.compile_matcher())
148    else {
149        return false;
150    };
151
152    let Ok(entries) = std::fs::read_dir(dir) else {
153        return false;
154    };
155
156    entries.filter_map(|e| e.ok()).any(|entry| {
157        entry
158            .file_name()
159            .to_str()
160            .map(|name| matcher.is_match(name))
161            .unwrap_or(false)
162    })
163}
164
165/// Detect which variant to use based on lockfile presence.
166fn detect_variant(plugin: &Plugin, dir: &Path) -> Option<String> {
167    for (variant_name, variant) in &plugin.variants {
168        for detect_file in &variant.detect_files {
169            if detect_file.contains('*') {
170                if glob_matches(dir, detect_file) {
171                    return Some(variant_name.clone());
172                }
173            } else if dir.join(detect_file).exists() {
174                return Some(variant_name.clone());
175            }
176        }
177    }
178    None
179}
180
181/// Resolve multiple matches using priority. Error on ties.
182fn resolve_matches(matches: Vec<DetectMatch>, dir: &Path) -> Result<DetectionResult> {
183    assert!(!matches.is_empty());
184
185    if matches.len() == 1 {
186        let m = matches.into_iter().next().unwrap();
187        return Ok(DetectionResult {
188            plugin_name: m.plugin_name,
189            variant_name: m.variant_name,
190            source: m.source,
191            project_root: dir.to_path_buf(),
192        });
193    }
194
195    // Sort by priority descending
196    let mut sorted = matches;
197    sorted.sort_by(|a, b| b.priority.cmp(&a.priority));
198
199    // Check if the top two have the same priority
200    if sorted[0].priority == sorted[1].priority {
201        let plugins: Vec<_> = sorted.iter().map(|m| m.plugin_name.as_str()).collect();
202        return Err(UbtError::PluginConflict {
203            plugins: plugins.join(", "),
204            suggested_tool: sorted[0].plugin_name.clone(),
205        });
206    }
207
208    let winner = sorted.into_iter().next().unwrap();
209    Ok(DetectionResult {
210        plugin_name: winner.plugin_name,
211        variant_name: winner.variant_name,
212        source: winner.source,
213        project_root: dir.to_path_buf(),
214    })
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use std::sync::Mutex;
221    use tempfile::TempDir;
222
223    static ENV_MUTEX: Mutex<()> = Mutex::new(());
224
225    fn with_clean_env<F, R>(f: F) -> R
226    where
227        F: FnOnce() -> R,
228    {
229        let _lock = ENV_MUTEX.lock().unwrap_or_else(|e| e.into_inner());
230        let prev = std::env::var("UBT_TOOL").ok();
231        unsafe {
232            std::env::remove_var("UBT_TOOL");
233        }
234        let result = f();
235        if let Some(v) = prev {
236            unsafe {
237                std::env::set_var("UBT_TOOL", v);
238            }
239        }
240        result
241    }
242
243    #[test]
244    fn detect_go_project() {
245        with_clean_env(|| {
246            let dir = TempDir::new().unwrap();
247            std::fs::write(dir.path().join("go.mod"), "module example.com/foo").unwrap();
248
249            let registry = PluginRegistry::new().unwrap();
250            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
251
252            assert_eq!(result.plugin_name, "go");
253            assert_eq!(result.variant_name, "go");
254        });
255    }
256
257    #[test]
258    fn detect_node_npm() {
259        with_clean_env(|| {
260            let dir = TempDir::new().unwrap();
261            std::fs::write(dir.path().join("package.json"), "{}").unwrap();
262            std::fs::write(dir.path().join("package-lock.json"), "{}").unwrap();
263
264            let registry = PluginRegistry::new().unwrap();
265            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
266
267            assert_eq!(result.plugin_name, "node");
268            assert_eq!(result.variant_name, "npm");
269        });
270    }
271
272    #[test]
273    fn detect_node_pnpm() {
274        with_clean_env(|| {
275            let dir = TempDir::new().unwrap();
276            std::fs::write(dir.path().join("package.json"), "{}").unwrap();
277            std::fs::write(dir.path().join("pnpm-lock.yaml"), "").unwrap();
278
279            let registry = PluginRegistry::new().unwrap();
280            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
281
282            assert_eq!(result.plugin_name, "node");
283            assert_eq!(result.variant_name, "pnpm");
284        });
285    }
286
287    #[test]
288    fn detect_node_default_variant_when_no_lockfile() {
289        with_clean_env(|| {
290            let dir = TempDir::new().unwrap();
291            std::fs::write(dir.path().join("package.json"), "{}").unwrap();
292
293            let registry = PluginRegistry::new().unwrap();
294            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
295
296            assert_eq!(result.plugin_name, "node");
297            assert_eq!(result.variant_name, "npm");
298        });
299    }
300
301    #[test]
302    fn detect_rust_project() {
303        with_clean_env(|| {
304            let dir = TempDir::new().unwrap();
305            std::fs::write(dir.path().join("Cargo.toml"), "[package]\nname = \"foo\"").unwrap();
306
307            let registry = PluginRegistry::new().unwrap();
308            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
309
310            assert_eq!(result.plugin_name, "rust");
311            assert_eq!(result.variant_name, "cargo");
312        });
313    }
314
315    #[test]
316    fn detect_cli_override() {
317        with_clean_env(|| {
318            let dir = TempDir::new().unwrap();
319            // Even with go.mod present, --tool=node should win
320            std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
321
322            let registry = PluginRegistry::new().unwrap();
323            let result = detect_tool(Some("node"), None, dir.path(), &registry).unwrap();
324
325            assert_eq!(result.plugin_name, "node");
326        });
327    }
328
329    #[test]
330    fn detect_config_override() {
331        with_clean_env(|| {
332            let dir = TempDir::new().unwrap();
333            std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
334
335            let registry = PluginRegistry::new().unwrap();
336            let result = detect_tool(None, Some("node"), dir.path(), &registry).unwrap();
337
338            assert_eq!(result.plugin_name, "node");
339        });
340    }
341
342    #[test]
343    fn detect_variant_name_as_tool() {
344        with_clean_env(|| {
345            let dir = TempDir::new().unwrap();
346            let registry = PluginRegistry::new().unwrap();
347            let result = detect_tool(Some("pnpm"), None, dir.path(), &registry).unwrap();
348
349            assert_eq!(result.plugin_name, "node");
350            assert_eq!(result.variant_name, "pnpm");
351        });
352    }
353
354    #[test]
355    fn detect_walks_upward() {
356        with_clean_env(|| {
357            let dir = TempDir::new().unwrap();
358            std::fs::write(dir.path().join("go.mod"), "module foo").unwrap();
359            let nested = dir.path().join("a").join("b").join("c");
360            std::fs::create_dir_all(&nested).unwrap();
361
362            let registry = PluginRegistry::new().unwrap();
363            let result = detect_tool(None, None, &nested, &registry).unwrap();
364
365            assert_eq!(result.plugin_name, "go");
366            assert_eq!(result.project_root, dir.path());
367        });
368    }
369
370    #[test]
371    fn detect_no_match_errors() {
372        with_clean_env(|| {
373            let dir = TempDir::new().unwrap();
374            let registry = PluginRegistry::new().unwrap();
375            let result = detect_tool(None, None, dir.path(), &registry);
376
377            assert!(result.is_err());
378            assert!(matches!(result.unwrap_err(), UbtError::NoPluginMatch));
379        });
380    }
381
382    #[test]
383    fn detect_unknown_tool_errors() {
384        with_clean_env(|| {
385            let dir = TempDir::new().unwrap();
386            let registry = PluginRegistry::new().unwrap();
387            let result = detect_tool(Some("nonexistent"), None, dir.path(), &registry);
388
389            assert!(result.is_err());
390        });
391    }
392
393    #[test]
394    fn detect_dotnet_glob() {
395        with_clean_env(|| {
396            let dir = TempDir::new().unwrap();
397            std::fs::write(dir.path().join("MyApp.csproj"), "<Project/>").unwrap();
398
399            let registry = PluginRegistry::new().unwrap();
400            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
401
402            assert_eq!(result.plugin_name, "dotnet");
403        });
404    }
405
406    #[test]
407    fn detect_ruby_project() {
408        with_clean_env(|| {
409            let dir = TempDir::new().unwrap();
410            std::fs::write(dir.path().join("Gemfile"), "source 'https://rubygems.org'").unwrap();
411            std::fs::write(dir.path().join("Gemfile.lock"), "").unwrap();
412
413            let registry = PluginRegistry::new().unwrap();
414            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
415
416            assert_eq!(result.plugin_name, "ruby");
417            assert_eq!(result.variant_name, "bundler");
418        });
419    }
420
421    #[test]
422    fn detect_python_pip() {
423        with_clean_env(|| {
424            let dir = TempDir::new().unwrap();
425            std::fs::write(dir.path().join("requirements.txt"), "flask").unwrap();
426
427            let registry = PluginRegistry::new().unwrap();
428            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
429
430            assert_eq!(result.plugin_name, "python");
431        });
432    }
433
434    #[test]
435    fn detect_java_maven() {
436        with_clean_env(|| {
437            let dir = TempDir::new().unwrap();
438            std::fs::write(dir.path().join("pom.xml"), "<project/>").unwrap();
439
440            let registry = PluginRegistry::new().unwrap();
441            let result = detect_tool(None, None, dir.path(), &registry).unwrap();
442
443            assert_eq!(result.plugin_name, "java");
444            assert_eq!(result.variant_name, "mvn");
445        });
446    }
447}