Skip to main content

pounce_cli/
nl_external.rs

1//! AMPL imported (external) function support via the `funcadd_ASL` ABI.
2//!
3//! This module implements enough of AMPL's `funcadd.h` ABI to:
4//!
5//! 1. `dlopen` a user-supplied shared library;
6//! 2. resolve the `funcadd_ASL` symbol and call it;
7//! 3. receive registration callbacks of the form `Addfunc(name, rfunc, type,
8//!    nargs, funcinfo, ae)` and record them;
9//! 4. later call back into the registered `rfunc` with an `arglist` to obtain
10//!    function values, gradients, and Hessians.
11//!
12//! The `AmplExports` and `Arglist` struct layouts are taken from
13//! AMPL-MP/ASL `funcadd.h`; cross-checked against the ctypes mapping in
14//! `pyomo.core.base.external`. Fields we don't populate are left null —
15//! Pyomo does the same and it is sufficient for IDAES's Helmholtz library
16//! (see issue #15).
17//!
18//! All unsafe FFI is contained in this module. Public surface is safe.
19
20use std::collections::HashMap;
21use std::ffi::{c_char, c_int, c_long, c_void, CStr, CString};
22use std::path::Path;
23use std::ptr;
24use std::sync::{Arc, Mutex, OnceLock};
25
26use libloading::{Library, Symbol};
27
28use crate::nl_reader::{Expr, FuncallArg, ImportedFunc};
29
30/// Resolved AMPL imported function: shared library + registered name.
31/// `NlProblem` carries one of these per `ImportedFunc` id when external
32/// functions are wired up at problem-build time. The same `Arc<ExternalLibrary>`
33/// may be shared across many funcall ids (one library typically registers
34/// several functions).
35#[derive(Default, Clone)]
36pub struct ExternalResolver {
37    /// `Funcall { id }` -> (library, registered function name).
38    pub funcs_by_id: HashMap<usize, (Arc<ExternalLibrary>, String)>,
39}
40
41impl ExternalResolver {
42    pub fn is_empty(&self) -> bool {
43        self.funcs_by_id.is_empty()
44    }
45
46    /// Build a resolver for every `ImportedFunc` declared in the `.nl` file
47    /// that is *actually referenced* somewhere in the problem's expressions.
48    ///
49    /// Library paths are resolved through the `AMPLFUNC` environment variable
50    /// (a `\n`-separated list of shared-library paths, matching AMPL/IPOPT
51    /// conventions). Each path is loaded once and queried for every name we
52    /// need. Returns an error if a referenced name cannot be found in any
53    /// listed library, or if `AMPLFUNC` is missing.
54    pub fn build_for_problem(
55        imported_funcs: &[ImportedFunc],
56        referenced_ids: &std::collections::BTreeSet<usize>,
57    ) -> Result<Self, String> {
58        if referenced_ids.is_empty() {
59            return Ok(Self::default());
60        }
61        let amplfunc = std::env::var("AMPLFUNC").map_err(|_| {
62            "problem uses external functions but AMPLFUNC is not set; \
63             set AMPLFUNC to a newline-separated list of AMPL shared-library paths"
64                .to_string()
65        })?;
66        let mut libs: Vec<Arc<ExternalLibrary>> = Vec::new();
67        for path_str in amplfunc
68            .split('\n')
69            .map(|s| s.trim())
70            .filter(|s| !s.is_empty())
71        {
72            let path = std::path::Path::new(path_str);
73            let lib = ExternalLibrary::load(path).map_err(|e| format!("AMPLFUNC: {e}"))?;
74            libs.push(Arc::new(lib));
75        }
76
77        let mut funcs_by_id: HashMap<usize, (Arc<ExternalLibrary>, String)> = HashMap::new();
78        for id in referenced_ids {
79            let decl = imported_funcs
80                .iter()
81                .find(|f| f.id == *id)
82                .ok_or_else(|| format!("funcall id {id} has no F<{id}> declaration"))?;
83            let found = libs
84                .iter()
85                .find(|lib| lib.get(&decl.name).is_some())
86                .ok_or_else(|| {
87                    format!(
88                        "external function '{}' (id {}) not found in any library on AMPLFUNC",
89                        decl.name, decl.id
90                    )
91                })?;
92            funcs_by_id.insert(*id, (found.clone(), decl.name.clone()));
93        }
94        Ok(Self { funcs_by_id })
95    }
96}
97
98/// Walk an `Expr` and collect every funcall id it references (including
99/// through CSEs). Used to build an `ExternalResolver` covering exactly the
100/// functions a problem actually uses.
101pub fn collect_funcall_ids(e: &Expr, out: &mut std::collections::BTreeSet<usize>) {
102    match e {
103        Expr::Const(_) | Expr::Var(_) => {}
104        Expr::Binary(_, a, b) => {
105            collect_funcall_ids(a, out);
106            collect_funcall_ids(b, out);
107        }
108        Expr::Unary(_, a) => collect_funcall_ids(a, out),
109        Expr::Sum(args) => {
110            for a in args {
111                collect_funcall_ids(a, out);
112            }
113        }
114        Expr::Cse(body) => collect_funcall_ids(body, out),
115        Expr::Funcall { id, args } => {
116            out.insert(*id);
117            for arg in args {
118                if let FuncallArg::Real(e) = arg {
119                    collect_funcall_ids(e, out);
120                }
121            }
122        }
123    }
124}
125
126/// Process-wide lock serialising every call that crosses the AMPL external
127/// ABI. Real AMPL libraries (e.g. IDAES general_helmholtz) keep mutable
128/// global state (cached parameters, tabulated lookups) and are not safe for
129/// concurrent entry. Python's `pyomo.core.base.external` relies on the GIL
130/// for the same guarantee.
131fn ampl_lock() -> &'static Mutex<()> {
132    static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
133    LOCK.get_or_init(|| Mutex::new(()))
134}
135
136/// FUNCADD_TYPE bits (mirrors `funcadd.h`).
137pub const FUNCADD_REAL_VALUED: i32 = 0;
138/// Set if the function consumes string arguments. Value is still real.
139pub const FUNCADD_STRING_ARGS: i32 = 1;
140/// Set if the function is allowed to have a variable number of args.
141pub const FUNCADD_OUTPUT_ARGS: i32 = 2;
142pub const FUNCADD_RANDOM_VALUED: i32 = 4;
143
144/// The `arglist` struct from AMPL's `funcadd.h`. Layout must match exactly.
145#[repr(C)]
146pub struct Arglist {
147    pub n: c_int,               // number of args
148    pub nr: c_int,              // number of real input args
149    pub at: *mut c_int,         // argument types
150    pub ra: *mut f64,           // pure real args (IN/OUT/INOUT)
151    pub sa: *mut *const c_char, // symbolic IN args
152    pub derivs: *mut f64,       // partial derivatives (if non-null)
153    pub hes: *mut f64,          // second partials (if non-null)
154    pub dig: *mut c_char,       // skip-derivatives flags
155    pub funcinfo: *mut c_void,  // per-function cookie (set by Addfunc)
156    pub ae: *mut AmplExports,   // points back at our AmplExports
157    pub f: *mut c_void,         // AMPL-internal
158    pub tva: *mut c_void,       // AMPL-internal
159    pub errmsg: *mut c_char,    // error description set by the function
160    pub tmi: *mut c_void,       // Tempmem cookie
161    pub private: *mut c_char,
162    pub nin: c_int,
163    pub nout: c_int,
164    pub nsin: c_int,
165    pub nsout: c_int,
166}
167
168/// Pointer to a user-defined real-valued function, matching
169/// `typedef real (*rfunc)(arglist*)`.
170pub type Rfunc = unsafe extern "C" fn(*mut Arglist) -> f64;
171
172/// Pointer to the `Addfunc` callback provided by the caller.
173pub type AddfuncFn = unsafe extern "C" fn(
174    name: *const c_char,
175    f: Rfunc,
176    ty: c_int,
177    nargs: c_int,
178    funcinfo: *mut c_void,
179    ae: *mut AmplExports,
180);
181
182/// Pointer to the `RandSeedSetter` callback.
183pub type RandSeedSetter = unsafe extern "C" fn(*mut c_void, std::os::raw::c_ulong);
184
185/// Pointer to the `Addrandinit` callback.
186pub type AddrandinitFn =
187    unsafe extern "C" fn(ae: *mut AmplExports, setter: RandSeedSetter, v: *mut c_void);
188
189/// Pointer to the `AtReset` callback.
190pub type AtResetFn = unsafe extern "C" fn(ae: *mut AmplExports, f: *mut c_void, v: *mut c_void);
191
192/// The `AmplExports` struct from AMPL's `funcadd.h`. Layout must match
193/// exactly. Function pointers we don't implement are held as `*mut c_void`
194/// (null) — AMPL's ABI does not require a caller to populate them unless the
195/// loaded library actually invokes them.
196#[repr(C)]
197pub struct AmplExports {
198    pub std_err: *mut c_void,
199    pub addfunc: Option<AddfuncFn>,
200    pub asl_date: c_long,
201    pub fprintf: *mut c_void,
202    pub printf: *mut c_void,
203    pub sprintf: *mut c_void,
204    pub vfprintf: *mut c_void,
205    pub vsprintf: *mut c_void,
206    pub strtod: *mut c_void,
207    pub crypto: *mut c_void,
208    pub asl: *mut c_char,
209    pub at_exit: *mut c_void,
210    pub at_reset: Option<AtResetFn>,
211    pub tempmem: *mut c_void,
212    pub add_table_handler: *mut c_void,
213    pub private_ae: *mut c_char,
214    pub qsortv: *mut c_void,
215
216    pub std_in: *mut c_void,
217    pub std_out: *mut c_void,
218    pub clearerr: *mut c_void,
219    pub fclose: *mut c_void,
220    pub fdopen: *mut c_void,
221    pub feof: *mut c_void,
222    pub ferror: *mut c_void,
223    pub fflush: *mut c_void,
224    pub fgetc: *mut c_void,
225    pub fgets: *mut c_void,
226    pub fileno: *mut c_void,
227    pub fopen: *mut c_void,
228    pub fputc: *mut c_void,
229    pub fputs: *mut c_void,
230    pub fread: *mut c_void,
231    pub freopen: *mut c_void,
232    pub fscanf: *mut c_void,
233    pub fseek: *mut c_void,
234    pub ftell: *mut c_void,
235    pub fwrite: *mut c_void,
236    pub pclose: *mut c_void,
237    pub perror: *mut c_void,
238    pub popen: *mut c_void,
239    pub puts: *mut c_void,
240    pub rewind: *mut c_void,
241    pub scanf: *mut c_void,
242    pub setbuf: *mut c_void,
243    pub setvbuf: *mut c_void,
244    pub sscanf: *mut c_void,
245    pub tempnam: *mut c_void,
246    pub tmpfile: *mut c_void,
247    pub tmpnam: *mut c_void,
248    pub ungetc: *mut c_void,
249    pub ai: *mut c_void,
250    pub getenv: *mut c_void,
251    pub breakfunc: *mut c_void,
252    pub breakarg: *mut c_char,
253
254    // Items available with ASLdate >= 20020501.
255    pub snprintf: *mut c_void,
256    pub vsnprintf: *mut c_void,
257
258    pub addrand: *mut c_void,
259    pub addrandinit: Option<AddrandinitFn>,
260}
261
262// SAFETY: AmplExports itself contains only raw pointers and integers. The
263// library never reads/writes it from another thread concurrently with us
264// (AMPL's model is single-threaded per problem), and we never share it
265// across threads. The Send/Sync bounds only matter because we box the
266// registry inside Arcs.
267unsafe impl Send for AmplExports {}
268unsafe impl Sync for AmplExports {}
269
270/// A function registered by a library via `Addfunc`. Mirrors the ASL
271/// `FUNCADD_TYPE` bits in `funcadd.h`.
272#[derive(Debug, Clone)]
273pub struct RegisteredFunc {
274    pub name: String,
275    pub rfunc: Rfunc,
276    /// OR of FUNCADD_TYPE bits.
277    pub ty: i32,
278    /// Declared arg count. >=0 means exactly that many, <=-1 means "at least
279    /// -(nargs+1) args".
280    pub nargs: i32,
281    /// Cookie set by the library; must be passed through to arglist.funcinfo.
282    pub funcinfo: *mut c_void,
283}
284
285// SAFETY: funcinfo is an opaque cookie owned by the library. We never
286// dereference it; we only pass it back to the library's functions, which
287// expect it. No thread-safety contract is violated by sending the struct.
288unsafe impl Send for RegisteredFunc {}
289unsafe impl Sync for RegisteredFunc {}
290
291impl std::fmt::Debug for ExternalLibrary {
292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
293        f.debug_struct("ExternalLibrary")
294            .field("funcs", &self.funcs.keys().collect::<Vec<_>>())
295            .finish()
296    }
297}
298
299/// A loaded external-function library plus its registered functions.
300pub struct ExternalLibrary {
301    /// Keep the library alive — it owns the code pages the function pointers
302    /// reference. Arc so `LoadedExternals` can share it.
303    _lib: Arc<Library>,
304    /// The AmplExports we handed to `funcadd_ASL`. Must be kept alive (pinned
305    /// in a Box) because some libraries may capture its address for later
306    /// use (e.g. for `AtReset` bookkeeping).
307    _ae: Box<AmplExports>,
308    /// Registrations collected during `funcadd_ASL`.
309    funcs: HashMap<String, RegisteredFunc>,
310}
311
312impl ExternalLibrary {
313    /// Open a shared library at `path` and invoke its `funcadd_ASL` entry
314    /// point, collecting all functions it registers.
315    pub fn load(path: &Path) -> Result<Self, String> {
316        // Serialise all ABI crossings: library init code and registration
317        // may touch global state that isn't safe under concurrent entry.
318        let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
319        // SAFETY: libloading::Library::new is unsafe because it can run
320        // arbitrary initialisers from the shared object. We trust the user's
321        // AMPLFUNC path the same way AMPL/IPOPT do.
322        let lib = unsafe { Library::new(path) }
323            .map_err(|e| format!("failed to open '{}': {}", path.display(), e))?;
324
325        // Resolve `funcadd_ASL`. AMPL's macro `#define funcadd funcadd_ASL`
326        // means every conforming library exports this symbol.
327        type FuncaddFn = unsafe extern "C" fn(*mut AmplExports);
328        let funcadd: Symbol<FuncaddFn> = unsafe { lib.get(b"funcadd_ASL\0") }
329            .map_err(|e| format!("no funcadd_ASL in '{}': {}", path.display(), e))?;
330
331        // Build an AmplExports. Most fields null — the library doesn't call
332        // them (same assumption Pyomo makes). Only the three hooks we can
333        // realistically service are set.
334        let mut ae = Box::new(AmplExports {
335            std_err: ptr::null_mut(),
336            addfunc: Some(trampoline_addfunc),
337            // ASLdate >= 20020501 unlocks the SnprintF/VsnprintF slots.
338            // Pyomo uses 20160307; mirror that.
339            asl_date: 20160307,
340            fprintf: ptr::null_mut(),
341            printf: ptr::null_mut(),
342            sprintf: ptr::null_mut(),
343            vfprintf: ptr::null_mut(),
344            vsprintf: ptr::null_mut(),
345            strtod: ptr::null_mut(),
346            crypto: ptr::null_mut(),
347            asl: ptr::null_mut(),
348            at_exit: ptr::null_mut(),
349            at_reset: Some(trampoline_atreset),
350            tempmem: ptr::null_mut(),
351            add_table_handler: ptr::null_mut(),
352            private_ae: ptr::null_mut(),
353            qsortv: ptr::null_mut(),
354            std_in: ptr::null_mut(),
355            std_out: ptr::null_mut(),
356            clearerr: ptr::null_mut(),
357            fclose: ptr::null_mut(),
358            fdopen: ptr::null_mut(),
359            feof: ptr::null_mut(),
360            ferror: ptr::null_mut(),
361            fflush: ptr::null_mut(),
362            fgetc: ptr::null_mut(),
363            fgets: ptr::null_mut(),
364            fileno: ptr::null_mut(),
365            fopen: ptr::null_mut(),
366            fputc: ptr::null_mut(),
367            fputs: ptr::null_mut(),
368            fread: ptr::null_mut(),
369            freopen: ptr::null_mut(),
370            fscanf: ptr::null_mut(),
371            fseek: ptr::null_mut(),
372            ftell: ptr::null_mut(),
373            fwrite: ptr::null_mut(),
374            pclose: ptr::null_mut(),
375            perror: ptr::null_mut(),
376            popen: ptr::null_mut(),
377            puts: ptr::null_mut(),
378            rewind: ptr::null_mut(),
379            scanf: ptr::null_mut(),
380            setbuf: ptr::null_mut(),
381            setvbuf: ptr::null_mut(),
382            sscanf: ptr::null_mut(),
383            tempnam: ptr::null_mut(),
384            tmpfile: ptr::null_mut(),
385            tmpnam: ptr::null_mut(),
386            ungetc: ptr::null_mut(),
387            ai: ptr::null_mut(),
388            getenv: ptr::null_mut(),
389            breakfunc: ptr::null_mut(),
390            breakarg: ptr::null_mut(),
391            snprintf: ptr::null_mut(),
392            vsnprintf: ptr::null_mut(),
393            addrand: ptr::null_mut(),
394            addrandinit: Some(trampoline_addrandinit),
395        });
396
397        // Drive registrations into a thread-local sink so the C trampoline
398        // has somewhere to deposit them without capturing Rust state.
399        REGISTRY_SINK.with(|sink| {
400            let mut guard = sink.borrow_mut();
401            assert!(
402                guard.is_none(),
403                "nested ExternalLibrary::load is not supported"
404            );
405            *guard = Some(HashMap::new());
406        });
407
408        // SAFETY: funcadd is a valid C function from the loaded library; we
409        // pass it a correctly-shaped AmplExports.
410        unsafe { funcadd(ae.as_mut()) };
411
412        let funcs = REGISTRY_SINK
413            .with(|sink| sink.borrow_mut().take())
414            .unwrap_or_default();
415
416        Ok(ExternalLibrary {
417            _lib: Arc::new(lib),
418            _ae: ae,
419            funcs,
420        })
421    }
422
423    /// Names of all functions registered by this library.
424    pub fn function_names(&self) -> impl Iterator<Item = &str> {
425        self.funcs.keys().map(|s| s.as_str())
426    }
427
428    /// Look up a registered function by name.
429    pub fn get(&self, name: &str) -> Option<&RegisteredFunc> {
430        self.funcs.get(name)
431    }
432
433    /// Evaluate a registered function with the given positional arguments.
434    ///
435    /// Arguments are encoded per the AMPL `arglist` ABI: real args are stored
436    /// in `ra[]`, string args in `sa[]`, and `at[i]` maps argument position
437    /// `i` to either a real-slot index (`at[i] >= 0`) or a string-slot index
438    /// (`at[i] < 0`, decoded as `-(at[i]+1)`).
439    ///
440    /// If `want_derivs` is set, a length-`nr` derivative buffer is allocated
441    /// and returned on success. If `want_hes` is set, a length-`nr*(nr+1)/2`
442    /// Hessian buffer is also allocated and returned. The library is told to
443    /// fill both by the non-null `arglist.derivs` / `arglist.hes` pointers.
444    pub fn eval(
445        &self,
446        name: &str,
447        args: &[ExternalArg<'_>],
448        want_derivs: bool,
449        want_hes: bool,
450    ) -> Result<EvalResult, String> {
451        let rf = self
452            .funcs
453            .get(name)
454            .ok_or_else(|| format!("no such external function '{name}'"))?;
455
456        // Validate arity against the registered signature.
457        let n = args.len() as i32;
458        if rf.nargs >= 0 {
459            if rf.nargs != n {
460                return Err(format!(
461                    "external '{name}' expects {} args, got {}",
462                    rf.nargs, n
463                ));
464            }
465        } else {
466            // Negative: minimum -(nargs+1) args.
467            let min_args = -(rf.nargs + 1);
468            if n < min_args {
469                return Err(format!(
470                    "external '{name}' expects at least {min_args} args, got {n}"
471                ));
472            }
473        }
474
475        // Bucket args: build at[], ra[], sa[] in lockstep with their indices.
476        let mut at_vec: Vec<c_int> = Vec::with_capacity(args.len());
477        let mut ra_vec: Vec<f64> = Vec::new();
478        let mut sa_owned: Vec<CString> = Vec::new();
479        for a in args {
480            match a {
481                ExternalArg::Real(x) => {
482                    at_vec.push(ra_vec.len() as c_int);
483                    ra_vec.push(*x);
484                }
485                ExternalArg::Str(s) => {
486                    let cs = CString::new(*s)
487                        .map_err(|_| format!("external '{name}' string arg contains NUL"))?;
488                    at_vec.push(-(sa_owned.len() as c_int + 1));
489                    sa_owned.push(cs);
490                }
491            }
492        }
493        let nr = ra_vec.len() as c_int;
494        let sa_ptrs: Vec<*const c_char> = sa_owned.iter().map(|s| s.as_ptr()).collect();
495
496        // If the library declared FUNCADD_STRING_ARGS we let it see sa; if it
497        // did not, the library shouldn't be called with strings. Surface that.
498        let has_strings = !sa_owned.is_empty();
499        if has_strings && (rf.ty & FUNCADD_STRING_ARGS) == 0 {
500            return Err(format!(
501                "external '{name}' is not declared FUNCADD_STRING_ARGS but was \
502                 called with string arguments"
503            ));
504        }
505
506        // Optional output buffers.
507        let mut derivs_buf: Vec<f64> = if want_derivs {
508            vec![0.0; nr as usize]
509        } else {
510            Vec::new()
511        };
512        let hes_len = if want_hes {
513            (nr as usize) * ((nr as usize) + 1) / 2
514        } else {
515            0
516        };
517        let mut hes_buf: Vec<f64> = if want_hes {
518            vec![0.0; hes_len]
519        } else {
520            Vec::new()
521        };
522
523        // Space for a library-set error message.
524        let mut errmsg_buf: Vec<c_char> = vec![0; 1024];
525
526        // Build the arglist. Pointers into Rust-owned buffers are valid for
527        // the duration of the call since we hold those Vecs in this stack
528        // frame and the callee runs synchronously.
529        let mut al = Arglist {
530            n,
531            nr,
532            at: if at_vec.is_empty() {
533                ptr::null_mut()
534            } else {
535                at_vec.as_mut_ptr()
536            },
537            ra: if ra_vec.is_empty() {
538                ptr::null_mut()
539            } else {
540                ra_vec.as_mut_ptr()
541            },
542            sa: if sa_ptrs.is_empty() {
543                ptr::null_mut()
544            } else {
545                sa_ptrs.as_ptr() as *mut *const c_char
546            },
547            derivs: if want_derivs {
548                derivs_buf.as_mut_ptr()
549            } else {
550                ptr::null_mut()
551            },
552            hes: if want_hes {
553                hes_buf.as_mut_ptr()
554            } else {
555                ptr::null_mut()
556            },
557            dig: ptr::null_mut(),
558            funcinfo: rf.funcinfo,
559            // Some libraries read arglist.ae (e.g. to call fprintf); point at
560            // the same AmplExports we handed to funcadd_ASL.
561            ae: self._ae_ptr(),
562            f: ptr::null_mut(),
563            tva: ptr::null_mut(),
564            errmsg: errmsg_buf.as_mut_ptr(),
565            tmi: ptr::null_mut(),
566            private: ptr::null_mut(),
567            nin: 0,
568            nout: 0,
569            nsin: 0,
570            nsout: 0,
571        };
572
573        // SAFETY: rfunc is a valid extern "C" function pointer provided by
574        // the loaded library; arglist layout matches funcadd.h exactly.
575        // The AMPL lock serialises concurrent entry into the library.
576        let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
577        let value = unsafe { (rf.rfunc)(&mut al as *mut Arglist) };
578        drop(_guard);
579
580        // If the library wrote into errmsg, surface that. AMPL convention:
581        // if errmsg[0] != 0 after the call, treat as error.
582        if errmsg_buf[0] != 0 {
583            // SAFETY: errmsg_buf is a NUL-terminated C buffer (we allocated
584            // and zeroed it); the library only writes a C string there.
585            let msg = unsafe { CStr::from_ptr(errmsg_buf.as_ptr()) }
586                .to_string_lossy()
587                .into_owned();
588            return Err(format!("external '{name}' reported: {msg}"));
589        }
590
591        Ok(EvalResult {
592            value,
593            derivs: if want_derivs { Some(derivs_buf) } else { None },
594            hessian: if want_hes { Some(hes_buf) } else { None },
595        })
596    }
597
598    // Raw mutable pointer to the owned AmplExports. Used when building an
599    // arglist so the library can call back through the same table it was
600    // registered with. The Box is pinned for the lifetime of self.
601    fn _ae_ptr(&self) -> *mut AmplExports {
602        // Cast away the const; we never mutate the AmplExports ourselves.
603        (&*self._ae as *const AmplExports) as *mut AmplExports
604    }
605}
606
607/// One positional argument to an external function.
608#[derive(Debug, Clone, Copy)]
609pub enum ExternalArg<'a> {
610    Real(f64),
611    Str(&'a str),
612}
613
614/// Return value from [`ExternalLibrary::eval`].
615#[derive(Debug, Clone)]
616pub struct EvalResult {
617    /// Function value.
618    pub value: f64,
619    /// `df/dx_i` for each real argument, in `ra[]` order, if `want_derivs`.
620    pub derivs: Option<Vec<f64>>,
621    /// Packed upper-triangular Hessian in AMPL's convention,
622    /// `hes[i + j*(j+1)/2]` for `0 <= i <= j < nr`, if `want_hes`.
623    pub hessian: Option<Vec<f64>>,
624}
625
626// ---------------------------------------------------------------------------
627// Registration trampoline.
628//
629// `funcadd_ASL` can call Addfunc multiple times (once per registered name).
630// Rust closures can't be converted to `extern "C"` function pointers, so we
631// route each call through a free function that deposits into a thread-local
632// sink populated by `ExternalLibrary::load`.
633// ---------------------------------------------------------------------------
634
635thread_local! {
636    static REGISTRY_SINK: std::cell::RefCell<Option<HashMap<String, RegisteredFunc>>> =
637        std::cell::RefCell::new(None);
638}
639
640/// C-callable trampoline that receives Addfunc calls from the shared library.
641unsafe extern "C" fn trampoline_addfunc(
642    name: *const c_char,
643    f: Rfunc,
644    ty: c_int,
645    nargs: c_int,
646    funcinfo: *mut c_void,
647    _ae: *mut AmplExports,
648) {
649    if name.is_null() {
650        return;
651    }
652    // SAFETY: AMPL guarantees name is a NUL-terminated C string.
653    let cname = unsafe { CStr::from_ptr(name) };
654    let name_str = match cname.to_str() {
655        Ok(s) => s.to_owned(),
656        Err(_) => return, // non-UTF8 name — skip; real libs use ASCII.
657    };
658    REGISTRY_SINK.with(|sink| {
659        if let Some(map) = sink.borrow_mut().as_mut() {
660            map.insert(
661                name_str.clone(),
662                RegisteredFunc {
663                    name: name_str,
664                    rfunc: f,
665                    ty: ty as i32,
666                    nargs: nargs as i32,
667                    funcinfo,
668                },
669            );
670        }
671    });
672}
673
674/// Stub — some libraries ask us to register an AtReset callback. Pyomo logs a
675/// warning and does nothing. We do the same.
676unsafe extern "C" fn trampoline_atreset(_ae: *mut AmplExports, _f: *mut c_void, _v: *mut c_void) {
677    tracing::debug!("external library registered an AtReset callback; ignoring");
678}
679
680/// Stub — invoked by libraries that use random-valued externals. We just
681/// seed with 1 (matches Pyomo's default; no randomness in KKT paths).
682unsafe extern "C" fn trampoline_addrandinit(
683    _ae: *mut AmplExports,
684    setter: RandSeedSetter,
685    v: *mut c_void,
686) {
687    unsafe { setter(v, 1) };
688}
689
690#[cfg(test)]
691mod tests {
692    use super::*;
693
694    fn idaes_dylib() -> Option<std::path::PathBuf> {
695        let home = std::env::var_os("HOME")?;
696        let p = std::path::PathBuf::from(home).join(".idaes/bin/general_helmholtz_external.dylib");
697        if p.exists() {
698            Some(p)
699        } else {
700            None
701        }
702    }
703
704    fn idaes_params_dir() -> Option<String> {
705        let home = std::env::var_os("HOME")?;
706        let p = std::path::PathBuf::from(home).join(
707            "Dropbox/uv/.venv/lib/python3.12/site-packages/idaes/\
708             models/properties/general_helmholtz/components/parameters/",
709        );
710        if p.exists() {
711            p.to_str().map(|s| s.to_owned())
712        } else {
713            None
714        }
715    }
716
717    /// Opening the IDAES Helmholtz dylib (when present locally) should
718    /// surface the three functions used by the issue #15 fixture.
719    #[test]
720    fn load_idaes_helmholtz_dylib_registers_known_functions() {
721        let Some(path) = idaes_dylib() else {
722            eprintln!("skipping: IDAES dylib not present");
723            return;
724        };
725
726        let lib = ExternalLibrary::load(&path).expect("load should succeed");
727        let names: Vec<String> = lib.function_names().map(|s| s.to_owned()).collect();
728
729        for required in &["vf_hp", "h_liq_hp", "h_vap_hp"] {
730            assert!(
731                names.iter().any(|n| n == required),
732                "expected {required} in registered names: {names:?}"
733            );
734        }
735    }
736
737    /// Evaluate vf_hp at the NL fixture's initial guess. We don't assert the
738    /// exact numeric value (that's an IDAES invariant, not a ripopt one), but
739    /// the return value must be finite and the call must not set errmsg.
740    #[test]
741    fn eval_vf_hp_at_fixture_initial_point() {
742        let Some(path) = idaes_dylib() else {
743            eprintln!("skipping: IDAES dylib not present");
744            return;
745        };
746        let Some(params_dir) = idaes_params_dir() else {
747            eprintln!("skipping: IDAES parameters directory not present");
748            return;
749        };
750
751        let lib = ExternalLibrary::load(&path).expect("load");
752        // Fixture initial guess: h = 1878.71 kJ/kg-scaled, p = 101.325 kPa
753        // (the scaled values actually passed through the v3/v4 slots are
754        // 1878.71 * 0.0555... and 101325 * 0.001 respectively; using raw
755        // values here, the function should still return a finite number).
756        let args = [
757            ExternalArg::Str("h2o"),
758            ExternalArg::Real(1878.71 * 0.055508472036052976),
759            ExternalArg::Real(101325.0 * 0.001),
760            ExternalArg::Str(&params_dir),
761        ];
762        let res = lib.eval("vf_hp", &args, false, false).expect("eval");
763        assert!(
764            res.value.is_finite(),
765            "vf_hp returned non-finite value {}",
766            res.value
767        );
768    }
769
770    /// Same call path, but asking for first derivatives. derivs must be a
771    /// length-2 buffer (nr=2) of finite values.
772    #[test]
773    fn eval_vf_hp_with_derivatives() {
774        let Some(path) = idaes_dylib() else {
775            eprintln!("skipping: IDAES dylib not present");
776            return;
777        };
778        let Some(params_dir) = idaes_params_dir() else {
779            eprintln!("skipping: IDAES parameters directory not present");
780            return;
781        };
782
783        let lib = ExternalLibrary::load(&path).expect("load");
784        let args = [
785            ExternalArg::Str("h2o"),
786            ExternalArg::Real(1878.71 * 0.055508472036052976),
787            ExternalArg::Real(101325.0 * 0.001),
788            ExternalArg::Str(&params_dir),
789        ];
790        let res = lib.eval("vf_hp", &args, true, false).expect("eval");
791        let derivs = res.derivs.expect("derivs requested");
792        assert_eq!(derivs.len(), 2, "nr=2 reals -> 2 derivatives");
793        for (i, d) in derivs.iter().enumerate() {
794            assert!(d.is_finite(), "derivs[{i}] = {d} not finite");
795        }
796    }
797
798    /// Also request the packed Hessian. For nr=2 reals, that's 3 entries
799    /// (H00, H01, H11) in AMPL's packed upper-triangular layout.
800    #[test]
801    fn eval_vf_hp_with_hessian() {
802        let Some(path) = idaes_dylib() else {
803            eprintln!("skipping: IDAES dylib not present");
804            return;
805        };
806        let Some(params_dir) = idaes_params_dir() else {
807            eprintln!("skipping: IDAES parameters directory not present");
808            return;
809        };
810
811        let lib = ExternalLibrary::load(&path).expect("load");
812        let args = [
813            ExternalArg::Str("h2o"),
814            ExternalArg::Real(1878.71 * 0.055508472036052976),
815            ExternalArg::Real(101325.0 * 0.001),
816            ExternalArg::Str(&params_dir),
817        ];
818        let res = lib.eval("vf_hp", &args, true, true).expect("eval");
819        let hes = res.hessian.expect("hessian requested");
820        assert_eq!(hes.len(), 3, "nr=2 -> packed Hessian of length 3");
821        for (i, h) in hes.iter().enumerate() {
822            assert!(h.is_finite(), "hes[{i}] = {h} not finite");
823        }
824    }
825}