use std::collections::HashMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitConfig {
#[serde(default)]
pub enabled: bool,
#[serde(default = "default_qps")]
pub default_qps: u64,
#[serde(default = "default_burst")]
pub default_burst: u64,
#[serde(default)]
pub tiers: HashMap<String, RateLimitTier>,
#[serde(default)]
pub endpoint_costs: HashMap<String, u64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RateLimitTier {
pub qps: u64,
pub burst: u64,
}
impl Default for RateLimitConfig {
fn default() -> Self {
Self {
enabled: false,
default_qps: 100,
default_burst: 200,
tiers: HashMap::new(),
endpoint_costs: default_endpoint_costs(),
}
}
}
fn default_qps() -> u64 {
100
}
fn default_burst() -> u64 {
200
}
fn default_endpoint_costs() -> HashMap<String, u64> {
let mut m = HashMap::new();
m.insert("point_get".into(), 1);
m.insert("point_put".into(), 1);
m.insert("document_scan".into(), 5);
m.insert("vector_search".into(), 20);
m.insert("text_search".into(), 10);
m.insert("hybrid_search".into(), 25);
m.insert("graph_hop".into(), 10);
m.insert("graph_path".into(), 15);
m.insert("aggregate".into(), 10);
m.insert("kv_get".into(), 1);
m.insert("kv_put".into(), 1);
m.insert("kv_scan".into(), 5);
m
}
impl RateLimitConfig {
pub fn operation_cost(&self, operation: &str) -> u64 {
*self.endpoint_costs.get(operation).unwrap_or(&1)
}
pub fn tier(&self, name: &str) -> Option<&RateLimitTier> {
self.tiers.get(name)
}
}