1use std::sync::Arc;
18use std::time::{Duration, Instant};
19
20use dashmap::DashMap;
21
22use crate::adapter::PairingChannelAdapter;
23use crate::store::PairingStore;
24use crate::types::{Decision, PairingError, PairingPolicy};
25
26const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(30);
27
28#[derive(Clone)]
29struct CacheEntry {
30 decision: Decision,
31 expires_at: Instant,
32}
33
34pub struct PairingGate {
35 store: Arc<PairingStore>,
36 cache: DashMap<String, CacheEntry>,
37 cache_ttl: Duration,
38}
39
40impl PairingGate {
41 pub fn new(store: Arc<PairingStore>) -> Self {
42 Self {
43 store,
44 cache: DashMap::new(),
45 cache_ttl: DEFAULT_CACHE_TTL,
46 }
47 }
48
49 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
50 self.cache_ttl = ttl;
51 self
52 }
53
54 pub fn flush_cache(&self) {
57 self.cache.clear();
58 }
59
60 pub async fn should_admit(
64 &self,
65 channel: &str,
66 account_id: &str,
67 sender_id: &str,
68 policy: &PairingPolicy,
69 adapter: Option<&dyn PairingChannelAdapter>,
70 ) -> Result<Decision, PairingError> {
71 if !policy.auto_challenge {
74 return Ok(Decision::Admit);
75 }
76
77 let normalized: String = adapter
84 .and_then(|a| a.normalize_sender(sender_id))
85 .unwrap_or_else(|| sender_id.to_string());
86 let sender_id = normalized.as_str();
87
88 let key = cache_key(channel, account_id, sender_id);
89 if let Some(entry) = self.cache.get(&key) {
90 if entry.expires_at > Instant::now() {
91 return Ok(entry.decision.clone());
92 }
93 }
94
95 let decision = if self
96 .store
97 .is_allowed(channel, account_id, sender_id)
98 .await?
99 {
100 Decision::Admit
101 } else {
102 match self
103 .store
104 .upsert_pending(channel, account_id, sender_id, serde_json::Value::Null)
105 .await
106 {
107 Ok(out) => Decision::Challenge { code: out.code },
108 Err(PairingError::MaxPending { .. }) => Decision::Drop,
109 Err(e) => return Err(e),
110 }
111 };
112
113 self.cache.insert(
114 key,
115 CacheEntry {
116 decision: decision.clone(),
117 expires_at: Instant::now() + self.cache_ttl,
118 },
119 );
120 Ok(decision)
121 }
122}
123
124fn cache_key(channel: &str, account_id: &str, sender_id: &str) -> String {
125 format!("{channel}|{account_id}|{sender_id}")
126}
127
128#[cfg(test)]
129mod tests {
130 use super::*;
131
132 fn allow() -> PairingPolicy {
133 PairingPolicy {
134 auto_challenge: true,
135 }
136 }
137 fn off() -> PairingPolicy {
138 PairingPolicy {
139 auto_challenge: false,
140 }
141 }
142
143 #[tokio::test]
144 async fn gate_admits_when_policy_off() {
145 let store = Arc::new(PairingStore::open_memory().await.unwrap());
146 let gate = PairingGate::new(store);
147 let d = gate
148 .should_admit("wa", "p", "+57", &off(), None)
149 .await
150 .unwrap();
151 assert!(matches!(d, Decision::Admit));
152 }
153
154 #[tokio::test]
155 async fn first_unknown_sender_gets_challenge_with_code() {
156 let store = Arc::new(PairingStore::open_memory().await.unwrap());
157 let gate = PairingGate::new(store);
158 let d = gate
159 .should_admit("wa", "p", "+57", &allow(), None)
160 .await
161 .unwrap();
162 match d {
163 Decision::Challenge { code } => assert_eq!(code.len(), crate::code::LENGTH),
164 other => panic!("expected challenge, got {other:?}"),
165 }
166 }
167
168 #[tokio::test]
169 async fn approved_sender_admits_after_cache_flush() {
170 let store = Arc::new(PairingStore::open_memory().await.unwrap());
171 let gate = PairingGate::new(Arc::clone(&store));
172 let d1 = gate
173 .should_admit("wa", "p", "+57", &allow(), None)
174 .await
175 .unwrap();
176 let code = match d1 {
177 Decision::Challenge { code } => code,
178 other => panic!("{other:?}"),
179 };
180 store.approve(&code).await.unwrap();
181 gate.flush_cache();
182 let d2 = gate
183 .should_admit("wa", "p", "+57", &allow(), None)
184 .await
185 .unwrap();
186 assert_eq!(d2, Decision::Admit);
187 }
188
189 #[tokio::test]
190 async fn cache_returns_same_decision_within_ttl() {
191 let store = Arc::new(PairingStore::open_memory().await.unwrap());
192 let gate = PairingGate::new(store);
193 let d1 = gate
194 .should_admit("wa", "p", "+57", &allow(), None)
195 .await
196 .unwrap();
197 let d2 = gate
198 .should_admit("wa", "p", "+57", &allow(), None)
199 .await
200 .unwrap();
201 assert_eq!(d1, d2);
202 }
203
204 #[tokio::test]
205 async fn fourth_unknown_sender_drops_due_to_max_pending() {
206 let store = Arc::new(PairingStore::open_memory().await.unwrap());
207 let gate = PairingGate::new(store);
208 for i in 1..=3 {
209 let s = format!("+5710000000{i}");
210 let d = gate
211 .should_admit("wa", "p", &s, &allow(), None)
212 .await
213 .unwrap();
214 assert!(matches!(d, Decision::Challenge { .. }));
215 }
216 let d4 = gate
217 .should_admit("wa", "p", "+571000000099", &allow(), None)
218 .await
219 .unwrap();
220 assert_eq!(d4, Decision::Drop);
221 }
222}