#![allow(warnings)]
use aes::cipher::{block_padding::Pkcs7, BlockDecryptMut, BlockEncryptMut, KeyIvInit};
use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
use rand::Rng;
use sha1::{Digest, Sha1};
use std::io::Cursor;
use thiserror::Error;
type Aes128CbcEnc = cbc::Encryptor<aes::Aes256>;
type Aes128CbcDec = cbc::Decryptor<aes::Aes256>;
pub const AES_KEY_SIZE: usize = 32;
pub const AES_IV_SIZE: usize = 16;
pub const ENCODING_KEY_SIZE: usize = 43;
pub const RAND_ENCRYPT_STR_LEN: usize = 16;
pub const MSG_LEN: usize = 4;
pub const MAX_BASE64_SIZE: usize = 1_000_000_000;
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum WXBizMsgCryptError {
#[error("Validate signature error")]
ValidateSignatureError,
#[error("Parse XML error")]
ParseXmlError,
#[error("Compute signature error")]
ComputeSignatureError,
#[error("Illegal AES key")]
IllegalAesKey,
#[error("Validate corpId error")]
ValidateCorpidError,
#[error("Encrypt AES error")]
EncryptAESError,
#[error("Decrypt AES error")]
DecryptAESError,
#[error("Illegal buffer")]
IllegalBuffer,
#[error("Encode Base64 error")]
EncodeBase64Error,
#[error("Decode Base64 error")]
DecodeBase64Error,
#[error("Generate return XML error")]
GenReturnXmlError,
}
impl WXBizMsgCryptError {
pub fn error_code(&self) -> i32 {
match self {
WXBizMsgCryptError::ValidateSignatureError => -40001,
WXBizMsgCryptError::ParseXmlError => -40002,
WXBizMsgCryptError::ComputeSignatureError => -40003,
WXBizMsgCryptError::IllegalAesKey => -40004,
WXBizMsgCryptError::ValidateCorpidError => -40005,
WXBizMsgCryptError::EncryptAESError => -40006,
WXBizMsgCryptError::DecryptAESError => -40007,
WXBizMsgCryptError::IllegalBuffer => -40008,
WXBizMsgCryptError::EncodeBase64Error => -40009,
WXBizMsgCryptError::DecodeBase64Error => -40010,
WXBizMsgCryptError::GenReturnXmlError => -40011,
}
}
}
pub struct WXBizMsgCrypt {
token: String,
encoding_aes_key: String,
receive_id: String,
aes_key: Vec<u8>,
}
impl WXBizMsgCrypt {
pub fn new(
token: impl Into<String>,
encoding_aes_key: impl Into<String>,
receive_id: impl Into<String>,
) -> Result<Self, WXBizMsgCryptError> {
let token = token.into();
let encoding_aes_key = encoding_aes_key.into();
let receive_id = receive_id.into();
if encoding_aes_key.len() != ENCODING_KEY_SIZE {
return Err(WXBizMsgCryptError::IllegalAesKey);
}
let aes_key = Self::gen_aes_key_from_encoding_key(&encoding_aes_key)?;
Ok(Self {
token,
encoding_aes_key,
receive_id,
aes_key,
})
}
fn gen_aes_key_from_encoding_key(encoding_key: &str) -> Result<Vec<u8>, WXBizMsgCryptError> {
let base64_key = if encoding_key.len() % 4 == 3 {
format!("{}=", encoding_key)
} else if encoding_key.len() % 4 == 2 {
format!("{}==", encoding_key)
} else {
encoding_key.to_string()
};
match BASE64.decode(base64_key.as_bytes()) {
Ok(key) => {
if key.len() != AES_KEY_SIZE {
return Err(WXBizMsgCryptError::IllegalAesKey);
}
Ok(key)
}
Err(_) => Err(WXBizMsgCryptError::IllegalAesKey),
}
}
pub fn verify_url(
&self,
msg_signature: &str,
timestamp: &str,
nonce: &str,
echo_str: &str,
) -> Result<String, WXBizMsgCryptError> {
Self::validate_signature(&self.token, timestamp, nonce, echo_str, msg_signature)?;
let aes_data = BASE64
.decode(echo_str)
.map_err(|_| WXBizMsgCryptError::DecodeBase64Error)?;
let decrypted = self.aes_cbc_decrypt(&aes_data)?;
let (msg, receive_id) = Self::parse_decrypted_data(&decrypted)?;
if receive_id != self.receive_id {
return Err(WXBizMsgCryptError::ValidateCorpidError);
}
Ok(msg)
}
pub fn decrypt_msg(
&self,
msg_signature: &str,
timestamp: &str,
nonce: &str,
post_data: &str,
) -> Result<String, WXBizMsgCryptError> {
let encrypt_msg = Self::get_xml_field(post_data, "Encrypt")?;
Self::validate_signature(&self.token, timestamp, nonce, &encrypt_msg, msg_signature)?;
let aes_data = BASE64
.decode(&encrypt_msg)
.map_err(|_| WXBizMsgCryptError::DecodeBase64Error)?;
let decrypted = self.aes_cbc_decrypt(&aes_data)?;
let (msg, receive_id) = Self::parse_decrypted_data(&decrypted)?;
if receive_id != self.receive_id {
return Err(WXBizMsgCryptError::ValidateCorpidError);
}
Ok(msg)
}
pub fn encrypt_msg(
&self,
reply_msg: &str,
timestamp: &str,
nonce: &str,
) -> Result<String, WXBizMsgCryptError> {
if reply_msg.is_empty() {
return Err(WXBizMsgCryptError::ParseXmlError);
}
let need_encrypt = Self::gen_need_encrypt_data(reply_msg, &self.receive_id);
let aes_data = self.aes_cbc_encrypt(&need_encrypt)?;
let base64_data = BASE64.encode(&aes_data);
let signature =
Self::compute_signature(&self.token, timestamp, nonce, &base64_data)?;
Self::gen_return_xml(&base64_data, &signature, timestamp, nonce)
}
fn aes_cbc_encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, WXBizMsgCryptError> {
let key = &self.aes_key[..];
let iv = &self.aes_key[..AES_IV_SIZE];
let cipher = Aes128CbcEnc::new(key.into(), iv.into());
let mut buf = vec
![0u8; plaintext.len()
+ 16];
buf[..plaintext.len()].copy_from_slice(plaintext);
let ct_len = cipher
.encrypt_padded_mut::<Pkcs7>(&mut buf, plaintext.len())
.map_err(|_| WXBizMsgCryptError::EncryptAESError)?
.len();
buf.truncate(ct_len);
Ok(buf)
}
fn aes_cbc_decrypt(&self, ciphertext: &[u8]) -> Result<Vec<u8>, WXBizMsgCryptError> {
let key = &self.aes_key[..];
let iv = &self.aes_key[..AES_IV_SIZE];
let cipher = Aes128CbcDec::new(key.into(), iv.into());
let mut buf = ciphertext.to_vec();
let pt_len = cipher
.decrypt_padded_mut::<Pkcs7>(&mut buf)
.map_err(|_| WXBizMsgCryptError::DecryptAESError)?
.len();
buf.truncate(pt_len);
Ok(buf)
}
fn parse_decrypted_data(data: &[u8]) -> Result<(String, String), WXBizMsgCryptError> {
if data.len() <= (RAND_ENCRYPT_STR_LEN + MSG_LEN) {
return Err(WXBizMsgCryptError::IllegalBuffer);
}
let msg_len_bytes = &data[RAND_ENCRYPT_STR_LEN..RAND_ENCRYPT_STR_LEN + MSG_LEN];
let msg_len = u32::from_be_bytes([msg_len_bytes[0], msg_len_bytes[1], msg_len_bytes[2], msg_len_bytes[3]]) as usize;
let msg_end = RAND_ENCRYPT_STR_LEN + MSG_LEN + msg_len;
if data.len() < msg_end {
return Err(WXBizMsgCryptError::IllegalBuffer);
}
let msg = String::from_utf8_lossy(&data[RAND_ENCRYPT_STR_LEN + MSG_LEN..msg_end])
.to_string();
let receive_id = String::from_utf8_lossy(&data[msg_end..]).to_string();
Ok((msg, receive_id))
}
fn gen_need_encrypt_data(msg: &str, receive_id: &str) -> Vec<u8> {
let mut rng = rand::thread_rng();
let rand_str: Vec<u8> = (0..RAND_ENCRYPT_STR_LEN)
.map(|_| rng.gen_range(33..128) as u8) .collect();
let msg_len = (msg.len() as u32).to_be_bytes();
let mut result = Vec::with_capacity(RAND_ENCRYPT_STR_LEN + MSG_LEN + msg.len() + receive_id.len());
result.extend_from_slice(&rand_str);
result.extend_from_slice(&msg_len);
result.extend_from_slice(msg.as_bytes());
result.extend_from_slice(receive_id.as_bytes());
result
}
pub fn get_xml_field(xml_data: &str, field_name: &str) -> Result<String, WXBizMsgCryptError> {
let pattern = format!(r#"<{}><!\[CDATA\[(.*?)\]\]></{}>"#, field_name, field_name);
let re = regex::Regex::new(&pattern).unwrap();
if let Some(caps) = re.captures(xml_data) {
Ok(caps.get(1).unwrap().as_str().to_string())
} else {
let pattern = format!(r#"<{}>(.*?)</{}>"#, field_name, field_name);
let re = regex::Regex::new(&pattern).unwrap();
if let Some(caps) = re.captures(xml_data) {
Ok(caps.get(1).unwrap().as_str().to_string())
} else {
Err(WXBizMsgCryptError::ParseXmlError)
}
}
}
fn compute_signature(
token: &str,
timestamp: &str,
nonce: &str,
message: &str,
) -> Result<String, WXBizMsgCryptError> {
if token.is_empty() || nonce.is_empty() || message.is_empty() || timestamp.is_empty() {
return Err(WXBizMsgCryptError::ComputeSignatureError);
}
let mut params = vec![token, timestamp, nonce, message];
params.sort();
let combined = params.join("");
let mut hasher = Sha1::new();
hasher.update(combined.as_bytes());
let result = hasher.finalize();
Ok(hex::encode(result))
}
fn validate_signature(
token: &str,
timestamp: &str,
nonce: &str,
encrypt_msg: &str,
msg_signature: &str,
) -> Result<(), WXBizMsgCryptError> {
let signature = Self::compute_signature(token, timestamp, nonce, encrypt_msg)?;
if signature != msg_signature {
return Err(WXBizMsgCryptError::ValidateSignatureError);
}
Ok(())
}
fn gen_return_xml(
encrypt_msg: &str,
signature: &str,
timestamp: &str,
nonce: &str,
) -> Result<String, WXBizMsgCryptError> {
let xml = format!(
"<xml><Encrypt><![CDATA[{}]]></Encrypt><MsgSignature><![CDATA[{}]]></MsgSignature><TimeStamp>{}</TimeStamp><Nonce><![CDATA[{}]]></Nonce></xml>",
encrypt_msg, signature, timestamp, nonce
);
Ok(xml)
}
}
#[cfg(test)]
mod tests {
use super::*;
const TOKEN: &str = "QDG6eK";
const ENCODING_AES_KEY: &str = "AQIDBAUGBwgJCgsMDQ4PEBESExQVFhcYGRobHB0eHyA";
const RECEIVE_ID: &str = "wx5823bf96d3bd56c7";
#[test]
fn test_verify_url() {
let crypt = WXBizMsgCrypt::new(TOKEN, ENCODING_AES_KEY, RECEIVE_ID).unwrap();
let reply_msg = "hello from testcase";
let timestamp = "1409659813";
let nonce = "263014780";
let encrypted = crypt.encrypt_msg(reply_msg, timestamp, nonce).unwrap();
let msg_signature = WXBizMsgCrypt::get_xml_field(&encrypted, "MsgSignature").unwrap();
let encrypt_msg = WXBizMsgCrypt::get_xml_field(&encrypted, "Encrypt").unwrap();
let result = crypt.verify_url(&msg_signature, timestamp, nonce, &encrypt_msg);
assert!(result.is_ok());
assert_eq!(result.unwrap(), reply_msg);
}
#[test]
fn test_encrypt_decrypt() {
let crypt = WXBizMsgCrypt::new(TOKEN, ENCODING_AES_KEY, RECEIVE_ID).unwrap();
let reply_msg = "<xml><Content>Hello World</Content></xml>";
let timestamp = "1409659813";
let nonce = "263014780";
let encrypted = crypt.encrypt_msg(reply_msg, timestamp, nonce).unwrap();
println!("Encrypted: {}", encrypted);
let post_data = encrypted;
let encrypt_field = WXBizMsgCrypt::get_xml_field(&post_data, "Encrypt").unwrap();
let signature_field = WXBizMsgCrypt::get_xml_field(&post_data, "MsgSignature").unwrap();
let decrypted = crypt.decrypt_msg(&signature_field, timestamp, nonce, &post_data).unwrap();
assert_eq!(decrypted, reply_msg);
}
#[test]
fn test_error_codes() {
let err = WXBizMsgCryptError::ValidateSignatureError;
assert_eq!(err.error_code(), -40001);
let err = WXBizMsgCryptError::ParseXmlError;
assert_eq!(err.error_code(), -40002);
let err = WXBizMsgCryptError::IllegalAesKey;
assert_eq!(err.error_code(), -40004);
}
}