use super::{codegen, parse};
fn pretty_print(tokens: proc_macro2::TokenStream) -> String {
let file = syn::parse2(tokens).expect("valid tokens");
prettyplease::unparse(&file)
}
fn derive_error_enum(input: syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let ir = parse::parse_error_derive(input)?;
Ok(codegen::generate(&ir))
}
fn derive_error_struct(input: syn::DeriveInput) -> syn::Result<proc_macro2::TokenStream> {
let ir = parse::parse_error_derive(input)?;
Ok(codegen::generate(&ir))
}
mod expansion {
use super::*;
mod enums {
use super::*;
#[test]
fn simple() {
let output = derive_error_enum(syn::parse_quote! {
enum SimpleError {
#[error("Something went wrong")]
#[diagnostic(code(errors::simple))]
Simple,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn with_fields() {
let output = derive_error_enum(syn::parse_quote! {
enum ErrorWithFields {
#[error("Invalid port")]
#[diagnostic(code(config::invalid_port))]
InvalidPort {
#[extension]
port: u16,
#[extension]
config_file: String,
},
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn with_help() {
let output = derive_error_enum(syn::parse_quote! {
enum ErrorWithHelp {
#[error("Invalid configuration")]
#[diagnostic(code(config::invalid), help("Check your configuration file for syntax errors"))]
InvalidConfig,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn with_http_status() {
let output = derive_error_enum(syn::parse_quote! {
enum ErrorWithHttpStatus {
#[error("Resource not found")]
#[diagnostic(code(resource::not_found))]
#[http_status(404)]
NotFound,
#[error("Unauthorized access")]
#[diagnostic(code(auth::unauthorized))]
#[http_status(401)]
Unauthorized,
#[error("Internal server error")]
#[diagnostic(code(server::internal))]
InternalError,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn with_severity() {
let output = derive_error_enum(syn::parse_quote! {
enum ErrorWithSeverity {
#[error("Deprecated API usage")]
#[diagnostic(code(api::deprecated), severity(Warning))]
DeprecatedApi,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn with_url() {
let output = derive_error_enum(syn::parse_quote! {
enum ErrorWithUrl {
#[error("Database connection failed")]
#[diagnostic(code(db::connection_failed), url("https://docs.example.com/errors/db-connection"))]
ConnectionFailed,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn multiple_variants() {
let output = derive_error_enum(syn::parse_quote! {
enum MultiError {
#[error("First error")]
#[diagnostic(code(multi::first))]
First,
#[error("Second error with field")]
#[diagnostic(code(multi::second))]
Second {
#[extension]
value: i32,
},
#[error("Third error")]
#[diagnostic(code(multi::third), help("Try something else"))]
#[http_status(400)]
Third,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn with_source() {
let output = derive_error_enum(syn::parse_quote! {
enum ErrorWithSource {
#[error("IO error occurred")]
#[diagnostic(code(io::error))]
IoError {
#[source]
source: std::io::Error,
},
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn transparent_variant() {
let output = derive_error_enum(syn::parse_quote! {
enum ErrorWithTransparent {
#[error("Regular error")]
#[diagnostic(code(errors::regular))]
Regular,
#[diagnostic(transparent)]
Inner(std::io::Error),
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn with_from() {
let output = derive_error_enum(syn::parse_quote! {
enum ErrorWithFrom {
#[error("IO error")]
#[diagnostic(code(io::error))]
Io {
#[from]
source: std::io::Error,
},
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
}
mod structs {
use super::*;
#[test]
fn simple_struct() {
let output = derive_error_struct(syn::parse_quote! {
#[error("Something went wrong")]
#[diagnostic(code(errors::struct_error))]
struct SimpleStructError;
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn struct_with_fields() {
let output = derive_error_struct(syn::parse_quote! {
#[error("Invalid configuration")]
#[diagnostic(code(config::invalid))]
struct ConfigError {
#[extension]
field: String,
#[extension]
line: u32,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn struct_with_help() {
let output = derive_error_struct(syn::parse_quote! {
#[error("Missing required field")]
#[diagnostic(code(validation::missing_field), help("Ensure all required fields are provided"))]
struct MissingFieldError {
#[extension]
field_name: String,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn struct_with_http_status() {
let output = derive_error_struct(syn::parse_quote! {
#[error("Not found")]
#[diagnostic(code(http::not_found))]
#[http_status(404)]
struct NotFoundError {
#[extension]
resource: String,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
#[test]
fn struct_with_source() {
let output = derive_error_struct(syn::parse_quote! {
#[error("Database error")]
#[diagnostic(code(db::error))]
struct DbError {
#[source]
source: std::io::Error,
}
})
.unwrap();
insta::assert_snapshot!(pretty_print(output));
}
}
}
mod errors {
use super::*;
#[test]
fn union_not_supported() {
let result = derive_error_enum(syn::parse_quote! {
union NotSupported {
a: u32,
b: f32,
}
});
let Err(err) = result else {
panic!("union should return error");
};
assert!(err.to_string().contains("cannot be used on unions"));
}
#[test]
fn error_code_too_few_segments() {
let result = derive_error_enum(syn::parse_quote! {
enum MyError {
#[error("Bad code")]
#[diagnostic(code(just_one))]
Bad,
}
});
let Err(err) = result else {
panic!("too few segments should return error");
};
assert!(err.to_string().contains("at least 2 segments"));
}
#[test]
fn error_code_uppercase_rejected() {
let result = derive_error_enum(syn::parse_quote! {
enum MyError {
#[error("Bad code")]
#[diagnostic(code(TEST::UPPERCASE))]
Bad,
}
});
let Err(err) = result else {
panic!("uppercase code should return error");
};
assert!(err.to_string().contains("lowercase"));
}
#[test]
fn optional_field_in_enum_message_rejected() {
let result = derive_error_enum(syn::parse_quote! {
enum MyError {
#[error("Retry after {retry_after}")]
#[diagnostic(code(rate::limited))]
RateLimited {
retry_after: Option<u64>,
},
}
});
let Err(err) = result else {
panic!("optional field in message should return error");
};
assert!(
err.to_string()
.contains("optional field `retry_after` cannot be used in error message")
);
}
#[test]
fn optional_field_in_struct_message_rejected() {
let result = derive_error_struct(syn::parse_quote! {
#[error("Error at line {line}")]
#[diagnostic(code(parse::error))]
struct ParseError {
line: Option<u32>,
}
});
let Err(err) = result else {
panic!("optional field in message should return error");
};
assert!(
err.to_string()
.contains("optional field `line` cannot be used in error message")
);
}
#[test]
fn optional_field_not_in_message_allowed() {
let result = derive_error_struct(syn::parse_quote! {
#[error("Something went wrong")]
#[diagnostic(code(errors::test))]
struct TestError {
#[extension]
context: Option<String>,
}
});
assert!(
result.is_ok(),
"optional field not in message should be allowed"
);
}
#[test]
fn non_optional_field_in_message_allowed() {
let result = derive_error_struct(syn::parse_quote! {
#[error("Invalid value: {value}")]
#[diagnostic(code(errors::invalid))]
struct InvalidError {
value: String,
}
});
assert!(
result.is_ok(),
"non-optional field in message should be allowed"
);
}
}