claude_agent_sdk/commands/
mod.rs1use std::collections::HashMap;
6use std::fmt;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11#[derive(Debug, Clone)]
13pub enum CommandError {
14 NotFound(String),
16 ExecutionFailed(String),
18 InvalidName(String),
20 AlreadyRegistered(String),
22}
23
24impl fmt::Display for CommandError {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 CommandError::NotFound(name) => write!(f, "Command not found: {}", name),
28 CommandError::ExecutionFailed(msg) => write!(f, "Command execution failed: {}", msg),
29 CommandError::InvalidName(name) => write!(f, "Invalid command name: {}", name),
30 CommandError::AlreadyRegistered(name) => {
31 write!(f, "Command already registered: {}", name)
32 }
33 }
34 }
35}
36
37impl std::error::Error for CommandError {}
38
39pub type CommandHandler = Arc<
48 dyn Fn(&str, Vec<String>) -> Pin<Box<dyn Future<Output = Result<String, CommandError>> + Send>>
49 + Send
50 + Sync,
51>;
52
53#[derive(Clone)]
55pub struct SlashCommand {
56 pub name: String,
58 pub description: String,
60 pub handler: CommandHandler,
62}
63
64impl SlashCommand {
65 pub fn new(
72 name: impl Into<String>,
73 description: impl Into<String>,
74 handler: CommandHandler,
75 ) -> Self {
76 Self {
77 name: name.into(),
78 description: description.into(),
79 handler,
80 }
81 }
82
83 fn validate_name(name: &str) -> Result<(), CommandError> {
85 if name.is_empty() {
86 return Err(CommandError::InvalidName("Command name cannot be empty".to_string()));
87 }
88 if name.contains(' ') {
89 return Err(CommandError::InvalidName(
90 "Command name cannot contain spaces".to_string(),
91 ));
92 }
93 if !name.chars().next().unwrap().is_alphabetic() {
94 return Err(CommandError::InvalidName(
95 "Command name must start with a letter".to_string(),
96 ));
97 }
98 Ok(())
99 }
100}
101
102#[derive(Default)]
104pub struct CommandRegistry {
105 commands: HashMap<String, SlashCommand>,
106}
107
108impl CommandRegistry {
109 pub fn new() -> Self {
111 Self {
112 commands: HashMap::new(),
113 }
114 }
115
116 pub fn register(&mut self, command: SlashCommand) -> Result<(), CommandError> {
125 SlashCommand::validate_name(&command.name)?;
126
127 if self.commands.contains_key(&command.name) {
128 return Err(CommandError::AlreadyRegistered(command.name));
129 }
130
131 self.commands.insert(command.name.clone(), command);
132 Ok(())
133 }
134
135 pub async fn execute(&self, name: &str, args: Vec<String>) -> Result<String, CommandError> {
145 let command = self
146 .commands
147 .get(name)
148 .ok_or_else(|| CommandError::NotFound(name.to_string()))?;
149
150 (command.handler)(name, args).await
151 }
152
153 pub fn exists(&self, name: &str) -> bool {
155 self.commands.contains_key(name)
156 }
157
158 pub fn get(&self, name: &str) -> Option<&SlashCommand> {
160 self.commands.get(name)
161 }
162
163 pub fn list_names(&self) -> Vec<String> {
165 self.commands.keys().cloned().collect()
166 }
167
168 pub fn list_all(&self) -> Vec<&SlashCommand> {
170 self.commands.values().collect()
171 }
172
173 pub fn len(&self) -> usize {
175 self.commands.len()
176 }
177
178 pub fn is_empty(&self) -> bool {
180 self.commands.is_empty()
181 }
182
183 pub fn unregister(&mut self, name: &str) -> Result<(), CommandError> {
189 self.commands
190 .remove(name)
191 .ok_or_else(|| CommandError::NotFound(name.to_string()))?;
192 Ok(())
193 }
194
195 pub fn clear(&mut self) {
197 self.commands.clear();
198 }
199}
200
201impl fmt::Debug for SlashCommand {
202 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
203 f.debug_struct("SlashCommand")
204 .field("name", &self.name)
205 .field("description", &self.description)
206 .field("handler", &"<function>")
207 .finish()
208 }
209}
210
211impl fmt::Debug for CommandRegistry {
212 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
213 f.debug_struct("CommandRegistry")
214 .field("commands_count", &self.commands.len())
215 .field("command_names", &self.list_names())
216 .finish()
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 fn create_test_command(name: &str, description: &str) -> SlashCommand {
226 SlashCommand::new(
227 name,
228 description,
229 Arc::new(|_name, args| {
230 Box::pin(async move {
231 Ok(format!("Executed with args: {:?}", args))
232 })
233 }),
234 )
235 }
236
237 #[test]
238 fn test_registry_creation() {
239 let registry = CommandRegistry::new();
240 assert!(registry.is_empty());
241 assert_eq!(registry.len(), 0);
242 }
243
244 #[test]
245 fn test_registry_default() {
246 let registry = CommandRegistry::default();
247 assert!(registry.is_empty());
248 }
249
250 #[test]
251 fn test_register_command() {
252 let mut registry = CommandRegistry::new();
253 let cmd = create_test_command("test", "A test command");
254
255 assert!(registry.register(cmd).is_ok());
256 assert_eq!(registry.len(), 1);
257 assert!(registry.exists("test"));
258 }
259
260 #[test]
261 fn test_register_duplicate_fails() {
262 let mut registry = CommandRegistry::new();
263 let cmd1 = create_test_command("test", "First command");
264 let cmd2 = create_test_command("test", "Duplicate command");
265
266 assert!(registry.register(cmd1).is_ok());
267 let result = registry.register(cmd2);
268 assert!(matches!(result, Err(CommandError::AlreadyRegistered(_))));
269 }
270
271 #[test]
272 fn test_invalid_name_empty() {
273 let cmd = SlashCommand::new(
274 "",
275 "description",
276 Arc::new(|_name, _args| Box::pin(async { Ok(String::new()) })),
277 );
278
279 let result = SlashCommand::validate_name(&cmd.name);
280 assert!(matches!(result, Err(CommandError::InvalidName(_))));
281 }
282
283 #[test]
284 fn test_invalid_name_contains_space() {
285 let cmd = SlashCommand::new(
286 "test command",
287 "description",
288 Arc::new(|_name, _args| Box::pin(async { Ok(String::new()) })),
289 );
290
291 let result = SlashCommand::validate_name(&cmd.name);
292 assert!(matches!(result, Err(CommandError::InvalidName(_))));
293 }
294
295 #[test]
296 fn test_invalid_name_starts_with_number() {
297 let cmd = SlashCommand::new(
298 "123test",
299 "description",
300 Arc::new(|_name, _args| Box::pin(async { Ok(String::new()) })),
301 );
302
303 let result = SlashCommand::validate_name(&cmd.name);
304 assert!(matches!(result, Err(CommandError::InvalidName(_))));
305 }
306
307 #[test]
308 fn test_valid_name() {
309 assert!(SlashCommand::validate_name("test").is_ok());
310 assert!(SlashCommand::validate_name("test_command").is_ok());
311 assert!(SlashCommand::validate_name("test-command").is_ok());
312 assert!(SlashCommand::validate_name("TestCommand").is_ok());
313 }
314
315 #[test]
316 fn test_execute_command() {
317 let mut registry = CommandRegistry::new();
318 let cmd = create_test_command("echo", "Echo arguments");
319 registry.register(cmd).unwrap();
320
321 let rt = tokio::runtime::Runtime::new().unwrap();
322 let result = rt.block_on(registry.execute("echo", vec!["hello".to_string()]));
323
324 assert!(result.is_ok());
325 assert!(result.unwrap().contains("hello"));
326 }
327
328 #[test]
329 fn test_execute_nonexistent_command() {
330 let registry = CommandRegistry::new();
331
332 let rt = tokio::runtime::Runtime::new().unwrap();
333 let result = rt.block_on(registry.execute("nonexistent", vec![]));
334
335 assert!(matches!(result, Err(CommandError::NotFound(_))));
336 }
337
338 #[test]
339 fn test_get_command() {
340 let mut registry = CommandRegistry::new();
341 let cmd = create_test_command("test", "A test command");
342 registry.register(cmd).unwrap();
343
344 let retrieved = registry.get("test");
345 assert!(retrieved.is_some());
346 assert_eq!(retrieved.unwrap().name, "test");
347 }
348
349 #[test]
350 fn test_get_nonexistent_command() {
351 let registry = CommandRegistry::new();
352 assert!(registry.get("nonexistent").is_none());
353 }
354
355 #[test]
356 fn test_list_names() {
357 let mut registry = CommandRegistry::new();
358 registry.register(create_test_command("cmd1", "First")).unwrap();
359 registry.register(create_test_command("cmd2", "Second")).unwrap();
360 registry.register(create_test_command("cmd3", "Third")).unwrap();
361
362 let names = registry.list_names();
363 assert_eq!(names.len(), 3);
364 assert!(names.contains(&"cmd1".to_string()));
365 assert!(names.contains(&"cmd2".to_string()));
366 assert!(names.contains(&"cmd3".to_string()));
367 }
368
369 #[test]
370 fn test_list_all() {
371 let mut registry = CommandRegistry::new();
372 registry.register(create_test_command("cmd1", "First")).unwrap();
373 registry.register(create_test_command("cmd2", "Second")).unwrap();
374
375 let commands = registry.list_all();
376 assert_eq!(commands.len(), 2);
377 }
378
379 #[test]
380 fn test_unregister_command() {
381 let mut registry = CommandRegistry::new();
382 registry.register(create_test_command("test", "A test command")).unwrap();
383
384 assert!(registry.unregister("test").is_ok());
385 assert!(!registry.exists("test"));
386 assert_eq!(registry.len(), 0);
387 }
388
389 #[test]
390 fn test_unregister_nonexistent_command() {
391 let mut registry = CommandRegistry::new();
392 let result = registry.unregister("nonexistent");
393 assert!(matches!(result, Err(CommandError::NotFound(_))));
394 }
395
396 #[test]
397 fn test_clear_commands() {
398 let mut registry = CommandRegistry::new();
399 registry.register(create_test_command("cmd1", "First")).unwrap();
400 registry.register(create_test_command("cmd2", "Second")).unwrap();
401
402 registry.clear();
403 assert!(registry.is_empty());
404 assert_eq!(registry.len(), 0);
405 }
406
407 #[test]
408 fn test_command_error_display() {
409 assert!(format!("{}", CommandError::NotFound("test".to_string())).contains("test"));
410 assert!(format!("{}", CommandError::ExecutionFailed("error".to_string())).contains("error"));
411 assert!(format!("{}", CommandError::InvalidName("bad".to_string())).contains("bad"));
412 assert!(format!("{}", CommandError::AlreadyRegistered("cmd".to_string())).contains("cmd"));
413 }
414
415 #[test]
416 fn test_complex_command_handler() {
417 let mut registry = CommandRegistry::new();
418
419 let cmd = SlashCommand::new(
420 "sum",
421 "Sum numbers",
422 Arc::new(|_name, args| {
423 Box::pin(async move {
424 let sum: i32 = args
425 .iter()
426 .map(|s| s.parse::<i32>().unwrap_or(0))
427 .sum();
428 Ok(format!("Sum: {}", sum))
429 })
430 }),
431 );
432
433 registry.register(cmd).unwrap();
434
435 let rt = tokio::runtime::Runtime::new().unwrap();
436 let result = rt.block_on(registry.execute(
437 "sum",
438 vec!["10".to_string(), "20".to_string(), "30".to_string()],
439 ));
440
441 assert!(result.is_ok());
442 assert_eq!(result.unwrap(), "Sum: 60");
443 }
444
445 #[test]
446 fn test_async_error_handling() {
447 let mut registry = CommandRegistry::new();
448
449 let cmd = SlashCommand::new(
450 "failing",
451 "Always fails",
452 Arc::new(|_name, _args| {
453 Box::pin(async move {
454 Err(CommandError::ExecutionFailed("Intentional failure".to_string()))
455 })
456 }),
457 );
458
459 registry.register(cmd).unwrap();
460
461 let rt = tokio::runtime::Runtime::new().unwrap();
462 let result = rt.block_on(registry.execute("failing", vec![]));
463
464 assert!(matches!(result, Err(CommandError::ExecutionFailed(_))));
465 }
466}