use std::{
alloc::{self, Layout},
ptr,
sync::{
LazyLock,
atomic::{AtomicI32, Ordering},
},
};
use dashmap::DashMap;
use log::warn;
use super::safe_ptr;
use crate::{Error, ErrorKind, Result, raw};
const SHM_FLAG_BUFFER_ALLOCED: u32 = 1 << 0;
const MAX_SHARED_MEMORY_SIZE: usize = 100 * 1024 * 1024;
struct ShmEntry {
context_id: i32,
ptr: *mut u8,
size: usize,
}
unsafe impl Send for ShmEntry {}
unsafe impl Sync for ShmEntry {}
impl Drop for ShmEntry {
fn drop(&mut self) {
if !self.ptr.is_null() && self.size > 0 {
unsafe {
alloc::dealloc(self.ptr, Layout::from_size_align(self.size, 1).unwrap());
}
self.ptr = ptr::null_mut();
}
}
}
type SharedMemoryMap = DashMap<i32, ShmEntry>;
static SHMS: LazyLock<SharedMemoryMap> = LazyLock::new(DashMap::new);
static SHM_ID_COUNTER: LazyLock<AtomicI32> = LazyLock::new(|| AtomicI32::new(0));
pub struct SharedMemoryManager;
impl SharedMemoryManager {
pub fn allocate(
ctx: *mut raw::TEEC_Context,
shm: *mut raw::TEEC_SharedMemory,
registe: bool,
) -> Result<()> {
let mut shm_nn = safe_ptr::deref_mut(shm)?;
let ctx_nn = safe_ptr::deref(ctx)?;
let shm_ref = unsafe { shm_nn.as_mut() };
let ctx_ref = unsafe { ctx_nn.as_ref() };
let context_id = ctx_ref.imp.fd;
let flags = shm_ref.flags;
let size = shm_ref.size;
if size == 0 || flags == 0 || flags & !(raw::TEEC_MEM_INPUT | raw::TEEC_MEM_OUTPUT) != 0 {
return Err(Error::new(ErrorKind::BadParameters));
}
if size > MAX_SHARED_MEMORY_SIZE {
warn!(
"共享内存请求过大: {} bytes (最大允许: {} bytes)",
size, MAX_SHARED_MEMORY_SIZE
);
return Err(Error::new(ErrorKind::OutOfMemory));
}
let layout =
Layout::from_size_align(size, 1).map_err(|_| Error::new(ErrorKind::BadParameters))?;
let buf_ptr = unsafe { alloc::alloc(layout) };
if buf_ptr.is_null() {
return Err(Error::new(ErrorKind::OutOfMemory));
}
unsafe { ptr::write_bytes(buf_ptr, 0, size) };
if registe {
if shm_ref.buffer.is_null() {
unsafe { alloc::dealloc(buf_ptr, layout) };
return Err(Error::new(ErrorKind::BadParameters));
}
let src_data = safe_ptr::read_to_vec(shm_ref.buffer as *const u8, size)?;
unsafe { ptr::copy_nonoverlapping(src_data.as_ptr(), buf_ptr, size) };
}
let shm_id = SHM_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
shm_ref.buffer = buf_ptr as *mut std::ffi::c_void;
shm_ref.imp.id = shm_id;
shm_ref.imp.registered_fd = -1;
shm_ref.imp.shadow_buffer = ptr::null_mut();
shm_ref.imp.alloced_size = size;
shm_ref.imp.flags = SHM_FLAG_BUFFER_ALLOCED;
SHMS.insert(
shm_id,
ShmEntry {
context_id,
ptr: buf_ptr,
size,
},
);
Ok(())
}
pub fn release(shm: *mut raw::TEEC_SharedMemory) {
if let Ok(mut shm_nn) = safe_ptr::deref_mut(shm) {
let shm_ref = unsafe { shm_nn.as_mut() };
if shm_ref.imp.id < 0 {
return;
}
let id = shm_ref.imp.id;
shm_ref.imp.id = -1;
shm_ref.size = 0;
shm_ref.flags = 0;
shm_ref.buffer = ptr::null_mut();
SHMS.remove(&id);
}
}
pub fn get_buffer(shm: *const raw::TEEC_SharedMemory) -> Option<Vec<u8>> {
if shm.is_null() {
return None;
}
let id = unsafe { (*shm).imp.id };
SHMS.get(&id).map(|entry| {
let slice =
unsafe { std::slice::from_raw_parts(entry.value().ptr, entry.value().size) };
slice.to_vec()
})
}
pub fn release_by_context(ctx: *mut raw::TEEC_Context) {
if let Ok(ctx_nn) = safe_ptr::deref(ctx) {
let ctx_ref = unsafe { ctx_nn.as_ref() };
let context_id = ctx_ref.imp.fd;
SHMS.retain(|_, entry| entry.context_id != context_id);
}
}
}
#[cfg(test)]
mod shared_memory_tests {
use super::*;
fn create_test_context(id: i32) -> raw::TEEC_Context {
raw::TEEC_Context {
imp: raw::TEEC_Context__Imp {
fd: id,
memref_null: false,
reg_mem: false,
},
}
}
fn create_test_shm() -> raw::TEEC_SharedMemory {
raw::TEEC_SharedMemory {
buffer: ptr::null_mut(),
size: 0,
flags: 0,
imp: raw::TEEC_SharedMemory__Imp {
id: -1,
registered_fd: -1,
shadow_buffer: ptr::null_mut(),
alloced_size: 0,
flags: 0,
},
}
}
#[test]
fn test_allocate_valid_memory() {
let mut ctx = create_test_context(100);
let mut shm = create_test_shm();
shm.size = 64;
shm.flags = raw::TEEC_MEM_INPUT;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_ok(), "应该成功分配共享内存");
assert!(!shm.buffer.is_null(), "缓冲区指针应该非空");
assert!(shm.imp.id >= 0, "ID 应该为非负的唯一共享内存 id");
assert_eq!(shm.imp.alloced_size, 64, "分配大小应该正确");
SharedMemoryManager::release(&mut shm as *mut raw::TEEC_SharedMemory);
}
#[test]
fn test_allocate_with_registration() {
let mut ctx = create_test_context(101);
let mut data = vec![1u8, 2, 3, 4, 5];
let mut shm = create_test_shm();
shm.size = data.len();
shm.flags = raw::TEEC_MEM_INPUT;
shm.buffer = data.as_mut_ptr() as *mut std::ffi::c_void;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
true,
);
assert!(result.is_ok(), "应该成功注册共享内存");
if let Some(buffer) = SharedMemoryManager::get_buffer(&shm) {
assert_eq!(buffer, data, "数据应该被正确复制");
}
SharedMemoryManager::release(&mut shm as *mut raw::TEEC_SharedMemory);
}
#[test]
fn test_allocate_zero_size() {
let mut ctx = create_test_context(102);
let mut shm = create_test_shm();
shm.size = 0;
shm.flags = raw::TEEC_MEM_INPUT;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_err(), "零大小分配应该失败");
}
#[test]
fn test_allocate_invalid_flags() {
let mut ctx = create_test_context(103);
let mut shm = create_test_shm();
shm.size = 64;
shm.flags = 0;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_err(), "无效标志位应该导致失败");
}
#[test]
fn test_allocate_excessive_size() {
let mut ctx = create_test_context(104);
let mut shm = create_test_shm();
shm.size = MAX_SHARED_MEMORY_SIZE + 1;
shm.flags = raw::TEEC_MEM_INPUT;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_err(), "过大的分配应该失败");
}
#[test]
fn test_allocate_null_context() {
let mut shm = create_test_shm();
shm.size = 64;
shm.flags = raw::TEEC_MEM_INPUT;
let result = SharedMemoryManager::allocate(
ptr::null_mut(),
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_err(), "空上下文应该导致失败");
}
#[test]
fn test_allocate_null_shm() {
let mut ctx = create_test_context(105);
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
ptr::null_mut(),
false,
);
assert!(result.is_err(), "空共享内存指针应该导致失败");
}
#[test]
fn test_release_valid_memory() {
let mut ctx = create_test_context(106);
let mut shm = create_test_shm();
shm.size = 64;
shm.flags = raw::TEEC_MEM_INPUT;
SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
)
.unwrap();
assert!(SharedMemoryManager::get_buffer(&shm).is_some());
SharedMemoryManager::release(&mut shm as *mut raw::TEEC_SharedMemory);
assert_eq!(shm.imp.id, -1, "ID 应该被重置为 -1");
assert_eq!(shm.size, 0, "大小应该被重置为 0");
assert!(shm.buffer.is_null(), "缓冲区指针应该为空");
assert!(
SharedMemoryManager::get_buffer(&shm).is_none(),
"缓存应该被清除"
);
}
#[test]
fn test_release_null_shm() {
SharedMemoryManager::release(ptr::null_mut());
}
#[test]
fn test_release_already_released() {
let mut ctx = create_test_context(107);
let mut shm = create_test_shm();
shm.size = 64;
shm.flags = raw::TEEC_MEM_INPUT;
SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
)
.unwrap();
SharedMemoryManager::release(&mut shm as *mut raw::TEEC_SharedMemory);
SharedMemoryManager::release(&mut shm as *mut raw::TEEC_SharedMemory);
}
#[test]
fn test_get_buffer_null_shm() {
let result = SharedMemoryManager::get_buffer(ptr::null());
assert!(result.is_none(), "空指针应该返回 None");
}
#[test]
fn test_get_buffer_unregistered() {
let shm = create_test_shm();
let result = SharedMemoryManager::get_buffer(&shm);
assert!(result.is_none(), "未注册的共享内存应该返回 None");
}
#[test]
fn test_multiple_allocations() {
let mut contexts = vec![];
let mut shms = vec![];
for i in 200..205 {
let mut ctx = create_test_context(i);
let mut shm = create_test_shm();
shm.size = 32;
shm.flags = raw::TEEC_MEM_INPUT;
SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
)
.unwrap();
contexts.push(ctx);
shms.push(shm);
}
for shm in &shms {
assert!(SharedMemoryManager::get_buffer(shm).is_some());
}
for shm in &mut shms {
SharedMemoryManager::release(shm as *mut raw::TEEC_SharedMemory);
}
for shm in &shms {
assert!(SharedMemoryManager::get_buffer(shm).is_none());
}
}
#[test]
fn test_allocate_both_flags() {
let mut ctx = create_test_context(300);
let mut shm = create_test_shm();
shm.size = 64;
shm.flags = raw::TEEC_MEM_INPUT | raw::TEEC_MEM_OUTPUT;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_ok(), "同时设置 INPUT 和 OUTPUT 应该成功");
SharedMemoryManager::release(&mut shm as *mut raw::TEEC_SharedMemory);
}
#[test]
fn test_allocate_exactly_max_size() {
let mut ctx = create_test_context(301);
let mut shm = create_test_shm();
shm.size = MAX_SHARED_MEMORY_SIZE;
shm.flags = raw::TEEC_MEM_INPUT;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_ok(), "正好等于最大大小应该成功");
SharedMemoryManager::release(&mut shm as *mut raw::TEEC_SharedMemory);
}
#[test]
fn test_allocate_one_byte_over_max() {
let mut ctx = create_test_context(302);
let mut shm = create_test_shm();
shm.size = MAX_SHARED_MEMORY_SIZE + 1;
shm.flags = raw::TEEC_MEM_INPUT;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_err(), "超过最大大小应该失败");
}
#[test]
fn test_allocate_invalid_flag_combinations() {
let mut ctx = create_test_context(303);
let mut shm = create_test_shm();
shm.size = 64;
shm.flags = raw::TEEC_MEM_OUTPUT;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
drop(result);
shm.flags = 0x12345678;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_err(), "无效标志位应该失败");
}
#[test]
fn test_release_with_different_states() {
let mut shm1 = create_test_shm();
SharedMemoryManager::release(&mut shm1 as *mut raw::TEEC_SharedMemory);
assert_eq!(shm1.imp.id, -1);
let mut ctx = create_test_context(304);
let mut shm2 = create_test_shm();
shm2.size = 32;
shm2.flags = raw::TEEC_MEM_INPUT;
SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm2 as *mut raw::TEEC_SharedMemory,
false,
)
.unwrap();
SharedMemoryManager::release(&mut shm2 as *mut raw::TEEC_SharedMemory);
assert_eq!(shm2.imp.id, -1);
assert_eq!(shm2.size, 0);
}
#[test]
fn test_get_buffer_edge_cases() {
let mut shm = create_test_shm();
shm.imp.id = 0; let result = SharedMemoryManager::get_buffer(&shm);
assert!(result.is_none());
shm.imp.id = -2;
let result = SharedMemoryManager::get_buffer(&shm);
assert!(result.is_none());
}
#[test]
fn test_allocate_with_registration_null_buffer() {
let mut ctx = create_test_context(305);
let mut shm = create_test_shm();
shm.size = 64;
shm.flags = raw::TEEC_MEM_INPUT;
shm.buffer = ptr::null_mut();
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
true, );
drop(result);
}
#[test]
fn test_stress_multiple_rapid_allocations() {
let mut ctx = create_test_context(400);
for i in 0..100 {
let mut shm = create_test_shm();
shm.size = 16;
shm.flags = raw::TEEC_MEM_INPUT;
let result = SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm as *mut raw::TEEC_SharedMemory,
false,
);
assert!(result.is_ok(), "第 {} 次分配应该成功", i);
SharedMemoryManager::release(&mut shm as *mut raw::TEEC_SharedMemory);
}
}
#[test]
fn test_multiple_allocations_same_context() {
let mut ctx = create_test_context(500);
let mut shm1 = create_test_shm();
shm1.size = 32;
shm1.flags = raw::TEEC_MEM_INPUT;
let mut shm2 = create_test_shm();
shm2.size = 64;
shm2.flags = raw::TEEC_MEM_INPUT;
SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm1 as *mut raw::TEEC_SharedMemory,
false,
)
.unwrap();
SharedMemoryManager::allocate(
&mut ctx as *mut raw::TEEC_Context,
&mut shm2 as *mut raw::TEEC_SharedMemory,
false,
)
.unwrap();
assert_ne!(
shm1.imp.id, shm2.imp.id,
"同一 context 下多次分配应该有不同的 shm id"
);
assert!(SharedMemoryManager::get_buffer(&shm1).is_some());
assert!(SharedMemoryManager::get_buffer(&shm2).is_some());
SharedMemoryManager::release(&mut shm1 as *mut raw::TEEC_SharedMemory);
assert!(SharedMemoryManager::get_buffer(&shm1).is_none());
assert!(
SharedMemoryManager::get_buffer(&shm2).is_some(),
"释放 shm1 不应影响 shm2"
);
SharedMemoryManager::release(&mut shm2 as *mut raw::TEEC_SharedMemory);
assert!(SharedMemoryManager::get_buffer(&shm2).is_none());
}
}