rust-libteec 0.4.6

Rust implementation of TEE Client API for secure communication with Trusted Applications.
Documentation
// SPDX-License-Identifier: Apache-2.0
// Copyright (C) 2025-2026 KylinSoft Co., Ltd. <https://www.kylinos.cn/>
// See LICENSES for license details.

//! 安全指针操作模块,封装 unsafe 操作并提供错误处理
//!
//! **注意**: 此模块仅在 teec 模块内部使用,不对外公开。

use std::ptr::NonNull;

use crate::{Error, ErrorKind, ErrorOrigin, Result};

/// 安全地解引用可变指针,返回 NonNull 包装
pub(super) fn deref_mut<T>(ptr: *mut T) -> Result<NonNull<T>> {
    NonNull::new(ptr).ok_or(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API))
}

/// 安全地解引用不可变指针,返回 NonNull 包装
pub(super) fn deref<T>(ptr: *const T) -> Result<NonNull<T>> {
    NonNull::new(ptr as *mut T)
        .ok_or(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API))
}

/// 安全地向指针指向的内存写入值
pub(super) fn write_raw<T>(ptr: *mut T, value: T) -> Result<()> {
    let nn = NonNull::new(ptr)
        .ok_or(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API))?;
    // SAFETY: `nn` 已通过上面的 `NonNull::new` 检查为非空。
    // 调用方保证传入的 `ptr` 的所有权/有效性,因此向 `nn.as_ptr()` 写入
    // `value` 是安全的(内存对齐正确且可写)。
    unsafe { nn.as_ptr().write(value) };
    Ok(())
}

/// 安全地从指针指向的内存读取值
#[allow(dead_code)]
pub(super) fn read_raw<T: Copy>(ptr: *const T) -> Result<T> {
    let nn = NonNull::new(ptr as *mut T)
        .ok_or(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API))?;
    // SAFETY: `nn` 已为非空并指向有效的 `T`(已在上面检查)。
    // `T: Copy` 确保复制该值是安全的。
    Ok(unsafe { *nn.as_ptr() })
}

/// 安全地从指针读取指定长度的数据到 Vec
pub(super) fn read_to_vec<T: Copy>(src: *const T, len: usize) -> Result<Vec<T>> {
    if src.is_null() && len > 0 {
        return Err(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API));
    }
    let _nn = NonNull::new(src as *mut T)
        .ok_or(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API))?;
    // SAFETY: 当 `len > 0` 时 `src` 已被检查为非空。调用方保证
    // `src..src+len` 可安全读取且内存对齐正确。
    let slice = unsafe { std::slice::from_raw_parts(src, len) };
    Ok(slice.to_vec())
}

/// 安全地将切片数据写入指针指向的内存
pub(super) fn write_from_slice<T: Copy>(dst: *mut T, src: &[T]) -> Result<()> {
    if dst.is_null() && !src.is_empty() {
        return Err(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API));
    }

    if src.is_empty() {
        return Err(Error::new(ErrorKind::BadFormat).with_origin(ErrorOrigin::API));
    }

    let _nn = NonNull::new(dst)
        .ok_or(Error::new(ErrorKind::BadParameters).with_origin(ErrorOrigin::API))?;
    // SAFETY: 当 `src` 非空时已检查 `dst` 为非空。调用方保证 `dst` 指向
    // `src.len()` 个可写元素,且内存对齐适合 `T`。
    let dst_slice = unsafe { std::slice::from_raw_parts_mut(dst, src.len()) };
    dst_slice.copy_from_slice(src);
    Ok(())
}

/// 安全地进行指针加法运算
pub(super) fn add_ptr<T>(ptr: *const T, offset: usize) -> *const T {
    // SAFETY: 在同一已分配对象内进行指针运算;调用方保证 `ptr` 有效
    // 且 `offset` 在预期范围内。
    unsafe { ptr.add(offset) }
}

/// 安全地进行可变指针加法运算
pub(super) fn add_ptr_mut<T>(ptr: *mut T, offset: usize) -> *mut T {
    // SAFETY: 在同一已分配对象内进行指针运算;调用方保证 `ptr` 有效
    // 且 `offset` 在预期范围内。
    unsafe { ptr.add(offset) }
}

#[cfg(test)]
mod safe_ptr_tests {
    use super::*;

    #[test]
    fn test_write_and_read_raw() {
        // 测试安全指针写入和读取功能
        let mut v: u32 = 0;
        let p: *mut u32 = &mut v as *mut u32;

        write_raw(p, 0xDEADBEEF).expect("写入原始数据应该成功");
        let r = read_raw(p as *const u32).expect("读取原始数据应该成功");
        assert_eq!(r, 0xDEADBEEF, "写入后读取的值应该与写入值相同");
    }

    #[test]
    fn test_read_to_vec_and_write_from_slice() {
        // 测试从指针读取数据到 Vec 和从切片写入到指针
        let arr = [1u8, 2, 3, 4, 5];
        let vec = read_to_vec(arr.as_ptr(), arr.len()).expect("从指针读取到 Vec 应该成功");
        assert_eq!(vec, arr.to_vec(), "从数组读取的数据应该与原始数组相同");

        let mut buf = [0u8; 5];
        write_from_slice(buf.as_mut_ptr(), &vec).expect("从切片写入指针应该成功");
        assert_eq!(buf.to_vec(), vec, "写入缓冲区后应该与原始数据相同");
    }

    #[test]
    fn test_add_ptr() {
        // 测试指针加法运算
        let arr = [10u8, 20, 30, 40];
        let p = arr.as_ptr();
        let p2 = add_ptr(p, 2);

        assert_eq!(unsafe { *p2 }, 30, "偏移2个元素后应该指向第3个元素");
    }

    #[test]
    fn test_deref_null() {
        use std::ptr;
        assert!(deref(ptr::null::<u8>()).is_err());
        assert!(deref_mut(ptr::null_mut::<u8>()).is_err());
    }

    #[test]
    fn test_write_raw_null() {
        use std::ptr;
        let result = write_raw(ptr::null_mut::<u32>(), 42);
        assert!(result.is_err());
    }

    #[test]
    fn test_read_raw_null() {
        use std::ptr;
        let result = read_raw(ptr::null::<u32>());
        assert!(result.is_err());
    }

    #[test]
    fn test_read_to_vec_null() {
        use std::ptr;
        let result = read_to_vec(ptr::null::<u8>(), 10);
        assert!(result.is_err());
    }

    #[test]
    fn test_write_from_slice_empty() {
        let mut buf = [0u8; 5];
        let empty_slice: &[u8] = &[];
        let result = write_from_slice(buf.as_mut_ptr(), empty_slice);
        assert!(result.is_err());
    }

    #[test]
    fn test_add_ptr_mut() {
        let mut arr = [10u8, 20, 30, 40];
        let p = arr.as_mut_ptr();
        let p2 = add_ptr_mut(p, 2);

        assert_eq!(unsafe { *p2 }, 30, "可变指针偏移2个元素后应该指向第3个元素");
    }

    #[test]
    fn test_read_to_vec_zero_length() {
        // 测试读取零长度数据
        let arr = [1u8, 2, 3];
        let result = read_to_vec(arr.as_ptr(), 0);
        assert!(result.is_ok());
        assert_eq!(result.unwrap().len(), 0, "零长度读取应该返回空向量");
    }

    #[test]
    fn test_write_from_slice_null_dst_with_empty_src() {
        // 测试向空指针写入空切片
        use std::ptr;
        let empty_slice: &[u8] = &[];
        let result = write_from_slice(ptr::null_mut(), empty_slice);
        // 根据实现,即使源为空,目标为空也应该失败
        assert!(result.is_err());
    }

    #[test]
    fn test_add_ptr_boundary() {
        // 测试指针加法边界情况
        let arr = [1u8, 2, 3, 4, 5];
        let p = arr.as_ptr();

        // 偏移0(应该指向第一个元素)
        let p0 = add_ptr(p, 0);
        assert_eq!(unsafe { *p0 }, 1);

        // 偏移到最后一个元素
        let p_last = add_ptr(p, 4);
        assert_eq!(unsafe { *p_last }, 5);
    }

    #[test]
    fn test_deref_valid_pointer() {
        // 测试解引用有效指针
        let value = 42u32;
        let p = &value as *const u32;
        let result = deref(p);
        assert!(result.is_ok());
        assert_eq!(unsafe { *result.unwrap().as_ptr() }, 42);
    }

    #[test]
    fn test_deref_mut_valid_pointer() {
        // 测试可变解引用有效指针
        let mut value = 42u32;
        let p = &mut value as *mut u32;
        let result = deref_mut(p);
        assert!(result.is_ok());
        unsafe {
            *result.unwrap().as_ptr() = 100;
        }
        assert_eq!(value, 100, "可变解引用应该能修改值");
    }

    #[test]
    fn test_read_raw_valid() {
        // 测试从有效指针读取
        let value = 0xDEADBEEFu32;
        let p = &value as *const u32;
        let result = read_raw(p);
        assert!(result.is_ok());
        assert_eq!(result.unwrap(), 0xDEADBEEF);
    }

    #[test]
    fn test_write_raw_valid() {
        // 测试向有效指针写入
        let mut value = 0u32;
        let p = &mut value as *mut u32;
        let result = write_raw(p, 0xCAFEBABE);
        assert!(result.is_ok());
        assert_eq!(value, 0xCAFEBABE);
    }

    #[test]
    fn test_read_to_vec_large_size() {
        // 测试读取较大数据
        let large_arr = vec![42u8; 1000];
        let result = read_to_vec(large_arr.as_ptr(), large_arr.len());
        assert!(result.is_ok());
        assert_eq!(result.unwrap().len(), 1000);
    }

    #[test]
    fn test_write_from_slice_various_sizes() {
        // 测试不同大小的切片写入

        // 测试 1: 单字节
        let mut buf1 = [0u8; 1];
        let data1 = [1u8];
        assert!(write_from_slice(buf1.as_mut_ptr(), &data1).is_ok());
        assert_eq!(buf1[0], 1);

        // 测试 2: 中等大小
        let mut buf2 = [0u8; 100];
        let data2 = vec![42u8; 100];
        assert!(write_from_slice(buf2.as_mut_ptr(), &data2).is_ok());
        assert_eq!(buf2.to_vec(), data2);

        // 测试 3: 较大数据
        let mut buf3 = vec![0u8; 1000];
        let data3 = vec![99u8; 1000];
        assert!(write_from_slice(buf3.as_mut_ptr(), &data3).is_ok());
        assert_eq!(buf3, data3);
    }

    #[test]
    fn test_add_ptr_zero_offset() {
        // 测试偏移量为 0 的情况
        let arr = [5u8, 10, 15];
        let p = arr.as_ptr();
        let p0 = add_ptr(p, 0);
        assert_eq!(unsafe { *p0 }, 5, "偏移 0 应该指向第一个元素");
    }

    #[test]
    fn test_add_ptr_mut_various_offsets() {
        // 测试可变指针的不同偏移量
        let mut arr = [1u8, 2, 3, 4, 5];
        let p = arr.as_mut_ptr();

        for i in 0..5 {
            let p_i = add_ptr_mut(p, i);
            assert_eq!(unsafe { *p_i }, (i + 1) as u8);
        }
    }

    #[test]
    fn test_deref_and_deref_mut_consistency() {
        // 测试 deref 和 deref_mut 的一致性
        let mut value = 123u32;

        // 使用 deref 读取
        let const_ptr = &value as *const u32;
        let result_read = deref(const_ptr);
        assert!(result_read.is_ok());
        assert_eq!(unsafe { *result_read.unwrap().as_ptr() }, 123);

        // 使用 deref_mut 写入
        let mut_ptr = &mut value as *mut u32;
        let result_write = deref_mut(mut_ptr);
        assert!(result_write.is_ok());
        unsafe {
            *result_write.unwrap().as_ptr() = 456;
        }
        assert_eq!(value, 456);
    }

    #[test]
    fn test_error_paths_comprehensive() {
        // 综合测试所有错误路径
        use std::ptr;

        // 1. deref 空指针
        assert!(deref(ptr::null::<u8>()).is_err());

        // 2. deref_mut 空指针
        assert!(deref_mut(ptr::null_mut::<u8>()).is_err());

        // 3. read_raw 空指针
        assert!(read_raw(ptr::null::<u32>()).is_err());

        // 4. write_raw 空指针
        assert!(write_raw(ptr::null_mut::<u32>(), 42).is_err());

        // 5. read_to_vec 空指针且长度 > 0
        assert!(read_to_vec(ptr::null::<u8>(), 10).is_err());

        // 6. write_from_slice 空目标且源非空
        let data = [1u8, 2, 3];
        assert!(write_from_slice(ptr::null_mut(), &data).is_err());

        // 7. write_from_slice 空源
        let empty: &[u8] = &[];
        assert!(write_from_slice([0u8; 5].as_mut_ptr(), empty).is_err());
    }
}