use std::{
io::{Read, Write},
sync::Arc,
};
use log::{debug, warn};
use mbedtls::{
Result as TlsResult,
error::codes,
rng::CtrDrbg,
ssl::{
CipherSuite::EcdhePskWithSm4128GcmSm3,
Config, Context, Version,
config::{Endpoint, Preset, Transport},
},
};
use virga::client::{ClientConfig, VirgeClient};
use zeroize::Zeroize;
use teec_protocol::{CHUNK_SIZE, PacketHeader, PacketType};
use xtee_psk::{
CryptoRng, PSK_LEN, PskError, derive_psk, ecdh_compute_shared, extract_ec_point,
generate_ecdh_keypair, get_psk_identity, new_crypto_rng, parse_ecdh_response,
verify_ecdh_signature,
};
use crate::cc_client::{
tofu::verify_server_identity,
vsock_define::{get_vsock_cid, get_vsock_port},
};
pub(crate) struct CcClient {
ctx: Context<VirgeClient>,
}
unsafe impl Send for CcClient {}
impl CcClient {
pub fn init() -> TlsResult<Self> {
let _ = env_logger::try_init();
let entropy = Arc::new(mbedtls::rng::OsEntropy::new());
let rng = CtrDrbg::new(entropy.clone(), None)?;
let mut crypto_rng = new_crypto_rng().map_err(psk_to_mbedtls_error)?;
let cid = get_vsock_cid();
let port = get_vsock_port();
let vconfig = ClientConfig::new(cid, port, CHUNK_SIZE as u32, false);
let mut client = VirgeClient::new(vconfig);
client.connect().map_err(|e| {
warn!("VirgeClient 连接失败:{e}");
mbedtls::Error::LowLevel(codes::NetConnectFailed)
})?;
let (mut psk, psk_identity) = ecdh_negotiate(&mut client, &mut crypto_rng)?;
debug!("ECDH 密钥协商成功");
let cipher_suites: Vec<i32> = vec![EcdhePskWithSm4128GcmSm3.into(), 0];
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
config.set_min_version(Version::Tls1_2)?;
config.set_max_version(Version::Tls1_2)?;
config.set_ciphersuites(Arc::new(cipher_suites));
config.set_read_timeout(5000); config.set_rng(Arc::new(rng));
config.set_psk(&psk, psk_identity)?;
psk.zeroize();
let mut ctx = Context::new(Arc::new(config));
ctx.establish(client, None).map_err(|e| {
warn!("与服务端握手失败:{e}");
e
})?;
debug!("与服务端握手成功,ciphersuite: {:4x?}", ctx.ciphersuite());
Ok(Self { ctx })
}
pub fn send_data_with_header(
&mut self,
packet_type: PacketType,
data: &[u8],
) -> std::io::Result<()> {
let header = PacketHeader {
data_type: u64::from(packet_type),
data_size: data.len() as u64,
};
debug!("客户端:发送协议头:{:x?}", header.as_bytes());
self.ctx.write_all(header.as_bytes()).map_err(|e| {
warn!("客户端:发送协议头失败:{e}");
e
})?;
self.send_data(data)
}
pub fn send_data(&mut self, data: &[u8]) -> std::io::Result<()> {
self.ctx.write_all(data)?;
debug!("客户端:发送数据,大小:{}", data.len());
Ok(())
}
pub fn recv_data(&mut self, data: &mut [u8]) -> std::io::Result<()> {
self.ctx.read_exact(data)?;
debug!("客户端:接收数据,实际大小:{}", data.len());
Ok(())
}
pub fn close(&mut self) {
self.ctx.close();
}
}
fn psk_to_mbedtls_error(e: PskError) -> mbedtls::Error {
debug!("PSK 协商错误:{e}");
mbedtls::Error::HighLevel(codes::PkBadInputData)
}
fn ecdh_negotiate(
client: &mut VirgeClient,
ecdh_rng: &mut CryptoRng,
) -> TlsResult<([u8; PSK_LEN], &'static str)> {
let ecdh_key = generate_ecdh_keypair(ecdh_rng).map_err(|e| {
warn!("生成 SM2 密钥对失败:{e}");
psk_to_mbedtls_error(e)
})?;
let client_point = extract_ec_point(&ecdh_key).map_err(psk_to_mbedtls_error)?;
client.send(client_point.clone()).map_err(|e| {
warn!("发送 ECDH 请求失败:{e}");
mbedtls::Error::LowLevel(codes::NetSendFailed)
})?;
let recv_data = client.recv().map_err(|e| {
warn!("接收 ECDH 响应失败:{e}");
mbedtls::Error::LowLevel(codes::NetRecvFailed)
})?;
let resp = parse_ecdh_response(&recv_data).map_err(|e| {
warn!("解析 ECDH 响应失败:{e}");
psk_to_mbedtls_error(e)
})?;
verify_server_identity(resp.long_term_point).map_err(|e| {
warn!("服务端身份校验失败:{e}");
psk_to_mbedtls_error(e)
})?;
verify_ecdh_signature(
resp.long_term_point,
&client_point,
resp.server_point,
resp.signature,
)
.map_err(|e| {
warn!("服务端 ECDH 签名验证失败:{e}");
psk_to_mbedtls_error(e)
})?;
debug!("服务端 ECDH 签名验证通过");
let shared = ecdh_compute_shared(&ecdh_key, resp.server_point).map_err(|e| {
warn!("ECDH 共享秘密计算失败:{e}");
psk_to_mbedtls_error(e)
})?;
drop(ecdh_key);
let psk_result = derive_psk(&shared, &client_point, resp.server_point);
let psk = psk_result.map_err(|e| {
warn!("PSK 派生失败:{e}");
psk_to_mbedtls_error(e)
})?;
Ok((psk, get_psk_identity()))
}
#[unsafe(no_mangle)]
pub extern "C" fn cc_check_enable() -> i32 {
match CcClient::init() {
Ok(mut client) => {
client.close();
1
}
Err(_) => 0,
}
}
#[cfg(test)]
mod cc_client_tests {
use super::*;
use teec_protocol::PacketType;
#[test]
fn test_packet_header_serialization() {
let header = PacketHeader {
data_type: 1,
data_size: 1024,
};
let bytes = header.as_bytes();
assert_eq!(bytes.len(), 16); }
#[test]
fn test_all_packet_types() {
assert_eq!(u64::from(PacketType::Unknown), 0);
assert_eq!(PacketType::from(0), PacketType::Unknown);
assert_eq!(u64::from(PacketType::OpenSession), 1);
assert_eq!(PacketType::from(1), PacketType::OpenSession);
assert_eq!(u64::from(PacketType::CloseSession), 2);
assert_eq!(PacketType::from(2), PacketType::CloseSession);
assert_eq!(u64::from(PacketType::InvokeCommand), 3);
assert_eq!(PacketType::from(3), PacketType::InvokeCommand);
assert_eq!(u64::from(PacketType::RequestCancellation), 4);
assert_eq!(PacketType::from(4), PacketType::RequestCancellation);
}
#[test]
fn test_packet_type_conversion_consistency() {
for value in 0..=4 {
let packet_type = PacketType::from(value);
let converted_value = u64::from(packet_type);
assert_eq!(
value, converted_value,
"双向转换不一致: 原始值={}, 转换后值={}",
value, converted_value
);
}
let out_of_range_value = 99;
let packet_type = PacketType::from(out_of_range_value);
assert_eq!(
packet_type,
PacketType::Unknown,
"超出范围的值 {} 应该转换为 PacketType::Unknown",
out_of_range_value
);
assert_eq!(
u64::from(packet_type),
0,
"PacketType::Unknown 应该转换为 0"
);
}
}