Skip to main content

punch_types/
secret_store.rs

1//! Secret management with zeroization — keeps secrets locked in the vault.
2//!
3//! Provides a zero-on-drop `Secret` wrapper that wipes sensitive data from
4//! memory when it goes out of scope. The `SecretStore` offers a concurrent,
5//! named vault for storing and retrieving secrets, while `SecretProvider`
6//! implementations load secrets from environment variables, files, or other
7//! sources.
8
9use std::collections::HashMap;
10use std::fmt;
11use std::path::Path;
12use std::sync::Arc;
13
14use dashmap::DashMap;
15use serde::{Deserialize, Serialize};
16use zeroize::Zeroize;
17
18// ---------------------------------------------------------------------------
19// Secret<T> wrapper
20// ---------------------------------------------------------------------------
21
22/// A wrapper that zeroizes its inner value when dropped.
23///
24/// Prevents secrets from lingering in memory after they are no longer needed.
25/// Like wiping the blood from the canvas between bouts.
26pub struct Secret<T: Zeroize> {
27    inner: T,
28}
29
30impl<T: Zeroize> Secret<T> {
31    /// Wrap a value in the secret container.
32    pub fn new(value: T) -> Self {
33        Self { inner: value }
34    }
35
36    /// Access the secret value. Handle with care — the contents are sensitive.
37    pub fn expose(&self) -> &T {
38        &self.inner
39    }
40}
41
42impl<T: Zeroize> Drop for Secret<T> {
43    fn drop(&mut self) {
44        self.inner.zeroize();
45    }
46}
47
48impl<T: Zeroize> fmt::Debug for Secret<T> {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.write_str("Secret(***)")
51    }
52}
53
54impl<T: Zeroize> fmt::Display for Secret<T> {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        f.write_str("***")
57    }
58}
59
60impl<T: Zeroize + Clone> Clone for Secret<T> {
61    fn clone(&self) -> Self {
62        Self {
63            inner: self.inner.clone(),
64        }
65    }
66}
67
68// ---------------------------------------------------------------------------
69// SecretString alias
70// ---------------------------------------------------------------------------
71
72/// A `Secret<String>` — the most common secret type.
73pub type SecretString = Secret<String>;
74
75// ---------------------------------------------------------------------------
76// SecretStore
77// ---------------------------------------------------------------------------
78
79/// A concurrent, named vault for storing secrets.
80///
81/// Backed by `DashMap` for lock-free concurrent access. Secret values are
82/// wrapped in `SecretString` so they are zeroized when removed or when the
83/// store is dropped.
84#[derive(Debug, Clone)]
85pub struct SecretStore {
86    secrets: Arc<DashMap<String, String>>,
87}
88
89impl Default for SecretStore {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl SecretStore {
96    /// Create an empty vault.
97    pub fn new() -> Self {
98        Self {
99            secrets: Arc::new(DashMap::new()),
100        }
101    }
102
103    /// Store a named secret. If a secret with the same name already exists,
104    /// the old value is zeroized and replaced.
105    pub fn store_secret(&self, name: &str, value: &str) {
106        if let Some(mut old) = self.secrets.get_mut(name) {
107            old.value_mut().zeroize();
108        }
109        self.secrets.insert(name.to_string(), value.to_string());
110    }
111
112    /// Retrieve a named secret wrapped in a `SecretString`.
113    ///
114    /// Returns `None` if the secret does not exist.
115    pub fn get_secret(&self, name: &str) -> Option<SecretString> {
116        self.secrets
117            .get(name)
118            .map(|entry| Secret::new(entry.value().clone()))
119    }
120
121    /// Delete a named secret, zeroizing its value before removal.
122    pub fn delete_secret(&self, name: &str) {
123        if let Some((_, mut value)) = self.secrets.remove(name) {
124            value.zeroize();
125        }
126    }
127
128    /// List all secret names without exposing their values.
129    pub fn list_secret_names(&self) -> Vec<String> {
130        self.secrets
131            .iter()
132            .map(|entry| entry.key().clone())
133            .collect()
134    }
135
136    /// Return the number of secrets in the store.
137    pub fn len(&self) -> usize {
138        self.secrets.len()
139    }
140
141    /// Check if the store is empty.
142    pub fn is_empty(&self) -> bool {
143        self.secrets.is_empty()
144    }
145}
146
147// ---------------------------------------------------------------------------
148// Masking
149// ---------------------------------------------------------------------------
150
151/// Mask a secret value for safe display.
152///
153/// Shows the first 2 and last 2 characters with asterisks in between.
154/// Values shorter than 5 characters are fully masked.
155pub fn mask_secret(value: &str) -> String {
156    if value.len() < 5 {
157        return "*".repeat(value.len());
158    }
159    let chars: Vec<char> = value.chars().collect();
160    let first_two: String = chars[..2].iter().collect();
161    let last_two: String = chars[chars.len() - 2..].iter().collect();
162    let mask_len = chars.len() - 4;
163    format!("{}{}{}", first_two, "*".repeat(mask_len), last_two)
164}
165
166// ---------------------------------------------------------------------------
167// SecretProvider trait
168// ---------------------------------------------------------------------------
169
170/// A source of secrets — loads secrets from an external provider.
171pub trait SecretProvider: Send + Sync {
172    /// Load all available secrets into the given store.
173    fn load_secrets(&self, store: &SecretStore) -> Result<usize, SecretProviderError>;
174}
175
176/// Errors from secret providers.
177#[derive(Debug, thiserror::Error)]
178pub enum SecretProviderError {
179    /// An I/O error occurred while reading secrets.
180    #[error("I/O error: {0}")]
181    Io(#[from] std::io::Error),
182
183    /// The secret source format is invalid.
184    #[error("invalid format: {0}")]
185    InvalidFormat(String),
186}
187
188// ---------------------------------------------------------------------------
189// EnvSecretProvider
190// ---------------------------------------------------------------------------
191
192/// Loads secrets from environment variables with a configurable prefix.
193///
194/// For example, with prefix `"PUNCH_SECRET_"`, the env var
195/// `PUNCH_SECRET_API_KEY=xyz` becomes a secret named `API_KEY` with value `xyz`.
196#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct EnvSecretProvider {
198    /// The prefix to filter environment variables.
199    pub prefix: String,
200}
201
202impl EnvSecretProvider {
203    /// Create a provider that reads env vars with the given prefix.
204    pub fn new(prefix: &str) -> Self {
205        Self {
206            prefix: prefix.to_string(),
207        }
208    }
209}
210
211impl SecretProvider for EnvSecretProvider {
212    fn load_secrets(&self, store: &SecretStore) -> Result<usize, SecretProviderError> {
213        let mut count = 0;
214        for (key, value) in std::env::vars() {
215            if key.starts_with(&self.prefix) {
216                let name = &key[self.prefix.len()..];
217                if !name.is_empty() {
218                    store.store_secret(name, &value);
219                    count += 1;
220                }
221            }
222        }
223        Ok(count)
224    }
225}
226
227// ---------------------------------------------------------------------------
228// FileSecretProvider
229// ---------------------------------------------------------------------------
230
231/// Loads secrets from a file in `KEY=VALUE` format (one per line).
232///
233/// Lines starting with `#` are treated as comments. Empty lines are skipped.
234/// Leading and trailing whitespace on keys and values is trimmed.
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct FileSecretProvider {
237    /// Path to the secrets file.
238    pub path: String,
239}
240
241impl FileSecretProvider {
242    /// Create a provider that reads secrets from the given file path.
243    pub fn new(path: &str) -> Self {
244        Self {
245            path: path.to_string(),
246        }
247    }
248
249    /// Parse a secrets file content into key-value pairs.
250    pub fn parse_secrets(content: &str) -> Result<HashMap<String, String>, SecretProviderError> {
251        let mut secrets = HashMap::new();
252        for (line_num, line) in content.lines().enumerate() {
253            let trimmed = line.trim();
254            if trimmed.is_empty() || trimmed.starts_with('#') {
255                continue;
256            }
257            if let Some(eq_pos) = trimmed.find('=') {
258                let key = trimmed[..eq_pos].trim();
259                let value = trimmed[eq_pos + 1..].trim();
260                if key.is_empty() {
261                    return Err(SecretProviderError::InvalidFormat(format!(
262                        "empty key on line {}",
263                        line_num + 1
264                    )));
265                }
266                secrets.insert(key.to_string(), value.to_string());
267            } else {
268                return Err(SecretProviderError::InvalidFormat(format!(
269                    "missing '=' on line {}",
270                    line_num + 1
271                )));
272            }
273        }
274        Ok(secrets)
275    }
276}
277
278impl SecretProvider for FileSecretProvider {
279    fn load_secrets(&self, store: &SecretStore) -> Result<usize, SecretProviderError> {
280        let content = std::fs::read_to_string(Path::new(&self.path))?;
281        let secrets = Self::parse_secrets(&content)?;
282        let count = secrets.len();
283        for (key, value) in secrets {
284            store.store_secret(&key, &value);
285        }
286        Ok(count)
287    }
288}
289
290// ---------------------------------------------------------------------------
291// Tests
292// ---------------------------------------------------------------------------
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_secret_display_is_masked() {
300        let secret = Secret::new("super-secret-value".to_string());
301        assert_eq!(format!("{}", secret), "***");
302        assert_eq!(format!("{:?}", secret), "Secret(***)");
303    }
304
305    #[test]
306    fn test_secret_expose() {
307        let secret = Secret::new("my-value".to_string());
308        assert_eq!(secret.expose(), "my-value");
309    }
310
311    #[test]
312    fn test_store_and_retrieve() {
313        let store = SecretStore::new();
314        store.store_secret("API_KEY", "abc123");
315        let retrieved = store.get_secret("API_KEY").unwrap();
316        assert_eq!(retrieved.expose(), "abc123");
317    }
318
319    #[test]
320    fn test_store_missing_key() {
321        let store = SecretStore::new();
322        assert!(store.get_secret("NONEXISTENT").is_none());
323    }
324
325    #[test]
326    fn test_delete_secret() {
327        let store = SecretStore::new();
328        store.store_secret("TO_DELETE", "value");
329        assert!(store.get_secret("TO_DELETE").is_some());
330        store.delete_secret("TO_DELETE");
331        assert!(store.get_secret("TO_DELETE").is_none());
332    }
333
334    #[test]
335    fn test_list_secret_names() {
336        let store = SecretStore::new();
337        store.store_secret("ALPHA", "a");
338        store.store_secret("BETA", "b");
339        store.store_secret("GAMMA", "c");
340        let mut names = store.list_secret_names();
341        names.sort();
342        assert_eq!(names, vec!["ALPHA", "BETA", "GAMMA"]);
343    }
344
345    #[test]
346    fn test_store_len_and_empty() {
347        let store = SecretStore::new();
348        assert!(store.is_empty());
349        assert_eq!(store.len(), 0);
350        store.store_secret("KEY", "val");
351        assert!(!store.is_empty());
352        assert_eq!(store.len(), 1);
353    }
354
355    #[test]
356    fn test_overwrite_secret() {
357        let store = SecretStore::new();
358        store.store_secret("KEY", "old");
359        store.store_secret("KEY", "new");
360        let retrieved = store.get_secret("KEY").unwrap();
361        assert_eq!(retrieved.expose(), "new");
362        assert_eq!(store.len(), 1);
363    }
364
365    #[test]
366    fn test_mask_secret_normal() {
367        assert_eq!(mask_secret("abcdefgh"), "ab****gh");
368        assert_eq!(mask_secret("12345"), "12*45");
369    }
370
371    #[test]
372    fn test_mask_secret_short() {
373        assert_eq!(mask_secret("ab"), "**");
374        assert_eq!(mask_secret("abc"), "***");
375        assert_eq!(mask_secret("abcd"), "****");
376    }
377
378    #[test]
379    fn test_mask_secret_empty() {
380        assert_eq!(mask_secret(""), "");
381    }
382
383    #[test]
384    fn test_zeroization_on_drop() {
385        // We can verify that after dropping, the Secret no longer holds data.
386        // While we cannot directly inspect freed memory, we can verify the
387        // zeroize trait is invoked by checking a clone before and after.
388        let mut value = String::from("sensitive-data");
389        value.zeroize();
390        // After zeroize, the string should be empty.
391        assert!(value.is_empty());
392    }
393
394    #[test]
395    fn test_env_secret_provider() {
396        let prefix = "PUNCH_TEST_SECRET_STORE_";
397        // Set up test env vars.
398        unsafe {
399            std::env::set_var(format!("{}DB_PASS", prefix), "hunter2");
400            std::env::set_var(format!("{}API_KEY", prefix), "key123");
401        }
402
403        let provider = EnvSecretProvider::new(prefix);
404        let store = SecretStore::new();
405        let count = provider.load_secrets(&store).unwrap();
406        assert!(count >= 2);
407        assert_eq!(store.get_secret("DB_PASS").unwrap().expose(), "hunter2");
408        assert_eq!(store.get_secret("API_KEY").unwrap().expose(), "key123");
409
410        // Clean up.
411        unsafe {
412            std::env::remove_var(format!("{}DB_PASS", prefix));
413            std::env::remove_var(format!("{}API_KEY", prefix));
414        }
415    }
416
417    #[test]
418    fn test_file_secret_provider_parse() {
419        let content = r#"
420# Database credentials
421DB_HOST=localhost
422DB_PASS=supersecret
423
424# API keys
425API_KEY=abc123
426"#;
427        let secrets = FileSecretProvider::parse_secrets(content).unwrap();
428        assert_eq!(secrets.len(), 3);
429        assert_eq!(secrets.get("DB_HOST").unwrap(), "localhost");
430        assert_eq!(secrets.get("DB_PASS").unwrap(), "supersecret");
431        assert_eq!(secrets.get("API_KEY").unwrap(), "abc123");
432    }
433
434    #[test]
435    fn test_file_secret_provider_invalid_format() {
436        let content = "VALID=ok\nINVALID_LINE_NO_EQUALS";
437        let result = FileSecretProvider::parse_secrets(content);
438        assert!(result.is_err());
439    }
440
441    #[test]
442    fn test_secret_clone() {
443        let secret = Secret::new("cloneable".to_string());
444        let cloned = secret.clone();
445        assert_eq!(cloned.expose(), "cloneable");
446    }
447}