1use soroban_sdk::{Address, BytesN, Env, Symbol, Vec};
2
3use super::error::AccountError;
4use super::types::SessionKey;
5
6const SESSION_KEYS_PREFIX: &str = "sess_keys";
8
9pub struct SessionStorage;
40
41impl SessionStorage {
42 pub fn store(env: &Env, account: &Address, key: &SessionKey) {
46 let keys = Self::load_all(env, account);
47 let mut new_keys: Vec<SessionKey> = Vec::new(env);
49 for i in 0..keys.len() {
50 if let Some(k) = keys.get(i) {
51 if k.key_id != key.key_id {
52 new_keys.push_back(k);
53 }
54 }
55 }
56 new_keys.push_back(key.clone());
57 let storage_key = Self::storage_key(env, account);
58 env.storage().persistent().set(&storage_key, &new_keys);
59 }
60
61 pub fn load(env: &Env, account: &Address, key_id: &BytesN<32>) -> Option<SessionKey> {
63 let keys = Self::load_all(env, account);
64 for i in 0..keys.len() {
65 if let Some(k) = keys.get(i) {
66 if &k.key_id == key_id {
67 return Some(k);
68 }
69 }
70 }
71 None
72 }
73
74 pub fn load_all(env: &Env, account: &Address) -> Vec<SessionKey> {
76 let storage_key = Self::storage_key(env, account);
77 env.storage()
78 .persistent()
79 .get(&storage_key)
80 .unwrap_or_else(|| Vec::new(env))
81 }
82
83 pub fn remove(env: &Env, account: &Address, key_id: &BytesN<32>) -> bool {
85 let keys = Self::load_all(env, account);
86 let mut new_keys: Vec<SessionKey> = Vec::new(env);
87 let mut found = false;
88
89 for i in 0..keys.len() {
90 if let Some(k) = keys.get(i) {
91 if &k.key_id == key_id {
92 found = true;
93 } else {
94 new_keys.push_back(k);
95 }
96 }
97 }
98
99 if found {
100 let storage_key = Self::storage_key(env, account);
101 if new_keys.is_empty() {
102 env.storage().persistent().remove(&storage_key);
103 } else {
104 env.storage().persistent().set(&storage_key, &new_keys);
105 }
106 }
107 found
108 }
109
110 pub fn increment_usage(
112 env: &Env,
113 account: &Address,
114 key_id: &BytesN<32>,
115 ) -> Result<(), AccountError> {
116 let keys = Self::load_all(env, account);
117 let mut new_keys: Vec<SessionKey> = Vec::new(env);
118 let mut found = false;
119
120 for i in 0..keys.len() {
121 if let Some(mut k) = keys.get(i) {
122 if &k.key_id == key_id {
123 found = true;
124 k.operations_used += 1;
125 new_keys.push_back(k);
126 } else {
127 new_keys.push_back(k);
128 }
129 }
130 }
131
132 if !found {
133 return Err(AccountError::InvalidScope);
134 }
135
136 let storage_key = Self::storage_key(env, account);
137 env.storage().persistent().set(&storage_key, &new_keys);
138 Ok(())
139 }
140
141 pub fn consume_authorized_session(
143 env: &Env,
144 account: &Address,
145 key_id: &BytesN<32>,
146 ) -> Result<SessionKey, AccountError> {
147 let keys = Self::load_all(env, account);
148 let mut new_keys: Vec<SessionKey> = Vec::new(env);
149 let mut updated: Option<SessionKey> = None;
150
151 for i in 0..keys.len() {
152 if let Some(mut k) = keys.get(i) {
153 if &k.key_id == key_id {
154 if env.ledger().timestamp() >= k.scope.expires_at {
155 return Err(AccountError::SessionExpired);
156 }
157 if k.operations_used >= k.scope.max_operations {
158 return Err(AccountError::SessionBudgetExceeded);
159 }
160 k.operations_used += 1;
161 k.next_nonce += 1;
162 updated = Some(k.clone());
163 }
164 new_keys.push_back(k);
165 }
166 }
167
168 let updated = updated.ok_or(AccountError::SessionRevoked)?;
169 let storage_key = Self::storage_key(env, account);
170 env.storage().persistent().set(&storage_key, &new_keys);
171 Ok(updated)
172 }
173
174 pub fn cleanup_expired(env: &Env, account: &Address) -> u32 {
177 let keys = Self::load_all(env, account);
178 let now = env.ledger().timestamp();
179 let mut new_keys: Vec<SessionKey> = Vec::new(env);
180 let mut removed: u32 = 0;
181
182 for i in 0..keys.len() {
183 if let Some(k) = keys.get(i) {
184 if now >= k.scope.expires_at {
185 removed += 1;
186 } else {
187 new_keys.push_back(k);
188 }
189 }
190 }
191
192 if removed > 0 {
193 let storage_key = Self::storage_key(env, account);
194 if new_keys.is_empty() {
195 env.storage().persistent().remove(&storage_key);
196 } else {
197 env.storage().persistent().set(&storage_key, &new_keys);
198 }
199 }
200
201 removed
202 }
203
204 fn storage_key(env: &Env, account: &Address) -> (Symbol, Address) {
206 (Symbol::new(env, SESSION_KEYS_PREFIX), account.clone())
207 }
208
209 pub(crate) fn is_action_allowed(
210 scope: &crate::accounts::types::SessionScope,
211 action: &Symbol,
212 ) -> bool {
213 if scope.allowed_actions.is_empty() {
214 return true;
215 }
216 for i in 0..scope.allowed_actions.len() {
217 if scope.allowed_actions.get(i).unwrap() == *action {
218 return true;
219 }
220 }
221 false
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228 use soroban_sdk::{contract, contractimpl, symbol_short, testutils::Address as _, vec, Env};
229
230 use crate::accounts::types::SessionScope;
231
232 #[contract]
234 pub struct TestContract;
235
236 #[contractimpl]
237 impl TestContract {}
238
239 fn make_session_key(env: &Env, id_byte: u8, expires_at: u64) -> SessionKey {
240 SessionKey {
241 key_id: BytesN::from_array(env, &[id_byte; 32]),
242 scope: SessionScope {
243 allowed_actions: vec![env, symbol_short!("move")],
244 max_operations: 100,
245 expires_at,
246 },
247 created_at: 0,
248 operations_used: 0,
249 next_nonce: 0,
250 }
251 }
252
253 #[test]
254 fn test_store_and_load() {
255 let env = Env::default();
256 let contract_id = env.register(TestContract, ());
257 let addr = Address::generate(&env);
258 let key = make_session_key(&env, 1, 99999);
259
260 env.as_contract(&contract_id, || {
261 SessionStorage::store(&env, &addr, &key);
262 let loaded = SessionStorage::load(&env, &addr, &key.key_id);
263 assert!(loaded.is_some());
264 assert_eq!(loaded.unwrap().operations_used, 0);
265 });
266 }
267
268 #[test]
269 fn test_load_nonexistent() {
270 let env = Env::default();
271 let contract_id = env.register(TestContract, ());
272 let addr = Address::generate(&env);
273 let fake_id = BytesN::from_array(&env, &[99u8; 32]);
274
275 env.as_contract(&contract_id, || {
276 assert!(SessionStorage::load(&env, &addr, &fake_id).is_none());
277 });
278 }
279
280 #[test]
281 fn test_load_all() {
282 let env = Env::default();
283 let contract_id = env.register(TestContract, ());
284 let addr = Address::generate(&env);
285
286 env.as_contract(&contract_id, || {
287 SessionStorage::store(&env, &addr, &make_session_key(&env, 1, 99999));
288 SessionStorage::store(&env, &addr, &make_session_key(&env, 2, 99999));
289
290 let all = SessionStorage::load_all(&env, &addr);
291 assert_eq!(all.len(), 2);
292 });
293 }
294
295 #[test]
296 fn test_remove() {
297 let env = Env::default();
298 let contract_id = env.register(TestContract, ());
299 let addr = Address::generate(&env);
300 let key = make_session_key(&env, 1, 99999);
301
302 env.as_contract(&contract_id, || {
303 SessionStorage::store(&env, &addr, &key);
304 assert!(SessionStorage::remove(&env, &addr, &key.key_id));
305 assert!(SessionStorage::load(&env, &addr, &key.key_id).is_none());
306 });
307 }
308
309 #[test]
310 fn test_remove_nonexistent() {
311 let env = Env::default();
312 let contract_id = env.register(TestContract, ());
313 let addr = Address::generate(&env);
314 let fake_id = BytesN::from_array(&env, &[99u8; 32]);
315
316 env.as_contract(&contract_id, || {
317 assert!(!SessionStorage::remove(&env, &addr, &fake_id));
318 });
319 }
320
321 #[test]
322 fn test_increment_usage() {
323 let env = Env::default();
324 let contract_id = env.register(TestContract, ());
325 let addr = Address::generate(&env);
326 let key = make_session_key(&env, 1, 99999);
327
328 env.as_contract(&contract_id, || {
329 SessionStorage::store(&env, &addr, &key);
330 SessionStorage::increment_usage(&env, &addr, &key.key_id).unwrap();
331
332 let loaded = SessionStorage::load(&env, &addr, &key.key_id).unwrap();
333 assert_eq!(loaded.operations_used, 1);
334 });
335 }
336
337 #[test]
338 fn test_increment_usage_nonexistent() {
339 let env = Env::default();
340 let contract_id = env.register(TestContract, ());
341 let addr = Address::generate(&env);
342 let fake_id = BytesN::from_array(&env, &[99u8; 32]);
343
344 env.as_contract(&contract_id, || {
345 let result = SessionStorage::increment_usage(&env, &addr, &fake_id);
346 assert_eq!(result, Err(AccountError::InvalidScope));
347 });
348 }
349
350 #[test]
351 fn test_cleanup_expired() {
352 let env = Env::default();
353 let contract_id = env.register(TestContract, ());
354 let addr = Address::generate(&env);
355
356 env.as_contract(&contract_id, || {
357 SessionStorage::store(&env, &addr, &make_session_key(&env, 1, 0));
359 SessionStorage::store(&env, &addr, &make_session_key(&env, 2, 99999));
361
362 let removed = SessionStorage::cleanup_expired(&env, &addr);
363 assert_eq!(removed, 1);
364
365 let remaining = SessionStorage::load_all(&env, &addr);
366 assert_eq!(remaining.len(), 1);
367 });
368 }
369
370 #[test]
371 fn test_store_overwrites_existing() {
372 let env = Env::default();
373 let contract_id = env.register(TestContract, ());
374 let addr = Address::generate(&env);
375
376 env.as_contract(&contract_id, || {
377 let mut key = make_session_key(&env, 1, 99999);
378 SessionStorage::store(&env, &addr, &key);
379
380 key.operations_used = 42;
382 SessionStorage::store(&env, &addr, &key);
383
384 let all = SessionStorage::load_all(&env, &addr);
385 assert_eq!(all.len(), 1); let loaded = SessionStorage::load(&env, &addr, &key.key_id).unwrap();
388 assert_eq!(loaded.operations_used, 42);
389 });
390 }
391}