use minijinja::{Environment, Error, ErrorKind};
use sherpack_core::{SecretCharset, SecretGenerator, SecretState};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct SecretFunctionState {
generator: Arc<std::sync::Mutex<SecretGenerator>>,
}
impl Default for SecretFunctionState {
fn default() -> Self {
Self::new()
}
}
impl SecretFunctionState {
pub fn new() -> Self {
Self {
generator: Arc::new(std::sync::Mutex::new(SecretGenerator::new())),
}
}
pub fn with_state(state: SecretState) -> Self {
Self {
generator: Arc::new(std::sync::Mutex::new(SecretGenerator::with_state(state))),
}
}
pub fn is_dirty(&self) -> bool {
self.generator.lock().unwrap().is_dirty()
}
pub fn take_state(&self) -> SecretState {
let mut generator = self.generator.lock().unwrap();
let state = std::mem::take(&mut *generator);
state.into_state()
}
pub fn register(&self, env: &mut Environment<'static>) {
let generator = Arc::clone(&self.generator);
env.add_function(
"generate_secret",
move |name: String, length: i64, charset: Option<String>| -> Result<String, Error> {
if name.is_empty() {
return Err(Error::new(
ErrorKind::InvalidOperation,
"generate_secret: name cannot be empty",
));
}
if length < 1 {
return Err(Error::new(
ErrorKind::InvalidOperation,
format!("generate_secret: length must be positive, got {}", length),
));
}
if length > 4096 {
return Err(Error::new(
ErrorKind::InvalidOperation,
format!("generate_secret: length {} exceeds maximum of 4096", length),
));
}
let charset = match charset {
Some(ref charset_str) => {
SecretCharset::parse(charset_str).ok_or_else(|| {
Error::new(
ErrorKind::InvalidOperation,
format!(
"generate_secret: unknown charset '{}'. Valid options: \
alphanumeric, alpha, numeric, hex, base64, urlsafe",
charset_str
),
)
})?
}
None => SecretCharset::default(),
};
let mut secret_gen = generator.lock().unwrap();
let secret =
secret_gen.get_or_generate_with_charset(&name, length as usize, charset);
Ok(secret)
},
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generate_secret_basic() {
let mut env = Environment::new();
let state = SecretFunctionState::new();
state.register(&mut env);
let template = r#"{{ generate_secret("test-password", 16) }}"#;
let result = env.render_str(template, ()).unwrap();
assert_eq!(result.len(), 16);
assert!(result.chars().all(|c| c.is_ascii_alphanumeric()));
}
#[test]
fn test_generate_secret_idempotent() {
let mut env = Environment::new();
let state = SecretFunctionState::new();
state.register(&mut env);
let template1 = r#"{{ generate_secret("db-password", 20) }}"#;
let result1 = env.render_str(template1, ()).unwrap();
let result2 = env.render_str(template1, ()).unwrap();
assert_eq!(result1, result2);
}
#[test]
fn test_generate_secret_different_names() {
let mut env = Environment::new();
let state = SecretFunctionState::new();
state.register(&mut env);
let template =
r#"{{ generate_secret("password1", 16) }}-{{ generate_secret("password2", 16) }}"#;
let result = env.render_str(template, ()).unwrap();
let parts: Vec<&str> = result.split('-').collect();
assert_eq!(parts.len(), 2);
assert_ne!(parts[0], parts[1]);
}
#[test]
fn test_generate_secret_hex_charset() {
let mut env = Environment::new();
let state = SecretFunctionState::new();
state.register(&mut env);
let template = r#"{{ generate_secret("hex-token", 32, "hex") }}"#;
let result = env.render_str(template, ()).unwrap();
assert_eq!(result.len(), 32);
assert!(result.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_generate_secret_numeric_charset() {
let mut env = Environment::new();
let state = SecretFunctionState::new();
state.register(&mut env);
let template = r#"{{ generate_secret("pin", 6, "numeric") }}"#;
let result = env.render_str(template, ()).unwrap();
assert_eq!(result.len(), 6);
assert!(result.chars().all(|c| c.is_ascii_digit()));
}
#[test]
fn test_generate_secret_invalid_charset() {
let mut env = Environment::new();
let state = SecretFunctionState::new();
state.register(&mut env);
let template = r#"{{ generate_secret("test", 16, "invalid") }}"#;
let result = env.render_str(template, ());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("unknown charset"));
}
#[test]
fn test_generate_secret_missing_args() {
let mut env = Environment::new();
let state = SecretFunctionState::new();
state.register(&mut env);
let template = r#"{{ generate_secret("only-name") }}"#;
let result = env.render_str(template, ());
assert!(result.is_err());
}
#[test]
fn test_generate_secret_invalid_length() {
let mut env = Environment::new();
let state = SecretFunctionState::new();
state.register(&mut env);
let result = env.render_str(r#"{{ generate_secret("test", 0) }}"#, ());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("must be positive"));
let mut env2 = Environment::new();
let state2 = SecretFunctionState::new();
state2.register(&mut env2);
let result = env2.render_str(r#"{{ generate_secret("test", -5) }}"#, ());
assert!(result.is_err());
let mut env3 = Environment::new();
let state3 = SecretFunctionState::new();
state3.register(&mut env3);
let result = env3.render_str(r#"{{ generate_secret("test", 10000) }}"#, ());
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
}
#[test]
fn test_state_is_dirty() {
let state = SecretFunctionState::new();
assert!(!state.is_dirty());
let mut env = Environment::new();
state.register(&mut env);
env.render_str(r#"{{ generate_secret("new-secret", 16) }}"#, ())
.unwrap();
assert!(state.is_dirty());
}
#[test]
fn test_state_persistence() {
let state1 = SecretFunctionState::new();
let mut env1 = Environment::new();
state1.register(&mut env1);
let secret = env1
.render_str(r#"{{ generate_secret("db-password", 24) }}"#, ())
.unwrap();
let persisted = state1.take_state();
let json = serde_json::to_string(&persisted).unwrap();
let loaded: SecretState = serde_json::from_str(&json).unwrap();
let state2 = SecretFunctionState::with_state(loaded);
let mut env2 = Environment::new();
state2.register(&mut env2);
let secret2 = env2
.render_str(r#"{{ generate_secret("db-password", 24) }}"#, ())
.unwrap();
assert_eq!(secret, secret2);
assert!(!state2.is_dirty());
}
#[test]
fn test_multiple_secrets_in_template() {
let state = SecretFunctionState::new();
let mut env = Environment::new();
state.register(&mut env);
let template = r#"
postgres-password: {{ generate_secret("postgres-password", 24) }}
replication-password: {{ generate_secret("replication-password", 24) }}
api-key: {{ generate_secret("api-key", 32, "hex") }}
"#;
let result = env.render_str(template, ()).unwrap();
let lines: Vec<&str> = result.lines().filter(|l| !l.is_empty()).collect();
assert_eq!(lines.len(), 3);
let postgres_pw = lines[0].split(": ").nth(1).unwrap();
let repl_pw = lines[1].split(": ").nth(1).unwrap();
let api_key = lines[2].split(": ").nth(1).unwrap();
assert_ne!(postgres_pw, repl_pw);
assert_ne!(postgres_pw, api_key);
assert_eq!(postgres_pw.len(), 24);
assert_eq!(repl_pw.len(), 24);
assert_eq!(api_key.len(), 32);
}
}