kora_lib/usage_limit/
usage_tracker.rs

1use std::{collections::HashSet, sync::Arc};
2
3use deadpool_redis::Runtime;
4use redis::AsyncCommands;
5use solana_sdk::{pubkey::Pubkey, transaction::VersionedTransaction};
6use tokio::sync::OnceCell;
7
8use super::usage_store::{RedisUsageStore, UsageStore};
9use crate::{config::Config, error::KoraError, sanitize_error, state::get_signer_pool};
10
11#[cfg(not(test))]
12use crate::state::get_config;
13
14#[cfg(test)]
15use crate::tests::config_mock::mock_state::get_config;
16
17const USAGE_CACHE_KEY: &str = "kora:usage_limit";
18
19/// Global usage limiter instance
20static USAGE_LIMITER: OnceCell<Option<UsageTracker>> = OnceCell::const_new();
21
22pub struct UsageTracker {
23    store: Arc<dyn UsageStore>,
24    max_transactions: u64,
25    kora_signers: HashSet<Pubkey>,
26    fallback_if_unavailable: bool,
27}
28
29impl UsageTracker {
30    pub fn new(
31        store: Arc<dyn UsageStore>,
32        max_transactions: u64,
33        kora_signers: HashSet<Pubkey>,
34        fallback_if_unavailable: bool,
35    ) -> Self {
36        Self { store, max_transactions, kora_signers, fallback_if_unavailable }
37    }
38
39    fn get_usage_key(&self, wallet: &Pubkey) -> String {
40        format!("{USAGE_CACHE_KEY}:{wallet}")
41    }
42
43    /// Handle store errors according to fallback configuration
44    fn handle_store_error(
45        &self,
46        error: KoraError,
47        operation: &str,
48        wallet: &Pubkey,
49    ) -> Result<(), KoraError> {
50        log::error!("Failed to {operation} for {wallet}: {error}");
51        if self.fallback_if_unavailable {
52            log::error!("Fallback enabled - allowing transaction due to store error");
53            Ok(()) // Allow transaction when fallback is enabled
54        } else {
55            Err(KoraError::InternalServerError(format!(
56                "Usage limit store unavailable and fallback disabled: {error}"
57            )))
58        }
59    }
60
61    async fn check_usage_limit(&self, wallet: &Pubkey) -> Result<(), KoraError> {
62        // Skip check if unlimited (0)
63        if self.max_transactions == 0 {
64            return Ok(());
65        }
66
67        // Atomically increment and then check against limit
68        let key = self.get_usage_key(wallet);
69
70        let new_count = match self.store.increment(&key).await {
71            Ok(count) => count,
72            Err(e) => {
73                return self.handle_store_error(e, "increment usage count", wallet);
74            }
75        };
76
77        if new_count > self.max_transactions as u32 {
78            return Err(KoraError::UsageLimitExceeded(format!(
79                "Wallet {wallet} exceeded limit: {}/{}",
80                new_count, self.max_transactions
81            )));
82        }
83
84        log::debug!("Usage check passed for {wallet}: {new_count}/{}", self.max_transactions);
85
86        Ok(())
87    }
88
89    fn get_usage_limiter() -> Result<Option<&'static UsageTracker>, KoraError> {
90        match USAGE_LIMITER.get() {
91            Some(limiter) => Ok(limiter.as_ref()),
92            None => {
93                Err(KoraError::InternalServerError("Usage limiter not initialized".to_string()))
94            }
95        }
96    }
97
98    /// Extract sender from transaction
99    fn extract_transaction_sender(
100        &self,
101        transaction: &VersionedTransaction,
102    ) -> Result<Option<Pubkey>, KoraError> {
103        let account_keys = transaction.message.static_account_keys();
104
105        if account_keys.is_empty() {
106            return Err(KoraError::InvalidTransaction(
107                "Transaction has no account keys".to_string(),
108            ));
109        }
110
111        let signers = account_keys
112            .iter()
113            .take(transaction.message.header().num_required_signatures as usize)
114            .collect::<Vec<_>>();
115
116        for signer in &signers {
117            if !self.kora_signers.contains(signer) {
118                return Ok(Some(**signer));
119            }
120        }
121
122        log::debug!(
123            "No user signers found when extracting transaction sender for usage limit: {signers:?}",
124        );
125
126        Ok(None)
127    }
128
129    /// Initialize the global usage limiter
130    pub async fn init_usage_limiter() -> Result<(), KoraError> {
131        let config = get_config()?;
132
133        if !config.kora.usage_limit.enabled {
134            log::info!("Usage limiting disabled");
135            USAGE_LIMITER.set(None).map_err(|_| {
136                KoraError::InternalServerError("Usage limiter already initialized".to_string())
137            })?;
138            return Ok(());
139        }
140
141        let usage_limiter = if let Some(cache_url) = &config.kora.usage_limit.cache_url {
142            let cfg = deadpool_redis::Config::from_url(cache_url);
143            let pool = cfg.create_pool(Some(Runtime::Tokio1)).map_err(|e| {
144                KoraError::InternalServerError(format!(
145                    "Failed to create Redis pool: {}",
146                    sanitize_error!(e)
147                ))
148            })?;
149
150            // Test Redis connection
151            let mut conn = pool.get().await.map_err(|e| {
152                KoraError::InternalServerError(format!(
153                    "Failed to connect to Redis: {}",
154                    sanitize_error!(e)
155                ))
156            })?;
157
158            // Simple connection test
159            let _: Option<String> = conn.get("__usage_limiter_test__").await.map_err(|e| {
160                KoraError::InternalServerError(format!(
161                    "Redis connection test failed: {}",
162                    sanitize_error!(e)
163                ))
164            })?;
165
166            log::info!(
167                "Usage limiter initialized with max {} transactions",
168                config.kora.usage_limit.max_transactions
169            );
170
171            let kora_signers = get_signer_pool()?
172                .get_signers_info()
173                .iter()
174                .filter_map(|info| info.public_key.parse().ok())
175                .collect();
176
177            let store = Arc::new(RedisUsageStore::new(pool));
178            Some(UsageTracker::new(
179                store,
180                config.kora.usage_limit.max_transactions,
181                kora_signers,
182                config.kora.usage_limit.fallback_if_unavailable,
183            ))
184        } else {
185            log::info!("Usage limiting enabled but no cache_url configured - disabled");
186            None
187        };
188
189        USAGE_LIMITER.set(usage_limiter).map_err(|_| {
190            KoraError::InternalServerError("Usage limiter already initialized".to_string())
191        })?;
192
193        Ok(())
194    }
195
196    /// Check usage limit for transaction sender
197    pub async fn check_transaction_usage_limit(
198        config: &Config,
199        transaction: &VersionedTransaction,
200    ) -> Result<(), KoraError> {
201        if let Some(limiter) = Self::get_usage_limiter()? {
202            let sender = limiter.extract_transaction_sender(transaction)?;
203            if let Some(sender) = sender {
204                limiter.check_usage_limit(&sender).await?;
205            }
206            Ok(())
207        } else if config.kora.usage_limit.enabled
208            && !config.kora.usage_limit.fallback_if_unavailable
209        {
210            // Usage limiting enabled but limiter unavailable and fallback disabled
211            Err(KoraError::InternalServerError(
212                "Usage limiter unavailable and fallback disabled".to_string(),
213            ))
214        } else {
215            // Usage limiting disabled or fallback allowed
216            Ok(())
217        }
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::{
225        tests::{config_mock::ConfigMockBuilder, transaction_mock::create_mock_transaction},
226        usage_limit::{usage_store::ErrorUsageStore, InMemoryUsageStore},
227    };
228
229    #[tokio::test]
230    async fn test_get_usage_key_format() {
231        let wallet = Pubkey::new_unique();
232        let expected_key = format!("kora:usage_limit:{wallet}");
233
234        assert_eq!(expected_key, format!("kora:usage_limit:{wallet}"));
235    }
236
237    #[tokio::test]
238    async fn test_usage_limit_enforcement() {
239        let store = Arc::new(InMemoryUsageStore::new());
240        let kora_signers = HashSet::new();
241        let tracker = UsageTracker::new(store, 2, kora_signers, true);
242
243        let wallet = Pubkey::new_unique();
244
245        // First transaction should succeed
246        assert!(tracker.check_usage_limit(&wallet).await.is_ok());
247
248        // Second transaction should succeed (at limit)
249        assert!(tracker.check_usage_limit(&wallet).await.is_ok());
250
251        // Third transaction should fail (over limit)
252        let result = tracker.check_usage_limit(&wallet).await;
253        assert!(result.is_err());
254        assert!(result.unwrap_err().to_string().contains("exceeded limit"));
255    }
256
257    #[tokio::test]
258    async fn test_independent_wallet_limits() {
259        let store = Arc::new(InMemoryUsageStore::new());
260        let kora_signers = HashSet::new();
261        let tracker = UsageTracker::new(store, 2, kora_signers, true);
262
263        let wallet1 = Pubkey::new_unique();
264        let wallet2 = Pubkey::new_unique();
265
266        // Use up wallet1's limit
267        assert!(tracker.check_usage_limit(&wallet1).await.is_ok());
268        assert!(tracker.check_usage_limit(&wallet1).await.is_ok());
269        assert!(tracker.check_usage_limit(&wallet1).await.is_err());
270
271        // Wallet2 should still be able to make transactions
272        assert!(tracker.check_usage_limit(&wallet2).await.is_ok());
273        assert!(tracker.check_usage_limit(&wallet2).await.is_ok());
274        assert!(tracker.check_usage_limit(&wallet2).await.is_err());
275    }
276
277    #[tokio::test]
278    async fn test_unlimited_usage() {
279        let store = Arc::new(InMemoryUsageStore::new());
280        let kora_signers = HashSet::new();
281        let tracker = UsageTracker::new(store, 0, kora_signers, true); // 0 = unlimited
282
283        let wallet = Pubkey::new_unique();
284
285        // Should allow many transactions when unlimited
286        for _ in 0..10 {
287            assert!(tracker.check_usage_limit(&wallet).await.is_ok());
288        }
289    }
290
291    #[tokio::test]
292    async fn test_usage_limiter_disabled_fallback() {
293        // Test that when usage limiting is disabled, transactions are allowed
294        let _m = ConfigMockBuilder::new().with_usage_limit_enabled(false).build_and_setup();
295
296        // Initialize the usage limiter - it should set to None when disabled
297        let _ = UsageTracker::init_usage_limiter().await;
298
299        let config = get_config().unwrap();
300        let result =
301            UsageTracker::check_transaction_usage_limit(&config, &create_mock_transaction()).await;
302        match &result {
303            Ok(_) => {}
304            Err(e) => println!("Test failed with error: {e}"),
305        }
306        assert!(result.is_ok());
307    }
308
309    #[tokio::test]
310    async fn test_usage_limiter_fallback_allowed() {
311        let _m = ConfigMockBuilder::new()
312            .with_usage_limit_enabled(true)
313            .with_usage_limit_cache_url(None)
314            .with_usage_limit_fallback(true)
315            .build_and_setup();
316
317        // Initialize with no cache_url - should set limiter to None
318        let _ = UsageTracker::init_usage_limiter().await;
319
320        let config = get_config().unwrap();
321        let result =
322            UsageTracker::check_transaction_usage_limit(&config, &create_mock_transaction()).await;
323        assert!(result.is_ok());
324    }
325
326    #[tokio::test]
327    async fn test_usage_limiter_fallback_denied() {
328        let _m = ConfigMockBuilder::new()
329            .with_usage_limit_enabled(true)
330            .with_usage_limit_cache_url(None)
331            .with_usage_limit_fallback(false)
332            .build_and_setup();
333
334        // Initialize with no cache_url - should set limiter to None
335        let _ = UsageTracker::init_usage_limiter().await;
336
337        let config = get_config().unwrap();
338        let result =
339            UsageTracker::check_transaction_usage_limit(&config, &create_mock_transaction()).await;
340        assert!(result.is_err());
341        assert!(result
342            .unwrap_err()
343            .to_string()
344            .contains("Usage limiter unavailable and fallback disabled"));
345    }
346
347    #[tokio::test]
348    async fn test_usage_limit_store_increment_error_fallback_enabled() {
349        let store = Arc::new(ErrorUsageStore::new(false, true)); // increment() will error
350        let kora_signers = HashSet::new();
351        let tracker = UsageTracker::new(store, 2, kora_signers, true); // fallback enabled
352
353        let wallet = Pubkey::new_unique();
354
355        // Should succeed because fallback is enabled (get() succeeds, increment() fails but fallback allows)
356        let result = tracker.check_usage_limit(&wallet).await;
357        assert!(result.is_ok());
358    }
359
360    #[tokio::test]
361    async fn test_usage_limit_store_increment_error_fallback_disabled() {
362        let store = Arc::new(ErrorUsageStore::new(false, true)); // increment() will error
363        let kora_signers = HashSet::new();
364        let tracker = UsageTracker::new(store, 2, kora_signers, false); // fallback disabled
365
366        let wallet = Pubkey::new_unique();
367
368        // Should fail because fallback is disabled
369        let result = tracker.check_usage_limit(&wallet).await;
370        assert!(result.is_err());
371        assert!(result
372            .unwrap_err()
373            .to_string()
374            .contains("Usage limit store unavailable and fallback disabled"));
375    }
376}