1use std::collections::HashMap;
21use std::ffi::{c_char, c_int, c_long, c_void, CStr, CString};
22use std::path::Path;
23use std::ptr;
24use std::sync::{Arc, Mutex, OnceLock};
25
26use libloading::{Library, Symbol};
27
28use crate::nl_reader::{Expr, FuncallArg, ImportedFunc};
29
30#[derive(Default, Clone)]
36pub struct ExternalResolver {
37 pub funcs_by_id: HashMap<usize, (Arc<ExternalLibrary>, String)>,
39}
40
41impl ExternalResolver {
42 pub fn is_empty(&self) -> bool {
43 self.funcs_by_id.is_empty()
44 }
45
46 pub fn build_for_problem(
55 imported_funcs: &[ImportedFunc],
56 referenced_ids: &std::collections::BTreeSet<usize>,
57 ) -> Result<Self, String> {
58 if referenced_ids.is_empty() {
59 return Ok(Self::default());
60 }
61 let amplfunc = std::env::var("AMPLFUNC").map_err(|_| {
62 "problem uses external functions but AMPLFUNC is not set; \
63 set AMPLFUNC to a newline-separated list of AMPL shared-library paths"
64 .to_string()
65 })?;
66 let mut libs: Vec<Arc<ExternalLibrary>> = Vec::new();
67 for path_str in amplfunc
68 .split('\n')
69 .map(|s| s.trim())
70 .filter(|s| !s.is_empty())
71 {
72 let path = std::path::Path::new(path_str);
73 let lib = ExternalLibrary::load(path).map_err(|e| format!("AMPLFUNC: {e}"))?;
74 libs.push(Arc::new(lib));
75 }
76
77 let mut funcs_by_id: HashMap<usize, (Arc<ExternalLibrary>, String)> = HashMap::new();
78 for id in referenced_ids {
79 let decl = imported_funcs
80 .iter()
81 .find(|f| f.id == *id)
82 .ok_or_else(|| format!("funcall id {id} has no F<{id}> declaration"))?;
83 let found = libs
84 .iter()
85 .find(|lib| lib.get(&decl.name).is_some())
86 .ok_or_else(|| {
87 format!(
88 "external function '{}' (id {}) not found in any library on AMPLFUNC",
89 decl.name, decl.id
90 )
91 })?;
92 funcs_by_id.insert(*id, (found.clone(), decl.name.clone()));
93 }
94 Ok(Self { funcs_by_id })
95 }
96}
97
98pub fn collect_funcall_ids(e: &Expr, out: &mut std::collections::BTreeSet<usize>) {
102 match e {
103 Expr::Const(_) | Expr::Var(_) => {}
104 Expr::Binary(_, a, b) => {
105 collect_funcall_ids(a, out);
106 collect_funcall_ids(b, out);
107 }
108 Expr::Unary(_, a) => collect_funcall_ids(a, out),
109 Expr::Sum(args) => {
110 for a in args {
111 collect_funcall_ids(a, out);
112 }
113 }
114 Expr::Cse(body) => collect_funcall_ids(body, out),
115 Expr::Funcall { id, args } => {
116 out.insert(*id);
117 for arg in args {
118 if let FuncallArg::Real(e) = arg {
119 collect_funcall_ids(e, out);
120 }
121 }
122 }
123 }
124}
125
126fn ampl_lock() -> &'static Mutex<()> {
132 static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
133 LOCK.get_or_init(|| Mutex::new(()))
134}
135
136pub const FUNCADD_REAL_VALUED: i32 = 0;
138pub const FUNCADD_STRING_ARGS: i32 = 1;
140pub const FUNCADD_OUTPUT_ARGS: i32 = 2;
142pub const FUNCADD_RANDOM_VALUED: i32 = 4;
143
144#[repr(C)]
146pub struct Arglist {
147 pub n: c_int, pub nr: c_int, pub at: *mut c_int, pub ra: *mut f64, pub sa: *mut *const c_char, pub derivs: *mut f64, pub hes: *mut f64, pub dig: *mut c_char, pub funcinfo: *mut c_void, pub ae: *mut AmplExports, pub f: *mut c_void, pub tva: *mut c_void, pub errmsg: *mut c_char, pub tmi: *mut c_void, pub private: *mut c_char,
162 pub nin: c_int,
163 pub nout: c_int,
164 pub nsin: c_int,
165 pub nsout: c_int,
166}
167
168pub type Rfunc = unsafe extern "C" fn(*mut Arglist) -> f64;
171
172pub type AddfuncFn = unsafe extern "C" fn(
174 name: *const c_char,
175 f: Rfunc,
176 ty: c_int,
177 nargs: c_int,
178 funcinfo: *mut c_void,
179 ae: *mut AmplExports,
180);
181
182pub type RandSeedSetter = unsafe extern "C" fn(*mut c_void, std::os::raw::c_ulong);
184
185pub type AddrandinitFn =
187 unsafe extern "C" fn(ae: *mut AmplExports, setter: RandSeedSetter, v: *mut c_void);
188
189pub type AtResetFn = unsafe extern "C" fn(ae: *mut AmplExports, f: *mut c_void, v: *mut c_void);
191
192#[repr(C)]
197pub struct AmplExports {
198 pub std_err: *mut c_void,
199 pub addfunc: Option<AddfuncFn>,
200 pub asl_date: c_long,
201 pub fprintf: *mut c_void,
202 pub printf: *mut c_void,
203 pub sprintf: *mut c_void,
204 pub vfprintf: *mut c_void,
205 pub vsprintf: *mut c_void,
206 pub strtod: *mut c_void,
207 pub crypto: *mut c_void,
208 pub asl: *mut c_char,
209 pub at_exit: *mut c_void,
210 pub at_reset: Option<AtResetFn>,
211 pub tempmem: *mut c_void,
212 pub add_table_handler: *mut c_void,
213 pub private_ae: *mut c_char,
214 pub qsortv: *mut c_void,
215
216 pub std_in: *mut c_void,
217 pub std_out: *mut c_void,
218 pub clearerr: *mut c_void,
219 pub fclose: *mut c_void,
220 pub fdopen: *mut c_void,
221 pub feof: *mut c_void,
222 pub ferror: *mut c_void,
223 pub fflush: *mut c_void,
224 pub fgetc: *mut c_void,
225 pub fgets: *mut c_void,
226 pub fileno: *mut c_void,
227 pub fopen: *mut c_void,
228 pub fputc: *mut c_void,
229 pub fputs: *mut c_void,
230 pub fread: *mut c_void,
231 pub freopen: *mut c_void,
232 pub fscanf: *mut c_void,
233 pub fseek: *mut c_void,
234 pub ftell: *mut c_void,
235 pub fwrite: *mut c_void,
236 pub pclose: *mut c_void,
237 pub perror: *mut c_void,
238 pub popen: *mut c_void,
239 pub puts: *mut c_void,
240 pub rewind: *mut c_void,
241 pub scanf: *mut c_void,
242 pub setbuf: *mut c_void,
243 pub setvbuf: *mut c_void,
244 pub sscanf: *mut c_void,
245 pub tempnam: *mut c_void,
246 pub tmpfile: *mut c_void,
247 pub tmpnam: *mut c_void,
248 pub ungetc: *mut c_void,
249 pub ai: *mut c_void,
250 pub getenv: *mut c_void,
251 pub breakfunc: *mut c_void,
252 pub breakarg: *mut c_char,
253
254 pub snprintf: *mut c_void,
256 pub vsnprintf: *mut c_void,
257
258 pub addrand: *mut c_void,
259 pub addrandinit: Option<AddrandinitFn>,
260}
261
262unsafe impl Send for AmplExports {}
268unsafe impl Sync for AmplExports {}
269
270#[derive(Debug, Clone)]
273pub struct RegisteredFunc {
274 pub name: String,
275 pub rfunc: Rfunc,
276 pub ty: i32,
278 pub nargs: i32,
281 pub funcinfo: *mut c_void,
283}
284
285unsafe impl Send for RegisteredFunc {}
289unsafe impl Sync for RegisteredFunc {}
290
291impl std::fmt::Debug for ExternalLibrary {
292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293 f.debug_struct("ExternalLibrary")
294 .field("funcs", &self.funcs.keys().collect::<Vec<_>>())
295 .finish()
296 }
297}
298
299pub struct ExternalLibrary {
301 _lib: Arc<Library>,
304 _ae: Box<AmplExports>,
308 funcs: HashMap<String, RegisteredFunc>,
310}
311
312impl ExternalLibrary {
313 pub fn load(path: &Path) -> Result<Self, String> {
316 let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
319 let lib = unsafe { Library::new(path) }
323 .map_err(|e| format!("failed to open '{}': {}", path.display(), e))?;
324
325 type FuncaddFn = unsafe extern "C" fn(*mut AmplExports);
328 let funcadd: Symbol<FuncaddFn> = unsafe { lib.get(b"funcadd_ASL\0") }
329 .map_err(|e| format!("no funcadd_ASL in '{}': {}", path.display(), e))?;
330
331 let mut ae = Box::new(AmplExports {
335 std_err: ptr::null_mut(),
336 addfunc: Some(trampoline_addfunc),
337 asl_date: 20160307,
340 fprintf: ptr::null_mut(),
341 printf: ptr::null_mut(),
342 sprintf: ptr::null_mut(),
343 vfprintf: ptr::null_mut(),
344 vsprintf: ptr::null_mut(),
345 strtod: ptr::null_mut(),
346 crypto: ptr::null_mut(),
347 asl: ptr::null_mut(),
348 at_exit: ptr::null_mut(),
349 at_reset: Some(trampoline_atreset),
350 tempmem: ptr::null_mut(),
351 add_table_handler: ptr::null_mut(),
352 private_ae: ptr::null_mut(),
353 qsortv: ptr::null_mut(),
354 std_in: ptr::null_mut(),
355 std_out: ptr::null_mut(),
356 clearerr: ptr::null_mut(),
357 fclose: ptr::null_mut(),
358 fdopen: ptr::null_mut(),
359 feof: ptr::null_mut(),
360 ferror: ptr::null_mut(),
361 fflush: ptr::null_mut(),
362 fgetc: ptr::null_mut(),
363 fgets: ptr::null_mut(),
364 fileno: ptr::null_mut(),
365 fopen: ptr::null_mut(),
366 fputc: ptr::null_mut(),
367 fputs: ptr::null_mut(),
368 fread: ptr::null_mut(),
369 freopen: ptr::null_mut(),
370 fscanf: ptr::null_mut(),
371 fseek: ptr::null_mut(),
372 ftell: ptr::null_mut(),
373 fwrite: ptr::null_mut(),
374 pclose: ptr::null_mut(),
375 perror: ptr::null_mut(),
376 popen: ptr::null_mut(),
377 puts: ptr::null_mut(),
378 rewind: ptr::null_mut(),
379 scanf: ptr::null_mut(),
380 setbuf: ptr::null_mut(),
381 setvbuf: ptr::null_mut(),
382 sscanf: ptr::null_mut(),
383 tempnam: ptr::null_mut(),
384 tmpfile: ptr::null_mut(),
385 tmpnam: ptr::null_mut(),
386 ungetc: ptr::null_mut(),
387 ai: ptr::null_mut(),
388 getenv: ptr::null_mut(),
389 breakfunc: ptr::null_mut(),
390 breakarg: ptr::null_mut(),
391 snprintf: ptr::null_mut(),
392 vsnprintf: ptr::null_mut(),
393 addrand: ptr::null_mut(),
394 addrandinit: Some(trampoline_addrandinit),
395 });
396
397 REGISTRY_SINK.with(|sink| {
400 let mut guard = sink.borrow_mut();
401 assert!(
402 guard.is_none(),
403 "nested ExternalLibrary::load is not supported"
404 );
405 *guard = Some(HashMap::new());
406 });
407
408 unsafe { funcadd(ae.as_mut()) };
411
412 let funcs = REGISTRY_SINK
413 .with(|sink| sink.borrow_mut().take())
414 .unwrap_or_default();
415
416 Ok(ExternalLibrary {
417 _lib: Arc::new(lib),
418 _ae: ae,
419 funcs,
420 })
421 }
422
423 pub fn function_names(&self) -> impl Iterator<Item = &str> {
425 self.funcs.keys().map(|s| s.as_str())
426 }
427
428 pub fn get(&self, name: &str) -> Option<&RegisteredFunc> {
430 self.funcs.get(name)
431 }
432
433 pub fn eval(
445 &self,
446 name: &str,
447 args: &[ExternalArg<'_>],
448 want_derivs: bool,
449 want_hes: bool,
450 ) -> Result<EvalResult, String> {
451 let rf = self
452 .funcs
453 .get(name)
454 .ok_or_else(|| format!("no such external function '{name}'"))?;
455
456 let n = args.len() as i32;
458 if rf.nargs >= 0 {
459 if rf.nargs != n {
460 return Err(format!(
461 "external '{name}' expects {} args, got {}",
462 rf.nargs, n
463 ));
464 }
465 } else {
466 let min_args = -(rf.nargs + 1);
468 if n < min_args {
469 return Err(format!(
470 "external '{name}' expects at least {min_args} args, got {n}"
471 ));
472 }
473 }
474
475 let mut at_vec: Vec<c_int> = Vec::with_capacity(args.len());
477 let mut ra_vec: Vec<f64> = Vec::new();
478 let mut sa_owned: Vec<CString> = Vec::new();
479 for a in args {
480 match a {
481 ExternalArg::Real(x) => {
482 at_vec.push(ra_vec.len() as c_int);
483 ra_vec.push(*x);
484 }
485 ExternalArg::Str(s) => {
486 let cs = CString::new(*s)
487 .map_err(|_| format!("external '{name}' string arg contains NUL"))?;
488 at_vec.push(-(sa_owned.len() as c_int + 1));
489 sa_owned.push(cs);
490 }
491 }
492 }
493 let nr = ra_vec.len() as c_int;
494 let sa_ptrs: Vec<*const c_char> = sa_owned.iter().map(|s| s.as_ptr()).collect();
495
496 let has_strings = !sa_owned.is_empty();
499 if has_strings && (rf.ty & FUNCADD_STRING_ARGS) == 0 {
500 return Err(format!(
501 "external '{name}' is not declared FUNCADD_STRING_ARGS but was \
502 called with string arguments"
503 ));
504 }
505
506 let mut derivs_buf: Vec<f64> = if want_derivs {
508 vec![0.0; nr as usize]
509 } else {
510 Vec::new()
511 };
512 let hes_len = if want_hes {
513 (nr as usize) * ((nr as usize) + 1) / 2
514 } else {
515 0
516 };
517 let mut hes_buf: Vec<f64> = if want_hes {
518 vec![0.0; hes_len]
519 } else {
520 Vec::new()
521 };
522
523 let mut errmsg_buf: Vec<c_char> = vec![0; 1024];
525
526 let mut al = Arglist {
530 n,
531 nr,
532 at: if at_vec.is_empty() {
533 ptr::null_mut()
534 } else {
535 at_vec.as_mut_ptr()
536 },
537 ra: if ra_vec.is_empty() {
538 ptr::null_mut()
539 } else {
540 ra_vec.as_mut_ptr()
541 },
542 sa: if sa_ptrs.is_empty() {
543 ptr::null_mut()
544 } else {
545 sa_ptrs.as_ptr() as *mut *const c_char
546 },
547 derivs: if want_derivs {
548 derivs_buf.as_mut_ptr()
549 } else {
550 ptr::null_mut()
551 },
552 hes: if want_hes {
553 hes_buf.as_mut_ptr()
554 } else {
555 ptr::null_mut()
556 },
557 dig: ptr::null_mut(),
558 funcinfo: rf.funcinfo,
559 ae: self._ae_ptr(),
562 f: ptr::null_mut(),
563 tva: ptr::null_mut(),
564 errmsg: errmsg_buf.as_mut_ptr(),
565 tmi: ptr::null_mut(),
566 private: ptr::null_mut(),
567 nin: 0,
568 nout: 0,
569 nsin: 0,
570 nsout: 0,
571 };
572
573 let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
577 let value = unsafe { (rf.rfunc)(&mut al as *mut Arglist) };
578 drop(_guard);
579
580 if errmsg_buf[0] != 0 {
583 let msg = unsafe { CStr::from_ptr(errmsg_buf.as_ptr()) }
586 .to_string_lossy()
587 .into_owned();
588 return Err(format!("external '{name}' reported: {msg}"));
589 }
590
591 Ok(EvalResult {
592 value,
593 derivs: if want_derivs { Some(derivs_buf) } else { None },
594 hessian: if want_hes { Some(hes_buf) } else { None },
595 })
596 }
597
598 fn _ae_ptr(&self) -> *mut AmplExports {
602 (&*self._ae as *const AmplExports) as *mut AmplExports
604 }
605}
606
607#[derive(Debug, Clone, Copy)]
609pub enum ExternalArg<'a> {
610 Real(f64),
611 Str(&'a str),
612}
613
614#[derive(Debug, Clone)]
616pub struct EvalResult {
617 pub value: f64,
619 pub derivs: Option<Vec<f64>>,
621 pub hessian: Option<Vec<f64>>,
624}
625
626thread_local! {
636 static REGISTRY_SINK: std::cell::RefCell<Option<HashMap<String, RegisteredFunc>>> =
637 std::cell::RefCell::new(None);
638}
639
640unsafe extern "C" fn trampoline_addfunc(
642 name: *const c_char,
643 f: Rfunc,
644 ty: c_int,
645 nargs: c_int,
646 funcinfo: *mut c_void,
647 _ae: *mut AmplExports,
648) {
649 if name.is_null() {
650 return;
651 }
652 let cname = unsafe { CStr::from_ptr(name) };
654 let name_str = match cname.to_str() {
655 Ok(s) => s.to_owned(),
656 Err(_) => return, };
658 REGISTRY_SINK.with(|sink| {
659 if let Some(map) = sink.borrow_mut().as_mut() {
660 map.insert(
661 name_str.clone(),
662 RegisteredFunc {
663 name: name_str,
664 rfunc: f,
665 ty: ty as i32,
666 nargs: nargs as i32,
667 funcinfo,
668 },
669 );
670 }
671 });
672}
673
674unsafe extern "C" fn trampoline_atreset(_ae: *mut AmplExports, _f: *mut c_void, _v: *mut c_void) {
677 tracing::debug!("external library registered an AtReset callback; ignoring");
678}
679
680unsafe extern "C" fn trampoline_addrandinit(
683 _ae: *mut AmplExports,
684 setter: RandSeedSetter,
685 v: *mut c_void,
686) {
687 unsafe { setter(v, 1) };
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693
694 fn idaes_dylib() -> Option<std::path::PathBuf> {
695 let home = std::env::var_os("HOME")?;
696 let p = std::path::PathBuf::from(home).join(".idaes/bin/general_helmholtz_external.dylib");
697 if p.exists() {
698 Some(p)
699 } else {
700 None
701 }
702 }
703
704 fn idaes_params_dir() -> Option<String> {
705 let home = std::env::var_os("HOME")?;
706 let p = std::path::PathBuf::from(home).join(
707 "Dropbox/uv/.venv/lib/python3.12/site-packages/idaes/\
708 models/properties/general_helmholtz/components/parameters/",
709 );
710 if p.exists() {
711 p.to_str().map(|s| s.to_owned())
712 } else {
713 None
714 }
715 }
716
717 #[test]
720 fn load_idaes_helmholtz_dylib_registers_known_functions() {
721 let Some(path) = idaes_dylib() else {
722 eprintln!("skipping: IDAES dylib not present");
723 return;
724 };
725
726 let lib = ExternalLibrary::load(&path).expect("load should succeed");
727 let names: Vec<String> = lib.function_names().map(|s| s.to_owned()).collect();
728
729 for required in &["vf_hp", "h_liq_hp", "h_vap_hp"] {
730 assert!(
731 names.iter().any(|n| n == required),
732 "expected {required} in registered names: {names:?}"
733 );
734 }
735 }
736
737 #[test]
741 fn eval_vf_hp_at_fixture_initial_point() {
742 let Some(path) = idaes_dylib() else {
743 eprintln!("skipping: IDAES dylib not present");
744 return;
745 };
746 let Some(params_dir) = idaes_params_dir() else {
747 eprintln!("skipping: IDAES parameters directory not present");
748 return;
749 };
750
751 let lib = ExternalLibrary::load(&path).expect("load");
752 let args = [
757 ExternalArg::Str("h2o"),
758 ExternalArg::Real(1878.71 * 0.055508472036052976),
759 ExternalArg::Real(101325.0 * 0.001),
760 ExternalArg::Str(¶ms_dir),
761 ];
762 let res = lib.eval("vf_hp", &args, false, false).expect("eval");
763 assert!(
764 res.value.is_finite(),
765 "vf_hp returned non-finite value {}",
766 res.value
767 );
768 }
769
770 #[test]
773 fn eval_vf_hp_with_derivatives() {
774 let Some(path) = idaes_dylib() else {
775 eprintln!("skipping: IDAES dylib not present");
776 return;
777 };
778 let Some(params_dir) = idaes_params_dir() else {
779 eprintln!("skipping: IDAES parameters directory not present");
780 return;
781 };
782
783 let lib = ExternalLibrary::load(&path).expect("load");
784 let args = [
785 ExternalArg::Str("h2o"),
786 ExternalArg::Real(1878.71 * 0.055508472036052976),
787 ExternalArg::Real(101325.0 * 0.001),
788 ExternalArg::Str(¶ms_dir),
789 ];
790 let res = lib.eval("vf_hp", &args, true, false).expect("eval");
791 let derivs = res.derivs.expect("derivs requested");
792 assert_eq!(derivs.len(), 2, "nr=2 reals -> 2 derivatives");
793 for (i, d) in derivs.iter().enumerate() {
794 assert!(d.is_finite(), "derivs[{i}] = {d} not finite");
795 }
796 }
797
798 #[test]
801 fn eval_vf_hp_with_hessian() {
802 let Some(path) = idaes_dylib() else {
803 eprintln!("skipping: IDAES dylib not present");
804 return;
805 };
806 let Some(params_dir) = idaes_params_dir() else {
807 eprintln!("skipping: IDAES parameters directory not present");
808 return;
809 };
810
811 let lib = ExternalLibrary::load(&path).expect("load");
812 let args = [
813 ExternalArg::Str("h2o"),
814 ExternalArg::Real(1878.71 * 0.055508472036052976),
815 ExternalArg::Real(101325.0 * 0.001),
816 ExternalArg::Str(¶ms_dir),
817 ];
818 let res = lib.eval("vf_hp", &args, true, true).expect("eval");
819 let hes = res.hessian.expect("hessian requested");
820 assert_eq!(hes.len(), 3, "nr=2 -> packed Hessian of length 3");
821 for (i, h) in hes.iter().enumerate() {
822 assert!(h.is_finite(), "hes[{i}] = {h} not finite");
823 }
824 }
825}