rust-libteec 0.4.6

Rust implementation of TEE Client API for secure communication with Trusted Applications.
Documentation
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2025-2026 KylinSoft Co., Ltd. <https://www.kylinos.cn/>
// See LICENSES for license details.

//! 机密通信客户端实现,基于 TLS 和 VSOCK 实现与 TEE OS 的安全通信通道。

use std::{
    io::{Read, Result, Write},
    sync::Arc,
};

use log::{debug, warn};
use mbedtls::{
    Result as TlsResult,
    error::codes,
    rng::CtrDrbg,
    ssl::{
        CipherSuite::EcdhePskWithSm4128GcmSm3,
        Config, Context, Version,
        config::{Endpoint, Preset, Transport},
    },
};
use virga::client::{ClientConfig, VirgeClient};
use zeroize::Zeroize;

use crate::cc_client::{
    psk::{generate_psk, get_psk_identity},
    vsock_define::{get_vsock_cid, get_vsock_port},
};
use teec_protocol::{CHUNK_SIZE, PacketHeader, PacketType};

pub(crate) struct CcClient {
    pub ctx: Context<VirgeClient>,
}

/// SAFETY: CcClient 可以安全地在线程间发送 (Send)
///
/// 理由:
/// 1. `Context<VirgeClient>` 内部使用 Arc 管理共享状态
/// 2. Virga 的 ClientConfig 和连接状态都有内部锁保护
/// 3. 所有对 ctx 的可变访问都通过 TEEC_InvokeCommand 等函数序列化
/// 4. 没有线程不安全的裸指针或可变静态变量
///
/// 注意:虽然 CcClient 是 Send,但并发调用时需要外部同步(由 teec.rs 中的 Mutex 保证)
unsafe impl Send for CcClient {}

/// SAFETY: CcClient 可以安全地在线程间共享引用 (Sync)
///
/// 理由:
/// 1. 同 Send 的理由,内部状态通过 Arc 和锁保护
/// 2. mbedtls 的 Context 设计为线程安全
/// 3. VirgeClient 的连接操作有内部同步机制
/// 4. 实际使用中,CcClient 被包裹在 Arc<Mutex<CcClient>> 中,提供额外的同步保障
///
/// 注意:Sync 仅表示可以共享引用,实际的可变访问仍需要 Mutex
unsafe impl Sync for CcClient {}

impl CcClient {
    /// 初始化客户端并建立 TLS 连接
    /// 返回建立连接的客户端实例,失败返回TLS错误
    pub fn init() -> TlsResult<Self> {
        // 尝试初始化日志系统,如果已经初始化则忽略错误
        // 使用 RUST_LOG 环境变量控制级别,如 RUST_LOG=debug
        let _ = env_logger::try_init();
        let entropy = Arc::new(mbedtls::rng::OsEntropy::new());
        let rng = Arc::new(CtrDrbg::new(entropy, None)?);
        let cipher_suites: Vec<i32> = vec![EcdhePskWithSm4128GcmSm3.into(), 0];
        let mut psk = generate_psk()?;
        let psk_identify = get_psk_identity();
        let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);

        config.set_min_version(Version::Tls1_2)?;
        config.set_max_version(Version::Tls1_2)?;
        config.set_rng(rng);
        config.set_ciphersuites(Arc::new(cipher_suites));
        config.set_psk(&psk, psk_identify)?;
        config.set_read_timeout(5000); // 5 秒读取超时

        // 敏感数据使用后立即清零
        psk.zeroize();

        let mut ctx = Context::new(Arc::new(config));

        // 从环境变量获取 VSOCK 配置,支持自定义 CID 和 Port
        let cid = get_vsock_cid();
        let port = get_vsock_port();
        let vconfig = ClientConfig::new(cid, port, CHUNK_SIZE as u32, false);

        let mut client = VirgeClient::new(vconfig);

        client.connect().map_err(|e| {
            warn!("VirgeClient 连接失败:{e}");
            mbedtls::Error::LowLevel(codes::NetConnectFailed)
        })?;

        // 进行握手
        ctx.establish(client, None).map_err(|e| {
            warn!("与服务端握手失败:{e}");
            e
        })?;

        let ciphersuite = ctx.ciphersuite();
        debug!("与服务端握手成功,ciphersuite: {:4x?}", ciphersuite);

        Ok(Self { ctx })
    }

    /// 发送带协议头的数据包
    /// packet_type: 数据包类型,用于服务端解析处理
    /// data: 要发送的原始数据
    pub fn send_data_with_header(&mut self, packet_type: PacketType, data: &[u8]) -> Result<()> {
        let header = PacketHeader {
            data_type: u64::from(packet_type),
            data_size: data.len() as u64,
        };

        debug!("客户端:发送协议头:{:x?}", header.as_bytes());

        // 先发送协议头
        self.ctx.write_all(header.as_bytes()).map_err(|e| {
            warn!("客户端:发送协议头失败:{e}");
            e
        })?;

        // 再发送数据体
        self.send_data(data)
    }

    /// 发送无协议头的原始数据(用于连续数据流)
    pub fn send_data(&mut self, data: &[u8]) -> Result<()> {
        self.ctx.write_all(data)?;
        debug!("客户端:发送数据,大小:{}", data.len());
        Ok(())
    }

    /// 从服务器接收数据,支持分块接收
    pub fn recv_data(&mut self, data: &mut [u8]) -> Result<()> {
        self.ctx.read_exact(data)?;
        debug!("客户端:接收数据,实际大小:{}", data.len());
        Ok(())
    }
}

/// C 接口:检查机密通信功能是否可用
/// 返回1表示可用,0 表示不可用
#[unsafe(no_mangle)]
pub extern "C" fn cc_check_enable() -> i32 {
    let mut ctx = CcClient::init();
    match &mut ctx {
        Ok(ctx) => {
            ctx.ctx.close();
            1
        }
        Err(_) => 0,
    }
}

#[cfg(test)]
mod cc_client_tests {
    use super::*;
    use std::io::{Error, ErrorKind};
    use teec_protocol::{CHUNK_SIZE, PacketType};

    // 测试用模拟 vsock 流
    struct MockVsockStream {
        read_data: Vec<u8>,
        write_data: Vec<u8>,
        read_pos: usize,
        should_fail: bool,
    }

    impl MockVsockStream {
        fn new() -> Self {
            Self {
                read_data: vec![0u8; 4], // 默认返回 4 字节的临时数据
                write_data: Vec::new(),
                read_pos: 0,
                should_fail: false,
            }
        }
    }

    impl Read for MockVsockStream {
        fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
            if self.should_fail {
                return Err(Error::new(ErrorKind::ConnectionReset, "模拟连接错误"));
            }

            let remaining = self.read_data.len() - self.read_pos;
            if remaining == 0 {
                return Ok(0);
            }

            let to_read = buf.len().min(remaining);
            buf[..to_read].copy_from_slice(&self.read_data[self.read_pos..self.read_pos + to_read]);
            self.read_pos += to_read;
            Ok(to_read)
        }
    }

    impl Write for MockVsockStream {
        fn write(&mut self, buf: &[u8]) -> Result<usize> {
            if self.should_fail {
                return Err(Error::new(ErrorKind::ConnectionReset, "模拟写入错误"));
            }

            self.write_data.extend_from_slice(buf);
            Ok(buf.len())
        }

        fn flush(&mut self) -> Result<()> {
            Ok(())
        }
    }

    // 模拟 TLS 上下文,用于测试业务逻辑
    struct MockContext {
        stream: MockVsockStream,
    }

    impl MockContext {
        fn new(stream: MockVsockStream) -> Self {
            Self { stream }
        }
    }

    impl Read for MockContext {
        fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
            self.stream.read(buf)
        }
    }

    impl Write for MockContext {
        fn write(&mut self, buf: &[u8]) -> Result<usize> {
            self.stream.write(buf)
        }

        fn flush(&mut self) -> Result<()> {
            self.stream.flush()
        }
    }

    #[test]
    fn test_packet_header_serialization() {
        let header = PacketHeader {
            data_type: 1,
            data_size: 1024,
        };

        let bytes = header.as_bytes();
        assert_eq!(bytes.len(), 16); // u64 + u64 = 16 字节
    }

    #[test]
    fn test_write_chunks_small_data() {
        let mock_stream = MockVsockStream::new();
        let _mock_ctx = MockContext::new(mock_stream);

        let small_data = [1, 2, 3, 4, 5];

        // 测试 chunk 逻辑
        let chunk_size = CHUNK_SIZE as usize;
        let chunks: Vec<&[u8]> = small_data.chunks(chunk_size).collect();

        assert_eq!(chunks.len(), 1); // 小数据应该只有 1 个 chunk
        assert_eq!(chunks[0].len(), small_data.len());
    }

    #[test]
    fn test_write_chunks_large_data() {
        let large_data_size = (CHUNK_SIZE * 3) as usize;
        let large_data: Vec<u8> = (0..large_data_size).map(|i| (i % 256) as u8).collect();

        let chunk_size = CHUNK_SIZE as usize;
        let chunks: Vec<&[u8]> = large_data.chunks(chunk_size).collect();

        assert_eq!(chunks.len(), 3); // 大数据应该分成 3 个 chunk
        assert_eq!(chunks[0].len(), chunk_size);
        assert_eq!(chunks[1].len(), chunk_size);
        assert_eq!(chunks[2].len(), large_data_size - 2 * chunk_size);
    }

    #[test]
    fn test_recv_data_logic() {
        // 测试接收逻辑的分支
        let small_buffer_size = (CHUNK_SIZE / 2) as usize;
        let large_buffer_size = (CHUNK_SIZE * 2) as usize;

        // 测试小数据接收路径
        assert!(small_buffer_size <= CHUNK_SIZE as usize);

        // 测试大数据接收路径
        assert!(large_buffer_size > CHUNK_SIZE as usize);
    }

    #[test]
    fn test_cc_check_enable_logic() {
        // 测试返回值的逻辑
        let success_result = 1;
        let fail_result = 0;

        assert_ne!(success_result, fail_result);
    }

    #[test]
    fn test_packet_type_conversion() {
        // 使用实际的PacketType变体进行测试
        let open_session_type = PacketType::OpenSession;
        let invoke_command_type = PacketType::InvokeCommand;

        assert_ne!(u64::from(open_session_type), u64::from(invoke_command_type));

        // 测试从u64转换
        assert_eq!(PacketType::from(1), PacketType::OpenSession);
        assert_eq!(PacketType::from(3), PacketType::InvokeCommand);
        assert_eq!(PacketType::from(99), PacketType::Unknown); // 测试未知类型
    }

    #[test]
    fn test_error_handling() {
        // 测试错误处理路径
        let error = Error::new(ErrorKind::ConnectionReset, "测试连接错误");

        // 验证错误信息包含预期内容
        assert!(format!("{}", error).contains("测试连接错误"));

        // 测试 mbedtls 错误
        let mbedtls_error = mbedtls::Error::LowLevel(codes::NetConnectFailed);

        // 检查错误码
        match mbedtls_error {
            mbedtls::Error::LowLevel(code) => {
                // 验证错误码是 NetConnectFailed
                assert_eq!(code, codes::NetConnectFailed);
            }
            _ => panic!("Expected LowLevel error"),
        }
    }

    #[test]
    fn test_chunk_boundary_cases() {
        // 测试边界情况
        let exact_chunk_size = CHUNK_SIZE as usize;
        let one_less_than_chunk = (CHUNK_SIZE - 1) as usize;
        let one_more_than_chunk = (CHUNK_SIZE + 1) as usize;

        assert!(exact_chunk_size <= CHUNK_SIZE as usize); // 应该走小数据路径
        assert!(one_less_than_chunk <= CHUNK_SIZE as usize);
        assert!(one_more_than_chunk > CHUNK_SIZE as usize);
    }

    #[test]
    fn test_all_packet_types() {
        assert_eq!(u64::from(PacketType::Unknown), 0);
        assert_eq!(PacketType::from(0), PacketType::Unknown);

        assert_eq!(u64::from(PacketType::OpenSession), 1);
        assert_eq!(PacketType::from(1), PacketType::OpenSession);

        assert_eq!(u64::from(PacketType::CloseSession), 2);
        assert_eq!(PacketType::from(2), PacketType::CloseSession);

        assert_eq!(u64::from(PacketType::InvokeCommand), 3);
        assert_eq!(PacketType::from(3), PacketType::InvokeCommand);

        assert_eq!(u64::from(PacketType::RequestCancellation), 4);
        assert_eq!(PacketType::from(4), PacketType::RequestCancellation);
    }

    #[test]
    fn test_packet_type_conversion_consistency() {
        // 测试 PacketType 转换的一致性
        // 这个测试验证双向转换的一致性
        for value in 0..=4 {
            let packet_type = PacketType::from(value);
            let converted_value = u64::from(packet_type);
            assert_eq!(
                value, converted_value,
                "双向转换不一致: 原始值={}, 转换后值={}",
                value, converted_value
            );
        }

        // 测试超出范围的值的转换
        let out_of_range_value = 99;
        let packet_type = PacketType::from(out_of_range_value);
        assert_eq!(
            packet_type,
            PacketType::Unknown,
            "超出范围的值 {} 应该转换为 PacketType::Unknown",
            out_of_range_value
        );
        assert_eq!(
            u64::from(packet_type),
            0,
            "PacketType::Unknown 应该转换为 0"
        );
    }
}