use std::collections::HashSet;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use rust_tg_bot_raw::types::update::Update;
use super::base::{Handler, HandlerCallback, HandlerResult, MatchResult};
pub type UpdateFilter = Arc<dyn Fn(&Update) -> bool + Send + Sync>;
fn default_update_filter(update: &Update) -> bool {
update.message().is_some()
}
pub struct PrefixHandler {
commands: HashSet<String>,
callback: HandlerCallback,
block: bool,
filter_fn: Option<UpdateFilter>,
}
impl PrefixHandler {
pub fn new(
prefixes: Vec<String>,
commands: Vec<String>,
callback: HandlerCallback,
block: bool,
) -> Self {
let mut combined = HashSet::new();
for p in &prefixes {
for c in &commands {
combined.insert(format!("{}{}", p.to_lowercase(), c.to_lowercase()));
}
}
Self {
commands: combined,
callback,
block,
filter_fn: None,
}
}
pub fn with_filter(mut self, filter: UpdateFilter) -> Self {
self.filter_fn = Some(filter);
self
}
}
impl Handler for PrefixHandler {
fn check_update(&self, update: &Update) -> Option<MatchResult> {
let passes_filter = match &self.filter_fn {
Some(f) => f(update),
None => default_update_filter(update),
};
if !passes_filter {
return None;
}
let message = update.effective_message()?;
let text = message.text.as_ref()?;
let mut words = text.split_whitespace();
let first_word = words.next()?;
if !self.commands.contains(&first_word.to_lowercase()) {
return None;
}
let args: Vec<String> = words.map(String::from).collect();
Some(MatchResult::Args(args))
}
fn handle_update(
&self,
update: Arc<Update>,
match_result: MatchResult,
) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
(self.callback)(update, match_result)
}
fn block(&self) -> bool {
self.block
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use serde_json::json;
fn noop_callback() -> HandlerCallback {
Arc::new(|_u, _m| Box::pin(async { HandlerResult::Continue }))
}
#[test]
fn cartesian_product_commands() {
let h = PrefixHandler::new(
vec!["!".into(), "#".into()],
vec!["test".into(), "help".into()],
noop_callback(),
true,
);
assert!(h.commands.contains("!test"));
assert!(h.commands.contains("#test"));
assert!(h.commands.contains("!help"));
assert!(h.commands.contains("#help"));
assert_eq!(h.commands.len(), 4);
}
#[test]
fn default_filter_accepts_message() {
let h = PrefixHandler::new(vec!["!".into()], vec!["test".into()], noop_callback(), true);
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"message": {
"message_id": 1,
"date": 0,
"chat": {"id": 1, "type": "private"},
"text": "!test arg1"
}
}))
.expect("valid");
assert!(h.check_update(&update).is_some());
}
#[test]
fn default_filter_accepts_edited_message() {
let h = PrefixHandler::new(vec!["!".into()], vec!["test".into()], noop_callback(), true);
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"edited_message": {
"message_id": 1,
"date": 0,
"chat": {"id": 1, "type": "private"},
"text": "!test arg1"
}
}))
.expect("valid");
assert!(h.check_update(&update).is_some());
}
#[test]
fn default_filter_accepts_channel_post() {
let h = PrefixHandler::new(vec!["!".into()], vec!["test".into()], noop_callback(), true);
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"channel_post": {
"message_id": 1,
"date": 0,
"chat": {"id": -100, "type": "channel"},
"text": "!test"
}
}))
.expect("valid");
assert!(h.check_update(&update).is_some());
}
#[test]
fn custom_filter_rejects() {
let h = PrefixHandler::new(vec!["!".into()], vec!["test".into()], noop_callback(), true)
.with_filter(Arc::new(|_u| false));
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"message": {
"message_id": 1,
"date": 0,
"chat": {"id": 1, "type": "private"},
"text": "!test"
}
}))
.expect("valid");
assert!(h.check_update(&update).is_none());
}
#[test]
fn custom_filter_allows_channel_post() {
let h = PrefixHandler::new(vec!["!".into()], vec!["test".into()], noop_callback(), true)
.with_filter(Arc::new(|u| {
u.message().is_some() || u.channel_post().is_some()
}));
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"channel_post": {
"message_id": 1,
"date": 0,
"chat": {"id": -100, "type": "channel"},
"text": "!test arg"
}
}))
.expect("valid");
assert!(h.check_update(&update).is_some());
}
}