use std::{collections::HashMap, sync::Arc};
#[derive(Debug, Clone)]
pub struct CommandContext {
pub command: String,
pub args: Vec<String>,
pub raw: String,
pub author_id: String,
pub channel_id: String,
}
type Handler = Arc<dyn Fn(CommandContext) + Send + Sync>;
type BeforeHook = Arc<dyn Fn(&CommandContext) -> bool + Send + Sync>;
type AfterHook = Arc<dyn Fn(&CommandContext) + Send + Sync>;
type UnrecognisedHook = Arc<dyn Fn(&CommandContext) + Send + Sync>;
pub struct CommandFramework {
prefix: String,
commands: HashMap<String, Handler>,
before: Option<BeforeHook>,
after: Option<AfterHook>,
on_unrecognised: Option<UnrecognisedHook>,
case_insensitive: bool,
allow_dm: bool,
}
impl CommandFramework {
pub fn new(prefix: impl Into<String>) -> Self {
Self {
prefix: prefix.into(),
commands: HashMap::new(),
before: None,
after: None,
on_unrecognised: None,
case_insensitive: true,
allow_dm: true,
}
}
pub fn case_insensitive(mut self, value: bool) -> Self {
self.case_insensitive = value;
self
}
pub fn allow_dm(mut self, value: bool) -> Self {
self.allow_dm = value;
self
}
pub fn command<F>(&mut self, name: &str, handler: F)
where
F: Fn(CommandContext) + Send + Sync + 'static,
{
let key = if self.case_insensitive { name.to_lowercase() } else { name.to_string() };
self.commands.insert(key, Arc::new(handler));
}
pub fn before<F>(&mut self, hook: F)
where
F: Fn(&CommandContext) -> bool + Send + Sync + 'static,
{
self.before = Some(Arc::new(hook));
}
pub fn after<F>(&mut self, hook: F)
where
F: Fn(&CommandContext) + Send + Sync + 'static,
{
self.after = Some(Arc::new(hook));
}
pub fn on_unrecognised<F>(&mut self, hook: F)
where
F: Fn(&CommandContext) + Send + Sync + 'static,
{
self.on_unrecognised = Some(Arc::new(hook));
}
pub fn dispatch(&self, content: &str, author_id: &str, channel_id: &str) -> bool {
let rest = match content.strip_prefix(&self.prefix) {
Some(r) => r.trim_start(),
None => return false,
};
let mut parts = rest.splitn(2, char::is_whitespace);
let cmd_raw = match parts.next() {
Some(c) if !c.is_empty() => c,
_ => return false,
};
let args_str = parts.next().unwrap_or("").trim();
let args: Vec<String> = if args_str.is_empty() { vec![] } else { args_str.split_whitespace().map(str::to_string).collect() };
let key = if self.case_insensitive { cmd_raw.to_lowercase() } else { cmd_raw.to_string() };
let ctx = CommandContext {
command: key.clone(),
args,
raw: content.to_string(),
author_id: author_id.to_string(),
channel_id: channel_id.to_string(),
};
if let Some(handler) = self.commands.get(&key) {
if let Some(before) = &self.before {
if !before(&ctx) {
return false;
}
}
handler(ctx.clone());
if let Some(after) = &self.after {
after(&ctx);
}
true
} else {
if let Some(hook) = &self.on_unrecognised {
hook(&ctx);
}
false
}
}
}
#[cfg(test)]
mod tests {
use std::sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
};
use super::*;
#[test]
fn basic_dispatch() {
let fired = Arc::new(AtomicBool::new(false));
let fired2 = Arc::clone(&fired);
let mut fw = CommandFramework::new("!");
fw.command("ping", move |_ctx| {
fired2.store(true, Ordering::Relaxed);
});
assert!(fw.dispatch("!ping", "u1", "c1"));
assert!(fired.load(Ordering::Relaxed));
}
#[test]
fn unrecognised_command() {
let fired = Arc::new(AtomicBool::new(false));
let fired2 = Arc::clone(&fired);
let mut fw = CommandFramework::new("!");
fw.on_unrecognised(move |_ctx| {
fired2.store(true, Ordering::Relaxed);
});
assert!(!fw.dispatch("!unknown", "u1", "c1"));
assert!(fired.load(Ordering::Relaxed));
}
#[test]
fn no_prefix_not_dispatched() {
let mut fw = CommandFramework::new("!");
fw.command("ping", |_| {});
assert!(!fw.dispatch("ping", "u1", "c1"));
}
#[test]
fn args_are_split() {
let args_out = Arc::new(std::sync::Mutex::new(vec![]));
let args_clone = Arc::clone(&args_out);
let mut fw = CommandFramework::new("!");
fw.command("echo", move |ctx| {
*args_clone.lock().unwrap() = ctx.args.clone();
});
fw.dispatch("!echo hello world", "u1", "c1");
assert_eq!(*args_out.lock().unwrap(), vec!["hello", "world"]);
}
#[test]
fn before_hook_can_abort() {
let fired = Arc::new(AtomicBool::new(false));
let fired2 = Arc::clone(&fired);
let mut fw = CommandFramework::new("!");
fw.before(|_ctx| false); fw.command("ping", move |_| {
fired2.store(true, Ordering::Relaxed);
});
assert!(!fw.dispatch("!ping", "u1", "c1"));
assert!(!fired.load(Ordering::Relaxed));
}
#[test]
fn after_hook_fires() {
let count = Arc::new(AtomicUsize::new(0));
let c2 = Arc::clone(&count);
let mut fw = CommandFramework::new("!");
fw.command("ping", |_| {});
fw.after(move |_ctx| {
c2.fetch_add(1, Ordering::Relaxed);
});
fw.dispatch("!ping", "u1", "c1");
assert_eq!(count.load(Ordering::Relaxed), 1);
}
#[test]
fn case_insensitive_match() {
let fired = Arc::new(AtomicBool::new(false));
let fired2 = Arc::clone(&fired);
let mut fw = CommandFramework::new("!");
fw.command("ping", move |_| {
fired2.store(true, Ordering::Relaxed);
});
assert!(fw.dispatch("!PING", "u1", "c1"));
assert!(fired.load(Ordering::Relaxed));
}
}