use std::collections::HashMap;
use std::fmt;
use std::ops::{BitAnd, BitOr, BitXor, Not};
use std::sync::Arc;
use serde_json::Value;
pub type Update = rust_tg_bot_raw::types::update::Update;
pub fn to_value(update: &Update) -> Value {
serde_json::to_value(update).unwrap_or_default()
}
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum FilterResult {
NoMatch,
Match,
MatchWithData(HashMap<String, Vec<String>>),
}
impl FilterResult {
#[must_use]
pub fn is_match(&self) -> bool {
!matches!(self, FilterResult::NoMatch)
}
#[must_use]
pub fn merge(self, other: FilterResult) -> FilterResult {
match (self, other) {
(FilterResult::NoMatch, _) | (_, FilterResult::NoMatch) => FilterResult::NoMatch,
(FilterResult::MatchWithData(mut a), FilterResult::MatchWithData(b)) => {
for (k, mut v) in b {
a.entry(k).or_default().append(&mut v);
}
FilterResult::MatchWithData(a)
}
(FilterResult::MatchWithData(d), FilterResult::Match)
| (FilterResult::Match, FilterResult::MatchWithData(d)) => {
FilterResult::MatchWithData(d)
}
(FilterResult::Match, FilterResult::Match) => FilterResult::Match,
}
}
}
pub fn effective_message_val(v: &Value) -> Option<&Value> {
v.get("message")
.or_else(|| v.get("edited_message"))
.or_else(|| v.get("channel_post"))
.or_else(|| v.get("edited_channel_post"))
.or_else(|| v.get("business_message"))
.or_else(|| v.get("edited_business_message"))
}
pub fn effective_user_val(v: &Value) -> Option<&Value> {
if let Some(msg) = effective_message_val(v) {
if let Some(u) = msg.get("from") {
return Some(u);
}
}
for key in &[
"callback_query",
"inline_query",
"chosen_inline_result",
"shipping_query",
"pre_checkout_query",
"poll_answer",
"my_chat_member",
"chat_member",
"chat_join_request",
] {
if let Some(obj) = v.get(key) {
if let Some(u) = obj.get("from") {
return Some(u);
}
}
}
None
}
pub fn effective_chat_val(v: &Value) -> Option<&Value> {
if let Some(msg) = effective_message_val(v) {
if let Some(c) = msg.get("chat") {
return Some(c);
}
}
for key in &[
"callback_query",
"my_chat_member",
"chat_member",
"chat_join_request",
] {
if let Some(obj) = v.get(key) {
if let Some(c) = obj.get("chat") {
return Some(c);
}
}
}
None
}
pub fn effective_message(update: &Update) -> Option<Value> {
update
.effective_message()
.and_then(|m| serde_json::to_value(m).ok())
}
pub fn effective_user(update: &Update) -> Option<Value> {
update
.effective_user()
.and_then(|u| serde_json::to_value(u).ok())
}
pub fn effective_chat(update: &Update) -> Option<Value> {
update
.effective_chat()
.and_then(|c| serde_json::to_value(c).ok())
}
pub fn has_effective_message(update: &Update) -> bool {
update.effective_message().is_some()
}
pub trait Filter: Send + Sync + 'static {
fn check_update(&self, update: &Update) -> FilterResult;
fn name(&self) -> &str {
std::any::type_name::<Self>()
}
}
#[derive(Clone)]
pub struct F(pub Arc<dyn Filter>);
impl F {
pub fn new(filter: impl Filter) -> Self {
Self(Arc::new(filter))
}
}
impl fmt::Debug for F {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "F({})", self.0.name())
}
}
impl fmt::Display for F {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0.name())
}
}
impl Filter for F {
fn check_update(&self, update: &Update) -> FilterResult {
self.0.check_update(update)
}
fn name(&self) -> &str {
self.0.name()
}
}
struct AndFilter {
left: F,
right: F,
display: String,
}
impl Filter for AndFilter {
fn check_update(&self, update: &Update) -> FilterResult {
let left = self.left.check_update(update);
if !left.is_match() {
return FilterResult::NoMatch;
}
let right = self.right.check_update(update);
left.merge(right)
}
fn name(&self) -> &str {
&self.display
}
}
struct OrFilter {
left: F,
right: F,
display: String,
}
impl Filter for OrFilter {
fn check_update(&self, update: &Update) -> FilterResult {
let left = self.left.check_update(update);
if left.is_match() {
return left;
}
self.right.check_update(update)
}
fn name(&self) -> &str {
&self.display
}
}
struct XorFilter {
left: F,
right: F,
display: String,
}
impl Filter for XorFilter {
fn check_update(&self, update: &Update) -> FilterResult {
let left = self.left.check_update(update);
let right = self.right.check_update(update);
match (left.is_match(), right.is_match()) {
(true, false) => self.left.check_update(update),
(false, true) => right,
_ => FilterResult::NoMatch,
}
}
fn name(&self) -> &str {
&self.display
}
}
struct NotFilter {
inner: F,
display: String,
}
impl Filter for NotFilter {
fn check_update(&self, update: &Update) -> FilterResult {
if self.inner.check_update(update).is_match() {
FilterResult::NoMatch
} else {
FilterResult::Match
}
}
fn name(&self) -> &str {
&self.display
}
}
impl BitAnd for F {
type Output = F;
fn bitand(self, rhs: F) -> F {
let display = format!("<{} and {}>", self.0.name(), rhs.0.name());
F(Arc::new(AndFilter {
left: self,
right: rhs,
display,
}))
}
}
impl BitOr for F {
type Output = F;
fn bitor(self, rhs: F) -> F {
let display = format!("<{} or {}>", self.0.name(), rhs.0.name());
F(Arc::new(OrFilter {
left: self,
right: rhs,
display,
}))
}
}
impl BitXor for F {
type Output = F;
fn bitxor(self, rhs: F) -> F {
let display = format!("<{} xor {}>", self.0.name(), rhs.0.name());
F(Arc::new(XorFilter {
left: self,
right: rhs,
display,
}))
}
}
impl Not for F {
type Output = F;
fn not(self) -> F {
let display = format!("<not {}>", self.0.name());
F(Arc::new(NotFilter {
inner: self,
display,
}))
}
}
pub struct FnFilter<Func> {
func: Func,
label: &'static str,
}
impl<Func> FnFilter<Func>
where
Func: Fn(&Update) -> bool + Send + Sync + 'static,
{
pub fn new(label: &'static str, func: Func) -> Self {
Self { func, label }
}
}
impl<Func> Filter for FnFilter<Func>
where
Func: Fn(&Update) -> bool + Send + Sync + 'static,
{
fn check_update(&self, update: &Update) -> FilterResult {
if (self.func)(update) {
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
self.label
}
}
pub struct All;
impl Filter for All {
fn check_update(&self, update: &Update) -> FilterResult {
if has_effective_message(update) {
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.ALL"
}
}
pub const ALL: All = All;
macro_rules! message_presence_filter {
(
$(#[$meta:meta])*
$struct_name:ident, $field:ident, $display:expr
) => {
$(#[$meta])*
pub struct $struct_name;
impl Filter for $struct_name {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_message()
.and_then(|m| m.$field.as_ref())
.is_some()
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
$display
}
}
};
(
$(#[$meta:meta])*
bool: $struct_name:ident, $field:ident, $display:expr
) => {
$(#[$meta])*
pub struct $struct_name;
impl Filter for $struct_name {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_message()
.map(|m| m.$field)
.unwrap_or(false)
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
$display
}
}
};
}
message_presence_filter!(
AnimationFilter, animation, "filters.ANIMATION"
);
pub const ANIMATION: AnimationFilter = AnimationFilter;
message_presence_filter!(
AudioFilter, audio, "filters.AUDIO"
);
pub const AUDIO: AudioFilter = AudioFilter;
message_presence_filter!(
BoostAdded, boost_added, "filters.BOOST_ADDED"
);
pub const BOOST_ADDED: BoostAdded = BoostAdded;
message_presence_filter!(
ChecklistFilter, checklist, "filters.CHECKLIST"
);
pub const CHECKLIST: ChecklistFilter = ChecklistFilter;
message_presence_filter!(
ContactFilter, contact, "filters.CONTACT"
);
pub const CONTACT: ContactFilter = ContactFilter;
message_presence_filter!(
EffectId, effect_id, "filters.EFFECT_ID"
);
pub const EFFECT_ID: EffectId = EffectId;
message_presence_filter!(
ForwardedPresence, forward_origin, "filters.FORWARDED"
);
pub const FORWARDED: ForwardedPresence = ForwardedPresence;
message_presence_filter!(
GameFilter, game, "filters.GAME"
);
pub const GAME: GameFilter = GameFilter;
message_presence_filter!(
GiveawayFilter, giveaway, "filters.GIVEAWAY"
);
pub const GIVEAWAY: GiveawayFilter = GiveawayFilter;
message_presence_filter!(
GiveawayWinners, giveaway_winners, "filters.GIVEAWAY_WINNERS"
);
pub const GIVEAWAY_WINNERS: GiveawayWinners = GiveawayWinners;
message_presence_filter!(
InvoiceFilter, invoice, "filters.INVOICE"
);
pub const INVOICE: InvoiceFilter = InvoiceFilter;
message_presence_filter!(
LocationFilter, location, "filters.LOCATION"
);
pub const LOCATION: LocationFilter = LocationFilter;
message_presence_filter!(
PaidMediaFilter, paid_media, "filters.PAID_MEDIA"
);
pub const PAID_MEDIA: PaidMediaFilter = PaidMediaFilter;
message_presence_filter!(
PassportDataFilter, passport_data, "filters.PASSPORT_DATA"
);
pub const PASSPORT_DATA: PassportDataFilter = PassportDataFilter;
message_presence_filter!(
PollFilter, poll, "filters.POLL"
);
pub const POLL: PollFilter = PollFilter;
message_presence_filter!(
ReplyFilter, reply_to_message, "filters.REPLY"
);
pub const REPLY: ReplyFilter = ReplyFilter;
message_presence_filter!(
ReplyToStory, reply_to_story, "filters.REPLY_TO_STORY"
);
pub const REPLY_TO_STORY: ReplyToStory = ReplyToStory;
message_presence_filter!(
StoryFilter, story, "filters.STORY"
);
pub const STORY: StoryFilter = StoryFilter;
message_presence_filter!(
VenueFilter, venue, "filters.VENUE"
);
pub const VENUE: VenueFilter = VenueFilter;
message_presence_filter!(
VideoFilter, video, "filters.VIDEO"
);
pub const VIDEO: VideoFilter = VideoFilter;
message_presence_filter!(
VideoNoteFilter, video_note, "filters.VIDEO_NOTE"
);
pub const VIDEO_NOTE: VideoNoteFilter = VideoNoteFilter;
message_presence_filter!(
VoiceFilter, voice, "filters.VOICE"
);
pub const VOICE: VoiceFilter = VoiceFilter;
message_presence_filter!(
SuggestedPostInfo, suggested_post_info, "filters.SUGGESTED_POST_INFO"
);
pub const SUGGESTED_POST_INFO: SuggestedPostInfo = SuggestedPostInfo;
message_presence_filter!(
bool: HasMediaSpoiler, has_media_spoiler, "filters.HAS_MEDIA_SPOILER"
);
pub const HAS_MEDIA_SPOILER: HasMediaSpoiler = HasMediaSpoiler;
message_presence_filter!(
bool: HasProtectedContent, has_protected_content, "filters.HAS_PROTECTED_CONTENT"
);
pub const HAS_PROTECTED_CONTENT: HasProtectedContent = HasProtectedContent;
message_presence_filter!(
bool: IsAutomaticForward, is_automatic_forward, "filters.IS_AUTOMATIC_FORWARD"
);
pub const IS_AUTOMATIC_FORWARD: IsAutomaticForward = IsAutomaticForward;
message_presence_filter!(
bool: IsTopicMessage, is_topic_message, "filters.IS_TOPIC_MESSAGE"
);
pub const IS_TOPIC_MESSAGE: IsTopicMessage = IsTopicMessage;
message_presence_filter!(
bool: IsFromOffline, is_from_offline, "filters.IS_FROM_OFFLINE"
);
pub const IS_FROM_OFFLINE: IsFromOffline = IsFromOffline;
pub struct SenderBoostCount;
impl Filter for SenderBoostCount {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_message()
.and_then(|m| m.sender_boost_count)
.is_some()
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.SENDER_BOOST_COUNT"
}
}
pub const SENDER_BOOST_COUNT: SenderBoostCount = SenderBoostCount;
pub struct AttachmentFilter;
impl Filter for AttachmentFilter {
fn check_update(&self, update: &Update) -> FilterResult {
let msg = match update.effective_message() {
Some(m) => m,
None => return FilterResult::NoMatch,
};
let matched = msg.animation.is_some()
|| msg.audio.is_some()
|| msg.contact.is_some()
|| msg.dice.is_some()
|| msg.document.is_some()
|| msg.game.is_some()
|| msg.invoice.is_some()
|| msg.location.is_some()
|| msg.paid_media.is_some()
|| msg.passport_data.is_some()
|| msg.photo.as_ref().map(|a| !a.is_empty()).unwrap_or(false)
|| msg.poll.is_some()
|| msg.sticker.is_some()
|| msg.story.is_some()
|| msg.venue.is_some()
|| msg.video.is_some()
|| msg.video_note.is_some()
|| msg.voice.is_some();
if matched {
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.ATTACHMENT"
}
}
pub const ATTACHMENT: AttachmentFilter = AttachmentFilter;
pub struct ForumFilter;
impl Filter for ForumFilter {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_message()
.and_then(|m| m.chat.is_forum)
.unwrap_or(false)
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.FORUM"
}
}
pub const FORUM: ForumFilter = ForumFilter;
pub struct DirectMessages;
impl Filter for DirectMessages {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_message()
.and_then(|m| m.chat.is_direct_messages)
.unwrap_or(false)
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.DIRECT_MESSAGES"
}
}
pub const DIRECT_MESSAGES: DirectMessages = DirectMessages;
pub struct UserPresence;
impl Filter for UserPresence {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_message()
.and_then(|m| m.from_user.as_ref())
.is_some()
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.USER"
}
}
pub const USER: UserPresence = UserPresence;
pub struct UserAttachmentMenu;
impl Filter for UserAttachmentMenu {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_user()
.and_then(|u| u.added_to_attachment_menu)
.unwrap_or(false)
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.USER_ATTACHMENT"
}
}
pub const USER_ATTACHMENT: UserAttachmentMenu = UserAttachmentMenu;
pub struct PremiumUser;
impl Filter for PremiumUser {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_user()
.and_then(|u| u.is_premium)
.unwrap_or(false)
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.PREMIUM_USER"
}
}
pub const PREMIUM_USER: PremiumUser = PremiumUser;
pub struct SenderChatPresence;
impl Filter for SenderChatPresence {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_message()
.and_then(|m| m.sender_chat.as_ref())
.is_some()
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.SenderChat.ALL"
}
}
pub struct ViaBotPresence;
impl Filter for ViaBotPresence {
fn check_update(&self, update: &Update) -> FilterResult {
if update
.effective_message()
.and_then(|m| m.via_bot.as_ref())
.is_some()
{
FilterResult::Match
} else {
FilterResult::NoMatch
}
}
fn name(&self) -> &str {
"filters.VIA_BOT"
}
}
pub const VIA_BOT: ViaBotPresence = ViaBotPresence;
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn text_update(text: &str) -> Update {
serde_json::from_value(json!({
"update_id": 1,
"message": {
"message_id": 1,
"date": 0,
"chat": {"id": 1, "type": "private"},
"text": text
}
}))
.unwrap()
}
fn empty_update() -> Update {
serde_json::from_value(json!({"update_id": 1})).unwrap()
}
#[test]
fn all_matches_message() {
assert!(ALL.check_update(&text_update("hello")).is_match());
}
#[test]
fn all_rejects_empty() {
assert!(!ALL.check_update(&empty_update()).is_match());
}
#[test]
fn and_combinator() {
let f = F::new(All) & F::new(All);
assert!(f.check_update(&text_update("hello")).is_match());
}
#[test]
fn or_combinator() {
let f = F::new(All) | F::new(All);
assert!(!f.check_update(&empty_update()).is_match());
}
#[test]
fn not_combinator() {
let f = !F::new(All);
assert!(!f.check_update(&text_update("hi")).is_match());
}
#[test]
fn xor_both_true_is_false() {
let f = F::new(All) ^ F::new(All);
assert!(!f.check_update(&text_update("hi")).is_match());
}
#[test]
fn fn_filter_works() {
let f = FnFilter::new("always_true", |_| true);
assert!(f.check_update(&empty_update()).is_match());
}
#[test]
fn presence_animation() {
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"message": {
"message_id": 1, "date": 0,
"chat": {"id": 1, "type": "private"},
"animation": {"file_id": "a", "file_unique_id": "b", "width": 1, "height": 1, "duration": 1}
}
})).unwrap();
assert!(ANIMATION.check_update(&update).is_match());
assert!(!VIDEO.check_update(&update).is_match());
}
#[test]
fn attachment_computed() {
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"message": {
"message_id": 1, "date": 0,
"chat": {"id": 1, "type": "private"},
"document": {"file_id": "d", "file_unique_id": "e"}
}
}))
.unwrap();
assert!(ATTACHMENT.check_update(&update).is_match());
}
#[test]
fn effective_message_from_edited() {
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"edited_message": {"message_id": 2, "chat": {"id": 1, "type": "private"}, "date": 0}
}))
.unwrap();
assert!(effective_message(&update).is_some());
}
#[test]
fn effective_user_from_callback() {
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"callback_query": {
"id": "1",
"from": {"id": 42, "is_bot": false, "first_name": "Test"},
"chat_instance": "ci"
}
}))
.unwrap();
let user = effective_user(&update).unwrap();
assert_eq!(user.get("id").unwrap().as_i64().unwrap(), 42);
}
#[test]
fn forum_filter() {
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"message": {
"message_id": 1, "date": 0,
"chat": {"id": 1, "type": "supergroup", "is_forum": true},
"text": "hello"
}
}))
.unwrap();
assert!(FORUM.check_update(&update).is_match());
}
#[test]
fn premium_user_filter() {
let update: Update = serde_json::from_value(json!({
"update_id": 1,
"message": {
"message_id": 1, "date": 0,
"chat": {"id": 1, "type": "private"},
"from": {"id": 1, "is_bot": false, "first_name": "A", "is_premium": true},
"text": "hi"
}
}))
.unwrap();
assert!(PREMIUM_USER.check_update(&update).is_match());
}
#[test]
fn filter_result_merge() {
let a = FilterResult::MatchWithData(HashMap::from([("x".into(), vec!["1".into()])]));
let b = FilterResult::MatchWithData(HashMap::from([("x".into(), vec!["2".into()])]));
let merged = a.merge(b);
if let FilterResult::MatchWithData(m) = merged {
assert_eq!(m.get("x").unwrap(), &vec!["1".to_owned(), "2".to_owned()]);
} else {
panic!("expected MatchWithData");
}
}
#[test]
fn filter_result_merge_nomatch() {
let a = FilterResult::Match;
let b = FilterResult::NoMatch;
assert_eq!(a.merge(b), FilterResult::NoMatch);
}
#[test]
fn and_combinator_merges_data() {
let f1 = FnFilter::new("f1", |_| true);
let f2 = FnFilter::new("f2", |_| true);
let combined = F::new(f1) & F::new(f2);
assert!(combined.check_update(&text_update("hi")).is_match());
}
#[test]
fn or_returns_first_match() {
let f1 = FnFilter::new("f1", |_| true);
let f2 = FnFilter::new("f2", |_| false);
let combined = F::new(f1) | F::new(f2);
assert!(combined.check_update(&text_update("hi")).is_match());
}
#[test]
fn xor_one_true() {
let f1 = FnFilter::new("f1", |_| true);
let f2 = FnFilter::new("f2", |_| false);
let combined = F::new(f1) ^ F::new(f2);
assert!(combined.check_update(&text_update("hi")).is_match());
}
}