1use std::collections::HashMap;
6use std::time::{Duration, Instant};
7
8use crate::auto_reply::types::TriggerType;
9
10#[derive(Debug, Clone)]
12pub enum CooldownCheckResult {
13 Allowed,
15 InCooldown { remaining: Duration },
17}
18
19pub struct CooldownTracker {
21 last_trigger: HashMap<String, Instant>,
23 default_cooldown: Duration,
25 type_cooldowns: HashMap<TriggerType, Duration>,
27}
28
29impl CooldownTracker {
30 pub fn new(default_cooldown: Duration) -> Self {
32 Self {
33 last_trigger: HashMap::new(),
34 default_cooldown,
35 type_cooldowns: HashMap::new(),
36 }
37 }
38
39 pub fn check_cooldown(&self, user_id: &str, trigger_type: TriggerType) -> CooldownCheckResult {
41 let cooldown = self
42 .type_cooldowns
43 .get(&trigger_type)
44 .copied()
45 .unwrap_or(self.default_cooldown);
46
47 match self.last_trigger.get(user_id) {
48 Some(last) => {
49 let elapsed = last.elapsed();
50 if elapsed < cooldown {
51 CooldownCheckResult::InCooldown {
52 remaining: cooldown - elapsed,
53 }
54 } else {
55 CooldownCheckResult::Allowed
56 }
57 }
58 None => CooldownCheckResult::Allowed,
59 }
60 }
61
62 pub fn record_trigger(&mut self, user_id: &str) {
64 self.last_trigger
65 .insert(user_id.to_string(), Instant::now());
66 }
67
68 pub fn set_type_cooldown(&mut self, trigger_type: TriggerType, duration: Duration) {
70 self.type_cooldowns.insert(trigger_type, duration);
71 }
72
73 pub fn reset_cooldown(&mut self, user_id: &str) {
75 self.last_trigger.remove(user_id);
76 }
77
78 pub fn cleanup_expired(&mut self) {
80 let max_cooldown = self
81 .type_cooldowns
82 .values()
83 .max()
84 .copied()
85 .unwrap_or(self.default_cooldown);
86
87 self.last_trigger
88 .retain(|_, instant| instant.elapsed() < max_cooldown * 2);
89 }
90
91 pub fn default_cooldown(&self) -> Duration {
93 self.default_cooldown
94 }
95
96 pub fn get_type_cooldown(&self, trigger_type: TriggerType) -> Duration {
98 self.type_cooldowns
99 .get(&trigger_type)
100 .copied()
101 .unwrap_or(self.default_cooldown)
102 }
103
104 pub fn has_trigger_record(&self, user_id: &str) -> bool {
106 self.last_trigger.contains_key(user_id)
107 }
108}
109
110#[cfg(test)]
111mod tests {
112 use super::*;
113 use proptest::prelude::*;
114 use std::thread;
115
116 fn arb_user_id() -> impl Strategy<Value = String> {
124 "[a-zA-Z0-9_]{1,20}".prop_map(|s| s)
125 }
126
127 fn arb_cooldown_ms() -> impl Strategy<Value = u64> {
130 10u64..500
131 }
132
133 fn arb_trigger_type() -> impl Strategy<Value = TriggerType> {
135 prop_oneof![
136 Just(TriggerType::Mention),
137 Just(TriggerType::Keyword),
138 Just(TriggerType::DirectMessage),
139 Just(TriggerType::Schedule),
140 Just(TriggerType::Webhook),
141 ]
142 }
143
144 proptest! {
145 #![proptest_config(ProptestConfig::with_cases(20))]
146
147 #[test]
150 fn prop_new_user_always_allowed(
151 user_id in arb_user_id(),
152 cooldown_ms in arb_cooldown_ms(),
153 trigger_type in arb_trigger_type()
154 ) {
155 let tracker = CooldownTracker::new(Duration::from_millis(cooldown_ms));
156
157 let result = tracker.check_cooldown(&user_id, trigger_type);
159 prop_assert!(
160 matches!(result, CooldownCheckResult::Allowed),
161 "New user should always be allowed, got {:?}",
162 result
163 );
164 }
165
166 #[test]
169 fn prop_after_trigger_user_in_cooldown(
170 user_id in arb_user_id(),
171 cooldown_ms in 100u64..1000, trigger_type in arb_trigger_type()
173 ) {
174 let mut tracker = CooldownTracker::new(Duration::from_millis(cooldown_ms));
175
176 tracker.record_trigger(&user_id);
178
179 let result = tracker.check_cooldown(&user_id, trigger_type);
181 prop_assert!(
182 matches!(result, CooldownCheckResult::InCooldown { .. }),
183 "User should be in cooldown immediately after trigger, got {:?}",
184 result
185 );
186 }
187
188 #[test]
191 fn prop_remaining_time_bounded_by_cooldown(
192 user_id in arb_user_id(),
193 cooldown_ms in 100u64..1000,
194 trigger_type in arb_trigger_type()
195 ) {
196 let cooldown = Duration::from_millis(cooldown_ms);
197 let mut tracker = CooldownTracker::new(cooldown);
198
199 tracker.record_trigger(&user_id);
201
202 let result = tracker.check_cooldown(&user_id, trigger_type);
204 if let CooldownCheckResult::InCooldown { remaining } = result {
205 prop_assert!(
206 remaining <= cooldown,
207 "Remaining time {:?} should be <= cooldown {:?}",
208 remaining,
209 cooldown
210 );
211 }
212 }
213
214 #[test]
217 fn prop_different_trigger_types_different_cooldowns(
218 user_id in arb_user_id(),
219 default_ms in 100u64..500,
220 mention_ms in 10u64..50,
221 keyword_ms in 200u64..500
222 ) {
223 prop_assume!(mention_ms < keyword_ms);
225
226 let mut tracker = CooldownTracker::new(Duration::from_millis(default_ms));
227 tracker.set_type_cooldown(TriggerType::Mention, Duration::from_millis(mention_ms));
228 tracker.set_type_cooldown(TriggerType::Keyword, Duration::from_millis(keyword_ms));
229
230 tracker.record_trigger(&user_id);
232
233 thread::sleep(Duration::from_millis(mention_ms + 10));
235
236 let mention_result = tracker.check_cooldown(&user_id, TriggerType::Mention);
238 prop_assert!(
239 matches!(mention_result, CooldownCheckResult::Allowed),
240 "Mention should be allowed after its cooldown expires, got {:?}",
241 mention_result
242 );
243
244 let keyword_result = tracker.check_cooldown(&user_id, TriggerType::Keyword);
246 prop_assert!(
247 matches!(keyword_result, CooldownCheckResult::InCooldown { .. }),
248 "Keyword should still be in cooldown, got {:?}",
249 keyword_result
250 );
251 }
252
253 #[test]
256 fn prop_reset_cooldown_allows_immediate_retrigger(
257 user_id in arb_user_id(),
258 cooldown_ms in 100u64..1000,
259 trigger_type in arb_trigger_type()
260 ) {
261 let mut tracker = CooldownTracker::new(Duration::from_millis(cooldown_ms));
262
263 tracker.record_trigger(&user_id);
265
266 let result_before = tracker.check_cooldown(&user_id, trigger_type);
268 prop_assert!(
269 matches!(result_before, CooldownCheckResult::InCooldown { .. }),
270 "Should be in cooldown before reset"
271 );
272
273 tracker.reset_cooldown(&user_id);
275
276 let result_after = tracker.check_cooldown(&user_id, trigger_type);
278 prop_assert!(
279 matches!(result_after, CooldownCheckResult::Allowed),
280 "Should be allowed after reset, got {:?}",
281 result_after
282 );
283 }
284
285 #[test]
288 fn prop_independent_user_cooldowns(
289 user1 in arb_user_id(),
290 user2 in arb_user_id(),
291 cooldown_ms in 100u64..1000,
292 trigger_type in arb_trigger_type()
293 ) {
294 prop_assume!(user1 != user2);
296
297 let mut tracker = CooldownTracker::new(Duration::from_millis(cooldown_ms));
298
299 tracker.record_trigger(&user1);
301
302 let result1 = tracker.check_cooldown(&user1, trigger_type);
304 prop_assert!(
305 matches!(result1, CooldownCheckResult::InCooldown { .. }),
306 "User1 should be in cooldown"
307 );
308
309 let result2 = tracker.check_cooldown(&user2, trigger_type);
311 prop_assert!(
312 matches!(result2, CooldownCheckResult::Allowed),
313 "User2 should be allowed, got {:?}",
314 result2
315 );
316 }
317
318 #[test]
321 fn prop_cooldown_expires_allows_trigger(
322 user_id in arb_user_id(),
323 trigger_type in arb_trigger_type()
324 ) {
325 let cooldown = Duration::from_millis(20);
327 let mut tracker = CooldownTracker::new(cooldown);
328
329 tracker.record_trigger(&user_id);
331
332 let result_immediate = tracker.check_cooldown(&user_id, trigger_type);
334 prop_assert!(
335 matches!(result_immediate, CooldownCheckResult::InCooldown { .. }),
336 "Should be in cooldown immediately"
337 );
338
339 thread::sleep(Duration::from_millis(30));
341
342 let result_after = tracker.check_cooldown(&user_id, trigger_type);
344 prop_assert!(
345 matches!(result_after, CooldownCheckResult::Allowed),
346 "Should be allowed after cooldown expires, got {:?}",
347 result_after
348 );
349 }
350
351 #[test]
354 fn prop_get_type_cooldown_consistent(
355 default_ms in arb_cooldown_ms(),
356 type_ms in arb_cooldown_ms(),
357 trigger_type in arb_trigger_type()
358 ) {
359 let default_cooldown = Duration::from_millis(default_ms);
360 let type_cooldown = Duration::from_millis(type_ms);
361
362 let mut tracker = CooldownTracker::new(default_cooldown);
363
364 prop_assert_eq!(
366 tracker.get_type_cooldown(trigger_type),
367 default_cooldown,
368 "Should return default cooldown when type not set"
369 );
370
371 tracker.set_type_cooldown(trigger_type, type_cooldown);
373 prop_assert_eq!(
374 tracker.get_type_cooldown(trigger_type),
375 type_cooldown,
376 "Should return set cooldown for type"
377 );
378 }
379
380 #[test]
383 fn prop_trigger_record_consistency(
384 user_id in arb_user_id(),
385 cooldown_ms in arb_cooldown_ms()
386 ) {
387 let mut tracker = CooldownTracker::new(Duration::from_millis(cooldown_ms));
388
389 prop_assert!(
391 !tracker.has_trigger_record(&user_id),
392 "Should not have trigger record initially"
393 );
394
395 tracker.record_trigger(&user_id);
397 prop_assert!(
398 tracker.has_trigger_record(&user_id),
399 "Should have trigger record after recording"
400 );
401
402 tracker.reset_cooldown(&user_id);
404 prop_assert!(
405 !tracker.has_trigger_record(&user_id),
406 "Should not have trigger record after reset"
407 );
408 }
409 }
410
411 #[test]
417 fn test_track_last_trigger_time_per_user() {
418 let mut tracker = CooldownTracker::new(Duration::from_secs(60));
419
420 assert!(!tracker.has_trigger_record("user1"));
422 assert!(!tracker.has_trigger_record("user2"));
423
424 tracker.record_trigger("user1");
426 assert!(tracker.has_trigger_record("user1"));
427 assert!(!tracker.has_trigger_record("user2"));
428
429 tracker.record_trigger("user2");
431 assert!(tracker.has_trigger_record("user1"));
432 assert!(tracker.has_trigger_record("user2"));
433 }
434
435 #[test]
438 fn test_reject_trigger_within_cooldown() {
439 let mut tracker = CooldownTracker::new(Duration::from_secs(60));
440
441 tracker.record_trigger("user1");
443
444 let result = tracker.check_cooldown("user1", TriggerType::Mention);
446 match result {
447 CooldownCheckResult::InCooldown { remaining } => {
448 assert!(remaining.as_secs() <= 60);
450 assert!(remaining.as_secs() >= 59);
451 }
452 CooldownCheckResult::Allowed => {
453 panic!("Should be in cooldown");
454 }
455 }
456 }
457
458 #[test]
460 fn test_configurable_cooldown_duration() {
461 let tracker_short = CooldownTracker::new(Duration::from_secs(10));
463 let tracker_long = CooldownTracker::new(Duration::from_secs(300));
464
465 assert_eq!(tracker_short.default_cooldown(), Duration::from_secs(10));
466 assert_eq!(tracker_long.default_cooldown(), Duration::from_secs(300));
467 }
468
469 #[test]
471 fn test_per_trigger_type_cooldown() {
472 let mut tracker = CooldownTracker::new(Duration::from_secs(60));
473
474 tracker.set_type_cooldown(TriggerType::Mention, Duration::from_secs(30));
476 tracker.set_type_cooldown(TriggerType::Keyword, Duration::from_secs(120));
477
478 assert_eq!(
480 tracker.get_type_cooldown(TriggerType::Mention),
481 Duration::from_secs(30)
482 );
483 assert_eq!(
484 tracker.get_type_cooldown(TriggerType::Keyword),
485 Duration::from_secs(120)
486 );
487 assert_eq!(
489 tracker.get_type_cooldown(TriggerType::DirectMessage),
490 Duration::from_secs(60)
491 );
492 }
493
494 #[test]
496 fn test_allow_after_cooldown_expires() {
497 let mut tracker = CooldownTracker::new(Duration::from_millis(50));
499
500 tracker.record_trigger("user1");
502
503 let result = tracker.check_cooldown("user1", TriggerType::Mention);
505 assert!(matches!(result, CooldownCheckResult::InCooldown { .. }));
506
507 thread::sleep(Duration::from_millis(60));
509
510 let result = tracker.check_cooldown("user1", TriggerType::Mention);
512 assert!(matches!(result, CooldownCheckResult::Allowed));
513 }
514
515 #[test]
517 fn test_remaining_cooldown_time_in_rejection() {
518 let mut tracker = CooldownTracker::new(Duration::from_secs(60));
519
520 tracker.record_trigger("user1");
522
523 let result = tracker.check_cooldown("user1", TriggerType::Mention);
525 match result {
526 CooldownCheckResult::InCooldown { remaining } => {
527 assert!(remaining > Duration::ZERO);
529 assert!(remaining <= Duration::from_secs(60));
530 }
531 CooldownCheckResult::Allowed => {
532 panic!("Should be in cooldown with remaining time");
533 }
534 }
535 }
536
537 #[test]
543 fn test_new_user_allowed() {
544 let tracker = CooldownTracker::new(Duration::from_secs(60));
545
546 let result = tracker.check_cooldown("new_user", TriggerType::Mention);
548 assert!(matches!(result, CooldownCheckResult::Allowed));
549 }
550
551 #[test]
553 fn test_reset_cooldown() {
554 let mut tracker = CooldownTracker::new(Duration::from_secs(60));
555
556 tracker.record_trigger("user1");
558 assert!(tracker.has_trigger_record("user1"));
559
560 tracker.reset_cooldown("user1");
562 assert!(!tracker.has_trigger_record("user1"));
563
564 let result = tracker.check_cooldown("user1", TriggerType::Mention);
566 assert!(matches!(result, CooldownCheckResult::Allowed));
567 }
568
569 #[test]
571 fn test_cleanup_expired() {
572 let mut tracker = CooldownTracker::new(Duration::from_millis(10));
573
574 tracker.record_trigger("user1");
576 tracker.record_trigger("user2");
577
578 thread::sleep(Duration::from_millis(30));
580
581 tracker.cleanup_expired();
583
584 assert!(!tracker.has_trigger_record("user1"));
586 assert!(!tracker.has_trigger_record("user2"));
587 }
588
589 #[test]
591 fn test_independent_user_cooldowns() {
592 let mut tracker = CooldownTracker::new(Duration::from_secs(60));
593
594 tracker.record_trigger("user1");
596
597 let result1 = tracker.check_cooldown("user1", TriggerType::Mention);
599 assert!(matches!(result1, CooldownCheckResult::InCooldown { .. }));
600
601 let result2 = tracker.check_cooldown("user2", TriggerType::Mention);
603 assert!(matches!(result2, CooldownCheckResult::Allowed));
604 }
605
606 #[test]
608 fn test_different_cooldown_per_type() {
609 let mut tracker = CooldownTracker::new(Duration::from_millis(100));
610 tracker.set_type_cooldown(TriggerType::Mention, Duration::from_millis(20));
611 tracker.set_type_cooldown(TriggerType::Keyword, Duration::from_millis(200));
612
613 tracker.record_trigger("user1");
615
616 thread::sleep(Duration::from_millis(30));
618
619 let result_mention = tracker.check_cooldown("user1", TriggerType::Mention);
621 assert!(matches!(result_mention, CooldownCheckResult::Allowed));
622
623 let result_keyword = tracker.check_cooldown("user1", TriggerType::Keyword);
625 assert!(matches!(
626 result_keyword,
627 CooldownCheckResult::InCooldown { .. }
628 ));
629 }
630
631 #[test]
633 fn test_cooldown_check_result_debug() {
634 let allowed = CooldownCheckResult::Allowed;
635 let in_cooldown = CooldownCheckResult::InCooldown {
636 remaining: Duration::from_secs(30),
637 };
638
639 let _ = format!("{:?}", allowed);
641 let _ = format!("{:?}", in_cooldown);
642 }
643
644 #[test]
646 fn test_cooldown_check_result_clone() {
647 let original = CooldownCheckResult::InCooldown {
648 remaining: Duration::from_secs(30),
649 };
650 let cloned = original.clone();
651
652 match cloned {
653 CooldownCheckResult::InCooldown { remaining } => {
654 assert_eq!(remaining, Duration::from_secs(30));
655 }
656 _ => panic!("Clone should preserve variant"),
657 }
658 }
659}