use aes::cipher::{BlockDecryptMut, BlockEncryptMut, KeyIvInit};
use base64::{
Engine, alphabet,
engine::{GeneralPurpose, GeneralPurposeConfig, general_purpose::STANDARD as BASE64},
};
use rand::Rng;
use sha1::{Digest, Sha1};
use crate::error::{Result, WeChatError};
type Aes256CbcEnc = cbc::Encryptor<aes::Aes256>;
type Aes256CbcDec = cbc::Decryptor<aes::Aes256>;
const AES_BLOCK_SIZE: usize = 16;
const LENIENT_BASE64: GeneralPurpose = GeneralPurpose::new(
&alphabet::STANDARD,
GeneralPurposeConfig::new().with_decode_allow_trailing_bits(true),
);
pub fn check_signature(token: &str, signature: &str, timestamp: &str, nonce: &str) -> bool {
let computed = compute_signature(token, timestamp, nonce);
computed == signature
}
fn compute_signature(token: &str, timestamp: &str, nonce: &str) -> String {
let mut parts = [token, timestamp, nonce];
parts.sort();
let input = parts.join("");
let mut hasher = Sha1::new();
hasher.update(input.as_bytes());
let result = hasher.finalize();
hex::encode(result)
}
pub fn compute_msg_signature(
token: &str,
timestamp: &str,
nonce: &str,
encrypt_msg: &str,
) -> String {
let mut parts = [token, timestamp, nonce, encrypt_msg];
parts.sort();
let input = parts.join("");
let mut hasher = Sha1::new();
hasher.update(input.as_bytes());
let result = hasher.finalize();
hex::encode(result)
}
pub fn check_msg_signature(
token: &str,
msg_signature: &str,
timestamp: &str,
nonce: &str,
encrypt_msg: &str,
) -> bool {
let computed = compute_msg_signature(token, timestamp, nonce, encrypt_msg);
computed == msg_signature
}
pub fn decode_aes_key(encoding_aes_key: &str) -> Result<[u8; 32]> {
if encoding_aes_key.len() != 43 {
return Err(WeChatError::InvalidAesKey);
}
let padded = format!("{}=", encoding_aes_key);
let decoded = LENIENT_BASE64
.decode(padded)
.map_err(|_| WeChatError::InvalidAesKey)?;
if decoded.len() != 32 {
return Err(WeChatError::InvalidAesKey);
}
let mut key = [0u8; 32];
key.copy_from_slice(&decoded);
Ok(key)
}
pub fn decrypt_message(aes_key: &[u8; 32], encrypted: &str) -> Result<(String, String)> {
let ciphertext = BASE64
.decode(encrypted)
.map_err(|e| WeChatError::DecryptionFailed(format!("Base64 decode failed: {}", e)))?;
if ciphertext.len() < AES_BLOCK_SIZE || ciphertext.len() % AES_BLOCK_SIZE != 0 {
return Err(WeChatError::DecryptionFailed(
"Invalid ciphertext length".to_string(),
));
}
let iv: [u8; 16] = aes_key[..16].try_into().unwrap();
let mut buf = ciphertext.to_vec();
let decrypted = Aes256CbcDec::new(aes_key.into(), &iv.into())
.decrypt_padded_mut::<aes::cipher::block_padding::NoPadding>(&mut buf)
.map_err(|e| WeChatError::DecryptionFailed(format!("AES decrypt failed: {}", e)))?;
let unpadded = pkcs7_unpad(decrypted)?;
if unpadded.len() < 20 {
return Err(WeChatError::InvalidMessageFormat);
}
let content = &unpadded[16..];
let msg_len = u32::from_be_bytes([content[0], content[1], content[2], content[3]]) as usize;
if content.len() < 4 + msg_len {
return Err(WeChatError::InvalidMessageFormat);
}
let msg = &content[4..4 + msg_len];
let app_id = &content[4 + msg_len..];
let xml = String::from_utf8(msg.to_vec())
.map_err(|e| WeChatError::DecryptionFailed(format!("Invalid UTF-8: {}", e)))?;
let app_id = String::from_utf8(app_id.to_vec())
.map_err(|e| WeChatError::DecryptionFailed(format!("Invalid UTF-8 in AppID: {}", e)))?;
Ok((xml, app_id))
}
pub fn encrypt_message(aes_key: &[u8; 32], app_id: &str, xml: &str) -> Result<String> {
let xml_bytes = xml.as_bytes();
let app_id_bytes = app_id.as_bytes();
let random_bytes: [u8; 16] = rand::random();
let msg_len = xml_bytes.len() as u32;
let msg_len_bytes = msg_len.to_be_bytes();
let mut plaintext = Vec::with_capacity(16 + 4 + xml_bytes.len() + app_id_bytes.len());
plaintext.extend_from_slice(&random_bytes);
plaintext.extend_from_slice(&msg_len_bytes);
plaintext.extend_from_slice(xml_bytes);
plaintext.extend_from_slice(app_id_bytes);
let padded = pkcs7_pad(&plaintext, AES_BLOCK_SIZE);
let iv: [u8; 16] = aes_key[..16].try_into().unwrap();
let mut buf = padded;
let buf_len = buf.len();
let ciphertext = Aes256CbcEnc::new(aes_key.into(), &iv.into())
.encrypt_padded_mut::<aes::cipher::block_padding::NoPadding>(&mut buf, buf_len)
.map_err(|e| WeChatError::EncryptionFailed(format!("AES encrypt failed: {}", e)))?;
Ok(BASE64.encode(ciphertext))
}
fn pkcs7_pad(data: &[u8], block_size: usize) -> Vec<u8> {
let padding_len = block_size - (data.len() % block_size);
let mut padded = data.to_vec();
padded.extend(std::iter::repeat(padding_len as u8).take(padding_len));
padded
}
fn pkcs7_unpad(data: &[u8]) -> Result<&[u8]> {
if data.is_empty() {
return Err(WeChatError::DecryptionFailed("Empty data".to_string()));
}
let padding_len = data[data.len() - 1] as usize;
if padding_len == 0 || padding_len > AES_BLOCK_SIZE || padding_len > data.len() {
return Err(WeChatError::DecryptionFailed("Invalid padding".to_string()));
}
for &byte in &data[data.len() - padding_len..] {
if byte as usize != padding_len {
return Err(WeChatError::DecryptionFailed("Invalid padding".to_string()));
}
}
Ok(&data[..data.len() - padding_len])
}
pub fn generate_nonce() -> String {
let chars: Vec<char> = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
.chars()
.collect();
let mut rng = rand::thread_rng();
(0..16)
.map(|_| chars[rng.gen_range(0..chars.len())])
.collect()
}
pub fn generate_encrypted_xml(
encrypt_msg: &str,
msg_signature: &str,
timestamp: &str,
nonce: &str,
) -> String {
format!(
r#"<xml>
<Encrypt><![CDATA[{}]]></Encrypt>
<MsgSignature><![CDATA[{}]]></MsgSignature>
<TimeStamp>{}</TimeStamp>
<Nonce><![CDATA[{}]]></Nonce>
</xml>"#,
encrypt_msg, msg_signature, timestamp, nonce
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_check_signature() {
let token = "test_token";
let timestamp = "1234567890";
let nonce = "abc123";
let signature = compute_signature(token, timestamp, nonce);
assert!(check_signature(token, &signature, timestamp, nonce));
assert!(!check_signature(token, "wrong_signature", timestamp, nonce));
}
#[test]
fn test_decode_aes_key() {
let key = "MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MEE";
let result = decode_aes_key(key);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 32);
let short_key = "abc";
assert!(decode_aes_key(short_key).is_err());
}
#[test]
fn test_pkcs7_pad_unpad() {
let data = b"hello";
let padded = pkcs7_pad(data, 16);
assert_eq!(padded.len(), 16);
assert_eq!(padded[5..], [11u8; 11]);
let unpadded = pkcs7_unpad(&padded).unwrap();
assert_eq!(unpadded, data);
}
#[test]
fn test_encrypt_decrypt_roundtrip() {
let key_str = "MDEyMzQ1Njc4OTAxMjM0NTY3ODkwMTIzNDU2Nzg5MEE";
let aes_key = decode_aes_key(key_str).unwrap();
let app_id = "wx1234567890";
let xml = "<xml><Content>Hello</Content></xml>";
let encrypted = encrypt_message(&aes_key, app_id, xml).unwrap();
let (decrypted_xml, decrypted_app_id) = decrypt_message(&aes_key, &encrypted).unwrap();
assert_eq!(decrypted_xml, xml);
assert_eq!(decrypted_app_id, app_id);
}
#[test]
fn test_msg_signature() {
let token = "test_token";
let timestamp = "1234567890";
let nonce = "abc123";
let encrypt_msg = "encrypted_content";
let sig = compute_msg_signature(token, timestamp, nonce, encrypt_msg);
assert!(check_msg_signature(
token,
&sig,
timestamp,
nonce,
encrypt_msg
));
assert!(!check_msg_signature(
token,
"wrong",
timestamp,
nonce,
encrypt_msg
));
}
#[test]
fn test_generate_nonce() {
let nonce1 = generate_nonce();
let nonce2 = generate_nonce();
assert_eq!(nonce1.len(), 16);
assert_eq!(nonce2.len(), 16);
assert_ne!(nonce1, nonce2);
}
}