use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use regex::Regex;
use rust_tg_bot_raw::types::update::Update;
use super::base::{Handler, HandlerCallback, HandlerResult, MatchResult};
use crate::context::CallbackContext;
pub struct StringRegexHandler {
pattern: Regex,
callback: HandlerCallback,
block: bool,
}
impl StringRegexHandler {
pub fn new(pattern: Regex, callback: HandlerCallback, block: bool) -> Self {
Self {
pattern,
callback,
block,
}
}
}
impl Handler for StringRegexHandler {
fn check_update(&self, update: &Update) -> Option<MatchResult> {
let message = update.effective_message()?;
let text = message.text.as_ref()?;
let caps = self.pattern.captures(text)?;
let positional: Vec<String> = caps
.iter()
.filter_map(|m| m.map(|m| m.as_str().to_owned()))
.collect();
let mut named: HashMap<String, String> = HashMap::new();
for name in self.pattern.capture_names().flatten() {
if let Some(m) = caps.name(name) {
named.insert(name.to_owned(), m.as_str().to_owned());
}
}
if named.is_empty() {
Some(MatchResult::RegexMatch(positional))
} else {
Some(MatchResult::RegexMatchWithNames { positional, named })
}
}
fn handle_update(
&self,
update: Arc<Update>,
match_result: MatchResult,
) -> Pin<Box<dyn Future<Output = HandlerResult> + Send>> {
(self.callback)(update, match_result)
}
fn block(&self) -> bool {
self.block
}
fn collect_additional_context(
&self,
context: &mut CallbackContext,
match_result: &MatchResult,
) {
match match_result {
MatchResult::RegexMatch(groups) => {
context.matches = Some(groups.clone());
}
MatchResult::RegexMatchWithNames { positional, named } => {
context.matches = Some(positional.clone());
context.named_matches = Some(named.clone());
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
fn noop_callback() -> HandlerCallback {
Arc::new(|_update, _mr| Box::pin(async { HandlerResult::Continue }))
}
fn make_message_update(text: &str) -> Update {
serde_json::from_str(&format!(
r#"{{"update_id":1,"message":{{"message_id":1,"date":0,"chat":{{"id":1,"type":"private"}},"text":"{text}"}}}}"#
))
.unwrap()
}
#[test]
fn matches_regex() {
let h =
StringRegexHandler::new(Regex::new(r"^hello (\w+)").unwrap(), noop_callback(), true);
let update: Update = serde_json::from_str(
r#"{"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hello world"}}"#,
).unwrap();
let result = h.check_update(&update);
assert!(result.is_some());
if let Some(MatchResult::RegexMatch(groups)) = result {
assert_eq!(groups.len(), 2);
assert_eq!(groups[0], "hello world");
assert_eq!(groups[1], "world");
} else {
panic!("expected RegexMatch");
}
}
#[test]
fn no_match_returns_none() {
let h = StringRegexHandler::new(Regex::new(r"^goodbye").unwrap(), noop_callback(), true);
let update: Update = serde_json::from_str(
r#"{"update_id":1,"message":{"message_id":1,"date":0,"chat":{"id":1,"type":"private"},"text":"hello world"}}"#,
).unwrap();
assert!(h.check_update(&update).is_none());
}
#[test]
fn named_group_returns_regex_match_with_names() {
let h = StringRegexHandler::new(
Regex::new(r"^hello (?P<name>\w+)").unwrap(),
noop_callback(),
true,
);
let update = make_message_update("hello alice");
match h.check_update(&update) {
Some(MatchResult::RegexMatchWithNames { positional, named }) => {
assert_eq!(positional[0], "hello alice");
assert_eq!(named.get("name").map(String::as_str), Some("alice"));
}
other => panic!("expected RegexMatchWithNames, got {other:?}"),
}
}
#[test]
fn collect_context_populates_matches() {
use crate::context::CallbackContext;
use crate::ext_bot::test_support::mock_request;
use rust_tg_bot_raw::bot::Bot;
let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
"test",
mock_request(),
)));
let stores = (
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
);
let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
let h = StringRegexHandler::new(Regex::new(r"x").unwrap(), noop_callback(), true);
let mr = MatchResult::RegexMatch(vec!["hello".into()]);
h.collect_additional_context(&mut ctx, &mr);
assert_eq!(ctx.matches, Some(vec!["hello".into()]));
assert!(ctx.named_matches.is_none());
}
#[test]
fn collect_context_populates_named_matches() {
use crate::context::CallbackContext;
use crate::ext_bot::test_support::mock_request;
use rust_tg_bot_raw::bot::Bot;
let bot = Arc::new(crate::ext_bot::ExtBot::from_bot(Bot::new(
"test",
mock_request(),
)));
let stores = (
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
Arc::new(tokio::sync::RwLock::new(HashMap::new())),
);
let mut ctx = CallbackContext::new(bot, None, None, stores.0, stores.1, stores.2);
let h = StringRegexHandler::new(Regex::new(r"x").unwrap(), noop_callback(), true);
let mut named = HashMap::new();
named.insert("name".into(), "alice".into());
let mr = MatchResult::RegexMatchWithNames {
positional: vec!["hello alice".into(), "alice".into()],
named,
};
h.collect_additional_context(&mut ctx, &mr);
assert_eq!(
ctx.matches,
Some(vec!["hello alice".into(), "alice".into()])
);
assert_eq!(
ctx.named_matches
.as_ref()
.and_then(|m| m.get("name"))
.map(String::as_str),
Some("alice")
);
}
}