Skip to main content

atomcode_core/commands/
mod.rs

1// crates/atomcode-core/src/commands/mod.rs
2//
3// Custom slash-command registry. Users define commands as `.md` files with
4// YAML-style frontmatter in two locations:
5//
6//   1. `$ATOMCODE_HOME/commands/`          — global (apply to every project)
7//   2. `<project>/.atomcode/commands/`  — project-level (override global
8//                                          when names collide)
9//
10// Each file has the shape:
11//
12// ```markdown
13// ---
14// name: review
15// description: 对当前 diff 进行代码审查
16// args: optional
17// ---
18// 请对当前 git diff 中的所有改动进行代码审查。
19// 如有指定文件则只审查: $ARGUMENTS
20// ```
21//
22// The registry is loaded once at startup (or on `/mcp reload`-style events)
23// and queried by the TUI dispatch loop. Custom commands are NOT LLM-invocable
24// — they are user-invocable via `/command_name` in the input, sending the
25// rendered template as a user message to the agent.
26
27use std::collections::HashMap;
28use std::path::{Path, PathBuf};
29
30/// A single custom command parsed from a `.md` template file.
31#[derive(Debug, Clone)]
32pub struct CustomCommand {
33    pub name: String,
34    pub description: String,
35    pub args_requirement: ArgsRequirement,
36    pub template: String,
37    pub source: PathBuf,
38    pub namespace: Option<String>,
39}
40
41/// Whether a custom command expects arguments from the user.
42#[derive(Debug, Clone, PartialEq)]
43pub enum ArgsRequirement {
44    Required,
45    Optional,
46    None,
47}
48
49/// In-memory registry of all custom commands discovered from global +
50/// project-level directories. Project commands override global ones when
51/// names collide (loaded second).
52pub struct CustomCommandRegistry {
53    commands: HashMap<String, CustomCommand>,
54}
55
56impl CustomCommandRegistry {
57    /// Scan both global (`$ATOMCODE_HOME/commands/`) and project-level
58    /// (`<project_root>/.atomcode/commands/`) directories, merging results.
59    /// Project entries win on name collision.
60    pub fn load(project_root: &Path) -> Self {
61        let config_dir = crate::config::Config::config_dir();
62        let mut commands = HashMap::new();
63        // Global first — project overrides on second pass.
64        Self::load_from_dir(&config_dir.join("commands"), None, &mut commands);
65        Self::load_from_dir(&project_root.join(".atomcode/commands"), None, &mut commands);
66        // Plugin layer
67        for assets in crate::plugin::loader::iter_installed_plugin_assets() {
68            Self::load_from_dir(&assets.commands_dir(), Some(&assets.plugin), &mut commands);
69        }
70        Self { commands }
71    }
72
73    /// An empty registry — useful for tests or when custom commands are
74    /// disabled.
75    pub fn empty() -> Self {
76        Self {
77            commands: HashMap::new(),
78        }
79    }
80
81    fn load_from_dir(
82        dir: &Path,
83        namespace: Option<&str>,
84        commands: &mut HashMap<String, CustomCommand>,
85    ) {
86        let entries = match std::fs::read_dir(dir) {
87            Ok(e) => e,
88            Err(_) => return,
89        };
90        for entry in entries.flatten() {
91            let path = entry.path();
92            if path.extension().and_then(|e| e.to_str()) != Some("md") {
93                continue;
94            }
95            if let Some(mut cmd) = Self::parse_command_file(&path) {
96                if let Some(ns) = namespace {
97                    cmd.namespace = Some(ns.to_string());
98                }
99                let key = match &cmd.namespace {
100                    Some(ns) => format!("{}:{}", ns, cmd.name),
101                    None => cmd.name.clone(),
102                };
103                commands.insert(key, cmd);
104            }
105        }
106    }
107
108    fn parse_command_file(path: &Path) -> Option<CustomCommand> {
109        let content = std::fs::read_to_string(path).ok()?;
110        let (frontmatter, template) = Self::split_frontmatter(&content)?;
111        let name = Self::extract_field(&frontmatter, "name")?;
112        let description = Self::extract_field(&frontmatter, "description")
113            .unwrap_or_else(|| format!("Custom command: {}", name));
114        let args_str = Self::extract_field(&frontmatter, "args").unwrap_or_else(|| "none".into());
115        let args_requirement = match args_str.as_str() {
116            "required" => ArgsRequirement::Required,
117            "optional" => ArgsRequirement::Optional,
118            _ => ArgsRequirement::None,
119        };
120        Some(CustomCommand {
121            name,
122            description,
123            args_requirement,
124            template: template.trim().to_string(),
125            source: path.to_path_buf(),
126            namespace: None,
127        })
128    }
129
130    /// Split `---\n..frontmatter..\n---\nbody` into (frontmatter, body).
131    /// Returns `None` when the content doesn't start with `---`.
132    fn split_frontmatter(content: &str) -> Option<(String, String)> {
133        let content = content.trim();
134        if !content.starts_with("---") {
135            return None;
136        }
137        let rest = &content[3..];
138        let end = rest.find("---")?;
139        Some((rest[..end].trim().to_string(), rest[end + 3..].to_string()))
140    }
141
142    /// Extract a `key: value` field from frontmatter text. Handles leading
143    /// whitespace but not quoted values — keeps parsing minimal.
144    fn extract_field(frontmatter: &str, key: &str) -> Option<String> {
145        for line in frontmatter.lines() {
146            let line = line.trim();
147            if let Some(rest) = line.strip_prefix(key) {
148                if let Some(value) = rest.trim_start().strip_prefix(':') {
149                    return Some(value.trim().to_string());
150                }
151            }
152        }
153        None
154    }
155
156    /// Look up a command by name.
157    pub fn get(&self, name: &str) -> Option<&CustomCommand> {
158        self.commands.get(name)
159    }
160
161    /// Render the template for `name`, replacing `$ARGUMENTS` /
162    /// `${ARGUMENTS}` with the provided args string.
163    pub fn render(&self, name: &str, args: &str) -> Option<String> {
164        self.commands.get(name).map(|cmd| {
165            cmd.template
166                .replace("$ARGUMENTS", args)
167                .replace("${ARGUMENTS}", args)
168        })
169    }
170
171    /// All commands, sorted by name.
172    pub fn list(&self) -> Vec<&CustomCommand> {
173        let mut cmds: Vec<_> = self.commands.values().collect();
174        cmds.sort_by_key(|c| &c.name);
175        cmds
176    }
177
178    /// `(name, description)` pairs for every registered custom command,
179    /// sorted by name. Convenient for feeding into completion / menu builders.
180    pub fn command_names_and_descriptions(&self) -> Vec<(String, String)> {
181        self.list()
182            .iter()
183            .map(|c| (c.name.clone(), c.description.clone()))
184            .collect()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn split_frontmatter_works() {
194        let content = "---\nname: review\ndescription: Code review\n---\nTemplate body here";
195        let (fm, body) = CustomCommandRegistry::split_frontmatter(content).unwrap();
196        assert!(fm.contains("name: review"));
197        assert!(fm.contains("description: Code review"));
198        assert_eq!(body.trim(), "Template body here");
199    }
200
201    #[test]
202    fn split_frontmatter_returns_none_without_delimiters() {
203        assert!(CustomCommandRegistry::split_frontmatter("no frontmatter here").is_none());
204        assert!(CustomCommandRegistry::split_frontmatter("--- only opening").is_none());
205    }
206
207    #[test]
208    fn extract_field_works() {
209        let fm = "name: review\ndescription: Code review\nargs: optional";
210        assert_eq!(
211            CustomCommandRegistry::extract_field(fm, "name"),
212            Some("review".into())
213        );
214        assert_eq!(
215            CustomCommandRegistry::extract_field(fm, "description"),
216            Some("Code review".into())
217        );
218        assert_eq!(
219            CustomCommandRegistry::extract_field(fm, "args"),
220            Some("optional".into())
221        );
222        assert_eq!(CustomCommandRegistry::extract_field(fm, "missing"), None);
223    }
224
225    #[test]
226    fn parse_command_file_works() {
227        let dir = tempfile::tempdir().unwrap();
228        let path = dir.path().join("review.md");
229        std::fs::write(
230            &path,
231            "---\nname: review\ndescription: Code review\nargs: optional\n---\nReview: $ARGUMENTS",
232        )
233        .unwrap();
234        let cmd = CustomCommandRegistry::parse_command_file(&path).unwrap();
235        assert_eq!(cmd.name, "review");
236        assert_eq!(cmd.description, "Code review");
237        assert_eq!(cmd.args_requirement, ArgsRequirement::Optional);
238        assert_eq!(cmd.template, "Review: $ARGUMENTS");
239        assert_eq!(cmd.source, path);
240    }
241
242    #[test]
243    fn render_replaces_arguments() {
244        let dir = tempfile::tempdir().unwrap();
245        let cmd_dir = dir.path().join(".atomcode/commands");
246        std::fs::create_dir_all(&cmd_dir).unwrap();
247        std::fs::write(
248            cmd_dir.join("review.md"),
249            "---\nname: review\ndescription: Review\nargs: optional\n---\nReview $ARGUMENTS and ${ARGUMENTS} done",
250        )
251        .unwrap();
252        let reg = CustomCommandRegistry::load(dir.path());
253        let rendered = reg.render("review", "main.rs").unwrap();
254        assert_eq!(rendered, "Review main.rs and main.rs done");
255    }
256
257    #[test]
258    fn render_empty_args() {
259        let dir = tempfile::tempdir().unwrap();
260        let cmd_dir = dir.path().join(".atomcode/commands");
261        std::fs::create_dir_all(&cmd_dir).unwrap();
262        std::fs::write(
263            cmd_dir.join("test.md"),
264            "---\nname: test\ndescription: Run tests\n---\nRun all tests. Focus: $ARGUMENTS",
265        )
266        .unwrap();
267        let reg = CustomCommandRegistry::load(dir.path());
268        let rendered = reg.render("test", "").unwrap();
269        assert_eq!(rendered, "Run all tests. Focus: ");
270    }
271
272    #[test]
273    fn load_from_dir_skips_non_md() {
274        let dir = tempfile::tempdir().unwrap();
275        let cmd_dir = dir.path().join("commands");
276        std::fs::create_dir_all(&cmd_dir).unwrap();
277        // Valid .md file
278        std::fs::write(
279            cmd_dir.join("valid.md"),
280            "---\nname: valid\ndescription: Valid cmd\n---\nTemplate",
281        )
282        .unwrap();
283        // Non-md file — should be skipped
284        std::fs::write(
285            cmd_dir.join("skip.txt"),
286            "---\nname: skip\ndescription: Skip\n---\nNope",
287        )
288        .unwrap();
289        // No extension — should be skipped
290        std::fs::write(
291            cmd_dir.join("noext"),
292            "---\nname: noext\ndescription: No ext\n---\nNope",
293        )
294        .unwrap();
295        let mut commands = HashMap::new();
296        CustomCommandRegistry::load_from_dir(&cmd_dir, None, &mut commands);
297        assert_eq!(commands.len(), 1);
298        assert!(commands.contains_key("valid"));
299    }
300
301    #[test]
302    fn project_overrides_global_same_name() {
303        let root = tempfile::tempdir().unwrap();
304
305        // Simulate global dir
306        let global_dir = root.path().join("global_commands");
307        std::fs::create_dir_all(&global_dir).unwrap();
308        std::fs::write(
309            global_dir.join("review.md"),
310            "---\nname: review\ndescription: Global review\n---\nGlobal template",
311        )
312        .unwrap();
313
314        // Simulate project dir
315        let project_dir = root.path().join("project_commands");
316        std::fs::create_dir_all(&project_dir).unwrap();
317        std::fs::write(
318            project_dir.join("review.md"),
319            "---\nname: review\ndescription: Project review\n---\nProject template",
320        )
321        .unwrap();
322
323        let mut commands = HashMap::new();
324        // Load global first, then project — project should override.
325        CustomCommandRegistry::load_from_dir(&global_dir, None, &mut commands);
326        CustomCommandRegistry::load_from_dir(&project_dir, None, &mut commands);
327
328        let cmd = commands.get("review").unwrap();
329        assert_eq!(cmd.description, "Project review");
330        assert_eq!(cmd.template, "Project template");
331    }
332
333    #[test]
334    fn list_returns_sorted_commands() {
335        let dir = tempfile::tempdir().unwrap();
336        let cmd_dir = dir.path().join(".atomcode/commands");
337        std::fs::create_dir_all(&cmd_dir).unwrap();
338        std::fs::write(
339            cmd_dir.join("zebra.md"),
340            "---\nname: zebra\ndescription: Z\n---\nZ",
341        )
342        .unwrap();
343        std::fs::write(
344            cmd_dir.join("alpha.md"),
345            "---\nname: alpha\ndescription: A\n---\nA",
346        )
347        .unwrap();
348        let reg = CustomCommandRegistry::load(dir.path());
349        let names: Vec<_> = reg.list().iter().map(|c| c.name.as_str()).collect();
350        assert_eq!(names, vec!["alpha", "zebra"]);
351    }
352
353    #[test]
354    fn empty_registry_has_no_commands() {
355        let reg = CustomCommandRegistry::empty();
356        assert!(reg.list().is_empty());
357        assert!(reg.get("anything").is_none());
358        assert!(reg.render("anything", "").is_none());
359    }
360
361    #[test]
362    fn parse_file_without_frontmatter_returns_none() {
363        let dir = tempfile::tempdir().unwrap();
364        let path = dir.path().join("bad.md");
365        std::fs::write(&path, "No frontmatter here, just text.").unwrap();
366        assert!(CustomCommandRegistry::parse_command_file(&path).is_none());
367    }
368
369    #[test]
370    fn parse_file_without_name_returns_none() {
371        let dir = tempfile::tempdir().unwrap();
372        let path = dir.path().join("noname.md");
373        std::fs::write(
374            &path,
375            "---\ndescription: Missing name field\n---\nTemplate",
376        )
377        .unwrap();
378        assert!(CustomCommandRegistry::parse_command_file(&path).is_none());
379    }
380
381    #[test]
382    fn default_description_when_missing() {
383        let dir = tempfile::tempdir().unwrap();
384        let path = dir.path().join("nodesc.md");
385        std::fs::write(&path, "---\nname: nodesc\n---\nTemplate").unwrap();
386        let cmd = CustomCommandRegistry::parse_command_file(&path).unwrap();
387        assert_eq!(cmd.description, "Custom command: nodesc");
388    }
389
390    #[test]
391    fn default_args_requirement_is_none() {
392        let dir = tempfile::tempdir().unwrap();
393        let path = dir.path().join("noargs.md");
394        std::fs::write(&path, "---\nname: noargs\n---\nTemplate").unwrap();
395        let cmd = CustomCommandRegistry::parse_command_file(&path).unwrap();
396        assert_eq!(cmd.args_requirement, ArgsRequirement::None);
397    }
398
399    #[test]
400    #[serial_test::serial]
401    fn load_plugin_layer_namespaces_commands() {
402        // Set up an installed plugin on disk via the plugin module's state.
403        let tmp = tempfile::tempdir().unwrap();
404        std::env::set_var("ATOMCODE_HOME", tmp.path());
405
406        let plugin_dir = tmp.path().join("plugins/marketplaces/p");
407        let cmd_dir = plugin_dir.join("commands");
408        std::fs::create_dir_all(&cmd_dir).unwrap();
409        std::fs::write(
410            cmd_dir.join("greet.md"),
411            "---\nname: greet\ndescription: hi\n---\nhello $ARGUMENTS",
412        )
413        .unwrap();
414        std::fs::write(
415            tmp.path().join("plugins/installed_plugins.json"),
416            r#"{"version":1,"plugins":{"p@p":{"marketplace":"p","plugin":"p","plugin_dir":"marketplaces/p","installed_at":"x"}}}"#,
417        )
418        .unwrap();
419
420        let working = tempfile::tempdir().unwrap();
421        let reg = CustomCommandRegistry::load(working.path());
422        assert!(reg.get("p:greet").is_some());
423
424        std::env::remove_var("ATOMCODE_HOME");
425    }
426}