kora_lib/usage_limit/
usage_tracker.rs1use std::{collections::HashSet, sync::Arc};
2
3use deadpool_redis::Runtime;
4use redis::AsyncCommands;
5use solana_sdk::{pubkey::Pubkey, transaction::VersionedTransaction};
6use tokio::sync::OnceCell;
7
8use super::usage_store::{RedisUsageStore, UsageStore};
9use crate::{error::KoraError, sanitize_error, state::get_signer_pool};
10
11#[cfg(not(test))]
12use crate::state::get_config;
13
14#[cfg(test)]
15use crate::tests::config_mock::mock_state::get_config;
16
17const USAGE_CACHE_KEY: &str = "kora:usage_limit";
18
19static USAGE_LIMITER: OnceCell<Option<UsageTracker>> = OnceCell::const_new();
21
22pub struct UsageTracker {
23 store: Arc<dyn UsageStore>,
24 max_transactions: u64,
25 kora_signers: HashSet<Pubkey>,
26 fallback_if_unavailable: bool,
27}
28
29impl UsageTracker {
30 pub fn new(
31 store: Arc<dyn UsageStore>,
32 max_transactions: u64,
33 kora_signers: HashSet<Pubkey>,
34 fallback_if_unavailable: bool,
35 ) -> Self {
36 Self { store, max_transactions, kora_signers, fallback_if_unavailable }
37 }
38
39 fn get_usage_key(&self, wallet: &Pubkey) -> String {
40 format!("{USAGE_CACHE_KEY}:{wallet}")
41 }
42
43 fn handle_store_error(
45 &self,
46 error: KoraError,
47 operation: &str,
48 wallet: &Pubkey,
49 ) -> Result<(), KoraError> {
50 log::error!("Failed to {operation} for {wallet}: {error}");
51 if self.fallback_if_unavailable {
52 log::error!("Fallback enabled - allowing transaction due to store error");
53 Ok(()) } else {
55 Err(KoraError::InternalServerError(format!(
56 "Usage limit store unavailable and fallback disabled: {error}"
57 )))
58 }
59 }
60
61 async fn check_usage_limit(&self, wallet: &Pubkey) -> Result<(), KoraError> {
62 if self.max_transactions == 0 {
64 return Ok(());
65 }
66
67 let key = self.get_usage_key(wallet);
69
70 let current_count = match self.store.get(&key).await {
72 Ok(count) => count,
73 Err(e) => {
74 return self.handle_store_error(e, "get usage count", wallet);
75 }
76 };
77
78 if current_count >= self.max_transactions as u32 {
79 return Err(KoraError::UsageLimitExceeded(format!(
80 "Wallet {wallet} exceeded limit: {}/{}",
81 current_count + 1,
82 self.max_transactions
83 )));
84 }
85
86 let new_count = match self.store.increment(&key).await {
88 Ok(count) => count,
89 Err(e) => {
90 return self.handle_store_error(e, "increment usage count", wallet);
91 }
92 };
93
94 log::debug!("Usage check passed for {wallet}: {new_count}/{}", self.max_transactions);
95
96 Ok(())
97 }
98
99 fn get_usage_limiter() -> Result<Option<&'static UsageTracker>, KoraError> {
100 match USAGE_LIMITER.get() {
101 Some(limiter) => Ok(limiter.as_ref()),
102 None => {
103 Err(KoraError::InternalServerError("Usage limiter not initialized".to_string()))
104 }
105 }
106 }
107
108 fn extract_transaction_sender(
110 &self,
111 transaction: &VersionedTransaction,
112 ) -> Result<Option<Pubkey>, KoraError> {
113 let account_keys = transaction.message.static_account_keys();
114
115 if account_keys.is_empty() {
116 return Err(KoraError::InvalidTransaction(
117 "Transaction has no account keys".to_string(),
118 ));
119 }
120
121 let signers = account_keys
122 .iter()
123 .take(transaction.message.header().num_required_signatures as usize)
124 .collect::<Vec<_>>();
125
126 for signer in &signers {
127 if !self.kora_signers.contains(signer) {
128 return Ok(Some(**signer));
129 }
130 }
131
132 log::debug!(
133 "No user signers found when extracting transaction sender for usage limit: {signers:?}",
134 );
135
136 Ok(None)
137 }
138
139 pub async fn init_usage_limiter() -> Result<(), KoraError> {
141 let config = get_config()?;
142
143 if !config.kora.usage_limit.enabled {
144 log::info!("Usage limiting disabled");
145 USAGE_LIMITER.set(None).map_err(|_| {
146 KoraError::InternalServerError("Usage limiter already initialized".to_string())
147 })?;
148 return Ok(());
149 }
150
151 let usage_limiter = if let Some(cache_url) = &config.kora.usage_limit.cache_url {
152 let cfg = deadpool_redis::Config::from_url(cache_url);
153 let pool = cfg.create_pool(Some(Runtime::Tokio1)).map_err(|e| {
154 KoraError::InternalServerError(format!(
155 "Failed to create Redis pool: {}",
156 sanitize_error!(e)
157 ))
158 })?;
159
160 let mut conn = pool.get().await.map_err(|e| {
162 KoraError::InternalServerError(format!(
163 "Failed to connect to Redis: {}",
164 sanitize_error!(e)
165 ))
166 })?;
167
168 let _: Option<String> = conn.get("__usage_limiter_test__").await.map_err(|e| {
170 KoraError::InternalServerError(format!(
171 "Redis connection test failed: {}",
172 sanitize_error!(e)
173 ))
174 })?;
175
176 log::info!(
177 "Usage limiter initialized with max {} transactions",
178 config.kora.usage_limit.max_transactions
179 );
180
181 let kora_signers = get_signer_pool()?
182 .get_signers_info()
183 .iter()
184 .filter_map(|info| info.public_key.parse().ok())
185 .collect();
186
187 let store = Arc::new(RedisUsageStore::new(pool));
188 Some(UsageTracker::new(
189 store,
190 config.kora.usage_limit.max_transactions,
191 kora_signers,
192 config.kora.usage_limit.fallback_if_unavailable,
193 ))
194 } else {
195 log::info!("Usage limiting enabled but no cache_url configured - disabled");
196 None
197 };
198
199 USAGE_LIMITER.set(usage_limiter).map_err(|_| {
200 KoraError::InternalServerError("Usage limiter already initialized".to_string())
201 })?;
202
203 Ok(())
204 }
205
206 pub async fn check_transaction_usage_limit(
208 transaction: &VersionedTransaction,
209 ) -> Result<(), KoraError> {
210 let config = get_config()?;
211
212 if let Some(limiter) = Self::get_usage_limiter()? {
213 let sender = limiter.extract_transaction_sender(transaction)?;
214 if let Some(sender) = sender {
215 limiter.check_usage_limit(&sender).await?;
216 }
217 Ok(())
218 } else if config.kora.usage_limit.enabled
219 && !config.kora.usage_limit.fallback_if_unavailable
220 {
221 Err(KoraError::InternalServerError(
223 "Usage limiter unavailable and fallback disabled".to_string(),
224 ))
225 } else {
226 Ok(())
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235 use crate::{
236 tests::{config_mock::ConfigMockBuilder, transaction_mock::create_mock_transaction},
237 usage_limit::{usage_store::ErrorUsageStore, InMemoryUsageStore},
238 };
239
240 #[tokio::test]
241 async fn test_get_usage_key_format() {
242 let wallet = Pubkey::new_unique();
243 let expected_key = format!("kora:usage_limit:{wallet}");
244
245 assert_eq!(expected_key, format!("kora:usage_limit:{wallet}"));
246 }
247
248 #[tokio::test]
249 async fn test_usage_limit_enforcement() {
250 let store = Arc::new(InMemoryUsageStore::new());
251 let kora_signers = HashSet::new();
252 let tracker = UsageTracker::new(store, 2, kora_signers, true);
253
254 let wallet = Pubkey::new_unique();
255
256 assert!(tracker.check_usage_limit(&wallet).await.is_ok());
258
259 assert!(tracker.check_usage_limit(&wallet).await.is_ok());
261
262 let result = tracker.check_usage_limit(&wallet).await;
264 assert!(result.is_err());
265 assert!(result.unwrap_err().to_string().contains("exceeded limit"));
266 }
267
268 #[tokio::test]
269 async fn test_independent_wallet_limits() {
270 let store = Arc::new(InMemoryUsageStore::new());
271 let kora_signers = HashSet::new();
272 let tracker = UsageTracker::new(store, 2, kora_signers, true);
273
274 let wallet1 = Pubkey::new_unique();
275 let wallet2 = Pubkey::new_unique();
276
277 assert!(tracker.check_usage_limit(&wallet1).await.is_ok());
279 assert!(tracker.check_usage_limit(&wallet1).await.is_ok());
280 assert!(tracker.check_usage_limit(&wallet1).await.is_err());
281
282 assert!(tracker.check_usage_limit(&wallet2).await.is_ok());
284 assert!(tracker.check_usage_limit(&wallet2).await.is_ok());
285 assert!(tracker.check_usage_limit(&wallet2).await.is_err());
286 }
287
288 #[tokio::test]
289 async fn test_unlimited_usage() {
290 let store = Arc::new(InMemoryUsageStore::new());
291 let kora_signers = HashSet::new();
292 let tracker = UsageTracker::new(store, 0, kora_signers, true); let wallet = Pubkey::new_unique();
295
296 for _ in 0..10 {
298 assert!(tracker.check_usage_limit(&wallet).await.is_ok());
299 }
300 }
301
302 #[tokio::test]
303 async fn test_usage_limiter_disabled_fallback() {
304 let _m = ConfigMockBuilder::new().with_usage_limit_enabled(false).build_and_setup();
306
307 let _ = UsageTracker::init_usage_limiter().await;
309
310 let result = UsageTracker::check_transaction_usage_limit(&create_mock_transaction()).await;
311 match &result {
312 Ok(_) => {}
313 Err(e) => println!("Test failed with error: {e}"),
314 }
315 assert!(result.is_ok());
316 }
317
318 #[tokio::test]
319 async fn test_usage_limiter_fallback_allowed() {
320 let _m = ConfigMockBuilder::new()
321 .with_usage_limit_enabled(true)
322 .with_usage_limit_cache_url(None)
323 .with_usage_limit_fallback(true)
324 .build_and_setup();
325
326 let _ = UsageTracker::init_usage_limiter().await;
328
329 let result = UsageTracker::check_transaction_usage_limit(&create_mock_transaction()).await;
330 assert!(result.is_ok());
331 }
332
333 #[tokio::test]
334 async fn test_usage_limiter_fallback_denied() {
335 let _m = ConfigMockBuilder::new()
336 .with_usage_limit_enabled(true)
337 .with_usage_limit_cache_url(None)
338 .with_usage_limit_fallback(false)
339 .build_and_setup();
340
341 let _ = UsageTracker::init_usage_limiter().await;
343
344 let result = UsageTracker::check_transaction_usage_limit(&create_mock_transaction()).await;
345 assert!(result.is_err());
346 assert!(result
347 .unwrap_err()
348 .to_string()
349 .contains("Usage limiter unavailable and fallback disabled"));
350 }
351
352 #[tokio::test]
353 async fn test_usage_limit_store_get_error_fallback_enabled() {
354 let store = Arc::new(ErrorUsageStore::new(true, false)); let kora_signers = HashSet::new();
356 let tracker = UsageTracker::new(store, 2, kora_signers, true); let wallet = Pubkey::new_unique();
359
360 let result = tracker.check_usage_limit(&wallet).await;
362 assert!(result.is_ok());
363 }
364
365 #[tokio::test]
366 async fn test_usage_limit_store_get_error_fallback_disabled() {
367 let store = Arc::new(ErrorUsageStore::new(true, false)); let kora_signers = HashSet::new();
369 let tracker = UsageTracker::new(store, 2, kora_signers, false); let wallet = Pubkey::new_unique();
372
373 let result = tracker.check_usage_limit(&wallet).await;
375 assert!(result.is_err());
376 assert!(result
377 .unwrap_err()
378 .to_string()
379 .contains("Usage limit store unavailable and fallback disabled"));
380 }
381
382 #[tokio::test]
383 async fn test_usage_limit_store_increment_error_fallback_enabled() {
384 let store = Arc::new(ErrorUsageStore::new(false, true)); let kora_signers = HashSet::new();
386 let tracker = UsageTracker::new(store, 2, kora_signers, true); let wallet = Pubkey::new_unique();
389
390 let result = tracker.check_usage_limit(&wallet).await;
392 assert!(result.is_ok());
393 }
394
395 #[tokio::test]
396 async fn test_usage_limit_store_increment_error_fallback_disabled() {
397 let store = Arc::new(ErrorUsageStore::new(false, true)); let kora_signers = HashSet::new();
399 let tracker = UsageTracker::new(store, 2, kora_signers, false); let wallet = Pubkey::new_unique();
402
403 let result = tracker.check_usage_limit(&wallet).await;
405 assert!(result.is_err());
406 assert!(result
407 .unwrap_err()
408 .to_string()
409 .contains("Usage limit store unavailable and fallback disabled"));
410 }
411}