use std::any::Any;
use std::collections::HashMap;
use std::ffi::{CStr, CString, c_char};
use std::ptr;
use std::sync::{Arc, RwLock};
pub struct DiContainer {
inner: crate::Container,
services: RwLock<HashMap<String, Arc<dyn Any + Send + Sync>>>,
}
pub struct DiService {
type_name: String,
data: Vec<u8>,
}
#[repr(C)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiErrorCode {
Ok = 0,
NotFound = 1,
InvalidArgument = 2,
AlreadyRegistered = 3,
InternalError = 4,
SerializationError = 5,
}
#[repr(C)]
pub struct DiResult {
pub code: DiErrorCode,
pub service: *mut DiService,
}
thread_local! {
static LAST_ERROR: std::cell::RefCell<Option<String>> = const { std::cell::RefCell::new(None) };
}
fn set_last_error(msg: impl Into<String>) {
LAST_ERROR.with(|e| {
*e.borrow_mut() = Some(msg.into());
});
}
#[unsafe(no_mangle)]
pub extern "C" fn di_container_new() -> *mut DiContainer {
let container = Box::new(DiContainer {
inner: crate::Container::new(),
services: RwLock::new(HashMap::new()),
});
Box::into_raw(container)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_container_free(container: *mut DiContainer) {
if !container.is_null() {
drop(unsafe { Box::from_raw(container) });
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_container_scope(container: *mut DiContainer) -> *mut DiContainer {
if container.is_null() {
set_last_error("Container pointer is null");
return ptr::null_mut();
}
let parent = unsafe { &*container };
let child_inner = parent.inner.scope();
let services = parent.services.read().unwrap().clone();
let child = Box::new(DiContainer {
inner: child_inner,
services: RwLock::new(services),
});
Box::into_raw(child)
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_register_singleton(
container: *mut DiContainer,
type_name: *const c_char,
data: *const u8,
data_len: usize,
) -> DiErrorCode {
if container.is_null() {
set_last_error("Container pointer is null");
return DiErrorCode::InvalidArgument;
}
if type_name.is_null() {
set_last_error("Type name is null");
return DiErrorCode::InvalidArgument;
}
let type_name_str = if let Ok(s) = unsafe { CStr::from_ptr(type_name) }.to_str() {
s.to_string()
} else {
set_last_error("Type name is not valid UTF-8");
return DiErrorCode::InvalidArgument;
};
if data.is_null() && data_len > 0 {
set_last_error("Data pointer is null but length is non-zero");
return DiErrorCode::InvalidArgument;
}
let data_vec = if data_len > 0 {
unsafe { std::slice::from_raw_parts(data, data_len) }.to_vec()
} else {
Vec::new()
};
let container = unsafe { &*container };
{
let services = container.services.read().unwrap();
if services.contains_key(&type_name_str) {
set_last_error(format!("Service '{}' is already registered", type_name_str));
return DiErrorCode::AlreadyRegistered;
}
}
let service_data: Arc<dyn Any + Send + Sync> = Arc::new(data_vec);
container
.services
.write()
.unwrap()
.insert(type_name_str, service_data);
DiErrorCode::Ok
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_register_singleton_json(
container: *mut DiContainer,
type_name: *const c_char,
json_data: *const c_char,
) -> DiErrorCode {
if json_data.is_null() {
set_last_error("JSON data is null");
return DiErrorCode::InvalidArgument;
}
let json_str = match unsafe { CStr::from_ptr(json_data) }.to_str() {
Ok(s) => s,
Err(_) => {
set_last_error("JSON data is not valid UTF-8");
return DiErrorCode::InvalidArgument;
}
};
let json_bytes = json_str.as_bytes();
unsafe { di_register_singleton(container, type_name, json_bytes.as_ptr(), json_bytes.len()) }
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_resolve(
container: *mut DiContainer,
type_name: *const c_char,
) -> DiResult {
if container.is_null() {
set_last_error("Container pointer is null");
return DiResult {
code: DiErrorCode::InvalidArgument,
service: ptr::null_mut(),
};
}
if type_name.is_null() {
set_last_error("Type name is null");
return DiResult {
code: DiErrorCode::InvalidArgument,
service: ptr::null_mut(),
};
}
let type_name_str = match unsafe { CStr::from_ptr(type_name) }.to_str() {
Ok(s) => s.to_string(),
Err(_) => {
set_last_error("Type name is not valid UTF-8");
return DiResult {
code: DiErrorCode::InvalidArgument,
service: ptr::null_mut(),
};
}
};
let container = unsafe { &*container };
let services = container.services.read().unwrap();
match services.get(&type_name_str) {
Some(service_arc) => {
if let Some(data) = service_arc.downcast_ref::<Vec<u8>>() {
let service = Box::new(DiService {
type_name: type_name_str,
data: data.clone(),
});
DiResult {
code: DiErrorCode::Ok,
service: Box::into_raw(service),
}
} else {
set_last_error("Internal error: service data type mismatch");
DiResult {
code: DiErrorCode::InternalError,
service: ptr::null_mut(),
}
}
}
None => {
set_last_error(format!("Service '{}' not found", type_name_str));
DiResult {
code: DiErrorCode::NotFound,
service: ptr::null_mut(),
}
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_resolve_json(
container: *mut DiContainer,
type_name: *const c_char,
) -> *mut c_char {
if container.is_null() {
set_last_error("Container pointer is null");
return ptr::null_mut();
}
if type_name.is_null() {
set_last_error("Type name is null");
return ptr::null_mut();
}
let type_name_str = match unsafe { CStr::from_ptr(type_name) }.to_str() {
Ok(s) => s.to_string(),
Err(_) => {
set_last_error("Type name is not valid UTF-8");
return ptr::null_mut();
}
};
let container = unsafe { &*container };
let services = container.services.read().unwrap();
match services.get(&type_name_str) {
Some(service_arc) => {
if let Some(data) = service_arc.downcast_ref::<Vec<u8>>() {
match std::str::from_utf8(data) {
Ok(json_str) => match CString::new(json_str) {
Ok(cstr) => cstr.into_raw(),
Err(_) => {
set_last_error("JSON string contains null bytes");
ptr::null_mut()
}
},
Err(_) => {
set_last_error("Service data is not valid UTF-8");
ptr::null_mut()
}
}
} else {
set_last_error("Internal error: service data type mismatch");
ptr::null_mut()
}
}
None => {
set_last_error(format!("Service '{}' not found", type_name_str));
ptr::null_mut()
}
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_contains(container: *mut DiContainer, type_name: *const c_char) -> i32 {
if container.is_null() || type_name.is_null() {
return -1;
}
let type_name_str = match unsafe { CStr::from_ptr(type_name) }.to_str() {
Ok(s) => s,
Err(_) => return -1,
};
let container = unsafe { &*container };
let services = container.services.read().unwrap();
if services.contains_key(type_name_str) {
1
} else {
0
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_service_data(service: *const DiService) -> *const u8 {
if service.is_null() {
return ptr::null();
}
unsafe { &*service }.data.as_ptr()
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_service_data_len(service: *const DiService) -> usize {
if service.is_null() {
return 0;
}
unsafe { &*service }.data.len()
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_service_type_name(service: *const DiService) -> *const c_char {
if service.is_null() {
return ptr::null();
}
let service = unsafe { &*service };
match CString::new(service.type_name.clone()) {
Ok(cstr) => cstr.into_raw(),
Err(_) => ptr::null(),
}
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_service_free(service: *mut DiService) {
if !service.is_null() {
drop(unsafe { Box::from_raw(service) });
}
}
#[unsafe(no_mangle)]
pub extern "C" fn di_error_message() -> *mut c_char {
LAST_ERROR.with(|e| {
let error = e.borrow();
match &*error {
Some(msg) => match CString::new(msg.as_str()) {
Ok(cstr) => cstr.into_raw(),
Err(_) => ptr::null_mut(),
},
None => ptr::null_mut(),
}
})
}
#[unsafe(no_mangle)]
pub extern "C" fn di_error_clear() {
LAST_ERROR.with(|e| {
*e.borrow_mut() = None;
});
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_string_free(s: *mut c_char) {
if !s.is_null() {
drop(unsafe { CString::from_raw(s) });
}
}
#[unsafe(no_mangle)]
pub extern "C" fn di_version() -> *const c_char {
static VERSION: &[u8] = concat!(env!("CARGO_PKG_VERSION"), "\0").as_bytes();
VERSION.as_ptr() as *const c_char
}
#[unsafe(no_mangle)]
pub unsafe extern "C" fn di_service_count(container: *const DiContainer) -> i64 {
if container.is_null() {
return -1;
}
let container = unsafe { &*container };
container.services.read().unwrap().len() as i64
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_container_lifecycle() {
unsafe {
let container = di_container_new();
assert!(!container.is_null());
di_container_free(container);
}
}
#[test]
fn test_register_and_resolve() {
unsafe {
let container = di_container_new();
let type_name = CString::new("TestService").unwrap();
let data = b"hello world";
let result =
di_register_singleton(container, type_name.as_ptr(), data.as_ptr(), data.len());
assert_eq!(result, DiErrorCode::Ok);
let resolve_result = di_resolve(container, type_name.as_ptr());
assert_eq!(resolve_result.code, DiErrorCode::Ok);
assert!(!resolve_result.service.is_null());
let service = resolve_result.service;
assert_eq!(di_service_data_len(service), 11);
let data_ptr = di_service_data(service);
let resolved_data = std::slice::from_raw_parts(data_ptr, 11);
assert_eq!(resolved_data, b"hello world");
di_service_free(service);
di_container_free(container);
}
}
#[test]
fn test_not_found() {
unsafe {
let container = di_container_new();
let type_name = CString::new("NonExistent").unwrap();
let result = di_resolve(container, type_name.as_ptr());
assert_eq!(result.code, DiErrorCode::NotFound);
assert!(result.service.is_null());
di_container_free(container);
}
}
#[test]
fn test_contains() {
unsafe {
let container = di_container_new();
let type_name = CString::new("TestService").unwrap();
assert_eq!(di_contains(container, type_name.as_ptr()), 0);
let data = b"test";
di_register_singleton(container, type_name.as_ptr(), data.as_ptr(), data.len());
assert_eq!(di_contains(container, type_name.as_ptr()), 1);
di_container_free(container);
}
}
#[test]
fn test_scope() {
unsafe {
let parent = di_container_new();
let type_name = CString::new("ParentService").unwrap();
let data = b"parent";
di_register_singleton(parent, type_name.as_ptr(), data.as_ptr(), data.len());
let child = di_container_scope(parent);
assert!(!child.is_null());
assert_eq!(di_contains(child, type_name.as_ptr()), 1);
di_container_free(child);
di_container_free(parent);
}
}
}