Skip to main content

forge_sandbox/
stash.rs

1//! Session stash — a per-session key/value store with TTL and group isolation.
2//!
3//! The stash lets sandbox executions persist data across calls within the same
4//! session. Entries are scoped to a server group (optional) and expire after a
5//! configurable TTL.
6//!
7//! # Group access rules
8//!
9//! - `source_group = None` → readable by **any** execution (public within session)
10//! - `source_group = Some("A")`, `current_group = Some("A")` → OK (same group)
11//! - `source_group = Some("A")`, `current_group = Some("B")` → [`StashError::CrossGroupAccess`]
12//! - `source_group = Some("A")`, `current_group = None` → [`StashError::CrossGroupAccess`]
13
14use std::collections::HashMap;
15use std::sync::LazyLock;
16use std::time::{Duration, Instant};
17
18use regex::Regex;
19use serde_json::Value;
20
21/// Key validation regex: alphanumerics plus `_`, `-`, `.`, `:`, 1–256 chars.
22static KEY_RE: LazyLock<Regex> =
23    LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_\-.:]{1,256}$").expect("static regex is valid"));
24
25/// Configuration for the session stash.
26#[derive(Debug, Clone)]
27pub struct StashConfig {
28    /// Maximum number of distinct keys (default: 256).
29    pub max_keys: usize,
30    /// Maximum size of a single value when JSON-serialised in bytes (default: 16 MiB).
31    pub max_value_size: usize,
32    /// Maximum combined size of all values in bytes (default: 128 MiB).
33    pub max_total_size: usize,
34    /// TTL applied when a caller does not specify one (default: 1 hour).
35    pub default_ttl: Duration,
36    /// Hard ceiling on caller-supplied TTLs (default: 24 hours).
37    pub max_ttl: Duration,
38}
39
40impl Default for StashConfig {
41    fn default() -> Self {
42        Self {
43            max_keys: 256,
44            max_value_size: 16 * 1024 * 1024,
45            max_total_size: 128 * 1024 * 1024,
46            default_ttl: Duration::from_secs(3600),
47            max_ttl: Duration::from_secs(86400),
48        }
49    }
50}
51
52/// A single entry stored in the stash.
53struct StashEntry {
54    value: Value,
55    size_bytes: usize,
56    created_at: Instant,
57    ttl: Duration,
58    source_group: Option<String>,
59}
60
61impl StashEntry {
62    fn is_expired(&self) -> bool {
63        self.created_at.elapsed() >= self.ttl
64    }
65}
66
67/// Per-session key/value store with TTL-based expiry and group isolation.
68pub struct SessionStash {
69    entries: HashMap<String, StashEntry>,
70    total_size: usize,
71    config: StashConfig,
72}
73
74/// Errors returned by [`SessionStash`] operations.
75#[derive(Debug, thiserror::Error)]
76#[non_exhaustive]
77pub enum StashError {
78    /// The stash already holds the maximum number of keys.
79    #[error("stash key limit exceeded (max {max} keys)")]
80    KeyLimitExceeded {
81        /// Configured maximum.
82        max: usize,
83    },
84    /// The serialised value exceeds the per-value size limit.
85    #[error("stash value too large ({size} bytes, max {max} bytes)")]
86    ValueTooLarge {
87        /// Actual size.
88        size: usize,
89        /// Configured maximum.
90        max: usize,
91    },
92    /// Adding the value would push total stash size past the limit.
93    #[error("stash total size exceeded ({total} bytes, max {max} bytes)")]
94    TotalSizeExceeded {
95        /// Projected total.
96        total: usize,
97        /// Configured maximum.
98        max: usize,
99    },
100    /// The key exceeds 256 characters.
101    #[error("stash key too long ({len} chars, max 256)")]
102    KeyTooLong {
103        /// Actual length.
104        len: usize,
105    },
106    /// The key contains characters outside `[a-zA-Z0-9_\-.:]{1,256}`.
107    #[error("stash key contains invalid characters")]
108    InvalidKey,
109    /// The caller-supplied TTL exceeds `max_ttl`.
110    #[error("stash TTL exceeds maximum ({requested_secs}s, max {max_secs}s)")]
111    TtlTooLong {
112        /// Requested TTL in seconds.
113        requested_secs: u64,
114        /// Configured maximum TTL in seconds.
115        max_secs: u64,
116    },
117    /// The current execution belongs to a different group than the entry.
118    #[error(
119        "cross-group stash access denied: entry belongs to group '{entry_group}', \
120         current execution is in group '{current_group}'"
121    )]
122    CrossGroupAccess {
123        /// Group that owns the entry.
124        entry_group: String,
125        /// Group (or representation of ungrouped) attempting access.
126        current_group: String,
127    },
128}
129
130// ---------------------------------------------------------------------------
131// Helpers
132// ---------------------------------------------------------------------------
133
134/// Validate a stash key, returning an appropriate error on failure.
135pub(crate) fn validate_key(key: &str) -> Result<(), StashError> {
136    if key.is_empty() {
137        return Err(StashError::InvalidKey);
138    }
139    if key.len() > 256 {
140        return Err(StashError::KeyTooLong { len: key.len() });
141    }
142    if !KEY_RE.is_match(key) {
143        return Err(StashError::InvalidKey);
144    }
145    Ok(())
146}
147
148/// Check whether `current_group` may access an entry owned by `source_group`.
149fn check_group_read(
150    source_group: &Option<String>,
151    current_group: Option<&str>,
152) -> Result<(), StashError> {
153    match source_group {
154        None => Ok(()), // public entry — anyone may read
155        Some(entry_group) => match current_group {
156            Some(cg) if cg == entry_group => Ok(()),
157            other => Err(StashError::CrossGroupAccess {
158                entry_group: entry_group.clone(),
159                current_group: other.unwrap_or("<ungrouped>").to_string(),
160            }),
161        },
162    }
163}
164
165// ---------------------------------------------------------------------------
166// SessionStash implementation
167// ---------------------------------------------------------------------------
168
169impl SessionStash {
170    /// Create a new stash with the given configuration.
171    pub fn new(config: StashConfig) -> Self {
172        Self {
173            entries: HashMap::new(),
174            total_size: 0,
175            config,
176        }
177    }
178
179    /// Store a value under `key`.
180    ///
181    /// If the key already exists its value is replaced and size accounting is
182    /// updated accordingly. The `current_group` is recorded as the entry's
183    /// owner — future reads from other groups will be denied.
184    ///
185    /// # Errors
186    ///
187    /// Returns [`StashError`] if any limit is exceeded or the key is invalid.
188    pub fn put(
189        &mut self,
190        key: &str,
191        value: Value,
192        ttl: Option<Duration>,
193        current_group: Option<&str>,
194    ) -> Result<(), StashError> {
195        // --- Key validation ---
196        validate_key(key)?;
197
198        // --- Value size ---
199        let serialised = serde_json::to_vec(&value).unwrap_or_default();
200        let value_size = serialised.len();
201        if value_size > self.config.max_value_size {
202            return Err(StashError::ValueTooLarge {
203                size: value_size,
204                max: self.config.max_value_size,
205            });
206        }
207
208        // --- TTL validation ---
209        let effective_ttl = match ttl {
210            Some(d) => {
211                if d.is_zero() {
212                    return Err(StashError::TtlTooLong {
213                        requested_secs: 0,
214                        max_secs: self.config.max_ttl.as_secs(),
215                    });
216                }
217                if d > self.config.max_ttl {
218                    return Err(StashError::TtlTooLong {
219                        requested_secs: d.as_secs(),
220                        max_secs: self.config.max_ttl.as_secs(),
221                    });
222                }
223                d
224            }
225            None => self.config.default_ttl,
226        };
227
228        // --- Replace vs. new key ---
229        let is_replacement = self.entries.contains_key(key);
230        if is_replacement {
231            // Subtract old size before checking limits
232            let old_size = self.entries[key].size_bytes;
233            self.total_size -= old_size;
234        } else {
235            // Only check key count for brand-new keys
236            if self.entries.len() >= self.config.max_keys {
237                return Err(StashError::KeyLimitExceeded {
238                    max: self.config.max_keys,
239                });
240            }
241        }
242
243        // --- Total size check ---
244        let new_total = self.total_size + value_size;
245        if new_total > self.config.max_total_size {
246            // Roll back the subtraction we did for a replacement
247            if is_replacement {
248                self.total_size += self.entries[key].size_bytes;
249            }
250            return Err(StashError::TotalSizeExceeded {
251                total: new_total,
252                max: self.config.max_total_size,
253            });
254        }
255
256        // --- Commit ---
257        self.total_size = new_total;
258        self.entries.insert(
259            key.to_string(),
260            StashEntry {
261                value,
262                size_bytes: value_size,
263                created_at: Instant::now(),
264                ttl: effective_ttl,
265                source_group: current_group.map(str::to_string),
266            },
267        );
268        Ok(())
269    }
270
271    /// Retrieve the value stored under `key`.
272    ///
273    /// Returns `Ok(None)` if the key does not exist or has expired.
274    ///
275    /// # Errors
276    ///
277    /// Returns [`StashError::CrossGroupAccess`] if the entry is owned by a
278    /// different group than `current_group`.
279    pub fn get(
280        &self,
281        key: &str,
282        current_group: Option<&str>,
283    ) -> Result<Option<&Value>, StashError> {
284        match self.entries.get(key) {
285            None => Ok(None),
286            Some(entry) if entry.is_expired() => Ok(None),
287            Some(entry) => {
288                check_group_read(&entry.source_group, current_group)?;
289                Ok(Some(&entry.value))
290            }
291        }
292    }
293
294    /// Remove the entry stored under `key`.
295    ///
296    /// Returns `Ok(true)` if the entry was present and removed, `Ok(false)` if
297    /// the key did not exist.
298    ///
299    /// # Errors
300    ///
301    /// Returns [`StashError::CrossGroupAccess`] if the entry is owned by a
302    /// different group.
303    pub fn delete(&mut self, key: &str, current_group: Option<&str>) -> Result<bool, StashError> {
304        match self.entries.get(key) {
305            None => Ok(false),
306            Some(entry) => {
307                check_group_read(&entry.source_group, current_group)?;
308                let size = entry.size_bytes;
309                self.entries.remove(key);
310                self.total_size -= size;
311                Ok(true)
312            }
313        }
314    }
315
316    /// Return the keys currently visible to `current_group`.
317    ///
318    /// Expired entries and entries belonging to a different strict group are
319    /// excluded.
320    pub fn keys(&self, current_group: Option<&str>) -> Vec<&str> {
321        self.entries
322            .iter()
323            .filter(|(_, entry)| {
324                if entry.is_expired() {
325                    return false;
326                }
327                // Apply group visibility rules (same logic as get, but no error)
328                match &entry.source_group {
329                    None => true, // public entry
330                    Some(eg) => match current_group {
331                        Some(cg) => cg == eg,
332                        None => false, // ungrouped can't see grouped entries
333                    },
334                }
335            })
336            .map(|(k, _)| k.as_str())
337            .collect()
338    }
339
340    /// Remove all expired entries and return how many were removed.
341    pub fn reap_expired(&mut self) -> usize {
342        let before = self.entries.len();
343        let to_remove: Vec<String> = self
344            .entries
345            .iter()
346            .filter(|(_, e)| e.is_expired())
347            .map(|(k, _)| k.clone())
348            .collect();
349        for key in &to_remove {
350            if let Some(e) = self.entries.remove(key) {
351                self.total_size -= e.size_bytes;
352            }
353        }
354        before - self.entries.len()
355    }
356}
357
358// ---------------------------------------------------------------------------
359// Tests
360// ---------------------------------------------------------------------------
361
362#[cfg(test)]
363mod tests {
364    use std::sync::{Arc, Mutex};
365
366    use serde_json::json;
367
368    use super::*;
369
370    fn default_stash() -> SessionStash {
371        SessionStash::new(StashConfig::default())
372    }
373
374    // ST-U01: put() stores value and get() retrieves it
375    #[test]
376    fn st_u01_put_and_get() {
377        let mut stash = default_stash();
378        stash
379            .put("key1", json!({"hello": "world"}), None, None)
380            .unwrap();
381        let v = stash.get("key1", None).unwrap().unwrap();
382        assert_eq!(v, &json!({"hello": "world"}));
383    }
384
385    // ST-U02: put() replaces existing key (updates size accounting)
386    #[test]
387    fn st_u02_put_replaces_existing_key() {
388        let mut stash = default_stash();
389        stash
390            .put("k", json!("a big string that takes space"), None, None)
391            .unwrap();
392        let size_after_first = stash.total_size;
393        stash.put("k", json!(1), None, None).unwrap();
394        let size_after_second = stash.total_size;
395        assert!(
396            size_after_second < size_after_first,
397            "total_size should shrink when replacing with smaller value"
398        );
399        let v = stash.get("k", None).unwrap().unwrap();
400        assert_eq!(v, &json!(1));
401    }
402
403    // ST-U03: put() rejects key exceeding 256 chars
404    #[test]
405    fn st_u03_put_rejects_key_too_long() {
406        let mut stash = default_stash();
407        let long_key = "a".repeat(257);
408        let err = stash.put(&long_key, json!(null), None, None).unwrap_err();
409        assert!(matches!(err, StashError::KeyTooLong { len: 257 }));
410    }
411
412    // ST-U04: put() rejects key with invalid characters (spaces, slashes, null bytes)
413    #[test]
414    fn st_u04_put_rejects_invalid_key_characters() {
415        let mut stash = default_stash();
416        for bad in &["key with space", "key/slash", "key\0null"] {
417            let err = stash.put(bad, json!(null), None, None).unwrap_err();
418            assert!(
419                matches!(err, StashError::InvalidKey),
420                "expected InvalidKey for {:?}",
421                bad
422            );
423        }
424    }
425
426    // Key with trailing invalid char rejected (validates $ end anchor)
427    #[test]
428    fn st_key_regex_rejects_trailing_invalid_char() {
429        let mut stash = default_stash();
430        let err = stash
431            .put("valid_key!", json!(null), None, None)
432            .unwrap_err();
433        assert!(
434            matches!(err, StashError::InvalidKey),
435            "expected InvalidKey for key with trailing '!', got: {err}"
436        );
437    }
438
439    // ST-U05: put() rejects value exceeding max_value_size
440    #[test]
441    fn st_u05_put_rejects_oversized_value() {
442        let config = StashConfig {
443            max_value_size: 10,
444            ..Default::default()
445        };
446        let mut stash = SessionStash::new(config);
447        // A string with 11+ chars will serialise to more than 10 bytes
448        let big_value = json!("this is definitely more than ten bytes");
449        let err = stash.put("k", big_value, None, None).unwrap_err();
450        assert!(matches!(err, StashError::ValueTooLarge { .. }));
451    }
452
453    // ST-U06: put() rejects when total_size would exceed max_total_size
454    #[test]
455    fn st_u06_put_rejects_when_total_size_exceeded() {
456        // max_total_size is 30 bytes, max_value_size is 100 bytes so individual
457        // values pass the per-value check but the combination exceeds the total.
458        let config = StashConfig {
459            max_total_size: 30,
460            max_value_size: 100,
461            ..Default::default()
462        };
463        let mut stash = SessionStash::new(config);
464        // First put — "12345" serialises as "\"12345\"" = 7 bytes, fits fine
465        stash.put("k1", json!("12345"), None, None).unwrap();
466        // Second put — "abcdefghijklmnopqrstuvwxyz" = 28 bytes serialised;
467        // combined with k1 that's ~35 bytes which exceeds max_total_size=30
468        let err = stash
469            .put("k2", json!("abcdefghijklmnopqrstuvwxyz"), None, None)
470            .unwrap_err();
471        assert!(matches!(err, StashError::TotalSizeExceeded { .. }));
472    }
473
474    // ST-U07: put() rejects when key count exceeds max_keys
475    #[test]
476    fn st_u07_put_rejects_when_key_count_exceeded() {
477        let config = StashConfig {
478            max_keys: 2,
479            ..Default::default()
480        };
481        let mut stash = SessionStash::new(config);
482        stash.put("k1", json!(1), None, None).unwrap();
483        stash.put("k2", json!(2), None, None).unwrap();
484        let err = stash.put("k3", json!(3), None, None).unwrap_err();
485        assert!(matches!(err, StashError::KeyLimitExceeded { max: 2 }));
486    }
487
488    // ST-U08: put() rejects TTL exceeding max_ttl
489    #[test]
490    fn st_u08_put_rejects_ttl_exceeding_max() {
491        let config = StashConfig {
492            max_ttl: Duration::from_secs(60),
493            ..Default::default()
494        };
495        let mut stash = SessionStash::new(config);
496        let err = stash
497            .put("k", json!(1), Some(Duration::from_secs(61)), None)
498            .unwrap_err();
499        assert!(matches!(err, StashError::TtlTooLong { .. }));
500    }
501
502    // ST-U09: get() returns None for nonexistent key
503    #[test]
504    fn st_u09_get_returns_none_for_missing_key() {
505        let stash = default_stash();
506        assert!(stash.get("no-such-key", None).unwrap().is_none());
507    }
508
509    // ST-U10: get() returns None for expired key
510    #[test]
511    fn st_u10_get_returns_none_for_expired_key() {
512        let mut stash = default_stash();
513        stash
514            .put("k", json!("v"), Some(Duration::from_millis(1)), None)
515            .unwrap();
516        std::thread::sleep(Duration::from_millis(10));
517        assert!(stash.get("k", None).unwrap().is_none());
518    }
519
520    // ST-U11: get() returns CrossGroupAccess error for different strict group
521    #[test]
522    fn st_u11_get_cross_group_access_denied() {
523        let mut stash = default_stash();
524        stash.put("k", json!(1), None, Some("group-a")).unwrap();
525        let err = stash.get("k", Some("group-b")).unwrap_err();
526        assert!(
527            matches!(err, StashError::CrossGroupAccess { .. }),
528            "unexpected error: {err}"
529        );
530    }
531
532    // ST-U12: get() allows access from same strict group
533    #[test]
534    fn st_u12_get_same_group_allowed() {
535        let mut stash = default_stash();
536        stash.put("k", json!(42), None, Some("team-a")).unwrap();
537        let v = stash.get("k", Some("team-a")).unwrap().unwrap();
538        assert_eq!(v, &json!(42));
539    }
540
541    // ST-U13: get() allows access from ungrouped execution to ungrouped entries (open group)
542    #[test]
543    fn st_u13_ungrouped_entry_accessible_to_ungrouped() {
544        let mut stash = default_stash();
545        stash.put("k", json!("public"), None, None).unwrap();
546        let v = stash.get("k", None).unwrap().unwrap();
547        assert_eq!(v, &json!("public"));
548    }
549
550    // ST-U14: get() allows access from ungrouped execution to ungrouped entries
551    #[test]
552    fn st_u14_grouped_execution_can_read_ungrouped_entry() {
553        let mut stash = default_stash();
554        stash.put("k", json!("public"), None, None).unwrap();
555        // A grouped execution should also be able to read a public entry
556        let v = stash.get("k", Some("any-group")).unwrap().unwrap();
557        assert_eq!(v, &json!("public"));
558    }
559
560    // ST-U15: delete() removes entry and updates size accounting
561    #[test]
562    fn st_u15_delete_removes_entry_and_updates_size() {
563        let mut stash = default_stash();
564        stash.put("k", json!("value"), None, None).unwrap();
565        let size_before = stash.total_size;
566        assert!(size_before > 0);
567        let removed = stash.delete("k", None).unwrap();
568        assert!(removed);
569        assert_eq!(stash.total_size, 0);
570        assert!(stash.get("k", None).unwrap().is_none());
571    }
572
573    // ST-U16: delete() returns false for nonexistent key
574    #[test]
575    fn st_u16_delete_returns_false_for_missing_key() {
576        let mut stash = default_stash();
577        let removed = stash.delete("no-such-key", None).unwrap();
578        assert!(!removed);
579    }
580
581    // ST-U17: delete() enforces group isolation (cannot delete cross-group entries)
582    #[test]
583    fn st_u17_delete_cross_group_denied() {
584        let mut stash = default_stash();
585        stash.put("k", json!(1), None, Some("group-a")).unwrap();
586        let err = stash.delete("k", Some("group-b")).unwrap_err();
587        assert!(matches!(err, StashError::CrossGroupAccess { .. }));
588        // Entry should still be present
589        let v = stash.get("k", Some("group-a")).unwrap().unwrap();
590        assert_eq!(v, &json!(1));
591    }
592
593    // ST-U18: keys() returns only keys visible to current group
594    #[test]
595    fn st_u18_keys_filtered_by_group() {
596        let mut stash = default_stash();
597        stash.put("pub", json!(1), None, None).unwrap();
598        stash.put("a-key", json!(2), None, Some("group-a")).unwrap();
599        stash.put("b-key", json!(3), None, Some("group-b")).unwrap();
600
601        let mut keys_a: Vec<&str> = stash.keys(Some("group-a"));
602        keys_a.sort();
603        assert_eq!(keys_a, vec!["a-key", "pub"]);
604
605        let mut keys_b: Vec<&str> = stash.keys(Some("group-b"));
606        keys_b.sort();
607        assert_eq!(keys_b, vec!["b-key", "pub"]);
608
609        let keys_none: Vec<&str> = stash.keys(None);
610        assert_eq!(keys_none, vec!["pub"]);
611    }
612
613    // ST-U19: keys() excludes expired keys
614    #[test]
615    fn st_u19_keys_excludes_expired() {
616        let mut stash = default_stash();
617        stash.put("alive", json!(1), None, None).unwrap();
618        stash
619            .put("dead", json!(2), Some(Duration::from_millis(1)), None)
620            .unwrap();
621        std::thread::sleep(Duration::from_millis(10));
622        let mut keys: Vec<&str> = stash.keys(None);
623        keys.sort();
624        assert_eq!(keys, vec!["alive"]);
625    }
626
627    // ST-U20: reap_expired() removes all expired entries and updates total_size
628    #[test]
629    fn st_u20_reap_expired() {
630        let mut stash = default_stash();
631        stash
632            .put("k1", json!("a"), Some(Duration::from_millis(1)), None)
633            .unwrap();
634        stash
635            .put("k2", json!("b"), Some(Duration::from_millis(1)), None)
636            .unwrap();
637        stash.put("k3", json!("c"), None, None).unwrap();
638        let size_before = stash.total_size;
639        assert!(size_before > 0);
640
641        std::thread::sleep(Duration::from_millis(10));
642
643        let removed = stash.reap_expired();
644        assert_eq!(removed, 2);
645        assert_eq!(stash.entries.len(), 1);
646        assert!(stash.total_size < size_before);
647        assert!(stash.get("k3", None).unwrap().is_some());
648    }
649
650    // ST-U21: put() with explicit TTL=0 is rejected (must be positive)
651    #[test]
652    fn st_u21_put_ttl_zero_rejected() {
653        let mut stash = default_stash();
654        let err = stash
655            .put("k", json!(1), Some(Duration::from_secs(0)), None)
656            .unwrap_err();
657        assert!(matches!(
658            err,
659            StashError::TtlTooLong {
660                requested_secs: 0,
661                ..
662            }
663        ));
664    }
665
666    // ST-U22: concurrent put/get from multiple threads (Arc<Mutex<>> safety)
667    #[tokio::test]
668    async fn st_u22_concurrent_put_get() {
669        let stash = Arc::new(Mutex::new(default_stash()));
670
671        let mut handles = Vec::new();
672        for i in 0..8usize {
673            let stash = stash.clone();
674            handles.push(tokio::spawn(async move {
675                let key = format!("key-{i}");
676                {
677                    let mut s = stash.lock().unwrap();
678                    s.put(&key, json!(i), None, None).unwrap();
679                }
680                {
681                    let s = stash.lock().unwrap();
682                    let v = s.get(&key, None).unwrap().unwrap();
683                    assert_eq!(v, &json!(i));
684                }
685            }));
686        }
687        for h in handles {
688            h.await.unwrap();
689        }
690    }
691
692    // ST-U23: replacing a large value with a small one correctly decrements total_size
693    #[test]
694    fn st_u23_replace_large_with_small_decrements_total_size() {
695        let mut stash = default_stash();
696        let big = json!("x".repeat(1000));
697        stash.put("k", big, None, None).unwrap();
698        let size_after_big = stash.total_size;
699
700        stash.put("k", json!(1), None, None).unwrap();
701        let size_after_small = stash.total_size;
702
703        assert!(
704            size_after_small < size_after_big,
705            "total_size ({size_after_small}) should be less than after big insert ({size_after_big})"
706        );
707        // Verify the new size matches the small value's serialised length
708        let expected = serde_json::to_vec(&json!(1)).unwrap().len();
709        assert_eq!(stash.total_size, expected);
710    }
711}