1use std::{
2 borrow::{Borrow, Cow},
3 collections::HashSet,
4 fmt,
5 future::Future,
6 pin::Pin,
7 sync::Arc,
8};
9
10#[cfg(feature = "cache")]
11use robespierre_cache::{Cache, HasCache};
12use robespierre_models::{
13 channels::Channel,
14 channels::{ChannelPermissions, Message, MessageContent},
15 id::UserId,
16 servers::ServerPermissions,
17};
18
19use crate::{
20 model::{MessageExt, ServerIdExt},
21 Context, HasHttp, UserData,
22};
23
24use super::Framework;
25
26#[cfg(feature = "framework-macros")]
27pub mod macros {
28 pub use robespierre_fw_macros::command;
29}
30
31pub mod extractors;
32
33#[derive(Default)]
34pub struct StdFwConfig {
35 prefix: Cow<'static, str>,
36 owners: HashSet<UserId>,
37}
38
39impl StdFwConfig {
40 pub fn prefix(self, prefix: impl Into<Cow<'static, str>>) -> Self {
41 Self {
42 prefix: prefix.into(),
43 ..self
44 }
45 }
46
47 pub fn owners(self, owners: HashSet<UserId>) -> Self {
48 Self { owners, ..self }
49 }
50}
51
52#[derive(Default)]
53pub struct StandardFramework {
54 root_group: RootGroup,
55 normal_message: Option<NormalMessageHandlerCode>,
56 unknown_command: Option<UnknownCommandHandlerCode>,
57 after: Option<AfterHandlerCode>,
58 config: StdFwConfig,
59}
60
61impl StandardFramework {
62 pub fn configure<F>(self, f: F) -> Self
63 where
64 F: FnOnce(StdFwConfig) -> StdFwConfig,
65 {
66 Self {
67 config: f(StdFwConfig::default()),
68 ..self
69 }
70 }
71
72 pub fn normal_message(self, handler: impl Into<NormalMessageHandlerCode>) -> Self {
73 Self {
74 normal_message: Some(handler.into()),
75 ..self
76 }
77 }
78
79 pub fn unknown_command(self, handler: impl Into<UnknownCommandHandlerCode>) -> Self {
80 Self {
81 unknown_command: Some(handler.into()),
82 ..self
83 }
84 }
85
86 pub fn after(self, handler: impl Into<AfterHandlerCode>) -> Self {
87 Self {
88 after: Some(handler.into()),
89 ..self
90 }
91 }
92
93 pub fn group<F>(mut self, f: F) -> Self
94 where
95 F: for<'a> FnOnce(Group) -> Group,
96 {
97 let group = Group {
98 name: "".into(),
99 commands: vec![],
100 default_invoke: None,
101 subgroups: vec![],
102 };
103 let group = f(group);
104 debug_assert!(
105 group.name.as_ref() != "",
106 "Name of group is \"\"; did you forget to set name of group?"
107 );
108
109 self.root_group.subgroups.push(group);
110 self
111 }
112
113 async fn invoke_unknown_command(&self, ctx: &FwContext, message: &Arc<Message>) {
114 if let Some(code) = self.unknown_command.as_ref() {
115 code.invoke(ctx, message).await;
116 }
117 }
118
119 async fn invoke_after<'a>(
120 &'a self,
121 ctx: &'a FwContext,
122 message: &'a Arc<Message>,
123 result: CommandResult,
124 ) {
125 if let Some(code) = self.after.as_ref() {
126 code.invoke(ctx, message, result).await;
127 }
128 }
129}
130
131#[async_trait::async_trait]
132impl Framework for StandardFramework {
133 type Context = FwContext;
134
135 async fn handle(&self, ctx: Self::Context, message: &Arc<Message>) {
136 let prefix: &str = self.config.prefix.borrow();
137 let message_content = match &message.content {
138 MessageContent::Content(c) => c,
139 MessageContent::SystemMessage(_) => return,
140 };
141 if let Some(command) = message_content.strip_prefix(prefix) {
142 let command = self.root_group.find_command(command);
143
144 match command {
145 Some((cmd, args)) => {
146 if cmd.owners_only && !self.config.owners.contains(&message.author) {
147 self.invoke_unknown_command(&ctx, message).await;
148 return;
149 }
150
151 let result = cmd.invoke(&ctx, message, args).await;
152
153 self.invoke_after(&ctx, message, result).await;
154 }
155 None => {
156 self.invoke_unknown_command(&ctx, message).await;
157 }
158 }
159 } else if let Some(code) = self.normal_message.as_ref() {
160 code.invoke(&ctx, message).await;
161 }
162 }
163}
164
165#[derive(Default)]
166pub struct RootGroup {
167 subgroups: Vec<Group>,
168}
169
170impl RootGroup {
171 pub(crate) fn find_command<'a, 'b>(
172 &'a self,
173 command: &'b str,
174 ) -> Option<(&'a Command, &'b str)> {
175 self.subgroups
176 .iter()
177 .find_map(|it| it.find_command(command))
178 }
179}
180
181#[derive(Default)]
182pub struct Group {
183 name: Cow<'static, str>,
184 subgroups: Vec<Group>,
185 commands: Vec<Command>,
186 default_invoke: Option<Command>,
187}
188
189impl Group {
190 pub fn name(self, name: impl Into<Cow<'static, str>>) -> Self {
191 Self {
192 name: name.into(),
193 ..self
194 }
195 }
196
197 pub fn subgroup<F>(mut self, f: F) -> Self
198 where
199 F: FnOnce(Group) -> Group,
200 {
201 let group = f(Group::default());
202 debug_assert!(
203 group.name.as_ref() != "",
204 "Name of group is \"\"; did you forget to set name of group?"
205 );
206
207 self.subgroups.push(group);
208 self
209 }
210
211 pub fn command<F>(mut self, f: F) -> Self
212 where
213 F: FnOnce() -> Command,
214 {
215 let command = f();
216 self.commands.push(command);
217 self
218 }
219
220 pub fn default_command<F>(self, f: F) -> Self
221 where
222 F: FnOnce() -> Command,
223 {
224 let command = f();
225 Self {
226 default_invoke: Some(command),
227 ..self
228 }
229 }
230}
231
232impl Group {
233 pub(crate) fn find_command<'a, 'b>(
234 &'a self,
235 command: &'b str,
236 ) -> Option<(&'a Command, &'b str)> {
237 self.subgroups
238 .iter()
239 .find_map(|group| {
240 let group_name: &str = group.name.borrow();
241 if let Some(rest) = command.strip_prefix(group_name) {
242 if rest.trim() == "" {
243 Some((group.default_invoke.as_ref()?, ""))
244 } else if rest.starts_with(char::is_whitespace) {
245 group.find_command(rest.trim_start())
246 } else {
247 None
248 }
249 } else {
250 None
251 }
252 })
253 .or_else(|| {
254 self.commands.iter().find_map(|c| {
255 let command_name: &str = c.name.borrow();
256 let rest = std::iter::once(command_name)
257 .chain(c.aliases.iter().map(|it| -> &str { it }))
258 .find_map(|name| command.strip_prefix(name));
259 if let Some(rest) = rest {
260 if rest.trim() == "" {
261 Some((c, ""))
262 } else if rest.starts_with(char::is_whitespace) {
263 Some((c, rest.trim_start()))
264 } else {
265 None
266 }
267 } else {
268 None
269 }
270 })
271 })
272 .or_else(|| Some((self.default_invoke.as_ref()?, command.trim_start())))
273 }
274}
275
276#[derive(Debug)]
277pub struct Command {
278 name: Cow<'static, str>,
279 aliases: smallvec::SmallVec<[Cow<'static, str>; 4]>,
280 code: CommandCode,
281 required_perms: (ServerPermissions, ChannelPermissions),
282 owners_only: bool,
283}
284
285impl Command {
286 pub fn new(name: impl Into<Cow<'static, str>>, code: impl Into<CommandCode>) -> Self {
287 Self {
288 name: name.into(),
289 aliases: smallvec::SmallVec::default(),
290 code: code.into(),
291 required_perms: (ServerPermissions::empty(), ChannelPermissions::empty()),
292 owners_only: false,
293 }
294 }
295
296 pub fn alias(mut self, alias: impl Into<Cow<'static, str>>) -> Self {
297 self.aliases.push(alias.into());
298 self
299 }
300
301 pub fn required_server_permissions(mut self, perms: impl Into<ServerPermissions>) -> Self {
302 self.required_perms.0 = perms.into();
303 self
304 }
305
306 pub fn required_channel_permissions(mut self, perms: impl Into<ChannelPermissions>) -> Self {
307 self.required_perms.1 = perms.into();
308 self
309 }
310
311 pub fn owners_only(self, owners_only: impl Into<bool>) -> Self {
312 Self {
313 owners_only: owners_only.into(),
314 ..self
315 }
316 }
317}
318
319#[derive(Debug, thiserror::Error)]
320#[error("one or more of the following permissions are missing: (server: {0:?}, channel: {1:?})")]
321pub struct MissingPermissions(ServerPermissions, ChannelPermissions);
322
323async fn check_perms<'a>(
324 ctx: &'a FwContext,
325 message: &'a Message,
326 sp: ServerPermissions,
327 cp: ChannelPermissions,
328) -> CommandResult {
329 let channel = message.channel(ctx).await?;
330
331 match &channel {
332 Channel::SavedMessages(..) | Channel::DirectMessage(..) => Ok::<_, CommandError>(()),
333 ch @ Channel::Group(..) => {
334 let check = robespierre_models::permissions_utils::user_has_permissions_in_group(
335 message.author,
336 ch,
337 cp,
338 );
339
340 if check {
341 Ok::<_, CommandError>(())
342 } else {
343 Err(MissingPermissions(ServerPermissions::empty(), cp).into())
344 }
345 }
346 ch @ Channel::TextChannel { .. } | ch @ Channel::VoiceChannel { .. } => {
347 let server = ch.server_id().unwrap();
348 let member = server.member(ctx, message.author).await?;
349 let server = server.server(ctx).await?;
350
351 let check = robespierre_models::permissions_utils::member_has_permissions_in_channel(
352 &member, sp, &server, cp, ch,
353 );
354
355 if check {
356 Ok::<_, CommandError>(())
357 } else {
358 Err(MissingPermissions(sp, cp).into())
359 }
360 }
361 }
362}
363
364impl Command {
365 fn invoke<'a>(
366 &'a self,
367 ctx: &'a FwContext,
368 message: &'a Arc<Message>,
369 args: &'a str,
370 ) -> impl Future<Output = CommandResult> + Send + 'a {
371 let (sp, cp) = self.required_perms;
372 async move {
373 check_perms(ctx, message, sp, cp).await?;
374
375 self.code.invoke(ctx, message, args).await?;
376
377 Ok::<_, CommandError>(())
378 }
379 }
380}
381
382#[derive(Clone)]
383pub struct FwContext {
384 ctx: Context,
385}
386
387impl HasHttp for FwContext {
388 fn get_http(&self) -> &robespierre_http::Http {
389 self.ctx.get_http()
390 }
391}
392
393#[cfg(feature = "cache")]
394impl HasCache for FwContext {
395 fn get_cache(&self) -> Option<&Cache> {
396 self.ctx.get_cache()
397 }
398}
399
400impl AsRef<Context> for FwContext {
401 fn as_ref(&self) -> &Context {
402 &self.ctx
403 }
404}
405
406impl From<Context> for FwContext {
407 fn from(ctx: Context) -> Self {
408 Self { ctx }
409 }
410}
411
412#[async_trait::async_trait]
413impl UserData for FwContext {
414 async fn data_lock_read(&self) -> tokio::sync::RwLockReadGuard<typemap::ShareMap> {
415 self.ctx.data_lock_read().await
416 }
417
418 async fn data_lock_write(&self) -> tokio::sync::RwLockWriteGuard<typemap::ShareMap> {
419 self.ctx.data_lock_write().await
420 }
421}
422
423pub type AfterHandlerCodeFn = for<'a> fn(
424 ctx: &'a FwContext,
425 message: &'a Message,
426 result: CommandResult,
427) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
428
429pub enum AfterHandlerCode {
430 Binary(AfterHandlerCodeFn),
431 #[cfg(feature = "interpreter")]
432 Interpreted(String),
433}
434
435impl From<AfterHandlerCodeFn> for AfterHandlerCode {
436 fn from(code: AfterHandlerCodeFn) -> Self {
437 Self::Binary(code)
438 }
439}
440
441impl AfterHandlerCode {
442 pub async fn invoke<'a>(
443 &'a self,
444 ctx: &'a FwContext,
445 message: &'a Message,
446 result: CommandResult,
447 ) {
448 match self {
449 AfterHandlerCode::Binary(f) => f(ctx, message, result).await,
450 #[cfg(feature = "interpreter")]
451 AfterHandlerCode::Interpreted(code) => todo!(),
452 }
453 }
454}
455
456pub type CommandCodeFn = for<'a> fn(
457 ctx: &'a FwContext,
458 message: &'a Arc<Message>,
459 args: &'a str,
460) -> Pin<Box<dyn Future<Output = CommandResult> + Send + 'a>>;
461
462pub enum CommandCode {
463 Binary(CommandCodeFn),
464 #[cfg(feature = "interpreter")]
465 Interpreted(String),
466}
467
468impl From<CommandCodeFn> for CommandCode {
469 fn from(code: CommandCodeFn) -> Self {
470 Self::Binary(code)
471 }
472}
473
474impl fmt::Debug for CommandCode {
475 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476 match self {
477 Self::Binary(code) => f
478 .debug_tuple("Binary")
479 .field(&format_args!("{:p}", code as *const _))
480 .finish(),
481 #[cfg(feature = "interpreter")]
482 Self::Interpreted(code) => f.debug_tuple("Interpreted").field(code).finish(),
483 }
484 }
485}
486
487pub type CommandError = Box<dyn std::error::Error + Send + Sync + 'static>;
488pub type CommandResult<T = ()> = Result<T, CommandError>;
489
490impl CommandCode {
491 pub async fn invoke(
492 &self,
493 ctx: &FwContext,
494 message: &Arc<Message>,
495 args: &str,
496 ) -> CommandResult {
497 match self {
498 CommandCode::Binary(f) => f(ctx, message, args).await,
499 #[cfg(feature = "interpreter")]
500 CommandCode::Interpreted(code) => todo!(),
501 }
502 }
503}
504
505pub type UnknownCommandHandlerCodeFn = for<'a> fn(
506 ctx: &'a FwContext,
507 message: &'a Message,
508) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
509
510pub enum UnknownCommandHandlerCode {
511 Binary(UnknownCommandHandlerCodeFn),
512 #[cfg(feature = "interpreter")]
513 Interpreted(String),
514}
515
516impl UnknownCommandHandlerCode {
517 pub async fn invoke(&self, ctx: &FwContext, message: &Message) {
518 match self {
519 Self::Binary(f) => f(ctx, message).await,
520 #[cfg(feature = "interpreter")]
521 Self::Interpreted(code) => todo!(),
522 }
523 }
524}
525
526pub type NormalMessageHandlerCodeFn = for<'a> fn(
527 ctx: &'a FwContext,
528 message: &'a Message,
529) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
530pub enum NormalMessageHandlerCode {
531 Binary(NormalMessageHandlerCodeFn),
532 #[cfg(feature = "interpreter")]
533 Interpreted(String),
534}
535
536impl From<NormalMessageHandlerCodeFn> for NormalMessageHandlerCode {
537 fn from(code: NormalMessageHandlerCodeFn) -> Self {
538 Self::Binary(code)
539 }
540}
541
542impl NormalMessageHandlerCode {
543 pub async fn invoke(&self, ctx: &FwContext, message: &Message) {
544 match self {
545 Self::Binary(f) => f(ctx, message).await,
546 #[cfg(feature = "interpreter")]
547 Self::Interpreted(code) => todo!(),
548 }
549 }
550}
551
552#[cfg(test)]
553mod test;