1#[cfg(feature = "internal-commands")]
5use crate::commands::{FrameworkCommand, HelloCommand, HelpCommand, PingCommand};
6use crate::output::hook;
7
8#[cfg(feature = "async")]
9use crate::command::AsyncCommand;
10use crate::command::Command;
11#[allow(unused_imports)]
12use crate::error::ModCliError;
13use std::collections::{HashMap, HashSet};
14
15type PreHookFn = dyn Fn(&str, &[String]) + Send + Sync;
17type PostHookFn = dyn Fn(&str, &[String], Result<(), &str>) + Send + Sync;
18type ErrorFmtFn = dyn Fn(&crate::error::ModCliError) -> String + Send + Sync;
19type VisibilityPolicyFn = dyn Fn(&dyn Command, &HashSet<String>) -> bool + Send + Sync;
20type AuthorizePolicyFn =
21 dyn Fn(&dyn Command, &HashSet<String>, &[String]) -> Result<(), String> + Send + Sync;
22
23pub struct CommandRegistry {
41 prefix: String,
42 commands: HashMap<String, Box<dyn Command>>,
43 aliases: HashMap<String, String>,
44 #[cfg(feature = "async")]
45 async_commands: HashMap<String, Box<dyn AsyncCommand>>, #[cfg(feature = "async")]
47 async_aliases: HashMap<String, String>,
48 caps: HashSet<String>,
49 visibility_policy: Option<Box<VisibilityPolicyFn>>,
50 authorize_policy: Option<Box<AuthorizePolicyFn>>,
51 pre_hook: Option<Box<PreHookFn>>, post_hook: Option<Box<PostHookFn>>, error_formatter: Option<Box<ErrorFmtFn>>,
54 #[cfg(feature = "dispatch-cache")]
55 cache: std::sync::Mutex<Option<(String, String)>>,
56}
57
58impl Default for CommandRegistry {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl CommandRegistry {
65 pub fn new() -> Self {
67 let mut reg = Self {
68 prefix: String::new(),
69 commands: HashMap::new(),
70 aliases: HashMap::new(),
71 #[cfg(feature = "async")]
72 async_commands: HashMap::new(),
73 #[cfg(feature = "async")]
74 async_aliases: HashMap::new(),
75 caps: HashSet::new(),
76 visibility_policy: None,
77 authorize_policy: None,
78 pre_hook: None,
79 post_hook: None,
80 error_formatter: None,
81 #[cfg(feature = "dispatch-cache")]
82 cache: std::sync::Mutex::new(None),
83 };
84
85 #[cfg(feature = "custom-commands")]
86 reg.load_custom_commands();
87
88 #[cfg(feature = "internal-commands")]
89 reg.load_internal_commands();
90
91 reg
92 }
93
94 #[cfg(feature = "async")]
96 pub fn register_async(&mut self, cmd: Box<dyn AsyncCommand>) {
97 let name = cmd.name().to_string();
98 self.async_commands.insert(name.clone(), cmd);
99 for &alias in self.async_commands[&name].aliases() {
100 if !self.async_commands.contains_key(alias) {
101 self.async_aliases.insert(alias.to_string(), name.clone());
102 }
103 }
104 }
105
106 pub fn set_prefix(&mut self, prefix: &str) {
109 self.prefix = prefix.to_string();
110 }
111
112 pub fn get_prefix(&self) -> &str {
115 &self.prefix
116 }
117
118 #[inline(always)]
121 pub fn get(&self, name: &str) -> Option<&dyn Command> {
122 self.commands.get(name).map(|b| b.as_ref())
123 }
124
125 #[inline(always)]
128 pub fn register(&mut self, cmd: Box<dyn Command>) {
129 let name = cmd.name().to_string();
131 self.commands.insert(name.clone(), cmd);
132
133 for &alias in self.commands[&name].aliases() {
135 if !self.commands.contains_key(alias) {
137 self.aliases.insert(alias.to_string(), name.clone());
139 }
140 }
141 }
142
143 pub fn all(&self) -> impl Iterator<Item = &Box<dyn Command>> {
146 self.commands.values()
147 }
148
149 #[cfg(feature = "async")]
151 pub fn all_async(&self) -> impl Iterator<Item = &Box<dyn AsyncCommand>> {
152 self.async_commands.values()
153 }
154
155 #[cfg(feature = "async")]
157 #[inline(always)]
158 pub async fn try_execute_async(&self, cmd: &str, args: &[String]) -> Result<(), ModCliError> {
159 if let Some(ref pre) = self.pre_hook {
160 pre(cmd, args);
161 }
162
163 let token: &str = if !self.prefix.is_empty() && cmd.len() > self.prefix.len() + 1 {
165 let (maybe_prefix, rest_with_colon) = cmd.split_at(self.prefix.len());
166 if maybe_prefix == self.prefix && rest_with_colon.as_bytes().first() == Some(&b':') {
167 &rest_with_colon[1..]
168 } else {
169 cmd
170 }
171 } else {
172 cmd
173 };
174
175 if let Some(command) = self.async_commands.get(token) {
177 if let Err(e) = self.is_authorized_async(args) {
178 return Err(ModCliError::InvalidUsage(e));
179 }
180 command.execute_async(args).await?;
181 if let Some(ref post) = self.post_hook {
182 post(cmd, args, Ok(()));
183 }
184 return Ok(());
185 }
186
187 if let Some(primary) = self.async_aliases.get(token) {
189 if let Some(command) = self.async_commands.get(primary.as_str()) {
190 if let Err(e) = self.is_authorized_async(args) {
191 return Err(ModCliError::InvalidUsage(e));
192 }
193 command.execute_async(args).await?;
194 if let Some(ref post) = self.post_hook {
195 post(cmd, args, Ok(()));
196 }
197 return Ok(());
198 }
199 }
200
201 if !args.is_empty() {
203 let combined = format!("{token}:{}", args[0]);
204 if let Some(command) = self.async_commands.get(combined.as_str()) {
205 let rest = &args[1..];
206 if let Err(e) = self.is_authorized_async(rest) {
207 return Err(ModCliError::InvalidUsage(e));
208 }
209 command.execute_async(rest).await?;
210 if let Some(ref post) = self.post_hook {
211 post(cmd, args, Ok(()));
212 }
213 return Ok(());
214 }
215 }
216
217 if let Some(ref post) = self.post_hook {
218 post(cmd, args, Err("unknown"));
219 }
220 Err(ModCliError::UnknownCommand(cmd.to_string()))
221 }
222
223 #[cfg(feature = "async")]
225 #[inline(always)]
226 pub async fn execute_async(&self, cmd: &str, args: &[String]) {
227 if let Err(err) = self.try_execute_async(cmd, args).await {
228 if let Some(ref fmt) = self.error_formatter {
229 hook::error(&fmt(&err));
230 } else {
231 match err {
232 ModCliError::InvalidUsage(msg) => hook::error(&format!("Invalid usage: {msg}")),
233 ModCliError::UnknownCommand(name) => hook::unknown(&format!(
234 "[{name}]. Type `help` or `--help` for a list of available commands."
235 )),
236 other => hook::error(&format!("{other}")),
237 }
238 }
239 }
240 }
241
242 #[cfg(feature = "async")]
244 #[inline(always)]
245 fn is_authorized_async(&self, args: &[String]) -> Result<(), String> {
246 if let Some(ref pol) = self.authorize_policy {
247 struct Dummy;
248 impl Command for Dummy {
249 fn name(&self) -> &str {
250 "__async_dummy__"
251 }
252 fn execute(&self, _args: &[String]) {}
253 }
254 return pol(&Dummy, &self.caps, args);
255 }
256 Ok(())
257 }
258
259 pub fn grant_cap<S: Into<String>>(&mut self, cap: S) {
261 self.caps.insert(cap.into());
262 }
263 pub fn revoke_cap(&mut self, cap: &str) {
264 self.caps.remove(cap);
265 }
266 pub fn has_cap(&self, cap: &str) -> bool {
267 self.caps.contains(cap)
268 }
269 pub fn set_caps<I, S>(&mut self, caps: I)
270 where
271 I: IntoIterator<Item = S>,
272 S: Into<String>,
273 {
274 self.caps.clear();
275 for c in caps {
276 self.caps.insert(c.into());
277 }
278 }
279
280 pub fn set_visibility_policy<F>(&mut self, f: F)
281 where
282 F: Fn(&dyn Command, &HashSet<String>) -> bool + Send + Sync + 'static,
283 {
284 self.visibility_policy = Some(Box::new(f));
285 }
286
287 pub fn set_authorize_policy<F>(&mut self, f: F)
288 where
289 F: Fn(&dyn Command, &HashSet<String>, &[String]) -> Result<(), String>
290 + Send
291 + Sync
292 + 'static,
293 {
294 self.authorize_policy = Some(Box::new(f));
295 }
296
297 pub fn set_pre_hook<F>(&mut self, f: F)
298 where
299 F: Fn(&str, &[String]) + Send + Sync + 'static,
300 {
301 self.pre_hook = Some(Box::new(f));
302 }
303
304 pub fn set_post_hook<F>(&mut self, f: F)
305 where
306 F: Fn(&str, &[String], Result<(), &str>) + Send + Sync + 'static,
307 {
308 self.post_hook = Some(Box::new(f));
309 }
310
311 pub fn set_error_formatter<F>(&mut self, f: F)
312 where
313 F: Fn(&crate::error::ModCliError) -> String + Send + Sync + 'static,
314 {
315 self.error_formatter = Some(Box::new(f));
316 }
317
318 #[inline(always)]
319 pub fn is_visible(&self, cmd: &dyn Command) -> bool {
320 if let Some(ref pol) = self.visibility_policy {
321 return pol(cmd, &self.caps);
322 }
323 if cmd.hidden() {
324 return false;
325 }
326 cmd.required_caps().iter().all(|c| self.caps.contains(*c))
327 }
328
329 #[inline(always)]
330 pub fn is_authorized(&self, cmd: &dyn Command, args: &[String]) -> Result<(), String> {
331 if let Some(ref pol) = self.authorize_policy {
332 return pol(cmd, &self.caps, args);
333 }
334 if cmd.required_caps().iter().all(|c| self.caps.contains(*c)) {
335 Ok(())
336 } else {
337 Err("Not authorized".into())
338 }
339 }
340
341 #[inline(always)]
360 pub fn execute(&self, cmd: &str, args: &[String]) {
361 if let Err(err) = self.try_execute(cmd, args) {
362 if let Some(ref fmt) = self.error_formatter {
363 hook::error(&fmt(&err));
364 } else {
365 match err {
366 ModCliError::InvalidUsage(msg) => hook::error(&format!("Invalid usage: {msg}")),
367 ModCliError::UnknownCommand(name) => hook::unknown(&format!(
368 "[{name}]. Type `help` or `--help` for a list of available commands."
369 )),
370 other => hook::error(&format!("{other}")),
371 }
372 }
373 }
374 }
375
376 #[inline(always)]
398 pub fn try_execute(&self, cmd: &str, args: &[String]) -> Result<(), ModCliError> {
399 if let Some(ref pre) = self.pre_hook {
400 pre(cmd, args);
401 }
402 let token: &str = if !self.prefix.is_empty() && cmd.len() > self.prefix.len() + 1 {
404 let (maybe_prefix, rest_with_colon) = cmd.split_at(self.prefix.len());
405 if maybe_prefix == self.prefix && rest_with_colon.as_bytes().first() == Some(&b':') {
406 &rest_with_colon[1..]
407 } else {
408 cmd
409 }
410 } else {
411 cmd
412 };
413
414 #[cfg(feature = "dispatch-cache")]
415 if let Ok(guard) = self.cache.lock() {
416 if let Some((ref t, ref p)) = *guard {
417 if t == token {
418 if let Some(command) = self.commands.get(p.as_str()) {
419 command.validate(args)?;
420 command.execute_with(args, self);
421 return Ok(());
422 }
423 }
424 }
425 }
426
427 if let Some(command) = self.commands.get(token) {
429 if let Err(err) = self.is_authorized(command.as_ref(), args) {
430 return Err(ModCliError::InvalidUsage(err));
431 }
432 command.validate(args)?;
433 command.execute_with(args, self);
434 #[cfg(feature = "dispatch-cache")]
435 if let Ok(mut guard) = self.cache.lock() {
436 *guard = Some((token.to_string(), token.to_string()));
437 }
438 if let Some(ref post) = self.post_hook {
439 post(cmd, args, Ok(()));
440 }
441 return Ok(());
442 }
443
444 if let Some(primary) = self.aliases.get(token) {
446 if let Some(command) = self.commands.get(primary.as_str()) {
447 if let Err(err) = self.is_authorized(command.as_ref(), args) {
448 return Err(ModCliError::InvalidUsage(err));
449 }
450 command.validate(args)?;
451 command.execute_with(args, self);
452 #[cfg(feature = "dispatch-cache")]
453 if let Ok(mut guard) = self.cache.lock() {
454 *guard = Some((token.to_string(), primary.clone()));
455 }
456 if let Some(ref post) = self.post_hook {
457 post(cmd, args, Ok(()));
458 }
459 return Ok(());
460 }
461 }
462
463 if !args.is_empty() {
465 let combined = format!("{token}:{}", args[0]);
466 if let Some(command) = self.commands.get(combined.as_str()) {
467 let rest = &args[1..];
468 if let Err(err) = self.is_authorized(command.as_ref(), rest) {
469 return Err(ModCliError::InvalidUsage(err));
470 }
471 command.validate(rest)?;
472 command.execute_with(rest, self);
473 if let Some(ref post) = self.post_hook {
474 post(cmd, args, Ok(()));
475 }
476 return Ok(());
477 }
478 }
479 let err = ModCliError::UnknownCommand(cmd.to_string());
480 if let Some(ref post) = self.post_hook {
481 post(cmd, args, Err("unknown"));
482 }
483 Err(err)
484 }
485
486 #[cfg(feature = "internal-commands")]
487 pub fn load_internal_commands(&mut self) {
488 self.register(Box::new(PingCommand));
489 self.register(Box::new(HelloCommand));
490 self.register(Box::new(FrameworkCommand));
491 self.register(Box::new(HelpCommand::new()));
492 }
493
494 pub fn len(&self) -> usize {
497 self.commands.len()
498 }
499
500 pub fn is_empty(&self) -> bool {
501 self.commands.is_empty()
502 }
503
504 #[cfg(feature = "custom-commands")]
505 pub fn load_custom_commands(&mut self) {
506 }
508}