robson_core/entities/
rate_limit.rs1use anyhow::Result;
2use chrono::Utc;
3use sea_orm::entity::prelude::*;
4use sea_orm::{ActiveValue::Set, IntoActiveModel};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)]
8#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
9pub enum RateLimitScope {
10 #[sea_orm(string_value = "user")]
11 User,
12 #[sea_orm(string_value = "channel")]
13 Channel,
14}
15
16impl std::fmt::Display for RateLimitScope {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 RateLimitScope::User => write!(f, "user"),
20 RateLimitScope::Channel => write!(f, "channel"),
21 }
22 }
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize, Deserialize)]
26#[sea_orm(table_name = "rate_limits")]
27pub struct Model {
28 #[sea_orm(primary_key)]
29 pub id: i32,
30 pub scope_type: RateLimitScope,
31 pub scope_id: String,
32 pub window_start: String,
33 pub request_count: i32,
34}
35
36#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
37pub enum Relation {}
38
39impl ActiveModelBehavior for ActiveModel {}
40
41impl Model {
42 pub async fn check_and_increment(
45 db: &DatabaseConnection,
46 scope_type: RateLimitScope,
47 scope_id: &str,
48 window_secs: i64,
49 max_requests: i32,
50 ) -> Result<bool> {
51 let now = Utc::now();
52 let now_str = now.to_rfc3339();
53
54 let existing = Entity::find()
55 .filter(Column::ScopeType.eq(scope_type.clone()))
56 .filter(Column::ScopeId.eq(scope_id))
57 .one(db)
58 .await?;
59
60 match existing {
61 None => {
62 let active = ActiveModel {
64 scope_type: Set(scope_type),
65 scope_id: Set(scope_id.to_string()),
66 window_start: Set(now_str),
67 request_count: Set(1),
68 ..Default::default()
69 };
70 active.insert(db).await?;
71 Ok(true)
72 }
73 Some(record) => {
74 let window_start = chrono::DateTime::parse_from_rfc3339(&record.window_start)
76 .map(|dt| dt.with_timezone(&Utc))
77 .unwrap_or(now);
78
79 let elapsed = (now - window_start).num_seconds();
80
81 if elapsed >= window_secs {
82 let mut active = record.into_active_model();
84 active.window_start = Set(now_str);
85 active.request_count = Set(1);
86 active.update(db).await?;
87 Ok(true)
88 } else {
89 let new_count = record.request_count + 1;
90 if new_count > max_requests {
91 Ok(false)
93 } else {
94 let mut active = record.into_active_model();
95 active.request_count = Set(new_count);
96 active.update(db).await?;
97 Ok(true)
98 }
99 }
100 }
101 }
102 }
103}