use std::{
io::{Read, Write},
sync::{Arc, Once},
};
use log::{debug, warn};
use mbedtls::{
Result as TlsResult,
error::codes,
rng::CtrDrbg,
ssl::{
CipherSuite::EcdhePskWithSm4128GcmSm3,
Config, Context, Version,
config::{Endpoint, Preset, Transport},
},
};
use virga::client::{ClientConfig, VirgeClient};
use teec_protocol::{CHUNK_SIZE, PacketHeader, PacketType};
use xtee_psk::{
PskError, client_ecdh_negotiate, new_crypto_rng, virga_transport::VirgeClientTransport,
};
use crate::cc_client::{
tofu::verify_server_identity,
vsock_define::{get_vsock_cid, get_vsock_port},
};
pub(crate) struct CcClient {
ctx: Context<VirgeClient>,
}
unsafe impl Send for CcClient {}
static LOGGER_INIT: Once = Once::new();
impl CcClient {
pub fn init() -> TlsResult<Self> {
LOGGER_INIT.call_once(|| {
let _ = env_logger::try_init();
});
let entropy = Arc::new(mbedtls::rng::OsEntropy::new());
let rng = CtrDrbg::new(entropy.clone(), None)?;
let mut crypto_rng = new_crypto_rng().map_err(psk_to_mbedtls_error)?;
let cid = get_vsock_cid();
let port = get_vsock_port();
debug!("[HOST] CcClient::init: vsock cid={cid} port={port}");
let vconfig = ClientConfig::new(cid, port, CHUNK_SIZE as u32, false);
let mut client = VirgeClient::new(vconfig);
debug!("[HOST] CcClient::init: calling client.connect()");
client.connect().map_err(|e| {
warn!("VirgeClient 连接失败:{e}");
mbedtls::Error::LowLevel(codes::NetConnectFailed)
})?;
debug!("[HOST] CcClient::init: connect done, starting TLS handshake");
let mut transport = VirgeClientTransport(client);
let (psk, long_term_point, psk_identity) =
client_ecdh_negotiate(&mut transport, &mut crypto_rng).map_err(psk_to_mbedtls_error)?;
let client = transport.0;
verify_server_identity(&long_term_point).map_err(|e| {
warn!("服务端身份校验失败:{e}");
psk_to_mbedtls_error(e)
})?;
debug!("ECDH 密钥协商成功");
let cipher_suites: Vec<i32> = vec![EcdhePskWithSm4128GcmSm3.into(), 0];
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
config.set_min_version(Version::Tls1_2)?;
config.set_max_version(Version::Tls1_2)?;
config.set_ciphersuites(Arc::new(cipher_suites));
config.set_read_timeout(5000); config.set_rng(Arc::new(rng));
config.set_psk(&*psk, psk_identity)?;
let mut ctx = Context::new(Arc::new(config));
ctx.establish(client, None).map_err(|e| {
warn!("与服务端握手失败:{e}");
e
})?;
let ciphersuite = ctx.ciphersuite();
debug!("[HOST] CcClient::init: 与服务端握手成功, ciphersuite={ciphersuite:4x?}");
Ok(Self { ctx })
}
pub fn send_data_with_header(
&mut self,
packet_type: PacketType,
data: &[u8],
) -> std::io::Result<()> {
let header = PacketHeader {
data_type: u64::from(packet_type),
data_size: data.len() as u64,
};
debug!(
"[HOST] send_data_with_header: data_type={} data_len={} header={:x?}",
header.data_type,
data.len(),
header.as_bytes()
);
self.ctx.write_all(header.as_bytes()).map_err(|e| {
warn!("客户端:发送协议头失败:{e}");
e
})?;
self.send_data(data)
}
pub fn send_data(&mut self, data: &[u8]) -> std::io::Result<()> {
self.ctx.write_all(data)?;
debug!("客户端:发送数据,大小:{}", data.len());
Ok(())
}
pub fn recv_data(&mut self, data: &mut [u8]) -> std::io::Result<()> {
self.ctx.read_exact(data)?;
debug!("[HOST] recv_data: len={}", data.len());
Ok(())
}
pub fn close(&mut self) {
self.ctx.close();
}
}
fn psk_to_mbedtls_error(e: PskError) -> mbedtls::Error {
debug!("PSK 协商错误:{e}");
mbedtls::Error::HighLevel(codes::PkBadInputData)
}
#[unsafe(no_mangle)]
pub extern "C" fn cc_check_enable() -> i32 {
match CcClient::init() {
Ok(mut client) => {
client.close();
1
}
Err(_) => 0,
}
}
#[cfg(test)]
mod cc_client_tests {
use super::*;
use teec_protocol::PacketType;
#[test]
fn test_packet_header_serialization() {
let header = PacketHeader {
data_type: 1,
data_size: 1024,
};
let bytes = header.as_bytes();
assert_eq!(bytes.len(), 16); }
#[test]
fn test_all_packet_types() {
assert_eq!(u64::from(PacketType::Unknown), 0);
assert_eq!(PacketType::from(0), PacketType::Unknown);
assert_eq!(u64::from(PacketType::OpenSession), 1);
assert_eq!(PacketType::from(1), PacketType::OpenSession);
assert_eq!(u64::from(PacketType::CloseSession), 2);
assert_eq!(PacketType::from(2), PacketType::CloseSession);
assert_eq!(u64::from(PacketType::InvokeCommand), 3);
assert_eq!(PacketType::from(3), PacketType::InvokeCommand);
assert_eq!(u64::from(PacketType::RequestCancellation), 4);
assert_eq!(PacketType::from(4), PacketType::RequestCancellation);
}
#[test]
fn test_packet_type_conversion_consistency() {
for value in 0..=4 {
let packet_type = PacketType::from(value);
let converted_value = u64::from(packet_type);
assert_eq!(
value, converted_value,
"双向转换不一致: 原始值={}, 转换后值={}",
value, converted_value
);
}
let out_of_range_value = 99;
let packet_type = PacketType::from(out_of_range_value);
assert_eq!(
packet_type,
PacketType::Unknown,
"超出范围的值 {} 应该转换为 PacketType::Unknown",
out_of_range_value
);
assert_eq!(
u64::from(packet_type),
0,
"PacketType::Unknown 应该转换为 0"
);
}
}