1use crate::ChannelMessage;
2use std::collections::HashMap;
3
4#[derive(Debug, Clone, PartialEq, Eq)]
6pub enum GroupReplyMode {
7 AllMessages,
9 MentionOnly,
11}
12
13impl GroupReplyMode {
14 pub fn parse(s: &str) -> Self {
15 match s.trim().to_lowercase().as_str() {
16 "mention_only" | "mentiononly" | "mention" => Self::MentionOnly,
17 _ => Self::AllMessages,
18 }
19 }
20}
21
22#[derive(Debug, Clone)]
24pub struct GroupReplyPolicy {
25 pub mode: GroupReplyMode,
26 pub allowed_sender_ids: Vec<String>,
27 pub bot_name: Option<String>,
28}
29
30impl Default for GroupReplyPolicy {
31 fn default() -> Self {
32 Self {
33 mode: GroupReplyMode::AllMessages,
34 allowed_sender_ids: Vec::new(),
35 bot_name: None,
36 }
37 }
38}
39
40#[derive(Debug, Clone, Default)]
42pub struct GroupReplyFilter {
43 policies: HashMap<String, GroupReplyPolicy>,
44}
45
46impl GroupReplyFilter {
47 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn with_policy(mut self, channel: impl Into<String>, policy: GroupReplyPolicy) -> Self {
52 self.policies.insert(channel.into(), policy);
53 self
54 }
55
56 pub fn set_policy(&mut self, channel: impl Into<String>, policy: GroupReplyPolicy) {
57 self.policies.insert(channel.into(), policy);
58 }
59
60 pub fn should_process(&self, msg: &ChannelMessage) -> bool {
63 let policy = match self.policies.get(&msg.channel) {
64 Some(p) => p,
65 None => return true, };
67
68 match policy.mode {
69 GroupReplyMode::AllMessages => true,
70 GroupReplyMode::MentionOnly => {
71 if !policy.allowed_sender_ids.is_empty()
73 && policy
74 .allowed_sender_ids
75 .iter()
76 .any(|id| id == "*" || id.eq_ignore_ascii_case(&msg.sender))
77 {
78 return true;
79 }
80
81 if let Some(bot_name) = &policy.bot_name {
83 let content_lower = msg.content.to_lowercase();
84 let bot_lower = bot_name.to_lowercase();
85
86 content_lower.contains(&format!("@{bot_lower}"))
88 || content_lower.contains(&bot_lower)
89 } else {
90 false
92 }
93 }
94 }
95 }
96}
97
98#[cfg(test)]
99mod tests {
100 use super::*;
101
102 fn test_msg(channel: &str, sender: &str, content: &str) -> ChannelMessage {
103 ChannelMessage {
104 id: "1".into(),
105 sender: sender.into(),
106 reply_target: sender.into(),
107 content: content.into(),
108 channel: channel.into(),
109 timestamp: 0,
110 thread_ts: None,
111 privacy_boundary: String::new(),
112 }
113 }
114
115 #[test]
116 fn no_policy_allows_all_messages() {
117 let filter = GroupReplyFilter::new();
118 assert!(filter.should_process(&test_msg("telegram", "alice", "hello")));
119 }
120
121 #[test]
122 fn all_messages_mode_allows_all() {
123 let filter = GroupReplyFilter::new().with_policy(
124 "telegram",
125 GroupReplyPolicy {
126 mode: GroupReplyMode::AllMessages,
127 ..Default::default()
128 },
129 );
130 assert!(filter.should_process(&test_msg("telegram", "alice", "hello")));
131 }
132
133 #[test]
134 fn mention_only_drops_non_mention() {
135 let filter = GroupReplyFilter::new().with_policy(
136 "telegram",
137 GroupReplyPolicy {
138 mode: GroupReplyMode::MentionOnly,
139 bot_name: Some("MyBot".into()),
140 ..Default::default()
141 },
142 );
143 assert!(!filter.should_process(&test_msg("telegram", "alice", "hello everyone")));
144 }
145
146 #[test]
147 fn mention_only_allows_at_mention() {
148 let filter = GroupReplyFilter::new().with_policy(
149 "telegram",
150 GroupReplyPolicy {
151 mode: GroupReplyMode::MentionOnly,
152 bot_name: Some("MyBot".into()),
153 ..Default::default()
154 },
155 );
156 assert!(filter.should_process(&test_msg("telegram", "alice", "hey @mybot help me")));
157 }
158
159 #[test]
160 fn mention_only_allows_name_mention() {
161 let filter = GroupReplyFilter::new().with_policy(
162 "telegram",
163 GroupReplyPolicy {
164 mode: GroupReplyMode::MentionOnly,
165 bot_name: Some("MyBot".into()),
166 ..Default::default()
167 },
168 );
169 assert!(filter.should_process(&test_msg("telegram", "alice", "MyBot can you help?")));
170 }
171
172 #[test]
173 fn mention_only_allows_allowed_sender() {
174 let filter = GroupReplyFilter::new().with_policy(
175 "telegram",
176 GroupReplyPolicy {
177 mode: GroupReplyMode::MentionOnly,
178 bot_name: Some("MyBot".into()),
179 allowed_sender_ids: vec!["admin".into()],
180 },
181 );
182 assert!(filter.should_process(&test_msg("telegram", "admin", "do something")));
184 assert!(!filter.should_process(&test_msg("telegram", "alice", "do something")));
186 }
187
188 #[test]
189 fn mention_only_wildcard_sender_allows_all() {
190 let filter = GroupReplyFilter::new().with_policy(
191 "telegram",
192 GroupReplyPolicy {
193 mode: GroupReplyMode::MentionOnly,
194 bot_name: None,
195 allowed_sender_ids: vec!["*".into()],
196 },
197 );
198 assert!(filter.should_process(&test_msg("telegram", "anyone", "anything")));
199 }
200
201 #[test]
202 fn mention_only_no_bot_name_no_allowed_senders_drops() {
203 let filter = GroupReplyFilter::new().with_policy(
204 "telegram",
205 GroupReplyPolicy {
206 mode: GroupReplyMode::MentionOnly,
207 bot_name: None,
208 allowed_sender_ids: Vec::new(),
209 },
210 );
211 assert!(!filter.should_process(&test_msg("telegram", "alice", "hello")));
212 }
213
214 #[test]
215 fn different_channel_not_affected() {
216 let filter = GroupReplyFilter::new().with_policy(
217 "telegram",
218 GroupReplyPolicy {
219 mode: GroupReplyMode::MentionOnly,
220 bot_name: Some("MyBot".into()),
221 ..Default::default()
222 },
223 );
224 assert!(filter.should_process(&test_msg("discord", "alice", "hello")));
226 }
227
228 #[test]
229 fn group_reply_mode_from_str() {
230 assert_eq!(
231 GroupReplyMode::parse("mention_only"),
232 GroupReplyMode::MentionOnly
233 );
234 assert_eq!(
235 GroupReplyMode::parse("MentionOnly"),
236 GroupReplyMode::MentionOnly
237 );
238 assert_eq!(
239 GroupReplyMode::parse("mention"),
240 GroupReplyMode::MentionOnly
241 );
242 assert_eq!(
243 GroupReplyMode::parse("all_messages"),
244 GroupReplyMode::AllMessages
245 );
246 assert_eq!(
247 GroupReplyMode::parse("anything_else"),
248 GroupReplyMode::AllMessages
249 );
250 }
251}