rust-libteec 0.5.0

Rust implementation of TEE Client API for secure communication with Trusted Applications.
Documentation
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2025-2026 KylinSoft Co., Ltd. <https://www.kylinos.cn/>
// See LICENSES for license details.

//! TEE 会话管理模块
//!
//! 负责 CA 与 TA 之间的会话生命周期管理,包括:
//! - 会话的打开、关闭
//! - 命令调用
//! - 取消请求
//!
//! 所有会话操作均通过机密通信通道(TLS + VSOCK)进行,
//! 请求和响应序列化为 `postcard` 格式后加密传输。

use std::ptr;

use log::{debug, warn};

use postcard::{take_from_bytes, to_allocvec};
use uuid::Uuid;

#[cfg(feature = "ca-sign-verify")]
use teec_protocol::CaAuthInfo;
use teec_protocol::{PacketType, TEE_Parameters, TEE_Request, TEE_Response};

use crate::{Error, ErrorKind, ErrorOrigin, Result, raw};

use super::context::ContextManager;
#[cfg(feature = "ca-sign-verify")]
use super::get_or_verify_ca;
use super::{build_parameters_from_operation, safe_ptr, update_operation_from_parameters};

pub(crate) fn open_session_impl(
    ctx: *mut raw::TEEC_Context,
    session: *mut raw::TEEC_Session,
    destination: *const raw::TEEC_UUID,
    connection_method: u32,
    operation: *mut raw::TEEC_Operation,
) -> Result<u32> {
    let _ = safe_ptr::deref_mut(ctx)?;
    let _ = safe_ptr::deref_mut(session)?;
    let uuid_nn = safe_ptr::deref(destination)?;
    // SAFETY: `uuid_nn` 已由 `deref` 验证为非空。`as_ref()` 返回对
    // 调用方提供的 UUID 的不可变引用,该引用在此处有效。
    let uuid = unsafe { uuid_nn.as_ref() };
    let uuid_str = uuid_to_string(uuid)?;

    // 在 OpenSession 开始时立即执行 CA 认证(带缓存)
    // 尽早认证可以提前发现 CA 文件问题,避免后续不必要的操作
    #[cfg(feature = "ca-sign-verify")]
    let ca_auth_info = {
        let auth_result = get_or_verify_ca();

        // 记录认证结果(警告式,不阻断)
        if !auth_result.verified {
            warn!("CA 签名验证失败: ca_uuid={}", auth_result.ca_uuid);
        } else {
            debug!("CA 签名验证通过: ca_uuid={}", auth_result.ca_uuid);
        }

        // 转换为协议格式(直接使用 CaAuthInfo,无需转换)
        Some(CaAuthInfo {
            ca_uuid: auth_result.ca_uuid,
            verified: auth_result.verified,
        })
    };

    #[cfg(not(feature = "ca-sign-verify"))]
    let ca_auth_info: Option<teec_protocol::CaAuthInfo> = None;

    let params = if operation.is_null() {
        TEE_Parameters::default()
    } else {
        build_parameters_from_operation(operation)?
    };

    let request = TEE_Request::OpenSession {
        uuid: uuid_str,
        connection_method,
        params,
        ca_auth_info,
    };

    let mut session_nn = safe_ptr::deref_mut(session)?;
    // SAFETY: `session_nn` 已验证非空;`as_mut()` 返回针对会话结构的
    // 可变引用,该引用在此作用域中有效。
    let session_ref = unsafe { session_nn.as_mut() };
    let response = send_request_and_recv_response(ctx, PacketType::OpenSession, &request)?;

    match response {
        TEE_Response::OpenSession { session_id, result } => {
            debug!("TEEC_OpenSession: 接收 session_id 和结果: {session_id}, {result}");

            session_ref.imp.ctx = ctx;
            session_ref.imp.session_id = session_id;

            if result != raw::TEEC_SUCCESS {
                return Err(Error::new(ErrorKind::from(result)).with_origin(ErrorOrigin::TEE));
            }
            Ok(result)
        }
        _ => Err(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API)),
    }
}

pub(crate) fn close_session_impl(session: *mut raw::TEEC_Session) -> Result<()> {
    let mut session_nn = safe_ptr::deref_mut(session)?;
    // SAFETY: `session_nn` 已验证非空;`as_mut()` 返回针对会话结构的
    // 可变引用,可用于读取其字段。
    let session_ref = unsafe { session_nn.as_mut() };
    let session_id = session_ref.imp.session_id;
    let ctx = session_ref.imp.ctx;

    if ctx.is_null() {
        return Err(Error::new(ErrorKind::BadParameters));
    }

    let request = TEE_Request::CloseSession { session_id };
    let response = send_request_and_recv_response(ctx, PacketType::CloseSession, &request)?;

    match response {
        TEE_Response::CloseSession { result } => {
            debug!("TEEC_CloseSession: 接收结果: {result}");
            session_ref.imp.ctx = ptr::null_mut();
            session_ref.imp.session_id = 0;
            Ok(())
        }
        _ => Err(Error::new(ErrorKind::BadParameters)),
    }
}

pub(crate) fn invoke_command_impl(
    session: *mut raw::TEEC_Session,
    cmd_id: u32,
    operation: *mut raw::TEEC_Operation,
) -> Result<u32> {
    let mut session_nn = safe_ptr::deref_mut(session)?;
    // SAFETY: `session_nn` 已验证非空;`as_mut()` 返回的可变引用在此
    // 作用域内有效,可用于读取 `imp.session_id` 和 `imp.ctx`。
    let session_ref = unsafe { session_nn.as_mut() };
    let session_id = session_ref.imp.session_id;
    let ctx = session_ref.imp.ctx;

    if ctx.is_null() {
        return Err(Error::new(ErrorKind::BadParameters));
    }

    let params = if operation.is_null() {
        TEE_Parameters::default()
    } else {
        let mut operation_nn = safe_ptr::deref_mut(operation)?;
        // SAFETY: `operation_nn` 已验证非空,且 `deref_mut` 提供独占访问;
        // 调用 `as_mut()` 可获得可变引用,可安全修改 `imp.session` 字段。
        let operation_ref = unsafe { operation_nn.as_mut() };
        operation_ref.imp.session = session;

        build_parameters_from_operation(operation)?
    };

    let request = TEE_Request::InvokeCommand {
        session_id,
        cmd_id,
        params,
    };

    let response = send_request_and_recv_response(ctx, PacketType::InvokeCommand, &request)?;

    match response {
        TEE_Response::InvokeCommand { params, result } => {
            if result != raw::TEEC_SUCCESS {
                let origin = if result == raw::TEEC_ERROR_TARGET_DEAD {
                    ErrorOrigin::TEE
                } else {
                    ErrorOrigin::API
                };
                return Err(Error::new(ErrorKind::from(result)).with_origin(origin));
            }

            if !operation.is_null() {
                update_operation_from_parameters(operation, params)?;
            }
            Ok(result)
        }
        _ => Err(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API)),
    }
}

pub(crate) fn request_cancellation_impl(operation: *mut raw::TEEC_Operation) -> Result<()> {
    let mut operation_nn = safe_ptr::deref_mut(operation)?;
    // SAFETY: `operation_nn` 已验证非空,`deref_mut` 提供独占访问;
    // `as_mut()` 返回的可变引用可安全用于读取 `imp.session`。
    let operation_ref = unsafe { operation_nn.as_mut() };
    let session = operation_ref.imp.session;

    if session.is_null() {
        return Ok(());
    }

    let mut session_nn = safe_ptr::deref_mut(session)?;
    // SAFETY: `session_nn` 已验证非空;`as_mut()` 返回的可变引用在此作用域内有效,
    // 可用于读取 `imp.session_id` 和 `imp.ctx`。
    let session_ref = unsafe { session_nn.as_mut() };
    let session_id = session_ref.imp.session_id;
    let ctx = session_ref.imp.ctx;

    if ctx.is_null() {
        return Ok(());
    }

    let request = TEE_Request::RequestCancellation { session_id };
    let response = send_request_and_recv_response(ctx, PacketType::RequestCancellation, &request)?;

    match response {
        TEE_Response::RequestCancellation { result: _ } => {
            debug!("TEEC_RequestCancellation: 接收结果");
            Ok(())
        }
        _ => Err(Error::new(ErrorKind::BadParameters)),
    }
}

/// 将 GP TEE UUID 格式(`TEEC_UUID` 结构体)转换为标准 UUID 字符串。
///
/// 用于在 `TEEC_OpenSession` 中将 TA UUID 序列化为协议请求中的 `uuid` 字段。
fn uuid_to_string(uuid: &raw::TEEC_UUID) -> Result<String> {
    Ok(Uuid::from_fields(
        uuid.timeLow,
        uuid.timeMid,
        uuid.timeHiAndVersion,
        &uuid.clockSeqAndNode,
    )
    .to_string())
}

/// 通过机密通信通道发送请求并接收响应。
///
/// 流程:
/// 1. 从 `ContextManager` 获取 TLS 客户端连接
/// 2. 将 `TEE_Request` 序列化为 `postcard` 格式
/// 3. 通过 TLS 通道发送(带协议头)
/// 4. 先接收 4 字节长度前缀,再接收响应体
/// 5. 反序列化为 `TEE_Response` 并返回
///
/// 请求和响应均限制最大 64MB,防止 OOM 攻击。
fn send_request_and_recv_response(
    ctx: *mut raw::TEEC_Context,
    packet_type: PacketType,
    request: &TEE_Request,
) -> Result<TEE_Response> {
    let client_arc = ContextManager::get_client(ctx)?;
    let mut client = client_arc
        .lock()
        .map_err(|_| Error::new(ErrorKind::Generic).with_origin(ErrorOrigin::API))?;

    let request_data = to_allocvec(request).map_err(|e| {
        warn!("序列化请求失败:{e}");
        Error::new(ErrorKind::BadFormat).with_origin(ErrorOrigin::API)
    })?;

    // 防止 OOM 攻击:限制请求数据最大为 64MB
    const MAX_REQUEST_SIZE: usize = 64 * 1024 * 1024; // 64MB

    if request_data.len() > MAX_REQUEST_SIZE {
        warn!(
            "请求数据过大:{} bytes (最大允许 {} bytes)",
            request_data.len(),
            MAX_REQUEST_SIZE
        );
        return Err(Error::new(ErrorKind::BadFormat).with_origin(ErrorOrigin::API));
    }

    client
        .send_data_with_header(packet_type, &request_data)
        .map_err(|e| {
            warn!("发送请求失败:{e}");
            Error::new(ErrorKind::Communication).with_origin(ErrorOrigin::COMMS)
        })?;

    let mut len_buf = [0u8; 4];

    client.recv_data(&mut len_buf).map_err(|e| {
        warn!("接收响应长度失败:{e}");
        Error::new(ErrorKind::Communication).with_origin(ErrorOrigin::COMMS)
    })?;

    // 防止整数溢出:先验证 u32 值在合理范围内,再转换为 usize
    let response_len_u32 = u32::from_ne_bytes(len_buf);
    const MAX_RESPONSE_SIZE: usize = 64 * 1024 * 1024; // 64MB

    if response_len_u32 > MAX_RESPONSE_SIZE as u32 {
        warn!(
            "响应数据过大:{} bytes (最大允许 {} bytes)",
            response_len_u32, MAX_RESPONSE_SIZE
        );
        return Err(Error::new(ErrorKind::BadFormat).with_origin(ErrorOrigin::COMMS));
    }

    let response_len = response_len_u32 as usize;
    let mut response_data = vec![0u8; response_len];

    client.recv_data(&mut response_data).map_err(|e| {
        warn!("接收响应数据失败:{e}");
        Error::new(ErrorKind::Communication).with_origin(ErrorOrigin::COMMS)
    })?;

    take_from_bytes::<TEE_Response>(&response_data)
        .map(|(response, _)| response)
        .map_err(|e| {
            warn!("反序列化响应失败:{e}");
            Error::new(ErrorKind::BadFormat).with_origin(ErrorOrigin::API)
        })
}