use serde_json::Value;
use unicode_normalization::UnicodeNormalization;
use crate::errors::{AshError, AshErrorCode};
const MAX_RECURSION_DEPTH: usize = 64;
const MAX_PAYLOAD_SIZE: usize = 10 * 1024 * 1024;
pub fn ash_canonicalize_json(input: &str) -> Result<String, AshError> {
if input.len() > MAX_PAYLOAD_SIZE {
return Err(AshError::new(
AshErrorCode::CanonicalizationError,
format!("Payload exceeds maximum size of {} bytes", MAX_PAYLOAD_SIZE),
));
}
let value: Value = serde_json::from_str(input).map_err(|_e| {
AshError::new(
AshErrorCode::CanonicalizationError,
"Invalid JSON format".to_string(),
)
})?;
let canonical = ash_canonicalize_value_with_depth(&value, 0)?;
serde_json::to_string(&canonical).map_err(|e| {
AshError::new(
AshErrorCode::CanonicalizationError,
format!("Failed to serialize: {}", e),
)
})
}
fn ash_canonicalize_value_with_depth(value: &Value, depth: usize) -> Result<Value, AshError> {
if depth > MAX_RECURSION_DEPTH {
return Err(AshError::new(
AshErrorCode::CanonicalizationError,
format!("JSON exceeds maximum nesting depth of {}", MAX_RECURSION_DEPTH),
));
}
match value {
Value::Null => Ok(Value::Null),
Value::Bool(b) => Ok(Value::Bool(*b)),
Value::Number(n) => ash_canonicalize_number(n),
Value::String(s) => Ok(Value::String(ash_canonicalize_string(s))),
Value::Array(arr) => {
let canonical: Result<Vec<Value>, AshError> =
arr.iter().map(|v| ash_canonicalize_value_with_depth(v, depth + 1)).collect();
Ok(Value::Array(canonical?))
}
Value::Object(obj) => {
let mut sorted: Vec<(&String, &Value)> = obj.iter().collect();
sorted.sort_by(|a, b| a.0.cmp(b.0));
let mut canonical = serde_json::Map::new();
for (key, val) in sorted {
let canonical_key = ash_canonicalize_string(key);
let canonical_val = ash_canonicalize_value_with_depth(val, depth + 1)?;
canonical.insert(canonical_key, canonical_val);
}
Ok(Value::Object(canonical))
}
}
}
fn ash_canonicalize_number(n: &serde_json::Number) -> Result<Value, AshError> {
if let Some(i) = n.as_i64() {
if i == 0 {
return Ok(Value::Number(serde_json::Number::from(0)));
}
return Ok(Value::Number(serde_json::Number::from(i)));
}
if let Some(u) = n.as_u64() {
return Ok(Value::Number(serde_json::Number::from(u)));
}
if let Some(f) = n.as_f64() {
if f.is_nan() {
return Err(AshError::new(
AshErrorCode::CanonicalizationError,
"NaN is not supported in ASH canonicalization (RFC 8785)",
));
}
if f.is_infinite() {
return Err(AshError::new(
AshErrorCode::CanonicalizationError,
"Infinity is not supported in ASH canonicalization (RFC 8785)",
));
}
let f = if f == 0.0 && f.is_sign_negative() {
0.0
} else {
f
};
const MAX_SAFE_INT: f64 = 9007199254740991.0; if f.fract() == 0.0 && (-MAX_SAFE_INT..=MAX_SAFE_INT).contains(&f) {
let i = f as i64;
return Ok(Value::Number(serde_json::Number::from(i)));
}
serde_json::Number::from_f64(f)
.map(Value::Number)
.ok_or_else(|| {
AshError::new(
AshErrorCode::CanonicalizationError,
"Failed to canonicalize number",
)
})
} else {
Err(AshError::new(
AshErrorCode::CanonicalizationError,
"Unsupported number format",
))
}
}
fn ash_canonicalize_string(s: &str) -> String {
s.nfc().collect()
}
pub fn ash_canonicalize_json_value(value: &Value) -> Result<String, AshError> {
let canonical = ash_canonicalize_value_with_depth(value, 0)?;
serde_json::to_string(&canonical).map_err(|e| {
AshError::new(
AshErrorCode::CanonicalizationError,
format!("Failed to serialize: {}", e),
)
})
}
pub fn ash_canonicalize_json_value_with_size_check(value: &Value) -> Result<String, AshError> {
let serialized = serde_json::to_string(value).map_err(|e| {
AshError::new(
AshErrorCode::CanonicalizationError,
format!("Failed to serialize: {}", e),
)
})?;
if serialized.len() > MAX_PAYLOAD_SIZE {
return Err(AshError::new(
AshErrorCode::CanonicalizationError,
format!("Payload exceeds maximum size of {} bytes", MAX_PAYLOAD_SIZE),
));
}
ash_canonicalize_json_value(value)
}
pub fn ash_canonicalize_urlencoded(input: &str) -> Result<String, AshError> {
if input.len() > MAX_PAYLOAD_SIZE {
return Err(AshError::new(
AshErrorCode::CanonicalizationError,
format!("Payload exceeds maximum size of {} bytes", MAX_PAYLOAD_SIZE),
));
}
if input.is_empty() {
return Ok(String::new());
}
let mut pairs: Vec<(String, String)> = Vec::new();
for part in input.split('&') {
if part.is_empty() {
continue;
}
let (key, value) = match part.find('=') {
Some(pos) => (&part[..pos], &part[pos + 1..]),
None => (part, ""),
};
let decoded_key = ash_percent_decode_query(key)?;
let decoded_value = ash_percent_decode_query(value)?;
let normalized_key: String = decoded_key.nfc().collect();
let normalized_value: String = decoded_value.nfc().collect();
pairs.push((normalized_key, normalized_value));
}
pairs.sort_by(|a, b| {
match a.0.as_bytes().cmp(b.0.as_bytes()) {
std::cmp::Ordering::Equal => a.1.as_bytes().cmp(b.1.as_bytes()),
other => other,
}
});
let encoded: Vec<String> = pairs
.into_iter()
.map(|(k, v)| format!("{}={}", ash_percent_encode_uppercase(&k), ash_percent_encode_uppercase(&v)))
.collect();
Ok(encoded.join("&"))
}
fn ash_percent_decode_query(input: &str) -> Result<String, AshError> {
let mut bytes = Vec::with_capacity(input.len());
let mut chars = input.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '%' {
let hex: String = chars.by_ref().take(2).collect();
if hex.len() != 2 {
return Err(AshError::new(
AshErrorCode::CanonicalizationError,
"Invalid percent encoding",
));
}
let byte = u8::from_str_radix(&hex, 16).map_err(|_| {
AshError::new(
AshErrorCode::CanonicalizationError,
"Invalid percent encoding hex",
)
})?;
bytes.push(byte);
} else {
let mut buf = [0u8; 4];
let encoded = ch.encode_utf8(&mut buf);
bytes.extend_from_slice(encoded.as_bytes());
}
}
String::from_utf8(bytes).map_err(|_| {
AshError::new(
AshErrorCode::CanonicalizationError,
"Invalid UTF-8 in percent-decoded string",
)
})
}
pub fn ash_canonicalize_query(input: &str) -> Result<String, AshError> {
if input.len() > MAX_PAYLOAD_SIZE {
return Err(AshError::new(
AshErrorCode::CanonicalizationError,
format!("Query string exceeds maximum size of {} bytes", MAX_PAYLOAD_SIZE),
));
}
let query = input.strip_prefix('?').unwrap_or(input);
let query = query.split('#').next().unwrap_or(query);
if query.is_empty() {
return Ok(String::new());
}
let mut pairs: Vec<(String, String)> = Vec::new();
for part in query.split('&') {
if part.is_empty() {
continue;
}
let (key, value) = match part.find('=') {
Some(pos) => (&part[..pos], &part[pos + 1..]),
None => (part, ""), };
let decoded_key = ash_percent_decode_query(key)?;
let decoded_value = ash_percent_decode_query(value)?;
let normalized_key: String = decoded_key.nfc().collect();
let normalized_value: String = decoded_value.nfc().collect();
pairs.push((normalized_key, normalized_value));
}
pairs.sort_by(|a, b| {
match a.0.as_bytes().cmp(b.0.as_bytes()) {
std::cmp::Ordering::Equal => a.1.as_bytes().cmp(b.1.as_bytes()),
other => other,
}
});
let encoded: Vec<String> = pairs
.into_iter()
.map(|(k, v)| {
format!(
"{}={}",
ash_percent_encode_uppercase(&k),
ash_percent_encode_uppercase(&v)
)
})
.collect();
Ok(encoded.join("&"))
}
fn ash_percent_encode_uppercase(input: &str) -> String {
let mut result = String::with_capacity(input.len() * 3);
for ch in input.chars() {
match ch {
'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' => {
result.push(ch);
}
' ' => {
result.push_str("%20");
}
_ => {
let mut buf = [0u8; 4];
let encoded = ch.encode_utf8(&mut buf);
for byte in encoded.as_bytes() {
result.push('%');
use std::fmt::Write;
write!(result, "{:02X}", byte).unwrap();
}
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_canonicalize_json_simple_object() {
let input = r#"{"z":1,"a":2}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":2,"z":1}"#);
}
#[test]
fn test_canonicalize_json_nested_object() {
let input = r#"{"b":{"d":4,"c":3},"a":1}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":1,"b":{"c":3,"d":4}}"#);
}
#[test]
fn test_canonicalize_json_with_whitespace() {
let input = r#"{ "z" : 1 , "a" : 2 }"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":2,"z":1}"#);
}
#[test]
fn test_canonicalize_json_array_preserved() {
let input = r#"{"arr":[3,1,2]}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"arr":[3,1,2]}"#);
}
#[test]
fn test_canonicalize_json_null() {
let input = r#"{"a":null}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":null}"#);
}
#[test]
fn test_canonicalize_json_boolean() {
let input = r#"{"b":true,"a":false}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":false,"b":true}"#);
}
#[test]
fn test_canonicalize_json_empty_object() {
let input = r#"{}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{}"#);
}
#[test]
fn test_canonicalize_json_empty_array() {
let input = r#"[]"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"[]"#);
}
#[test]
fn test_canonicalize_json_unicode() {
let input = r#"{"name":"café"}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"name":"café"}"#);
}
#[test]
fn test_canonicalize_json_invalid() {
let input = r#"{"a":}"#;
let result = ash_canonicalize_json(input);
assert!(result.is_err());
let err = result.unwrap_err();
let err_msg = err.message();
assert!(!err_msg.contains("{"), "Error should not contain JSON fragments");
assert!(err_msg.contains("Invalid") || err_msg.contains("invalid"));
}
#[test]
fn test_canonicalize_json_whole_float_becomes_integer() {
let input = r#"{"a":5.0}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":5}"#);
}
#[test]
fn test_canonicalize_json_negative_zero_becomes_zero() {
let input = r#"{"a":-0.0}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":0}"#);
}
#[test]
fn test_canonicalize_json_preserves_fractional() {
let input = r#"{"a":5.5}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":5.5}"#);
}
#[test]
fn test_canonicalize_json_large_whole_float() {
let input = r#"{"a":1000000.0}"#;
let output = ash_canonicalize_json(input).unwrap();
assert_eq!(output, r#"{"a":1000000}"#);
}
#[test]
fn test_canonicalize_urlencoded_simple() {
let input = "z=3&a=1&b=2";
let output = ash_canonicalize_urlencoded(input).unwrap();
assert_eq!(output, "a=1&b=2&z=3");
}
#[test]
fn test_canonicalize_urlencoded_duplicate_keys() {
let input = "a=2&a=1&b=3";
let output = ash_canonicalize_urlencoded(input).unwrap();
assert_eq!(output, "a=1&a=2&b=3");
}
#[test]
fn test_canonicalize_urlencoded_encoded_space() {
let input = "a=hello%20world";
let output = ash_canonicalize_urlencoded(input).unwrap();
assert_eq!(output, "a=hello%20world");
}
#[test]
fn test_canonicalize_urlencoded_plus_as_literal() {
let input = "a=hello+world";
let output = ash_canonicalize_urlencoded(input).unwrap();
assert_eq!(output, "a=hello%2Bworld");
}
#[test]
fn test_canonicalize_urlencoded_empty() {
let input = "";
let output = ash_canonicalize_urlencoded(input).unwrap();
assert_eq!(output, "");
}
#[test]
fn test_canonicalize_urlencoded_no_value() {
let input = "a&b=2";
let output = ash_canonicalize_urlencoded(input).unwrap();
assert_eq!(output, "a=&b=2");
}
#[test]
fn test_canonicalize_query_strips_fragment() {
let input = "z=3&a=1#section";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "a=1&z=3");
}
#[test]
fn test_canonicalize_query_strips_fragment_with_question_mark() {
let input = "?z=3&a=1#fragment";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "a=1&z=3");
}
#[test]
fn test_canonicalize_query_plus_is_literal() {
let input = "a=hello+world";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "a=hello%2Bworld");
}
#[test]
fn test_canonicalize_query_space_is_percent20() {
let input = "a=hello%20world";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "a=hello%20world");
}
#[test]
fn test_canonicalize_query_preserves_empty_value() {
let input = "a=&b=2";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "a=&b=2");
}
#[test]
fn test_canonicalize_query_key_without_equals() {
let input = "flag&b=2";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "b=2&flag=");
}
#[test]
fn test_canonicalize_query_sorts_by_key_then_value() {
let input = "a=2&a=1&a=3";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "a=1&a=2&a=3");
}
#[test]
fn test_canonicalize_query_uppercase_hex() {
let input = "a=hello%20world"; let output = ash_canonicalize_query(input).unwrap();
assert!(output.contains("%20"));
assert!(!output.contains("%2a")); }
#[test]
fn test_canonicalize_query_byte_order_sorting() {
let input = "z=1&A=2&a=3&0=4";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "0=4&A=2&a=3&z=1");
}
#[test]
fn test_canonicalize_query_only_fragment() {
let input = "#onlyfragment";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "");
}
#[test]
fn test_canonicalize_query_empty_with_question_mark() {
let input = "?";
let output = ash_canonicalize_query(input).unwrap();
assert_eq!(output, "");
}
#[test]
fn test_rejects_deeply_nested_json() {
let mut input = String::from("{\"a\":");
for _ in 0..100 {
input.push_str("{\"a\":");
}
input.push('1');
for _ in 0..101 {
input.push('}');
}
let result = ash_canonicalize_json(&input);
assert!(result.is_err());
assert!(result.unwrap_err().message().contains("nesting depth"));
}
#[test]
fn test_accepts_moderately_nested_json() {
let mut input = String::from("{\"a\":");
for _ in 0..30 {
input.push_str("{\"a\":");
}
input.push('1');
for _ in 0..31 {
input.push('}');
}
let result = ash_canonicalize_json(&input);
assert!(result.is_ok());
}
#[test]
fn test_rejects_oversized_json_payload() {
let large_value = "x".repeat(11 * 1024 * 1024); let input = format!(r#"{{"data":"{}"}}"#, large_value);
let result = ash_canonicalize_json(&input);
assert!(result.is_err());
assert!(result.unwrap_err().message().contains("maximum size"));
}
#[test]
fn test_rejects_oversized_query_string() {
let large_value = "x".repeat(11 * 1024 * 1024); let input = format!("a={}", large_value);
let result = ash_canonicalize_query(&input);
assert!(result.is_err());
assert!(result.unwrap_err().message().contains("maximum size"));
}
}