use std::collections::HashMap;
use std::ffi::{c_char, c_int, c_long, c_void, CStr, CString};
use std::path::Path;
use std::ptr;
use std::sync::{Arc, Mutex, OnceLock};
use libloading::{Library, Symbol};
use crate::nl_reader::{Expr, FuncallArg, ImportedFunc};
#[derive(Default, Clone)]
pub struct ExternalResolver {
pub funcs_by_id: HashMap<usize, (Arc<ExternalLibrary>, String)>,
}
impl ExternalResolver {
pub fn is_empty(&self) -> bool {
self.funcs_by_id.is_empty()
}
pub fn build_for_problem(
imported_funcs: &[ImportedFunc],
referenced_ids: &std::collections::BTreeSet<usize>,
) -> Result<Self, String> {
if referenced_ids.is_empty() {
return Ok(Self::default());
}
let amplfunc = std::env::var("AMPLFUNC").map_err(|_| {
"problem uses external functions but AMPLFUNC is not set; \
set AMPLFUNC to a newline-separated list of AMPL shared-library paths"
.to_string()
})?;
let mut libs: Vec<Arc<ExternalLibrary>> = Vec::new();
for path_str in amplfunc
.split('\n')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
{
let path = std::path::Path::new(path_str);
let lib = ExternalLibrary::load(path).map_err(|e| format!("AMPLFUNC: {e}"))?;
libs.push(Arc::new(lib));
}
let mut funcs_by_id: HashMap<usize, (Arc<ExternalLibrary>, String)> = HashMap::new();
for id in referenced_ids {
let decl = imported_funcs
.iter()
.find(|f| f.id == *id)
.ok_or_else(|| format!("funcall id {id} has no F<{id}> declaration"))?;
let found = libs
.iter()
.find(|lib| lib.get(&decl.name).is_some())
.ok_or_else(|| {
format!(
"external function '{}' (id {}) not found in any library on AMPLFUNC",
decl.name, decl.id
)
})?;
funcs_by_id.insert(*id, (found.clone(), decl.name.clone()));
}
Ok(Self { funcs_by_id })
}
}
pub fn collect_funcall_ids(e: &Expr, out: &mut std::collections::BTreeSet<usize>) {
match e {
Expr::Const(_) | Expr::Var(_) => {}
Expr::Binary(_, a, b) => {
collect_funcall_ids(a, out);
collect_funcall_ids(b, out);
}
Expr::Unary(_, a) => collect_funcall_ids(a, out),
Expr::Sum(args) => {
for a in args {
collect_funcall_ids(a, out);
}
}
Expr::Cse(body) => collect_funcall_ids(body, out),
Expr::Funcall { id, args } => {
out.insert(*id);
for arg in args {
if let FuncallArg::Real(e) = arg {
collect_funcall_ids(e, out);
}
}
}
}
}
fn ampl_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
pub const FUNCADD_REAL_VALUED: i32 = 0;
pub const FUNCADD_STRING_ARGS: i32 = 1;
pub const FUNCADD_OUTPUT_ARGS: i32 = 2;
pub const FUNCADD_RANDOM_VALUED: i32 = 4;
#[repr(C)]
pub struct Arglist {
pub n: c_int, pub nr: c_int, pub at: *mut c_int, pub ra: *mut f64, pub sa: *mut *const c_char, pub derivs: *mut f64, pub hes: *mut f64, pub dig: *mut c_char, pub funcinfo: *mut c_void, pub ae: *mut AmplExports, pub f: *mut c_void, pub tva: *mut c_void, pub errmsg: *mut c_char, pub tmi: *mut c_void, pub private: *mut c_char,
pub nin: c_int,
pub nout: c_int,
pub nsin: c_int,
pub nsout: c_int,
}
pub type Rfunc = unsafe extern "C" fn(*mut Arglist) -> f64;
pub type AddfuncFn = unsafe extern "C" fn(
name: *const c_char,
f: Rfunc,
ty: c_int,
nargs: c_int,
funcinfo: *mut c_void,
ae: *mut AmplExports,
);
pub type RandSeedSetter = unsafe extern "C" fn(*mut c_void, std::os::raw::c_ulong);
pub type AddrandinitFn =
unsafe extern "C" fn(ae: *mut AmplExports, setter: RandSeedSetter, v: *mut c_void);
pub type AtResetFn = unsafe extern "C" fn(ae: *mut AmplExports, f: *mut c_void, v: *mut c_void);
#[repr(C)]
pub struct AmplExports {
pub std_err: *mut c_void,
pub addfunc: Option<AddfuncFn>,
pub asl_date: c_long,
pub fprintf: *mut c_void,
pub printf: *mut c_void,
pub sprintf: *mut c_void,
pub vfprintf: *mut c_void,
pub vsprintf: *mut c_void,
pub strtod: *mut c_void,
pub crypto: *mut c_void,
pub asl: *mut c_char,
pub at_exit: *mut c_void,
pub at_reset: Option<AtResetFn>,
pub tempmem: *mut c_void,
pub add_table_handler: *mut c_void,
pub private_ae: *mut c_char,
pub qsortv: *mut c_void,
pub std_in: *mut c_void,
pub std_out: *mut c_void,
pub clearerr: *mut c_void,
pub fclose: *mut c_void,
pub fdopen: *mut c_void,
pub feof: *mut c_void,
pub ferror: *mut c_void,
pub fflush: *mut c_void,
pub fgetc: *mut c_void,
pub fgets: *mut c_void,
pub fileno: *mut c_void,
pub fopen: *mut c_void,
pub fputc: *mut c_void,
pub fputs: *mut c_void,
pub fread: *mut c_void,
pub freopen: *mut c_void,
pub fscanf: *mut c_void,
pub fseek: *mut c_void,
pub ftell: *mut c_void,
pub fwrite: *mut c_void,
pub pclose: *mut c_void,
pub perror: *mut c_void,
pub popen: *mut c_void,
pub puts: *mut c_void,
pub rewind: *mut c_void,
pub scanf: *mut c_void,
pub setbuf: *mut c_void,
pub setvbuf: *mut c_void,
pub sscanf: *mut c_void,
pub tempnam: *mut c_void,
pub tmpfile: *mut c_void,
pub tmpnam: *mut c_void,
pub ungetc: *mut c_void,
pub ai: *mut c_void,
pub getenv: *mut c_void,
pub breakfunc: *mut c_void,
pub breakarg: *mut c_char,
pub snprintf: *mut c_void,
pub vsnprintf: *mut c_void,
pub addrand: *mut c_void,
pub addrandinit: Option<AddrandinitFn>,
}
unsafe impl Send for AmplExports {}
unsafe impl Sync for AmplExports {}
#[derive(Debug, Clone)]
pub struct RegisteredFunc {
pub name: String,
pub rfunc: Rfunc,
pub ty: i32,
pub nargs: i32,
pub funcinfo: *mut c_void,
}
unsafe impl Send for RegisteredFunc {}
unsafe impl Sync for RegisteredFunc {}
impl std::fmt::Debug for ExternalLibrary {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ExternalLibrary")
.field("funcs", &self.funcs.keys().collect::<Vec<_>>())
.finish()
}
}
pub struct ExternalLibrary {
_lib: Arc<Library>,
_ae: Box<AmplExports>,
funcs: HashMap<String, RegisteredFunc>,
}
impl ExternalLibrary {
pub fn load(path: &Path) -> Result<Self, String> {
let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
let lib = unsafe { Library::new(path) }
.map_err(|e| format!("failed to open '{}': {}", path.display(), e))?;
type FuncaddFn = unsafe extern "C" fn(*mut AmplExports);
let funcadd: Symbol<FuncaddFn> = unsafe { lib.get(b"funcadd_ASL\0") }
.map_err(|e| format!("no funcadd_ASL in '{}': {}", path.display(), e))?;
let mut ae = Box::new(AmplExports {
std_err: ptr::null_mut(),
addfunc: Some(trampoline_addfunc),
asl_date: 20160307,
fprintf: ptr::null_mut(),
printf: ptr::null_mut(),
sprintf: ptr::null_mut(),
vfprintf: ptr::null_mut(),
vsprintf: ptr::null_mut(),
strtod: ptr::null_mut(),
crypto: ptr::null_mut(),
asl: ptr::null_mut(),
at_exit: ptr::null_mut(),
at_reset: Some(trampoline_atreset),
tempmem: ptr::null_mut(),
add_table_handler: ptr::null_mut(),
private_ae: ptr::null_mut(),
qsortv: ptr::null_mut(),
std_in: ptr::null_mut(),
std_out: ptr::null_mut(),
clearerr: ptr::null_mut(),
fclose: ptr::null_mut(),
fdopen: ptr::null_mut(),
feof: ptr::null_mut(),
ferror: ptr::null_mut(),
fflush: ptr::null_mut(),
fgetc: ptr::null_mut(),
fgets: ptr::null_mut(),
fileno: ptr::null_mut(),
fopen: ptr::null_mut(),
fputc: ptr::null_mut(),
fputs: ptr::null_mut(),
fread: ptr::null_mut(),
freopen: ptr::null_mut(),
fscanf: ptr::null_mut(),
fseek: ptr::null_mut(),
ftell: ptr::null_mut(),
fwrite: ptr::null_mut(),
pclose: ptr::null_mut(),
perror: ptr::null_mut(),
popen: ptr::null_mut(),
puts: ptr::null_mut(),
rewind: ptr::null_mut(),
scanf: ptr::null_mut(),
setbuf: ptr::null_mut(),
setvbuf: ptr::null_mut(),
sscanf: ptr::null_mut(),
tempnam: ptr::null_mut(),
tmpfile: ptr::null_mut(),
tmpnam: ptr::null_mut(),
ungetc: ptr::null_mut(),
ai: ptr::null_mut(),
getenv: ptr::null_mut(),
breakfunc: ptr::null_mut(),
breakarg: ptr::null_mut(),
snprintf: ptr::null_mut(),
vsnprintf: ptr::null_mut(),
addrand: ptr::null_mut(),
addrandinit: Some(trampoline_addrandinit),
});
REGISTRY_SINK.with(|sink| {
let mut guard = sink.borrow_mut();
assert!(
guard.is_none(),
"nested ExternalLibrary::load is not supported"
);
*guard = Some(HashMap::new());
});
unsafe { funcadd(ae.as_mut()) };
let funcs = REGISTRY_SINK
.with(|sink| sink.borrow_mut().take())
.unwrap_or_default();
Ok(ExternalLibrary {
_lib: Arc::new(lib),
_ae: ae,
funcs,
})
}
pub fn function_names(&self) -> impl Iterator<Item = &str> {
self.funcs.keys().map(|s| s.as_str())
}
pub fn get(&self, name: &str) -> Option<&RegisteredFunc> {
self.funcs.get(name)
}
pub fn eval(
&self,
name: &str,
args: &[ExternalArg<'_>],
want_derivs: bool,
want_hes: bool,
) -> Result<EvalResult, String> {
let rf = self
.funcs
.get(name)
.ok_or_else(|| format!("no such external function '{name}'"))?;
let n = args.len() as i32;
if rf.nargs >= 0 {
if rf.nargs != n {
return Err(format!(
"external '{name}' expects {} args, got {}",
rf.nargs, n
));
}
} else {
let min_args = -(rf.nargs + 1);
if n < min_args {
return Err(format!(
"external '{name}' expects at least {min_args} args, got {n}"
));
}
}
let mut at_vec: Vec<c_int> = Vec::with_capacity(args.len());
let mut ra_vec: Vec<f64> = Vec::new();
let mut sa_owned: Vec<CString> = Vec::new();
for a in args {
match a {
ExternalArg::Real(x) => {
at_vec.push(ra_vec.len() as c_int);
ra_vec.push(*x);
}
ExternalArg::Str(s) => {
let cs = CString::new(*s)
.map_err(|_| format!("external '{name}' string arg contains NUL"))?;
at_vec.push(-(sa_owned.len() as c_int + 1));
sa_owned.push(cs);
}
}
}
let nr = ra_vec.len() as c_int;
let sa_ptrs: Vec<*const c_char> = sa_owned.iter().map(|s| s.as_ptr()).collect();
let has_strings = !sa_owned.is_empty();
if has_strings && (rf.ty & FUNCADD_STRING_ARGS) == 0 {
return Err(format!(
"external '{name}' is not declared FUNCADD_STRING_ARGS but was \
called with string arguments"
));
}
let mut derivs_buf: Vec<f64> = if want_derivs {
vec![0.0; nr as usize]
} else {
Vec::new()
};
let hes_len = if want_hes {
(nr as usize) * ((nr as usize) + 1) / 2
} else {
0
};
let mut hes_buf: Vec<f64> = if want_hes {
vec![0.0; hes_len]
} else {
Vec::new()
};
let mut errmsg_buf: Vec<c_char> = vec![0; 1024];
let mut al = Arglist {
n,
nr,
at: if at_vec.is_empty() {
ptr::null_mut()
} else {
at_vec.as_mut_ptr()
},
ra: if ra_vec.is_empty() {
ptr::null_mut()
} else {
ra_vec.as_mut_ptr()
},
sa: if sa_ptrs.is_empty() {
ptr::null_mut()
} else {
sa_ptrs.as_ptr() as *mut *const c_char
},
derivs: if want_derivs {
derivs_buf.as_mut_ptr()
} else {
ptr::null_mut()
},
hes: if want_hes {
hes_buf.as_mut_ptr()
} else {
ptr::null_mut()
},
dig: ptr::null_mut(),
funcinfo: rf.funcinfo,
ae: self._ae_ptr(),
f: ptr::null_mut(),
tva: ptr::null_mut(),
errmsg: errmsg_buf.as_mut_ptr(),
tmi: ptr::null_mut(),
private: ptr::null_mut(),
nin: 0,
nout: 0,
nsin: 0,
nsout: 0,
};
let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
let value = unsafe { (rf.rfunc)(&mut al as *mut Arglist) };
drop(_guard);
if errmsg_buf[0] != 0 {
let msg = unsafe { CStr::from_ptr(errmsg_buf.as_ptr()) }
.to_string_lossy()
.into_owned();
return Err(format!("external '{name}' reported: {msg}"));
}
Ok(EvalResult {
value,
derivs: if want_derivs { Some(derivs_buf) } else { None },
hessian: if want_hes { Some(hes_buf) } else { None },
})
}
fn _ae_ptr(&self) -> *mut AmplExports {
(&*self._ae as *const AmplExports) as *mut AmplExports
}
}
#[derive(Debug, Clone, Copy)]
pub enum ExternalArg<'a> {
Real(f64),
Str(&'a str),
}
#[derive(Debug, Clone)]
pub struct EvalResult {
pub value: f64,
pub derivs: Option<Vec<f64>>,
pub hessian: Option<Vec<f64>>,
}
thread_local! {
static REGISTRY_SINK: std::cell::RefCell<Option<HashMap<String, RegisteredFunc>>> =
std::cell::RefCell::new(None);
}
unsafe extern "C" fn trampoline_addfunc(
name: *const c_char,
f: Rfunc,
ty: c_int,
nargs: c_int,
funcinfo: *mut c_void,
_ae: *mut AmplExports,
) {
if name.is_null() {
return;
}
let cname = unsafe { CStr::from_ptr(name) };
let name_str = match cname.to_str() {
Ok(s) => s.to_owned(),
Err(_) => return, };
REGISTRY_SINK.with(|sink| {
if let Some(map) = sink.borrow_mut().as_mut() {
map.insert(
name_str.clone(),
RegisteredFunc {
name: name_str,
rfunc: f,
ty: ty as i32,
nargs: nargs as i32,
funcinfo,
},
);
}
});
}
unsafe extern "C" fn trampoline_atreset(_ae: *mut AmplExports, _f: *mut c_void, _v: *mut c_void) {
tracing::debug!("external library registered an AtReset callback; ignoring");
}
unsafe extern "C" fn trampoline_addrandinit(
_ae: *mut AmplExports,
setter: RandSeedSetter,
v: *mut c_void,
) {
unsafe { setter(v, 1) };
}
#[cfg(test)]
mod tests {
use super::*;
fn idaes_dylib() -> Option<std::path::PathBuf> {
let home = std::env::var_os("HOME")?;
let p = std::path::PathBuf::from(home).join(".idaes/bin/general_helmholtz_external.dylib");
if p.exists() {
Some(p)
} else {
None
}
}
fn idaes_params_dir() -> Option<String> {
let home = std::env::var_os("HOME")?;
let p = std::path::PathBuf::from(home).join(
"Dropbox/uv/.venv/lib/python3.12/site-packages/idaes/\
models/properties/general_helmholtz/components/parameters/",
);
if p.exists() {
p.to_str().map(|s| s.to_owned())
} else {
None
}
}
#[test]
fn load_idaes_helmholtz_dylib_registers_known_functions() {
let Some(path) = idaes_dylib() else {
eprintln!("skipping: IDAES dylib not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load should succeed");
let names: Vec<String> = lib.function_names().map(|s| s.to_owned()).collect();
for required in &["vf_hp", "h_liq_hp", "h_vap_hp"] {
assert!(
names.iter().any(|n| n == required),
"expected {required} in registered names: {names:?}"
);
}
}
#[test]
fn eval_vf_hp_at_fixture_initial_point() {
let Some(path) = idaes_dylib() else {
eprintln!("skipping: IDAES dylib not present");
return;
};
let Some(params_dir) = idaes_params_dir() else {
eprintln!("skipping: IDAES parameters directory not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let args = [
ExternalArg::Str("h2o"),
ExternalArg::Real(1878.71 * 0.055508472036052976),
ExternalArg::Real(101325.0 * 0.001),
ExternalArg::Str(¶ms_dir),
];
let res = lib.eval("vf_hp", &args, false, false).expect("eval");
assert!(
res.value.is_finite(),
"vf_hp returned non-finite value {}",
res.value
);
}
#[test]
fn eval_vf_hp_with_derivatives() {
let Some(path) = idaes_dylib() else {
eprintln!("skipping: IDAES dylib not present");
return;
};
let Some(params_dir) = idaes_params_dir() else {
eprintln!("skipping: IDAES parameters directory not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let args = [
ExternalArg::Str("h2o"),
ExternalArg::Real(1878.71 * 0.055508472036052976),
ExternalArg::Real(101325.0 * 0.001),
ExternalArg::Str(¶ms_dir),
];
let res = lib.eval("vf_hp", &args, true, false).expect("eval");
let derivs = res.derivs.expect("derivs requested");
assert_eq!(derivs.len(), 2, "nr=2 reals -> 2 derivatives");
for (i, d) in derivs.iter().enumerate() {
assert!(d.is_finite(), "derivs[{i}] = {d} not finite");
}
}
#[test]
fn eval_vf_hp_with_hessian() {
let Some(path) = idaes_dylib() else {
eprintln!("skipping: IDAES dylib not present");
return;
};
let Some(params_dir) = idaes_params_dir() else {
eprintln!("skipping: IDAES parameters directory not present");
return;
};
let lib = ExternalLibrary::load(&path).expect("load");
let args = [
ExternalArg::Str("h2o"),
ExternalArg::Real(1878.71 * 0.055508472036052976),
ExternalArg::Real(101325.0 * 0.001),
ExternalArg::Str(¶ms_dir),
];
let res = lib.eval("vf_hp", &args, true, true).expect("eval");
let hes = res.hessian.expect("hessian requested");
assert_eq!(hes.len(), 3, "nr=2 -> packed Hessian of length 3");
for (i, h) in hes.iter().enumerate() {
assert!(h.is_finite(), "hes[{i}] = {h} not finite");
}
}
}