use std::collections::HashMap;
use std::ffi::{c_char, c_int, c_long, c_void, CStr, CString};
use std::path::Path;
use std::ptr;
use std::sync::{Arc, Mutex, OnceLock};
use libloading::{Library, Symbol};
fn ampl_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
pub const FUNCADD_REAL_VALUED: i32 = 0;
pub const FUNCADD_STRING_ARGS: i32 = 1;
pub const FUNCADD_OUTPUT_ARGS: i32 = 2;
pub const FUNCADD_RANDOM_VALUED: i32 = 4;
#[repr(C)]
pub struct Arglist {
pub n: c_int, pub nr: c_int, pub at: *mut c_int, pub ra: *mut f64, pub sa: *mut *const c_char, pub derivs: *mut f64, pub hes: *mut f64, pub dig: *mut c_char, pub funcinfo: *mut c_void, pub ae: *mut AmplExports, pub f: *mut c_void, pub tva: *mut c_void, pub errmsg: *mut c_char, pub tmi: *mut c_void, pub private: *mut c_char,
pub nin: c_int,
pub nout: c_int,
pub nsin: c_int,
pub nsout: c_int,
}
pub type Rfunc = unsafe extern "C" fn(*mut Arglist) -> f64;
pub type AddfuncFn = unsafe extern "C" fn(
name: *const c_char,
f: Rfunc,
ty: c_int,
nargs: c_int,
funcinfo: *mut c_void,
ae: *mut AmplExports,
);
pub type RandSeedSetter = unsafe extern "C" fn(*mut c_void, std::os::raw::c_ulong);
pub type AddrandinitFn = unsafe extern "C" fn(
ae: *mut AmplExports,
setter: RandSeedSetter,
v: *mut c_void,
);
pub type AtResetFn = unsafe extern "C" fn(
ae: *mut AmplExports,
f: *mut c_void,
v: *mut c_void,
);
#[repr(C)]
pub struct AmplExports {
pub std_err: *mut c_void,
pub addfunc: Option<AddfuncFn>,
pub asl_date: c_long,
pub fprintf: *mut c_void,
pub printf: *mut c_void,
pub sprintf: *mut c_void,
pub vfprintf: *mut c_void,
pub vsprintf: *mut c_void,
pub strtod: *mut c_void,
pub crypto: *mut c_void,
pub asl: *mut c_char,
pub at_exit: *mut c_void,
pub at_reset: Option<AtResetFn>,
pub tempmem: *mut c_void,
pub add_table_handler: *mut c_void,
pub private_ae: *mut c_char,
pub qsortv: *mut c_void,
pub std_in: *mut c_void,
pub std_out: *mut c_void,
pub clearerr: *mut c_void,
pub fclose: *mut c_void,
pub fdopen: *mut c_void,
pub feof: *mut c_void,
pub ferror: *mut c_void,
pub fflush: *mut c_void,
pub fgetc: *mut c_void,
pub fgets: *mut c_void,
pub fileno: *mut c_void,
pub fopen: *mut c_void,
pub fputc: *mut c_void,
pub fputs: *mut c_void,
pub fread: *mut c_void,
pub freopen: *mut c_void,
pub fscanf: *mut c_void,
pub fseek: *mut c_void,
pub ftell: *mut c_void,
pub fwrite: *mut c_void,
pub pclose: *mut c_void,
pub perror: *mut c_void,
pub popen: *mut c_void,
pub puts: *mut c_void,
pub rewind: *mut c_void,
pub scanf: *mut c_void,
pub setbuf: *mut c_void,
pub setvbuf: *mut c_void,
pub sscanf: *mut c_void,
pub tempnam: *mut c_void,
pub tmpfile: *mut c_void,
pub tmpnam: *mut c_void,
pub ungetc: *mut c_void,
pub ai: *mut c_void,
pub getenv: *mut c_void,
pub breakfunc: *mut c_void,
pub breakarg: *mut c_char,
pub snprintf: *mut c_void,
pub vsnprintf: *mut c_void,
pub addrand: *mut c_void,
pub addrandinit: Option<AddrandinitFn>,
}
unsafe impl Send for AmplExports {}
unsafe impl Sync for AmplExports {}
#[derive(Debug, Clone)]
pub struct RegisteredFunc {
pub name: String,
pub rfunc: Rfunc,
pub ty: i32,
pub nargs: i32,
pub funcinfo: *mut c_void,
}
unsafe impl Send for RegisteredFunc {}
unsafe impl Sync for RegisteredFunc {}
impl std::fmt::Debug for ExternalLibrary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExternalLibrary")
.field("funcs", &self.funcs.keys().collect::<Vec<_>>())
.finish()
}
}
pub struct ExternalLibrary {
_lib: Arc<Library>,
_ae: Box<AmplExports>,
funcs: HashMap<String, RegisteredFunc>,
}
impl ExternalLibrary {
pub fn load(path: &Path) -> Result<Self, String> {
let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
let lib = unsafe { Library::new(path) }
.map_err(|e| format!("failed to open '{}': {}", path.display(), e))?;
type FuncaddFn = unsafe extern "C" fn(*mut AmplExports);
let funcadd: Symbol<FuncaddFn> = unsafe { lib.get(b"funcadd_ASL\0") }
.map_err(|e| format!("no funcadd_ASL in '{}': {}", path.display(), e))?;
let mut ae = Box::new(AmplExports {
std_err: ptr::null_mut(),
addfunc: Some(trampoline_addfunc),
asl_date: 20160307,
fprintf: ptr::null_mut(),
printf: ptr::null_mut(),
sprintf: ptr::null_mut(),
vfprintf: ptr::null_mut(),
vsprintf: ptr::null_mut(),
strtod: ptr::null_mut(),
crypto: ptr::null_mut(),
asl: ptr::null_mut(),
at_exit: ptr::null_mut(),
at_reset: Some(trampoline_atreset),
tempmem: ptr::null_mut(),
add_table_handler: ptr::null_mut(),
private_ae: ptr::null_mut(),
qsortv: ptr::null_mut(),
std_in: ptr::null_mut(),
std_out: ptr::null_mut(),
clearerr: ptr::null_mut(),
fclose: ptr::null_mut(),
fdopen: ptr::null_mut(),
feof: ptr::null_mut(),
ferror: ptr::null_mut(),
fflush: ptr::null_mut(),
fgetc: ptr::null_mut(),
fgets: ptr::null_mut(),
fileno: ptr::null_mut(),
fopen: ptr::null_mut(),
fputc: ptr::null_mut(),
fputs: ptr::null_mut(),
fread: ptr::null_mut(),
freopen: ptr::null_mut(),
fscanf: ptr::null_mut(),
fseek: ptr::null_mut(),
ftell: ptr::null_mut(),
fwrite: ptr::null_mut(),
pclose: ptr::null_mut(),
perror: ptr::null_mut(),
popen: ptr::null_mut(),
puts: ptr::null_mut(),
rewind: ptr::null_mut(),
scanf: ptr::null_mut(),
setbuf: ptr::null_mut(),
setvbuf: ptr::null_mut(),
sscanf: ptr::null_mut(),
tempnam: ptr::null_mut(),
tmpfile: ptr::null_mut(),
tmpnam: ptr::null_mut(),
ungetc: ptr::null_mut(),
ai: ptr::null_mut(),
getenv: ptr::null_mut(),
breakfunc: ptr::null_mut(),
breakarg: ptr::null_mut(),
snprintf: ptr::null_mut(),
vsnprintf: ptr::null_mut(),
addrand: ptr::null_mut(),
addrandinit: Some(trampoline_addrandinit),
});
REGISTRY_SINK.with(|sink| {
let mut guard = sink.borrow_mut();
assert!(guard.is_none(), "nested ExternalLibrary::load is not supported");
*guard = Some(HashMap::new());
});
unsafe { funcadd(ae.as_mut()) };
let funcs = REGISTRY_SINK
.with(|sink| sink.borrow_mut().take())
.unwrap_or_default();
Ok(ExternalLibrary {
_lib: Arc::new(lib),
_ae: ae,
funcs,
})
}
pub fn function_names(&self) -> impl Iterator<Item = &str> {
self.funcs.keys().map(|s| s.as_str())
}
pub fn get(&self, name: &str) -> Option<&RegisteredFunc> {
self.funcs.get(name)
}
pub fn eval(
&self,
name: &str,
args: &[ExternalArg<'_>],
want_derivs: bool,
want_hes: bool,
) -> Result<EvalResult, String> {
let rf = self
.funcs
.get(name)
.ok_or_else(|| format!("no such external function '{name}'"))?;
let n = args.len() as i32;
if rf.nargs >= 0 {
if rf.nargs != n {
return Err(format!(
"external '{name}' expects {} args, got {}",
rf.nargs, n
));
}
} else {
let min_args = -(rf.nargs + 1);
if n < min_args {
return Err(format!(
"external '{name}' expects at least {min_args} args, got {n}"
));
}
}
let mut at_vec: Vec<c_int> = Vec::with_capacity(args.len());
let mut ra_vec: Vec<f64> = Vec::new();
let mut sa_owned: Vec<CString> = Vec::new();
for a in args {
match a {
ExternalArg::Real(x) => {
at_vec.push(ra_vec.len() as c_int);
ra_vec.push(*x);
}
ExternalArg::Str(s) => {
let cs = CString::new(*s)
.map_err(|_| format!("external '{name}' string arg contains NUL"))?;
at_vec.push(-(sa_owned.len() as c_int + 1));
sa_owned.push(cs);
}
}
}
let nr = ra_vec.len() as c_int;
let sa_ptrs: Vec<*const c_char> = sa_owned.iter().map(|s| s.as_ptr()).collect();
let has_strings = !sa_owned.is_empty();
if has_strings && (rf.ty & FUNCADD_STRING_ARGS) == 0 {
return Err(format!(
"external '{name}' is not declared FUNCADD_STRING_ARGS but was \
called with string arguments"
));
}
let mut derivs_buf: Vec<f64> = if want_derivs {
vec![0.0; nr as usize]
} else {
Vec::new()
};
let hes_len = if want_hes {
(nr as usize) * ((nr as usize) + 1) / 2
} else {
0
};
let mut hes_buf: Vec<f64> = if want_hes { vec![0.0; hes_len] } else { Vec::new() };
let mut errmsg_buf: Vec<c_char> = vec![0; 1024];
let mut al = Arglist {
n,
nr,
at: if at_vec.is_empty() {
ptr::null_mut()
} else {
at_vec.as_mut_ptr()
},
ra: if ra_vec.is_empty() {
ptr::null_mut()
} else {
ra_vec.as_mut_ptr()
},
sa: if sa_ptrs.is_empty() {
ptr::null_mut()
} else {
sa_ptrs.as_ptr() as *mut *const c_char
},
derivs: if want_derivs {
derivs_buf.as_mut_ptr()
} else {
ptr::null_mut()
},
hes: if want_hes {
hes_buf.as_mut_ptr()
} else {
ptr::null_mut()
},
dig: ptr::null_mut(),
funcinfo: rf.funcinfo,
ae: self._ae_ptr(),
f: ptr::null_mut(),
tva: ptr::null_mut(),
errmsg: errmsg_buf.as_mut_ptr(),
tmi: ptr::null_mut(),
private: ptr::null_mut(),
nin: 0,
nout: 0,
nsin: 0,
nsout: 0,
};
let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
let value = unsafe { (rf.rfunc)(&mut al as *mut Arglist) };
drop(_guard);
if errmsg_buf[0] != 0 {
let msg = unsafe { CStr::from_ptr(errmsg_buf.as_ptr()) }
.to_string_lossy()
.into_owned();
return Err(format!("external '{name}' reported: {msg}"));
}
Ok(EvalResult {
value,
derivs: if want_derivs { Some(derivs_buf) } else { None },
hessian: if want_hes { Some(hes_buf) } else { None },
})
}
fn _ae_ptr(&self) -> *mut AmplExports {
(&*self._ae as *const AmplExports) as *mut AmplExports
}
}
#[derive(Debug, Clone, Copy)]
pub enum ExternalArg<'a> {
Real(f64),
Str(&'a str),
}
#[derive(Debug, Clone)]
pub struct EvalResult {
pub value: f64,
pub derivs: Option<Vec<f64>>,
pub hessian: Option<Vec<f64>>,
}
thread_local! {
static REGISTRY_SINK: std::cell::RefCell<Option<HashMap<String, RegisteredFunc>>> =
std::cell::RefCell::new(None);
}
unsafe extern "C" fn trampoline_addfunc(
name: *const c_char,
f: Rfunc,
ty: c_int,
nargs: c_int,
funcinfo: *mut c_void,
_ae: *mut AmplExports,
) {
if name.is_null() {
return;
}
let cname = unsafe { CStr::from_ptr(name) };
let name_str = match cname.to_str() {
Ok(s) => s.to_owned(),
Err(_) => return, };
REGISTRY_SINK.with(|sink| {
if let Some(map) = sink.borrow_mut().as_mut() {
map.insert(
name_str.clone(),
RegisteredFunc {
name: name_str,
rfunc: f,
ty: ty as i32,
nargs: nargs as i32,
funcinfo,
},
);
}
});
}
unsafe extern "C" fn trampoline_atreset(
_ae: *mut AmplExports,
_f: *mut c_void,
_v: *mut c_void,
) {
log::debug!("external library registered an AtReset callback; ignoring");
}
unsafe extern "C" fn trampoline_addrandinit(
_ae: *mut AmplExports,
setter: RandSeedSetter,
v: *mut c_void,
) {
unsafe { setter(v, 1) };
}
#[cfg(test)]
mod tests {
use super::*;
fn idaes_dylib() -> Option<std::path::PathBuf> {
let home = std::env::var_os("HOME")?;
let p = std::path::PathBuf::from(home)
.join(".idaes/bin/general_helmholtz_external.dylib");
if p.exists() {
Some(p)
} else {
None
}
}
fn idaes_functions_dylib() -> Option<std::path::PathBuf> {
let home = std::env::var_os("HOME")?;
let p = std::path::PathBuf::from(home).join(".idaes/bin/functions.dylib");
if p.exists() { Some(p) } else { None }
}
fn idaes_params_dir() -> Option<String> {
let home = std::env::var_os("HOME")?;
let p = std::path::PathBuf::from(home).join(
"Dropbox/uv/.venv/lib/python3.12/site-packages/idaes/\
models/properties/general_helmholtz/components/parameters/",
);
if p.exists() {
p.to_str().map(|s| s.to_owned())
} else {
None
}
}
#[test]
fn load_idaes_helmholtz_dylib_registers_known_functions() {
let Some(path) = idaes_dylib() else {
eprintln!("skipping: IDAES dylib not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load should succeed");
let names: Vec<String> = lib.function_names().map(|s| s.to_owned()).collect();
for required in &["vf_hp", "h_liq_hp", "h_vap_hp"] {
assert!(
names.iter().any(|n| n == required),
"expected {required} in registered names: {names:?}"
);
}
}
#[test]
fn eval_vf_hp_at_fixture_initial_point() {
let Some(path) = idaes_dylib() else {
eprintln!("skipping: IDAES dylib not present");
return;
};
let Some(params_dir) = idaes_params_dir() else {
eprintln!("skipping: IDAES parameters directory not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let args = [
ExternalArg::Str("h2o"),
ExternalArg::Real(1878.71 * 0.055508472036052976),
ExternalArg::Real(101325.0 * 0.001),
ExternalArg::Str(¶ms_dir),
];
let res = lib.eval("vf_hp", &args, false, false).expect("eval");
assert!(
res.value.is_finite(),
"vf_hp returned non-finite value {}",
res.value
);
}
#[test]
fn eval_vf_hp_with_derivatives() {
let Some(path) = idaes_dylib() else {
eprintln!("skipping: IDAES dylib not present");
return;
};
let Some(params_dir) = idaes_params_dir() else {
eprintln!("skipping: IDAES parameters directory not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let args = [
ExternalArg::Str("h2o"),
ExternalArg::Real(1878.71 * 0.055508472036052976),
ExternalArg::Real(101325.0 * 0.001),
ExternalArg::Str(¶ms_dir),
];
let res = lib.eval("vf_hp", &args, true, false).expect("eval");
let derivs = res.derivs.expect("derivs requested");
assert_eq!(derivs.len(), 2, "nr=2 reals -> 2 derivatives");
for (i, d) in derivs.iter().enumerate() {
assert!(d.is_finite(), "derivs[{i}] = {d} not finite");
}
}
#[test]
fn eval_vf_hp_with_hessian() {
let Some(path) = idaes_dylib() else {
eprintln!("skipping: IDAES dylib not present");
return;
};
let Some(params_dir) = idaes_params_dir() else {
eprintln!("skipping: IDAES parameters directory not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let args = [
ExternalArg::Str("h2o"),
ExternalArg::Real(1878.71 * 0.055508472036052976),
ExternalArg::Real(101325.0 * 0.001),
ExternalArg::Str(¶ms_dir),
];
let res = lib.eval("vf_hp", &args, true, true).expect("eval");
let hes = res.hessian.expect("hessian requested");
assert_eq!(hes.len(), 3, "nr=2 -> packed Hessian of length 3");
for (i, h) in hes.iter().enumerate() {
assert!(h.is_finite(), "hes[{i}] = {h} not finite");
}
}
#[test]
fn eval_cbrt_matches_closed_form() {
let Some(path) = idaes_functions_dylib() else {
eprintln!("skipping: IDAES functions.dylib not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let rf = lib.get("cbrt").expect("cbrt registered");
assert_eq!(rf.ty, 0, "cbrt should be FUNCADD_REAL_VALUED (type=0)");
assert_eq!(rf.nargs, 1, "cbrt should have exact arity 1");
let args = [ExternalArg::Real(8.0)];
let res = lib.eval("cbrt", &args, true, true).expect("eval");
let derivs = res.derivs.expect("derivs requested");
let hes = res.hessian.expect("hessian requested");
assert_eq!(derivs.len(), 1);
assert_eq!(hes.len(), 1, "nr=1 -> packed Hessian length 1");
assert!((res.value - 2.0).abs() < 1e-12, "value {}", res.value);
assert!((derivs[0] - 1.0 / 12.0).abs() < 1e-12, "deriv {}", derivs[0]);
assert!((hes[0] + 1.0 / 144.0).abs() < 1e-12, "hes {}", hes[0]);
}
#[test]
fn eval_cbrt_arity_mismatch_errors() {
let Some(path) = idaes_functions_dylib() else {
eprintln!("skipping: IDAES functions.dylib not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let too_few: [ExternalArg; 0] = [];
assert!(lib.eval("cbrt", &too_few, false, false).is_err());
let too_many = [ExternalArg::Real(1.0), ExternalArg::Real(2.0)];
assert!(lib.eval("cbrt", &too_many, false, false).is_err());
}
#[test]
fn eval_cbrt_rejects_string_arg() {
let Some(path) = idaes_functions_dylib() else {
eprintln!("skipping: IDAES functions.dylib not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let args = [ExternalArg::Str("nope")];
let err = lib
.eval("cbrt", &args, false, false)
.err()
.expect("string arg to type=0 function must error");
assert!(
err.to_lowercase().contains("string"),
"error should mention strings, got: {err}"
);
}
}