use anyhow::Result;
use chrono::Utc;
use sea_orm::entity::prelude::*;
use sea_orm::{ActiveValue::Set, IntoActiveModel};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
pub enum RateLimitScope {
#[sea_orm(string_value = "user")]
User,
#[sea_orm(string_value = "channel")]
Channel,
}
impl std::fmt::Display for RateLimitScope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RateLimitScope::User => write!(f, "user"),
RateLimitScope::Channel => write!(f, "channel"),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "rate_limits")]
pub struct Model {
#[sea_orm(primary_key)]
pub id: i32,
pub scope_type: RateLimitScope,
pub scope_id: String,
pub window_start: String,
pub request_count: i32,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}
impl ActiveModelBehavior for ActiveModel {}
impl Model {
pub async fn check_and_increment(
db: &DatabaseConnection,
scope_type: RateLimitScope,
scope_id: &str,
window_secs: i64,
max_requests: i32,
) -> Result<bool> {
let now = Utc::now();
let now_str = now.to_rfc3339();
let existing = Entity::find()
.filter(Column::ScopeType.eq(scope_type.clone()))
.filter(Column::ScopeId.eq(scope_id))
.one(db)
.await?;
match existing {
None => {
let active = ActiveModel {
scope_type: Set(scope_type),
scope_id: Set(scope_id.to_string()),
window_start: Set(now_str),
request_count: Set(1),
..Default::default()
};
active.insert(db).await?;
Ok(true)
}
Some(record) => {
let window_start = chrono::DateTime::parse_from_rfc3339(&record.window_start)
.map(|dt| dt.with_timezone(&Utc))
.unwrap_or(now);
let elapsed = (now - window_start).num_seconds();
if elapsed >= window_secs {
let mut active = record.into_active_model();
active.window_start = Set(now_str);
active.request_count = Set(1);
active.update(db).await?;
Ok(true)
} else {
let new_count = record.request_count + 1;
if new_count > max_requests {
Ok(false)
} else {
let mut active = record.into_active_model();
active.request_count = Set(new_count);
active.update(db).await?;
Ok(true)
}
}
}
}
}
}