Skip to main content

robson_core/entities/
rate_limit.rs

1use anyhow::Result;
2use chrono::Utc;
3use sea_orm::entity::prelude::*;
4use sea_orm::{ActiveValue::Set, IntoActiveModel};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)]
8#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
9pub enum RateLimitScope {
10    #[sea_orm(string_value = "user")]
11    User,
12    #[sea_orm(string_value = "channel")]
13    Channel,
14}
15
16impl std::fmt::Display for RateLimitScope {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        match self {
19            RateLimitScope::User => write!(f, "user"),
20            RateLimitScope::Channel => write!(f, "channel"),
21        }
22    }
23}
24
25#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize, Deserialize)]
26#[sea_orm(table_name = "rate_limits")]
27pub struct Model {
28    #[sea_orm(primary_key)]
29    pub id: i32,
30    pub scope_type: RateLimitScope,
31    pub scope_id: String,
32    pub window_start: String,
33    pub request_count: i32,
34}
35
36#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
37pub enum Relation {}
38
39impl ActiveModelBehavior for ActiveModel {}
40
41impl Model {
42    /// Returns true if the request is allowed (within rate limit), false if rate limited.
43    /// Creates or updates the rate limit record for the given scope.
44    pub async fn check_and_increment(
45        db: &DatabaseConnection,
46        scope_type: RateLimitScope,
47        scope_id: &str,
48        window_secs: i64,
49        max_requests: i32,
50    ) -> Result<bool> {
51        let now = Utc::now();
52        let now_str = now.to_rfc3339();
53
54        let existing = Entity::find()
55            .filter(Column::ScopeType.eq(scope_type.clone()))
56            .filter(Column::ScopeId.eq(scope_id))
57            .one(db)
58            .await?;
59
60        match existing {
61            None => {
62                // First request: insert with count=1
63                let active = ActiveModel {
64                    scope_type: Set(scope_type),
65                    scope_id: Set(scope_id.to_string()),
66                    window_start: Set(now_str),
67                    request_count: Set(1),
68                    ..Default::default()
69                };
70                active.insert(db).await?;
71                Ok(true)
72            }
73            Some(record) => {
74                // Parse window_start to check if we're still in the window
75                let window_start = chrono::DateTime::parse_from_rfc3339(&record.window_start)
76                    .map(|dt| dt.with_timezone(&Utc))
77                    .unwrap_or(now);
78
79                let elapsed = (now - window_start).num_seconds();
80
81                if elapsed >= window_secs {
82                    // Window expired: reset
83                    let mut active = record.into_active_model();
84                    active.window_start = Set(now_str);
85                    active.request_count = Set(1);
86                    active.update(db).await?;
87                    Ok(true)
88                } else {
89                    let new_count = record.request_count + 1;
90                    if new_count > max_requests {
91                        // Rate limited
92                        Ok(false)
93                    } else {
94                        let mut active = record.into_active_model();
95                        active.request_count = Set(new_count);
96                        active.update(db).await?;
97                        Ok(true)
98                    }
99                }
100            }
101        }
102    }
103}