use ocaml_boxroot_sys::boxroot_teardown;
use std::{
cell::UnsafeCell,
marker::PhantomData,
ops::{Deref, DerefMut},
};
use crate::{memory::OCamlRef, value::OCaml};
thread_local! {
static TLS_RUNTIME: UnsafeCell<OCamlRuntime> = const { UnsafeCell::new({
OCamlRuntime { _not_send_sync: PhantomData }
})};
}
pub struct OCamlRuntimeStartupGuard {
_not_send_sync: PhantomData<*const ()>,
}
impl Deref for OCamlRuntimeStartupGuard {
type Target = OCamlRuntime;
fn deref(&self) -> &OCamlRuntime {
unsafe { internal::recover_runtime_handle() }
}
}
impl DerefMut for OCamlRuntimeStartupGuard {
fn deref_mut(&mut self) -> &mut OCamlRuntime {
unsafe { internal::recover_runtime_handle_mut() }
}
}
pub struct OCamlRuntime {
_not_send_sync: PhantomData<*const ()>,
}
impl OCamlRuntime {
pub fn init() -> Result<OCamlRuntimeStartupGuard, String> {
#[cfg(not(feature = "no-caml-startup"))]
{
use std::sync::atomic::{AtomicBool, Ordering};
static INIT_CALLED: AtomicBool = AtomicBool::new(false);
if INIT_CALLED
.compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
.is_err()
{
return Err("OCaml runtime already initialized".to_string());
}
unsafe {
let arg0 = c"ocaml".as_ptr() as *const ocaml_sys::Char;
let args = [arg0, core::ptr::null()];
ocaml_sys::caml_startup(args.as_ptr());
ocaml_boxroot_sys::boxroot_setup();
ocaml_sys::caml_enter_blocking_section();
}
Ok(OCamlRuntimeStartupGuard {
_not_send_sync: PhantomData,
})
}
#[cfg(feature = "no-caml-startup")]
return Err(
"Rust code called from OCaml should not try to initialize the runtime".to_string(),
);
}
pub fn releasing_runtime<T, F>(&mut self, f: F) -> T
where
F: FnOnce() -> T,
{
OCamlBlockingSection::new().perform(f)
}
pub fn get<'tmp, T>(&'tmp self, reference: OCamlRef<T>) -> OCaml<'tmp, T> {
OCaml {
_marker: PhantomData,
raw: unsafe { reference.get_raw() },
}
}
pub fn with_domain_lock<F, T>(f: F) -> T
where
F: FnOnce(&mut Self) -> T,
{
let mut lock = OCamlDomainLock::new();
f(&mut lock)
}
}
impl Drop for OCamlRuntimeStartupGuard {
fn drop(&mut self) {
unsafe {
ocaml_sys::caml_leave_blocking_section();
boxroot_teardown();
ocaml_sys::caml_shutdown();
}
}
}
struct OCamlBlockingSection;
impl OCamlBlockingSection {
fn new() -> Self {
Self
}
fn perform<T, F>(self, f: F) -> T
where
F: FnOnce() -> T,
{
unsafe { ocaml_sys::caml_enter_blocking_section() };
f()
}
}
impl Drop for OCamlBlockingSection {
fn drop(&mut self) {
unsafe { ocaml_sys::caml_leave_blocking_section() };
}
}
struct OCamlDomainLock {
_not_send_sync: PhantomData<*const ()>,
}
impl OCamlDomainLock {
#[inline(always)]
fn new() -> Self {
OCamlThreadRegistrationGuard::ensure();
unsafe {
ocaml_sys::caml_leave_blocking_section();
};
Self {
_not_send_sync: PhantomData,
}
}
}
impl Drop for OCamlDomainLock {
fn drop(&mut self) {
unsafe {
ocaml_sys::caml_enter_blocking_section();
};
}
}
impl Deref for OCamlDomainLock {
type Target = OCamlRuntime;
fn deref(&self) -> &OCamlRuntime {
unsafe { internal::recover_runtime_handle() }
}
}
impl DerefMut for OCamlDomainLock {
fn deref_mut(&mut self) -> &mut OCamlRuntime {
unsafe { internal::recover_runtime_handle_mut() }
}
}
extern "C" {
pub fn caml_c_thread_register() -> isize;
pub fn caml_c_thread_unregister() -> isize;
}
struct OCamlThreadRegistrationGuard {
registered: bool,
}
thread_local! {
static OCAML_THREAD_REGISTRATION_GUARD: OCamlThreadRegistrationGuard = {
let ok = unsafe { caml_c_thread_register() } == 1;
OCamlThreadRegistrationGuard { registered: ok }
};
}
impl OCamlThreadRegistrationGuard {
#[inline(always)]
pub fn ensure() {
OCAML_THREAD_REGISTRATION_GUARD.with(|_| {});
}
}
impl Drop for OCamlThreadRegistrationGuard {
fn drop(&mut self) {
if self.registered {
unsafe {
caml_c_thread_unregister();
}
}
}
}
#[no_mangle]
extern "C" fn ocaml_interop_setup(_unit: crate::RawOCaml) -> crate::RawOCaml {
ocaml_sys::UNIT
}
#[no_mangle]
extern "C" fn ocaml_interop_teardown(_unit: crate::RawOCaml) -> crate::RawOCaml {
unsafe { boxroot_teardown() };
ocaml_sys::UNIT
}
#[doc(hidden)]
pub mod internal {
use super::{OCamlRuntime, TLS_RUNTIME};
#[inline(always)]
pub unsafe fn recover_runtime_handle_mut() -> &'static mut OCamlRuntime {
TLS_RUNTIME.with(|cell| &mut *cell.get())
}
#[inline(always)]
pub unsafe fn recover_runtime_handle() -> &'static OCamlRuntime {
TLS_RUNTIME.with(|cell| &*cell.get())
}
}