envoy/rate_limit/
store.rs1use crate::error::Result;
9use crate::rate_limit::RateLimitState;
10
11const KIND_RATE_LIMIT: &str = "EnvoyRateLimit";
12const KIND_RATE_LIMIT_BAN: &str = "EnvoyRateLimitBan";
13
14pub struct RateLimitStore;
19
20impl RateLimitStore {
21 pub fn new() -> Self {
22 Self
23 }
24
25 pub fn persist(&self, graph: &sqlitegraph::SqliteGraph, state: &RateLimitState) -> Result<()> {
27 use sqlitegraph::GraphEntity;
28
29 let data = serde_json::json!({
30 "tokens": state.bucket().tokens,
31 "max_tokens": state.bucket().max_tokens,
32 "replenish_rate": state.bucket().replenish_rate,
33 });
34
35 if let Some(mut entity) =
36 graph.find_entity_by_kind_and_name(KIND_RATE_LIMIT, &state.agent_id)?
37 {
38 entity.data = data;
39 graph.update_entity(&entity)?;
40 } else {
41 let entity = GraphEntity {
42 id: 0,
43 kind: KIND_RATE_LIMIT.to_string(),
44 name: state.agent_id.clone(),
45 file_path: None,
46 data,
47 };
48 graph.insert_entity(&entity)?;
49 }
50 Ok(())
51 }
52
53 pub fn load(
55 &self,
56 graph: &sqlitegraph::SqliteGraph,
57 agent_id: &str,
58 ) -> Result<Option<RateLimitState>> {
59 use crate::rate_limit::TokenBucket;
60
61 if let Some(entity) = graph.find_entity_by_kind_and_name(KIND_RATE_LIMIT, agent_id)? {
62 let tokens = entity
63 .data
64 .get("tokens")
65 .and_then(|v| v.as_u64())
66 .unwrap_or(0);
67 let max_tokens = entity
68 .data
69 .get("max_tokens")
70 .and_then(|v| v.as_u64())
71 .unwrap_or(1000);
72 let replenish_rate = entity
73 .data
74 .get("replenish_rate")
75 .and_then(|v| v.as_u64())
76 .unwrap_or(100);
77
78 let bucket = TokenBucket {
79 tokens,
80 max_tokens,
81 replenish_rate,
82 last_replenish: std::time::Instant::now(),
83 };
84
85 return Ok(Some(RateLimitState::from_bucket(
86 agent_id.to_string(),
87 bucket,
88 )));
89 }
90 Ok(None)
91 }
92
93 pub fn load_all(&self, graph: &sqlitegraph::SqliteGraph) -> Result<Vec<RateLimitState>> {
95 let entities = graph.find_entities_by_kind(KIND_RATE_LIMIT)?;
96 let mut states = Vec::new();
97
98 for entity in &entities {
99 if let Some(state) = self.load(graph, &entity.name)? {
100 states.push(state);
101 }
102 }
103
104 Ok(states)
105 }
106
107 pub fn persist_ban(
109 &self,
110 graph: &sqlitegraph::SqliteGraph,
111 agent_id: &str,
112 reason: &str,
113 ) -> Result<()> {
114 use sqlitegraph::GraphEntity;
115
116 let data = serde_json::json!({
117 "reason": reason,
118 "banned_at": chrono::Utc::now().to_rfc3339(),
119 });
120
121 if let Some(mut entity) =
122 graph.find_entity_by_kind_and_name(KIND_RATE_LIMIT_BAN, agent_id)?
123 {
124 entity.data = data;
125 graph.update_entity(&entity)?;
126 } else {
127 let entity = GraphEntity {
128 id: 0,
129 kind: KIND_RATE_LIMIT_BAN.to_string(),
130 name: agent_id.to_string(),
131 file_path: None,
132 data,
133 };
134 graph.insert_entity(&entity)?;
135 }
136 Ok(())
137 }
138
139 pub fn remove_ban(&self, graph: &sqlitegraph::SqliteGraph, agent_id: &str) -> Result<()> {
141 if let Some(entity) = graph.find_entity_by_kind_and_name(KIND_RATE_LIMIT_BAN, agent_id)? {
142 graph.delete_entity(entity.id)?;
143 }
144 Ok(())
145 }
146
147 pub fn load_bans(
149 &self,
150 graph: &sqlitegraph::SqliteGraph,
151 ) -> Result<std::collections::HashMap<String, String>> {
152 let entities = graph.find_entities_by_kind(KIND_RATE_LIMIT_BAN)?;
153 let mut bans = std::collections::HashMap::new();
154
155 for entity in &entities {
156 if let Some(reason) = entity.data.get("reason").and_then(|v| v.as_str()) {
157 bans.insert(entity.name.clone(), reason.to_string());
158 }
159 }
160
161 Ok(bans)
162 }
163}
164
165impl Default for RateLimitStore {
166 fn default() -> Self {
167 Self::new()
168 }
169}