use super::super::DbPool;
use super::{RateLimit, DEFAULT_ACCOUNT_ID};
use crate::error::StorageError;
use chrono::{DateTime, Utc};
pub async fn check_rate_limit_for(
pool: &DbPool,
account_id: &str,
action_type: &str,
) -> Result<bool, StorageError> {
let mut tx = pool
.begin()
.await
.map_err(|e| StorageError::Connection { source: e })?;
let row = sqlx::query_as::<_, RateLimit>(
"SELECT action_type, request_count, period_start, max_requests, period_seconds \
FROM rate_limits WHERE account_id = ? AND action_type = ?",
)
.bind(account_id)
.bind(action_type)
.fetch_optional(&mut *tx)
.await
.map_err(|e| StorageError::Query { source: e })?;
let limit = match row {
Some(l) => l,
None => {
tx.commit()
.await
.map_err(|e| StorageError::Connection { source: e })?;
return Ok(true);
}
};
let now = Utc::now();
let period_start = limit.period_start.parse::<DateTime<Utc>>().unwrap_or(now);
let elapsed = now.signed_duration_since(period_start).num_seconds();
if elapsed >= limit.period_seconds {
sqlx::query(
"UPDATE rate_limits SET request_count = 0, period_start = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') \
WHERE account_id = ? AND action_type = ?",
)
.bind(account_id)
.bind(action_type)
.execute(&mut *tx)
.await
.map_err(|e| StorageError::Query { source: e })?;
tx.commit()
.await
.map_err(|e| StorageError::Connection { source: e })?;
return Ok(true);
}
let under_limit = limit.request_count < limit.max_requests;
tx.commit()
.await
.map_err(|e| StorageError::Connection { source: e })?;
Ok(under_limit)
}
pub async fn check_rate_limit(pool: &DbPool, action_type: &str) -> Result<bool, StorageError> {
check_rate_limit_for(pool, DEFAULT_ACCOUNT_ID, action_type).await
}
pub async fn check_and_increment_rate_limit_for(
pool: &DbPool,
account_id: &str,
action_type: &str,
) -> Result<bool, StorageError> {
let mut tx = pool
.begin()
.await
.map_err(|e| StorageError::Connection { source: e })?;
let row = sqlx::query_as::<_, RateLimit>(
"SELECT action_type, request_count, period_start, max_requests, period_seconds \
FROM rate_limits WHERE account_id = ? AND action_type = ?",
)
.bind(account_id)
.bind(action_type)
.fetch_optional(&mut *tx)
.await
.map_err(|e| StorageError::Query { source: e })?;
let limit = match row {
Some(l) => l,
None => {
tx.commit()
.await
.map_err(|e| StorageError::Connection { source: e })?;
return Ok(true);
}
};
let now = Utc::now();
let period_start = limit.period_start.parse::<DateTime<Utc>>().unwrap_or(now);
let elapsed = now.signed_duration_since(period_start).num_seconds();
if elapsed >= limit.period_seconds {
sqlx::query(
"UPDATE rate_limits SET request_count = 1, period_start = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') \
WHERE account_id = ? AND action_type = ?",
)
.bind(account_id)
.bind(action_type)
.execute(&mut *tx)
.await
.map_err(|e| StorageError::Query { source: e })?;
tx.commit()
.await
.map_err(|e| StorageError::Connection { source: e })?;
return Ok(true);
}
if limit.request_count < limit.max_requests {
sqlx::query("UPDATE rate_limits SET request_count = request_count + 1 WHERE account_id = ? AND action_type = ?")
.bind(account_id)
.bind(action_type)
.execute(&mut *tx)
.await
.map_err(|e| StorageError::Query { source: e })?;
tx.commit()
.await
.map_err(|e| StorageError::Connection { source: e })?;
Ok(true)
} else {
tx.commit()
.await
.map_err(|e| StorageError::Connection { source: e })?;
Ok(false)
}
}
pub async fn check_and_increment_rate_limit(
pool: &DbPool,
action_type: &str,
) -> Result<bool, StorageError> {
check_and_increment_rate_limit_for(pool, DEFAULT_ACCOUNT_ID, action_type).await
}
pub async fn increment_rate_limit_for(
pool: &DbPool,
account_id: &str,
action_type: &str,
) -> Result<(), StorageError> {
sqlx::query("UPDATE rate_limits SET request_count = request_count + 1 WHERE account_id = ? AND action_type = ?")
.bind(account_id)
.bind(action_type)
.execute(pool)
.await
.map_err(|e| StorageError::Query { source: e })?;
Ok(())
}
pub async fn increment_rate_limit(pool: &DbPool, action_type: &str) -> Result<(), StorageError> {
increment_rate_limit_for(pool, DEFAULT_ACCOUNT_ID, action_type).await
}