Skip to main content

cougr_core/accounts/
storage.rs

1use soroban_sdk::{Address, BytesN, Env, Symbol, Vec};
2
3use super::error::AccountError;
4use super::types::SessionKey;
5
6/// Symbol used as the storage key prefix for session keys.
7const SESSION_KEYS_PREFIX: &str = "sess_keys";
8
9/// Persistent session key storage using Soroban contract storage.
10///
11/// Session keys are stored in the contract's persistent storage, keyed
12/// by account address. This allows session keys to survive across
13/// contract invocations (unlike `ContractAccount`'s in-memory storage).
14///
15/// # Example
16/// ```no_run
17/// use cougr_core::accounts::{SessionKey, SessionScope, SessionStorage};
18/// use soroban_sdk::{symbol_short, testutils::Address as _, Address, BytesN, Env, Vec};
19///
20/// let env = Env::default();
21/// let player_address = Address::generate(&env);
22/// let key_id = BytesN::from_array(&env, &[7u8; 32]);
23/// let session_key = SessionKey {
24///     key_id: key_id.clone(),
25///     scope: SessionScope {
26///         allowed_actions: Vec::from_array(&env, [symbol_short!("move")]),
27///         max_operations: 10,
28///         expires_at: 1_000,
29///     },
30///     created_at: 0,
31///     operations_used: 0,
32///     next_nonce: 0,
33/// };
34///
35/// SessionStorage::store(&env, &player_address, &session_key);
36/// let loaded = SessionStorage::load(&env, &player_address, &key_id);
37/// assert!(loaded.is_some());
38/// ```
39pub struct SessionStorage;
40
41impl SessionStorage {
42    /// Store a session key for an account.
43    ///
44    /// If a key with the same `key_id` already exists, it is overwritten.
45    pub fn store(env: &Env, account: &Address, key: &SessionKey) {
46        let keys = Self::load_all(env, account);
47        // Remove existing key with same ID if present
48        let mut new_keys: Vec<SessionKey> = Vec::new(env);
49        for i in 0..keys.len() {
50            if let Some(k) = keys.get(i) {
51                if k.key_id != key.key_id {
52                    new_keys.push_back(k);
53                }
54            }
55        }
56        new_keys.push_back(key.clone());
57        let storage_key = Self::storage_key(env, account);
58        env.storage().persistent().set(&storage_key, &new_keys);
59    }
60
61    /// Load a specific session key by ID.
62    pub fn load(env: &Env, account: &Address, key_id: &BytesN<32>) -> Option<SessionKey> {
63        let keys = Self::load_all(env, account);
64        for i in 0..keys.len() {
65            if let Some(k) = keys.get(i) {
66                if &k.key_id == key_id {
67                    return Some(k);
68                }
69            }
70        }
71        None
72    }
73
74    /// Load all session keys for an account.
75    pub fn load_all(env: &Env, account: &Address) -> Vec<SessionKey> {
76        let storage_key = Self::storage_key(env, account);
77        env.storage()
78            .persistent()
79            .get(&storage_key)
80            .unwrap_or_else(|| Vec::new(env))
81    }
82
83    /// Remove a session key by ID. Returns true if found and removed.
84    pub fn remove(env: &Env, account: &Address, key_id: &BytesN<32>) -> bool {
85        let keys = Self::load_all(env, account);
86        let mut new_keys: Vec<SessionKey> = Vec::new(env);
87        let mut found = false;
88
89        for i in 0..keys.len() {
90            if let Some(k) = keys.get(i) {
91                if &k.key_id == key_id {
92                    found = true;
93                } else {
94                    new_keys.push_back(k);
95                }
96            }
97        }
98
99        if found {
100            let storage_key = Self::storage_key(env, account);
101            if new_keys.is_empty() {
102                env.storage().persistent().remove(&storage_key);
103            } else {
104                env.storage().persistent().set(&storage_key, &new_keys);
105            }
106        }
107        found
108    }
109
110    /// Increment the usage count of a session key.
111    pub fn increment_usage(
112        env: &Env,
113        account: &Address,
114        key_id: &BytesN<32>,
115    ) -> Result<(), AccountError> {
116        let keys = Self::load_all(env, account);
117        let mut new_keys: Vec<SessionKey> = Vec::new(env);
118        let mut found = false;
119
120        for i in 0..keys.len() {
121            if let Some(mut k) = keys.get(i) {
122                if &k.key_id == key_id {
123                    found = true;
124                    k.operations_used += 1;
125                    new_keys.push_back(k);
126                } else {
127                    new_keys.push_back(k);
128                }
129            }
130        }
131
132        if !found {
133            return Err(AccountError::InvalidScope);
134        }
135
136        let storage_key = Self::storage_key(env, account);
137        env.storage().persistent().set(&storage_key, &new_keys);
138        Ok(())
139    }
140
141    /// Consume one authorized session action and advance its nonce.
142    pub fn consume_authorized_session(
143        env: &Env,
144        account: &Address,
145        key_id: &BytesN<32>,
146    ) -> Result<SessionKey, AccountError> {
147        let keys = Self::load_all(env, account);
148        let mut new_keys: Vec<SessionKey> = Vec::new(env);
149        let mut updated: Option<SessionKey> = None;
150
151        for i in 0..keys.len() {
152            if let Some(mut k) = keys.get(i) {
153                if &k.key_id == key_id {
154                    if env.ledger().timestamp() >= k.scope.expires_at {
155                        return Err(AccountError::SessionExpired);
156                    }
157                    if k.operations_used >= k.scope.max_operations {
158                        return Err(AccountError::SessionBudgetExceeded);
159                    }
160                    k.operations_used += 1;
161                    k.next_nonce += 1;
162                    updated = Some(k.clone());
163                }
164                new_keys.push_back(k);
165            }
166        }
167
168        let updated = updated.ok_or(AccountError::SessionRevoked)?;
169        let storage_key = Self::storage_key(env, account);
170        env.storage().persistent().set(&storage_key, &new_keys);
171        Ok(updated)
172    }
173
174    /// Remove all expired session keys for an account.
175    /// Returns the number of keys removed.
176    pub fn cleanup_expired(env: &Env, account: &Address) -> u32 {
177        let keys = Self::load_all(env, account);
178        let now = env.ledger().timestamp();
179        let mut new_keys: Vec<SessionKey> = Vec::new(env);
180        let mut removed: u32 = 0;
181
182        for i in 0..keys.len() {
183            if let Some(k) = keys.get(i) {
184                if now >= k.scope.expires_at {
185                    removed += 1;
186                } else {
187                    new_keys.push_back(k);
188                }
189            }
190        }
191
192        if removed > 0 {
193            let storage_key = Self::storage_key(env, account);
194            if new_keys.is_empty() {
195                env.storage().persistent().remove(&storage_key);
196            } else {
197                env.storage().persistent().set(&storage_key, &new_keys);
198            }
199        }
200
201        removed
202    }
203
204    /// Build the storage key for an account's session keys.
205    fn storage_key(env: &Env, account: &Address) -> (Symbol, Address) {
206        (Symbol::new(env, SESSION_KEYS_PREFIX), account.clone())
207    }
208
209    pub(crate) fn is_action_allowed(
210        scope: &crate::accounts::types::SessionScope,
211        action: &Symbol,
212    ) -> bool {
213        if scope.allowed_actions.is_empty() {
214            return true;
215        }
216        for i in 0..scope.allowed_actions.len() {
217            if scope.allowed_actions.get(i).unwrap() == *action {
218                return true;
219            }
220        }
221        false
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228    use soroban_sdk::{contract, contractimpl, symbol_short, testutils::Address as _, vec, Env};
229
230    use crate::accounts::types::SessionScope;
231
232    // Dummy contract to provide a contract context for storage tests
233    #[contract]
234    pub struct TestContract;
235
236    #[contractimpl]
237    impl TestContract {}
238
239    fn make_session_key(env: &Env, id_byte: u8, expires_at: u64) -> SessionKey {
240        SessionKey {
241            key_id: BytesN::from_array(env, &[id_byte; 32]),
242            scope: SessionScope {
243                allowed_actions: vec![env, symbol_short!("move")],
244                max_operations: 100,
245                expires_at,
246            },
247            created_at: 0,
248            operations_used: 0,
249            next_nonce: 0,
250        }
251    }
252
253    #[test]
254    fn test_store_and_load() {
255        let env = Env::default();
256        let contract_id = env.register(TestContract, ());
257        let addr = Address::generate(&env);
258        let key = make_session_key(&env, 1, 99999);
259
260        env.as_contract(&contract_id, || {
261            SessionStorage::store(&env, &addr, &key);
262            let loaded = SessionStorage::load(&env, &addr, &key.key_id);
263            assert!(loaded.is_some());
264            assert_eq!(loaded.unwrap().operations_used, 0);
265        });
266    }
267
268    #[test]
269    fn test_load_nonexistent() {
270        let env = Env::default();
271        let contract_id = env.register(TestContract, ());
272        let addr = Address::generate(&env);
273        let fake_id = BytesN::from_array(&env, &[99u8; 32]);
274
275        env.as_contract(&contract_id, || {
276            assert!(SessionStorage::load(&env, &addr, &fake_id).is_none());
277        });
278    }
279
280    #[test]
281    fn test_load_all() {
282        let env = Env::default();
283        let contract_id = env.register(TestContract, ());
284        let addr = Address::generate(&env);
285
286        env.as_contract(&contract_id, || {
287            SessionStorage::store(&env, &addr, &make_session_key(&env, 1, 99999));
288            SessionStorage::store(&env, &addr, &make_session_key(&env, 2, 99999));
289
290            let all = SessionStorage::load_all(&env, &addr);
291            assert_eq!(all.len(), 2);
292        });
293    }
294
295    #[test]
296    fn test_remove() {
297        let env = Env::default();
298        let contract_id = env.register(TestContract, ());
299        let addr = Address::generate(&env);
300        let key = make_session_key(&env, 1, 99999);
301
302        env.as_contract(&contract_id, || {
303            SessionStorage::store(&env, &addr, &key);
304            assert!(SessionStorage::remove(&env, &addr, &key.key_id));
305            assert!(SessionStorage::load(&env, &addr, &key.key_id).is_none());
306        });
307    }
308
309    #[test]
310    fn test_remove_nonexistent() {
311        let env = Env::default();
312        let contract_id = env.register(TestContract, ());
313        let addr = Address::generate(&env);
314        let fake_id = BytesN::from_array(&env, &[99u8; 32]);
315
316        env.as_contract(&contract_id, || {
317            assert!(!SessionStorage::remove(&env, &addr, &fake_id));
318        });
319    }
320
321    #[test]
322    fn test_increment_usage() {
323        let env = Env::default();
324        let contract_id = env.register(TestContract, ());
325        let addr = Address::generate(&env);
326        let key = make_session_key(&env, 1, 99999);
327
328        env.as_contract(&contract_id, || {
329            SessionStorage::store(&env, &addr, &key);
330            SessionStorage::increment_usage(&env, &addr, &key.key_id).unwrap();
331
332            let loaded = SessionStorage::load(&env, &addr, &key.key_id).unwrap();
333            assert_eq!(loaded.operations_used, 1);
334        });
335    }
336
337    #[test]
338    fn test_increment_usage_nonexistent() {
339        let env = Env::default();
340        let contract_id = env.register(TestContract, ());
341        let addr = Address::generate(&env);
342        let fake_id = BytesN::from_array(&env, &[99u8; 32]);
343
344        env.as_contract(&contract_id, || {
345            let result = SessionStorage::increment_usage(&env, &addr, &fake_id);
346            assert_eq!(result, Err(AccountError::InvalidScope));
347        });
348    }
349
350    #[test]
351    fn test_cleanup_expired() {
352        let env = Env::default();
353        let contract_id = env.register(TestContract, ());
354        let addr = Address::generate(&env);
355
356        env.as_contract(&contract_id, || {
357            // Key with expires_at = 0 (already expired at ledger timestamp 0)
358            SessionStorage::store(&env, &addr, &make_session_key(&env, 1, 0));
359            // Key with expires_at = 99999 (not expired)
360            SessionStorage::store(&env, &addr, &make_session_key(&env, 2, 99999));
361
362            let removed = SessionStorage::cleanup_expired(&env, &addr);
363            assert_eq!(removed, 1);
364
365            let remaining = SessionStorage::load_all(&env, &addr);
366            assert_eq!(remaining.len(), 1);
367        });
368    }
369
370    #[test]
371    fn test_store_overwrites_existing() {
372        let env = Env::default();
373        let contract_id = env.register(TestContract, ());
374        let addr = Address::generate(&env);
375
376        env.as_contract(&contract_id, || {
377            let mut key = make_session_key(&env, 1, 99999);
378            SessionStorage::store(&env, &addr, &key);
379
380            // Store again with same key_id but different operations
381            key.operations_used = 42;
382            SessionStorage::store(&env, &addr, &key);
383
384            let all = SessionStorage::load_all(&env, &addr);
385            assert_eq!(all.len(), 1); // should not duplicate
386
387            let loaded = SessionStorage::load(&env, &addr, &key.key_id).unwrap();
388            assert_eq!(loaded.operations_used, 42);
389        });
390    }
391}