openauth-core 0.0.2

Core types and primitives for OpenAuth.
Documentation
//! Router-level rate limiting.

use crate::context::AuthContext;
use crate::env::is_production;
use crate::error::OpenAuthError;
use crate::options::{RateLimitRecord, RateLimitRule, RateLimitStorage, RateLimitStorageOption};
use crate::utils::ip::{
    create_rate_limit_key, is_valid_ip, normalize_ip_with_options, NormalizeIpOptions,
};
use crate::utils::url::normalize_pathname;
use http::Request;
use std::collections::HashMap;
use std::sync::Mutex;

pub type Body = Vec<u8>;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct RateLimitRejection {
    pub retry_after: u64,
}

#[derive(Debug, Default)]
pub struct MemoryRateLimitStorage {
    entries: Mutex<HashMap<String, MemoryRateLimitEntry>>,
}

#[derive(Debug, Clone)]
struct MemoryRateLimitEntry {
    record: RateLimitRecord,
    expires_at: i64,
}

impl MemoryRateLimitStorage {
    pub fn new() -> Self {
        Self::default()
    }
}

impl RateLimitStorage for MemoryRateLimitStorage {
    fn get(&self, key: &str) -> Result<Option<RateLimitRecord>, OpenAuthError> {
        let now = now_seconds();
        let mut entries = self
            .entries
            .lock()
            .map_err(|_| OpenAuthError::Api("rate limit storage lock poisoned".to_owned()))?;
        let Some(entry) = entries.get(key) else {
            return Ok(None);
        };
        if now >= entry.expires_at {
            entries.remove(key);
            return Ok(None);
        }
        Ok(Some(entry.record.clone()))
    }

    fn set(
        &self,
        key: &str,
        value: RateLimitRecord,
        ttl_seconds: u64,
        _update: bool,
    ) -> Result<(), OpenAuthError> {
        let expires_at = now_seconds().saturating_add(ttl_seconds as i64);
        let mut entries = self
            .entries
            .lock()
            .map_err(|_| OpenAuthError::Api("rate limit storage lock poisoned".to_owned()))?;
        entries.insert(
            key.to_owned(),
            MemoryRateLimitEntry {
                record: value,
                expires_at,
            },
        );
        Ok(())
    }
}

pub fn on_request_rate_limit(
    context: &AuthContext,
    request: &Request<Body>,
) -> Result<Option<RateLimitRejection>, OpenAuthError> {
    if !context.rate_limit.enabled {
        return Ok(None);
    }
    let Some(config) = resolve_config(context, request)? else {
        return Ok(None);
    };
    let storage = storage(context);
    let Some(record) = storage.get(&config.key)? else {
        return Ok(None);
    };

    if should_rate_limit(config.rule.max, config.rule.window, &record) {
        return Ok(Some(RateLimitRejection {
            retry_after: retry_after(record.last_request, config.rule.window),
        }));
    }

    Ok(None)
}

pub fn on_response_rate_limit(
    context: &AuthContext,
    request: &Request<Body>,
) -> Result<(), OpenAuthError> {
    if !context.rate_limit.enabled {
        return Ok(());
    }
    let Some(config) = resolve_config(context, request)? else {
        return Ok(());
    };
    let storage = storage(context);
    let now = now_seconds();
    let next_record = match storage.get(&config.key)? {
        Some(record) if now.saturating_sub(record.last_request) <= config.rule.window as i64 => {
            RateLimitRecord {
                key: config.key.clone(),
                count: record.count.saturating_add(1),
                last_request: now,
            }
        }
        _ => RateLimitRecord {
            key: config.key.clone(),
            count: 1,
            last_request: now,
        },
    };

    let update = next_record.count > 1;
    storage.set(&config.key, next_record, config.rule.window, update)
}

#[derive(Debug)]
struct ResolvedRateLimit {
    key: String,
    rule: RateLimitRule,
}

fn resolve_config(
    context: &AuthContext,
    request: &Request<Body>,
) -> Result<Option<ResolvedRateLimit>, OpenAuthError> {
    let path = normalize_pathname(&request.uri().to_string(), &context.base_path);
    let Some(ip) = request_ip(context, request) else {
        return Ok(None);
    };
    let Some(rule) = resolve_rule(context, request, &path)? else {
        return Ok(None);
    };
    Ok(Some(ResolvedRateLimit {
        key: create_rate_limit_key(&ip, &path),
        rule,
    }))
}

fn resolve_rule(
    context: &AuthContext,
    request: &Request<Body>,
    path: &str,
) -> Result<Option<RateLimitRule>, OpenAuthError> {
    let mut rule = default_rule(context);
    if let Some(special_rule) = default_special_rule(path) {
        rule = special_rule;
    }
    for plugin_rule in &context.rate_limit.plugin_rules {
        if path_matches(&plugin_rule.path, path) {
            rule = plugin_rule.rule.clone();
            break;
        }
    }
    for custom_rule in &context.rate_limit.custom_rules {
        if path_matches(&custom_rule.path, path) {
            return Ok(custom_rule.rule.clone());
        }
    }
    for dynamic_rule in &context.rate_limit.dynamic_rules {
        if path_matches(&dynamic_rule.path, path) {
            return dynamic_rule.provider.resolve(request, &rule);
        }
    }
    Ok(Some(rule))
}

fn default_rule(context: &AuthContext) -> RateLimitRule {
    RateLimitRule {
        window: context.rate_limit.window,
        max: context.rate_limit.max,
    }
}

fn default_special_rule(path: &str) -> Option<RateLimitRule> {
    if path.starts_with("/sign-in")
        || path.starts_with("/sign-up")
        || path.starts_with("/change-password")
        || path.starts_with("/change-email")
    {
        return Some(RateLimitRule { window: 10, max: 3 });
    }
    if path == "/request-password-reset"
        || path == "/send-verification-email"
        || path.starts_with("/forget-password")
        || path == "/email-otp/send-verification-otp"
        || path == "/email-otp/request-password-reset"
    {
        return Some(RateLimitRule { window: 60, max: 3 });
    }
    None
}

fn request_ip(context: &AuthContext, request: &Request<Body>) -> Option<String> {
    if context.options.advanced.ip_address.disable_ip_tracking {
        return None;
    }

    for header_name in &context.options.advanced.ip_address.headers {
        if let Some(value) = request
            .headers()
            .get(header_name)
            .and_then(|value| value.to_str().ok())
        {
            let Some(candidate) = value.split(',').next().map(str::trim) else {
                continue;
            };
            if is_valid_ip(candidate) {
                return Some(normalize_ip_with_options(
                    candidate,
                    NormalizeIpOptions {
                        ipv6_subnet: context.options.advanced.ip_address.ipv6_subnet,
                    },
                ));
            }
        }
    }

    if !context.options.production && !is_production() {
        return Some("127.0.0.1".to_owned());
    }

    None
}

fn storage(context: &AuthContext) -> &dyn RateLimitStorage {
    if let Some(storage) = &context.rate_limit.custom_storage {
        return storage.as_ref();
    }
    match context.rate_limit.storage {
        RateLimitStorageOption::Memory
        | RateLimitStorageOption::Database
        | RateLimitStorageOption::SecondaryStorage => context.rate_limit.memory_storage.as_ref(),
    }
}

fn should_rate_limit(max: u64, window: u64, record: &RateLimitRecord) -> bool {
    let time_since_last_request = now_seconds().saturating_sub(record.last_request);
    time_since_last_request < window as i64 && record.count >= max
}

fn retry_after(last_request: i64, window: u64) -> u64 {
    let retry_after = last_request
        .saturating_add(window as i64)
        .saturating_sub(now_seconds());
    retry_after.max(0) as u64
}

fn path_matches(pattern: &str, path: &str) -> bool {
    if let Some((prefix, suffix)) = pattern.split_once('*') {
        return path.starts_with(prefix) && path.ends_with(suffix);
    }
    pattern == path
}

fn now_seconds() -> i64 {
    time::OffsetDateTime::now_utc().unix_timestamp()
}