use std::sync::{
Arc, LazyLock, Mutex,
atomic::{AtomicI32, Ordering},
};
use dashmap::DashMap;
use super::safe_ptr;
use crate::{Error, ErrorKind, ErrorOrigin, Result, cc_client::client::CcClient, raw};
type ContextMap = DashMap<i32, Arc<Mutex<CcClient>>>;
static CONTEXTS: LazyLock<ContextMap> = LazyLock::new(DashMap::new);
pub struct ContextManager;
impl ContextManager {
pub fn add_context(ctx: *mut raw::TEEC_Context, client: CcClient) -> Result<()> {
let mut ctx_nn = safe_ptr::deref_mut(ctx)?;
let ctx_ref = unsafe { ctx_nn.as_mut() };
let ctx_id: i32 = ctx_ref.imp.fd;
let client_arc = Arc::new(Mutex::new(client));
if CONTEXTS.contains_key(&ctx_id) {
log::warn!(
"ContextManager: context {} already exists, rejecting duplicate",
ctx_id
);
return Err(Error::new(ErrorKind::BadState).with_origin(ErrorOrigin::API));
}
CONTEXTS.insert(ctx_id, client_arc);
Ok(())
}
pub fn remove_context(ctx: *mut raw::TEEC_Context) {
if let Ok(mut ctx_nn) = safe_ptr::deref_mut(ctx) {
let ctx_ref = unsafe { ctx_nn.as_mut() };
let ctx_id = ctx_ref.imp.fd;
if let Some((_, client_arc)) = CONTEXTS.remove(&ctx_id)
&& let Ok(mut client) = client_arc.lock()
{
client.close();
}
ctx_ref.imp.fd = -1;
}
}
pub fn get_client(ctx: *mut raw::TEEC_Context) -> Result<Arc<Mutex<CcClient>>> {
let ctx_nn = safe_ptr::deref_mut(ctx)?;
let ctx_ref = unsafe { ctx_nn.as_ref() };
let ctx_id = ctx_ref.imp.fd;
CONTEXTS
.get(&ctx_id)
.map(|entry| entry.value().clone())
.ok_or_else(|| Error::new(ErrorKind::Generic).with_origin(ErrorOrigin::API))
}
}
static CONTEXT_ID_COUNTER: LazyLock<AtomicI32> = LazyLock::new(|| AtomicI32::new(0));
pub(crate) fn initialize_context_impl(ctx: *mut raw::TEEC_Context) -> Result<()> {
let mut ctx_nn = safe_ptr::deref_mut(ctx)?;
let ctx_ref = unsafe { ctx_nn.as_mut() };
let client = CcClient::init().map_err(|e| {
log::warn!("TEEC_InitializeContext:初始化机密通信上下文失败:{e}");
Error::new(ErrorKind::Communication)
})?;
let id = CONTEXT_ID_COUNTER.fetch_add(1, Ordering::SeqCst);
ctx_ref.imp.fd = id;
ctx_ref.imp.reg_mem = true;
ctx_ref.imp.memref_null = true;
ContextManager::add_context(ctx, client)
}
#[cfg(test)]
mod context_tests {
use super::*;
use std::ptr;
fn create_test_context(id: i32) -> raw::TEEC_Context {
raw::TEEC_Context {
imp: raw::TEEC_Context__Imp {
fd: id,
memref_null: false,
reg_mem: false,
},
}
}
#[test]
fn test_remove_context_null_ptr() {
ContextManager::remove_context(ptr::null_mut());
}
#[test]
fn test_get_client_null_ptr() {
let result = ContextManager::get_client(ptr::null_mut());
assert!(result.is_err(), "空指针应该导致失败");
}
#[test]
fn test_get_client_unregistered() {
let mut ctx = create_test_context(999);
let result = ContextManager::get_client(&mut ctx as *mut raw::TEEC_Context);
assert!(result.is_err(), "未注册的上下文应该导致失败");
}
#[test]
fn test_context_lifecycle() {
let mut ctx = create_test_context(1000);
assert_eq!(ctx.imp.fd, 1000);
ContextManager::remove_context(&mut ctx as *mut raw::TEEC_Context);
assert_eq!(ctx.imp.fd, -1, "移除后 ID 应该被重置为 -1");
}
}