use crate::types::{ExecutionError, RiskLevel, TokenError};
use hmac::{Hmac, KeyInit, Mac};
use secrecy::{ExposeSecret, SecretBox};
use serde::{Deserialize, Serialize};
use sha2::Sha256;
use uuid::Uuid;
type HmacSha256 = Hmac<Sha256>;
pub struct TokenSecret(SecretBox<[u8]>);
impl TokenSecret {
pub fn new(secret: impl Into<Vec<u8>>) -> Self {
let bytes: Vec<u8> = secret.into();
Self(SecretBox::new(Box::from(bytes.as_slice())))
}
pub fn from_env(var: &str) -> Result<Self, std::env::VarError> {
let val = std::env::var(var)?;
Ok(Self::new(val.into_bytes()))
}
pub fn expose_secret(&self) -> &[u8] {
self.0.expose_secret()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApprovalToken {
pub request_id: String,
pub code_hash: String,
pub user_id: String,
pub session_id: String,
pub server_id: String,
pub context_hash: String,
pub risk_level: RiskLevel,
pub created_at: i64,
pub expires_at: i64,
pub signature: String,
}
impl ApprovalToken {
pub fn encode(&self) -> Result<String, serde_json::Error> {
let json = serde_json::to_string(self)?;
Ok(base64::Engine::encode(
&base64::engine::general_purpose::URL_SAFE_NO_PAD,
json.as_bytes(),
))
}
pub fn decode(encoded: &str) -> Result<Self, TokenDecodeError> {
let bytes =
base64::Engine::decode(&base64::engine::general_purpose::URL_SAFE_NO_PAD, encoded)
.map_err(|_| TokenDecodeError::InvalidBase64)?;
let json = String::from_utf8(bytes).map_err(|_| TokenDecodeError::InvalidUtf8)?;
serde_json::from_str(&json).map_err(|_| TokenDecodeError::InvalidJson)
}
fn payload_bytes(&self) -> Vec<u8> {
format!(
"{}|{}|{}|{}|{}|{}|{}|{}|{}",
self.request_id,
self.code_hash,
self.user_id,
self.session_id,
self.server_id,
self.context_hash,
self.risk_level,
self.created_at,
self.expires_at,
)
.into_bytes()
}
}
#[derive(Debug, thiserror::Error)]
pub enum TokenDecodeError {
#[error(
"Token is not valid base64 — it may have been truncated or corrupted during transport"
)]
InvalidBase64,
#[error("Token contains invalid UTF-8 bytes after base64 decoding")]
InvalidUtf8,
#[error("Token decoded to invalid JSON — the token string may have been truncated, double-encoded, or is not an approval token")]
InvalidJson,
}
pub trait TokenGenerator: Send + Sync {
fn generate(
&self,
code: &str,
user_id: &str,
session_id: &str,
server_id: &str,
context_hash: &str,
risk_level: RiskLevel,
ttl_seconds: i64,
) -> ApprovalToken;
fn verify(&self, token: &ApprovalToken) -> Result<(), ExecutionError>;
fn verify_code(&self, code: &str, token: &ApprovalToken) -> Result<(), ExecutionError>;
}
pub struct HmacTokenGenerator {
secret: TokenSecret,
}
impl HmacTokenGenerator {
pub const MIN_SECRET_LEN: usize = 16;
pub fn new(secret: TokenSecret) -> Result<Self, TokenError> {
if secret.expose_secret().len() < Self::MIN_SECRET_LEN {
return Err(TokenError::SecretTooShort {
minimum: Self::MIN_SECRET_LEN,
actual: secret.expose_secret().len(),
});
}
Ok(Self { secret })
}
pub fn new_from_bytes(bytes: impl Into<Vec<u8>>) -> Result<Self, TokenError> {
Self::new(TokenSecret::new(bytes))
}
pub fn from_env(env_var: &str) -> Result<Self, Box<dyn std::error::Error>> {
let secret = TokenSecret::from_env(env_var)?;
Ok(Self::new(secret)?)
}
fn sign(&self, payload: &[u8]) -> String {
let mut mac = HmacSha256::new_from_slice(self.secret.expose_secret())
.expect("HMAC can take key of any size");
mac.update(payload);
hex::encode(mac.finalize().into_bytes())
}
fn verify_signature(&self, payload: &[u8], signature: &str) -> bool {
let mut mac = HmacSha256::new_from_slice(self.secret.expose_secret())
.expect("HMAC can take key of any size");
mac.update(payload);
let expected = hex::decode(signature).unwrap_or_default();
mac.verify_slice(&expected).is_ok()
}
}
impl TokenGenerator for HmacTokenGenerator {
fn generate(
&self,
code: &str,
user_id: &str,
session_id: &str,
server_id: &str,
context_hash: &str,
risk_level: RiskLevel,
ttl_seconds: i64,
) -> ApprovalToken {
let now = chrono::Utc::now().timestamp();
let mut token = ApprovalToken {
request_id: Uuid::new_v4().to_string(),
code_hash: hash_code(code),
user_id: user_id.to_string(),
session_id: session_id.to_string(),
server_id: server_id.to_string(),
context_hash: context_hash.to_string(),
risk_level,
created_at: now,
expires_at: now + ttl_seconds,
signature: String::new(),
};
token.signature = self.sign(&token.payload_bytes());
token
}
fn verify(&self, token: &ApprovalToken) -> Result<(), ExecutionError> {
let now = chrono::Utc::now().timestamp();
if now > token.expires_at {
return Err(ExecutionError::TokenExpired);
}
if !self.verify_signature(&token.payload_bytes(), &token.signature) {
return Err(ExecutionError::TokenInvalid(
"signature verification failed".into(),
));
}
Ok(())
}
fn verify_code(&self, code: &str, token: &ApprovalToken) -> Result<(), ExecutionError> {
let current_hash = hash_code(code);
if current_hash != token.code_hash {
let expected_prefix = if token.code_hash.len() >= 12 {
&token.code_hash[..12]
} else {
&token.code_hash
};
let actual_prefix = if current_hash.len() >= 12 {
¤t_hash[..12]
} else {
¤t_hash
};
return Err(ExecutionError::CodeMismatch {
expected_hash: expected_prefix.to_string(),
actual_hash: actual_prefix.to_string(),
});
}
Ok(())
}
}
pub fn hash_code(code: &str) -> String {
use sha2::Digest;
let mut hasher = Sha256::new();
hasher.update(canonicalize_code(code).as_bytes());
hex::encode(hasher.finalize())
}
pub fn canonicalize_code(code: &str) -> String {
let mut result = String::new();
for line in code.trim().lines() {
let trimmed = line.trim();
if !trimmed.is_empty() {
if !result.is_empty() {
result.push('\n');
}
result.push_str(trimmed);
}
}
result
}
pub fn compute_context_hash(schema_hash: &str, permissions_hash: &str) -> String {
use sha2::Digest;
let mut hasher = Sha256::new();
hasher.update(schema_hash.as_bytes());
hasher.update(b"|");
hasher.update(permissions_hash.as_bytes());
hex::encode(hasher.finalize())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_generation_and_verification() {
let generator =
HmacTokenGenerator::new(TokenSecret::new(b"test-secret-key!".to_vec())).unwrap();
let token = generator.generate(
"query { users { id } }",
"user-123",
"session-456",
"server-789",
"context-hash",
RiskLevel::Low,
300,
);
assert!(generator.verify(&token).is_ok());
assert!(generator
.verify_code("query { users { id } }", &token)
.is_ok());
}
#[test]
fn test_code_mismatch() {
let generator =
HmacTokenGenerator::new(TokenSecret::new(b"test-secret-key!".to_vec())).unwrap();
let token = generator.generate(
"query { users { id } }",
"user-123",
"session-456",
"server-789",
"context-hash",
RiskLevel::Low,
300,
);
let result = generator.verify_code("query { orders { id } }", &token);
assert!(matches!(result, Err(ExecutionError::CodeMismatch { .. })));
}
#[test]
fn test_token_encode_decode() {
let generator =
HmacTokenGenerator::new(TokenSecret::new(b"test-secret-key!".to_vec())).unwrap();
let token = generator.generate(
"query { users { id } }",
"user-123",
"session-456",
"server-789",
"context-hash",
RiskLevel::Low,
300,
);
let encoded = token.encode().unwrap();
let decoded = ApprovalToken::decode(&encoded).unwrap();
assert_eq!(token.request_id, decoded.request_id);
assert_eq!(token.code_hash, decoded.code_hash);
assert_eq!(token.signature, decoded.signature);
}
#[test]
fn test_canonicalize_code() {
let code1 = "query { users { id } }";
let code2 = " query { users { id } } ";
let code3 = "query {\n users {\n id\n }\n}";
assert_eq!(canonicalize_code(code1), canonicalize_code(code2));
let canonical = canonicalize_code(code3);
assert!(canonical.contains("query {"));
assert!(canonical.contains("users {"));
}
#[test]
fn test_empty_secret_rejected() {
let result = HmacTokenGenerator::new(TokenSecret::new(b"".to_vec()));
assert!(matches!(
result,
Err(TokenError::SecretTooShort {
minimum: 16,
actual: 0
})
));
}
#[test]
fn test_short_secret_rejected() {
let result = HmacTokenGenerator::new(TokenSecret::new(b"short".to_vec()));
assert!(matches!(
result,
Err(TokenError::SecretTooShort {
minimum: 16,
actual: 5
})
));
}
}