1use chrono::{DateTime, Duration, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12use crate::error::{BitcoinError, Result};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16pub enum LimitPeriod {
17 Hourly,
19 Daily,
21 Weekly,
23 Monthly,
25}
26
27impl LimitPeriod {
28 pub fn duration(&self) -> Duration {
30 match self {
31 LimitPeriod::Hourly => Duration::hours(1),
32 LimitPeriod::Daily => Duration::days(1),
33 LimitPeriod::Weekly => Duration::weeks(1),
34 LimitPeriod::Monthly => Duration::days(30),
35 }
36 }
37
38 pub fn period_start(&self, now: DateTime<Utc>) -> DateTime<Utc> {
40 match self {
41 LimitPeriod::Hourly => now - Duration::hours(1),
42 LimitPeriod::Daily => now - Duration::days(1),
43 LimitPeriod::Weekly => now - Duration::weeks(1),
44 LimitPeriod::Monthly => now - Duration::days(30),
45 }
46 }
47}
48
49#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct TransactionLimit {
52 pub max_amount_sats: u64,
54 pub max_count: u32,
56 pub period: LimitPeriod,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct LimitConfig {
63 pub user_limits: Vec<TransactionLimit>,
65 pub platform_limits: Vec<TransactionLimit>,
67 pub single_tx_max_sats: u64,
69}
70
71impl Default for LimitConfig {
72 fn default() -> Self {
73 Self {
74 user_limits: vec![
75 TransactionLimit {
76 max_amount_sats: 10_000_000, max_count: 10,
78 period: LimitPeriod::Daily,
79 },
80 TransactionLimit {
81 max_amount_sats: 50_000_000, max_count: 100,
83 period: LimitPeriod::Monthly,
84 },
85 ],
86 platform_limits: vec![
87 TransactionLimit {
88 max_amount_sats: 100_000_000, max_count: 100,
90 period: LimitPeriod::Hourly,
91 },
92 TransactionLimit {
93 max_amount_sats: 1_000_000_000, max_count: 1000,
95 period: LimitPeriod::Daily,
96 },
97 ],
98 single_tx_max_sats: 50_000_000, }
100 }
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct UsageRecord {
106 pub user_id: String,
108 pub amount_sats: u64,
110 pub timestamp: DateTime<Utc>,
112 pub txid: Option<String>,
114}
115
116#[derive(Debug, Clone, Serialize, Deserialize)]
118pub struct LimitViolation {
119 pub limit_type: String,
121 pub period: LimitPeriod,
123 pub current_usage_sats: u64,
125 pub limit_sats: u64,
127 pub current_count: u32,
129 pub max_count: u32,
131}
132
133pub struct LimitEnforcer {
135 config: LimitConfig,
136 user_usage: Arc<RwLock<HashMap<String, Vec<UsageRecord>>>>,
138 platform_usage: Arc<RwLock<Vec<UsageRecord>>>,
140}
141
142impl LimitEnforcer {
143 pub fn new(config: LimitConfig) -> Self {
145 Self {
146 config,
147 user_usage: Arc::new(RwLock::new(HashMap::new())),
148 platform_usage: Arc::new(RwLock::new(Vec::new())),
149 }
150 }
151
152 pub async fn check_transaction(&self, user_id: &str, amount_sats: u64) -> Result<()> {
154 if amount_sats > self.config.single_tx_max_sats {
156 return Err(BitcoinError::LimitExceeded(format!(
157 "Transaction amount {} sats exceeds single transaction limit of {} sats",
158 amount_sats, self.config.single_tx_max_sats
159 )));
160 }
161
162 self.check_user_limits(user_id, amount_sats).await?;
164
165 self.check_platform_limits(amount_sats).await?;
167
168 Ok(())
169 }
170
171 pub async fn record_transaction(&self, user_id: &str, amount_sats: u64, txid: Option<String>) {
173 let record = UsageRecord {
174 user_id: user_id.to_string(),
175 amount_sats,
176 timestamp: Utc::now(),
177 txid,
178 };
179
180 let mut user_usage = self.user_usage.write().await;
182 user_usage
183 .entry(user_id.to_string())
184 .or_insert_with(Vec::new)
185 .push(record.clone());
186
187 let mut platform_usage = self.platform_usage.write().await;
189 platform_usage.push(record);
190
191 drop(user_usage);
193 drop(platform_usage);
194 self.cleanup_old_records().await;
195 }
196
197 pub async fn get_user_usage(&self, user_id: &str, period: LimitPeriod) -> (u64, u32) {
199 let user_usage = self.user_usage.read().await;
200 let records = user_usage.get(user_id);
201
202 if records.is_none() {
203 return (0, 0);
204 }
205
206 let now = Utc::now();
207 let period_start = period.period_start(now);
208
209 let filtered: Vec<_> = records
210 .unwrap()
211 .iter()
212 .filter(|r| r.timestamp >= period_start)
213 .collect();
214
215 let total_amount: u64 = filtered.iter().map(|r| r.amount_sats).sum();
216 let count = filtered.len() as u32;
217
218 (total_amount, count)
219 }
220
221 pub async fn get_platform_usage(&self, period: LimitPeriod) -> (u64, u32) {
223 let platform_usage = self.platform_usage.read().await;
224
225 let now = Utc::now();
226 let period_start = period.period_start(now);
227
228 let filtered: Vec<_> = platform_usage
229 .iter()
230 .filter(|r| r.timestamp >= period_start)
231 .collect();
232
233 let total_amount: u64 = filtered.iter().map(|r| r.amount_sats).sum();
234 let count = filtered.len() as u32;
235
236 (total_amount, count)
237 }
238
239 async fn check_user_limits(&self, user_id: &str, amount_sats: u64) -> Result<()> {
241 for limit in &self.config.user_limits {
242 let (current_usage, current_count) = self.get_user_usage(user_id, limit.period).await;
243
244 if current_usage + amount_sats > limit.max_amount_sats {
246 return Err(BitcoinError::LimitExceeded(format!(
247 "User {:?} limit exceeded for amount: {} + {} > {} sats",
248 limit.period, current_usage, amount_sats, limit.max_amount_sats
249 )));
250 }
251
252 if current_count + 1 > limit.max_count {
254 return Err(BitcoinError::LimitExceeded(format!(
255 "User {:?} limit exceeded for count: {} + 1 > {}",
256 limit.period, current_count, limit.max_count
257 )));
258 }
259 }
260
261 Ok(())
262 }
263
264 async fn check_platform_limits(&self, amount_sats: u64) -> Result<()> {
266 for limit in &self.config.platform_limits {
267 let (current_usage, current_count) = self.get_platform_usage(limit.period).await;
268
269 if current_usage + amount_sats > limit.max_amount_sats {
271 return Err(BitcoinError::LimitExceeded(format!(
272 "Platform {:?} limit exceeded for amount: {} + {} > {} sats",
273 limit.period, current_usage, amount_sats, limit.max_amount_sats
274 )));
275 }
276
277 if current_count + 1 > limit.max_count {
279 return Err(BitcoinError::LimitExceeded(format!(
280 "Platform {:?} limit exceeded for count: {} + 1 > {}",
281 limit.period, current_count, limit.max_count
282 )));
283 }
284 }
285
286 Ok(())
287 }
288
289 async fn cleanup_old_records(&self) {
291 let now = Utc::now();
292 let max_period = Duration::days(30); let cutoff = now - max_period;
294
295 let mut user_usage = self.user_usage.write().await;
297 for records in user_usage.values_mut() {
298 records.retain(|r| r.timestamp >= cutoff);
299 }
300
301 let mut platform_usage = self.platform_usage.write().await;
303 platform_usage.retain(|r| r.timestamp >= cutoff);
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 #[test]
312 fn test_limit_period_duration() {
313 assert_eq!(LimitPeriod::Hourly.duration(), Duration::hours(1));
314 assert_eq!(LimitPeriod::Daily.duration(), Duration::days(1));
315 assert_eq!(LimitPeriod::Weekly.duration(), Duration::weeks(1));
316 assert_eq!(LimitPeriod::Monthly.duration(), Duration::days(30));
317 }
318
319 #[test]
320 fn test_limit_config_defaults() {
321 let config = LimitConfig::default();
322 assert_eq!(config.user_limits.len(), 2);
323 assert_eq!(config.platform_limits.len(), 2);
324 assert_eq!(config.single_tx_max_sats, 50_000_000);
325 }
326
327 #[tokio::test]
328 async fn test_single_tx_limit() {
329 let enforcer = LimitEnforcer::new(LimitConfig::default());
330
331 let result = enforcer.check_transaction("user1", 100_000_000).await;
333 assert!(result.is_err());
334
335 let result = enforcer.check_transaction("user1", 1_000_000).await;
337 assert!(result.is_ok());
338 }
339
340 #[tokio::test]
341 async fn test_usage_tracking() {
342 let enforcer = LimitEnforcer::new(LimitConfig::default());
343
344 enforcer.record_transaction("user1", 1_000_000, None).await;
346 enforcer.record_transaction("user1", 2_000_000, None).await;
347
348 let (usage, count) = enforcer.get_user_usage("user1", LimitPeriod::Daily).await;
350 assert_eq!(usage, 3_000_000);
351 assert_eq!(count, 2);
352 }
353}