use std::fmt;
use std::sync::Arc;
use thiserror::Error;
use crate::token::RETRYABLE_ERROR_CODES;
#[derive(Debug)]
pub enum HttpError {
Reqwest(Arc<reqwest::Error>),
Decode(String),
}
impl fmt::Display for HttpError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
HttpError::Reqwest(e) => write!(f, "{}", e),
HttpError::Decode(msg) => write!(f, "Response decode error: {}", msg),
}
}
}
impl Clone for HttpError {
fn clone(&self) -> Self {
match self {
HttpError::Reqwest(e) => HttpError::Reqwest(Arc::clone(e)),
HttpError::Decode(msg) => HttpError::Decode(msg.clone()),
}
}
}
impl std::error::Error for HttpError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
HttpError::Reqwest(e) => Some(e.as_ref()),
HttpError::Decode(_) => None,
}
}
}
impl From<reqwest::Error> for HttpError {
fn from(e: reqwest::Error) -> Self {
HttpError::Reqwest(Arc::new(e))
}
}
impl HttpError {
pub fn is_transient(&self) -> bool {
match self {
HttpError::Reqwest(error) => match error.status() {
Some(status) => status.is_server_error() || status.as_u16() == 429,
None => true,
},
HttpError::Decode(_) => false,
}
}
}
#[derive(Debug, Error)]
pub enum WechatError {
#[error("{0}")]
Http(HttpError),
#[error("JSON serialization error: {0}")]
Json(#[from] serde_json::Error),
#[error("WeChat API error (code={code}): {message}")]
Api { code: i32, message: String },
#[error("Access token error: {0}")]
Token(String),
#[error("Configuration error: {0}")]
Config(String),
#[error("Signature verification failed: {0}")]
Signature(String),
#[error("Crypto operation error: {0}")]
Crypto(String),
#[error("Invalid AppId: {0}")]
InvalidAppId(String),
#[error("Invalid OpenId: {0}")]
InvalidOpenId(String),
#[error("Invalid AccessToken: {0}")]
InvalidAccessToken(String),
#[error("Invalid AppSecret: {0}")]
InvalidAppSecret(String),
#[error("Invalid SessionKey: {0}")]
InvalidSessionKey(String),
#[error("Invalid UnionId: {0}")]
InvalidUnionId(String),
}
impl Clone for WechatError {
fn clone(&self) -> Self {
match self {
WechatError::Http(e) => WechatError::Http(e.clone()),
WechatError::Json(e) => WechatError::Json(serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::Other,
e.to_string(),
))),
WechatError::Api { code, message } => WechatError::Api {
code: *code,
message: message.clone(),
},
WechatError::Token(msg) => WechatError::Token(msg.clone()),
WechatError::Config(msg) => WechatError::Config(msg.clone()),
WechatError::Signature(msg) => WechatError::Signature(msg.clone()),
WechatError::Crypto(msg) => WechatError::Crypto(msg.clone()),
WechatError::InvalidAppId(msg) => WechatError::InvalidAppId(msg.clone()),
WechatError::InvalidOpenId(msg) => WechatError::InvalidOpenId(msg.clone()),
WechatError::InvalidAccessToken(msg) => WechatError::InvalidAccessToken(msg.clone()),
WechatError::InvalidAppSecret(msg) => WechatError::InvalidAppSecret(msg.clone()),
WechatError::InvalidSessionKey(msg) => WechatError::InvalidSessionKey(msg.clone()),
WechatError::InvalidUnionId(msg) => WechatError::InvalidUnionId(msg.clone()),
}
}
}
impl WechatError {
pub(crate) fn check_api(errcode: i32, errmsg: &str) -> Result<(), WechatError> {
if errcode != 0 {
Err(WechatError::Api {
code: errcode,
message: errmsg.to_string(),
})
} else {
Ok(())
}
}
pub fn is_transient(&self) -> bool {
match self {
WechatError::Http(err) => err.is_transient(),
WechatError::Api { code, .. } => RETRYABLE_ERROR_CODES.contains(code),
_ => false,
}
}
}
impl From<reqwest::Error> for WechatError {
fn from(e: reqwest::Error) -> Self {
WechatError::Http(HttpError::Reqwest(Arc::new(e)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::token::RETRYABLE_ERROR_CODES;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
#[test]
fn test_invalid_appid_error_message() {
let err = WechatError::InvalidAppId("invalid".to_string());
assert_eq!(err.to_string(), "Invalid AppId: invalid");
}
#[test]
fn test_invalid_openid_error_message() {
let err = WechatError::InvalidOpenId("short".to_string());
assert_eq!(err.to_string(), "Invalid OpenId: short");
}
#[test]
fn test_invalid_access_token_error_message() {
let err = WechatError::InvalidAccessToken("".to_string());
assert_eq!(err.to_string(), "Invalid AccessToken: ");
}
#[test]
fn test_invalid_app_secret_error_message() {
let err = WechatError::InvalidAppSecret("wrong".to_string());
assert_eq!(err.to_string(), "Invalid AppSecret: wrong");
}
#[test]
fn test_invalid_session_key_error_message() {
let err = WechatError::InvalidSessionKey("invalid".to_string());
assert_eq!(err.to_string(), "Invalid SessionKey: invalid");
}
#[test]
fn test_invalid_union_id_error_message() {
let err = WechatError::InvalidUnionId("".to_string());
assert_eq!(err.to_string(), "Invalid UnionId: ");
}
#[test]
fn test_check_api_success() {
let result = WechatError::check_api(0, "success");
assert!(result.is_ok());
}
#[test]
fn test_check_api_error() {
let result = WechatError::check_api(40013, "invalid appid");
assert!(result.is_err());
if let Err(WechatError::Api { code, message }) = result {
assert_eq!(code, 40013);
assert_eq!(message, "invalid appid");
} else {
panic!("Expected Api error");
}
}
#[test]
fn test_wechat_error_clone() {
let err = WechatError::Api {
code: 40013,
message: "invalid appid".to_string(),
};
let cloned = err.clone();
assert_eq!(format!("{}", err), format!("{}", cloned));
let token_err = WechatError::Token("expired".to_string());
let cloned_token = token_err.clone();
assert_eq!(format!("{}", token_err), format!("{}", cloned_token));
}
#[test]
fn test_http_error_clone() {
let err = HttpError::Decode("bad json".to_string());
let cloned = err.clone();
assert_eq!(format!("{}", err), format!("{}", cloned));
}
#[test]
fn test_http_error_source_chain() {
use std::error::Error;
let decode_err = HttpError::Decode("test".to_string());
assert!(decode_err.source().is_none());
}
#[test]
fn test_http_error_is_transient() {
let reqwest_error = reqwest::Client::new().get("http://").build().unwrap_err();
let reqwest_http_error = HttpError::Reqwest(Arc::new(reqwest_error));
assert!(reqwest_http_error.is_transient());
let decode_http_error = HttpError::Decode("bad json".to_string());
assert!(!decode_http_error.is_transient());
}
#[test]
fn test_wechat_error_is_transient_for_http_variants() {
let reqwest_error = reqwest::Client::new().get("http://").build().unwrap_err();
let transient_error = WechatError::Http(HttpError::Reqwest(Arc::new(reqwest_error)));
assert!(transient_error.is_transient());
let non_transient_error = WechatError::Http(HttpError::Decode("bad json".to_string()));
assert!(!non_transient_error.is_transient());
}
#[tokio::test]
async fn test_http_reqwest_status_503_is_transient() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/status-503"))
.respond_with(ResponseTemplate::new(503))
.mount(&mock_server)
.await;
let err = reqwest::Client::new()
.get(format!("{}/status-503", mock_server.uri()))
.send()
.await
.unwrap()
.error_for_status()
.unwrap_err();
let http_error = HttpError::Reqwest(Arc::new(err));
assert!(http_error.is_transient());
}
#[tokio::test]
async fn test_http_reqwest_status_400_is_not_transient() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/status-400"))
.respond_with(ResponseTemplate::new(400))
.mount(&mock_server)
.await;
let err = reqwest::Client::new()
.get(format!("{}/status-400", mock_server.uri()))
.send()
.await
.unwrap()
.error_for_status()
.unwrap_err();
let http_error = HttpError::Reqwest(Arc::new(err));
assert!(!http_error.is_transient());
}
#[test]
fn test_wechat_error_is_transient_for_api_and_all_other_variants() {
for &code in RETRYABLE_ERROR_CODES {
let retryable = WechatError::Api {
code,
message: "retryable".to_string(),
};
assert!(
retryable.is_transient(),
"code {} should be transient",
code
);
}
let non_retryable_api = WechatError::Api {
code: 40013,
message: "invalid appid".to_string(),
};
assert!(!non_retryable_api.is_transient());
let json_error = serde_json::from_str::<serde_json::Value>("not json").unwrap_err();
let non_transient_variants = [
WechatError::Json(json_error),
WechatError::Token("token".to_string()),
WechatError::Config("config".to_string()),
WechatError::Signature("sig".to_string()),
WechatError::Crypto("crypto".to_string()),
WechatError::InvalidAppId("appid".to_string()),
WechatError::InvalidOpenId("openid".to_string()),
WechatError::InvalidAccessToken("token".to_string()),
WechatError::InvalidAppSecret("secret".to_string()),
WechatError::InvalidSessionKey("session".to_string()),
WechatError::InvalidUnionId("union".to_string()),
];
for error in non_transient_variants {
assert!(!error.is_transient());
}
}
}