use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use crate::error::{BitcoinError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum LimitPeriod {
Hourly,
Daily,
Weekly,
Monthly,
}
impl LimitPeriod {
pub fn duration(&self) -> Duration {
match self {
LimitPeriod::Hourly => Duration::hours(1),
LimitPeriod::Daily => Duration::days(1),
LimitPeriod::Weekly => Duration::weeks(1),
LimitPeriod::Monthly => Duration::days(30),
}
}
pub fn period_start(&self, now: DateTime<Utc>) -> DateTime<Utc> {
match self {
LimitPeriod::Hourly => now - Duration::hours(1),
LimitPeriod::Daily => now - Duration::days(1),
LimitPeriod::Weekly => now - Duration::weeks(1),
LimitPeriod::Monthly => now - Duration::days(30),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransactionLimit {
pub max_amount_sats: u64,
pub max_count: u32,
pub period: LimitPeriod,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LimitConfig {
pub user_limits: Vec<TransactionLimit>,
pub platform_limits: Vec<TransactionLimit>,
pub single_tx_max_sats: u64,
}
impl Default for LimitConfig {
fn default() -> Self {
Self {
user_limits: vec![
TransactionLimit {
max_amount_sats: 10_000_000, max_count: 10,
period: LimitPeriod::Daily,
},
TransactionLimit {
max_amount_sats: 50_000_000, max_count: 100,
period: LimitPeriod::Monthly,
},
],
platform_limits: vec![
TransactionLimit {
max_amount_sats: 100_000_000, max_count: 100,
period: LimitPeriod::Hourly,
},
TransactionLimit {
max_amount_sats: 1_000_000_000, max_count: 1000,
period: LimitPeriod::Daily,
},
],
single_tx_max_sats: 50_000_000, }
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageRecord {
pub user_id: String,
pub amount_sats: u64,
pub timestamp: DateTime<Utc>,
pub txid: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LimitViolation {
pub limit_type: String,
pub period: LimitPeriod,
pub current_usage_sats: u64,
pub limit_sats: u64,
pub current_count: u32,
pub max_count: u32,
}
pub struct LimitEnforcer {
config: LimitConfig,
user_usage: Arc<RwLock<HashMap<String, Vec<UsageRecord>>>>,
platform_usage: Arc<RwLock<Vec<UsageRecord>>>,
}
impl LimitEnforcer {
pub fn new(config: LimitConfig) -> Self {
Self {
config,
user_usage: Arc::new(RwLock::new(HashMap::new())),
platform_usage: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn check_transaction(&self, user_id: &str, amount_sats: u64) -> Result<()> {
if amount_sats > self.config.single_tx_max_sats {
return Err(BitcoinError::LimitExceeded(format!(
"Transaction amount {} sats exceeds single transaction limit of {} sats",
amount_sats, self.config.single_tx_max_sats
)));
}
self.check_user_limits(user_id, amount_sats).await?;
self.check_platform_limits(amount_sats).await?;
Ok(())
}
pub async fn record_transaction(&self, user_id: &str, amount_sats: u64, txid: Option<String>) {
let record = UsageRecord {
user_id: user_id.to_string(),
amount_sats,
timestamp: Utc::now(),
txid,
};
let mut user_usage = self.user_usage.write().await;
user_usage
.entry(user_id.to_string())
.or_insert_with(Vec::new)
.push(record.clone());
let mut platform_usage = self.platform_usage.write().await;
platform_usage.push(record);
drop(user_usage);
drop(platform_usage);
self.cleanup_old_records().await;
}
pub async fn get_user_usage(&self, user_id: &str, period: LimitPeriod) -> (u64, u32) {
let user_usage = self.user_usage.read().await;
let records = user_usage.get(user_id);
if records.is_none() {
return (0, 0);
}
let now = Utc::now();
let period_start = period.period_start(now);
let filtered: Vec<_> = records
.unwrap()
.iter()
.filter(|r| r.timestamp >= period_start)
.collect();
let total_amount: u64 = filtered.iter().map(|r| r.amount_sats).sum();
let count = filtered.len() as u32;
(total_amount, count)
}
pub async fn get_platform_usage(&self, period: LimitPeriod) -> (u64, u32) {
let platform_usage = self.platform_usage.read().await;
let now = Utc::now();
let period_start = period.period_start(now);
let filtered: Vec<_> = platform_usage
.iter()
.filter(|r| r.timestamp >= period_start)
.collect();
let total_amount: u64 = filtered.iter().map(|r| r.amount_sats).sum();
let count = filtered.len() as u32;
(total_amount, count)
}
async fn check_user_limits(&self, user_id: &str, amount_sats: u64) -> Result<()> {
for limit in &self.config.user_limits {
let (current_usage, current_count) = self.get_user_usage(user_id, limit.period).await;
if current_usage + amount_sats > limit.max_amount_sats {
return Err(BitcoinError::LimitExceeded(format!(
"User {:?} limit exceeded for amount: {} + {} > {} sats",
limit.period, current_usage, amount_sats, limit.max_amount_sats
)));
}
if current_count + 1 > limit.max_count {
return Err(BitcoinError::LimitExceeded(format!(
"User {:?} limit exceeded for count: {} + 1 > {}",
limit.period, current_count, limit.max_count
)));
}
}
Ok(())
}
async fn check_platform_limits(&self, amount_sats: u64) -> Result<()> {
for limit in &self.config.platform_limits {
let (current_usage, current_count) = self.get_platform_usage(limit.period).await;
if current_usage + amount_sats > limit.max_amount_sats {
return Err(BitcoinError::LimitExceeded(format!(
"Platform {:?} limit exceeded for amount: {} + {} > {} sats",
limit.period, current_usage, amount_sats, limit.max_amount_sats
)));
}
if current_count + 1 > limit.max_count {
return Err(BitcoinError::LimitExceeded(format!(
"Platform {:?} limit exceeded for count: {} + 1 > {}",
limit.period, current_count, limit.max_count
)));
}
}
Ok(())
}
async fn cleanup_old_records(&self) {
let now = Utc::now();
let max_period = Duration::days(30); let cutoff = now - max_period;
let mut user_usage = self.user_usage.write().await;
for records in user_usage.values_mut() {
records.retain(|r| r.timestamp >= cutoff);
}
let mut platform_usage = self.platform_usage.write().await;
platform_usage.retain(|r| r.timestamp >= cutoff);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_limit_period_duration() {
assert_eq!(LimitPeriod::Hourly.duration(), Duration::hours(1));
assert_eq!(LimitPeriod::Daily.duration(), Duration::days(1));
assert_eq!(LimitPeriod::Weekly.duration(), Duration::weeks(1));
assert_eq!(LimitPeriod::Monthly.duration(), Duration::days(30));
}
#[test]
fn test_limit_config_defaults() {
let config = LimitConfig::default();
assert_eq!(config.user_limits.len(), 2);
assert_eq!(config.platform_limits.len(), 2);
assert_eq!(config.single_tx_max_sats, 50_000_000);
}
#[tokio::test]
async fn test_single_tx_limit() {
let enforcer = LimitEnforcer::new(LimitConfig::default());
let result = enforcer.check_transaction("user1", 100_000_000).await;
assert!(result.is_err());
let result = enforcer.check_transaction("user1", 1_000_000).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_usage_tracking() {
let enforcer = LimitEnforcer::new(LimitConfig::default());
enforcer.record_transaction("user1", 1_000_000, None).await;
enforcer.record_transaction("user1", 2_000_000, None).await;
let (usage, count) = enforcer.get_user_usage("user1", LimitPeriod::Daily).await;
assert_eq!(usage, 3_000_000);
assert_eq!(count, 2);
}
}