rust-libteec 0.4.0

Rust implementation of TEE Client API for secure communication with Trusted Applications.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2025-2026 KylinSoft Co., Ltd. <https://www.kylinos.cn/>
// See LICENSES for license details.

//! 机密通信客户端实现,基于 TLS 和 VSOCK 实现与 TEE OS 的安全通信通道。

use std::{
    io::{Read, Result, Write},
    sync::Arc,
};

use log::{debug, warn};
use mbedtls::{
    Result as TlsResult,
    error::codes,
    rng::CtrDrbg,
    ssl::{
        CipherSuite::{
            DhePskWithSm4128GcmSm3, EcdhePskWithSm4128GcmSm3, PskWithSm4128GcmSm3,
            RsaPskWithSm4128GcmSm3,
        },
        Config, Context,
        config::{Endpoint, Preset, Transport},
    },
};
use virga::client::{ClientConfig, VirgeClient};
use zeroize::Zeroize;

use crate::cc_client::{
    psk::{generate_psk, get_psk_identity},
    vsock_define::{get_vsock_cid, get_vsock_port},
};
use teec_protocol::{CHUNK_SIZE, PacketHeader, PacketType};

pub(crate) struct CcClient {
    pub ctx: Context<VirgeClient>,
}

/// SAFETY: CcClient 可以安全地在线程间发送 (Send)
///
/// 理由:
/// 1. `Context<VirgeClient>` 内部使用 Arc 管理共享状态
/// 2. Virga 的 ClientConfig 和连接状态都有内部锁保护
/// 3. 所有对 ctx 的可变访问都通过 TEEC_InvokeCommand 等函数序列化
/// 4. 没有线程不安全的裸指针或可变静态变量
///
/// 注意:虽然 CcClient 是 Send,但并发调用时需要外部同步(由 teec.rs 中的 Mutex 保证)
unsafe impl Send for CcClient {}

/// SAFETY: CcClient 可以安全地在线程间共享引用 (Sync)
///
/// 理由:
/// 1. 同 Send 的理由,内部状态通过 Arc 和锁保护
/// 2. mbedtls 的 Context 设计为线程安全
/// 3. VirgeClient 的连接操作有内部同步机制
/// 4. 实际使用中,CcClient 被包裹在 Arc<Mutex<CcClient>> 中,提供额外的同步保障
///
/// 注意:Sync 仅表示可以共享引用,实际的可变访问仍需要 Mutex
unsafe impl Sync for CcClient {}

impl CcClient {
    /// 初始化客户端并建立 TLS 连接
    /// 返回建立连接的客户端实例,失败返回TLS错误
    pub fn init() -> TlsResult<Self> {
        // 尝试初始化日志系统,如果已经初始化则忽略错误
        // 使用 RUST_LOG 环境变量控制级别,如 RUST_LOG=debug
        let _ = env_logger::try_init();
        let entropy = Arc::new(mbedtls::rng::OsEntropy::new());
        let rng = Arc::new(CtrDrbg::new(entropy, None)?);
        let cipher_suites: Vec<i32> = vec![
            EcdhePskWithSm4128GcmSm3.into(),
            DhePskWithSm4128GcmSm3.into(),
            RsaPskWithSm4128GcmSm3.into(),
            PskWithSm4128GcmSm3.into(),
            0,
        ];
        let mut psk = generate_psk()?;
        let psk_identify = get_psk_identity();
        let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);

        config.set_rng(rng);
        config.set_ciphersuites(Arc::new(cipher_suites));
        config.set_psk(&psk, psk_identify)?;

        // 敏感数据使用后立即清零
        psk.zeroize();

        config.set_read_timeout(5000); // 5 秒读取超时

        let mut ctx = Context::new(Arc::new(config));

        // 从环境变量获取 VSOCK 配置,支持自定义 CID 和 Port
        let cid = get_vsock_cid();
        let port = get_vsock_port();
        let vconfig = ClientConfig::new(cid, port, CHUNK_SIZE as u32, false);

        let mut client = VirgeClient::new(vconfig);

        client.connect().map_err(|e| {
            warn!("VirgeClient 连接失败:{e}");
            mbedtls::Error::LowLevel(codes::NetConnectFailed)
        })?;

        // 进行握手
        ctx.establish(client, None).map_err(|e| {
            warn!("与服务端握手失败:{e}");
            e
        })?;

        let ciphersuite = ctx.ciphersuite();
        debug!("与服务端握手成功,ciphersuite: {:4x?}", ciphersuite);

        Ok(Self { ctx })
    }

    /// 发送带协议头的数据包
    /// packet_type: 数据包类型,用于服务端解析处理
    /// data: 要发送的原始数据
    pub fn send_data_with_header(&mut self, packet_type: PacketType, data: &[u8]) -> Result<()> {
        let header = PacketHeader {
            data_type: u64::from(packet_type),
            data_size: data.len() as u64,
        };

        debug!("客户端:发送协议头:{:x?}", header.as_bytes());

        // 先发送协议头
        self.ctx.write_all(header.as_bytes()).map_err(|e| {
            warn!("客户端:发送协议头失败:{e}");
            e
        })?;

        // 再发送数据体
        self.send_data(data)
    }

    /// 发送无协议头的原始数据(用于连续数据流)
    pub fn send_data(&mut self, data: &[u8]) -> Result<()> {
        self.ctx.write_all(data)?;
        debug!("客户端:发送数据,大小:{}", data.len());
        Ok(())
    }

    /// 从服务器接收数据,支持分块接收
    pub fn recv_data(&mut self, data: &mut [u8]) -> Result<()> {
        self.ctx.read_exact(data)?;
        debug!("客户端:接收数据,实际大小:{}", data.len());
        Ok(())
    }
}

/// C 接口:检查机密通信功能是否可用
/// 返回1表示可用,0 表示不可用
#[unsafe(no_mangle)]
pub extern "C" fn cc_check_enable() -> i32 {
    let mut ctx = CcClient::init();
    match &mut ctx {
        Ok(ctx) => {
            ctx.ctx.close();
            1
        }
        Err(_) => 0,
    }
}

#[cfg(test)]
mod cc_client_tests {
    use super::*;
    use std::io::{Error, ErrorKind};
    use teec_protocol::{CHUNK_SIZE, PacketType};

    // 测试用模拟 vsock 流
    struct MockVsockStream {
        read_data: Vec<u8>,
        write_data: Vec<u8>,
        read_pos: usize,
        should_fail: bool,
    }

    impl MockVsockStream {
        fn new() -> Self {
            Self {
                read_data: vec![0u8; 4], // 默认返回 4 字节的临时数据
                write_data: Vec::new(),
                read_pos: 0,
                should_fail: false,
            }
        }
    }

    impl Read for MockVsockStream {
        fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
            if self.should_fail {
                return Err(Error::new(ErrorKind::ConnectionReset, "模拟连接错误"));
            }

            let remaining = self.read_data.len() - self.read_pos;
            if remaining == 0 {
                return Ok(0);
            }

            let to_read = buf.len().min(remaining);
            buf[..to_read].copy_from_slice(&self.read_data[self.read_pos..self.read_pos + to_read]);
            self.read_pos += to_read;
            Ok(to_read)
        }
    }

    impl Write for MockVsockStream {
        fn write(&mut self, buf: &[u8]) -> Result<usize> {
            if self.should_fail {
                return Err(Error::new(ErrorKind::ConnectionReset, "模拟写入错误"));
            }

            self.write_data.extend_from_slice(buf);
            Ok(buf.len())
        }

        fn flush(&mut self) -> Result<()> {
            Ok(())
        }
    }

    // 模拟 TLS 上下文,用于测试业务逻辑
    struct MockContext {
        stream: MockVsockStream,
    }

    impl MockContext {
        fn new(stream: MockVsockStream) -> Self {
            Self { stream }
        }
    }

    impl Read for MockContext {
        fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
            self.stream.read(buf)
        }
    }

    impl Write for MockContext {
        fn write(&mut self, buf: &[u8]) -> Result<usize> {
            self.stream.write(buf)
        }

        fn flush(&mut self) -> Result<()> {
            self.stream.flush()
        }
    }

    #[test]
    fn test_packet_header_serialization() {
        let header = PacketHeader {
            data_type: 1,
            data_size: 1024,
        };

        let bytes = header.as_bytes();
        assert_eq!(bytes.len(), 16); // u64 + u64 = 16 字节
    }

    #[test]
    fn test_write_chunks_small_data() {
        let mock_stream = MockVsockStream::new();
        let _mock_ctx = MockContext::new(mock_stream);

        let small_data = [1, 2, 3, 4, 5];

        // 测试 chunk 逻辑
        let chunk_size = CHUNK_SIZE as usize;
        let chunks: Vec<&[u8]> = small_data.chunks(chunk_size).collect();

        assert_eq!(chunks.len(), 1); // 小数据应该只有 1 个 chunk
        assert_eq!(chunks[0].len(), small_data.len());
    }

    #[test]
    fn test_write_chunks_large_data() {
        let large_data_size = (CHUNK_SIZE * 3) as usize;
        let large_data: Vec<u8> = (0..large_data_size).map(|i| (i % 256) as u8).collect();

        let chunk_size = CHUNK_SIZE as usize;
        let chunks: Vec<&[u8]> = large_data.chunks(chunk_size).collect();

        assert_eq!(chunks.len(), 3); // 大数据应该分成 3 个 chunk
        assert_eq!(chunks[0].len(), chunk_size);
        assert_eq!(chunks[1].len(), chunk_size);
        assert_eq!(chunks[2].len(), large_data_size - 2 * chunk_size);
    }

    #[test]
    fn test_recv_data_logic() {
        // 测试接收逻辑的分支
        let small_buffer_size = (CHUNK_SIZE / 2) as usize;
        let large_buffer_size = (CHUNK_SIZE * 2) as usize;

        // 测试小数据接收路径
        assert!(small_buffer_size <= CHUNK_SIZE as usize);

        // 测试大数据接收路径
        assert!(large_buffer_size > CHUNK_SIZE as usize);
    }

    #[test]
    fn test_cc_check_enable_logic() {
        // 测试返回值的逻辑
        let success_result = 1;
        let fail_result = 0;

        assert_ne!(success_result, fail_result);
    }

    #[test]
    fn test_packet_type_conversion() {
        // 使用实际的PacketType变体进行测试
        let open_session_type = PacketType::OpenSession;
        let invoke_command_type = PacketType::InvokeCommand;

        assert_ne!(u64::from(open_session_type), u64::from(invoke_command_type));

        // 测试从u64转换
        assert_eq!(PacketType::from(1), PacketType::OpenSession);
        assert_eq!(PacketType::from(3), PacketType::InvokeCommand);
        assert_eq!(PacketType::from(99), PacketType::Unknown); // 测试未知类型
    }

    #[test]
    fn test_error_handling() {
        // 测试错误处理路径
        let error = Error::new(ErrorKind::ConnectionReset, "测试连接错误");

        // 验证错误信息包含预期内容
        assert!(format!("{}", error).contains("测试连接错误"));

        // 测试 mbedtls 错误
        let mbedtls_error = mbedtls::Error::LowLevel(codes::NetConnectFailed);

        // 检查错误码
        match mbedtls_error {
            mbedtls::Error::LowLevel(code) => {
                // 验证错误码是 NetConnectFailed
                assert_eq!(code, codes::NetConnectFailed);
            }
            _ => panic!("Expected LowLevel error"),
        }
    }

    #[test]
    fn test_chunk_boundary_cases() {
        // 测试边界情况
        let exact_chunk_size = CHUNK_SIZE as usize;
        let one_less_than_chunk = (CHUNK_SIZE - 1) as usize;
        let one_more_than_chunk = (CHUNK_SIZE + 1) as usize;

        assert!(exact_chunk_size <= CHUNK_SIZE as usize); // 应该走小数据路径
        assert!(one_less_than_chunk <= CHUNK_SIZE as usize);
        assert!(one_more_than_chunk > CHUNK_SIZE as usize);
    }

    #[test]
    fn test_all_packet_types() {
        assert_eq!(u64::from(PacketType::Unknown), 0);
        assert_eq!(PacketType::from(0), PacketType::Unknown);

        assert_eq!(u64::from(PacketType::OpenSession), 1);
        assert_eq!(PacketType::from(1), PacketType::OpenSession);

        assert_eq!(u64::from(PacketType::CloseSession), 2);
        assert_eq!(PacketType::from(2), PacketType::CloseSession);

        assert_eq!(u64::from(PacketType::InvokeCommand), 3);
        assert_eq!(PacketType::from(3), PacketType::InvokeCommand);

        assert_eq!(u64::from(PacketType::RequestCancellation), 4);
        assert_eq!(PacketType::from(4), PacketType::RequestCancellation);
    }

    #[test]
    fn test_packet_type_conversion_consistency() {
        // 测试 PacketType 转换的一致性
        // 这个测试验证双向转换的一致性
        for value in 0..=4 {
            let packet_type = PacketType::from(value);
            let converted_value = u64::from(packet_type);
            assert_eq!(
                value, converted_value,
                "双向转换不一致: 原始值={}, 转换后值={}",
                value, converted_value
            );
        }

        // 测试超出范围的值的转换
        let out_of_range_value = 99;
        let packet_type = PacketType::from(out_of_range_value);
        assert_eq!(
            packet_type,
            PacketType::Unknown,
            "超出范围的值 {} 应该转换为 PacketType::Unknown",
            out_of_range_value
        );
        assert_eq!(
            u64::from(packet_type),
            0,
            "PacketType::Unknown 应该转换为 0"
        );
    }
}