use alloc::boxed::Box;
use alloc::vec::Vec;
use rustls::crypto::cipher::{
make_tls13_aad, AeadKey, EncodedMessage, InboundOpaque, Iv, MessageDecrypter, MessageEncrypter,
Nonce, OutboundOpaque, OutboundPlain, Tls13AeadAlgorithm, UnsupportedOperationError, NONCE_LEN,
};
use rustls::crypto::tls13::HkdfUsingHmac;
use rustls::crypto::CipherSuite;
use rustls::enums::{ContentType, ProtocolVersion};
use rustls::error::Error;
use rustls::version::TLS13_VERSION;
use rustls::{CipherSuiteCommon, ConnectionTrafficSecrets, Tls13CipherSuite};
use crate::sm4::{sm4_decrypt_ccm, sm4_decrypt_gcm, sm4_encrypt_ccm, sm4_encrypt_gcm};
pub(crate) static HKDF_SM3: HkdfUsingHmac<'static> = HkdfUsingHmac(&super::hmac::HMAC_SM3);
pub static TLS13_SM4_GCM_SM3: &Tls13CipherSuite = &Tls13CipherSuite {
common: CipherSuiteCommon {
suite: CipherSuite::TLS13_SM4_GCM_SM3,
hash_provider: &super::hash::SM3,
confidentiality_limit: 1 << 24,
},
protocol_version: TLS13_VERSION,
hkdf_provider: &HKDF_SM3,
aead_alg: &Sm4GcmAead,
quic: None,
};
pub static TLS13_SM4_CCM_SM3: &Tls13CipherSuite = &Tls13CipherSuite {
common: CipherSuiteCommon {
suite: CipherSuite::TLS13_SM4_CCM_SM3,
hash_provider: &super::hash::SM3,
confidentiality_limit: 1 << 24,
},
protocol_version: TLS13_VERSION,
hkdf_provider: &HKDF_SM3,
aead_alg: &Sm4CcmAead,
quic: None,
};
struct Sm4GcmAead;
impl Tls13AeadAlgorithm for Sm4GcmAead {
fn encrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageEncrypter> {
Box::new(Sm4GcmEncrypter {
key: aead_key_to_16(&key),
iv,
})
}
fn decrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageDecrypter> {
Box::new(Sm4GcmDecrypter {
key: aead_key_to_16(&key),
iv,
})
}
fn key_len(&self) -> usize {
16
}
fn extract_keys(
&self,
key: AeadKey,
iv: Iv,
) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError> {
Ok(ConnectionTrafficSecrets::Sm4Gcm { key, iv })
}
}
struct Sm4GcmEncrypter {
key: [u8; 16],
iv: Iv,
}
impl MessageEncrypter for Sm4GcmEncrypter {
fn encrypt(
&mut self,
msg: EncodedMessage<OutboundPlain<'_>>,
seq: u64,
) -> Result<EncodedMessage<OutboundOpaque>, Error> {
let total_len = self.encrypted_payload_len(msg.payload.len());
let nonce = Nonce::new(&self.iv, seq).to_array::<NONCE_LEN>()?;
let aad = make_tls13_aad(total_len);
let mut plaintext: Vec<u8> = Vec::with_capacity(msg.payload.len() + 1);
{
let mut tmp = OutboundOpaque::with_capacity(msg.payload.len() + 1);
tmp.extend_from_chunks(&msg.payload);
tmp.extend_from_slice(&msg.typ.to_array());
plaintext.extend_from_slice(tmp.as_ref());
}
let (ciphertext, tag) = sm4_encrypt_gcm(&self.key, &nonce, &aad, &plaintext);
let mut out = OutboundOpaque::with_capacity(ciphertext.len() + 16);
out.extend_from_slice(&ciphertext);
out.extend_from_slice(&tag);
Ok(EncodedMessage {
typ: ContentType::ApplicationData,
version: ProtocolVersion::TLSv1_2,
payload: out,
})
}
fn encrypted_payload_len(&self, payload_len: usize) -> usize {
payload_len + 1 + 16 }
}
struct Sm4GcmDecrypter {
key: [u8; 16],
iv: Iv,
}
impl MessageDecrypter for Sm4GcmDecrypter {
fn decrypt<'a>(
&mut self,
mut msg: EncodedMessage<InboundOpaque<'a>>,
seq: u64,
) -> Result<EncodedMessage<&'a [u8]>, Error> {
let payload = &mut msg.payload;
if payload.len() < 16 {
return Err(Error::DecryptError);
}
let nonce = Nonce::new(&self.iv, seq).to_array::<NONCE_LEN>()?;
let aad = make_tls13_aad(payload.len());
let ct_len = payload.len() - 16;
let tag: [u8; 16] = payload[ct_len..]
.try_into()
.map_err(|_| Error::DecryptError)?;
let plaintext = sm4_decrypt_gcm(&self.key, &nonce, &aad, &payload[..ct_len], &tag)
.map_err(|_| Error::DecryptError)?;
let plain_len = plaintext.len();
payload[..plain_len].copy_from_slice(&plaintext);
payload.truncate(plain_len);
msg.into_tls13_unpadded_message()
}
}
struct Sm4CcmAead;
impl Tls13AeadAlgorithm for Sm4CcmAead {
fn encrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageEncrypter> {
Box::new(Sm4CcmEncrypter {
key: aead_key_to_16(&key),
iv,
})
}
fn decrypter(&self, key: AeadKey, iv: Iv) -> Box<dyn MessageDecrypter> {
Box::new(Sm4CcmDecrypter {
key: aead_key_to_16(&key),
iv,
})
}
fn key_len(&self) -> usize {
16
}
fn extract_keys(
&self,
key: AeadKey,
iv: Iv,
) -> Result<ConnectionTrafficSecrets, UnsupportedOperationError> {
Ok(ConnectionTrafficSecrets::Sm4Ccm { key, iv })
}
}
struct Sm4CcmEncrypter {
key: [u8; 16],
iv: Iv,
}
impl MessageEncrypter for Sm4CcmEncrypter {
fn encrypt(
&mut self,
msg: EncodedMessage<OutboundPlain<'_>>,
seq: u64,
) -> Result<EncodedMessage<OutboundOpaque>, Error> {
let total_len = self.encrypted_payload_len(msg.payload.len());
let nonce = Nonce::new(&self.iv, seq).to_array::<NONCE_LEN>()?;
let aad = make_tls13_aad(total_len);
let mut plaintext: Vec<u8> = Vec::with_capacity(msg.payload.len() + 1);
{
let mut tmp = OutboundOpaque::with_capacity(msg.payload.len() + 1);
tmp.extend_from_chunks(&msg.payload);
tmp.extend_from_slice(&msg.typ.to_array());
plaintext.extend_from_slice(tmp.as_ref());
}
let combined = sm4_encrypt_ccm(&self.key, &nonce, &aad, &plaintext, 16)
.map_err(|_| Error::EncryptError)?;
let mut out = OutboundOpaque::with_capacity(combined.len());
out.extend_from_slice(&combined);
Ok(EncodedMessage {
typ: ContentType::ApplicationData,
version: ProtocolVersion::TLSv1_2,
payload: out,
})
}
fn encrypted_payload_len(&self, payload_len: usize) -> usize {
payload_len + 1 + 16
}
}
struct Sm4CcmDecrypter {
key: [u8; 16],
iv: Iv,
}
impl MessageDecrypter for Sm4CcmDecrypter {
fn decrypt<'a>(
&mut self,
mut msg: EncodedMessage<InboundOpaque<'a>>,
seq: u64,
) -> Result<EncodedMessage<&'a [u8]>, Error> {
let payload = &mut msg.payload;
if payload.len() < 16 {
return Err(Error::DecryptError);
}
let nonce = Nonce::new(&self.iv, seq).to_array::<NONCE_LEN>()?;
let aad = make_tls13_aad(payload.len());
let plaintext = sm4_decrypt_ccm(&self.key, &nonce, &aad, &payload[..], 16)
.map_err(|_| Error::DecryptError)?;
let plain_len = plaintext.len();
payload[..plain_len].copy_from_slice(&plaintext);
payload.truncate(plain_len);
msg.into_tls13_unpadded_message()
}
}
fn aead_key_to_16(key: &AeadKey) -> [u8; 16] {
let mut out = [0u8; 16];
out.copy_from_slice(&key.as_ref()[..16]);
out
}