use super::{Filter, FilterResult};
use crate::{
client::{Bot, Session},
errors::SessionErrorKind,
methods::GetMe,
types::BotCommand,
FromContext, Request,
};
use regex::Regex;
use std::{borrow::Cow, iter::once};
use tracing::{event, instrument, Level};
#[derive(Debug, Clone)]
pub enum PatternType {
Text(Cow<'static, str>),
Object(BotCommand),
Regex(Regex),
}
impl From<Cow<'static, str>> for PatternType {
#[inline]
fn from(text: Cow<'static, str>) -> Self {
Self::Text(text)
}
}
impl From<&'static str> for PatternType {
#[inline]
fn from(text: &'static str) -> Self {
Self::Text(Cow::Borrowed(text))
}
}
impl From<BotCommand> for PatternType {
#[inline]
fn from(command: BotCommand) -> Self {
Self::Object(command)
}
}
impl From<Regex> for PatternType {
#[inline]
fn from(regex: Regex) -> Self {
Self::Regex(regex)
}
}
#[derive(Debug, Clone)]
pub struct Command {
commands: Vec<PatternType>,
prefix: char,
ignore_case: bool,
ignore_mention: bool,
}
impl Command {
#[must_use]
#[instrument(skip(commands))]
pub fn new<CommandType, Commands>(
commands: Commands,
prefix: char,
ignore_case: bool,
ignore_mention: bool,
) -> Self
where
CommandType: Into<PatternType>,
Commands: IntoIterator<Item = CommandType>,
{
let commands = if ignore_case {
commands
.into_iter()
.map(|command| match command.into() {
PatternType::Text(text) => PatternType::Text(text.to_lowercase().into()),
PatternType::Object(command) => {
PatternType::Text(command.command.to_lowercase().into())
}
PatternType::Regex(regex) => {
if ignore_mention {
event!(Level::WARN, "Ignore mention flag doesn't work with regexes");
}
PatternType::Regex(regex)
}
})
.collect()
} else {
commands
.into_iter()
.map(|command| match command.into() {
PatternType::Text(text) => PatternType::Text(text),
PatternType::Object(command) => {
PatternType::Text(Cow::Owned(command.command.into_string()))
}
PatternType::Regex(regex) => {
if ignore_mention {
event!(Level::WARN, "Ignore mention flag doesn't work with regexes");
}
PatternType::Regex(regex)
}
})
.collect()
};
Self {
commands,
prefix,
ignore_case,
ignore_mention,
}
}
#[inline]
#[must_use]
pub fn one(command: impl Into<PatternType>) -> Self {
Self::builder().command(command).build()
}
#[inline]
#[must_use]
pub fn one_with_prefix(command: impl Into<PatternType>, prefix: char) -> Self {
Self::builder().command(command).prefix(prefix).build()
}
#[inline]
#[must_use]
pub fn many<T, I>(commands: I) -> Self
where
T: Into<PatternType>,
I: IntoIterator<Item = T>,
{
Self::builder().commands(commands).build()
}
#[inline]
#[must_use]
pub fn many_with_prefix<T, I>(commands: I, prefix: char) -> Self
where
T: Into<PatternType>,
I: IntoIterator<Item = T>,
{
Self::builder().commands(commands).prefix(prefix).build()
}
#[inline]
#[must_use]
pub fn builder() -> Builder {
Builder::new()
}
}
impl Default for Command {
#[inline]
fn default() -> Self {
Self {
commands: vec![],
prefix: '/',
ignore_case: false,
ignore_mention: false,
}
}
}
#[derive(Debug, Clone)]
pub struct Builder {
commands: Vec<PatternType>,
prefix: char,
ignore_case: bool,
ignore_mention: bool,
}
impl Builder {
#[inline]
#[must_use]
pub fn new() -> Builder {
Self::default()
}
#[must_use]
pub fn command(self, val: impl Into<PatternType>) -> Self {
Self {
commands: self.commands.into_iter().chain(once(val.into())).collect(),
..self
}
}
#[must_use]
pub fn commands<T, I>(self, val: I) -> Self
where
T: Into<PatternType>,
I: IntoIterator<Item = T>,
{
Self {
commands: self
.commands
.into_iter()
.chain(val.into_iter().map(Into::into))
.collect(),
..self
}
}
#[inline]
#[must_use]
pub fn prefix(self, val: char) -> Self {
Self {
prefix: val,
..self
}
}
#[inline]
#[must_use]
pub fn ignore_case(self, val: bool) -> Self {
Self {
ignore_case: val,
..self
}
}
#[inline]
#[must_use]
pub fn ignore_mention(self, val: bool) -> Self {
Self {
ignore_mention: val,
..self
}
}
#[inline]
#[must_use]
pub fn build(self) -> Command {
Command::new(
self.commands,
self.prefix,
self.ignore_case,
self.ignore_mention,
)
}
}
impl Default for Builder {
#[inline]
fn default() -> Self {
Self {
commands: vec![],
prefix: '/',
ignore_case: false,
ignore_mention: false,
}
}
}
impl Command {
#[inline]
#[must_use]
pub fn validate_prefix(&self, command: &CommandObject) -> bool {
command.prefix == self.prefix
}
#[allow(clippy::missing_panics_doc)]
pub async fn validate_mention(
&self,
command: &CommandObject,
bot: &Bot<impl Session>,
) -> Result<bool, SessionErrorKind> {
if self.ignore_mention {
Ok(true)
} else if let Some(ref mention) = command.mention {
bot.send(GetMe {}).await.map(|user| {
user.username.unwrap().eq(mention)
})
} else {
Ok(true)
}
}
#[must_use]
pub fn validate_command(&self, command: &CommandObject) -> bool {
let command = if self.ignore_case {
command.command.to_lowercase().into_boxed_str()
} else {
command.command.clone()
};
let command_ref = command.as_ref();
for pattern in &*self.commands {
match pattern {
PatternType::Text(allowed_command) => {
if command_ref == allowed_command {
return true;
}
}
PatternType::Regex(regex) => {
if regex.is_match(&command) {
return true;
}
}
PatternType::Object(_) => {
unreachable!(
"`PatternType::Object` should be converted to `PatternType::Text` before \
validation"
)
}
}
}
false
}
pub async fn validate_command_object(
&self,
command: &CommandObject,
bot: &Bot<impl Session>,
) -> Result<bool, SessionErrorKind> {
Ok(self.validate_prefix(command)
&& self.validate_command(command)
&& self.validate_mention(command, bot).await?)
}
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Clone, Hash, PartialEq, Eq, FromContext)]
#[context(
key = "command",
description = "Parsed command object. This type is available only if the command filter is \
used and filer is passed."
)]
pub struct CommandObject {
pub command: Box<str>,
pub prefix: char,
pub mention: Option<Box<str>>,
pub args: Box<[Box<str>]>,
}
impl CommandObject {
#[must_use]
pub fn extract(text: &str) -> Option<Self> {
let result: Box<[&str]> = text.trim().split(' ').collect();
let full_command = result[0];
let args = result[1..]
.iter()
.map(|arg| (*arg).to_owned().into_boxed_str())
.collect();
let mut full_command_chars = full_command.chars();
let prefix = full_command_chars.next()?;
let command = full_command_chars.as_str();
if command.is_empty() {
return None;
}
let (command, mention) = if !command.starts_with('@') && command.contains('@') {
let result: Box<[&str]> = command.split('@').collect();
let command = result[0];
let mention = result[1];
let mention = if mention.is_empty() {
None
} else {
Some(mention)
};
(command, mention)
} else {
(command, None)
};
Some(CommandObject {
command: command.into(),
prefix,
mention: mention.map(Into::into),
args,
})
}
}
impl<Client> Filter<Client> for Command
where
Client: Session + 'static,
{
type Error = SessionErrorKind;
#[instrument]
async fn check(&mut self, request: &mut Request<Client>) -> FilterResult<Self::Error> {
let Some(message) = request.update.message() else {
return Ok(false);
};
let Some(text) = message.text().or(message.caption()) else {
return Ok(false);
};
let Some(command) = CommandObject::extract(text) else {
return Ok(false);
};
match self.validate_command_object(&command, &request.bot).await {
Ok(result) => {
if result {
request.context.insert("command", command);
Ok(true)
} else {
Ok(false)
}
}
Err(err) => {
event!(Level::ERROR, error = %err, "Failed to validate command object");
Err(err)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_extract() {
let command_obj = CommandObject::extract("/start").unwrap();
assert_eq!(command_obj.command.as_ref(), "start");
assert_eq!(command_obj.prefix, '/');
assert_eq!(command_obj.mention, None);
assert_eq!(command_obj.args, [].into());
let command_obj = CommandObject::extract("/start@bot_username").unwrap();
assert_eq!(command_obj.command.as_ref(), "start");
assert_eq!(command_obj.prefix, '/');
assert_eq!(command_obj.mention.as_deref(), Some("bot_username"));
assert_eq!(command_obj.args, [].into());
let command_obj = CommandObject::extract("/start@").unwrap();
assert_eq!(command_obj.command.as_ref(), "start");
assert_eq!(command_obj.prefix, '/');
assert_eq!(command_obj.mention, None);
assert_eq!(command_obj.args, [].into());
let command_obj = CommandObject::extract("/@").unwrap();
assert_eq!(command_obj.command.as_ref(), "@");
assert_eq!(command_obj.prefix, '/');
assert_eq!(command_obj.mention, None);
assert_eq!(command_obj.args, [].into());
let command_obj = CommandObject::extract("@/").unwrap();
assert_eq!(command_obj.command.as_ref(), "/");
assert_eq!(command_obj.prefix, '@');
assert_eq!(command_obj.mention, None);
assert_eq!(command_obj.args, [].into());
let command_obj = CommandObject::extract("/@ arg1 arg2").unwrap();
assert_eq!(command_obj.command.as_ref(), "@");
assert_eq!(command_obj.prefix, '/');
assert_eq!(command_obj.mention, None);
assert!(command_obj.args == Box::new(["arg1".into(), "arg2".into()]) as Box<_>);
let command_obj = CommandObject::extract("/@bot_username").unwrap();
assert_eq!(command_obj.command.as_ref(), "@bot_username");
assert_eq!(command_obj.prefix, '/');
assert_eq!(command_obj.mention, None);
assert_eq!(command_obj.args, [].into());
let command_obj = CommandObject::extract("@start@bot_username").unwrap();
assert_eq!(command_obj.command.as_ref(), "start");
assert_eq!(command_obj.prefix, '@');
assert_eq!(command_obj.mention.as_deref(), Some("bot_username"));
assert_eq!(command_obj.args, [].into());
let command_obj = CommandObject::extract("/start@bot_username arg1 arg2").unwrap();
assert_eq!(command_obj.command.as_ref(), "start");
assert_eq!(command_obj.prefix, '/');
assert_eq!(command_obj.mention.as_deref(), Some("bot_username"));
assert!(command_obj.args == Box::new(["arg1".into(), "arg2".into()]) as Box<_>);
let command_obj = CommandObject::extract("Telegram says: 123").unwrap();
assert_eq!(command_obj.command.as_ref(), "elegram");
assert_eq!(command_obj.prefix, 'T');
assert_eq!(command_obj.mention, None);
assert!(command_obj.args == Box::new(["says:".into(), "123".into()]) as Box<_>);
let command_obj = CommandObject::extract("One two").unwrap();
assert_eq!(command_obj.command.as_ref(), "ne");
assert_eq!(command_obj.prefix, 'O');
assert_eq!(command_obj.mention, None);
assert!(command_obj.args == Box::new(["two".into()]) as Box<_>);
let command_obj = CommandObject::extract("Один два").unwrap();
assert_eq!(command_obj.command.as_ref(), "дин");
assert_eq!(command_obj.prefix, 'О');
assert_eq!(command_obj.mention, None);
assert!(command_obj.args == Box::new(["два".into()]) as Box<_>);
}
#[test]
#[should_panic]
fn test_command_extract_panic() {
assert!(
CommandObject::extract("").is_some()
|| CommandObject::extract("/").is_some()
);
}
#[test]
fn test_validate_prefix() {
let command = Command::builder().prefix('/').command("start").build();
let command_obj = CommandObject::extract("/start").unwrap();
assert!(command.validate_prefix(&command_obj));
let command_obj = CommandObject::extract("/start_other").unwrap();
assert!(command.validate_prefix(&command_obj));
let command_obj = CommandObject::extract("!start").unwrap();
assert!(!command.validate_prefix(&command_obj));
}
#[test]
fn test_validate_command() {
let command = Command::builder()
.prefix('/')
.command("start")
.ignore_case(false)
.build();
let command_obj = CommandObject::extract("/start").unwrap();
assert!(command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/START").unwrap();
assert!(!command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/stop").unwrap();
assert!(!command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/STOP").unwrap();
assert!(!command.validate_command(&command_obj));
let command = Command::builder()
.prefix('/')
.command("start")
.ignore_case(true)
.build();
let command_obj = CommandObject::extract("/start").unwrap();
assert!(command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/START").unwrap();
assert!(command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/stop").unwrap();
assert!(!command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/STOP").unwrap();
assert!(!command.validate_command(&command_obj));
let command = Command::builder()
.prefix('/')
.command("Start")
.ignore_case(true)
.build();
let command_obj = CommandObject::extract("/start").unwrap();
assert!(command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/START").unwrap();
assert!(command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/stop").unwrap();
assert!(!command.validate_command(&command_obj));
let command_obj = CommandObject::extract("/STOP").unwrap();
assert!(!command.validate_command(&command_obj));
}
}