use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use regex::Regex;
use rust_tg_bot_raw::types::update::Update;
use super::base::{Handler, HandlerCallback, HandlerResult, MatchResult};
use crate::context::CallbackContext;
#[derive(Clone)]
#[non_exhaustive]
pub enum CallbackPattern {
Data(Regex),
Game(Regex),
Both {
data: Regex,
game: Regex,
},
Predicate(Arc<dyn Fn(&str) -> bool + Send + Sync>),
}
pub struct CallbackQueryHandler {
callback: HandlerCallback,
pattern: Option<CallbackPattern>,
block: bool,
}
impl CallbackQueryHandler {
pub fn new(callback: HandlerCallback, pattern: Option<CallbackPattern>, block: bool) -> Self {
Self {
callback,
pattern,
block,
}
}
fn try_regex(re: &Regex, text: &str) -> Option<MatchResult> {
let caps = re.captures(text)?;
let positional: Vec<String> = caps
.iter()
.filter_map(|m| m.map(|m| m.as_str().to_owned()))
.collect();
let mut named: HashMap<String, String> = HashMap::new();
for name in re.capture_names().flatten() {
if let Some(m) = caps.name(name) {
named.insert(name.to_owned(), m.as_str().to_owned());
}
}
if named.is_empty() {
Some(MatchResult::RegexMatch(positional))
} else {
Some(MatchResult::RegexMatchWithNames { positional, named })
}
}
}
impl Handler for CallbackQueryHandler {
fn check_update(&self, update: &Update) -> Option<MatchResult> {
let cq = update.callback_query()?;
match &self.pattern {
None => {
Some(MatchResult::Empty)
}
Some(CallbackPattern::Data(re)) => {
let data = cq.data.as_ref()?;
Self::try_regex(re, data)
}
Some(CallbackPattern::Game(re)) => {
let game = cq.game_short_name.as_ref()?;
Self::try_regex(re, game)
}
Some(CallbackPattern::Both {
data: dre,
game: gre,
}) => {
if let Some(data) = cq.data.as_ref() {
if let Some(mr) = Self::try_regex(dre, data) {
return Some(mr);
}
}
if let Some(game) = cq.game_short_name.as_ref() {
return Self::try_regex(gre, game);
}
None
}
Some(CallbackPattern::Predicate(pred)) => {
let data = cq.data.as_ref()?;
if pred(data) {
Some(MatchResult::Empty)
} else {
None
}
}
}
}
fn handle_update(
&self,
update: Arc<Update>,
match_result: MatchResult,
) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
(self.callback)(update, match_result)
}
fn block(&self) -> bool {
self.block
}
fn collect_additional_context(
&self,
context: &mut CallbackContext,
match_result: &MatchResult,
) {
match match_result {
MatchResult::RegexMatch(groups) => {
context.matches = Some(groups.clone());
}
MatchResult::RegexMatchWithNames { positional, named } => {
context.matches = Some(positional.clone());
context.named_matches = Some(named.clone());
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
fn noop_callback() -> HandlerCallback {
Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }))
}
fn make_callback_query_update(data: &str) -> Update {
serde_json::from_value(serde_json::json!({
"update_id": 1,
"callback_query": {
"id": "42",
"from": {"id": 1, "is_bot": false, "first_name": "Test"},
"chat_instance": "ci",
"data": data
}
}))
.expect("valid callback query update")
}
#[test]
fn no_callback_query_returns_none() {
let h = CallbackQueryHandler::new(noop_callback(), None, true);
let update: Update = serde_json::from_str(r#"{"update_id": 1}"#).unwrap();
assert!(h.check_update(&update).is_none());
}
#[test]
fn predicate_accepts_matching_data() {
let h = CallbackQueryHandler::new(
noop_callback(),
Some(CallbackPattern::Predicate(Arc::new(|data| {
data.starts_with("btn_")
}))),
true,
);
let update = make_callback_query_update("btn_42");
assert!(h.check_update(&update).is_some());
}
#[test]
fn predicate_rejects_non_matching_data() {
let h = CallbackQueryHandler::new(
noop_callback(),
Some(CallbackPattern::Predicate(Arc::new(|data| {
data.starts_with("btn_")
}))),
true,
);
let update = make_callback_query_update("action_42");
assert!(h.check_update(&update).is_none());
}
#[test]
fn predicate_requires_data_field() {
let h = CallbackQueryHandler::new(
noop_callback(),
Some(CallbackPattern::Predicate(Arc::new(|_| true))),
true,
);
let update: Update = serde_json::from_value(serde_json::json!({
"update_id": 1,
"callback_query": {
"id": "42",
"from": {"id": 1, "is_bot": false, "first_name": "Test"},
"chat_instance": "ci",
"game_short_name": "mygame"
}
}))
.expect("valid");
assert!(h.check_update(&update).is_none());
}
#[test]
fn no_pattern_accepts_any_callback_query() {
let h = CallbackQueryHandler::new(noop_callback(), None, true);
let update = make_callback_query_update("anything");
assert!(h.check_update(&update).is_some());
}
#[test]
fn regex_data_pattern_matches() {
let h = CallbackQueryHandler::new(
noop_callback(),
Some(CallbackPattern::Data(Regex::new(r"^btn_(\d+)$").unwrap())),
true,
);
let update = make_callback_query_update("btn_123");
let result = h.check_update(&update);
assert!(result.is_some());
if let Some(MatchResult::RegexMatch(groups)) = result {
assert_eq!(groups[0], "btn_123");
assert_eq!(groups[1], "123");
} else {
panic!("expected RegexMatch");
}
}
#[test]
fn named_group_pattern_returns_regex_match_with_names() {
let re = Regex::new(r"^btn_(?P<id>\d+)$").unwrap();
let h = CallbackQueryHandler::new(noop_callback(), Some(CallbackPattern::Data(re)), true);
let update = make_callback_query_update("btn_99");
match h.check_update(&update) {
Some(MatchResult::RegexMatchWithNames { positional, named }) => {
assert_eq!(positional[0], "btn_99");
assert_eq!(named.get("id").map(String::as_str), Some("99"));
}
other => panic!("expected RegexMatchWithNames, got {other:?}"),
}
}
#[test]
fn collect_context_positional() {
use crate::context::CallbackContext;
use crate::ext_bot::test_support::mock_request;
use rust_tg_bot_raw::bot::Bot;
let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
"test",
mock_request(),
)));
let stores = (
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
);
let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
let h = CallbackQueryHandler::new(noop_callback(), None, true);
let mr = MatchResult::RegexMatch(vec!["full".into(), "group1".into()]);
h.collect_additional_context(&mut ctx, &mr);
assert_eq!(ctx.matches, Some(vec!["full".into(), "group1".into()]));
assert!(ctx.named_matches.is_none());
}
#[test]
fn collect_context_named() {
use crate::context::CallbackContext;
use crate::ext_bot::test_support::mock_request;
use rust_tg_bot_raw::bot::Bot;
let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
"test",
mock_request(),
)));
let stores = (
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
);
let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
let h = CallbackQueryHandler::new(noop_callback(), None, true);
let mut named = HashMap::new();
named.insert("id".into(), "99".into());
let mr = MatchResult::RegexMatchWithNames {
positional: vec!["btn_99".into(), "99".into()],
named,
};
h.collect_additional_context(&mut ctx, &mr);
assert_eq!(ctx.matches, Some(vec!["btn_99".into(), "99".into()]));
assert_eq!(
ctx.named_matches
.as_ref()
.and_then(|m| m.get("id"))
.map(String::as_str),
Some("99")
);
}
}