robespierre/framework/
standard.rs

1use std::{
2    borrow::{Borrow, Cow},
3    collections::HashSet,
4    fmt,
5    future::Future,
6    pin::Pin,
7    sync::Arc,
8};
9
10#[cfg(feature = "cache")]
11use robespierre_cache::{Cache, HasCache};
12use robespierre_models::{
13    channels::Channel,
14    channels::{ChannelPermissions, Message, MessageContent},
15    id::UserId,
16    servers::ServerPermissions,
17};
18
19use crate::{
20    model::{MessageExt, ServerIdExt},
21    Context, HasHttp, UserData,
22};
23
24use super::Framework;
25
26#[cfg(feature = "framework-macros")]
27pub mod macros {
28    pub use robespierre_fw_macros::command;
29}
30
31pub mod extractors;
32
33#[derive(Default)]
34pub struct StdFwConfig {
35    prefix: Cow<'static, str>,
36    owners: HashSet<UserId>,
37}
38
39impl StdFwConfig {
40    pub fn prefix(self, prefix: impl Into<Cow<'static, str>>) -> Self {
41        Self {
42            prefix: prefix.into(),
43            ..self
44        }
45    }
46
47    pub fn owners(self, owners: HashSet<UserId>) -> Self {
48        Self { owners, ..self }
49    }
50}
51
52#[derive(Default)]
53pub struct StandardFramework {
54    root_group: RootGroup,
55    normal_message: Option<NormalMessageHandlerCode>,
56    unknown_command: Option<UnknownCommandHandlerCode>,
57    after: Option<AfterHandlerCode>,
58    config: StdFwConfig,
59}
60
61impl StandardFramework {
62    pub fn configure<F>(self, f: F) -> Self
63    where
64        F: FnOnce(StdFwConfig) -> StdFwConfig,
65    {
66        Self {
67            config: f(StdFwConfig::default()),
68            ..self
69        }
70    }
71
72    pub fn normal_message(self, handler: impl Into<NormalMessageHandlerCode>) -> Self {
73        Self {
74            normal_message: Some(handler.into()),
75            ..self
76        }
77    }
78
79    pub fn unknown_command(self, handler: impl Into<UnknownCommandHandlerCode>) -> Self {
80        Self {
81            unknown_command: Some(handler.into()),
82            ..self
83        }
84    }
85
86    pub fn after(self, handler: impl Into<AfterHandlerCode>) -> Self {
87        Self {
88            after: Some(handler.into()),
89            ..self
90        }
91    }
92
93    pub fn group<F>(mut self, f: F) -> Self
94    where
95        F: for<'a> FnOnce(Group) -> Group,
96    {
97        let group = Group {
98            name: "".into(),
99            commands: vec![],
100            default_invoke: None,
101            subgroups: vec![],
102        };
103        let group = f(group);
104        debug_assert!(
105            group.name.as_ref() != "",
106            "Name of group is \"\"; did you forget to set name of group?"
107        );
108
109        self.root_group.subgroups.push(group);
110        self
111    }
112
113    async fn invoke_unknown_command(&self, ctx: &FwContext, message: &Arc<Message>) {
114        if let Some(code) = self.unknown_command.as_ref() {
115            code.invoke(ctx, message).await;
116        }
117    }
118
119    async fn invoke_after<'a>(
120        &'a self,
121        ctx: &'a FwContext,
122        message: &'a Arc<Message>,
123        result: CommandResult,
124    ) {
125        if let Some(code) = self.after.as_ref() {
126            code.invoke(ctx, message, result).await;
127        }
128    }
129}
130
131#[async_trait::async_trait]
132impl Framework for StandardFramework {
133    type Context = FwContext;
134
135    async fn handle(&self, ctx: Self::Context, message: &Arc<Message>) {
136        let prefix: &str = self.config.prefix.borrow();
137        let message_content = match &message.content {
138            MessageContent::Content(c) => c,
139            MessageContent::SystemMessage(_) => return,
140        };
141        if let Some(command) = message_content.strip_prefix(prefix) {
142            let command = self.root_group.find_command(command);
143
144            match command {
145                Some((cmd, args)) => {
146                    if cmd.owners_only && !self.config.owners.contains(&message.author) {
147                        self.invoke_unknown_command(&ctx, message).await;
148                        return;
149                    }
150
151                    let result = cmd.invoke(&ctx, message, args).await;
152
153                    self.invoke_after(&ctx, message, result).await;
154                }
155                None => {
156                    self.invoke_unknown_command(&ctx, message).await;
157                }
158            }
159        } else if let Some(code) = self.normal_message.as_ref() {
160            code.invoke(&ctx, message).await;
161        }
162    }
163}
164
165#[derive(Default)]
166pub struct RootGroup {
167    subgroups: Vec<Group>,
168}
169
170impl RootGroup {
171    pub(crate) fn find_command<'a, 'b>(
172        &'a self,
173        command: &'b str,
174    ) -> Option<(&'a Command, &'b str)> {
175        self.subgroups
176            .iter()
177            .find_map(|it| it.find_command(command))
178    }
179}
180
181#[derive(Default)]
182pub struct Group {
183    name: Cow<'static, str>,
184    subgroups: Vec<Group>,
185    commands: Vec<Command>,
186    default_invoke: Option<Command>,
187}
188
189impl Group {
190    pub fn name(self, name: impl Into<Cow<'static, str>>) -> Self {
191        Self {
192            name: name.into(),
193            ..self
194        }
195    }
196
197    pub fn subgroup<F>(mut self, f: F) -> Self
198    where
199        F: FnOnce(Group) -> Group,
200    {
201        let group = f(Group::default());
202        debug_assert!(
203            group.name.as_ref() != "",
204            "Name of group is \"\"; did you forget to set name of group?"
205        );
206
207        self.subgroups.push(group);
208        self
209    }
210
211    pub fn command<F>(mut self, f: F) -> Self
212    where
213        F: FnOnce() -> Command,
214    {
215        let command = f();
216        self.commands.push(command);
217        self
218    }
219
220    pub fn default_command<F>(self, f: F) -> Self
221    where
222        F: FnOnce() -> Command,
223    {
224        let command = f();
225        Self {
226            default_invoke: Some(command),
227            ..self
228        }
229    }
230}
231
232impl Group {
233    pub(crate) fn find_command<'a, 'b>(
234        &'a self,
235        command: &'b str,
236    ) -> Option<(&'a Command, &'b str)> {
237        self.subgroups
238            .iter()
239            .find_map(|group| {
240                let group_name: &str = group.name.borrow();
241                if let Some(rest) = command.strip_prefix(group_name) {
242                    if rest.trim() == "" {
243                        Some((group.default_invoke.as_ref()?, ""))
244                    } else if rest.starts_with(char::is_whitespace) {
245                        group.find_command(rest.trim_start())
246                    } else {
247                        None
248                    }
249                } else {
250                    None
251                }
252            })
253            .or_else(|| {
254                self.commands.iter().find_map(|c| {
255                    let command_name: &str = c.name.borrow();
256                    let rest = std::iter::once(command_name)
257                        .chain(c.aliases.iter().map(|it| -> &str { it }))
258                        .find_map(|name| command.strip_prefix(name));
259                    if let Some(rest) = rest {
260                        if rest.trim() == "" {
261                            Some((c, ""))
262                        } else if rest.starts_with(char::is_whitespace) {
263                            Some((c, rest.trim_start()))
264                        } else {
265                            None
266                        }
267                    } else {
268                        None
269                    }
270                })
271            })
272            .or_else(|| Some((self.default_invoke.as_ref()?, command.trim_start())))
273    }
274}
275
276#[derive(Debug)]
277pub struct Command {
278    name: Cow<'static, str>,
279    aliases: smallvec::SmallVec<[Cow<'static, str>; 4]>,
280    code: CommandCode,
281    required_perms: (ServerPermissions, ChannelPermissions),
282    owners_only: bool,
283}
284
285impl Command {
286    pub fn new(name: impl Into<Cow<'static, str>>, code: impl Into<CommandCode>) -> Self {
287        Self {
288            name: name.into(),
289            aliases: smallvec::SmallVec::default(),
290            code: code.into(),
291            required_perms: (ServerPermissions::empty(), ChannelPermissions::empty()),
292            owners_only: false,
293        }
294    }
295
296    pub fn alias(mut self, alias: impl Into<Cow<'static, str>>) -> Self {
297        self.aliases.push(alias.into());
298        self
299    }
300
301    pub fn required_server_permissions(mut self, perms: impl Into<ServerPermissions>) -> Self {
302        self.required_perms.0 = perms.into();
303        self
304    }
305
306    pub fn required_channel_permissions(mut self, perms: impl Into<ChannelPermissions>) -> Self {
307        self.required_perms.1 = perms.into();
308        self
309    }
310
311    pub fn owners_only(self, owners_only: impl Into<bool>) -> Self {
312        Self {
313            owners_only: owners_only.into(),
314            ..self
315        }
316    }
317}
318
319#[derive(Debug, thiserror::Error)]
320#[error("one or more of the following permissions are missing: (server: {0:?}, channel: {1:?})")]
321pub struct MissingPermissions(ServerPermissions, ChannelPermissions);
322
323async fn check_perms<'a>(
324    ctx: &'a FwContext,
325    message: &'a Message,
326    sp: ServerPermissions,
327    cp: ChannelPermissions,
328) -> CommandResult {
329    let channel = message.channel(ctx).await?;
330
331    match &channel {
332        Channel::SavedMessages(..) | Channel::DirectMessage(..) => Ok::<_, CommandError>(()),
333        ch @ Channel::Group(..) => {
334            let check = robespierre_models::permissions_utils::user_has_permissions_in_group(
335                message.author,
336                ch,
337                cp,
338            );
339
340            if check {
341                Ok::<_, CommandError>(())
342            } else {
343                Err(MissingPermissions(ServerPermissions::empty(), cp).into())
344            }
345        }
346        ch @ Channel::TextChannel { .. } | ch @ Channel::VoiceChannel { .. } => {
347            let server = ch.server_id().unwrap();
348            let member = server.member(ctx, message.author).await?;
349            let server = server.server(ctx).await?;
350
351            let check = robespierre_models::permissions_utils::member_has_permissions_in_channel(
352                &member, sp, &server, cp, ch,
353            );
354
355            if check {
356                Ok::<_, CommandError>(())
357            } else {
358                Err(MissingPermissions(sp, cp).into())
359            }
360        }
361    }
362}
363
364impl Command {
365    fn invoke<'a>(
366        &'a self,
367        ctx: &'a FwContext,
368        message: &'a Arc<Message>,
369        args: &'a str,
370    ) -> impl Future<Output = CommandResult> + Send + 'a {
371        let (sp, cp) = self.required_perms;
372        async move {
373            check_perms(ctx, message, sp, cp).await?;
374
375            self.code.invoke(ctx, message, args).await?;
376
377            Ok::<_, CommandError>(())
378        }
379    }
380}
381
382#[derive(Clone)]
383pub struct FwContext {
384    ctx: Context,
385}
386
387impl HasHttp for FwContext {
388    fn get_http(&self) -> &robespierre_http::Http {
389        self.ctx.get_http()
390    }
391}
392
393#[cfg(feature = "cache")]
394impl HasCache for FwContext {
395    fn get_cache(&self) -> Option<&Cache> {
396        self.ctx.get_cache()
397    }
398}
399
400impl AsRef<Context> for FwContext {
401    fn as_ref(&self) -> &Context {
402        &self.ctx
403    }
404}
405
406impl From<Context> for FwContext {
407    fn from(ctx: Context) -> Self {
408        Self { ctx }
409    }
410}
411
412#[async_trait::async_trait]
413impl UserData for FwContext {
414    async fn data_lock_read(&self) -> tokio::sync::RwLockReadGuard<typemap::ShareMap> {
415        self.ctx.data_lock_read().await
416    }
417
418    async fn data_lock_write(&self) -> tokio::sync::RwLockWriteGuard<typemap::ShareMap> {
419        self.ctx.data_lock_write().await
420    }
421}
422
423pub type AfterHandlerCodeFn = for<'a> fn(
424    ctx: &'a FwContext,
425    message: &'a Message,
426    result: CommandResult,
427) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
428
429pub enum AfterHandlerCode {
430    Binary(AfterHandlerCodeFn),
431    #[cfg(feature = "interpreter")]
432    Interpreted(String),
433}
434
435impl From<AfterHandlerCodeFn> for AfterHandlerCode {
436    fn from(code: AfterHandlerCodeFn) -> Self {
437        Self::Binary(code)
438    }
439}
440
441impl AfterHandlerCode {
442    pub async fn invoke<'a>(
443        &'a self,
444        ctx: &'a FwContext,
445        message: &'a Message,
446        result: CommandResult,
447    ) {
448        match self {
449            AfterHandlerCode::Binary(f) => f(ctx, message, result).await,
450            #[cfg(feature = "interpreter")]
451            AfterHandlerCode::Interpreted(code) => todo!(),
452        }
453    }
454}
455
456pub type CommandCodeFn = for<'a> fn(
457    ctx: &'a FwContext,
458    message: &'a Arc<Message>,
459    args: &'a str,
460) -> Pin<Box<dyn Future<Output = CommandResult> + Send + 'a>>;
461
462pub enum CommandCode {
463    Binary(CommandCodeFn),
464    #[cfg(feature = "interpreter")]
465    Interpreted(String),
466}
467
468impl From<CommandCodeFn> for CommandCode {
469    fn from(code: CommandCodeFn) -> Self {
470        Self::Binary(code)
471    }
472}
473
474impl fmt::Debug for CommandCode {
475    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476        match self {
477            Self::Binary(code) => f
478                .debug_tuple("Binary")
479                .field(&format_args!("{:p}", code as *const _))
480                .finish(),
481            #[cfg(feature = "interpreter")]
482            Self::Interpreted(code) => f.debug_tuple("Interpreted").field(code).finish(),
483        }
484    }
485}
486
487pub type CommandError = Box<dyn std::error::Error + Send + Sync + 'static>;
488pub type CommandResult<T = ()> = Result<T, CommandError>;
489
490impl CommandCode {
491    pub async fn invoke(
492        &self,
493        ctx: &FwContext,
494        message: &Arc<Message>,
495        args: &str,
496    ) -> CommandResult {
497        match self {
498            CommandCode::Binary(f) => f(ctx, message, args).await,
499            #[cfg(feature = "interpreter")]
500            CommandCode::Interpreted(code) => todo!(),
501        }
502    }
503}
504
505pub type UnknownCommandHandlerCodeFn = for<'a> fn(
506    ctx: &'a FwContext,
507    message: &'a Message,
508) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
509
510pub enum UnknownCommandHandlerCode {
511    Binary(UnknownCommandHandlerCodeFn),
512    #[cfg(feature = "interpreter")]
513    Interpreted(String),
514}
515
516impl UnknownCommandHandlerCode {
517    pub async fn invoke(&self, ctx: &FwContext, message: &Message) {
518        match self {
519            Self::Binary(f) => f(ctx, message).await,
520            #[cfg(feature = "interpreter")]
521            Self::Interpreted(code) => todo!(),
522        }
523    }
524}
525
526pub type NormalMessageHandlerCodeFn = for<'a> fn(
527    ctx: &'a FwContext,
528    message: &'a Message,
529) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
530pub enum NormalMessageHandlerCode {
531    Binary(NormalMessageHandlerCodeFn),
532    #[cfg(feature = "interpreter")]
533    Interpreted(String),
534}
535
536impl From<NormalMessageHandlerCodeFn> for NormalMessageHandlerCode {
537    fn from(code: NormalMessageHandlerCodeFn) -> Self {
538        Self::Binary(code)
539    }
540}
541
542impl NormalMessageHandlerCode {
543    pub async fn invoke(&self, ctx: &FwContext, message: &Message) {
544        match self {
545            Self::Binary(f) => f(ctx, message).await,
546            #[cfg(feature = "interpreter")]
547            Self::Interpreted(code) => todo!(),
548        }
549    }
550}
551
552#[cfg(test)]
553mod test;