use crate::error::{AadError, JsonType};
use crate::types::{
ExtensionValue, Extensions, FieldKey, Purpose, Resource, SafeInt, Tenant, RESERVED_KEYS,
};
use serde_json::{Map, Value};
use std::collections::HashSet;
pub const MAX_AAD_SIZE: usize = 16 * 1024;
pub const CURRENT_VERSION: u64 = 1;
#[derive(Debug)]
pub struct ParsedAad {
pub version: SafeInt,
pub tenant: Tenant,
pub resource: Resource,
pub purpose: Purpose,
pub timestamp: Option<SafeInt>,
pub extensions: Extensions,
}
pub fn parse_json_with_duplicate_check(json: &str) -> Result<Value, AadError> {
check_for_duplicate_keys(json)?;
serde_json::from_str(json).map_err(|e| AadError::InvalidJson { message: e.to_string() })
}
fn check_for_duplicate_keys(json: &str) -> Result<(), AadError> {
let chars: Vec<char> = json.chars().collect();
let mut i = 0;
let mut key_stack: Vec<HashSet<String>> = Vec::new();
while i < chars.len() {
match chars[i] {
'{' => {
key_stack.push(HashSet::new());
i += 1;
}
'}' => {
key_stack.pop();
i += 1;
}
'[' | ']' | ',' | ':' => {
i += 1;
}
'"' => {
let (s, end_pos) = parse_json_string(&chars, i)?;
let mut j = end_pos;
while j < chars.len() && chars[j].is_whitespace() {
j += 1;
}
if j < chars.len() && chars[j] == ':' {
if let Some(keys) = key_stack.last_mut() {
if !keys.insert(s.clone()) {
return Err(AadError::DuplicateKey { key: s });
}
}
}
i = end_pos;
}
c if c.is_whitespace() => {
i += 1;
}
_ => {
while i < chars.len()
&& !matches!(chars[i], '{' | '}' | '[' | ']' | ',' | ':' | '"')
&& !chars[i].is_whitespace()
{
i += 1;
}
}
}
}
Ok(())
}
fn parse_json_string(chars: &[char], start: usize) -> Result<(String, usize), AadError> {
if chars.get(start) != Some(&'"') {
return Err(AadError::InvalidJson { message: "expected string".to_string() });
}
let mut result = String::new();
let mut i = start + 1;
while i < chars.len() {
match chars[i] {
'"' => return Ok((result, i + 1)),
'\\' => {
i = parse_escape_sequence(chars, i, &mut result)?;
}
c => {
result.push(c);
i += 1;
}
}
}
Err(AadError::InvalidJson { message: "unterminated string".to_string() })
}
fn parse_escape_sequence(
chars: &[char],
backslash_pos: usize,
result: &mut String,
) -> Result<usize, AadError> {
let i = backslash_pos;
if i + 1 >= chars.len() {
return Err(AadError::InvalidJson { message: "unterminated escape sequence".to_string() });
}
let mut pos = i + 1;
match chars[pos] {
'"' => result.push('"'),
'\\' => result.push('\\'),
'/' => result.push('/'),
'b' => result.push('\x08'),
'f' => result.push('\x0c'),
'n' => result.push('\n'),
'r' => result.push('\r'),
't' => result.push('\t'),
'u' => {
pos = parse_unicode_escape(chars, pos, result)?;
return Ok(pos);
}
c => {
return Err(AadError::InvalidJson {
message: format!("invalid escape sequence: \\{c}"),
});
}
}
Ok(pos + 1)
}
fn parse_unicode_escape(
chars: &[char],
u_pos: usize,
result: &mut String,
) -> Result<usize, AadError> {
if u_pos + 4 >= chars.len() {
return Err(AadError::InvalidJson { message: "invalid unicode escape".to_string() });
}
let hex: String = chars[u_pos + 1..u_pos + 5].iter().collect();
let code_point = u16::from_str_radix(&hex, 16).map_err(|_| AadError::InvalidJson {
message: format!("invalid unicode escape: \\u{hex}"),
})?;
if (0xD800..=0xDBFF).contains(&code_point) {
if let Some(cp) = try_parse_surrogate_pair(chars, u_pos, code_point) {
if let Some(c) = char::from_u32(cp) {
result.push(c);
return Ok(u_pos + 11);
}
}
}
if let Some(c) = char::from_u32(u32::from(code_point)) {
result.push(c);
} else {
use std::fmt::Write;
let _ = write!(result, "\\u{hex}");
}
Ok(u_pos + 5)
}
fn try_parse_surrogate_pair(chars: &[char], u_pos: usize, high: u16) -> Option<u32> {
if u_pos + 10 >= chars.len() {
return None;
}
if chars[u_pos + 5] != '\\' || chars[u_pos + 6] != 'u' {
return None;
}
let hex2: String = chars[u_pos + 7..u_pos + 11].iter().collect();
let low = u16::from_str_radix(&hex2, 16).ok()?;
if !(0xDC00..=0xDFFF).contains(&low) {
return None;
}
let high = u32::from(high);
let low = u32::from(low);
Some(0x10000 + ((high - 0xD800) << 10) + (low - 0xDC00))
}
pub fn parse_aad(json: &str) -> Result<ParsedAad, AadError> {
if json.len() > MAX_AAD_SIZE {
return Err(AadError::SerializedTooLarge {
max_bytes: MAX_AAD_SIZE,
actual_bytes: json.len(),
});
}
let value = parse_json_with_duplicate_check(json)?;
let obj = value.as_object().ok_or_else(|| AadError::InvalidJson {
message: "AAD must be a JSON object".to_string(),
})?;
let version = extract_version(obj)?;
validate_field_names(obj)?;
let tenant = extract_string_field(obj, "tenant").and_then(Tenant::new)?;
let resource = extract_string_field(obj, "resource").and_then(Resource::new)?;
let purpose = extract_string_field(obj, "purpose").and_then(Purpose::new)?;
let timestamp = extract_optional_timestamp(obj)?;
let extensions = extract_extensions(obj)?;
Ok(ParsedAad { version, tenant, resource, purpose, timestamp, extensions })
}
fn extract_version(obj: &Map<String, Value>) -> Result<SafeInt, AadError> {
match obj.get("v") {
None => Err(AadError::MissingRequiredField { field: "v" }),
Some(v) => {
let n = v.as_u64().ok_or_else(|| AadError::WrongFieldType {
field: "v",
expected: "integer",
actual: JsonType::from(v),
})?;
if n != CURRENT_VERSION {
return Err(AadError::UnsupportedVersion { version: n });
}
SafeInt::new(n)
}
}
}
fn validate_field_names(obj: &Map<String, Value>) -> Result<(), AadError> {
for key in obj.keys() {
if RESERVED_KEYS.contains(&key.as_str()) {
continue;
}
let field_key = FieldKey::new(key.clone()).map_err(|_| AadError::InvalidFieldKey {
key: key.clone(),
reason: "field keys must match pattern [a-z_]+".to_string(),
})?;
if !key.starts_with("x_") {
return Err(AadError::UnknownField { field: key.clone(), version: CURRENT_VERSION });
}
field_key.validate_as_extension()?;
}
Ok(())
}
fn extract_string_field(obj: &Map<String, Value>, field: &'static str) -> Result<String, AadError> {
obj.get(field).map_or_else(
|| Err(AadError::MissingRequiredField { field }),
|v| {
v.as_str().map(String::from).ok_or_else(|| AadError::WrongFieldType {
field,
expected: "string",
actual: JsonType::from(v),
})
},
)
}
fn extract_optional_timestamp(obj: &Map<String, Value>) -> Result<Option<SafeInt>, AadError> {
match obj.get("ts") {
None => Ok(None),
Some(v) => {
let n = v.as_u64().ok_or_else(|| {
v.as_i64().map_or_else(
|| AadError::WrongFieldType {
field: "ts",
expected: "integer",
actual: JsonType::from(v),
},
|i| AadError::NegativeInteger { value: i },
)
})?;
Ok(Some(SafeInt::new(n)?))
}
}
}
fn extract_extensions(obj: &Map<String, Value>) -> Result<Extensions, AadError> {
let mut extensions = Extensions::new();
for (key, value) in obj {
if key.starts_with("x_") {
let field_key = FieldKey::new(key)?;
field_key.validate_as_extension()?;
let ext_value = parse_extension_value(value)?;
extensions.insert(field_key, ext_value);
}
}
Ok(extensions)
}
fn parse_extension_value(value: &Value) -> Result<ExtensionValue, AadError> {
match value {
Value::String(s) => ExtensionValue::string(s),
Value::Number(n) => n.as_u64().map_or_else(
|| {
n.as_i64().map_or_else(
|| {
Err(AadError::WrongFieldType {
field: "extension",
expected: "string or integer",
actual: JsonType::Number,
})
},
|i| Err(AadError::NegativeInteger { value: i }),
)
},
ExtensionValue::integer,
),
_ => Err(AadError::WrongFieldType {
field: "extension",
expected: "string or integer",
actual: JsonType::from(value),
}),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_valid_minimal() {
let json = r#"{"v":1,"tenant":"org_abc","resource":"secrets/db","purpose":"encryption"}"#;
let result = parse_aad(json);
assert!(result.is_ok());
let aad = result.unwrap();
assert_eq!(aad.version.value(), 1);
assert_eq!(aad.tenant.as_str(), "org_abc");
assert_eq!(aad.resource.as_str(), "secrets/db");
assert_eq!(aad.purpose.as_str(), "encryption");
assert!(aad.timestamp.is_none());
assert!(aad.extensions.is_empty());
}
#[test]
fn test_parse_with_timestamp() {
let json = r#"{"v":1,"tenant":"org","resource":"res","purpose":"test","ts":1706400000}"#;
let result = parse_aad(json);
assert!(result.is_ok());
let aad = result.unwrap();
assert_eq!(aad.timestamp.unwrap().value(), 1_706_400_000);
}
#[test]
fn test_parse_with_extension() {
let json = r#"{"v":1,"tenant":"org","resource":"res","purpose":"test","x_vault_cluster":"us-east-1"}"#;
let result = parse_aad(json);
assert!(result.is_ok());
let aad = result.unwrap();
assert_eq!(aad.extensions.len(), 1);
}
#[test]
fn test_duplicate_key_detection() {
let json = r#"{"v":1,"tenant":"org","tenant":"other","resource":"res","purpose":"test"}"#;
let result = parse_aad(json);
assert!(matches!(result, Err(AadError::DuplicateKey { key }) if key == "tenant"));
}
#[test]
fn test_missing_required_field() {
let json = r#"{"v":1,"tenant":"org","resource":"res"}"#;
let result = parse_aad(json);
assert!(matches!(result, Err(AadError::MissingRequiredField { field: "purpose" })));
}
#[test]
fn test_unknown_field() {
let json = r#"{"v":1,"tenant":"org","resource":"res","purpose":"test","unknown":"value"}"#;
let result = parse_aad(json);
assert!(
matches!(result, Err(AadError::UnknownField { field, version: 1 }) if field == "unknown")
);
}
#[test]
fn test_invalid_extension_key() {
let json = r#"{"v":1,"tenant":"org","resource":"res","purpose":"test","x_foo":"value"}"#;
let result = parse_aad(json);
assert!(matches!(result, Err(AadError::InvalidExtensionKeyFormat { .. })));
}
#[test]
fn test_unsupported_version() {
let json = r#"{"v":2,"tenant":"org","resource":"res","purpose":"test"}"#;
let result = parse_aad(json);
assert!(matches!(result, Err(AadError::UnsupportedVersion { version: 2 })));
}
#[test]
fn test_wrong_field_type() {
let json = r#"{"v":"1","tenant":"org","resource":"res","purpose":"test"}"#;
let result = parse_aad(json);
assert!(matches!(
result,
Err(AadError::WrongFieldType { field: "v", expected: "integer", .. })
));
}
#[test]
fn test_integer_out_of_range() {
let json =
r#"{"v":1,"tenant":"org","resource":"res","purpose":"test","ts":9007199254740992}"#;
let result = parse_aad(json);
assert!(matches!(result, Err(AadError::IntegerOutOfRange { .. })));
}
#[test]
fn test_size_limit() {
let big_resource = "x".repeat(MAX_AAD_SIZE + 1);
let json =
format!(r#"{{"v":1,"tenant":"org","resource":"{}","purpose":"test"}}"#, big_resource);
let result = parse_aad(&json);
assert!(matches!(result, Err(AadError::SerializedTooLarge { .. })));
}
#[test]
fn test_unicode_escapes() {
let json = r#"{"v":1,"tenant":"test\u0041","resource":"res","purpose":"test"}"#;
let result = parse_aad(json);
assert!(result.is_ok());
}
#[test]
fn test_nested_duplicate_not_relevant() {
let json = r#"{"v":1}"#;
let result = parse_json_with_duplicate_check(json);
assert!(result.is_ok());
}
}