use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum AshErrorCode {
CtxNotFound,
CtxExpired,
CtxAlreadyUsed,
BindingMismatch,
ProofMissing,
ProofInvalid,
CanonicalizationError,
ValidationError,
ModeViolation,
UnsupportedContentType,
ScopeMismatch,
ChainBroken,
InternalError,
TimestampInvalid,
ScopedFieldMissing,
}
impl Serialize for AshErrorCode {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_str(self.as_str())
}
}
impl<'de> Deserialize<'de> for AshErrorCode {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let s = String::deserialize(deserializer)?;
match s.as_str() {
"ASH_CTX_NOT_FOUND" => Ok(AshErrorCode::CtxNotFound),
"ASH_CTX_EXPIRED" => Ok(AshErrorCode::CtxExpired),
"ASH_CTX_ALREADY_USED" => Ok(AshErrorCode::CtxAlreadyUsed),
"ASH_BINDING_MISMATCH" => Ok(AshErrorCode::BindingMismatch),
"ASH_PROOF_MISSING" => Ok(AshErrorCode::ProofMissing),
"ASH_PROOF_INVALID" => Ok(AshErrorCode::ProofInvalid),
"ASH_CANONICALIZATION_ERROR" => Ok(AshErrorCode::CanonicalizationError),
"ASH_VALIDATION_ERROR" => Ok(AshErrorCode::ValidationError),
"ASH_MODE_VIOLATION" => Ok(AshErrorCode::ModeViolation),
"ASH_UNSUPPORTED_CONTENT_TYPE" => Ok(AshErrorCode::UnsupportedContentType),
"ASH_SCOPE_MISMATCH" => Ok(AshErrorCode::ScopeMismatch),
"ASH_CHAIN_BROKEN" => Ok(AshErrorCode::ChainBroken),
"ASH_INTERNAL_ERROR" => Ok(AshErrorCode::InternalError),
"ASH_TIMESTAMP_INVALID" => Ok(AshErrorCode::TimestampInvalid),
"ASH_SCOPED_FIELD_MISSING" => Ok(AshErrorCode::ScopedFieldMissing),
_ => Err(serde::de::Error::unknown_variant(
&s,
&[
"ASH_CTX_NOT_FOUND", "ASH_CTX_EXPIRED", "ASH_CTX_ALREADY_USED",
"ASH_BINDING_MISMATCH", "ASH_PROOF_MISSING", "ASH_PROOF_INVALID",
"ASH_CANONICALIZATION_ERROR", "ASH_VALIDATION_ERROR", "ASH_MODE_VIOLATION",
"ASH_UNSUPPORTED_CONTENT_TYPE", "ASH_SCOPE_MISMATCH", "ASH_CHAIN_BROKEN",
"ASH_INTERNAL_ERROR", "ASH_TIMESTAMP_INVALID", "ASH_SCOPED_FIELD_MISSING",
],
)),
}
}
}
impl AshErrorCode {
pub fn http_status(&self) -> u16 {
match self {
AshErrorCode::CtxNotFound => 450,
AshErrorCode::CtxExpired => 451,
AshErrorCode::CtxAlreadyUsed => 452,
AshErrorCode::ProofInvalid => 460,
AshErrorCode::BindingMismatch => 461,
AshErrorCode::ScopeMismatch => 473,
AshErrorCode::ChainBroken => 474,
AshErrorCode::ScopedFieldMissing => 475,
AshErrorCode::TimestampInvalid => 482,
AshErrorCode::ProofMissing => 483,
AshErrorCode::CanonicalizationError => 484,
AshErrorCode::ValidationError => 485,
AshErrorCode::ModeViolation => 486,
AshErrorCode::UnsupportedContentType => 415,
AshErrorCode::InternalError => 500,
}
}
pub fn retryable(&self) -> bool {
matches!(
self,
AshErrorCode::TimestampInvalid | AshErrorCode::InternalError
)
}
pub fn as_str(&self) -> &'static str {
match self {
AshErrorCode::CtxNotFound => "ASH_CTX_NOT_FOUND",
AshErrorCode::CtxExpired => "ASH_CTX_EXPIRED",
AshErrorCode::CtxAlreadyUsed => "ASH_CTX_ALREADY_USED",
AshErrorCode::BindingMismatch => "ASH_BINDING_MISMATCH",
AshErrorCode::ProofMissing => "ASH_PROOF_MISSING",
AshErrorCode::ProofInvalid => "ASH_PROOF_INVALID",
AshErrorCode::CanonicalizationError => "ASH_CANONICALIZATION_ERROR",
AshErrorCode::ValidationError => "ASH_VALIDATION_ERROR",
AshErrorCode::ModeViolation => "ASH_MODE_VIOLATION",
AshErrorCode::UnsupportedContentType => "ASH_UNSUPPORTED_CONTENT_TYPE",
AshErrorCode::ScopeMismatch => "ASH_SCOPE_MISMATCH",
AshErrorCode::ChainBroken => "ASH_CHAIN_BROKEN",
AshErrorCode::InternalError => "ASH_INTERNAL_ERROR",
AshErrorCode::TimestampInvalid => "ASH_TIMESTAMP_INVALID",
AshErrorCode::ScopedFieldMissing => "ASH_SCOPED_FIELD_MISSING",
}
}
}
impl fmt::Display for AshErrorCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InternalReason {
HdrMissing,
HdrMultiValue,
HdrInvalidChars,
TsParse,
TsSkew,
TsLeadingZeros,
TsOverflow,
NonceTooShort,
NonceTooLong,
NonceInvalidChars,
General,
}
impl fmt::Display for InternalReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
InternalReason::HdrMissing => write!(f, "HDR_MISSING"),
InternalReason::HdrMultiValue => write!(f, "HDR_MULTI_VALUE"),
InternalReason::HdrInvalidChars => write!(f, "HDR_INVALID_CHARS"),
InternalReason::TsParse => write!(f, "TS_PARSE"),
InternalReason::TsSkew => write!(f, "TS_SKEW"),
InternalReason::TsLeadingZeros => write!(f, "TS_LEADING_ZEROS"),
InternalReason::TsOverflow => write!(f, "TS_OVERFLOW"),
InternalReason::NonceTooShort => write!(f, "NONCE_TOO_SHORT"),
InternalReason::NonceTooLong => write!(f, "NONCE_TOO_LONG"),
InternalReason::NonceInvalidChars => write!(f, "NONCE_INVALID_CHARS"),
InternalReason::General => write!(f, "GENERAL"),
}
}
}
#[derive(Debug, Clone)]
pub struct AshError {
code: AshErrorCode,
message: String,
reason: InternalReason,
details: Option<BTreeMap<&'static str, String>>,
}
impl AshError {
pub fn new(code: AshErrorCode, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
reason: InternalReason::General,
details: None,
}
}
pub fn with_reason(code: AshErrorCode, reason: InternalReason, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
reason,
details: None,
}
}
pub fn with_detail(mut self, key: &'static str, value: impl Into<String>) -> Self {
let map = self.details.get_or_insert_with(BTreeMap::new);
map.insert(key, value.into());
self
}
pub fn code(&self) -> AshErrorCode {
self.code
}
pub fn message(&self) -> &str {
&self.message
}
pub fn http_status(&self) -> u16 {
self.code.http_status()
}
pub fn reason(&self) -> InternalReason {
self.reason
}
pub fn details(&self) -> Option<&BTreeMap<&'static str, String>> {
self.details.as_ref()
}
pub fn retryable(&self) -> bool {
self.code.retryable()
}
}
impl fmt::Display for AshError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.code, self.message)
}
}
impl std::error::Error for AshError {}
impl AshError {
pub fn ctx_not_found() -> Self {
Self::new(AshErrorCode::CtxNotFound, "Context not found")
}
pub fn ctx_expired() -> Self {
Self::new(AshErrorCode::CtxExpired, "Context has expired")
}
pub fn ctx_already_used() -> Self {
Self::new(AshErrorCode::CtxAlreadyUsed, "Context already consumed")
}
pub fn binding_mismatch() -> Self {
Self::new(
AshErrorCode::BindingMismatch,
"Binding does not match endpoint",
)
}
pub fn proof_missing() -> Self {
Self::new(AshErrorCode::ProofMissing, "Required proof not provided")
}
pub fn proof_invalid() -> Self {
Self::new(AshErrorCode::ProofInvalid, "Proof verification failed")
}
pub fn canonicalization_error() -> Self {
Self::new(
AshErrorCode::CanonicalizationError,
"Failed to canonicalize payload",
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_code_http_status() {
assert_eq!(AshErrorCode::CtxNotFound.http_status(), 450);
assert_eq!(AshErrorCode::CtxExpired.http_status(), 451);
assert_eq!(AshErrorCode::CtxAlreadyUsed.http_status(), 452);
assert_eq!(AshErrorCode::ProofInvalid.http_status(), 460);
assert_eq!(AshErrorCode::BindingMismatch.http_status(), 461);
assert_eq!(AshErrorCode::ScopeMismatch.http_status(), 473);
assert_eq!(AshErrorCode::ChainBroken.http_status(), 474);
assert_eq!(AshErrorCode::ScopedFieldMissing.http_status(), 475);
assert_eq!(AshErrorCode::TimestampInvalid.http_status(), 482);
assert_eq!(AshErrorCode::ProofMissing.http_status(), 483);
assert_eq!(AshErrorCode::CanonicalizationError.http_status(), 484);
assert_eq!(AshErrorCode::ValidationError.http_status(), 485);
assert_eq!(AshErrorCode::ModeViolation.http_status(), 486);
assert_eq!(AshErrorCode::UnsupportedContentType.http_status(), 415);
assert_eq!(AshErrorCode::InternalError.http_status(), 500);
}
#[test]
fn test_error_code_as_str() {
assert_eq!(AshErrorCode::CtxNotFound.as_str(), "ASH_CTX_NOT_FOUND");
assert_eq!(AshErrorCode::CtxAlreadyUsed.as_str(), "ASH_CTX_ALREADY_USED");
}
#[test]
fn test_error_display() {
let err = AshError::ctx_not_found();
assert_eq!(err.to_string(), "ASH_CTX_NOT_FOUND: Context not found");
}
#[test]
fn test_error_convenience_functions() {
assert_eq!(
AshError::ctx_not_found().code(),
AshErrorCode::CtxNotFound
);
assert_eq!(
AshError::ctx_expired().code(),
AshErrorCode::CtxExpired
);
assert_eq!(
AshError::ctx_already_used().code(),
AshErrorCode::CtxAlreadyUsed
);
}
#[test]
fn test_error_code_serde_serialization() {
let serialized = serde_json::to_string(&AshErrorCode::CtxNotFound).unwrap();
assert_eq!(serialized, r#""ASH_CTX_NOT_FOUND""#);
let serialized = serde_json::to_string(&AshErrorCode::ValidationError).unwrap();
assert_eq!(serialized, r#""ASH_VALIDATION_ERROR""#);
let serialized = serde_json::to_string(&AshErrorCode::ScopedFieldMissing).unwrap();
assert_eq!(serialized, r#""ASH_SCOPED_FIELD_MISSING""#);
}
#[test]
fn test_error_code_serde_deserialization() {
let code: AshErrorCode = serde_json::from_str(r#""ASH_CTX_NOT_FOUND""#).unwrap();
assert_eq!(code, AshErrorCode::CtxNotFound);
let code: AshErrorCode = serde_json::from_str(r#""ASH_PROOF_INVALID""#).unwrap();
assert_eq!(code, AshErrorCode::ProofInvalid);
let code: AshErrorCode = serde_json::from_str(r#""ASH_INTERNAL_ERROR""#).unwrap();
assert_eq!(code, AshErrorCode::InternalError);
}
#[test]
fn test_error_code_serde_roundtrip_all_variants() {
let all_codes = [
AshErrorCode::CtxNotFound,
AshErrorCode::CtxExpired,
AshErrorCode::CtxAlreadyUsed,
AshErrorCode::BindingMismatch,
AshErrorCode::ProofMissing,
AshErrorCode::ProofInvalid,
AshErrorCode::CanonicalizationError,
AshErrorCode::ValidationError,
AshErrorCode::ModeViolation,
AshErrorCode::UnsupportedContentType,
AshErrorCode::ScopeMismatch,
AshErrorCode::ChainBroken,
AshErrorCode::InternalError,
AshErrorCode::TimestampInvalid,
AshErrorCode::ScopedFieldMissing,
];
for code in &all_codes {
let serialized = serde_json::to_string(code).unwrap();
assert!(serialized.contains("ASH_"), "Missing ASH_ prefix for {:?}: {}", code, serialized);
let deserialized: AshErrorCode = serde_json::from_str(&serialized).unwrap();
assert_eq!(*code, deserialized, "Roundtrip failed for {:?}", code);
let expected = format!("\"{}\"", code.as_str());
assert_eq!(serialized, expected, "Serde output doesn't match as_str() for {:?}", code);
}
}
#[test]
fn test_retryable_timestamp_invalid() {
assert!(AshErrorCode::TimestampInvalid.retryable());
}
#[test]
fn test_retryable_internal_error() {
assert!(AshErrorCode::InternalError.retryable());
}
#[test]
fn test_not_retryable_proof_invalid() {
assert!(!AshErrorCode::ProofInvalid.retryable());
}
#[test]
fn test_not_retryable_validation_error() {
assert!(!AshErrorCode::ValidationError.retryable());
}
#[test]
fn test_not_retryable_all_permanent_codes() {
let permanent = [
AshErrorCode::CtxNotFound,
AshErrorCode::CtxExpired,
AshErrorCode::CtxAlreadyUsed,
AshErrorCode::ProofInvalid,
AshErrorCode::BindingMismatch,
AshErrorCode::ScopeMismatch,
AshErrorCode::ChainBroken,
AshErrorCode::ScopedFieldMissing,
AshErrorCode::ProofMissing,
AshErrorCode::CanonicalizationError,
AshErrorCode::ValidationError,
AshErrorCode::ModeViolation,
AshErrorCode::UnsupportedContentType,
];
for code in &permanent {
assert!(!code.retryable(), "{:?} should not be retryable", code);
}
}
#[test]
fn test_ash_error_retryable_delegates() {
let retryable = AshError::new(AshErrorCode::TimestampInvalid, "skew");
assert!(retryable.retryable());
let permanent = AshError::new(AshErrorCode::ProofInvalid, "bad proof");
assert!(!permanent.retryable());
}
#[test]
fn test_error_code_serde_rejects_invalid() {
let result: Result<AshErrorCode, _> = serde_json::from_str(r#""INVALID_CODE""#);
assert!(result.is_err());
let result: Result<AshErrorCode, _> = serde_json::from_str(r#""CTX_NOT_FOUND""#);
assert!(result.is_err());
}
}