1use std::collections::HashMap;
15use std::sync::LazyLock;
16use std::time::{Duration, Instant};
17
18use regex::Regex;
19use serde_json::Value;
20
21static KEY_RE: LazyLock<Regex> =
23 LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_\-.:]{1,256}$").expect("static regex is valid"));
24
25#[derive(Debug, Clone)]
27pub struct StashConfig {
28 pub max_keys: usize,
30 pub max_value_size: usize,
32 pub max_total_size: usize,
34 pub default_ttl: Duration,
36 pub max_ttl: Duration,
38}
39
40impl Default for StashConfig {
41 fn default() -> Self {
42 Self {
43 max_keys: 256,
44 max_value_size: 16 * 1024 * 1024,
45 max_total_size: 128 * 1024 * 1024,
46 default_ttl: Duration::from_secs(3600),
47 max_ttl: Duration::from_secs(86400),
48 }
49 }
50}
51
52struct StashEntry {
54 value: Value,
55 size_bytes: usize,
56 created_at: Instant,
57 ttl: Duration,
58 source_group: Option<String>,
59}
60
61impl StashEntry {
62 fn is_expired(&self) -> bool {
63 self.created_at.elapsed() >= self.ttl
64 }
65}
66
67pub struct SessionStash {
69 entries: HashMap<String, StashEntry>,
70 total_size: usize,
71 config: StashConfig,
72}
73
74#[derive(Debug, thiserror::Error)]
76#[non_exhaustive]
77pub enum StashError {
78 #[error("stash key limit exceeded (max {max} keys)")]
80 KeyLimitExceeded {
81 max: usize,
83 },
84 #[error("stash value too large ({size} bytes, max {max} bytes)")]
86 ValueTooLarge {
87 size: usize,
89 max: usize,
91 },
92 #[error("stash total size exceeded ({total} bytes, max {max} bytes)")]
94 TotalSizeExceeded {
95 total: usize,
97 max: usize,
99 },
100 #[error("stash key too long ({len} chars, max 256)")]
102 KeyTooLong {
103 len: usize,
105 },
106 #[error("stash key contains invalid characters")]
108 InvalidKey,
109 #[error("stash TTL exceeds maximum ({requested_secs}s, max {max_secs}s)")]
111 TtlTooLong {
112 requested_secs: u64,
114 max_secs: u64,
116 },
117 #[error(
119 "cross-group stash access denied: entry belongs to group '{entry_group}', \
120 current execution is in group '{current_group}'"
121 )]
122 CrossGroupAccess {
123 entry_group: String,
125 current_group: String,
127 },
128}
129
130pub(crate) fn validate_key(key: &str) -> Result<(), StashError> {
136 if key.is_empty() {
137 return Err(StashError::InvalidKey);
138 }
139 if key.len() > 256 {
140 return Err(StashError::KeyTooLong { len: key.len() });
141 }
142 if !KEY_RE.is_match(key) {
143 return Err(StashError::InvalidKey);
144 }
145 Ok(())
146}
147
148fn check_group_read(
150 source_group: &Option<String>,
151 current_group: Option<&str>,
152) -> Result<(), StashError> {
153 match source_group {
154 None => Ok(()), Some(entry_group) => match current_group {
156 Some(cg) if cg == entry_group => Ok(()),
157 other => Err(StashError::CrossGroupAccess {
158 entry_group: entry_group.clone(),
159 current_group: other.unwrap_or("<ungrouped>").to_string(),
160 }),
161 },
162 }
163}
164
165impl SessionStash {
170 pub fn new(config: StashConfig) -> Self {
172 Self {
173 entries: HashMap::new(),
174 total_size: 0,
175 config,
176 }
177 }
178
179 pub fn put(
189 &mut self,
190 key: &str,
191 value: Value,
192 ttl: Option<Duration>,
193 current_group: Option<&str>,
194 ) -> Result<(), StashError> {
195 validate_key(key)?;
197
198 let serialised = serde_json::to_vec(&value).unwrap_or_default();
200 let value_size = serialised.len();
201 if value_size > self.config.max_value_size {
202 return Err(StashError::ValueTooLarge {
203 size: value_size,
204 max: self.config.max_value_size,
205 });
206 }
207
208 let effective_ttl = match ttl {
210 Some(d) => {
211 if d.is_zero() {
212 return Err(StashError::TtlTooLong {
213 requested_secs: 0,
214 max_secs: self.config.max_ttl.as_secs(),
215 });
216 }
217 if d > self.config.max_ttl {
218 return Err(StashError::TtlTooLong {
219 requested_secs: d.as_secs(),
220 max_secs: self.config.max_ttl.as_secs(),
221 });
222 }
223 d
224 }
225 None => self.config.default_ttl,
226 };
227
228 let is_replacement = self.entries.contains_key(key);
230 if is_replacement {
231 let old_size = self.entries[key].size_bytes;
233 self.total_size -= old_size;
234 } else {
235 if self.entries.len() >= self.config.max_keys {
237 return Err(StashError::KeyLimitExceeded {
238 max: self.config.max_keys,
239 });
240 }
241 }
242
243 let new_total = self.total_size + value_size;
245 if new_total > self.config.max_total_size {
246 if is_replacement {
248 self.total_size += self.entries[key].size_bytes;
249 }
250 return Err(StashError::TotalSizeExceeded {
251 total: new_total,
252 max: self.config.max_total_size,
253 });
254 }
255
256 self.total_size = new_total;
258 self.entries.insert(
259 key.to_string(),
260 StashEntry {
261 value,
262 size_bytes: value_size,
263 created_at: Instant::now(),
264 ttl: effective_ttl,
265 source_group: current_group.map(str::to_string),
266 },
267 );
268 Ok(())
269 }
270
271 pub fn get(
280 &self,
281 key: &str,
282 current_group: Option<&str>,
283 ) -> Result<Option<&Value>, StashError> {
284 match self.entries.get(key) {
285 None => Ok(None),
286 Some(entry) if entry.is_expired() => Ok(None),
287 Some(entry) => {
288 check_group_read(&entry.source_group, current_group)?;
289 Ok(Some(&entry.value))
290 }
291 }
292 }
293
294 pub fn delete(&mut self, key: &str, current_group: Option<&str>) -> Result<bool, StashError> {
304 match self.entries.get(key) {
305 None => Ok(false),
306 Some(entry) => {
307 check_group_read(&entry.source_group, current_group)?;
308 let size = entry.size_bytes;
309 self.entries.remove(key);
310 self.total_size -= size;
311 Ok(true)
312 }
313 }
314 }
315
316 pub fn keys(&self, current_group: Option<&str>) -> Vec<&str> {
321 self.entries
322 .iter()
323 .filter(|(_, entry)| {
324 if entry.is_expired() {
325 return false;
326 }
327 match &entry.source_group {
329 None => true, Some(eg) => match current_group {
331 Some(cg) => cg == eg,
332 None => false, },
334 }
335 })
336 .map(|(k, _)| k.as_str())
337 .collect()
338 }
339
340 pub fn reap_expired(&mut self) -> usize {
342 let before = self.entries.len();
343 let to_remove: Vec<String> = self
344 .entries
345 .iter()
346 .filter(|(_, e)| e.is_expired())
347 .map(|(k, _)| k.clone())
348 .collect();
349 for key in &to_remove {
350 if let Some(e) = self.entries.remove(key) {
351 self.total_size -= e.size_bytes;
352 }
353 }
354 before - self.entries.len()
355 }
356}
357
358#[cfg(test)]
363mod tests {
364 use std::sync::{Arc, Mutex};
365
366 use serde_json::json;
367
368 use super::*;
369
370 fn default_stash() -> SessionStash {
371 SessionStash::new(StashConfig::default())
372 }
373
374 #[test]
376 fn st_u01_put_and_get() {
377 let mut stash = default_stash();
378 stash
379 .put("key1", json!({"hello": "world"}), None, None)
380 .unwrap();
381 let v = stash.get("key1", None).unwrap().unwrap();
382 assert_eq!(v, &json!({"hello": "world"}));
383 }
384
385 #[test]
387 fn st_u02_put_replaces_existing_key() {
388 let mut stash = default_stash();
389 stash
390 .put("k", json!("a big string that takes space"), None, None)
391 .unwrap();
392 let size_after_first = stash.total_size;
393 stash.put("k", json!(1), None, None).unwrap();
394 let size_after_second = stash.total_size;
395 assert!(
396 size_after_second < size_after_first,
397 "total_size should shrink when replacing with smaller value"
398 );
399 let v = stash.get("k", None).unwrap().unwrap();
400 assert_eq!(v, &json!(1));
401 }
402
403 #[test]
405 fn st_u03_put_rejects_key_too_long() {
406 let mut stash = default_stash();
407 let long_key = "a".repeat(257);
408 let err = stash.put(&long_key, json!(null), None, None).unwrap_err();
409 assert!(matches!(err, StashError::KeyTooLong { len: 257 }));
410 }
411
412 #[test]
414 fn st_u04_put_rejects_invalid_key_characters() {
415 let mut stash = default_stash();
416 for bad in &["key with space", "key/slash", "key\0null"] {
417 let err = stash.put(bad, json!(null), None, None).unwrap_err();
418 assert!(
419 matches!(err, StashError::InvalidKey),
420 "expected InvalidKey for {:?}",
421 bad
422 );
423 }
424 }
425
426 #[test]
428 fn st_key_regex_rejects_trailing_invalid_char() {
429 let mut stash = default_stash();
430 let err = stash
431 .put("valid_key!", json!(null), None, None)
432 .unwrap_err();
433 assert!(
434 matches!(err, StashError::InvalidKey),
435 "expected InvalidKey for key with trailing '!', got: {err}"
436 );
437 }
438
439 #[test]
441 fn st_u05_put_rejects_oversized_value() {
442 let config = StashConfig {
443 max_value_size: 10,
444 ..Default::default()
445 };
446 let mut stash = SessionStash::new(config);
447 let big_value = json!("this is definitely more than ten bytes");
449 let err = stash.put("k", big_value, None, None).unwrap_err();
450 assert!(matches!(err, StashError::ValueTooLarge { .. }));
451 }
452
453 #[test]
455 fn st_u06_put_rejects_when_total_size_exceeded() {
456 let config = StashConfig {
459 max_total_size: 30,
460 max_value_size: 100,
461 ..Default::default()
462 };
463 let mut stash = SessionStash::new(config);
464 stash.put("k1", json!("12345"), None, None).unwrap();
466 let err = stash
469 .put("k2", json!("abcdefghijklmnopqrstuvwxyz"), None, None)
470 .unwrap_err();
471 assert!(matches!(err, StashError::TotalSizeExceeded { .. }));
472 }
473
474 #[test]
476 fn st_u07_put_rejects_when_key_count_exceeded() {
477 let config = StashConfig {
478 max_keys: 2,
479 ..Default::default()
480 };
481 let mut stash = SessionStash::new(config);
482 stash.put("k1", json!(1), None, None).unwrap();
483 stash.put("k2", json!(2), None, None).unwrap();
484 let err = stash.put("k3", json!(3), None, None).unwrap_err();
485 assert!(matches!(err, StashError::KeyLimitExceeded { max: 2 }));
486 }
487
488 #[test]
490 fn st_u08_put_rejects_ttl_exceeding_max() {
491 let config = StashConfig {
492 max_ttl: Duration::from_secs(60),
493 ..Default::default()
494 };
495 let mut stash = SessionStash::new(config);
496 let err = stash
497 .put("k", json!(1), Some(Duration::from_secs(61)), None)
498 .unwrap_err();
499 assert!(matches!(err, StashError::TtlTooLong { .. }));
500 }
501
502 #[test]
504 fn st_u09_get_returns_none_for_missing_key() {
505 let stash = default_stash();
506 assert!(stash.get("no-such-key", None).unwrap().is_none());
507 }
508
509 #[test]
511 fn st_u10_get_returns_none_for_expired_key() {
512 let mut stash = default_stash();
513 stash
514 .put("k", json!("v"), Some(Duration::from_millis(1)), None)
515 .unwrap();
516 std::thread::sleep(Duration::from_millis(10));
517 assert!(stash.get("k", None).unwrap().is_none());
518 }
519
520 #[test]
522 fn st_u11_get_cross_group_access_denied() {
523 let mut stash = default_stash();
524 stash.put("k", json!(1), None, Some("group-a")).unwrap();
525 let err = stash.get("k", Some("group-b")).unwrap_err();
526 assert!(
527 matches!(err, StashError::CrossGroupAccess { .. }),
528 "unexpected error: {err}"
529 );
530 }
531
532 #[test]
534 fn st_u12_get_same_group_allowed() {
535 let mut stash = default_stash();
536 stash.put("k", json!(42), None, Some("team-a")).unwrap();
537 let v = stash.get("k", Some("team-a")).unwrap().unwrap();
538 assert_eq!(v, &json!(42));
539 }
540
541 #[test]
543 fn st_u13_ungrouped_entry_accessible_to_ungrouped() {
544 let mut stash = default_stash();
545 stash.put("k", json!("public"), None, None).unwrap();
546 let v = stash.get("k", None).unwrap().unwrap();
547 assert_eq!(v, &json!("public"));
548 }
549
550 #[test]
552 fn st_u14_grouped_execution_can_read_ungrouped_entry() {
553 let mut stash = default_stash();
554 stash.put("k", json!("public"), None, None).unwrap();
555 let v = stash.get("k", Some("any-group")).unwrap().unwrap();
557 assert_eq!(v, &json!("public"));
558 }
559
560 #[test]
562 fn st_u15_delete_removes_entry_and_updates_size() {
563 let mut stash = default_stash();
564 stash.put("k", json!("value"), None, None).unwrap();
565 let size_before = stash.total_size;
566 assert!(size_before > 0);
567 let removed = stash.delete("k", None).unwrap();
568 assert!(removed);
569 assert_eq!(stash.total_size, 0);
570 assert!(stash.get("k", None).unwrap().is_none());
571 }
572
573 #[test]
575 fn st_u16_delete_returns_false_for_missing_key() {
576 let mut stash = default_stash();
577 let removed = stash.delete("no-such-key", None).unwrap();
578 assert!(!removed);
579 }
580
581 #[test]
583 fn st_u17_delete_cross_group_denied() {
584 let mut stash = default_stash();
585 stash.put("k", json!(1), None, Some("group-a")).unwrap();
586 let err = stash.delete("k", Some("group-b")).unwrap_err();
587 assert!(matches!(err, StashError::CrossGroupAccess { .. }));
588 let v = stash.get("k", Some("group-a")).unwrap().unwrap();
590 assert_eq!(v, &json!(1));
591 }
592
593 #[test]
595 fn st_u18_keys_filtered_by_group() {
596 let mut stash = default_stash();
597 stash.put("pub", json!(1), None, None).unwrap();
598 stash.put("a-key", json!(2), None, Some("group-a")).unwrap();
599 stash.put("b-key", json!(3), None, Some("group-b")).unwrap();
600
601 let mut keys_a: Vec<&str> = stash.keys(Some("group-a"));
602 keys_a.sort();
603 assert_eq!(keys_a, vec!["a-key", "pub"]);
604
605 let mut keys_b: Vec<&str> = stash.keys(Some("group-b"));
606 keys_b.sort();
607 assert_eq!(keys_b, vec!["b-key", "pub"]);
608
609 let keys_none: Vec<&str> = stash.keys(None);
610 assert_eq!(keys_none, vec!["pub"]);
611 }
612
613 #[test]
615 fn st_u19_keys_excludes_expired() {
616 let mut stash = default_stash();
617 stash.put("alive", json!(1), None, None).unwrap();
618 stash
619 .put("dead", json!(2), Some(Duration::from_millis(1)), None)
620 .unwrap();
621 std::thread::sleep(Duration::from_millis(10));
622 let mut keys: Vec<&str> = stash.keys(None);
623 keys.sort();
624 assert_eq!(keys, vec!["alive"]);
625 }
626
627 #[test]
629 fn st_u20_reap_expired() {
630 let mut stash = default_stash();
631 stash
632 .put("k1", json!("a"), Some(Duration::from_millis(1)), None)
633 .unwrap();
634 stash
635 .put("k2", json!("b"), Some(Duration::from_millis(1)), None)
636 .unwrap();
637 stash.put("k3", json!("c"), None, None).unwrap();
638 let size_before = stash.total_size;
639 assert!(size_before > 0);
640
641 std::thread::sleep(Duration::from_millis(10));
642
643 let removed = stash.reap_expired();
644 assert_eq!(removed, 2);
645 assert_eq!(stash.entries.len(), 1);
646 assert!(stash.total_size < size_before);
647 assert!(stash.get("k3", None).unwrap().is_some());
648 }
649
650 #[test]
652 fn st_u21_put_ttl_zero_rejected() {
653 let mut stash = default_stash();
654 let err = stash
655 .put("k", json!(1), Some(Duration::from_secs(0)), None)
656 .unwrap_err();
657 assert!(matches!(
658 err,
659 StashError::TtlTooLong {
660 requested_secs: 0,
661 ..
662 }
663 ));
664 }
665
666 #[tokio::test]
668 async fn st_u22_concurrent_put_get() {
669 let stash = Arc::new(Mutex::new(default_stash()));
670
671 let mut handles = Vec::new();
672 for i in 0..8usize {
673 let stash = stash.clone();
674 handles.push(tokio::spawn(async move {
675 let key = format!("key-{i}");
676 {
677 let mut s = stash.lock().unwrap();
678 s.put(&key, json!(i), None, None).unwrap();
679 }
680 {
681 let s = stash.lock().unwrap();
682 let v = s.get(&key, None).unwrap().unwrap();
683 assert_eq!(v, &json!(i));
684 }
685 }));
686 }
687 for h in handles {
688 h.await.unwrap();
689 }
690 }
691
692 #[test]
694 fn st_u23_replace_large_with_small_decrements_total_size() {
695 let mut stash = default_stash();
696 let big = json!("x".repeat(1000));
697 stash.put("k", big, None, None).unwrap();
698 let size_after_big = stash.total_size;
699
700 stash.put("k", json!(1), None, None).unwrap();
701 let size_after_small = stash.total_size;
702
703 assert!(
704 size_after_small < size_after_big,
705 "total_size ({size_after_small}) should be less than after big insert ({size_after_big})"
706 );
707 let expected = serde_json::to_vec(&json!(1)).unwrap().len();
709 assert_eq!(stash.total_size, expected);
710 }
711}