robson-core 0.1.0

Rust async agent orchestrator for automated development workflows
Documentation
use anyhow::Result;
use chrono::Utc;
use sea_orm::entity::prelude::*;
use sea_orm::{ActiveValue::Set, IntoActiveModel};
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, PartialEq, Eq, EnumIter, DeriveActiveEnum, Serialize, Deserialize)]
#[sea_orm(rs_type = "String", db_type = "String(StringLen::None)")]
pub enum RateLimitScope {
    #[sea_orm(string_value = "user")]
    User,
    #[sea_orm(string_value = "channel")]
    Channel,
}

impl std::fmt::Display for RateLimitScope {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            RateLimitScope::User => write!(f, "user"),
            RateLimitScope::Channel => write!(f, "channel"),
        }
    }
}

#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel, Serialize, Deserialize)]
#[sea_orm(table_name = "rate_limits")]
pub struct Model {
    #[sea_orm(primary_key)]
    pub id: i32,
    pub scope_type: RateLimitScope,
    pub scope_id: String,
    pub window_start: String,
    pub request_count: i32,
}

#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
pub enum Relation {}

impl ActiveModelBehavior for ActiveModel {}

impl Model {
    /// Returns true if the request is allowed (within rate limit), false if rate limited.
    /// Creates or updates the rate limit record for the given scope.
    pub async fn check_and_increment(
        db: &DatabaseConnection,
        scope_type: RateLimitScope,
        scope_id: &str,
        window_secs: i64,
        max_requests: i32,
    ) -> Result<bool> {
        let now = Utc::now();
        let now_str = now.to_rfc3339();

        let existing = Entity::find()
            .filter(Column::ScopeType.eq(scope_type.clone()))
            .filter(Column::ScopeId.eq(scope_id))
            .one(db)
            .await?;

        match existing {
            None => {
                // First request: insert with count=1
                let active = ActiveModel {
                    scope_type: Set(scope_type),
                    scope_id: Set(scope_id.to_string()),
                    window_start: Set(now_str),
                    request_count: Set(1),
                    ..Default::default()
                };
                active.insert(db).await?;
                Ok(true)
            }
            Some(record) => {
                // Parse window_start to check if we're still in the window
                let window_start = chrono::DateTime::parse_from_rfc3339(&record.window_start)
                    .map(|dt| dt.with_timezone(&Utc))
                    .unwrap_or(now);

                let elapsed = (now - window_start).num_seconds();

                if elapsed >= window_secs {
                    // Window expired: reset
                    let mut active = record.into_active_model();
                    active.window_start = Set(now_str);
                    active.request_count = Set(1);
                    active.update(db).await?;
                    Ok(true)
                } else {
                    let new_count = record.request_count + 1;
                    if new_count > max_requests {
                        // Rate limited
                        Ok(false)
                    } else {
                        let mut active = record.into_active_model();
                        active.request_count = Set(new_count);
                        active.update(db).await?;
                        Ok(true)
                    }
                }
            }
        }
    }
}