use std::{
ffi::CString,
marker::PhantomData,
os::raw::c_void,
panic::{self, AssertUnwindSafe},
sync::Arc,
};
use crate::{Context, Error, Result, ValueType, check_err, sys};
type PrimOpFn =
dyn Fn(&[PrimOpArg<'_>], &mut PrimOpRet<'_>) -> Result<()> + Send + Sync;
struct ClosureData {
arity: usize,
f: Box<PrimOpFn>,
}
unsafe extern "C" fn trampoline(
user_data: *mut c_void,
context: *mut sys::nix_c_context,
state: *mut sys::EvalState,
args: *mut *mut sys::nix_value,
ret: *mut sys::nix_value,
) {
let data = unsafe { &*(user_data as *const ClosureData) };
let arg_slice = unsafe { std::slice::from_raw_parts(args, data.arity) };
let arg_wrappers: Vec<PrimOpArg<'_>> = arg_slice
.iter()
.map(|&p| {
PrimOpArg {
inner: p,
ctx: context,
state,
_phantom: PhantomData,
}
})
.collect();
let mut ret_wrapper = PrimOpRet {
inner: ret,
ctx: context,
_phantom: PhantomData,
};
let result = panic::catch_unwind(AssertUnwindSafe(|| {
(data.f)(&arg_wrappers, &mut ret_wrapper)
}));
let err_msg = match result {
Ok(Ok(())) => return,
Ok(Err(e)) => format!("primop error: {e}"),
Err(_) => "primop panicked".to_string(),
};
let msg_c = CString::new(err_msg)
.unwrap_or_else(|_| CString::new("primop error").unwrap());
unsafe {
sys::nix_set_err_msg(
context,
sys::nix_err_NIX_ERR_NIX_ERROR,
msg_c.as_ptr(),
);
}
}
unsafe extern "C" fn drop_closure_finalizer(
_obj: *mut c_void,
cd: *mut c_void,
) {
let _ = unsafe { Box::from_raw(cd as *mut ClosureData) };
}
pub struct PrimOpArg<'a> {
inner: *mut sys::nix_value,
ctx: *mut sys::nix_c_context,
state: *mut sys::EvalState,
_phantom: PhantomData<&'a ()>,
}
unsafe impl Send for PrimOpArg<'_> {}
unsafe impl Sync for PrimOpArg<'_> {}
impl PrimOpArg<'_> {
#[must_use]
pub fn value_type(&self) -> ValueType {
let c_type = unsafe { sys::nix_get_type(self.ctx, self.inner) };
ValueType::from_c(c_type)
}
pub fn force(&mut self) -> Result<()> {
unsafe {
check_err(
self.ctx,
sys::nix_value_force(self.ctx, self.state, self.inner),
)
}
}
pub fn as_int(&self) -> Result<i64> {
if self.value_type() != ValueType::Int {
return Err(Error::InvalidType {
expected: "int",
actual: self.value_type().to_string(),
});
}
Ok(unsafe { sys::nix_get_int(self.ctx, self.inner) })
}
pub fn as_float(&self) -> Result<f64> {
if self.value_type() != ValueType::Float {
return Err(Error::InvalidType {
expected: "float",
actual: self.value_type().to_string(),
});
}
Ok(unsafe { sys::nix_get_float(self.ctx, self.inner) })
}
pub fn as_bool(&self) -> Result<bool> {
if self.value_type() != ValueType::Bool {
return Err(Error::InvalidType {
expected: "bool",
actual: self.value_type().to_string(),
});
}
Ok(unsafe { sys::nix_get_bool(self.ctx, self.inner) })
}
pub fn as_string(&self) -> Result<String> {
if self.value_type() != ValueType::String {
return Err(Error::InvalidType {
expected: "string",
actual: self.value_type().to_string(),
});
}
let realised_str = unsafe {
sys::nix_string_realise(self.ctx, self.state, self.inner, false)
};
if realised_str.is_null() {
return Err(Error::NullPointer);
}
let buffer_start =
unsafe { sys::nix_realised_string_get_buffer_start(realised_str) };
let buffer_size =
unsafe { sys::nix_realised_string_get_buffer_size(realised_str) };
if buffer_start.is_null() {
unsafe { sys::nix_realised_string_free(realised_str) };
return Err(Error::NullPointer);
}
let bytes = unsafe {
std::slice::from_raw_parts(buffer_start.cast::<u8>(), buffer_size)
};
let s = std::str::from_utf8(bytes)
.map_err(|_| Error::Unknown("Invalid UTF-8 in string".into()))?
.to_owned();
unsafe { sys::nix_realised_string_free(realised_str) };
Ok(s)
}
}
pub struct PrimOpRet<'a> {
inner: *mut sys::nix_value,
ctx: *mut sys::nix_c_context,
_phantom: PhantomData<&'a mut ()>,
}
impl PrimOpRet<'_> {
pub fn set_int(&mut self, i: i64) -> Result<()> {
unsafe { check_err(self.ctx, sys::nix_init_int(self.ctx, self.inner, i)) }
}
pub fn set_float(&mut self, f: f64) -> Result<()> {
unsafe { check_err(self.ctx, sys::nix_init_float(self.ctx, self.inner, f)) }
}
pub fn set_bool(&mut self, b: bool) -> Result<()> {
unsafe { check_err(self.ctx, sys::nix_init_bool(self.ctx, self.inner, b)) }
}
pub fn set_null(&mut self) -> Result<()> {
unsafe { check_err(self.ctx, sys::nix_init_null(self.ctx, self.inner)) }
}
pub fn set_string(&mut self, s: &str) -> Result<()> {
let s_c = CString::new(s)?;
unsafe {
check_err(
self.ctx,
sys::nix_init_string(self.ctx, self.inner, s_c.as_ptr()),
)
}
}
pub unsafe fn copy_from_raw(
&mut self,
src: *mut sys::nix_value,
) -> Result<()> {
unsafe {
check_err(self.ctx, sys::nix_copy_value(self.ctx, self.inner, src))
}
}
}
pub struct PrimOp {
inner: *mut sys::PrimOp,
context: Arc<Context>,
registered: bool,
}
unsafe impl Send for PrimOp {}
unsafe impl Sync for PrimOp {}
impl PrimOp {
pub fn new<F>(
context: &Arc<Context>,
name: &str,
arity: u32,
doc: Option<&str>,
f: F,
) -> Result<Self>
where
F: Fn(&[PrimOpArg<'_>], &mut PrimOpRet<'_>) -> Result<()>
+ Send
+ Sync
+ 'static,
{
let name_c = CString::new(name)?;
let doc_c = doc.map(CString::new).transpose()?;
let doc_ptr = doc_c.as_ref().map_or(std::ptr::null(), |c| c.as_ptr());
let data = Box::new(ClosureData {
arity: arity as usize,
f: Box::new(f),
});
let data_raw = Box::into_raw(data) as *mut c_void;
let primop_ptr = unsafe {
sys::nix_alloc_primop(
context.as_ptr(),
Some(trampoline),
arity as std::os::raw::c_int,
name_c.as_ptr(),
std::ptr::null_mut(), doc_ptr,
data_raw,
)
};
if primop_ptr.is_null() {
let _ = unsafe { Box::from_raw(data_raw as *mut ClosureData) };
return Err(Error::NullPointer);
}
unsafe {
sys::nix_gc_register_finalizer(
primop_ptr as *mut c_void,
data_raw,
Some(drop_closure_finalizer),
);
}
Ok(PrimOp {
inner: primop_ptr,
context: Arc::clone(context),
registered: false,
})
}
pub fn register(mut self, context: &Context) -> Result<()> {
let err = unsafe { sys::nix_register_primop(context.as_ptr(), self.inner) };
check_err(unsafe { self.context.as_ptr() }, err)?;
self.registered = true;
Ok(())
}
pub fn into_value(
self,
state: &crate::EvalState,
) -> Result<crate::Value<'_>> {
let v = state.alloc_value()?;
unsafe {
check_err(
state.context.as_ptr(),
sys::nix_init_primop(
state.context.as_ptr(),
v.inner.as_ptr(),
self.inner,
),
)?;
}
Ok(v)
}
}
impl Drop for PrimOp {
fn drop(&mut self) {
if !self.registered && !self.inner.is_null() {
unsafe {
let _ = sys::nix_gc_decref(
self.context.as_ptr(),
self.inner as *const c_void,
);
}
}
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use serial_test::serial;
use super::*;
use crate::{Context, EvalStateBuilder, Store};
#[test]
#[serial]
fn test_primop_into_value() {
let ctx = Arc::new(Context::new().expect("Failed to create context"));
let store =
Arc::new(Store::open(&ctx, None).expect("Failed to open store"));
let state = EvalStateBuilder::new(&store)
.expect("Failed to create builder")
.build()
.expect("Failed to build state");
let primop =
PrimOp::new(&ctx, "negate", 1, Some("Negate an integer"), |args, ret| {
let n = args[0].as_int()?;
ret.set_int(-n)
})
.expect("Failed to create primop");
let func = primop
.into_value(&state)
.expect("Failed to embed primop as value");
let arg = state.make_int(7).unwrap();
let result = func.call(&arg).expect("Failed to call primop");
assert_eq!(result.as_int().unwrap(), -7);
}
#[test]
#[serial]
fn test_primop_into_value_string() {
let ctx = Arc::new(Context::new().expect("Failed to create context"));
let store =
Arc::new(Store::open(&ctx, None).expect("Failed to open store"));
let state = EvalStateBuilder::new(&store)
.expect("Failed to create builder")
.build()
.expect("Failed to build state");
let primop = PrimOp::new(&ctx, "hello", 1, None, |_args, ret| {
ret.set_string("hello from primop")
})
.expect("Failed to create primop");
let func = primop
.into_value(&state)
.expect("Failed to embed primop as value");
let arg = state.make_null().unwrap();
let result = func.call(&arg).expect("Failed to call primop");
assert_eq!(result.as_string().unwrap(), "hello from primop");
}
}