llm_shield_core/
vault.rs

1//! Vault for cross-scanner state management
2//!
3//! ## SPARC Specification
4//!
5//! Provides thread-safe state storage for:
6//! - Anonymization mappings
7//! - Session context
8//! - Cross-scanner communication
9
10use crate::Error;
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14/// Thread-safe state storage for scanners
15///
16/// ## Enterprise Design
17///
18/// - Thread-safe with RwLock
19/// - Type-safe value storage
20/// - Namespaced keys
21/// - Clone-friendly (Arc<RwLock>)
22#[derive(Clone)]
23pub struct Vault {
24    data: Arc<RwLock<HashMap<String, serde_json::Value>>>,
25}
26
27impl Vault {
28    /// Create a new vault
29    pub fn new() -> Self {
30        Self {
31            data: Arc::new(RwLock::new(HashMap::new())),
32        }
33    }
34
35    /// Store a value in the vault
36    pub fn set<K: Into<String>, V: serde::Serialize>(&self, key: K, value: V) -> Result<(), Error> {
37        let json_value = serde_json::to_value(value)
38            .map_err(|e| Error::vault(format!("Failed to serialize value: {}", e)))?;
39
40        self.data
41            .write()
42            .map_err(|e| Error::vault(format!("Failed to acquire write lock: {}", e)))?
43            .insert(key.into(), json_value);
44
45        Ok(())
46    }
47
48    /// Get a value from the vault
49    pub fn get<K: AsRef<str>, V: for<'de> serde::Deserialize<'de>>(
50        &self,
51        key: K,
52    ) -> Result<Option<V>, Error> {
53        let data = self
54            .data
55            .read()
56            .map_err(|e| Error::vault(format!("Failed to acquire read lock: {}", e)))?;
57
58        match data.get(key.as_ref()) {
59            Some(value) => {
60                let typed_value = serde_json::from_value(value.clone())
61                    .map_err(|e| Error::vault(format!("Failed to deserialize value: {}", e)))?;
62                Ok(Some(typed_value))
63            }
64            None => Ok(None),
65        }
66    }
67
68    /// Check if a key exists
69    pub fn contains_key<K: AsRef<str>>(&self, key: K) -> bool {
70        self.data
71            .read()
72            .map(|data| data.contains_key(key.as_ref()))
73            .unwrap_or(false)
74    }
75
76    /// Remove a value from the vault
77    pub fn remove<K: AsRef<str>>(&self, key: K) -> Result<(), Error> {
78        self.data
79            .write()
80            .map_err(|e| Error::vault(format!("Failed to acquire write lock: {}", e)))?
81            .remove(key.as_ref());
82
83        Ok(())
84    }
85
86    /// Clear all values
87    pub fn clear(&self) -> Result<(), Error> {
88        self.data
89            .write()
90            .map_err(|e| Error::vault(format!("Failed to acquire write lock: {}", e)))?
91            .clear();
92
93        Ok(())
94    }
95
96    /// Get all keys
97    pub fn keys(&self) -> Result<Vec<String>, Error> {
98        let data = self
99            .data
100            .read()
101            .map_err(|e| Error::vault(format!("Failed to acquire read lock: {}", e)))?;
102
103        Ok(data.keys().cloned().collect())
104    }
105
106    /// Get number of entries
107    pub fn len(&self) -> usize {
108        self.data.read().map(|data| data.len()).unwrap_or(0)
109    }
110
111    /// Check if vault is empty
112    pub fn is_empty(&self) -> bool {
113        self.len() == 0
114    }
115}
116
117impl Default for Vault {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126
127    #[test]
128    fn test_vault_basic_operations() {
129        let vault = Vault::new();
130
131        // Test set and get
132        vault.set("key1", "value1").unwrap();
133        assert_eq!(vault.get::<_, String>("key1").unwrap(), Some("value1".to_string()));
134
135        // Test contains_key
136        assert!(vault.contains_key("key1"));
137        assert!(!vault.contains_key("key2"));
138
139        // Test remove
140        vault.remove("key1").unwrap();
141        assert!(!vault.contains_key("key1"));
142    }
143
144    #[test]
145    fn test_vault_typed_values() {
146        let vault = Vault::new();
147
148        vault.set("int", 42i32).unwrap();
149        vault.set("float", 3.14f64).unwrap();
150        vault.set("bool", true).unwrap();
151        vault.set("string", "hello").unwrap();
152
153        assert_eq!(vault.get::<_, i32>("int").unwrap(), Some(42));
154        assert_eq!(vault.get::<_, f64>("float").unwrap(), Some(3.14));
155        assert_eq!(vault.get::<_, bool>("bool").unwrap(), Some(true));
156        assert_eq!(vault.get::<_, String>("string").unwrap(), Some("hello".to_string()));
157    }
158
159    #[test]
160    fn test_vault_clear() {
161        let vault = Vault::new();
162
163        vault.set("key1", "value1").unwrap();
164        vault.set("key2", "value2").unwrap();
165
166        assert_eq!(vault.len(), 2);
167
168        vault.clear().unwrap();
169
170        assert_eq!(vault.len(), 0);
171        assert!(vault.is_empty());
172    }
173
174    #[test]
175    fn test_vault_clone() {
176        let vault1 = Vault::new();
177        vault1.set("key", "value").unwrap();
178
179        let vault2 = vault1.clone();
180        assert_eq!(vault2.get::<_, String>("key").unwrap(), Some("value".to_string()));
181
182        // Both vaults share the same underlying data
183        vault2.set("key2", "value2").unwrap();
184        assert!(vault1.contains_key("key2"));
185    }
186}