Skip to main content

envoy/rate_limit/
store.rs

1//! Rate limit persistence via sqlitegraph.
2//!
3//! Stores per-agent rate limit state in the database for:
4//! - Recovery across restarts
5//! - Cross-process coordination
6//! - L1 fallback
7
8use crate::error::Result;
9use crate::rate_limit::RateLimitState;
10
11const KIND_RATE_LIMIT: &str = "EnvoyRateLimit";
12const KIND_RATE_LIMIT_BAN: &str = "EnvoyRateLimitBan";
13
14/// Persistent store for rate limit state.
15///
16/// Follows AgentRegistry pattern: write-through to sqlitegraph on mutations,
17/// load from DB on startup.
18pub struct RateLimitStore;
19
20impl RateLimitStore {
21    pub fn new() -> Self {
22        Self
23    }
24
25    /// Persist rate limit state to sqlitegraph.
26    pub fn persist(&self, graph: &sqlitegraph::SqliteGraph, state: &RateLimitState) -> Result<()> {
27        use sqlitegraph::GraphEntity;
28
29        let data = serde_json::json!({
30            "tokens": state.bucket().tokens,
31            "max_tokens": state.bucket().max_tokens,
32            "replenish_rate": state.bucket().replenish_rate,
33        });
34
35        if let Some(mut entity) =
36            graph.find_entity_by_kind_and_name(KIND_RATE_LIMIT, &state.agent_id)?
37        {
38            entity.data = data;
39            graph.update_entity(&entity)?;
40        } else {
41            let entity = GraphEntity {
42                id: 0,
43                kind: KIND_RATE_LIMIT.to_string(),
44                name: state.agent_id.clone(),
45                file_path: None,
46                data,
47            };
48            graph.insert_entity(&entity)?;
49        }
50        Ok(())
51    }
52
53    /// Load rate limit state from sqlitegraph.
54    pub fn load(
55        &self,
56        graph: &sqlitegraph::SqliteGraph,
57        agent_id: &str,
58    ) -> Result<Option<RateLimitState>> {
59        use crate::rate_limit::TokenBucket;
60
61        if let Some(entity) = graph.find_entity_by_kind_and_name(KIND_RATE_LIMIT, agent_id)? {
62            let tokens = entity
63                .data
64                .get("tokens")
65                .and_then(|v| v.as_u64())
66                .unwrap_or(0);
67            let max_tokens = entity
68                .data
69                .get("max_tokens")
70                .and_then(|v| v.as_u64())
71                .unwrap_or(1000);
72            let replenish_rate = entity
73                .data
74                .get("replenish_rate")
75                .and_then(|v| v.as_u64())
76                .unwrap_or(100);
77
78            let bucket = TokenBucket {
79                tokens,
80                max_tokens,
81                replenish_rate,
82                last_replenish: std::time::Instant::now(),
83            };
84
85            return Ok(Some(RateLimitState::from_bucket(
86                agent_id.to_string(),
87                bucket,
88            )));
89        }
90        Ok(None)
91    }
92
93    /// Load all rate limit states from sqlitegraph.
94    pub fn load_all(&self, graph: &sqlitegraph::SqliteGraph) -> Result<Vec<RateLimitState>> {
95        let entities = graph.find_entities_by_kind(KIND_RATE_LIMIT)?;
96        let mut states = Vec::new();
97
98        for entity in &entities {
99            if let Some(state) = self.load(graph, &entity.name)? {
100                states.push(state);
101            }
102        }
103
104        Ok(states)
105    }
106
107    /// Persist a ban to sqlitegraph.
108    pub fn persist_ban(
109        &self,
110        graph: &sqlitegraph::SqliteGraph,
111        agent_id: &str,
112        reason: &str,
113    ) -> Result<()> {
114        use sqlitegraph::GraphEntity;
115
116        let data = serde_json::json!({
117            "reason": reason,
118            "banned_at": chrono::Utc::now().to_rfc3339(),
119        });
120
121        if let Some(mut entity) =
122            graph.find_entity_by_kind_and_name(KIND_RATE_LIMIT_BAN, agent_id)?
123        {
124            entity.data = data;
125            graph.update_entity(&entity)?;
126        } else {
127            let entity = GraphEntity {
128                id: 0,
129                kind: KIND_RATE_LIMIT_BAN.to_string(),
130                name: agent_id.to_string(),
131                file_path: None,
132                data,
133            };
134            graph.insert_entity(&entity)?;
135        }
136        Ok(())
137    }
138
139    /// Remove a ban from sqlitegraph.
140    pub fn remove_ban(&self, graph: &sqlitegraph::SqliteGraph, agent_id: &str) -> Result<()> {
141        if let Some(entity) = graph.find_entity_by_kind_and_name(KIND_RATE_LIMIT_BAN, agent_id)? {
142            graph.delete_entity(entity.id)?;
143        }
144        Ok(())
145    }
146
147    /// Load all bans from sqlitegraph.
148    pub fn load_bans(
149        &self,
150        graph: &sqlitegraph::SqliteGraph,
151    ) -> Result<std::collections::HashMap<String, String>> {
152        let entities = graph.find_entities_by_kind(KIND_RATE_LIMIT_BAN)?;
153        let mut bans = std::collections::HashMap::new();
154
155        for entity in &entities {
156            if let Some(reason) = entity.data.get("reason").and_then(|v| v.as_str()) {
157                bans.insert(entity.name.clone(), reason.to_string());
158            }
159        }
160
161        Ok(bans)
162    }
163}
164
165impl Default for RateLimitStore {
166    fn default() -> Self {
167        Self::new()
168    }
169}