rust-libteec 0.5.0

Rust implementation of TEE Client API for secure communication with Trusted Applications.
Documentation
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2025-2026 KylinSoft Co., Ltd. <https://www.kylinos.cn/>
// See LICENSES for license details.

//! 机密通信客户端实现,基于 TLS 和 VSOCK 实现与 TEE OS 的安全通信通道。

use std::{
    io::{Read, Write},
    sync::Arc,
};

use log::{debug, warn};
use mbedtls::{
    Result as TlsResult,
    error::codes,
    rng::CtrDrbg,
    ssl::{
        CipherSuite::EcdhePskWithSm4128GcmSm3,
        Config, Context, Version,
        config::{Endpoint, Preset, Transport},
    },
};
use virga::client::{ClientConfig, VirgeClient};
use zeroize::Zeroize;

use teec_protocol::{CHUNK_SIZE, PacketHeader, PacketType};
use xtee_psk::{
    CryptoRng, PSK_LEN, PskError, derive_psk, ecdh_compute_shared, extract_ec_point,
    generate_ecdh_keypair, get_psk_identity, new_crypto_rng, parse_ecdh_response,
    verify_ecdh_signature,
};

use crate::cc_client::{
    tofu::verify_server_identity,
    vsock_define::{get_vsock_cid, get_vsock_port},
};

/// 机密通信客户端,封装基于 TLS + VSOCK 的安全通信通道。
///
/// 内部通过 `mbedtls::ssl::Context<VirgeClient>` 管理 TLS 连接,
/// 使用 ECDH 动态协商的 PSK 进行身份认证和加密通信。
pub(crate) struct CcClient {
    ctx: Context<VirgeClient>,
}

/// SAFETY: CcClient 可以安全地在线程间发送 (Send)
///
/// 理由:
/// 1. `Context<VirgeClient>` 内部使用 Arc 管理共享状态
/// 2. Virga 的 ClientConfig 和连接状态都有内部锁保护
/// 3. 所有对 ctx 的可变访问都通过 TEEC_InvokeCommand 等函数序列化
/// 4. 没有线程不安全的裸指针或可变静态变量
///
/// 注意:CcClient 不实现 Sync,因为 Context<VirgeClient> 的方法需要 &mut self,
/// 并发共享引用无意义。实际使用中 CcClient 被包裹在 Mutex<CcClient> 中,
/// Mutex<T>: Send 要求 T: Send,不需要 T: Sync。
unsafe impl Send for CcClient {}

impl CcClient {
    /// 初始化客户端并建立 TLS 连接
    ///
    /// 流程:
    /// 1. 建立 VSOCK 连接
    /// 2. 在 VSOCK 上通过 ECDH 协商动态 PSK(SM2P256R1 曲线)
    /// 3. 派生 PSK 后建立 TLS 握手
    ///
    /// 返回建立连接的客户端实例,失败返回TLS错误
    pub fn init() -> TlsResult<Self> {
        // 尝试初始化日志系统,如果已经初始化则忽略错误
        // 使用 RUST_LOG 环境变量控制级别,如 RUST_LOG=debug
        let _ = env_logger::try_init();
        let entropy = Arc::new(mbedtls::rng::OsEntropy::new());
        let rng = CtrDrbg::new(entropy.clone(), None)?;
        let mut crypto_rng = new_crypto_rng().map_err(psk_to_mbedtls_error)?;

        // 从环境变量获取 VSOCK 配置,支持自定义 CID 和 Port
        let cid = get_vsock_cid();
        let port = get_vsock_port();
        let vconfig = ClientConfig::new(cid, port, CHUNK_SIZE as u32, false);
        let mut client = VirgeClient::new(vconfig);

        client.connect().map_err(|e| {
            warn!("VirgeClient 连接失败:{e}");
            mbedtls::Error::LowLevel(codes::NetConnectFailed)
        })?;

        //  ECDH 密钥协商(在 VSOCK 上)
        let (mut psk, psk_identity) = ecdh_negotiate(&mut client, &mut crypto_rng)?;

        debug!("ECDH 密钥协商成功");

        let cipher_suites: Vec<i32> = vec![EcdhePskWithSm4128GcmSm3.into(), 0];
        let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);

        config.set_min_version(Version::Tls1_2)?;
        config.set_max_version(Version::Tls1_2)?;
        config.set_ciphersuites(Arc::new(cipher_suites));
        config.set_read_timeout(5000); // 5 秒读取超时
        config.set_rng(Arc::new(rng));
        config.set_psk(&psk, psk_identity)?;

        // 清除敏感数据
        psk.zeroize();

        let mut ctx = Context::new(Arc::new(config));

        // 进行 TLS 握手(使用协商出的动态 PSK)
        ctx.establish(client, None).map_err(|e| {
            warn!("与服务端握手失败:{e}");
            e
        })?;

        debug!("与服务端握手成功,ciphersuite: {:4x?}", ctx.ciphersuite());

        Ok(Self { ctx })
    }

    /// 发送带协议头的数据包
    /// packet_type: 数据包类型,用于服务端解析处理
    /// data: 要发送的原始数据
    pub fn send_data_with_header(
        &mut self,
        packet_type: PacketType,
        data: &[u8],
    ) -> std::io::Result<()> {
        let header = PacketHeader {
            data_type: u64::from(packet_type),
            data_size: data.len() as u64,
        };

        debug!("客户端:发送协议头:{:x?}", header.as_bytes());

        // 先发送协议头
        self.ctx.write_all(header.as_bytes()).map_err(|e| {
            warn!("客户端:发送协议头失败:{e}");
            e
        })?;

        // 再发送数据体
        self.send_data(data)
    }

    /// 发送无协议头的原始数据(用于连续数据流)
    pub fn send_data(&mut self, data: &[u8]) -> std::io::Result<()> {
        self.ctx.write_all(data)?;
        debug!("客户端:发送数据,大小:{}", data.len());
        Ok(())
    }

    /// 从服务器接收数据,支持分块接收
    pub fn recv_data(&mut self, data: &mut [u8]) -> std::io::Result<()> {
        self.ctx.read_exact(data)?;
        debug!("客户端:接收数据,实际大小:{}", data.len());
        Ok(())
    }

    /// 关闭 TLS 连接,发送 close_notify 通知服务端连接即将关闭。
    ///
    /// 调用后连接资源被释放,后续操作将失败。
    pub fn close(&mut self) {
        self.ctx.close();
    }
}

/// 将 PSK 协商错误([`PskError`])映射为 mbedtls 错误码,
/// 统一为 `PkBadInputData` 高层错误。
///
/// 此函数是内部转换辅助函数,便于在 TLS 初始化流程中
/// 以统一的 `TlsResult` 类型传播错误。
fn psk_to_mbedtls_error(e: PskError) -> mbedtls::Error {
    debug!("PSK 协商错误:{e}");
    mbedtls::Error::HighLevel(codes::PkBadInputData)
}

/// ECDH 密钥协商(在 VSOCK 上)
///
/// 流程:
/// 1. 生成临时 ECDH 密钥对,发送裸 EC point
/// 2. 接收服务端响应并解析
/// 3. TOFU 校验服务端长期公钥
/// 4. 验证服务端 ECDH 签名
/// 5. 通过 ECDH 派生 PSK
fn ecdh_negotiate(
    client: &mut VirgeClient,
    ecdh_rng: &mut CryptoRng,
) -> TlsResult<([u8; PSK_LEN], &'static str)> {
    // 1. 生成临时 ECDH 密钥对
    let ecdh_key = generate_ecdh_keypair(ecdh_rng).map_err(|e| {
        warn!("生成 SM2 密钥对失败:{e}");
        psk_to_mbedtls_error(e)
    })?;
    let client_point = extract_ec_point(&ecdh_key).map_err(psk_to_mbedtls_error)?;

    // 2. 发送裸 EC point(65 字节)
    client.send(client_point.clone()).map_err(|e| {
        warn!("发送 ECDH 请求失败:{e}");
        mbedtls::Error::LowLevel(codes::NetSendFailed)
    })?;

    // 3. 接收并解析服务端响应
    let recv_data = client.recv().map_err(|e| {
        warn!("接收 ECDH 响应失败:{e}");
        mbedtls::Error::LowLevel(codes::NetRecvFailed)
    })?;
    let resp = parse_ecdh_response(&recv_data).map_err(|e| {
        warn!("解析 ECDH 响应失败:{e}");
        psk_to_mbedtls_error(e)
    })?;

    // 4. TOFU 校验服务端长期公钥
    verify_server_identity(resp.long_term_point).map_err(|e| {
        warn!("服务端身份校验失败:{e}");
        psk_to_mbedtls_error(e)
    })?;

    // 5. 验证服务端 ECDH 签名
    verify_ecdh_signature(
        resp.long_term_point,
        &client_point,
        resp.server_point,
        resp.signature,
    )
    .map_err(|e| {
        warn!("服务端 ECDH 签名验证失败:{e}");
        psk_to_mbedtls_error(e)
    })?;
    debug!("服务端 ECDH 签名验证通过");

    // 6. 通过 ECDH 协商共享秘密
    let shared = ecdh_compute_shared(&ecdh_key, resp.server_point).map_err(|e| {
        warn!("ECDH 共享秘密计算失败:{e}");
        psk_to_mbedtls_error(e)
    })?;
    drop(ecdh_key);

    // 7. 使用 HKDF-SM3 派生 PSK
    let psk_result = derive_psk(&shared, &client_point, resp.server_point);
    // shared: Zeroizing<Vec<u8>> 在离开作用域时自动清零,无需显式调用 zeroize
    let psk = psk_result.map_err(|e| {
        warn!("PSK 派生失败:{e}");
        psk_to_mbedtls_error(e)
    })?;

    Ok((psk, get_psk_identity()))
}

/// C 接口:检查机密通信功能是否可用。
///
/// 尝试建立一次完整的 TLS 连接(包括 ECDH PSK 协商),成功则立即关闭。
///
/// # 返回值
/// - 1:机密通信通道可用
/// - 0:连接失败,机密通信不可用
#[unsafe(no_mangle)]
pub extern "C" fn cc_check_enable() -> i32 {
    match CcClient::init() {
        Ok(mut client) => {
            client.close();
            1
        }
        Err(_) => 0,
    }
}

#[cfg(test)]
mod cc_client_tests {
    use super::*;
    use teec_protocol::PacketType;

    #[test]
    fn test_packet_header_serialization() {
        let header = PacketHeader {
            data_type: 1,
            data_size: 1024,
        };

        let bytes = header.as_bytes();
        assert_eq!(bytes.len(), 16); // u64 + u64 = 16 字节
    }

    #[test]
    fn test_all_packet_types() {
        assert_eq!(u64::from(PacketType::Unknown), 0);
        assert_eq!(PacketType::from(0), PacketType::Unknown);

        assert_eq!(u64::from(PacketType::OpenSession), 1);
        assert_eq!(PacketType::from(1), PacketType::OpenSession);

        assert_eq!(u64::from(PacketType::CloseSession), 2);
        assert_eq!(PacketType::from(2), PacketType::CloseSession);

        assert_eq!(u64::from(PacketType::InvokeCommand), 3);
        assert_eq!(PacketType::from(3), PacketType::InvokeCommand);

        assert_eq!(u64::from(PacketType::RequestCancellation), 4);
        assert_eq!(PacketType::from(4), PacketType::RequestCancellation);
    }

    #[test]
    fn test_packet_type_conversion_consistency() {
        // 测试 PacketType 转换的一致性
        // 这个测试验证双向转换的一致性
        for value in 0..=4 {
            let packet_type = PacketType::from(value);
            let converted_value = u64::from(packet_type);
            assert_eq!(
                value, converted_value,
                "双向转换不一致: 原始值={}, 转换后值={}",
                value, converted_value
            );
        }

        // 测试超出范围的值的转换
        let out_of_range_value = 99;
        let packet_type = PacketType::from(out_of_range_value);
        assert_eq!(
            packet_type,
            PacketType::Unknown,
            "超出范围的值 {} 应该转换为 PacketType::Unknown",
            out_of_range_value
        );
        assert_eq!(
            u64::from(packet_type),
            0,
            "PacketType::Unknown 应该转换为 0"
        );
    }
}