use std::cell::UnsafeCell;
use std::marker::PhantomData;
#[cfg_attr(php_zts, allow(unused_imports))]
use std::mem::MaybeUninit;
pub trait ModuleGlobal: Default + 'static {
fn ginit(&mut self) {}
fn gshutdown(&mut self) {}
}
unsafe extern "C" {
#[cfg(php_zts)]
fn ext_php_rs_tsrmg_bulk(id: i32) -> *mut std::ffi::c_void;
}
pub struct ModuleGlobals<T: ModuleGlobal> {
#[cfg(php_zts)]
id: UnsafeCell<i32>,
#[cfg(not(php_zts))]
inner: UnsafeCell<MaybeUninit<T>>,
_marker: PhantomData<T>,
}
unsafe impl<T: ModuleGlobal> Sync for ModuleGlobals<T> {}
impl<T: ModuleGlobal> Default for ModuleGlobals<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: ModuleGlobal> ModuleGlobals<T> {
#[must_use]
pub const fn new() -> Self {
Self {
#[cfg(php_zts)]
id: UnsafeCell::new(0),
#[cfg(not(php_zts))]
inner: UnsafeCell::new(MaybeUninit::uninit()),
_marker: PhantomData,
}
}
pub fn get(&self) -> &T {
unsafe {
#[cfg(php_zts)]
{
let id = *self.id.get();
debug_assert!(id != 0, "ModuleGlobals accessed before registration");
&*ext_php_rs_tsrmg_bulk(id).cast::<T>()
}
#[cfg(not(php_zts))]
{
(*self.inner.get()).assume_init_ref()
}
}
}
#[allow(clippy::mut_from_ref)]
pub unsafe fn get_mut(&self) -> &mut T {
#[cfg(php_zts)]
unsafe {
let id = *self.id.get();
debug_assert!(id != 0, "ModuleGlobals accessed before registration");
&mut *ext_php_rs_tsrmg_bulk(id).cast::<T>()
}
#[cfg(not(php_zts))]
unsafe {
(*self.inner.get()).assume_init_mut()
}
}
pub fn as_ptr(&self) -> *mut T {
unsafe {
#[cfg(php_zts)]
{
ext_php_rs_tsrmg_bulk(*self.id.get()).cast::<T>()
}
#[cfg(not(php_zts))]
{
(*self.inner.get()).as_mut_ptr()
}
}
}
#[cfg(php_zts)]
pub(crate) fn id_ptr(&self) -> *mut i32 {
self.id.get()
}
#[cfg(not(php_zts))]
pub(crate) fn data_ptr(&self) -> *mut std::ffi::c_void {
self.inner.get().cast()
}
}
pub(crate) unsafe extern "C" fn ginit_callback<T: ModuleGlobal>(globals: *mut std::ffi::c_void) {
unsafe {
let ptr = globals.cast::<T>();
ptr.write(T::default());
(*ptr).ginit();
}
}
pub(crate) unsafe extern "C" fn gshutdown_callback<T: ModuleGlobal>(
globals: *mut std::ffi::c_void,
) {
unsafe {
let ptr = globals.cast::<T>();
(*ptr).gshutdown();
std::ptr::drop_in_place(ptr);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Default)]
struct TestGlobals {
value: i32,
initialized: bool,
}
impl ModuleGlobal for TestGlobals {
fn ginit(&mut self) {
self.initialized = true;
self.value = 42;
}
fn gshutdown(&mut self) {
self.initialized = false;
}
}
#[test]
fn new_is_const() {
static _G: ModuleGlobals<TestGlobals> = ModuleGlobals::new();
}
#[test]
fn ginit_callback_initializes() {
let mut storage = MaybeUninit::<TestGlobals>::uninit();
unsafe {
ginit_callback::<TestGlobals>(storage.as_mut_ptr().cast());
let globals = storage.assume_init_ref();
assert!(globals.initialized);
assert_eq!(globals.value, 42);
std::ptr::drop_in_place(storage.as_mut_ptr());
}
}
#[test]
fn gshutdown_callback_cleans_up() {
let mut storage = MaybeUninit::<TestGlobals>::uninit();
unsafe {
ginit_callback::<TestGlobals>(storage.as_mut_ptr().cast());
gshutdown_callback::<TestGlobals>(storage.as_mut_ptr().cast());
}
}
#[test]
#[cfg(not(php_zts))]
fn non_zts_get_after_init() {
let globals: ModuleGlobals<TestGlobals> = ModuleGlobals::new();
unsafe {
ginit_callback::<TestGlobals>(globals.data_ptr());
}
assert!(globals.get().initialized);
assert_eq!(globals.get().value, 42);
unsafe {
gshutdown_callback::<TestGlobals>(globals.data_ptr());
}
}
#[test]
#[cfg(not(php_zts))]
fn non_zts_get_mut() {
let globals: ModuleGlobals<TestGlobals> = ModuleGlobals::new();
unsafe {
ginit_callback::<TestGlobals>(globals.data_ptr());
globals.get_mut().value = 99;
}
assert_eq!(globals.get().value, 99);
unsafe {
gshutdown_callback::<TestGlobals>(globals.data_ptr());
}
}
#[test]
#[cfg(not(php_zts))]
fn non_zts_as_ptr() {
let globals: ModuleGlobals<TestGlobals> = ModuleGlobals::new();
unsafe {
ginit_callback::<TestGlobals>(globals.data_ptr());
}
let ptr = globals.as_ptr();
assert_eq!(unsafe { (*ptr).value }, 42);
unsafe {
gshutdown_callback::<TestGlobals>(globals.data_ptr());
}
}
#[derive(Default)]
struct ZstGlobals;
impl ModuleGlobal for ZstGlobals {}
#[test]
fn zst_size_is_zero() {
assert_eq!(std::mem::size_of::<ZstGlobals>(), 0);
}
}