Skip to main content

pounce_cli/
nl_reader.rs

1//! Minimal AMPL `.nl` ASCII-format reader.
2//!
3//! Implements the `g`-header text dialect for problems whose constraint
4//! and objective expressions are restricted to a polynomial-friendly
5//! subset of opcodes. This is **not** a full `.nl` reader — it is the
6//! smallest piece that lets `pounce --nl-file foo.nl` solve a real
7//! AMPL-emitted unconstrained problem.
8//!
9//! Supported:
10//! * Text header (`g…`).
11//! * Constraint and objective expression segments using opcodes
12//!   `o0` (add), `o1` (sub), `o2` (mul), `o3` (div), `o5` (pow),
13//!   `o16` (unary minus), `o39` (sqrt), `o42` (log10), `o43` (log),
14//!   `o44` (exp), `o15` (abs), `o41` (sin), `o46` (cos), plus
15//!   `n<num>` constants and `v<idx>` variables.
16//! * Linear-Jacobian (`J`) and linear-objective (`G`) segments.
17//! * Variable bounds (`b`) and constraint bounds (`r`).
18//! * Optional initial primal (`x`) segment. Initial dual (`d`) is
19//!   read and discarded.
20//! * Multiple objectives (we use only the first; per AMPL convention).
21//!
22//! Not supported (will return an error explaining what's missing):
23//! * Network / piecewise-linear constructs.
24//! * Complementarity rows.
25//! * Binary-format `.nl` files (`b…` header).
26//!
27//! References:
28//! * <https://ampl.com/REFS/hooking2.pdf> — "Hooking Your Solver to
29//!   AMPL" (David M. Gay), the canonical `.nl` spec.
30//! * `ref/Ipopt/test/mytoy.nl` — annotated example used for the unit
31//!   tests in this module.
32
33use crate::nl_tape::Tape;
34use pounce_common::types::{Index, Number};
35use pounce_nlp::tnlp::{
36    BoundsInfo, IndexStyle, IpoptCq, IpoptData, Linearity, NlpInfo, Solution, SparsityRequest,
37    StartingPoint, TNLP,
38};
39use std::cell::RefCell;
40use std::collections::{BTreeMap, BTreeSet, HashMap};
41use std::path::Path;
42use std::rc::Rc;
43
44#[derive(Debug, Clone)]
45pub enum Expr {
46    /// Numeric constant.
47    Const(Number),
48    /// Variable reference (0-based index into `x`).
49    Var(usize),
50    /// Binary op: `args = [lhs, rhs]`.
51    Binary(BinOp, Box<Expr>, Box<Expr>),
52    /// Unary op.
53    Unary(UnaryOp, Box<Expr>),
54    /// n-ary sum (opcode `o54` — variadic; we may emit it from `o0`
55    /// folding optimization, but the parser treats `o0` as binary).
56    Sum(Vec<Expr>),
57    /// Reference to a common subexpression (`.nl` `V` segment). The
58    /// payload is a shared body; many references to the same CSE share
59    /// one `Rc`, so the parsed problem is a DAG. Walking through `Cse`
60    /// is mathematically equivalent to inlining the body at each
61    /// occurrence (every reference is an independent occurrence in the
62    /// chain rule), so eval/grad/collect_vars just recurse into the
63    /// inner `Expr`.
64    Cse(Rc<Expr>),
65    /// AMPL imported (external) function call. `id` matches an entry in
66    /// `NlProblem.imported_funcs`; resolution to a live shared library
67    /// happens when the tape is built (see `nl_external::ExternalResolver`).
68    Funcall { id: usize, args: Vec<FuncallArg> },
69}
70
71/// One positional argument to an AMPL imported function call. AMPL splits
72/// arguments into reals (carried by `ra[]`) and strings (carried by `sa[]`);
73/// `FuncallArg` mirrors that split. Real args are arbitrary expressions.
74#[derive(Debug, Clone)]
75pub enum FuncallArg {
76    Real(Expr),
77    Str(String),
78}
79
80/// An AMPL imported (external) function declaration from a top-level
81/// `F<id> <type> <nargs> <name>` segment.
82#[derive(Debug, Clone)]
83pub struct ImportedFunc {
84    pub id: usize,
85    /// 0 = real-valued, 1 = string-args (per AMPL's funcadd ABI).
86    pub kind: usize,
87    /// Declared arg count. >=0 exact arity; <=-1 means at least `-(nargs+1)`.
88    pub nargs: i64,
89    pub name: String,
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
93pub enum BinOp {
94    Add,
95    Sub,
96    Mul,
97    Div,
98    Pow,
99}
100
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum UnaryOp {
103    Neg,
104    Sqrt,
105    Log,
106    Exp,
107    Abs,
108    Sin,
109    Cos,
110    Log10,
111}
112
113/// Parsed `.nl` problem in the form needed by `NlTnlp`.
114#[derive(Debug, Clone)]
115pub struct NlProblem {
116    pub n: usize,
117    pub m: usize,
118    pub num_obj: usize,
119    pub minimize: bool,
120    pub obj_nonlinear: Expr,
121    pub obj_linear: Vec<(usize, Number)>,
122    pub obj_constant: Number,
123    /// Per-constraint nonlinear part (length m).
124    pub con_nonlinear: Vec<Expr>,
125    /// Per-constraint linear part (length m), each a list of (var, coef).
126    pub con_linear: Vec<Vec<(usize, Number)>>,
127    pub x_l: Vec<Number>,
128    pub x_u: Vec<Number>,
129    pub g_l: Vec<Number>,
130    pub g_u: Vec<Number>,
131    pub x0: Vec<Number>,
132    pub lambda0: Vec<Number>,
133    /// AMPL suffix dictionaries. Variable / constraint / objective
134    /// suffixes are stored as dense vectors (length n / m / num_obj)
135    /// with the sparse `.nl` `S`-segment entries scattered in, default
136    /// zero. The integer / real split matches the `S`-segment header's
137    /// kind bit (`0x4` ⇒ real, else integer). See
138    /// <https://ampl.com/REFS/hooking2.pdf> §6 and the upstream `.nl`
139    /// reader in `ref/Ipopt/src/Apps/AmplSolver/AmplTNLP.cpp`.
140    pub suffixes: NlSuffixes,
141    /// AMPL imported (external) functions declared via top-level `F` segments.
142    /// Empty unless the `.nl` file calls compiled-C user functions (typically
143    /// emitted by IDAES property packages — see issue #49).
144    pub imported_funcs: Vec<ImportedFunc>,
145}
146
147/// Suffix data parsed out of `S`-segments. Sparse entries are scattered
148/// into dense vectors at problem load time so callers can index by
149/// variable / constraint number directly. Empty maps when the `.nl`
150/// file declared no suffixes.
151#[derive(Debug, Clone, Default)]
152pub struct NlSuffixes {
153    /// Variable-level integer suffixes (kind = 0). Each vector has
154    /// length `n_full` (problem variables).
155    pub var_int: BTreeMap<String, Vec<Index>>,
156    /// Constraint-level integer suffixes (kind = 1). Length `m_full`.
157    pub con_int: BTreeMap<String, Vec<Index>>,
158    /// Objective-level integer suffixes (kind = 2). Length `num_obj`.
159    pub obj_int: BTreeMap<String, Vec<Index>>,
160    /// Problem-level integer suffixes (kind = 3). Single value per name.
161    pub problem_int: BTreeMap<String, Index>,
162    /// Variable-level real suffixes (kind = 4). Length `n_full`.
163    pub var_real: BTreeMap<String, Vec<Number>>,
164    /// Constraint-level real suffixes (kind = 5). Length `m_full`.
165    pub con_real: BTreeMap<String, Vec<Number>>,
166    /// Objective-level real suffixes (kind = 6). Length `num_obj`.
167    pub obj_real: BTreeMap<String, Vec<Number>>,
168    /// Problem-level real suffixes (kind = 7). Single value per name.
169    pub problem_real: BTreeMap<String, Number>,
170}
171
172/// Parse an `.nl` file from disk.
173pub fn read_nl_file(path: &Path) -> Result<NlProblem, String> {
174    let txt = std::fs::read_to_string(path)
175        .map_err(|e| format!("could not read {}: {}", path.display(), e))?;
176    parse_nl_text(&txt)
177}
178
179/// Parse `.nl` text content. Public so tests can use string literals.
180pub fn parse_nl_text(txt: &str) -> Result<NlProblem, String> {
181    let mut p = Parser::new(txt);
182    p.parse_header()?;
183    let n = p.n;
184    let m = p.m;
185    let num_obj = p.num_obj;
186
187    let mut con_nonlinear: Vec<Expr> = (0..m).map(|_| Expr::Const(0.0)).collect();
188    let mut obj_nonlinear = Expr::Const(0.0);
189    let mut minimize = true;
190    let mut obj_linear: Vec<(usize, Number)> = Vec::new();
191    let mut con_linear: Vec<Vec<(usize, Number)>> = vec![Vec::new(); m];
192    let mut x_l = vec![-1e19; n];
193    let mut x_u = vec![1e19; n];
194    let mut g_l = vec![-1e19; m];
195    let mut g_u = vec![1e19; m];
196    let mut x0 = vec![0.0; n];
197    let mut lambda0 = vec![0.0; m];
198    let mut suffixes = NlSuffixes::default();
199    let mut imported_funcs: Vec<ImportedFunc> = Vec::new();
200
201    while let Some(line) = p.peek_segment_line() {
202        let tag = line
203            .trim_start()
204            .chars()
205            .next()
206            .ok_or("unexpected blank segment header")?;
207        match tag {
208            'C' => {
209                let (_hdr, rest) = p.eat_segment_header()?;
210                let _ = rest;
211                let idx = parse_segment_index(&_hdr, 'C')?;
212                if idx >= m {
213                    return Err(format!("C{idx} out of range; m={m}"));
214                }
215                con_nonlinear[idx] = p.parse_expr()?;
216            }
217            'O' => {
218                let (hdr, _rest) = p.eat_segment_header()?;
219                let parts: Vec<&str> = hdr.split_whitespace().collect();
220                if parts.len() < 2 {
221                    return Err(format!("malformed O-segment header: {hdr}"));
222                }
223                let idx = parse_segment_index(parts[0], 'O')?;
224                let kind: i32 = parts[1].parse().map_err(|e| format!("O kind: {e}"))?;
225                if idx == 0 {
226                    minimize = kind == 0;
227                    obj_nonlinear = p.parse_expr()?;
228                } else {
229                    // Extra objectives are read but ignored.
230                    let _ = p.parse_expr()?;
231                }
232            }
233            'r' => {
234                p.eat_segment_header()?;
235                for i in 0..m {
236                    let line = p.next_data_line()?;
237                    let (lo, hi) = parse_bound_line(&line)?;
238                    g_l[i] = lo;
239                    g_u[i] = hi;
240                }
241            }
242            'b' => {
243                p.eat_segment_header()?;
244                for i in 0..n {
245                    let line = p.next_data_line()?;
246                    let (lo, hi) = parse_bound_line(&line)?;
247                    x_l[i] = lo;
248                    x_u[i] = hi;
249                }
250            }
251            'k' => {
252                // Column counts in the Jacobian; we don't need them
253                // for evaluation since J segments give explicit lists.
254                p.eat_segment_header()?;
255                let count = if n == 0 { 0 } else { n - 1 };
256                for _ in 0..count {
257                    p.next_data_line()?;
258                }
259            }
260            'J' => {
261                let (hdr, _) = p.eat_segment_header()?;
262                let parts: Vec<&str> = hdr.split_whitespace().collect();
263                if parts.len() < 2 {
264                    return Err(format!("malformed J-segment header: {hdr}"));
265                }
266                let row = parse_segment_index(parts[0], 'J')?;
267                let nz: usize = parts[1].parse().map_err(|e| format!("J nz: {e}"))?;
268                if row >= m {
269                    return Err(format!("J{row} out of range"));
270                }
271                for _ in 0..nz {
272                    let line = p.next_data_line()?;
273                    let (var, coef) = parse_var_coef(&line)?;
274                    con_linear[row].push((var, coef));
275                }
276            }
277            'G' => {
278                let (hdr, _) = p.eat_segment_header()?;
279                let parts: Vec<&str> = hdr.split_whitespace().collect();
280                if parts.len() < 2 {
281                    return Err(format!("malformed G-segment header: {hdr}"));
282                }
283                let idx = parse_segment_index(parts[0], 'G')?;
284                let nz: usize = parts[1].parse().map_err(|e| format!("G nz: {e}"))?;
285                let mut acc = Vec::with_capacity(nz);
286                for _ in 0..nz {
287                    let line = p.next_data_line()?;
288                    let (var, coef) = parse_var_coef(&line)?;
289                    acc.push((var, coef));
290                }
291                if idx == 0 {
292                    obj_linear = acc;
293                }
294            }
295            'x' => {
296                let (hdr, _) = p.eat_segment_header()?;
297                let parts: Vec<&str> = hdr.split_whitespace().collect();
298                let nx: usize = parts
299                    .first()
300                    .and_then(|s| s.trim_start_matches('x').parse().ok())
301                    .ok_or_else(|| format!("malformed x-segment header: {hdr}"))?;
302                for _ in 0..nx {
303                    let line = p.next_data_line()?;
304                    let (idx, val) = parse_var_coef(&line)?;
305                    if idx < n {
306                        x0[idx] = val;
307                    }
308                }
309            }
310            'd' => {
311                let (hdr, _) = p.eat_segment_header()?;
312                let parts: Vec<&str> = hdr.split_whitespace().collect();
313                let nd: usize = parts
314                    .first()
315                    .and_then(|s| s.trim_start_matches('d').parse().ok())
316                    .ok_or_else(|| format!("malformed d-segment header: {hdr}"))?;
317                for _ in 0..nd {
318                    let line = p.next_data_line()?;
319                    let (idx, val) = parse_var_coef(&line)?;
320                    if idx < m {
321                        lambda0[idx] = val;
322                    }
323                }
324            }
325            'V' => p.parse_v_segment()?,
326            'S' => {
327                parse_suffix_segment(&mut p, n, m, num_obj, &mut suffixes)?;
328            }
329            'F' => {
330                // AMPL imported (external) function declaration:
331                // `F<k> <type> <nargs> <name>`.
332                let (hdr, _rest) = p.eat_segment_header()?;
333                let parts: Vec<&str> = hdr.split_whitespace().collect();
334                if parts.is_empty() {
335                    return Err(format!("malformed F-segment header: '{hdr}'"));
336                }
337                let id = parse_segment_index(parts[0], 'F')?;
338                let kind: usize = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
339                let nargs: i64 = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0);
340                let name = parts.get(3).copied().unwrap_or("").to_string();
341                imported_funcs.push(ImportedFunc {
342                    id,
343                    kind,
344                    nargs,
345                    name,
346                });
347            }
348            other => return Err(format!("unknown .nl segment tag '{other}'")),
349        }
350    }
351
352    Ok(NlProblem {
353        n,
354        m,
355        num_obj,
356        minimize,
357        obj_nonlinear,
358        obj_linear,
359        obj_constant: 0.0,
360        con_nonlinear,
361        con_linear,
362        x_l,
363        x_u,
364        g_l,
365        g_u,
366        x0,
367        lambda0,
368        suffixes,
369        imported_funcs,
370    })
371}
372
373/// Parse a single `S`-segment. Format (Gay 2005, "Hooking Your Solver
374/// to AMPL", §6, and `ref/Ipopt/src/Apps/AmplSolver/AmplTNLP.cpp`):
375///
376/// ```text
377/// S<kind> <nentries> <suffix_name>
378/// <idx> <value>      ... nentries lines
379/// ```
380///
381/// `<kind>` is a 3-bit encoding:
382/// * Bits 0-1 select the suffix target: 0 = variables, 1 = constraints,
383///   2 = objectives, 3 = problem-level.
384/// * Bit 2 (`0x4`) selects the value type: 0 = integer, 1 = real.
385///
386/// Sparse entries scatter into a freshly-allocated dense vector (zero
387/// default), sized for the target dimension. Problem-level suffixes
388/// (kind = 3 / 7) carry a single value.
389fn parse_suffix_segment(
390    p: &mut Parser,
391    n: usize,
392    m: usize,
393    num_obj: usize,
394    out: &mut NlSuffixes,
395) -> Result<(), String> {
396    let (hdr, _) = p.eat_segment_header()?;
397    let parts: Vec<&str> = hdr.split_whitespace().collect();
398    if parts.len() < 3 {
399        return Err(format!(
400            "malformed S-segment header: '{hdr}' (expected `S<kind> <n> <name>`)"
401        ));
402    }
403    let kind_str = parts[0].trim_start_matches('S');
404    let kind: u32 = kind_str
405        .parse()
406        .map_err(|e| format!("S kind '{kind_str}': {e}"))?;
407    let nentries: usize = parts[1].parse().map_err(|e| format!("S nentries: {e}"))?;
408    let name = parts[2].to_string();
409
410    let is_real = (kind & 0x4) != 0;
411    let target = kind & 0x3;
412    let target_dim = match target {
413        0 => n,
414        1 => m,
415        2 => num_obj,
416        3 => 0, // problem-level — entries are single-valued (idx=0)
417        _ => unreachable!("kind & 0x3 is in 0..=3"),
418    };
419
420    // Pre-allocate dense buffers (default zero). Problem-level kinds
421    // (3 / 7) hold a single scalar — we still read the (idx, value)
422    // pairs but only the value field is meaningful.
423    let mut int_buf: Vec<Index> = if !is_real && target != 3 {
424        vec![0; target_dim]
425    } else {
426        Vec::new()
427    };
428    let mut real_buf: Vec<Number> = if is_real && target != 3 {
429        vec![0.0; target_dim]
430    } else {
431        Vec::new()
432    };
433    let mut problem_int: Index = 0;
434    let mut problem_real: Number = 0.0;
435
436    for _ in 0..nentries {
437        let line = p.next_data_line()?;
438        let parts: Vec<&str> = line.split_whitespace().collect();
439        if parts.len() < 2 {
440            return Err(format!(
441                "malformed S-segment entry '{line}' (expected `<idx> <value>`)"
442            ));
443        }
444        let idx: usize = parts[0]
445            .parse()
446            .map_err(|e| format!("S entry idx '{}': {e}", parts[0]))?;
447        if target != 3 && idx >= target_dim {
448            return Err(format!(
449                "S-suffix '{name}' index {idx} out of range for target dim {target_dim}"
450            ));
451        }
452        if is_real {
453            let v: Number = parts[1]
454                .parse()
455                .map_err(|e| format!("S real entry value '{}': {e}", parts[1]))?;
456            if target == 3 {
457                problem_real = v;
458            } else {
459                real_buf[idx] = v;
460            }
461        } else {
462            let v: Index = parts[1]
463                .parse()
464                .map_err(|e| format!("S int entry value '{}': {e}", parts[1]))?;
465            if target == 3 {
466                problem_int = v;
467            } else {
468                int_buf[idx] = v;
469            }
470        }
471    }
472
473    match (target, is_real) {
474        (0, false) => {
475            out.var_int.insert(name, int_buf);
476        }
477        (1, false) => {
478            out.con_int.insert(name, int_buf);
479        }
480        (2, false) => {
481            out.obj_int.insert(name, int_buf);
482        }
483        (3, false) => {
484            out.problem_int.insert(name, problem_int);
485        }
486        (0, true) => {
487            out.var_real.insert(name, real_buf);
488        }
489        (1, true) => {
490            out.con_real.insert(name, real_buf);
491        }
492        (2, true) => {
493            out.obj_real.insert(name, real_buf);
494        }
495        (3, true) => {
496            out.problem_real.insert(name, problem_real);
497        }
498        _ => unreachable!(),
499    }
500    Ok(())
501}
502
503fn parse_segment_index(s: &str, tag: char) -> Result<usize, String> {
504    let trimmed = s.trim_start_matches(tag);
505    trimmed
506        .parse()
507        .map_err(|e| format!("malformed {tag}-segment index '{s}': {e}"))
508}
509
510fn parse_bound_line(line: &str) -> Result<(Number, Number), String> {
511    let parts: Vec<&str> = line.split_whitespace().collect();
512    if parts.is_empty() {
513        return Err("empty bound line".into());
514    }
515    let kind: i32 = parts[0].parse().map_err(|e| format!("bound kind: {e}"))?;
516    let lo;
517    let hi;
518    match kind {
519        0 => {
520            // 0  lo  hi
521            if parts.len() < 3 {
522                return Err(format!("bound kind 0 needs 2 values: '{line}'"));
523            }
524            lo = parts[1].parse().map_err(|e| format!("lo: {e}"))?;
525            hi = parts[2].parse().map_err(|e| format!("hi: {e}"))?;
526        }
527        1 => {
528            // 1  hi
529            if parts.len() < 2 {
530                return Err(format!("bound kind 1 needs 1 value: '{line}'"));
531            }
532            lo = -1e19;
533            hi = parts[1].parse().map_err(|e| format!("hi: {e}"))?;
534        }
535        2 => {
536            // 2  lo
537            if parts.len() < 2 {
538                return Err(format!("bound kind 2 needs 1 value: '{line}'"));
539            }
540            lo = parts[1].parse().map_err(|e| format!("lo: {e}"))?;
541            hi = 1e19;
542        }
543        3 => {
544            // 3  (free)
545            lo = -1e19;
546            hi = 1e19;
547        }
548        4 => {
549            // 4  eq
550            if parts.len() < 2 {
551                return Err(format!("bound kind 4 needs 1 value: '{line}'"));
552            }
553            let v: Number = parts[1].parse().map_err(|e| format!("eq: {e}"))?;
554            lo = v;
555            hi = v;
556        }
557        5 => return Err("complementarity (kind 5) bounds are not supported".into()),
558        other => return Err(format!("unknown bound kind {other}")),
559    }
560    Ok((lo, hi))
561}
562
563fn parse_var_coef(line: &str) -> Result<(usize, Number), String> {
564    let parts: Vec<&str> = line.split_whitespace().collect();
565    if parts.len() < 2 {
566        return Err(format!("malformed var/coef line: '{line}'"));
567    }
568    let v: usize = parts[0].parse().map_err(|e| format!("var idx: {e}"))?;
569    let c: Number = parts[1].parse().map_err(|e| format!("coef: {e}"))?;
570    Ok((v, c))
571}
572
573struct Parser<'a> {
574    lines: Vec<&'a str>,
575    pos: usize,
576    n: usize,
577    m: usize,
578    num_obj: usize,
579    /// Number of AMPL imported (external) functions declared in the header.
580    n_funcs: usize,
581    /// Common subexpressions (`V` segments). Index in this vec is the
582    /// CSE-local index, i.e. the global `.nl` index minus `n`.
583    cses: Vec<Rc<Expr>>,
584}
585
586impl<'a> Parser<'a> {
587    fn new(txt: &'a str) -> Self {
588        let lines: Vec<&str> = txt.lines().collect();
589        Self {
590            lines,
591            pos: 0,
592            n: 0,
593            m: 0,
594            num_obj: 0,
595            n_funcs: 0,
596            cses: Vec::new(),
597        }
598    }
599
600    fn next_line(&mut self) -> Option<&'a str> {
601        while self.pos < self.lines.len() {
602            let l = self.lines[self.pos];
603            self.pos += 1;
604            // Strip comment after '#' for header / data lines (but
605            // leave the segment-tag tokens untouched — they are the
606            // first token on the line).
607            let trimmed = strip_comment(l).trim();
608            if !trimmed.is_empty() {
609                return Some(l);
610            }
611        }
612        None
613    }
614
615    fn next_data_line(&mut self) -> Result<String, String> {
616        let raw = self
617            .next_line()
618            .ok_or_else(|| "unexpected end of file in data line".to_string())?;
619        Ok(strip_comment(raw).trim().to_string())
620    }
621
622    fn parse_header(&mut self) -> Result<(), String> {
623        let line0 = self.next_line().ok_or("empty .nl file")?;
624        let trimmed = strip_comment(line0).trim();
625        let first = trimmed.chars().next().ok_or("empty header line")?;
626        if first != 'g' {
627            return Err(format!(
628                "only ASCII (g-) .nl files supported; got header '{trimmed}'"
629            ));
630        }
631
632        // Header line 2: n_vars n_cons n_objs ranges eqns
633        let l2 = self.next_data_line()?;
634        let nums: Vec<&str> = l2.split_whitespace().collect();
635        if nums.len() < 3 {
636            return Err(format!("malformed line 2: '{l2}'"));
637        }
638        self.n = nums[0].parse().map_err(|e| format!("n: {e}"))?;
639        self.m = nums[1].parse().map_err(|e| format!("m: {e}"))?;
640        self.num_obj = nums[2].parse().map_err(|e| format!("num_obj: {e}"))?;
641
642        // Lines 3..5 are metadata we skip.
643        for _ in 0..3 {
644            self.next_data_line()?;
645        }
646        // Line 5 (0-indexed from `g`-header): `nwv nfunc arith flags`
647        let l5 = self.next_data_line()?;
648        let nums5: Vec<&str> = l5.split_whitespace().collect();
649        self.n_funcs = nums5.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
650        // Lines 6..10 are metadata we don't need — skip 4 more lines.
651        for _ in 0..4 {
652            self.next_data_line()?;
653        }
654        Ok(())
655    }
656
657    fn peek_segment_line(&mut self) -> Option<&'a str> {
658        let saved = self.pos;
659        let l = self.next_line()?;
660        self.pos = saved;
661        Some(l)
662    }
663
664    /// Eat the next non-blank line as a segment header. Returns the
665    /// whole header (after stripping comments) and the comment text.
666    fn eat_segment_header(&mut self) -> Result<(String, String), String> {
667        let raw = self
668            .next_line()
669            .ok_or_else(|| "expected segment header".to_string())?;
670        let (hdr, comment) = split_comment(raw);
671        Ok((hdr.trim().to_string(), comment.trim().to_string()))
672    }
673
674    fn parse_expr(&mut self) -> Result<Expr, String> {
675        let raw = self
676            .next_line()
677            .ok_or_else(|| "expected expression token".to_string())?;
678        let tok = strip_comment(raw).trim().to_string();
679        if tok.is_empty() {
680            return Err("empty expression token".into());
681        }
682        let first = tok.chars().next().ok_or("empty expression token")?;
683        match first {
684            'n' => {
685                let v: Number = tok[1..]
686                    .trim()
687                    .parse()
688                    .map_err(|e| format!("n value: {e}"))?;
689                Ok(Expr::Const(v))
690            }
691            'v' => {
692                let i: usize = tok[1..]
693                    .trim()
694                    .parse()
695                    .map_err(|e| format!("v index: {e}"))?;
696                Ok(self.var_or_cse(i)?)
697            }
698            'o' => {
699                let code: i32 = tok[1..]
700                    .trim()
701                    .parse()
702                    .map_err(|e| format!("opcode: {e}"))?;
703                self.parse_opcode(code)
704            }
705            'f' => {
706                // AMPL imported (external) function call: `f<id> <nargs>`
707                // followed by nargs child expressions (or string literals).
708                let rest = &tok[1..];
709                let mut parts = rest.split_whitespace();
710                let id_str = parts
711                    .next()
712                    .ok_or_else(|| format!("missing function id in '{tok}'"))?;
713                let nargs_str = parts
714                    .next()
715                    .ok_or_else(|| format!("missing nargs in '{tok}'"))?;
716                let id: usize = id_str
717                    .parse()
718                    .map_err(|e| format!("bad function id '{id_str}': {e}"))?;
719                let nargs: usize = nargs_str
720                    .parse()
721                    .map_err(|e| format!("bad funcall nargs '{nargs_str}': {e}"))?;
722                let mut args: Vec<FuncallArg> = Vec::with_capacity(nargs);
723                for _ in 0..nargs {
724                    args.push(self.parse_funcall_arg()?);
725                }
726                Ok(Expr::Funcall { id, args })
727            }
728            't' | 'u' => Err(format!("unsupported expression token '{tok}'")),
729            other => Err(format!(
730                "unexpected expression token start '{other}': '{tok}'"
731            )),
732        }
733    }
734
735    /// Parse one argument to an AMPL imported function. An argument
736    /// is either a normal expression (real-valued) or a string literal
737    /// in the form `h<len>:<chars>`. AMPL emits string args only when the
738    /// function was declared `FUNCADD_STRING_ARGS` (e.g. component name
739    /// or a parameters-directory path for IDAES Helmholtz functions).
740    fn parse_funcall_arg(&mut self) -> Result<FuncallArg, String> {
741        // Peek the next non-blank line so we can route `h...` differently.
742        let saved = self.pos;
743        let raw = self
744            .next_line()
745            .ok_or_else(|| "expected funcall argument".to_string())?;
746        let tok = strip_comment(raw).trim().to_string();
747        let first = tok.chars().next().ok_or("empty funcall arg token")?;
748        if first == 'h' {
749            // `h<len>:<chars>` — strip the `<len>:` prefix.
750            let rest = &tok[1..];
751            let s = match rest.find(':') {
752                Some(i) => rest[i + 1..].to_string(),
753                None => String::new(),
754            };
755            Ok(FuncallArg::Str(s))
756        } else {
757            // Rewind: parse_expr re-consumes the line we just peeked.
758            self.pos = saved;
759            Ok(FuncallArg::Real(self.parse_expr()?))
760        }
761    }
762
763    fn parse_opcode(&mut self, code: i32) -> Result<Expr, String> {
764        match code {
765            0 => {
766                let a = self.parse_expr()?;
767                let b = self.parse_expr()?;
768                Ok(Expr::Binary(BinOp::Add, Box::new(a), Box::new(b)))
769            }
770            1 => {
771                let a = self.parse_expr()?;
772                let b = self.parse_expr()?;
773                Ok(Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b)))
774            }
775            2 => {
776                let a = self.parse_expr()?;
777                let b = self.parse_expr()?;
778                Ok(Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b)))
779            }
780            3 => {
781                let a = self.parse_expr()?;
782                let b = self.parse_expr()?;
783                Ok(Expr::Binary(BinOp::Div, Box::new(a), Box::new(b)))
784            }
785            5 => {
786                let a = self.parse_expr()?;
787                let b = self.parse_expr()?;
788                Ok(Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b)))
789            }
790            15 => Ok(Expr::Unary(UnaryOp::Abs, Box::new(self.parse_expr()?))),
791            16 => Ok(Expr::Unary(UnaryOp::Neg, Box::new(self.parse_expr()?))),
792            39 => Ok(Expr::Unary(UnaryOp::Sqrt, Box::new(self.parse_expr()?))),
793            41 => Ok(Expr::Unary(UnaryOp::Sin, Box::new(self.parse_expr()?))),
794            42 => Ok(Expr::Unary(UnaryOp::Log10, Box::new(self.parse_expr()?))),
795            43 => Ok(Expr::Unary(UnaryOp::Log, Box::new(self.parse_expr()?))),
796            44 => Ok(Expr::Unary(UnaryOp::Exp, Box::new(self.parse_expr()?))),
797            46 => Ok(Expr::Unary(UnaryOp::Cos, Box::new(self.parse_expr()?))),
798            54 => {
799                // Variadic sum: next data line gives the count.
800                let count_line = self.next_data_line()?;
801                let count: usize = count_line
802                    .split_whitespace()
803                    .next()
804                    .ok_or_else(|| "missing variadic count".to_string())?
805                    .parse()
806                    .map_err(|e| format!("variadic count: {e}"))?;
807                let mut args = Vec::with_capacity(count);
808                for _ in 0..count {
809                    args.push(self.parse_expr()?);
810                }
811                Ok(Expr::Sum(args))
812            }
813            other => Err(format!("unsupported opcode o{other}")),
814        }
815    }
816
817    /// Resolve a `v<i>` token into either a plain variable reference
818    /// (`i < n`) or a shared CSE reference (`i >= n`).
819    fn var_or_cse(&self, i: usize) -> Result<Expr, String> {
820        if i < self.n {
821            Ok(Expr::Var(i))
822        } else {
823            let local = i - self.n;
824            self.cses
825                .get(local)
826                .map(|rc| Expr::Cse(rc.clone()))
827                .ok_or_else(|| {
828                    format!(
829                        "v{i} references CSE {local} but only {} have been defined",
830                        self.cses.len()
831                    )
832                })
833        }
834    }
835
836    /// Parse a `V<k> <nlin> <type>` common-subexpression segment. The
837    /// CSE evaluates to `nonlinear_expr + sum_i coef_i * v_{var_i}`.
838    /// CSEs are numbered starting at `n` and must appear in order.
839    fn parse_v_segment(&mut self) -> Result<(), String> {
840        let (hdr, _) = self.eat_segment_header()?;
841        let parts: Vec<&str> = hdr.split_whitespace().collect();
842        if parts.len() < 2 {
843            return Err(format!("malformed V-segment header: {hdr}"));
844        }
845        let cse_idx = parse_segment_index(parts[0], 'V')?;
846        let nlin: usize = parts[1].parse().map_err(|e| format!("V nlin: {e}"))?;
847        // parts[2] (type) is ignored; values >0 just mark special-purpose CSEs.
848        let mut linear: Vec<(usize, Number)> = Vec::with_capacity(nlin);
849        for _ in 0..nlin {
850            let line = self.next_data_line()?;
851            let (var, coef) = parse_var_coef(&line)?;
852            linear.push((var, coef));
853        }
854        let nonlin = self.parse_expr()?;
855        // Build `nonlin + sum coef_i * v_{var_i}`. Linear terms can
856        // reference earlier CSEs as well as plain variables.
857        let mut combined = nonlin;
858        for (var, coef) in linear {
859            let v_expr = self.var_or_cse(var)?;
860            let term = if coef == 1.0 {
861                v_expr
862            } else {
863                Expr::Binary(BinOp::Mul, Box::new(Expr::Const(coef)), Box::new(v_expr))
864            };
865            combined = Expr::Binary(BinOp::Add, Box::new(combined), Box::new(term));
866        }
867        if cse_idx < self.n {
868            return Err(format!("V{cse_idx} below n={}", self.n));
869        }
870        let local = cse_idx - self.n;
871        if local != self.cses.len() {
872            return Err(format!(
873                "V-segment index V{cse_idx} out of order; expected V{}",
874                self.n + self.cses.len()
875            ));
876        }
877        self.cses.push(Rc::new(combined));
878        Ok(())
879    }
880}
881
882fn strip_comment(s: &str) -> &str {
883    match s.find('#') {
884        Some(i) => &s[..i],
885        None => s,
886    }
887}
888
889fn split_comment(s: &str) -> (&str, &str) {
890    match s.find('#') {
891        Some(i) => (&s[..i], &s[i + 1..]),
892        None => (s, ""),
893    }
894}
895
896// --------------------------------------------------------------------
897// Expression evaluation and gradient (tree walkers, kept for tests).
898// The hot paths in `NlTnlp` use the flat `Tape` AD in `nl_tape.rs`
899// instead — see `Tape::gradient_seed` / `Tape::hessian_accumulate`.
900// --------------------------------------------------------------------
901
902/// Forward-mode value evaluation.
903pub fn eval_expr(e: &Expr, x: &[Number]) -> Number {
904    match e {
905        Expr::Const(c) => *c,
906        Expr::Var(i) => x[*i],
907        Expr::Binary(op, a, b) => {
908            let va = eval_expr(a, x);
909            let vb = eval_expr(b, x);
910            match op {
911                BinOp::Add => va + vb,
912                BinOp::Sub => va - vb,
913                BinOp::Mul => va * vb,
914                BinOp::Div => va / vb,
915                BinOp::Pow => va.powf(vb),
916            }
917        }
918        Expr::Unary(op, a) => {
919            let va = eval_expr(a, x);
920            match op {
921                UnaryOp::Neg => -va,
922                UnaryOp::Sqrt => va.sqrt(),
923                UnaryOp::Log => va.ln(),
924                UnaryOp::Log10 => va.log10(),
925                UnaryOp::Exp => va.exp(),
926                UnaryOp::Abs => va.abs(),
927                UnaryOp::Sin => va.sin(),
928                UnaryOp::Cos => va.cos(),
929            }
930        }
931        Expr::Sum(args) => args.iter().map(|a| eval_expr(a, x)).sum(),
932        Expr::Cse(body) => eval_expr(body, x),
933        Expr::Funcall { .. } => panic!(
934            "eval_expr: AMPL imported function called without an external resolver; \
935             evaluate through the tape AD path (Tape::build_with_externals) instead"
936        ),
937    }
938}
939
940/// Reverse-mode gradient: accumulates `seed * d(expr)/dx_i` into `grad`.
941pub fn grad_expr(e: &Expr, x: &[Number], seed: Number, grad: &mut [Number]) {
942    match e {
943        Expr::Const(_) => {}
944        Expr::Var(i) => grad[*i] += seed,
945        Expr::Binary(op, a, b) => {
946            let va = eval_expr(a, x);
947            let vb = eval_expr(b, x);
948            match op {
949                BinOp::Add => {
950                    grad_expr(a, x, seed, grad);
951                    grad_expr(b, x, seed, grad);
952                }
953                BinOp::Sub => {
954                    grad_expr(a, x, seed, grad);
955                    grad_expr(b, x, -seed, grad);
956                }
957                BinOp::Mul => {
958                    grad_expr(a, x, seed * vb, grad);
959                    grad_expr(b, x, seed * va, grad);
960                }
961                BinOp::Div => {
962                    grad_expr(a, x, seed / vb, grad);
963                    grad_expr(b, x, -seed * va / (vb * vb), grad);
964                }
965                BinOp::Pow => {
966                    // d/da: b * a^(b-1)
967                    let dpa = vb * va.powf(vb - 1.0);
968                    grad_expr(a, x, seed * dpa, grad);
969                    // d/db: a^b * ln(a) (only valid for a>0; simple branch)
970                    if va > 0.0 {
971                        let dpb = va.powf(vb) * va.ln();
972                        grad_expr(b, x, seed * dpb, grad);
973                    }
974                }
975            }
976        }
977        Expr::Unary(op, a) => {
978            let va = eval_expr(a, x);
979            let d = match op {
980                UnaryOp::Neg => -1.0,
981                UnaryOp::Sqrt => 0.5 / va.sqrt(),
982                UnaryOp::Log => 1.0 / va,
983                UnaryOp::Log10 => 1.0 / (va * std::f64::consts::LN_10),
984                UnaryOp::Exp => va.exp(),
985                UnaryOp::Abs => {
986                    if va > 0.0 {
987                        1.0
988                    } else if va < 0.0 {
989                        -1.0
990                    } else {
991                        0.0
992                    }
993                }
994                UnaryOp::Sin => va.cos(),
995                UnaryOp::Cos => -va.sin(),
996            };
997            grad_expr(a, x, seed * d, grad);
998        }
999        Expr::Sum(args) => {
1000            for arg in args {
1001                grad_expr(arg, x, seed, grad);
1002            }
1003        }
1004        Expr::Cse(body) => grad_expr(body, x, seed, grad),
1005        Expr::Funcall { .. } => {
1006            panic!("grad_expr: AMPL imported function called without an external resolver")
1007        }
1008    }
1009}
1010
1011/// Walk `e` and insert every `Var(i)` index into `out`.
1012pub fn collect_vars(e: &Expr, out: &mut BTreeSet<usize>) {
1013    match e {
1014        Expr::Const(_) => {}
1015        Expr::Var(i) => {
1016            out.insert(*i);
1017        }
1018        Expr::Binary(_, a, b) => {
1019            collect_vars(a, out);
1020            collect_vars(b, out);
1021        }
1022        Expr::Unary(_, a) => collect_vars(a, out),
1023        Expr::Sum(args) => {
1024            for a in args {
1025                collect_vars(a, out);
1026            }
1027        }
1028        Expr::Cse(body) => collect_vars(body, out),
1029        Expr::Funcall { args, .. } => {
1030            for a in args {
1031                if let FuncallArg::Real(e) = a {
1032                    collect_vars(e, out);
1033                }
1034            }
1035        }
1036    }
1037}
1038
1039// --------------------------------------------------------------------
1040// TNLP wrapper — backed by `Tape` reverse-mode AD for value, gradient,
1041// Jacobian, and Hessian. Built once at construction; every solve-time
1042// callback is a tape sweep, no expression-tree recursion.
1043// --------------------------------------------------------------------
1044
1045/// Per-color decoding instruction for `eval_h` Hessian-coloring.
1046/// After a directional Hessian-vector product `compressed = H · s_c`,
1047/// the entry at row `row` came uniquely from column `col` (because
1048/// no two columns of color `c` share any nonzero row), so we
1049/// scatter `compressed[row]` into `values[hess_idx]`.
1050#[derive(Debug, Clone)]
1051struct ColorWrite {
1052    row: u32,
1053    hess_idx: u32,
1054}
1055
1056#[derive(Debug)]
1057pub struct NlTnlp {
1058    prob: NlProblem,
1059    /// Per-summand objective tapes (one `Tape` per top-level
1060    /// summand after `split_top_sums`).
1061    obj_tapes: Vec<Tape>,
1062    /// Per-constraint, per-summand tapes. Length `m`; row `i` holds
1063    /// one `Tape` per summand of constraint `i`.
1064    con_tapes: Vec<Vec<Tape>>,
1065    /// Lower-triangle Hessian sparsity (row >= col), one entry per
1066    /// structurally nonzero second derivative in the Lagrangian.
1067    h_irow: Vec<i32>,
1068    h_jcol: Vec<i32>,
1069    /// Per-row sorted variable indices for the constraint Jacobian.
1070    jac_cols: Vec<Vec<usize>>,
1071    jac_nnz: usize,
1072    /// Per-color seed vector: `seeds[c][k] = 1.0` iff variable `k`
1073    /// is in color `c`, else `0.0`. Each color is a set of
1074    /// variables whose Hessian columns have pairwise-disjoint
1075    /// nonzero rows; one directional H·s product per color
1076    /// recovers all those columns simultaneously. Dense for
1077    /// O(1) lookup in the per-op forward tangent.
1078    seeds: Vec<Vec<f64>>,
1079    /// Per-color decoding table: for each `(row, hess_idx)` entry,
1080    /// scatter `compressed_c[row] -> values[hess_idx]` after the
1081    /// per-color directional product.
1082    decoding: Vec<Vec<ColorWrite>>,
1083    /// For each objective tape: the distinct colors of vars it
1084    /// references. Lets us skip tape × color pairs where the tape
1085    /// has zero overlap with the color's seed.
1086    obj_tape_colors: Vec<Vec<u32>>,
1087    /// Same as `obj_tape_colors` but per constraint × summand.
1088    con_tape_colors: Vec<Vec<Vec<u32>>>,
1089    final_x: Option<Vec<Number>>,
1090    final_obj: Number,
1091    /// Per-row Jacobian accumulator (length n).
1092    scratch_row_grad: Vec<f64>,
1093    /// Scratch buffers for `Tape::hessian_directional` (each sized
1094    /// to `max_tape_n`).
1095    vals_scratch: Vec<f64>,
1096    dot_scratch: Vec<f64>,
1097    adj_scratch: Vec<f64>,
1098    adj_dot_scratch: Vec<f64>,
1099    /// Per-color compressed Hessian-vector results, sized to
1100    /// `prob.n`. Reused across `eval_h` calls but allocated once.
1101    compressed: Vec<Vec<f64>>,
1102}
1103
1104/// Recursively flatten top-level Sum and binary-Add nodes into a list
1105/// of independent summands. Non-Sum/Add expressions are returned as a
1106/// single-element vector. This lets `NlTnlp` build one small tape per
1107/// term so the per-variable Hessian sweep only walks the term that
1108/// actually depends on that variable.
1109fn split_top_sums(expr: &Expr) -> Vec<Expr> {
1110    let mut out = Vec::new();
1111    fn go(e: &Expr, out: &mut Vec<Expr>) {
1112        match e {
1113            Expr::Sum(terms) => {
1114                for t in terms {
1115                    go(t, out);
1116                }
1117            }
1118            Expr::Binary(BinOp::Add, l, r) => {
1119                go(l, out);
1120                go(r, out);
1121            }
1122            _ => out.push(e.clone()),
1123        }
1124    }
1125    go(expr, &mut out);
1126    if out.is_empty() {
1127        out.push(Expr::Const(0.0));
1128    }
1129    out
1130}
1131
1132/// Greedy column coloring of a symmetric sparsity pattern stored
1133/// as lower-triangle pairs.
1134///
1135/// Builds the column-intersection graph: columns `c1` and `c2` are
1136/// adjacent iff there exists a row `r` with `H[r, c1] != 0` and
1137/// `H[r, c2] != 0`. A distance-1 greedy coloring on this graph
1138/// satisfies the direct-recovery condition for symmetric Hessians
1139/// (Coleman-Moré): for any color, the columns it contains have
1140/// pairwise disjoint row supports, so a single H·s product
1141/// recovers them all unambiguously.
1142///
1143/// Returns `(var_color, n_colors)` where `var_color[k]` is the
1144/// color assigned to variable `k`, or `u32::MAX` for variables
1145/// not in any Hessian pair (they contribute nothing and don't
1146/// need a color).
1147fn greedy_hessian_coloring(n: usize, lower_pairs: &[(usize, usize)]) -> (Vec<u32>, usize) {
1148    if n == 0 {
1149        return (Vec::new(), 0);
1150    }
1151
1152    // For each variable k, list of rows in which column k has a
1153    // nonzero in the FULL (symmetric) Hessian. Built from lower
1154    // pairs: (i, j) with i >= j contributes row i to column j and
1155    // row j to column i (when i != j); diagonals contribute once.
1156    let mut col_rows: Vec<Vec<u32>> = vec![Vec::new(); n];
1157    let mut row_cols: Vec<Vec<u32>> = vec![Vec::new(); n];
1158    for &(i, j) in lower_pairs {
1159        col_rows[j].push(i as u32);
1160        row_cols[i].push(j as u32);
1161        if i != j {
1162            col_rows[i].push(j as u32);
1163            row_cols[j].push(i as u32);
1164        }
1165    }
1166
1167    let mut var_color = vec![u32::MAX; n];
1168    let mut forbidden = vec![u32::MAX; n + 1];
1169    let mut n_colors: u32 = 0;
1170
1171    for j in 0..n {
1172        // Variable `j` has no Hessian entries → skip (no color).
1173        if col_rows[j].is_empty() {
1174            continue;
1175        }
1176        // Mark colors used by any column sharing a row with `j`.
1177        // Row-of-col -> col-in-row visit pattern collects all
1178        // distance-1 neighbors in the column-intersection graph.
1179        for &r in &col_rows[j] {
1180            for &c in &row_cols[r as usize] {
1181                if c as usize == j {
1182                    continue;
1183                }
1184                let cc = var_color[c as usize];
1185                if cc != u32::MAX {
1186                    forbidden[cc as usize] = j as u32;
1187                }
1188            }
1189        }
1190        // First color not stamped with `j as u32`.
1191        let mut chosen: u32 = 0;
1192        while (chosen as usize) < forbidden.len() && forbidden[chosen as usize] == j as u32 {
1193            chosen += 1;
1194        }
1195        var_color[j] = chosen;
1196        if chosen + 1 > n_colors {
1197            n_colors = chosen + 1;
1198        }
1199    }
1200
1201    (var_color, n_colors as usize)
1202}
1203
1204impl NlTnlp {
1205    pub fn new(prob: NlProblem) -> Self {
1206        // Resolve any AMPL imported (external) functions. Walk every
1207        // nonlinear expression to collect the funcall ids actually
1208        // referenced; load the libraries named in $AMPLFUNC and bind
1209        // each id to its (library, registered-name) pair so the tape
1210        // builder can emit live `TapeOp::Funcall` ops.
1211        let mut referenced: BTreeSet<usize> = BTreeSet::new();
1212        super::nl_external::collect_funcall_ids(&prob.obj_nonlinear, &mut referenced);
1213        for c in &prob.con_nonlinear {
1214            super::nl_external::collect_funcall_ids(c, &mut referenced);
1215        }
1216        let resolver = if referenced.is_empty() {
1217            super::nl_external::ExternalResolver::default()
1218        } else {
1219            super::nl_external::ExternalResolver::build_for_problem(
1220                &prob.imported_funcs,
1221                &referenced,
1222            )
1223            .unwrap_or_else(|e| panic!("failed to resolve AMPL external functions: {e}"))
1224        };
1225
1226        // Flatten objective and each constraint into independent
1227        // summands. Each summand becomes its own `Tape` (CSE bodies
1228        // are deduplicated within a tape via Rc identity in
1229        // `Tape::build`; bodies shared across summands are
1230        // duplicated, which we accept as a simplicity tradeoff).
1231        let obj_summands = split_top_sums(&prob.obj_nonlinear);
1232        let obj_tapes: Vec<Tape> = obj_summands
1233            .iter()
1234            .map(|e| Tape::build_with_externals(e, &resolver))
1235            .collect();
1236
1237        let mut con_tapes: Vec<Vec<Tape>> = Vec::with_capacity(prob.m);
1238        for k in 0..prob.m {
1239            let summands = split_top_sums(&prob.con_nonlinear[k]);
1240            con_tapes.push(
1241                summands
1242                    .iter()
1243                    .map(|e| Tape::build_with_externals(e, &resolver))
1244                    .collect(),
1245            );
1246        }
1247
1248        // Hessian-of-Lagrangian sparsity: union of each tape's own
1249        // structural Hessian sparsity.
1250        let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
1251        for t in &obj_tapes {
1252            pairs.extend(t.hessian_sparsity());
1253        }
1254        for row in &con_tapes {
1255            for t in row {
1256                pairs.extend(t.hessian_sparsity());
1257            }
1258        }
1259        let mut h_irow = Vec::with_capacity(pairs.len());
1260        let mut h_jcol = Vec::with_capacity(pairs.len());
1261        let mut hess_map = HashMap::with_capacity(pairs.len());
1262        for (k, (hi, lo)) in pairs.iter().enumerate() {
1263            h_irow.push(*hi as i32);
1264            h_jcol.push(*lo as i32);
1265            hess_map.insert((*hi, *lo), k);
1266        }
1267
1268        // Hessian column coloring. The chromatic number of the
1269        // column-intersection graph bounds how many directional
1270        // Hessian-vector products we need per `eval_h` call —
1271        // typically O(stencil) for PDE-mesh problems.
1272        let lower_pairs: Vec<(usize, usize)> = pairs.iter().copied().collect();
1273        let (var_color, n_colors) = greedy_hessian_coloring(prob.n, &lower_pairs);
1274
1275        // Per-color seed vectors (dense for O(1) Var lookup in
1276        // `Tape::hessian_directional`).
1277        let mut seeds: Vec<Vec<f64>> = vec![vec![0.0; prob.n]; n_colors];
1278        for (k, &c) in var_color.iter().enumerate() {
1279            if c != u32::MAX {
1280                seeds[c as usize][k] = 1.0;
1281            }
1282        }
1283
1284        // Per-color decoding table. For each lower-tri pair (i, j)
1285        // with i >= j, the entry belongs to column j's color: after
1286        // computing compressed_{c_j} = (H · s_{c_j}), the value at
1287        // row i is exactly H[i, j] (coloring guarantees no other
1288        // column in c_j has a nonzero at row i).
1289        let mut decoding: Vec<Vec<ColorWrite>> = vec![Vec::new(); n_colors];
1290        for (&(i, j), &idx) in hess_map.iter() {
1291            let c = var_color[j];
1292            debug_assert!(
1293                c != u32::MAX,
1294                "column {j} has Hessian pair {idx} but no color"
1295            );
1296            decoding[c as usize].push(ColorWrite {
1297                row: i as u32,
1298                hess_idx: idx as u32,
1299            });
1300        }
1301
1302        // Per-tape distinct color set: for each tape, the colors
1303        // its variables fall into. `eval_h` loops over only these
1304        // (tape, color) pairs instead of n_tapes × n_colors.
1305        let tape_colors = |t: &Tape| -> Vec<u32> {
1306            let mut s: BTreeSet<u32> = BTreeSet::new();
1307            for v in t.variables() {
1308                let c = var_color[v];
1309                if c != u32::MAX {
1310                    s.insert(c);
1311                }
1312            }
1313            s.into_iter().collect()
1314        };
1315        let obj_tape_colors: Vec<Vec<u32>> = obj_tapes.iter().map(tape_colors).collect();
1316        let con_tape_colors: Vec<Vec<Vec<u32>>> = con_tapes
1317            .iter()
1318            .map(|row| row.iter().map(tape_colors).collect())
1319            .collect();
1320
1321        // Per-row Jacobian sparsity = union of tape vars plus
1322        // linear-segment vars.
1323        let mut jac_cols: Vec<Vec<usize>> = Vec::with_capacity(prob.m);
1324        let mut jac_nnz = 0;
1325        for i in 0..prob.m {
1326            let mut set: BTreeSet<usize> = BTreeSet::new();
1327            for t in &con_tapes[i] {
1328                for v in t.variables() {
1329                    set.insert(v);
1330                }
1331            }
1332            for (v, _) in &prob.con_linear[i] {
1333                set.insert(*v);
1334            }
1335            let cols: Vec<usize> = set.into_iter().collect();
1336            jac_nnz += cols.len();
1337            jac_cols.push(cols);
1338        }
1339
1340        let mut max_tape_n: usize = 0;
1341        for t in &obj_tapes {
1342            max_tape_n = max_tape_n.max(t.ops.len());
1343        }
1344        for row in &con_tapes {
1345            for t in row {
1346                max_tape_n = max_tape_n.max(t.ops.len());
1347            }
1348        }
1349
1350        if std::env::var("POUNCE_DBG_TAPE_STATS").is_ok() {
1351            let n_obj = obj_tapes.len();
1352            let n_con: usize = con_tapes.iter().map(|r| r.len()).sum();
1353            let total = n_obj + n_con;
1354            let mut sum_ops: usize = 0;
1355            for t in &obj_tapes {
1356                sum_ops += t.ops.len();
1357            }
1358            for row in &con_tapes {
1359                for t in row {
1360                    sum_ops += t.ops.len();
1361                }
1362            }
1363            let t = total.max(1);
1364            let nnz_h = h_irow.len();
1365            let avg_decode =
1366                decoding.iter().map(|d| d.len()).sum::<usize>() as f64 / n_colors.max(1) as f64;
1367            eprintln!(
1368                "[tape stats] summands={total} (obj={n_obj} con={n_con}) \
1369                 total_ops={sum_ops} avg_ops={:.1} max_ops={max_tape_n} \
1370                 n_colors={n_colors} avg_decode_per_color={avg_decode:.1} nnz_h={nnz_h}",
1371                sum_ops as f64 / t as f64,
1372            );
1373        }
1374
1375        let compressed: Vec<Vec<f64>> = vec![vec![0.0; prob.n]; n_colors];
1376
1377        Self {
1378            prob,
1379            obj_tapes,
1380            con_tapes,
1381            h_irow,
1382            h_jcol,
1383            jac_cols,
1384            jac_nnz,
1385            seeds,
1386            decoding,
1387            obj_tape_colors,
1388            con_tape_colors,
1389            final_x: None,
1390            final_obj: 0.0,
1391            scratch_row_grad: Vec::new(),
1392            vals_scratch: vec![0.0; max_tape_n],
1393            dot_scratch: vec![0.0; max_tape_n],
1394            adj_scratch: vec![0.0; max_tape_n],
1395            adj_dot_scratch: vec![0.0; max_tape_n],
1396            compressed,
1397        }
1398    }
1399
1400    pub fn final_x(&self) -> Option<&[Number]> {
1401        self.final_x.as_deref()
1402    }
1403
1404    pub fn final_obj(&self) -> Number {
1405        self.final_obj
1406    }
1407}
1408
1409impl pounce_nlp::expression_provider::ExpressionProvider for NlTnlp {
1410    /// Per-`.nl`-row constraint expression tape, with the linear
1411    /// part folded in. Returns `None` for constraints that contribute
1412    /// neither a nonlinear expression nor any linear coefficients
1413    /// (so FBBT skips them — there's nothing to tighten).
1414    fn constraint_expression(&self, i: usize) -> Option<pounce_nlp::FbbtTape> {
1415        let nonlinear = self.prob.con_nonlinear.get(i)?;
1416        let linear = self
1417            .prob
1418            .con_linear
1419            .get(i)
1420            .map(|v| v.as_slice())
1421            .unwrap_or(&[]);
1422        crate::nl_fbbt_translate::translate_constraint(nonlinear, linear)
1423    }
1424}
1425
1426impl TNLP for NlTnlp {
1427    fn get_nlp_info(&mut self) -> Option<NlpInfo> {
1428        Some(NlpInfo {
1429            n: self.prob.n as Index,
1430            m: self.prob.m as Index,
1431            nnz_jac_g: self.jac_nnz as Index,
1432            nnz_h_lag: self.h_irow.len() as Index,
1433            index_style: IndexStyle::C,
1434        })
1435    }
1436
1437    fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool {
1438        b.x_l.copy_from_slice(&self.prob.x_l);
1439        b.x_u.copy_from_slice(&self.prob.x_u);
1440        if !self.prob.g_l.is_empty() {
1441            b.g_l.copy_from_slice(&self.prob.g_l);
1442            b.g_u.copy_from_slice(&self.prob.g_u);
1443        }
1444        true
1445    }
1446
1447    fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool {
1448        sp.x.copy_from_slice(&self.prob.x0);
1449        true
1450    }
1451
1452    fn eval_f(&mut self, x: &[Number], _new_x: bool) -> Option<Number> {
1453        let mut nl: Number = 0.0;
1454        for t in &self.obj_tapes {
1455            nl += t.eval(x);
1456        }
1457        let lin: Number = self.prob.obj_linear.iter().map(|(i, c)| c * x[*i]).sum();
1458        let v = self.prob.obj_constant + nl + lin;
1459        let signed = if self.prob.minimize { v } else { -v };
1460        Some(signed)
1461    }
1462
1463    fn eval_grad_f(&mut self, x: &[Number], _new_x: bool, grad: &mut [Number]) -> bool {
1464        grad.fill(0.0);
1465        for t in &self.obj_tapes {
1466            t.gradient_seed(x, 1.0, grad);
1467        }
1468        for (i, c) in &self.prob.obj_linear {
1469            grad[*i] += c;
1470        }
1471        if !self.prob.minimize {
1472            for g in grad.iter_mut() {
1473                *g = -*g;
1474            }
1475        }
1476        true
1477    }
1478
1479    fn eval_g(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
1480        for i in 0..self.prob.m {
1481            let mut nl: Number = 0.0;
1482            for t in &self.con_tapes[i] {
1483                nl += t.eval(x);
1484            }
1485            let lin: Number = self.prob.con_linear[i].iter().map(|(j, c)| c * x[*j]).sum();
1486            g[i] = nl + lin;
1487        }
1488        true
1489    }
1490
1491    fn eval_jac_g(
1492        &mut self,
1493        x: Option<&[Number]>,
1494        _new_x: bool,
1495        mode: SparsityRequest<'_>,
1496    ) -> bool {
1497        match mode {
1498            SparsityRequest::Structure { irow, jcol } => {
1499                let mut k = 0;
1500                for i in 0..self.prob.m {
1501                    for &j in &self.jac_cols[i] {
1502                        irow[k] = i as Index;
1503                        jcol[k] = j as Index;
1504                        k += 1;
1505                    }
1506                }
1507                true
1508            }
1509            SparsityRequest::Values { values } => {
1510                let n = self.prob.n;
1511                let xs = x.unwrap_or(&self.prob.x0);
1512                if self.scratch_row_grad.len() < n {
1513                    self.scratch_row_grad.resize(n, 0.0);
1514                }
1515                let mut k = 0;
1516                for i in 0..self.prob.m {
1517                    for &j in &self.jac_cols[i] {
1518                        self.scratch_row_grad[j] = 0.0;
1519                    }
1520                    for t in &self.con_tapes[i] {
1521                        t.gradient_seed(xs, 1.0, &mut self.scratch_row_grad);
1522                    }
1523                    for &(v, c) in &self.prob.con_linear[i] {
1524                        self.scratch_row_grad[v] += c;
1525                    }
1526                    for &j in &self.jac_cols[i] {
1527                        values[k] = self.scratch_row_grad[j];
1528                        k += 1;
1529                    }
1530                }
1531                true
1532            }
1533        }
1534    }
1535
1536    fn eval_h(
1537        &mut self,
1538        x: Option<&[Number]>,
1539        _new_x: bool,
1540        obj_factor: Number,
1541        lambda: Option<&[Number]>,
1542        _new_lambda: bool,
1543        mode: SparsityRequest<'_>,
1544    ) -> bool {
1545        match mode {
1546            SparsityRequest::Structure { irow, jcol } => {
1547                irow.copy_from_slice(&self.h_irow);
1548                jcol.copy_from_slice(&self.h_jcol);
1549                true
1550            }
1551            SparsityRequest::Values { values } => {
1552                let x = x.unwrap_or(&self.prob.x0);
1553                values.fill(0.0);
1554
1555                let obj_seed = if self.prob.minimize {
1556                    obj_factor
1557                } else {
1558                    -obj_factor
1559                };
1560                // Coloring path. For each (tape, weight) we do
1561                // one forward pass into `vals_scratch`, then one
1562                // forward-tangent+reverse-over-tangent per color
1563                // touched by that tape. Each pass accumulates a
1564                // weighted contribution of (H_tape · seed_c) into
1565                // `compressed[c]`. After all tapes done, we
1566                // decode each color's compressed vector into the
1567                // sparse `values` array.
1568                for buf in &mut self.compressed {
1569                    buf.fill(0.0);
1570                }
1571
1572                if obj_seed != 0.0 {
1573                    for (ti, t) in self.obj_tapes.iter().enumerate() {
1574                        if t.ops.is_empty() {
1575                            continue;
1576                        }
1577                        t.forward_into(x, &mut self.vals_scratch);
1578                        for &c in &self.obj_tape_colors[ti] {
1579                            t.hessian_directional(
1580                                &self.vals_scratch,
1581                                &self.seeds[c as usize],
1582                                obj_seed,
1583                                &mut self.compressed[c as usize],
1584                                &mut self.dot_scratch,
1585                                &mut self.adj_scratch,
1586                                &mut self.adj_dot_scratch,
1587                            );
1588                        }
1589                    }
1590                }
1591
1592                if let Some(lam) = lambda {
1593                    for k in 0..self.prob.m {
1594                        let w = lam[k];
1595                        if w == 0.0 {
1596                            continue;
1597                        }
1598                        for (ti, t) in self.con_tapes[k].iter().enumerate() {
1599                            if t.ops.is_empty() {
1600                                continue;
1601                            }
1602                            t.forward_into(x, &mut self.vals_scratch);
1603                            for &c in &self.con_tape_colors[k][ti] {
1604                                t.hessian_directional(
1605                                    &self.vals_scratch,
1606                                    &self.seeds[c as usize],
1607                                    w,
1608                                    &mut self.compressed[c as usize],
1609                                    &mut self.dot_scratch,
1610                                    &mut self.adj_scratch,
1611                                    &mut self.adj_dot_scratch,
1612                                );
1613                            }
1614                        }
1615                    }
1616                }
1617
1618                // Decode each color's compressed Hessian-vector
1619                // result into the lower-triangle `values` array.
1620                for (c, table) in self.decoding.iter().enumerate() {
1621                    let comp = &self.compressed[c];
1622                    for w in table {
1623                        values[w.hess_idx as usize] += comp[w.row as usize];
1624                    }
1625                }
1626                true
1627            }
1628        }
1629    }
1630
1631    fn finalize_solution(&mut self, sol: Solution<'_>, _d: &IpoptData, _q: &IpoptCq) {
1632        self.final_x = Some(sol.x.to_vec());
1633        self.final_obj = sol.obj_value;
1634    }
1635
1636    fn get_constraints_linearity(&mut self, types: &mut [Linearity]) -> bool {
1637        // A row is linear iff its nonlinear-part expression is the
1638        // identity zero left over from initial allocation (post-parse
1639        // identity for "no `C<idx>` segment touched this row").
1640        for (i, t) in types.iter_mut().enumerate() {
1641            *t = match &self.prob.con_nonlinear[i] {
1642                Expr::Const(c) if *c == 0.0 => Linearity::Linear,
1643                _ => Linearity::NonLinear,
1644            };
1645        }
1646        true
1647    }
1648}
1649
1650/// Convenience: read an `.nl` file and build a TNLP-compatible Rc.
1651pub fn load_nl_as_tnlp(path: &Path) -> Result<Rc<RefCell<dyn TNLP>>, String> {
1652    let prob = read_nl_file(path)?;
1653    Ok(Rc::new(RefCell::new(NlTnlp::new(prob))))
1654}
1655
1656#[cfg(test)]
1657mod tests {
1658    use super::*;
1659
1660    /// `min (x0 - 1)^2 + (x1 - 2)^2` written in `.nl` ASCII form.
1661    /// Header values:
1662    ///   line 2: n=2 m=0 num_obj=1 0 0
1663    ///   line 3: 0 1   (1 nonlinear objective)
1664    ///   line 4: 0 0
1665    ///   line 5: 0 2 0 (nonlinear vars in obj=2)
1666    ///   line 6: 0 0 0 1
1667    ///   line 7: 0 0 0 0 0
1668    ///   line 8: 0 0   (no Jacobian nonzeros, no linear obj)
1669    ///   line 9: 0 0
1670    ///   line 10: 0 0 0 0 0
1671    /// Then `O0 0` followed by an expression tree:
1672    /// `(x0 - 1)^2 + (x1 - 2)^2` =
1673    ///   o0
1674    ///     o5 (o1 v0 n1) n2
1675    ///     o5 (o1 v1 n2) n2
1676    /// Then `b` segment: free for both.
1677    const SIMPLE: &str = "g3 0 1 0
16782 0 1 0 0
16790 1
16800 0
16810 2 0
16820 0 0 1
16830 0 0 0 0
16840 0
16850 0
16860 0 0 0 0
1687O0 0
1688o0
1689o5
1690o1
1691v0
1692n1
1693n2
1694o5
1695o1
1696v1
1697n2
1698n2
1699b
17003
17013
1702";
1703
1704    #[test]
1705    fn parses_simple_quadratic() {
1706        let p = parse_nl_text(SIMPLE).expect("parse");
1707        assert_eq!(p.n, 2);
1708        assert_eq!(p.m, 0);
1709        assert_eq!(p.num_obj, 1);
1710        // f(0,0) = 1 + 4 = 5
1711        let f = eval_expr(&p.obj_nonlinear, &[0.0, 0.0]);
1712        assert!((f - 5.0).abs() < 1e-12);
1713        // f(1,2) = 0
1714        let f = eval_expr(&p.obj_nonlinear, &[1.0, 2.0]);
1715        assert!(f.abs() < 1e-12);
1716    }
1717
1718    #[test]
1719    fn gradient_matches_analytic() {
1720        let p = parse_nl_text(SIMPLE).expect("parse");
1721        let x = [0.5, 1.0];
1722        let mut g = [0.0_f64; 2];
1723        grad_expr(&p.obj_nonlinear, &x, 1.0, &mut g);
1724        // d/dx0 = 2*(x0-1) = -1.0
1725        // d/dx1 = 2*(x1-2) = -2.0
1726        assert!((g[0] - (-1.0)).abs() < 1e-12);
1727        assert!((g[1] - (-2.0)).abs() < 1e-12);
1728    }
1729
1730    /// `min x0^2 + x1^2  s.t.  x0 + x1 = 1`.
1731    /// One equality constraint with a purely linear Jacobian — exercises
1732    /// the constrained path (`eval_g`, `eval_jac_g`, `r`-segment bound
1733    /// kind 4).
1734    ///
1735    /// Header layout:
1736    ///   line 1: g3 0 1 0
1737    ///   line 2: 2 1 1 0 0   (n=2, m=1, num_obj=1)
1738    ///   line 3: 0 1         (1 nonlinear obj, 0 nonlinear cons)
1739    ///   line 4: 0 0
1740    ///   line 5: 0 2 0       (nonlinear vars in obj=2)
1741    ///   line 6: 0 0 0 1
1742    ///   line 7: 0 0 0 0 0
1743    ///   line 8: 2 0         (Jacobian nnz=2, no linear obj)
1744    ///   line 9: 0 0
1745    ///   line 10: 0 0 0 0 0
1746    /// Then C0 = const 0 (no nonlinear part), O0 = x0^2 + x1^2,
1747    /// r-segment kind 4 (eq) value 1, b-segment free, k-segment, J-row.
1748    const EQ_LIN: &str = "g3 0 1 0
17492 1 1 0 0
17500 1
17510 0
17520 2 0
17530 0 0 1
17540 0 0 0 0
17552 0
17560 0
17570 0 0 0 0
1758C0
1759n0
1760O0 0
1761o0
1762o5
1763v0
1764n2
1765o5
1766v1
1767n2
1768r
17694 1
1770b
17713
17723
1773k1
17742
1775J0 2
17760 1
17771 1
1778";
1779
1780    #[test]
1781    fn parses_constrained_problem() {
1782        let p = parse_nl_text(EQ_LIN).expect("parse");
1783        assert_eq!(p.n, 2);
1784        assert_eq!(p.m, 1);
1785        // r-segment kind 4 (equality with rhs=1).
1786        assert!((p.g_l[0] - 1.0).abs() < 1e-12);
1787        assert!((p.g_u[0] - 1.0).abs() < 1e-12);
1788        // J-row 0: x0 (coef 1), x1 (coef 1).
1789        assert_eq!(p.con_linear[0], vec![(0, 1.0), (1, 1.0)]);
1790    }
1791
1792    #[test]
1793    fn constrained_tnlp_eval_g_jac_h() {
1794        let p = parse_nl_text(EQ_LIN).expect("parse");
1795        let mut t = NlTnlp::new(p);
1796        let info = t.get_nlp_info().unwrap();
1797        assert_eq!(info.m, 1);
1798        assert_eq!(info.nnz_jac_g, 2);
1799
1800        // g(0.3, 0.4) = 0.3 + 0.4 = 0.7
1801        let mut g = [0.0_f64; 1];
1802        assert!(t.eval_g(&[0.3, 0.4], true, &mut g));
1803        assert!((g[0] - 0.7).abs() < 1e-12);
1804
1805        // Jacobian structure: row 0, cols [0, 1].
1806        let mut irow = [0_i32; 2];
1807        let mut jcol = [0_i32; 2];
1808        assert!(t.eval_jac_g(
1809            None,
1810            true,
1811            SparsityRequest::Structure {
1812                irow: &mut irow,
1813                jcol: &mut jcol
1814            }
1815        ));
1816        assert_eq!(irow, [0, 0]);
1817        assert_eq!(jcol, [0, 1]);
1818
1819        // Jacobian values: both 1.0.
1820        let mut vals = [0.0_f64; 2];
1821        assert!(t.eval_jac_g(
1822            Some(&[0.3, 0.4]),
1823            true,
1824            SparsityRequest::Values { values: &mut vals }
1825        ));
1826        assert!((vals[0] - 1.0).abs() < 1e-12);
1827        assert!((vals[1] - 1.0).abs() < 1e-12);
1828
1829        // Hessian of L = (x0^2 + x1^2) + λ*(x0 + x1 - 1) is diag(2,2);
1830        // λ contributes nothing because the constraint is linear, and
1831        // x0^2 + x1^2 is separable so there's no (1,0) entry in the
1832        // structural sparsity. nnz_h_lag = 2: (0,0) and (1,1).
1833        assert_eq!(info.nnz_h_lag, 2);
1834        let mut hirow = [0_i32; 2];
1835        let mut hjcol = [0_i32; 2];
1836        assert!(t.eval_h(
1837            None,
1838            true,
1839            1.0,
1840            None,
1841            true,
1842            SparsityRequest::Structure {
1843                irow: &mut hirow,
1844                jcol: &mut hjcol
1845            }
1846        ));
1847        assert_eq!(hirow, [0, 1]);
1848        assert_eq!(hjcol, [0, 1]);
1849        let mut hvals = [0.0_f64; 2];
1850        assert!(t.eval_h(
1851            Some(&[0.3, 0.4]),
1852            true,
1853            1.0,
1854            Some(&[0.5]),
1855            true,
1856            SparsityRequest::Values { values: &mut hvals }
1857        ));
1858        assert!((hvals[0] - 2.0).abs() < 1e-12);
1859        assert!((hvals[1] - 2.0).abs() < 1e-12);
1860    }
1861
1862    /// `min (x0 + x1)^2 + (x0 + x1)` with the shared sum `(x0 + x1)`
1863    /// encoded as common-subexpression `V2`. Header line 10 declares
1864    /// one obj-only CSE; expression tree references `v2` twice.
1865    const CSE_OBJ: &str = "g3 0 1 0
18662 0 1 0 0
18670 1
18680 0
18690 2 0
18700 0 0 1
18710 0 0 0 0
18720 0
18730 0
18740 1 0 0 0
1875V2 0 0
1876o0
1877v0
1878v1
1879O0 0
1880o0
1881o5
1882v2
1883n2
1884v2
1885b
18863
18873
1888";
1889
1890    #[test]
1891    fn parses_v_segment_cse() {
1892        let p = parse_nl_text(CSE_OBJ).expect("parse");
1893        assert_eq!(p.n, 2);
1894        // f(1,2) = 9 + 3 = 12
1895        let f = eval_expr(&p.obj_nonlinear, &[1.0, 2.0]);
1896        assert!((f - 12.0).abs() < 1e-12, "got {f}");
1897        // d/dx0 = 2*(x0+x1) + 1 = 7 at (1,2). Same for x1.
1898        let mut g = [0.0_f64; 2];
1899        grad_expr(&p.obj_nonlinear, &[1.0, 2.0], 1.0, &mut g);
1900        assert!((g[0] - 7.0).abs() < 1e-12, "g[0]={}", g[0]);
1901        assert!((g[1] - 7.0).abs() < 1e-12, "g[1]={}", g[1]);
1902        // collect_vars reaches into the CSE body and finds {0, 1}.
1903        let mut vs = BTreeSet::new();
1904        collect_vars(&p.obj_nonlinear, &mut vs);
1905        assert_eq!(vs.into_iter().collect::<Vec<_>>(), vec![0, 1]);
1906    }
1907
1908    /// `min (x0 - 1)^2` with three suffix segments attached: an
1909    /// integer constraint-suffix (target=1, kind=1), an integer var-
1910    /// suffix (target=0, kind=0), and a real var-suffix (target=0,
1911    /// kind=4). The .nl format is `S<kind> <nentries> <name>` then
1912    /// `<idx> <value>` lines.
1913    const WITH_SUFFIXES: &str = "g3 0 1 0
19141 0 1 0 0
19150 1
19160 0
19170 1 0
19180 0 0 1
19190 0 0 0 0
19200 0
19210 0
19220 0 0 0 0
1923O0 0
1924o5
1925o1
1926v0
1927n1
1928n2
1929b
19303
1931S0 1 sens_state_1
19320 7
1933S4 1 sens_state_value_1
19340 4.5
1935";
1936
1937    #[test]
1938    fn parses_var_int_and_var_real_suffixes() {
1939        let p = parse_nl_text(WITH_SUFFIXES).expect("parse");
1940        // Integer var-suffix: dense length 1, slot 0 = 7.
1941        let v = p.suffixes.var_int.get("sens_state_1").expect("var_int");
1942        assert_eq!(v.as_slice(), &[7]);
1943        // Real var-suffix: dense length 1, slot 0 = 4.5.
1944        let r = p
1945            .suffixes
1946            .var_real
1947            .get("sens_state_value_1")
1948            .expect("var_real");
1949        assert_eq!(r.len(), 1);
1950        assert!((r[0] - 4.5).abs() < 1e-12);
1951        // Other suffix slots stay empty.
1952        assert!(p.suffixes.con_int.is_empty());
1953        assert!(p.suffixes.con_real.is_empty());
1954    }
1955
1956    /// Two-variable + two-constraint problem with a constraint-level
1957    /// integer suffix (kind=1). Sparse entries scatter to dense length 2.
1958    const WITH_CON_SUFFIX: &str = "g3 0 1 0
19592 2 1 0 0
19600 0
19610 0
19620 2 0
19630 0 0 1
19640 0 0 0 0
19652 0
19660 0
19670 0 0 0 0 0
1968C0
1969n0
1970C1
1971n0
1972O0 0
1973n0
1974r
19754 0.0
19764 0.0
1977b
19783
19793
1980k1
19810
1982J0 2
19830 1
19841 1
1985J1 2
19860 1
19871 -1
1988S1 2 sens_init_constr
19890 1
19901 2
1991";
1992
1993    #[test]
1994    fn parses_con_int_suffix() {
1995        let p = parse_nl_text(WITH_CON_SUFFIX).expect("parse");
1996        let s = p.suffixes.con_int.get("sens_init_constr").expect("con_int");
1997        // Sparse {0:1, 1:2} → dense [1, 2] at length m=2.
1998        assert_eq!(s.as_slice(), &[1, 2]);
1999    }
2000
2001    #[test]
2002    fn rejects_suffix_with_out_of_range_index() {
2003        let bad = WITH_CON_SUFFIX.replace("1 2\n", "5 2\n"); // m=2, idx=5 invalid
2004        let err = parse_nl_text(&bad).expect_err("must reject");
2005        assert!(
2006            err.contains("out of range"),
2007            "expected out-of-range error, got: {err}"
2008        );
2009    }
2010
2011    #[test]
2012    fn tnlp_round_trip_solves() {
2013        let p = parse_nl_text(SIMPLE).expect("parse");
2014        let mut tnlp = NlTnlp::new(p);
2015        let info = tnlp.get_nlp_info().unwrap();
2016        assert_eq!(info.n, 2);
2017        assert_eq!(info.m, 0);
2018        let f0 = tnlp.eval_f(&[0.0, 0.0], true).unwrap();
2019        assert!((f0 - 5.0).abs() < 1e-12);
2020        let mut g = [0.0_f64; 2];
2021        tnlp.eval_grad_f(&[0.0, 0.0], true, &mut g);
2022        // d/dx0 at x=0: 2*(0-1) = -2; d/dx1: 2*(0-2) = -4
2023        assert!((g[0] - (-2.0)).abs() < 1e-12);
2024        assert!((g[1] - (-4.0)).abs() < 1e-12);
2025    }
2026}