#![allow(dead_code)]
use std::{
io::{Read, Result, Write},
sync::Arc,
};
use log::{debug, warn};
use mbedtls::{
Result as TlsResult,
error::codes,
rng::CtrDrbg,
ssl::{
CipherSuite::{
DhePskWithSm4128GcmSm3, EcdhePskWithSm4128GcmSm3, PskWithSm4128GcmSm3,
RsaPskWithSm4128GcmSm3,
},
Config, Context,
config::{Endpoint, Preset, Transport},
},
};
use virga::client::{ClientConfig, VirgeClient};
use crate::common::{
protocol::{CHUNK_SIZE, PacketHeader, PacketType},
psk::{generate_psk, get_psk_identity},
vsock_define::{get_vsock_cid, get_vsock_port},
};
pub struct CcClient {
pub ctx: Context<VirgeClient>,
}
unsafe impl Send for CcClient {}
unsafe impl Sync for CcClient {}
impl CcClient {
pub fn init() -> TlsResult<Self> {
let _ = env_logger::try_init();
let entropy = Arc::new(mbedtls::rng::OsEntropy::new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let cipher_suites: Vec<i32> = vec![
EcdhePskWithSm4128GcmSm3.into(),
DhePskWithSm4128GcmSm3.into(),
RsaPskWithSm4128GcmSm3.into(),
PskWithSm4128GcmSm3.into(),
0,
];
let psk = generate_psk()?;
let psk_identify = get_psk_identity();
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
config.set_rng(rng);
config.set_ciphersuites(Arc::new(cipher_suites));
config.set_psk(&psk, psk_identify)?;
config.set_read_timeout(5000);
let mut ctx = Context::new(Arc::new(config));
let cid = get_vsock_cid();
let port = get_vsock_port();
let vconfig = ClientConfig::new(cid, port, CHUNK_SIZE as u32, false);
let mut client = VirgeClient::new(vconfig);
client.connect().map_err(|e| {
warn!("VirgeClient 连接失败:{e}");
mbedtls::Error::LowLevel(codes::NetConnectFailed)
})?;
ctx.establish(client, None).map_err(|e| {
warn!("与服务端握手失败:{e}");
e
})?;
let ciphersuite = ctx.ciphersuite();
debug!("与服务端握手成功,ciphersuite: {:4x?}", ciphersuite);
Ok(Self { ctx })
}
pub fn send_data_with_header(&mut self, packet_type: PacketType, data: &[u8]) -> Result<()> {
let header = PacketHeader {
data_type: u64::from(packet_type),
data_size: data.len() as u64,
};
debug!("客户端:发送协议头:{:x?}", header.as_bytes());
self.ctx.write_all(header.as_bytes()).map_err(|e| {
warn!("客户端:发送协议头失败:{e}");
e
})?;
self.write_chunks(data)
}
pub fn send_data(&mut self, data: &[u8]) -> Result<()> {
self.write_chunks(data)
}
fn write_chunks(&mut self, data: &[u8]) -> Result<()> {
let total = data.len();
let chunk_size = CHUNK_SIZE as usize;
let mut tmp = [0u8; 4];
if total <= chunk_size {
debug!("客户端:发送数据:{data:x?}");
self.ctx.write_all(data)?;
self.ctx.read_exact(&mut tmp).map_err(|e| {
warn!("客户端:接收临时数据失败:{e}");
e
})?;
} else {
for chunk in data.chunks(chunk_size) {
self.ctx.write_all(chunk)?;
self.ctx.read_exact(&mut tmp).map_err(|e| {
warn!("客户端:接收临时数据失败:{e}");
e
})?;
}
}
Ok(())
}
pub fn recv_data(&mut self, data: &mut [u8]) -> Result<()> {
let data_len = data.len();
if data_len <= CHUNK_SIZE as usize {
self.ctx.read_exact(data)?;
} else {
let mut offset: usize = 0;
while offset < data_len {
let chunk_size = (data_len - offset).min(CHUNK_SIZE as usize);
self.ctx
.read_exact(&mut data[offset..offset + chunk_size])?;
offset += chunk_size;
}
}
Ok(())
}
}
#[unsafe(no_mangle)]
pub extern "C" fn cc_check_enable() -> i32 {
let mut ctx = CcClient::init();
match &mut ctx {
Ok(ctx) => {
ctx.ctx.close();
1
}
Err(_) => 0,
}
}
#[cfg(test)]
mod cc_client_tests {
use super::*;
use crate::common::protocol::{CHUNK_SIZE, PacketType};
use std::io::{Error, ErrorKind};
struct MockVsockStream {
read_data: Vec<u8>,
write_data: Vec<u8>,
read_pos: usize,
should_fail: bool,
}
impl MockVsockStream {
fn new() -> Self {
Self {
read_data: vec![0u8; 4], write_data: Vec::new(),
read_pos: 0,
should_fail: false,
}
}
}
impl Read for MockVsockStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
if self.should_fail {
return Err(Error::new(ErrorKind::ConnectionReset, "模拟连接错误"));
}
let remaining = self.read_data.len() - self.read_pos;
if remaining == 0 {
return Ok(0);
}
let to_read = buf.len().min(remaining);
buf[..to_read].copy_from_slice(&self.read_data[self.read_pos..self.read_pos + to_read]);
self.read_pos += to_read;
Ok(to_read)
}
}
impl Write for MockVsockStream {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
if self.should_fail {
return Err(Error::new(ErrorKind::ConnectionReset, "模拟写入错误"));
}
self.write_data.extend_from_slice(buf);
Ok(buf.len())
}
fn flush(&mut self) -> Result<()> {
Ok(())
}
}
struct MockContext {
stream: MockVsockStream,
}
impl MockContext {
fn new(stream: MockVsockStream) -> Self {
Self { stream }
}
}
impl Read for MockContext {
fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
self.stream.read(buf)
}
}
impl Write for MockContext {
fn write(&mut self, buf: &[u8]) -> Result<usize> {
self.stream.write(buf)
}
fn flush(&mut self) -> Result<()> {
self.stream.flush()
}
}
#[test]
fn test_packet_header_serialization() {
let header = PacketHeader {
data_type: 1,
data_size: 1024,
};
let bytes = header.as_bytes();
assert_eq!(bytes.len(), 16); }
#[test]
fn test_write_chunks_small_data() {
let mock_stream = MockVsockStream::new();
let _mock_ctx = MockContext::new(mock_stream);
let small_data = vec![1, 2, 3, 4, 5];
let chunk_size = CHUNK_SIZE as usize;
let chunks: Vec<&[u8]> = small_data.chunks(chunk_size).collect();
assert_eq!(chunks.len(), 1); assert_eq!(chunks[0].len(), small_data.len());
}
#[test]
fn test_write_chunks_large_data() {
let large_data_size = (CHUNK_SIZE * 3) as usize;
let large_data: Vec<u8> = (0..large_data_size).map(|i| (i % 256) as u8).collect();
let chunk_size = CHUNK_SIZE as usize;
let chunks: Vec<&[u8]> = large_data.chunks(chunk_size).collect();
assert_eq!(chunks.len(), 3); assert_eq!(chunks[0].len(), chunk_size);
assert_eq!(chunks[1].len(), chunk_size);
assert_eq!(chunks[2].len(), large_data_size - 2 * chunk_size);
}
#[test]
fn test_recv_data_logic() {
let small_buffer_size = (CHUNK_SIZE / 2) as usize;
let large_buffer_size = (CHUNK_SIZE * 2) as usize;
assert!(small_buffer_size <= CHUNK_SIZE as usize);
assert!(large_buffer_size > CHUNK_SIZE as usize);
}
#[test]
fn test_cc_check_enable_logic() {
let success_result = 1;
let fail_result = 0;
assert_ne!(success_result, fail_result);
}
#[test]
fn test_packet_type_conversion() {
let open_session_type = PacketType::OpenSession;
let invoke_command_type = PacketType::InvokeCommand;
assert_ne!(u64::from(open_session_type), u64::from(invoke_command_type));
assert_eq!(PacketType::from(1), PacketType::OpenSession);
assert_eq!(PacketType::from(3), PacketType::InvokeCommand);
assert_eq!(PacketType::from(99), PacketType::Unknown); }
#[test]
fn test_error_handling() {
let error = Error::new(ErrorKind::ConnectionReset, "测试连接错误");
assert!(format!("{}", error).contains("测试连接错误"));
let mbedtls_error = mbedtls::Error::LowLevel(codes::NetConnectFailed);
match mbedtls_error {
mbedtls::Error::LowLevel(code) => {
assert_eq!(code, codes::NetConnectFailed);
}
_ => panic!("Expected LowLevel error"),
}
}
#[test]
fn test_chunk_boundary_cases() {
let exact_chunk_size = CHUNK_SIZE as usize;
let one_less_than_chunk = (CHUNK_SIZE - 1) as usize;
let one_more_than_chunk = (CHUNK_SIZE + 1) as usize;
assert!(exact_chunk_size <= CHUNK_SIZE as usize); assert!(one_less_than_chunk <= CHUNK_SIZE as usize);
assert!(one_more_than_chunk > CHUNK_SIZE as usize);
}
#[test]
fn test_all_packet_types() {
assert_eq!(u64::from(PacketType::Unknown), 0);
assert_eq!(PacketType::from(0), PacketType::Unknown);
assert_eq!(u64::from(PacketType::OpenSession), 1);
assert_eq!(PacketType::from(1), PacketType::OpenSession);
assert_eq!(u64::from(PacketType::CloseSession), 2);
assert_eq!(PacketType::from(2), PacketType::CloseSession);
assert_eq!(u64::from(PacketType::InvokeCommand), 3);
assert_eq!(PacketType::from(3), PacketType::InvokeCommand);
assert_eq!(u64::from(PacketType::RequestCancellation), 4);
assert_eq!(PacketType::from(4), PacketType::RequestCancellation);
}
#[test]
fn test_packet_type_conversion_consistency() {
for value in 0..=4 {
let packet_type = PacketType::from(value);
let converted_value = u64::from(packet_type);
assert_eq!(
value, converted_value,
"双向转换不一致: 原始值={}, 转换后值={}",
value, converted_value
);
}
let out_of_range_value = 99;
let packet_type = PacketType::from(out_of_range_value);
assert_eq!(
packet_type,
PacketType::Unknown,
"超出范围的值 {} 应该转换为 PacketType::Unknown",
out_of_range_value
);
assert_eq!(
u64::from(packet_type),
0,
"PacketType::Unknown 应该转换为 0"
);
}
}