Skip to main content

pounce_nl/
nl_external.rs

1//! AMPL imported (external) function support via the `funcadd_ASL` ABI.
2//!
3//! This module implements enough of AMPL's `funcadd.h` ABI to:
4//!
5//! 1. `dlopen` a user-supplied shared library;
6//! 2. resolve the `funcadd_ASL` symbol and call it;
7//! 3. receive registration callbacks of the form `Addfunc(name, rfunc, type,
8//!    nargs, funcinfo, ae)` and record them;
9//! 4. later call back into the registered `rfunc` with an `arglist` to obtain
10//!    function values, gradients, and Hessians.
11//!
12//! The `AmplExports` and `Arglist` struct layouts are taken from
13//! AMPL-MP/ASL `funcadd.h`; cross-checked against the ctypes mapping in
14//! `pyomo.core.base.external`. Fields we don't populate are left null —
15//! Pyomo does the same and it is sufficient for IDAES's Helmholtz library
16//! (see issue #15).
17//!
18//! All unsafe FFI is contained in this module. Public surface is safe.
19
20use std::collections::HashMap;
21use std::ffi::{c_char, c_int, c_long, c_void, CStr, CString};
22use std::path::Path;
23use std::ptr;
24use std::sync::{Arc, Mutex, OnceLock};
25
26use libloading::{Library, Symbol};
27
28use crate::nl_reader::{Expr, FuncallArg, ImportedFunc};
29
30/// Resolved AMPL imported function: shared library + registered name.
31/// `NlProblem` carries one of these per `ImportedFunc` id when external
32/// functions are wired up at problem-build time. The same `Arc<ExternalLibrary>`
33/// may be shared across many funcall ids (one library typically registers
34/// several functions).
35#[derive(Default, Clone)]
36pub struct ExternalResolver {
37    /// `Funcall { id }` -> (library, registered function name).
38    pub funcs_by_id: HashMap<usize, (Arc<ExternalLibrary>, String)>,
39}
40
41impl ExternalResolver {
42    pub fn is_empty(&self) -> bool {
43        self.funcs_by_id.is_empty()
44    }
45
46    /// Build a resolver for every `ImportedFunc` declared in the `.nl` file
47    /// that is *actually referenced* somewhere in the problem's expressions.
48    ///
49    /// Library paths are resolved through the `AMPLFUNC` environment variable
50    /// (a `\n`-separated list of shared-library paths, matching AMPL/IPOPT
51    /// conventions). Each path is loaded once and queried for every name we
52    /// need. Returns an error if a referenced name cannot be found in any
53    /// listed library, or if `AMPLFUNC` is missing.
54    pub fn build_for_problem(
55        imported_funcs: &[ImportedFunc],
56        referenced_ids: &std::collections::BTreeSet<usize>,
57    ) -> Result<Self, String> {
58        if referenced_ids.is_empty() {
59            return Ok(Self::default());
60        }
61        let amplfunc = std::env::var("AMPLFUNC").map_err(|_| {
62            "problem uses external functions but AMPLFUNC is not set; \
63             set AMPLFUNC to a newline-separated list of AMPL shared-library paths"
64                .to_string()
65        })?;
66        let mut libs: Vec<Arc<ExternalLibrary>> = Vec::new();
67        for path_str in amplfunc
68            .split('\n')
69            .map(|s| s.trim())
70            .filter(|s| !s.is_empty())
71        {
72            let path = std::path::Path::new(path_str);
73            let lib = ExternalLibrary::load(path).map_err(|e| format!("AMPLFUNC: {e}"))?;
74            libs.push(Arc::new(lib));
75        }
76
77        let mut funcs_by_id: HashMap<usize, (Arc<ExternalLibrary>, String)> = HashMap::new();
78        for id in referenced_ids {
79            let decl = imported_funcs
80                .iter()
81                .find(|f| f.id == *id)
82                .ok_or_else(|| format!("funcall id {id} has no F<{id}> declaration"))?;
83            let found = libs
84                .iter()
85                .find(|lib| lib.get(&decl.name).is_some())
86                .ok_or_else(|| {
87                    format!(
88                        "external function '{}' (id {}) not found in any library on AMPLFUNC",
89                        decl.name, decl.id
90                    )
91                })?;
92            funcs_by_id.insert(*id, (found.clone(), decl.name.clone()));
93        }
94        Ok(Self { funcs_by_id })
95    }
96}
97
98/// Walk an `Expr` and collect every funcall id it references (including
99/// through CSEs). Used to build an `ExternalResolver` covering exactly the
100/// functions a problem actually uses.
101pub fn collect_funcall_ids(e: &Expr, out: &mut std::collections::BTreeSet<usize>) {
102    match e {
103        Expr::Const(_) | Expr::Var(_) => {}
104        Expr::Binary(_, a, b) => {
105            collect_funcall_ids(a, out);
106            collect_funcall_ids(b, out);
107        }
108        Expr::Unary(_, a) => collect_funcall_ids(a, out),
109        Expr::Sum(args) | Expr::MinList(args) | Expr::MaxList(args) => {
110            for a in args {
111                collect_funcall_ids(a, out);
112            }
113        }
114        Expr::Compare(_, a, b) | Expr::And(a, b) | Expr::Or(a, b) => {
115            collect_funcall_ids(a, out);
116            collect_funcall_ids(b, out);
117        }
118        Expr::Not(a) => collect_funcall_ids(a, out),
119        Expr::Cond { cond, then_, else_ } => {
120            collect_funcall_ids(cond, out);
121            collect_funcall_ids(then_, out);
122            collect_funcall_ids(else_, out);
123        }
124        Expr::Cse(body) => collect_funcall_ids(body, out),
125        Expr::Funcall { id, args } => {
126            out.insert(*id);
127            for arg in args {
128                if let FuncallArg::Real(e) = arg {
129                    collect_funcall_ids(e, out);
130                }
131            }
132        }
133    }
134}
135
136/// Process-wide lock serialising every call that crosses the AMPL external
137/// ABI. Real AMPL libraries (e.g. IDAES general_helmholtz) keep mutable
138/// global state (cached parameters, tabulated lookups) and are not safe for
139/// concurrent entry. Python's `pyomo.core.base.external` relies on the GIL
140/// for the same guarantee.
141fn ampl_lock() -> &'static Mutex<()> {
142    static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
143    LOCK.get_or_init(|| Mutex::new(()))
144}
145
146/// FUNCADD_TYPE bits (mirrors `funcadd.h`).
147pub const FUNCADD_REAL_VALUED: i32 = 0;
148/// Set if the function consumes string arguments. Value is still real.
149pub const FUNCADD_STRING_ARGS: i32 = 1;
150/// Set if the function is allowed to have a variable number of args.
151pub const FUNCADD_OUTPUT_ARGS: i32 = 2;
152pub const FUNCADD_RANDOM_VALUED: i32 = 4;
153
154/// The `arglist` struct from AMPL's `funcadd.h`. Layout must match exactly.
155#[repr(C)]
156pub struct Arglist {
157    pub n: c_int,               // number of args
158    pub nr: c_int,              // number of real input args
159    pub at: *mut c_int,         // argument types
160    pub ra: *mut f64,           // pure real args (IN/OUT/INOUT)
161    pub sa: *mut *const c_char, // symbolic IN args
162    pub derivs: *mut f64,       // partial derivatives (if non-null)
163    pub hes: *mut f64,          // second partials (if non-null)
164    pub dig: *mut c_char,       // skip-derivatives flags
165    pub funcinfo: *mut c_void,  // per-function cookie (set by Addfunc)
166    pub ae: *mut AmplExports,   // points back at our AmplExports
167    pub f: *mut c_void,         // AMPL-internal
168    pub tva: *mut c_void,       // AMPL-internal
169    pub errmsg: *mut c_char,    // error description set by the function
170    pub tmi: *mut c_void,       // Tempmem cookie
171    pub private: *mut c_char,
172    pub nin: c_int,
173    pub nout: c_int,
174    pub nsin: c_int,
175    pub nsout: c_int,
176}
177
178/// Pointer to a user-defined real-valued function, matching
179/// `typedef real (*rfunc)(arglist*)`.
180pub type Rfunc = unsafe extern "C" fn(*mut Arglist) -> f64;
181
182/// Pointer to the `Addfunc` callback provided by the caller.
183pub type AddfuncFn = unsafe extern "C" fn(
184    name: *const c_char,
185    f: Rfunc,
186    ty: c_int,
187    nargs: c_int,
188    funcinfo: *mut c_void,
189    ae: *mut AmplExports,
190);
191
192/// Pointer to the `RandSeedSetter` callback.
193pub type RandSeedSetter = unsafe extern "C" fn(*mut c_void, std::os::raw::c_ulong);
194
195/// Pointer to the `Addrandinit` callback.
196pub type AddrandinitFn =
197    unsafe extern "C" fn(ae: *mut AmplExports, setter: RandSeedSetter, v: *mut c_void);
198
199/// Pointer to the `AtReset` callback.
200pub type AtResetFn = unsafe extern "C" fn(ae: *mut AmplExports, f: *mut c_void, v: *mut c_void);
201
202/// The `AmplExports` struct from AMPL's `funcadd.h`. Layout must match
203/// exactly. Function pointers we don't implement are held as `*mut c_void`
204/// (null) — AMPL's ABI does not require a caller to populate them unless the
205/// loaded library actually invokes them.
206#[repr(C)]
207pub struct AmplExports {
208    pub std_err: *mut c_void,
209    pub addfunc: Option<AddfuncFn>,
210    pub asl_date: c_long,
211    pub fprintf: *mut c_void,
212    pub printf: *mut c_void,
213    pub sprintf: *mut c_void,
214    pub vfprintf: *mut c_void,
215    pub vsprintf: *mut c_void,
216    pub strtod: *mut c_void,
217    pub crypto: *mut c_void,
218    pub asl: *mut c_char,
219    pub at_exit: *mut c_void,
220    pub at_reset: Option<AtResetFn>,
221    pub tempmem: *mut c_void,
222    pub add_table_handler: *mut c_void,
223    pub private_ae: *mut c_char,
224    pub qsortv: *mut c_void,
225
226    pub std_in: *mut c_void,
227    pub std_out: *mut c_void,
228    pub clearerr: *mut c_void,
229    pub fclose: *mut c_void,
230    pub fdopen: *mut c_void,
231    pub feof: *mut c_void,
232    pub ferror: *mut c_void,
233    pub fflush: *mut c_void,
234    pub fgetc: *mut c_void,
235    pub fgets: *mut c_void,
236    pub fileno: *mut c_void,
237    pub fopen: *mut c_void,
238    pub fputc: *mut c_void,
239    pub fputs: *mut c_void,
240    pub fread: *mut c_void,
241    pub freopen: *mut c_void,
242    pub fscanf: *mut c_void,
243    pub fseek: *mut c_void,
244    pub ftell: *mut c_void,
245    pub fwrite: *mut c_void,
246    pub pclose: *mut c_void,
247    pub perror: *mut c_void,
248    pub popen: *mut c_void,
249    pub puts: *mut c_void,
250    pub rewind: *mut c_void,
251    pub scanf: *mut c_void,
252    pub setbuf: *mut c_void,
253    pub setvbuf: *mut c_void,
254    pub sscanf: *mut c_void,
255    pub tempnam: *mut c_void,
256    pub tmpfile: *mut c_void,
257    pub tmpnam: *mut c_void,
258    pub ungetc: *mut c_void,
259    pub ai: *mut c_void,
260    pub getenv: *mut c_void,
261    pub breakfunc: *mut c_void,
262    pub breakarg: *mut c_char,
263
264    // Items available with ASLdate >= 20020501.
265    pub snprintf: *mut c_void,
266    pub vsnprintf: *mut c_void,
267
268    pub addrand: *mut c_void,
269    pub addrandinit: Option<AddrandinitFn>,
270}
271
272// SAFETY: AmplExports itself contains only raw pointers and integers. The
273// library never reads/writes it from another thread concurrently with us
274// (AMPL's model is single-threaded per problem), and we never share it
275// across threads. The Send/Sync bounds only matter because we box the
276// registry inside Arcs.
277unsafe impl Send for AmplExports {}
278unsafe impl Sync for AmplExports {}
279
280/// A function registered by a library via `Addfunc`. Mirrors the ASL
281/// `FUNCADD_TYPE` bits in `funcadd.h`.
282#[derive(Debug, Clone)]
283pub struct RegisteredFunc {
284    pub name: String,
285    pub rfunc: Rfunc,
286    /// OR of FUNCADD_TYPE bits.
287    pub ty: i32,
288    /// Declared arg count. >=0 means exactly that many, <=-1 means "at least
289    /// -(nargs+1) args".
290    pub nargs: i32,
291    /// Cookie set by the library; must be passed through to arglist.funcinfo.
292    pub funcinfo: *mut c_void,
293}
294
295// SAFETY: funcinfo is an opaque cookie owned by the library. We never
296// dereference it; we only pass it back to the library's functions, which
297// expect it. No thread-safety contract is violated by sending the struct.
298unsafe impl Send for RegisteredFunc {}
299unsafe impl Sync for RegisteredFunc {}
300
301impl std::fmt::Debug for ExternalLibrary {
302    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
303        f.debug_struct("ExternalLibrary")
304            .field("funcs", &self.funcs.keys().collect::<Vec<_>>())
305            .finish()
306    }
307}
308
309/// A loaded external-function library plus its registered functions.
310pub struct ExternalLibrary {
311    /// Keep the library alive — it owns the code pages the function pointers
312    /// reference. Arc so `LoadedExternals` can share it.
313    _lib: Arc<Library>,
314    /// The AmplExports we handed to `funcadd_ASL`. Must be kept alive (pinned
315    /// in a Box) because some libraries may capture its address for later
316    /// use (e.g. for `AtReset` bookkeeping).
317    _ae: Box<AmplExports>,
318    /// Registrations collected during `funcadd_ASL`.
319    funcs: HashMap<String, RegisteredFunc>,
320}
321
322impl ExternalLibrary {
323    /// Open a shared library at `path` and invoke its `funcadd_ASL` entry
324    /// point, collecting all functions it registers.
325    pub fn load(path: &Path) -> Result<Self, String> {
326        // Serialise all ABI crossings: library init code and registration
327        // may touch global state that isn't safe under concurrent entry.
328        let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
329        // SAFETY: libloading::Library::new is unsafe because it can run
330        // arbitrary initialisers from the shared object. We trust the user's
331        // AMPLFUNC path the same way AMPL/IPOPT do.
332        let lib = unsafe { Library::new(path) }
333            .map_err(|e| format!("failed to open '{}': {}", path.display(), e))?;
334
335        // Resolve `funcadd_ASL`. AMPL's macro `#define funcadd funcadd_ASL`
336        // means every conforming library exports this symbol.
337        type FuncaddFn = unsafe extern "C" fn(*mut AmplExports);
338        let funcadd: Symbol<FuncaddFn> = unsafe { lib.get(b"funcadd_ASL\0") }
339            .map_err(|e| format!("no funcadd_ASL in '{}': {}", path.display(), e))?;
340
341        // Build an AmplExports. Most fields null — the library doesn't call
342        // them (same assumption Pyomo makes). Only the three hooks we can
343        // realistically service are set.
344        let mut ae = Box::new(AmplExports {
345            std_err: ptr::null_mut(),
346            addfunc: Some(trampoline_addfunc),
347            // ASLdate >= 20020501 unlocks the SnprintF/VsnprintF slots.
348            // Pyomo uses 20160307; mirror that.
349            asl_date: 20160307,
350            fprintf: ptr::null_mut(),
351            printf: ptr::null_mut(),
352            sprintf: ptr::null_mut(),
353            vfprintf: ptr::null_mut(),
354            vsprintf: ptr::null_mut(),
355            strtod: ptr::null_mut(),
356            crypto: ptr::null_mut(),
357            asl: ptr::null_mut(),
358            at_exit: ptr::null_mut(),
359            at_reset: Some(trampoline_atreset),
360            tempmem: ptr::null_mut(),
361            add_table_handler: ptr::null_mut(),
362            private_ae: ptr::null_mut(),
363            qsortv: ptr::null_mut(),
364            std_in: ptr::null_mut(),
365            std_out: ptr::null_mut(),
366            clearerr: ptr::null_mut(),
367            fclose: ptr::null_mut(),
368            fdopen: ptr::null_mut(),
369            feof: ptr::null_mut(),
370            ferror: ptr::null_mut(),
371            fflush: ptr::null_mut(),
372            fgetc: ptr::null_mut(),
373            fgets: ptr::null_mut(),
374            fileno: ptr::null_mut(),
375            fopen: ptr::null_mut(),
376            fputc: ptr::null_mut(),
377            fputs: ptr::null_mut(),
378            fread: ptr::null_mut(),
379            freopen: ptr::null_mut(),
380            fscanf: ptr::null_mut(),
381            fseek: ptr::null_mut(),
382            ftell: ptr::null_mut(),
383            fwrite: ptr::null_mut(),
384            pclose: ptr::null_mut(),
385            perror: ptr::null_mut(),
386            popen: ptr::null_mut(),
387            puts: ptr::null_mut(),
388            rewind: ptr::null_mut(),
389            scanf: ptr::null_mut(),
390            setbuf: ptr::null_mut(),
391            setvbuf: ptr::null_mut(),
392            sscanf: ptr::null_mut(),
393            tempnam: ptr::null_mut(),
394            tmpfile: ptr::null_mut(),
395            tmpnam: ptr::null_mut(),
396            ungetc: ptr::null_mut(),
397            ai: ptr::null_mut(),
398            getenv: ptr::null_mut(),
399            breakfunc: ptr::null_mut(),
400            breakarg: ptr::null_mut(),
401            snprintf: ptr::null_mut(),
402            vsnprintf: ptr::null_mut(),
403            addrand: ptr::null_mut(),
404            addrandinit: Some(trampoline_addrandinit),
405        });
406
407        // Drive registrations into a thread-local sink so the C trampoline
408        // has somewhere to deposit them without capturing Rust state.
409        REGISTRY_SINK.with(|sink| {
410            let mut guard = sink.borrow_mut();
411            assert!(
412                guard.is_none(),
413                "nested ExternalLibrary::load is not supported"
414            );
415            *guard = Some(HashMap::new());
416        });
417
418        // SAFETY: funcadd is a valid C function from the loaded library; we
419        // pass it a correctly-shaped AmplExports.
420        unsafe { funcadd(ae.as_mut()) };
421
422        let funcs = REGISTRY_SINK
423            .with(|sink| sink.borrow_mut().take())
424            .unwrap_or_default();
425
426        Ok(ExternalLibrary {
427            _lib: Arc::new(lib),
428            _ae: ae,
429            funcs,
430        })
431    }
432
433    /// Names of all functions registered by this library.
434    pub fn function_names(&self) -> impl Iterator<Item = &str> {
435        self.funcs.keys().map(|s| s.as_str())
436    }
437
438    /// Look up a registered function by name.
439    pub fn get(&self, name: &str) -> Option<&RegisteredFunc> {
440        self.funcs.get(name)
441    }
442
443    /// Evaluate a registered function with the given positional arguments.
444    ///
445    /// Arguments are encoded per the AMPL `arglist` ABI: real args are stored
446    /// in `ra[]`, string args in `sa[]`, and `at[i]` maps argument position
447    /// `i` to either a real-slot index (`at[i] >= 0`) or a string-slot index
448    /// (`at[i] < 0`, decoded as `-(at[i]+1)`).
449    ///
450    /// If `want_derivs` is set, a length-`nr` derivative buffer is allocated
451    /// and returned on success. If `want_hes` is set, a length-`nr*(nr+1)/2`
452    /// Hessian buffer is also allocated and returned. The library is told to
453    /// fill both by the non-null `arglist.derivs` / `arglist.hes` pointers.
454    pub fn eval(
455        &self,
456        name: &str,
457        args: &[ExternalArg<'_>],
458        want_derivs: bool,
459        want_hes: bool,
460    ) -> Result<EvalResult, String> {
461        let rf = self
462            .funcs
463            .get(name)
464            .ok_or_else(|| format!("no such external function '{name}'"))?;
465
466        // Validate arity against the registered signature.
467        let n = args.len() as i32;
468        if rf.nargs >= 0 {
469            if rf.nargs != n {
470                return Err(format!(
471                    "external '{name}' expects {} args, got {}",
472                    rf.nargs, n
473                ));
474            }
475        } else {
476            // Negative: minimum -(nargs+1) args.
477            let min_args = -(rf.nargs + 1);
478            if n < min_args {
479                return Err(format!(
480                    "external '{name}' expects at least {min_args} args, got {n}"
481                ));
482            }
483        }
484
485        // Bucket args: build at[], ra[], sa[] in lockstep with their indices.
486        let mut at_vec: Vec<c_int> = Vec::with_capacity(args.len());
487        let mut ra_vec: Vec<f64> = Vec::new();
488        let mut sa_owned: Vec<CString> = Vec::new();
489        for a in args {
490            match a {
491                ExternalArg::Real(x) => {
492                    at_vec.push(ra_vec.len() as c_int);
493                    ra_vec.push(*x);
494                }
495                ExternalArg::Str(s) => {
496                    let cs = CString::new(*s)
497                        .map_err(|_| format!("external '{name}' string arg contains NUL"))?;
498                    at_vec.push(-(sa_owned.len() as c_int + 1));
499                    sa_owned.push(cs);
500                }
501            }
502        }
503        let nr = ra_vec.len() as c_int;
504        let sa_ptrs: Vec<*const c_char> = sa_owned.iter().map(|s| s.as_ptr()).collect();
505
506        // If the library declared FUNCADD_STRING_ARGS we let it see sa; if it
507        // did not, the library shouldn't be called with strings. Surface that.
508        let has_strings = !sa_owned.is_empty();
509        if has_strings && (rf.ty & FUNCADD_STRING_ARGS) == 0 {
510            return Err(format!(
511                "external '{name}' is not declared FUNCADD_STRING_ARGS but was \
512                 called with string arguments"
513            ));
514        }
515
516        // Optional output buffers.
517        let mut derivs_buf: Vec<f64> = if want_derivs {
518            vec![0.0; nr as usize]
519        } else {
520            Vec::new()
521        };
522        let hes_len = if want_hes {
523            (nr as usize) * ((nr as usize) + 1) / 2
524        } else {
525            0
526        };
527        let mut hes_buf: Vec<f64> = if want_hes {
528            vec![0.0; hes_len]
529        } else {
530            Vec::new()
531        };
532
533        // Space for a library-set error message. The ABI lets a library
534        // signal an error two ways (see `decode_external_errmsg`): by writing
535        // into this buffer, OR — the canonical conforming path — by
536        // *reassigning* `arglist.errmsg` to its own string. We seed the field
537        // with this buffer's address and remember it so the reassignment is
538        // detectable afterwards.
539        let mut errmsg_buf: Vec<c_char> = vec![0; 1024];
540        let errmsg_orig_ptr = errmsg_buf.as_ptr();
541
542        // Build the arglist. Pointers into Rust-owned buffers are valid for
543        // the duration of the call since we hold those Vecs in this stack
544        // frame and the callee runs synchronously.
545        let mut al = Arglist {
546            n,
547            nr,
548            at: if at_vec.is_empty() {
549                ptr::null_mut()
550            } else {
551                at_vec.as_mut_ptr()
552            },
553            ra: if ra_vec.is_empty() {
554                ptr::null_mut()
555            } else {
556                ra_vec.as_mut_ptr()
557            },
558            sa: if sa_ptrs.is_empty() {
559                ptr::null_mut()
560            } else {
561                sa_ptrs.as_ptr() as *mut *const c_char
562            },
563            derivs: if want_derivs {
564                derivs_buf.as_mut_ptr()
565            } else {
566                ptr::null_mut()
567            },
568            hes: if want_hes {
569                hes_buf.as_mut_ptr()
570            } else {
571                ptr::null_mut()
572            },
573            dig: ptr::null_mut(),
574            funcinfo: rf.funcinfo,
575            // Some libraries read arglist.ae (e.g. to call fprintf); point at
576            // the same AmplExports we handed to funcadd_ASL.
577            ae: self._ae_ptr(),
578            f: ptr::null_mut(),
579            tva: ptr::null_mut(),
580            errmsg: errmsg_buf.as_mut_ptr(),
581            tmi: ptr::null_mut(),
582            private: ptr::null_mut(),
583            nin: 0,
584            nout: 0,
585            nsin: 0,
586            nsout: 0,
587        };
588
589        // SAFETY: rfunc is a valid extern "C" function pointer provided by
590        // the loaded library; arglist layout matches funcadd.h exactly.
591        // The AMPL lock serialises concurrent entry into the library.
592        let _guard = ampl_lock().lock().unwrap_or_else(|e| e.into_inner());
593        let value = unsafe { (rf.rfunc)(&mut al as *mut Arglist) };
594        drop(_guard);
595
596        // Surface a library-reported error from *either* ABI channel: the
597        // reassigned `arglist.errmsg` pointer (the conforming path) or our
598        // pre-pointed buffer. Checking only the buffer would miss every
599        // library that does `al->Errmsg = "...";`, silently consuming garbage.
600        // SAFETY: `al.errmsg` is either our zeroed NUL-terminated buffer or a
601        // C string the library assigned; both are valid to read as a CStr.
602        if let Some(msg) =
603            unsafe { decode_external_errmsg(al.errmsg, errmsg_orig_ptr, errmsg_buf[0]) }
604        {
605            return Err(format!("external '{name}' reported: {msg}"));
606        }
607
608        Ok(EvalResult {
609            value,
610            derivs: if want_derivs { Some(derivs_buf) } else { None },
611            hessian: if want_hes { Some(hes_buf) } else { None },
612        })
613    }
614
615    // Raw mutable pointer to the owned AmplExports. Used when building an
616    // arglist so the library can call back through the same table it was
617    // registered with. The Box is pinned for the lifetime of self.
618    fn _ae_ptr(&self) -> *mut AmplExports {
619        // Cast away the const; we never mutate the AmplExports ourselves.
620        (&*self._ae as *const AmplExports) as *mut AmplExports
621    }
622}
623
624/// One positional argument to an external function.
625#[derive(Debug, Clone, Copy)]
626pub enum ExternalArg<'a> {
627    Real(f64),
628    Str(&'a str),
629}
630
631/// Return value from [`ExternalLibrary::eval`].
632#[derive(Debug, Clone)]
633pub struct EvalResult {
634    /// Function value.
635    pub value: f64,
636    /// `df/dx_i` for each real argument, in `ra[]` order, if `want_derivs`.
637    pub derivs: Option<Vec<f64>>,
638    /// Packed upper-triangular Hessian in AMPL's convention,
639    /// `hes[i + j*(j+1)/2]` for `0 <= i <= j < nr`, if `want_hes`.
640    pub hessian: Option<Vec<f64>>,
641}
642
643/// Decode an external function's error signal after its `rfunc` returns.
644///
645/// The AMPL `funcadd` ABI lets a library report an error two ways:
646///
647/// 1. **Reassign** `arglist.errmsg` to its own (usually static) C string —
648///    `al->Errmsg = "T out of range";`. This is the conforming path used by
649///    real libraries (e.g. IDAES Helmholtz on out-of-domain evals). The
650///    caller's pre-pointed buffer is left untouched.
651/// 2. Write a string into the buffer the caller pointed `errmsg` at before the
652///    call.
653///
654/// We seed `arglist.errmsg` with our buffer's address (`orig_buf_ptr`). After
655/// the call: if the field no longer equals that address (and is non-null) the
656/// library reassigned it → read from the new pointer; otherwise fall back to
657/// the buffer when its first byte is non-zero. Returns `None` when neither
658/// channel carries a message. Checking only the buffer (the prior behavior)
659/// silently dropped every channel-1 error and let the IPM consume NaN/garbage
660/// f/∇f/∇²f.
661///
662/// # Safety
663/// `errmsg_field` (when reassigned) and `orig_buf_ptr` must each point at a
664/// readable NUL-terminated C string for the duration of the read.
665unsafe fn decode_external_errmsg(
666    errmsg_field: *const c_char,
667    orig_buf_ptr: *const c_char,
668    buf_first: c_char,
669) -> Option<String> {
670    if !errmsg_field.is_null() && errmsg_field != orig_buf_ptr {
671        // Channel 1: the library reassigned the pointer to its own string.
672        // SAFETY: caller guarantees `errmsg_field` is a NUL-terminated string.
673        return Some(
674            unsafe { CStr::from_ptr(errmsg_field) }
675                .to_string_lossy()
676                .into_owned(),
677        );
678    }
679    if buf_first != 0 {
680        // Channel 2: the library wrote into the caller-provided buffer.
681        // SAFETY: caller guarantees `orig_buf_ptr` is a NUL-terminated string.
682        return Some(
683            unsafe { CStr::from_ptr(orig_buf_ptr) }
684                .to_string_lossy()
685                .into_owned(),
686        );
687    }
688    None
689}
690
691// ---------------------------------------------------------------------------
692// Registration trampoline.
693//
694// `funcadd_ASL` can call Addfunc multiple times (once per registered name).
695// Rust closures can't be converted to `extern "C"` function pointers, so we
696// route each call through a free function that deposits into a thread-local
697// sink populated by `ExternalLibrary::load`.
698// ---------------------------------------------------------------------------
699
700thread_local! {
701    static REGISTRY_SINK: std::cell::RefCell<Option<HashMap<String, RegisteredFunc>>> =
702        std::cell::RefCell::new(None);
703}
704
705/// C-callable trampoline that receives Addfunc calls from the shared library.
706unsafe extern "C" fn trampoline_addfunc(
707    name: *const c_char,
708    f: Rfunc,
709    ty: c_int,
710    nargs: c_int,
711    funcinfo: *mut c_void,
712    _ae: *mut AmplExports,
713) {
714    if name.is_null() {
715        return;
716    }
717    // SAFETY: AMPL guarantees name is a NUL-terminated C string.
718    let cname = unsafe { CStr::from_ptr(name) };
719    let name_str = match cname.to_str() {
720        Ok(s) => s.to_owned(),
721        Err(_) => return, // non-UTF8 name — skip; real libs use ASCII.
722    };
723    REGISTRY_SINK.with(|sink| {
724        if let Some(map) = sink.borrow_mut().as_mut() {
725            map.insert(
726                name_str.clone(),
727                RegisteredFunc {
728                    name: name_str,
729                    rfunc: f,
730                    ty: ty as i32,
731                    nargs: nargs as i32,
732                    funcinfo,
733                },
734            );
735        }
736    });
737}
738
739/// Stub — some libraries ask us to register an AtReset callback. Pyomo logs a
740/// warning and does nothing. We do the same.
741unsafe extern "C" fn trampoline_atreset(_ae: *mut AmplExports, _f: *mut c_void, _v: *mut c_void) {
742    tracing::debug!("external library registered an AtReset callback; ignoring");
743}
744
745/// Stub — invoked by libraries that use random-valued externals. We just
746/// seed with 1 (matches Pyomo's default; no randomness in KKT paths).
747unsafe extern "C" fn trampoline_addrandinit(
748    _ae: *mut AmplExports,
749    setter: RandSeedSetter,
750    v: *mut c_void,
751) {
752    unsafe { setter(v, 1) };
753}
754
755#[cfg(test)]
756mod tests {
757    use super::*;
758
759    fn idaes_dylib() -> Option<std::path::PathBuf> {
760        let home = std::env::var_os("HOME")?;
761        let p = std::path::PathBuf::from(home).join(".idaes/bin/general_helmholtz_external.dylib");
762        if p.exists() {
763            Some(p)
764        } else {
765            None
766        }
767    }
768
769    fn idaes_params_dir() -> Option<String> {
770        let home = std::env::var_os("HOME")?;
771        let p = std::path::PathBuf::from(home).join(
772            "Dropbox/uv/.venv/lib/python3.12/site-packages/idaes/\
773             models/properties/general_helmholtz/components/parameters/",
774        );
775        if p.exists() {
776            p.to_str().map(|s| s.to_owned())
777        } else {
778            None
779        }
780    }
781
782    /// Opening the IDAES Helmholtz dylib (when present locally) should
783    /// surface the three functions used by the issue #15 fixture.
784    #[test]
785    fn load_idaes_helmholtz_dylib_registers_known_functions() {
786        let Some(path) = idaes_dylib() else {
787            eprintln!("skipping: IDAES dylib not present");
788            return;
789        };
790
791        let lib = ExternalLibrary::load(&path).expect("load should succeed");
792        let names: Vec<String> = lib.function_names().map(|s| s.to_owned()).collect();
793
794        for required in &["vf_hp", "h_liq_hp", "h_vap_hp"] {
795            assert!(
796                names.iter().any(|n| n == required),
797                "expected {required} in registered names: {names:?}"
798            );
799        }
800    }
801
802    /// Evaluate vf_hp at the NL fixture's initial guess. We don't assert the
803    /// exact numeric value (that's an IDAES invariant, not a ripopt one), but
804    /// the return value must be finite and the call must not set errmsg.
805    #[test]
806    fn eval_vf_hp_at_fixture_initial_point() {
807        let Some(path) = idaes_dylib() else {
808            eprintln!("skipping: IDAES dylib not present");
809            return;
810        };
811        let Some(params_dir) = idaes_params_dir() else {
812            eprintln!("skipping: IDAES parameters directory not present");
813            return;
814        };
815
816        let lib = ExternalLibrary::load(&path).expect("load");
817        // Fixture initial guess: h = 1878.71 kJ/kg-scaled, p = 101.325 kPa
818        // (the scaled values actually passed through the v3/v4 slots are
819        // 1878.71 * 0.0555... and 101325 * 0.001 respectively; using raw
820        // values here, the function should still return a finite number).
821        let args = [
822            ExternalArg::Str("h2o"),
823            ExternalArg::Real(1878.71 * 0.055508472036052976),
824            ExternalArg::Real(101325.0 * 0.001),
825            ExternalArg::Str(&params_dir),
826        ];
827        let res = lib.eval("vf_hp", &args, false, false).expect("eval");
828        assert!(
829            res.value.is_finite(),
830            "vf_hp returned non-finite value {}",
831            res.value
832        );
833    }
834
835    /// Same call path, but asking for first derivatives. derivs must be a
836    /// length-2 buffer (nr=2) of finite values.
837    #[test]
838    fn eval_vf_hp_with_derivatives() {
839        let Some(path) = idaes_dylib() else {
840            eprintln!("skipping: IDAES dylib not present");
841            return;
842        };
843        let Some(params_dir) = idaes_params_dir() else {
844            eprintln!("skipping: IDAES parameters directory not present");
845            return;
846        };
847
848        let lib = ExternalLibrary::load(&path).expect("load");
849        let args = [
850            ExternalArg::Str("h2o"),
851            ExternalArg::Real(1878.71 * 0.055508472036052976),
852            ExternalArg::Real(101325.0 * 0.001),
853            ExternalArg::Str(&params_dir),
854        ];
855        let res = lib.eval("vf_hp", &args, true, false).expect("eval");
856        let derivs = res.derivs.expect("derivs requested");
857        assert_eq!(derivs.len(), 2, "nr=2 reals -> 2 derivatives");
858        for (i, d) in derivs.iter().enumerate() {
859            assert!(d.is_finite(), "derivs[{i}] = {d} not finite");
860        }
861    }
862
863    /// Also request the packed Hessian. For nr=2 reals, that's 3 entries
864    /// (H00, H01, H11) in AMPL's packed upper-triangular layout.
865    #[test]
866    fn eval_vf_hp_with_hessian() {
867        let Some(path) = idaes_dylib() else {
868            eprintln!("skipping: IDAES dylib not present");
869            return;
870        };
871        let Some(params_dir) = idaes_params_dir() else {
872            eprintln!("skipping: IDAES parameters directory not present");
873            return;
874        };
875
876        let lib = ExternalLibrary::load(&path).expect("load");
877        let args = [
878            ExternalArg::Str("h2o"),
879            ExternalArg::Real(1878.71 * 0.055508472036052976),
880            ExternalArg::Real(101325.0 * 0.001),
881            ExternalArg::Str(&params_dir),
882        ];
883        let res = lib.eval("vf_hp", &args, true, true).expect("eval");
884        let hes = res.hessian.expect("hessian requested");
885        assert_eq!(hes.len(), 3, "nr=2 -> packed Hessian of length 3");
886        for (i, h) in hes.iter().enumerate() {
887            assert!(h.is_finite(), "hes[{i}] = {h} not finite");
888        }
889    }
890
891    // --- H5: errmsg detection across both funcadd ABI channels ---
892
893    /// A conforming `rfunc` that signals an error the canonical AMPL way: by
894    /// **reassigning** `al->Errmsg` to its own static C string (leaving any
895    /// caller-provided buffer untouched), and returning NaN like an
896    /// out-of-domain evaluation.
897    unsafe extern "C" fn rfunc_reassigns_errmsg(al: *mut Arglist) -> f64 {
898        static MSG: &[u8] = b"T out of range\0";
899        // SAFETY: `al` is a valid, exclusively-borrowed Arglist for the call.
900        unsafe {
901            (*al).errmsg = MSG.as_ptr() as *mut c_char;
902        }
903        f64::NAN
904    }
905
906    /// Build an `Arglist` with every pointer null except `errmsg`. Sufficient
907    /// for a `rfunc` that only manipulates the error channel.
908    fn null_arglist(errmsg: *mut c_char) -> Arglist {
909        Arglist {
910            n: 1,
911            nr: 1,
912            at: ptr::null_mut(),
913            ra: ptr::null_mut(),
914            sa: ptr::null_mut(),
915            derivs: ptr::null_mut(),
916            hes: ptr::null_mut(),
917            dig: ptr::null_mut(),
918            funcinfo: ptr::null_mut(),
919            ae: ptr::null_mut(),
920            f: ptr::null_mut(),
921            tva: ptr::null_mut(),
922            errmsg,
923            tmi: ptr::null_mut(),
924            private: ptr::null_mut(),
925            nin: 0,
926            nout: 0,
927            nsin: 0,
928            nsout: 0,
929        }
930    }
931
932    /// End-to-end over the real `Arglist` + a real `extern "C"` call: a library
933    /// that reports an error by reassigning `al->Errmsg` (channel 1) must be
934    /// detected. Pre-fix, `eval` only inspected the caller buffer — which a
935    /// reassigning library never touches — so the error was invisible and the
936    /// IPM consumed the NaN return as a valid value.
937    #[test]
938    fn reassigned_errmsg_pointer_is_detected_end_to_end() {
939        let mut errmsg_buf: Vec<c_char> = vec![0; 1024];
940        let orig_ptr = errmsg_buf.as_ptr();
941        let mut al = null_arglist(errmsg_buf.as_mut_ptr());
942
943        // SAFETY: the rfunc matches the ABI and only writes `al.errmsg`.
944        let v = unsafe { rfunc_reassigns_errmsg(&mut al) };
945        assert!(v.is_nan(), "the failing eval returned NaN");
946
947        // A reassigning library leaves the caller buffer zeroed, so the old
948        // `errmsg_buf[0] != 0` check (the bug) saw nothing.
949        assert_eq!(
950            errmsg_buf[0], 0,
951            "a reassigning library must not touch the caller buffer"
952        );
953
954        // The fixed decode reads the reassigned pointer and surfaces the error.
955        let decoded = unsafe { decode_external_errmsg(al.errmsg, orig_ptr, errmsg_buf[0]) };
956        assert_eq!(
957            decoded.as_deref(),
958            Some("T out of range"),
959            "the reassigned errmsg pointer must be surfaced as an error"
960        );
961    }
962
963    /// The buffer channel (a library that writes into the caller buffer) and
964    /// the no-error cases still behave correctly.
965    #[test]
966    fn decode_external_errmsg_buffer_and_none_channels() {
967        // Channel 2: library wrote a string into the caller buffer.
968        let mut buf: Vec<c_char> = vec![0; 16];
969        for (i, b) in b"bad input".iter().enumerate() {
970            buf[i] = *b as c_char;
971        }
972        let orig = buf.as_ptr();
973        let decoded = unsafe { decode_external_errmsg(orig, orig, buf[0]) };
974        assert_eq!(decoded.as_deref(), Some("bad input"));
975
976        // No error: field still points at the (zeroed) buffer.
977        let zero: Vec<c_char> = vec![0; 16];
978        let z = zero.as_ptr();
979        assert_eq!(unsafe { decode_external_errmsg(z, z, zero[0]) }, None);
980
981        // No error via an explicitly NULL field (some libraries zero it).
982        assert_eq!(unsafe { decode_external_errmsg(ptr::null(), z, 0) }, None);
983    }
984}