1use std::collections::HashMap;
10use std::fmt;
11use std::path::Path;
12use std::sync::Arc;
13
14use dashmap::DashMap;
15use serde::{Deserialize, Serialize};
16use zeroize::Zeroize;
17
18pub struct Secret<T: Zeroize> {
27 inner: T,
28}
29
30impl<T: Zeroize> Secret<T> {
31 pub fn new(value: T) -> Self {
33 Self { inner: value }
34 }
35
36 pub fn expose(&self) -> &T {
38 &self.inner
39 }
40}
41
42impl<T: Zeroize> Drop for Secret<T> {
43 fn drop(&mut self) {
44 self.inner.zeroize();
45 }
46}
47
48impl<T: Zeroize> fmt::Debug for Secret<T> {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.write_str("Secret(***)")
51 }
52}
53
54impl<T: Zeroize> fmt::Display for Secret<T> {
55 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56 f.write_str("***")
57 }
58}
59
60impl<T: Zeroize + Clone> Clone for Secret<T> {
61 fn clone(&self) -> Self {
62 Self {
63 inner: self.inner.clone(),
64 }
65 }
66}
67
68pub type SecretString = Secret<String>;
74
75#[derive(Debug, Clone)]
85pub struct SecretStore {
86 secrets: Arc<DashMap<String, String>>,
87}
88
89impl Default for SecretStore {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl SecretStore {
96 pub fn new() -> Self {
98 Self {
99 secrets: Arc::new(DashMap::new()),
100 }
101 }
102
103 pub fn store_secret(&self, name: &str, value: &str) {
106 if let Some(mut old) = self.secrets.get_mut(name) {
107 old.value_mut().zeroize();
108 }
109 self.secrets.insert(name.to_string(), value.to_string());
110 }
111
112 pub fn get_secret(&self, name: &str) -> Option<SecretString> {
116 self.secrets
117 .get(name)
118 .map(|entry| Secret::new(entry.value().clone()))
119 }
120
121 pub fn delete_secret(&self, name: &str) {
123 if let Some((_, mut value)) = self.secrets.remove(name) {
124 value.zeroize();
125 }
126 }
127
128 pub fn list_secret_names(&self) -> Vec<String> {
130 self.secrets
131 .iter()
132 .map(|entry| entry.key().clone())
133 .collect()
134 }
135
136 pub fn len(&self) -> usize {
138 self.secrets.len()
139 }
140
141 pub fn is_empty(&self) -> bool {
143 self.secrets.is_empty()
144 }
145}
146
147pub fn mask_secret(value: &str) -> String {
156 if value.len() < 5 {
157 return "*".repeat(value.len());
158 }
159 let chars: Vec<char> = value.chars().collect();
160 let first_two: String = chars[..2].iter().collect();
161 let last_two: String = chars[chars.len() - 2..].iter().collect();
162 let mask_len = chars.len() - 4;
163 format!("{}{}{}", first_two, "*".repeat(mask_len), last_two)
164}
165
166pub trait SecretProvider: Send + Sync {
172 fn load_secrets(&self, store: &SecretStore) -> Result<usize, SecretProviderError>;
174}
175
176#[derive(Debug, thiserror::Error)]
178pub enum SecretProviderError {
179 #[error("I/O error: {0}")]
181 Io(#[from] std::io::Error),
182
183 #[error("invalid format: {0}")]
185 InvalidFormat(String),
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct EnvSecretProvider {
198 pub prefix: String,
200}
201
202impl EnvSecretProvider {
203 pub fn new(prefix: &str) -> Self {
205 Self {
206 prefix: prefix.to_string(),
207 }
208 }
209}
210
211impl SecretProvider for EnvSecretProvider {
212 fn load_secrets(&self, store: &SecretStore) -> Result<usize, SecretProviderError> {
213 let mut count = 0;
214 for (key, value) in std::env::vars() {
215 if key.starts_with(&self.prefix) {
216 let name = &key[self.prefix.len()..];
217 if !name.is_empty() {
218 store.store_secret(name, &value);
219 count += 1;
220 }
221 }
222 }
223 Ok(count)
224 }
225}
226
227#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct FileSecretProvider {
237 pub path: String,
239}
240
241impl FileSecretProvider {
242 pub fn new(path: &str) -> Self {
244 Self {
245 path: path.to_string(),
246 }
247 }
248
249 pub fn parse_secrets(content: &str) -> Result<HashMap<String, String>, SecretProviderError> {
251 let mut secrets = HashMap::new();
252 for (line_num, line) in content.lines().enumerate() {
253 let trimmed = line.trim();
254 if trimmed.is_empty() || trimmed.starts_with('#') {
255 continue;
256 }
257 if let Some(eq_pos) = trimmed.find('=') {
258 let key = trimmed[..eq_pos].trim();
259 let value = trimmed[eq_pos + 1..].trim();
260 if key.is_empty() {
261 return Err(SecretProviderError::InvalidFormat(format!(
262 "empty key on line {}",
263 line_num + 1
264 )));
265 }
266 secrets.insert(key.to_string(), value.to_string());
267 } else {
268 return Err(SecretProviderError::InvalidFormat(format!(
269 "missing '=' on line {}",
270 line_num + 1
271 )));
272 }
273 }
274 Ok(secrets)
275 }
276}
277
278impl SecretProvider for FileSecretProvider {
279 fn load_secrets(&self, store: &SecretStore) -> Result<usize, SecretProviderError> {
280 let content = std::fs::read_to_string(Path::new(&self.path))?;
281 let secrets = Self::parse_secrets(&content)?;
282 let count = secrets.len();
283 for (key, value) in secrets {
284 store.store_secret(&key, &value);
285 }
286 Ok(count)
287 }
288}
289
290#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_secret_display_is_masked() {
300 let secret = Secret::new("super-secret-value".to_string());
301 assert_eq!(format!("{}", secret), "***");
302 assert_eq!(format!("{:?}", secret), "Secret(***)");
303 }
304
305 #[test]
306 fn test_secret_expose() {
307 let secret = Secret::new("my-value".to_string());
308 assert_eq!(secret.expose(), "my-value");
309 }
310
311 #[test]
312 fn test_store_and_retrieve() {
313 let store = SecretStore::new();
314 store.store_secret("API_KEY", "abc123");
315 let retrieved = store.get_secret("API_KEY").unwrap();
316 assert_eq!(retrieved.expose(), "abc123");
317 }
318
319 #[test]
320 fn test_store_missing_key() {
321 let store = SecretStore::new();
322 assert!(store.get_secret("NONEXISTENT").is_none());
323 }
324
325 #[test]
326 fn test_delete_secret() {
327 let store = SecretStore::new();
328 store.store_secret("TO_DELETE", "value");
329 assert!(store.get_secret("TO_DELETE").is_some());
330 store.delete_secret("TO_DELETE");
331 assert!(store.get_secret("TO_DELETE").is_none());
332 }
333
334 #[test]
335 fn test_list_secret_names() {
336 let store = SecretStore::new();
337 store.store_secret("ALPHA", "a");
338 store.store_secret("BETA", "b");
339 store.store_secret("GAMMA", "c");
340 let mut names = store.list_secret_names();
341 names.sort();
342 assert_eq!(names, vec!["ALPHA", "BETA", "GAMMA"]);
343 }
344
345 #[test]
346 fn test_store_len_and_empty() {
347 let store = SecretStore::new();
348 assert!(store.is_empty());
349 assert_eq!(store.len(), 0);
350 store.store_secret("KEY", "val");
351 assert!(!store.is_empty());
352 assert_eq!(store.len(), 1);
353 }
354
355 #[test]
356 fn test_overwrite_secret() {
357 let store = SecretStore::new();
358 store.store_secret("KEY", "old");
359 store.store_secret("KEY", "new");
360 let retrieved = store.get_secret("KEY").unwrap();
361 assert_eq!(retrieved.expose(), "new");
362 assert_eq!(store.len(), 1);
363 }
364
365 #[test]
366 fn test_mask_secret_normal() {
367 assert_eq!(mask_secret("abcdefgh"), "ab****gh");
368 assert_eq!(mask_secret("12345"), "12*45");
369 }
370
371 #[test]
372 fn test_mask_secret_short() {
373 assert_eq!(mask_secret("ab"), "**");
374 assert_eq!(mask_secret("abc"), "***");
375 assert_eq!(mask_secret("abcd"), "****");
376 }
377
378 #[test]
379 fn test_mask_secret_empty() {
380 assert_eq!(mask_secret(""), "");
381 }
382
383 #[test]
384 fn test_zeroization_on_drop() {
385 let mut value = String::from("sensitive-data");
389 value.zeroize();
390 assert!(value.is_empty());
392 }
393
394 #[test]
395 fn test_env_secret_provider() {
396 let prefix = "PUNCH_TEST_SECRET_STORE_";
397 unsafe {
399 std::env::set_var(format!("{}DB_PASS", prefix), "hunter2");
400 std::env::set_var(format!("{}API_KEY", prefix), "key123");
401 }
402
403 let provider = EnvSecretProvider::new(prefix);
404 let store = SecretStore::new();
405 let count = provider.load_secrets(&store).unwrap();
406 assert!(count >= 2);
407 assert_eq!(store.get_secret("DB_PASS").unwrap().expose(), "hunter2");
408 assert_eq!(store.get_secret("API_KEY").unwrap().expose(), "key123");
409
410 unsafe {
412 std::env::remove_var(format!("{}DB_PASS", prefix));
413 std::env::remove_var(format!("{}API_KEY", prefix));
414 }
415 }
416
417 #[test]
418 fn test_file_secret_provider_parse() {
419 let content = r#"
420# Database credentials
421DB_HOST=localhost
422DB_PASS=supersecret
423
424# API keys
425API_KEY=abc123
426"#;
427 let secrets = FileSecretProvider::parse_secrets(content).unwrap();
428 assert_eq!(secrets.len(), 3);
429 assert_eq!(secrets.get("DB_HOST").unwrap(), "localhost");
430 assert_eq!(secrets.get("DB_PASS").unwrap(), "supersecret");
431 assert_eq!(secrets.get("API_KEY").unwrap(), "abc123");
432 }
433
434 #[test]
435 fn test_file_secret_provider_invalid_format() {
436 let content = "VALID=ok\nINVALID_LINE_NO_EQUALS";
437 let result = FileSecretProvider::parse_secrets(content);
438 assert!(result.is_err());
439 }
440
441 #[test]
442 fn test_secret_clone() {
443 let secret = Secret::new("cloneable".to_string());
444 let cloned = secret.clone();
445 assert_eq!(cloned.expose(), "cloneable");
446 }
447}