Skip to main content

pounce_cli/
nl_reader.rs

1//! Minimal AMPL `.nl` ASCII-format reader.
2//!
3//! Implements the `g`-header text dialect for problems whose constraint
4//! and objective expressions are restricted to a polynomial-friendly
5//! subset of opcodes. This is **not** a full `.nl` reader — it is the
6//! smallest piece that lets `pounce --nl-file foo.nl` solve a real
7//! AMPL-emitted unconstrained problem.
8//!
9//! Supported:
10//! * Text header (`g…`).
11//! * Constraint and objective expression segments using opcodes
12//!   `o0` (add), `o1` (sub), `o2` (mul), `o3` (div), `o5` (pow),
13//!   `o16` (unary minus), `o39` (sqrt), `o42` (log10), `o43` (log),
14//!   `o44` (exp), `o15` (abs), `o41` (sin), `o46` (cos), plus
15//!   `n<num>` constants and `v<idx>` variables.
16//! * Linear-Jacobian (`J`) and linear-objective (`G`) segments.
17//! * Variable bounds (`b`) and constraint bounds (`r`).
18//! * Optional initial primal (`x`) segment. Initial dual (`d`) is
19//!   read and discarded.
20//! * Multiple objectives (we use only the first; per AMPL convention).
21//!
22//! Not supported (will return an error explaining what's missing):
23//! * Network / piecewise-linear constructs.
24//! * Complementarity rows.
25//! * Binary-format `.nl` files (`b…` header).
26//!
27//! References:
28//! * <https://ampl.com/REFS/hooking2.pdf> — "Hooking Your Solver to
29//!   AMPL" (David M. Gay), the canonical `.nl` spec.
30//! * `ref/Ipopt/test/mytoy.nl` — annotated example used for the unit
31//!   tests in this module.
32
33use crate::nl_tape::Tape;
34use pounce_common::types::{Index, Number};
35use pounce_nlp::tnlp::{
36    BoundsInfo, IndexStyle, IpoptCq, IpoptData, Linearity, NlpInfo, Solution, SparsityRequest,
37    StartingPoint, TNLP,
38};
39use std::cell::RefCell;
40use std::collections::{BTreeMap, BTreeSet, HashMap};
41use std::path::Path;
42use std::rc::Rc;
43
44#[derive(Debug, Clone)]
45pub enum Expr {
46    /// Numeric constant.
47    Const(Number),
48    /// Variable reference (0-based index into `x`).
49    Var(usize),
50    /// Binary op: `args = [lhs, rhs]`.
51    Binary(BinOp, Box<Expr>, Box<Expr>),
52    /// Unary op.
53    Unary(UnaryOp, Box<Expr>),
54    /// n-ary sum (opcode `o54` — variadic; we may emit it from `o0`
55    /// folding optimization, but the parser treats `o0` as binary).
56    Sum(Vec<Expr>),
57    /// Reference to a common subexpression (`.nl` `V` segment). The
58    /// payload is a shared body; many references to the same CSE share
59    /// one `Rc`, so the parsed problem is a DAG. Walking through `Cse`
60    /// is mathematically equivalent to inlining the body at each
61    /// occurrence (every reference is an independent occurrence in the
62    /// chain rule), so eval/grad/collect_vars just recurse into the
63    /// inner `Expr`.
64    Cse(Rc<Expr>),
65}
66
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68pub enum BinOp {
69    Add,
70    Sub,
71    Mul,
72    Div,
73    Pow,
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
77pub enum UnaryOp {
78    Neg,
79    Sqrt,
80    Log,
81    Exp,
82    Abs,
83    Sin,
84    Cos,
85    Log10,
86}
87
88/// Parsed `.nl` problem in the form needed by `NlTnlp`.
89#[derive(Debug, Clone)]
90pub struct NlProblem {
91    pub n: usize,
92    pub m: usize,
93    pub num_obj: usize,
94    pub minimize: bool,
95    pub obj_nonlinear: Expr,
96    pub obj_linear: Vec<(usize, Number)>,
97    pub obj_constant: Number,
98    /// Per-constraint nonlinear part (length m).
99    pub con_nonlinear: Vec<Expr>,
100    /// Per-constraint linear part (length m), each a list of (var, coef).
101    pub con_linear: Vec<Vec<(usize, Number)>>,
102    pub x_l: Vec<Number>,
103    pub x_u: Vec<Number>,
104    pub g_l: Vec<Number>,
105    pub g_u: Vec<Number>,
106    pub x0: Vec<Number>,
107    pub lambda0: Vec<Number>,
108    /// AMPL suffix dictionaries. Variable / constraint / objective
109    /// suffixes are stored as dense vectors (length n / m / num_obj)
110    /// with the sparse `.nl` `S`-segment entries scattered in, default
111    /// zero. The integer / real split matches the `S`-segment header's
112    /// kind bit (`0x4` ⇒ real, else integer). See
113    /// <https://ampl.com/REFS/hooking2.pdf> §6 and the upstream `.nl`
114    /// reader in `ref/Ipopt/src/Apps/AmplSolver/AmplTNLP.cpp`.
115    pub suffixes: NlSuffixes,
116}
117
118/// Suffix data parsed out of `S`-segments. Sparse entries are scattered
119/// into dense vectors at problem load time so callers can index by
120/// variable / constraint number directly. Empty maps when the `.nl`
121/// file declared no suffixes.
122#[derive(Debug, Clone, Default)]
123pub struct NlSuffixes {
124    /// Variable-level integer suffixes (kind = 0). Each vector has
125    /// length `n_full` (problem variables).
126    pub var_int: BTreeMap<String, Vec<Index>>,
127    /// Constraint-level integer suffixes (kind = 1). Length `m_full`.
128    pub con_int: BTreeMap<String, Vec<Index>>,
129    /// Objective-level integer suffixes (kind = 2). Length `num_obj`.
130    pub obj_int: BTreeMap<String, Vec<Index>>,
131    /// Problem-level integer suffixes (kind = 3). Single value per name.
132    pub problem_int: BTreeMap<String, Index>,
133    /// Variable-level real suffixes (kind = 4). Length `n_full`.
134    pub var_real: BTreeMap<String, Vec<Number>>,
135    /// Constraint-level real suffixes (kind = 5). Length `m_full`.
136    pub con_real: BTreeMap<String, Vec<Number>>,
137    /// Objective-level real suffixes (kind = 6). Length `num_obj`.
138    pub obj_real: BTreeMap<String, Vec<Number>>,
139    /// Problem-level real suffixes (kind = 7). Single value per name.
140    pub problem_real: BTreeMap<String, Number>,
141}
142
143/// Parse an `.nl` file from disk.
144pub fn read_nl_file(path: &Path) -> Result<NlProblem, String> {
145    let txt = std::fs::read_to_string(path)
146        .map_err(|e| format!("could not read {}: {}", path.display(), e))?;
147    parse_nl_text(&txt)
148}
149
150/// Parse `.nl` text content. Public so tests can use string literals.
151pub fn parse_nl_text(txt: &str) -> Result<NlProblem, String> {
152    let mut p = Parser::new(txt);
153    p.parse_header()?;
154    let n = p.n;
155    let m = p.m;
156    let num_obj = p.num_obj;
157
158    let mut con_nonlinear: Vec<Expr> = (0..m).map(|_| Expr::Const(0.0)).collect();
159    let mut obj_nonlinear = Expr::Const(0.0);
160    let mut minimize = true;
161    let mut obj_linear: Vec<(usize, Number)> = Vec::new();
162    let mut con_linear: Vec<Vec<(usize, Number)>> = vec![Vec::new(); m];
163    let mut x_l = vec![-1e19; n];
164    let mut x_u = vec![1e19; n];
165    let mut g_l = vec![-1e19; m];
166    let mut g_u = vec![1e19; m];
167    let mut x0 = vec![0.0; n];
168    let mut lambda0 = vec![0.0; m];
169    let mut suffixes = NlSuffixes::default();
170
171    while let Some(line) = p.peek_segment_line() {
172        let tag = line
173            .trim_start()
174            .chars()
175            .next()
176            .ok_or("unexpected blank segment header")?;
177        match tag {
178            'C' => {
179                let (_hdr, rest) = p.eat_segment_header()?;
180                let _ = rest;
181                let idx = parse_segment_index(&_hdr, 'C')?;
182                if idx >= m {
183                    return Err(format!("C{idx} out of range; m={m}"));
184                }
185                con_nonlinear[idx] = p.parse_expr()?;
186            }
187            'O' => {
188                let (hdr, _rest) = p.eat_segment_header()?;
189                let parts: Vec<&str> = hdr.split_whitespace().collect();
190                if parts.len() < 2 {
191                    return Err(format!("malformed O-segment header: {hdr}"));
192                }
193                let idx = parse_segment_index(parts[0], 'O')?;
194                let kind: i32 = parts[1].parse().map_err(|e| format!("O kind: {e}"))?;
195                if idx == 0 {
196                    minimize = kind == 0;
197                    obj_nonlinear = p.parse_expr()?;
198                } else {
199                    // Extra objectives are read but ignored.
200                    let _ = p.parse_expr()?;
201                }
202            }
203            'r' => {
204                p.eat_segment_header()?;
205                for i in 0..m {
206                    let line = p.next_data_line()?;
207                    let (lo, hi) = parse_bound_line(&line)?;
208                    g_l[i] = lo;
209                    g_u[i] = hi;
210                }
211            }
212            'b' => {
213                p.eat_segment_header()?;
214                for i in 0..n {
215                    let line = p.next_data_line()?;
216                    let (lo, hi) = parse_bound_line(&line)?;
217                    x_l[i] = lo;
218                    x_u[i] = hi;
219                }
220            }
221            'k' => {
222                // Column counts in the Jacobian; we don't need them
223                // for evaluation since J segments give explicit lists.
224                p.eat_segment_header()?;
225                let count = if n == 0 { 0 } else { n - 1 };
226                for _ in 0..count {
227                    p.next_data_line()?;
228                }
229            }
230            'J' => {
231                let (hdr, _) = p.eat_segment_header()?;
232                let parts: Vec<&str> = hdr.split_whitespace().collect();
233                if parts.len() < 2 {
234                    return Err(format!("malformed J-segment header: {hdr}"));
235                }
236                let row = parse_segment_index(parts[0], 'J')?;
237                let nz: usize = parts[1].parse().map_err(|e| format!("J nz: {e}"))?;
238                if row >= m {
239                    return Err(format!("J{row} out of range"));
240                }
241                for _ in 0..nz {
242                    let line = p.next_data_line()?;
243                    let (var, coef) = parse_var_coef(&line)?;
244                    con_linear[row].push((var, coef));
245                }
246            }
247            'G' => {
248                let (hdr, _) = p.eat_segment_header()?;
249                let parts: Vec<&str> = hdr.split_whitespace().collect();
250                if parts.len() < 2 {
251                    return Err(format!("malformed G-segment header: {hdr}"));
252                }
253                let idx = parse_segment_index(parts[0], 'G')?;
254                let nz: usize = parts[1].parse().map_err(|e| format!("G nz: {e}"))?;
255                let mut acc = Vec::with_capacity(nz);
256                for _ in 0..nz {
257                    let line = p.next_data_line()?;
258                    let (var, coef) = parse_var_coef(&line)?;
259                    acc.push((var, coef));
260                }
261                if idx == 0 {
262                    obj_linear = acc;
263                }
264            }
265            'x' => {
266                let (hdr, _) = p.eat_segment_header()?;
267                let parts: Vec<&str> = hdr.split_whitespace().collect();
268                let nx: usize = parts
269                    .first()
270                    .and_then(|s| s.trim_start_matches('x').parse().ok())
271                    .ok_or_else(|| format!("malformed x-segment header: {hdr}"))?;
272                for _ in 0..nx {
273                    let line = p.next_data_line()?;
274                    let (idx, val) = parse_var_coef(&line)?;
275                    if idx < n {
276                        x0[idx] = val;
277                    }
278                }
279            }
280            'd' => {
281                let (hdr, _) = p.eat_segment_header()?;
282                let parts: Vec<&str> = hdr.split_whitespace().collect();
283                let nd: usize = parts
284                    .first()
285                    .and_then(|s| s.trim_start_matches('d').parse().ok())
286                    .ok_or_else(|| format!("malformed d-segment header: {hdr}"))?;
287                for _ in 0..nd {
288                    let line = p.next_data_line()?;
289                    let (idx, val) = parse_var_coef(&line)?;
290                    if idx < m {
291                        lambda0[idx] = val;
292                    }
293                }
294            }
295            'V' => p.parse_v_segment()?,
296            'S' => {
297                parse_suffix_segment(&mut p, n, m, num_obj, &mut suffixes)?;
298            }
299            'F' => return Err("F (imported function) segments are not supported".into()),
300            other => return Err(format!("unknown .nl segment tag '{other}'")),
301        }
302    }
303
304    Ok(NlProblem {
305        n,
306        m,
307        num_obj,
308        minimize,
309        obj_nonlinear,
310        obj_linear,
311        obj_constant: 0.0,
312        con_nonlinear,
313        con_linear,
314        x_l,
315        x_u,
316        g_l,
317        g_u,
318        x0,
319        lambda0,
320        suffixes,
321    })
322}
323
324/// Parse a single `S`-segment. Format (Gay 2005, "Hooking Your Solver
325/// to AMPL", §6, and `ref/Ipopt/src/Apps/AmplSolver/AmplTNLP.cpp`):
326///
327/// ```text
328/// S<kind> <nentries> <suffix_name>
329/// <idx> <value>      ... nentries lines
330/// ```
331///
332/// `<kind>` is a 3-bit encoding:
333/// * Bits 0-1 select the suffix target: 0 = variables, 1 = constraints,
334///   2 = objectives, 3 = problem-level.
335/// * Bit 2 (`0x4`) selects the value type: 0 = integer, 1 = real.
336///
337/// Sparse entries scatter into a freshly-allocated dense vector (zero
338/// default), sized for the target dimension. Problem-level suffixes
339/// (kind = 3 / 7) carry a single value.
340fn parse_suffix_segment(
341    p: &mut Parser,
342    n: usize,
343    m: usize,
344    num_obj: usize,
345    out: &mut NlSuffixes,
346) -> Result<(), String> {
347    let (hdr, _) = p.eat_segment_header()?;
348    let parts: Vec<&str> = hdr.split_whitespace().collect();
349    if parts.len() < 3 {
350        return Err(format!(
351            "malformed S-segment header: '{hdr}' (expected `S<kind> <n> <name>`)"
352        ));
353    }
354    let kind_str = parts[0].trim_start_matches('S');
355    let kind: u32 = kind_str
356        .parse()
357        .map_err(|e| format!("S kind '{kind_str}': {e}"))?;
358    let nentries: usize = parts[1].parse().map_err(|e| format!("S nentries: {e}"))?;
359    let name = parts[2].to_string();
360
361    let is_real = (kind & 0x4) != 0;
362    let target = kind & 0x3;
363    let target_dim = match target {
364        0 => n,
365        1 => m,
366        2 => num_obj,
367        3 => 0, // problem-level — entries are single-valued (idx=0)
368        _ => unreachable!("kind & 0x3 is in 0..=3"),
369    };
370
371    // Pre-allocate dense buffers (default zero). Problem-level kinds
372    // (3 / 7) hold a single scalar — we still read the (idx, value)
373    // pairs but only the value field is meaningful.
374    let mut int_buf: Vec<Index> = if !is_real && target != 3 {
375        vec![0; target_dim]
376    } else {
377        Vec::new()
378    };
379    let mut real_buf: Vec<Number> = if is_real && target != 3 {
380        vec![0.0; target_dim]
381    } else {
382        Vec::new()
383    };
384    let mut problem_int: Index = 0;
385    let mut problem_real: Number = 0.0;
386
387    for _ in 0..nentries {
388        let line = p.next_data_line()?;
389        let parts: Vec<&str> = line.split_whitespace().collect();
390        if parts.len() < 2 {
391            return Err(format!(
392                "malformed S-segment entry '{line}' (expected `<idx> <value>`)"
393            ));
394        }
395        let idx: usize = parts[0]
396            .parse()
397            .map_err(|e| format!("S entry idx '{}': {e}", parts[0]))?;
398        if target != 3 && idx >= target_dim {
399            return Err(format!(
400                "S-suffix '{name}' index {idx} out of range for target dim {target_dim}"
401            ));
402        }
403        if is_real {
404            let v: Number = parts[1]
405                .parse()
406                .map_err(|e| format!("S real entry value '{}': {e}", parts[1]))?;
407            if target == 3 {
408                problem_real = v;
409            } else {
410                real_buf[idx] = v;
411            }
412        } else {
413            let v: Index = parts[1]
414                .parse()
415                .map_err(|e| format!("S int entry value '{}': {e}", parts[1]))?;
416            if target == 3 {
417                problem_int = v;
418            } else {
419                int_buf[idx] = v;
420            }
421        }
422    }
423
424    match (target, is_real) {
425        (0, false) => {
426            out.var_int.insert(name, int_buf);
427        }
428        (1, false) => {
429            out.con_int.insert(name, int_buf);
430        }
431        (2, false) => {
432            out.obj_int.insert(name, int_buf);
433        }
434        (3, false) => {
435            out.problem_int.insert(name, problem_int);
436        }
437        (0, true) => {
438            out.var_real.insert(name, real_buf);
439        }
440        (1, true) => {
441            out.con_real.insert(name, real_buf);
442        }
443        (2, true) => {
444            out.obj_real.insert(name, real_buf);
445        }
446        (3, true) => {
447            out.problem_real.insert(name, problem_real);
448        }
449        _ => unreachable!(),
450    }
451    Ok(())
452}
453
454fn parse_segment_index(s: &str, tag: char) -> Result<usize, String> {
455    let trimmed = s.trim_start_matches(tag);
456    trimmed
457        .parse()
458        .map_err(|e| format!("malformed {tag}-segment index '{s}': {e}"))
459}
460
461fn parse_bound_line(line: &str) -> Result<(Number, Number), String> {
462    let parts: Vec<&str> = line.split_whitespace().collect();
463    if parts.is_empty() {
464        return Err("empty bound line".into());
465    }
466    let kind: i32 = parts[0].parse().map_err(|e| format!("bound kind: {e}"))?;
467    let lo;
468    let hi;
469    match kind {
470        0 => {
471            // 0  lo  hi
472            if parts.len() < 3 {
473                return Err(format!("bound kind 0 needs 2 values: '{line}'"));
474            }
475            lo = parts[1].parse().map_err(|e| format!("lo: {e}"))?;
476            hi = parts[2].parse().map_err(|e| format!("hi: {e}"))?;
477        }
478        1 => {
479            // 1  hi
480            if parts.len() < 2 {
481                return Err(format!("bound kind 1 needs 1 value: '{line}'"));
482            }
483            lo = -1e19;
484            hi = parts[1].parse().map_err(|e| format!("hi: {e}"))?;
485        }
486        2 => {
487            // 2  lo
488            if parts.len() < 2 {
489                return Err(format!("bound kind 2 needs 1 value: '{line}'"));
490            }
491            lo = parts[1].parse().map_err(|e| format!("lo: {e}"))?;
492            hi = 1e19;
493        }
494        3 => {
495            // 3  (free)
496            lo = -1e19;
497            hi = 1e19;
498        }
499        4 => {
500            // 4  eq
501            if parts.len() < 2 {
502                return Err(format!("bound kind 4 needs 1 value: '{line}'"));
503            }
504            let v: Number = parts[1].parse().map_err(|e| format!("eq: {e}"))?;
505            lo = v;
506            hi = v;
507        }
508        5 => return Err("complementarity (kind 5) bounds are not supported".into()),
509        other => return Err(format!("unknown bound kind {other}")),
510    }
511    Ok((lo, hi))
512}
513
514fn parse_var_coef(line: &str) -> Result<(usize, Number), String> {
515    let parts: Vec<&str> = line.split_whitespace().collect();
516    if parts.len() < 2 {
517        return Err(format!("malformed var/coef line: '{line}'"));
518    }
519    let v: usize = parts[0].parse().map_err(|e| format!("var idx: {e}"))?;
520    let c: Number = parts[1].parse().map_err(|e| format!("coef: {e}"))?;
521    Ok((v, c))
522}
523
524struct Parser<'a> {
525    lines: Vec<&'a str>,
526    pos: usize,
527    n: usize,
528    m: usize,
529    num_obj: usize,
530    /// Common subexpressions (`V` segments). Index in this vec is the
531    /// CSE-local index, i.e. the global `.nl` index minus `n`.
532    cses: Vec<Rc<Expr>>,
533}
534
535impl<'a> Parser<'a> {
536    fn new(txt: &'a str) -> Self {
537        let lines: Vec<&str> = txt.lines().collect();
538        Self {
539            lines,
540            pos: 0,
541            n: 0,
542            m: 0,
543            num_obj: 0,
544            cses: Vec::new(),
545        }
546    }
547
548    fn next_line(&mut self) -> Option<&'a str> {
549        while self.pos < self.lines.len() {
550            let l = self.lines[self.pos];
551            self.pos += 1;
552            // Strip comment after '#' for header / data lines (but
553            // leave the segment-tag tokens untouched — they are the
554            // first token on the line).
555            let trimmed = strip_comment(l).trim();
556            if !trimmed.is_empty() {
557                return Some(l);
558            }
559        }
560        None
561    }
562
563    fn next_data_line(&mut self) -> Result<String, String> {
564        let raw = self
565            .next_line()
566            .ok_or_else(|| "unexpected end of file in data line".to_string())?;
567        Ok(strip_comment(raw).trim().to_string())
568    }
569
570    fn parse_header(&mut self) -> Result<(), String> {
571        let line0 = self.next_line().ok_or("empty .nl file")?;
572        let trimmed = strip_comment(line0).trim();
573        let first = trimmed.chars().next().ok_or("empty header line")?;
574        if first != 'g' {
575            return Err(format!(
576                "only ASCII (g-) .nl files supported; got header '{trimmed}'"
577            ));
578        }
579
580        // Header line 2: n_vars n_cons n_objs ranges eqns
581        let l2 = self.next_data_line()?;
582        let nums: Vec<&str> = l2.split_whitespace().collect();
583        if nums.len() < 3 {
584            return Err(format!("malformed line 2: '{l2}'"));
585        }
586        self.n = nums[0].parse().map_err(|e| format!("n: {e}"))?;
587        self.m = nums[1].parse().map_err(|e| format!("m: {e}"))?;
588        self.num_obj = nums[2].parse().map_err(|e| format!("num_obj: {e}"))?;
589
590        // Lines 3..10 are metadata we don't need — skip 8 more lines.
591        for _ in 0..8 {
592            self.next_data_line()?;
593        }
594        Ok(())
595    }
596
597    fn peek_segment_line(&mut self) -> Option<&'a str> {
598        let saved = self.pos;
599        let l = self.next_line()?;
600        self.pos = saved;
601        Some(l)
602    }
603
604    /// Eat the next non-blank line as a segment header. Returns the
605    /// whole header (after stripping comments) and the comment text.
606    fn eat_segment_header(&mut self) -> Result<(String, String), String> {
607        let raw = self
608            .next_line()
609            .ok_or_else(|| "expected segment header".to_string())?;
610        let (hdr, comment) = split_comment(raw);
611        Ok((hdr.trim().to_string(), comment.trim().to_string()))
612    }
613
614    fn parse_expr(&mut self) -> Result<Expr, String> {
615        let raw = self
616            .next_line()
617            .ok_or_else(|| "expected expression token".to_string())?;
618        let tok = strip_comment(raw).trim().to_string();
619        if tok.is_empty() {
620            return Err("empty expression token".into());
621        }
622        let first = tok.chars().next().ok_or("empty expression token")?;
623        match first {
624            'n' => {
625                let v: Number = tok[1..]
626                    .trim()
627                    .parse()
628                    .map_err(|e| format!("n value: {e}"))?;
629                Ok(Expr::Const(v))
630            }
631            'v' => {
632                let i: usize = tok[1..]
633                    .trim()
634                    .parse()
635                    .map_err(|e| format!("v index: {e}"))?;
636                Ok(self.var_or_cse(i)?)
637            }
638            'o' => {
639                let code: i32 = tok[1..]
640                    .trim()
641                    .parse()
642                    .map_err(|e| format!("opcode: {e}"))?;
643                self.parse_opcode(code)
644            }
645            'f' | 't' | 'u' => Err(format!("unsupported expression token '{tok}'")),
646            other => Err(format!(
647                "unexpected expression token start '{other}': '{tok}'"
648            )),
649        }
650    }
651
652    fn parse_opcode(&mut self, code: i32) -> Result<Expr, String> {
653        match code {
654            0 => {
655                let a = self.parse_expr()?;
656                let b = self.parse_expr()?;
657                Ok(Expr::Binary(BinOp::Add, Box::new(a), Box::new(b)))
658            }
659            1 => {
660                let a = self.parse_expr()?;
661                let b = self.parse_expr()?;
662                Ok(Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b)))
663            }
664            2 => {
665                let a = self.parse_expr()?;
666                let b = self.parse_expr()?;
667                Ok(Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b)))
668            }
669            3 => {
670                let a = self.parse_expr()?;
671                let b = self.parse_expr()?;
672                Ok(Expr::Binary(BinOp::Div, Box::new(a), Box::new(b)))
673            }
674            5 => {
675                let a = self.parse_expr()?;
676                let b = self.parse_expr()?;
677                Ok(Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b)))
678            }
679            15 => Ok(Expr::Unary(UnaryOp::Abs, Box::new(self.parse_expr()?))),
680            16 => Ok(Expr::Unary(UnaryOp::Neg, Box::new(self.parse_expr()?))),
681            39 => Ok(Expr::Unary(UnaryOp::Sqrt, Box::new(self.parse_expr()?))),
682            41 => Ok(Expr::Unary(UnaryOp::Sin, Box::new(self.parse_expr()?))),
683            42 => Ok(Expr::Unary(UnaryOp::Log10, Box::new(self.parse_expr()?))),
684            43 => Ok(Expr::Unary(UnaryOp::Log, Box::new(self.parse_expr()?))),
685            44 => Ok(Expr::Unary(UnaryOp::Exp, Box::new(self.parse_expr()?))),
686            46 => Ok(Expr::Unary(UnaryOp::Cos, Box::new(self.parse_expr()?))),
687            54 => {
688                // Variadic sum: next data line gives the count.
689                let count_line = self.next_data_line()?;
690                let count: usize = count_line
691                    .split_whitespace()
692                    .next()
693                    .ok_or_else(|| "missing variadic count".to_string())?
694                    .parse()
695                    .map_err(|e| format!("variadic count: {e}"))?;
696                let mut args = Vec::with_capacity(count);
697                for _ in 0..count {
698                    args.push(self.parse_expr()?);
699                }
700                Ok(Expr::Sum(args))
701            }
702            other => Err(format!("unsupported opcode o{other}")),
703        }
704    }
705
706    /// Resolve a `v<i>` token into either a plain variable reference
707    /// (`i < n`) or a shared CSE reference (`i >= n`).
708    fn var_or_cse(&self, i: usize) -> Result<Expr, String> {
709        if i < self.n {
710            Ok(Expr::Var(i))
711        } else {
712            let local = i - self.n;
713            self.cses
714                .get(local)
715                .map(|rc| Expr::Cse(rc.clone()))
716                .ok_or_else(|| {
717                    format!(
718                        "v{i} references CSE {local} but only {} have been defined",
719                        self.cses.len()
720                    )
721                })
722        }
723    }
724
725    /// Parse a `V<k> <nlin> <type>` common-subexpression segment. The
726    /// CSE evaluates to `nonlinear_expr + sum_i coef_i * v_{var_i}`.
727    /// CSEs are numbered starting at `n` and must appear in order.
728    fn parse_v_segment(&mut self) -> Result<(), String> {
729        let (hdr, _) = self.eat_segment_header()?;
730        let parts: Vec<&str> = hdr.split_whitespace().collect();
731        if parts.len() < 2 {
732            return Err(format!("malformed V-segment header: {hdr}"));
733        }
734        let cse_idx = parse_segment_index(parts[0], 'V')?;
735        let nlin: usize = parts[1].parse().map_err(|e| format!("V nlin: {e}"))?;
736        // parts[2] (type) is ignored; values >0 just mark special-purpose CSEs.
737        let mut linear: Vec<(usize, Number)> = Vec::with_capacity(nlin);
738        for _ in 0..nlin {
739            let line = self.next_data_line()?;
740            let (var, coef) = parse_var_coef(&line)?;
741            linear.push((var, coef));
742        }
743        let nonlin = self.parse_expr()?;
744        // Build `nonlin + sum coef_i * v_{var_i}`. Linear terms can
745        // reference earlier CSEs as well as plain variables.
746        let mut combined = nonlin;
747        for (var, coef) in linear {
748            let v_expr = self.var_or_cse(var)?;
749            let term = if coef == 1.0 {
750                v_expr
751            } else {
752                Expr::Binary(BinOp::Mul, Box::new(Expr::Const(coef)), Box::new(v_expr))
753            };
754            combined = Expr::Binary(BinOp::Add, Box::new(combined), Box::new(term));
755        }
756        if cse_idx < self.n {
757            return Err(format!("V{cse_idx} below n={}", self.n));
758        }
759        let local = cse_idx - self.n;
760        if local != self.cses.len() {
761            return Err(format!(
762                "V-segment index V{cse_idx} out of order; expected V{}",
763                self.n + self.cses.len()
764            ));
765        }
766        self.cses.push(Rc::new(combined));
767        Ok(())
768    }
769}
770
771fn strip_comment(s: &str) -> &str {
772    match s.find('#') {
773        Some(i) => &s[..i],
774        None => s,
775    }
776}
777
778fn split_comment(s: &str) -> (&str, &str) {
779    match s.find('#') {
780        Some(i) => (&s[..i], &s[i + 1..]),
781        None => (s, ""),
782    }
783}
784
785// --------------------------------------------------------------------
786// Expression evaluation and gradient (tree walkers, kept for tests).
787// The hot paths in `NlTnlp` use the flat `Tape` AD in `nl_tape.rs`
788// instead — see `Tape::gradient_seed` / `Tape::hessian_accumulate`.
789// --------------------------------------------------------------------
790
791/// Forward-mode value evaluation.
792pub fn eval_expr(e: &Expr, x: &[Number]) -> Number {
793    match e {
794        Expr::Const(c) => *c,
795        Expr::Var(i) => x[*i],
796        Expr::Binary(op, a, b) => {
797            let va = eval_expr(a, x);
798            let vb = eval_expr(b, x);
799            match op {
800                BinOp::Add => va + vb,
801                BinOp::Sub => va - vb,
802                BinOp::Mul => va * vb,
803                BinOp::Div => va / vb,
804                BinOp::Pow => va.powf(vb),
805            }
806        }
807        Expr::Unary(op, a) => {
808            let va = eval_expr(a, x);
809            match op {
810                UnaryOp::Neg => -va,
811                UnaryOp::Sqrt => va.sqrt(),
812                UnaryOp::Log => va.ln(),
813                UnaryOp::Log10 => va.log10(),
814                UnaryOp::Exp => va.exp(),
815                UnaryOp::Abs => va.abs(),
816                UnaryOp::Sin => va.sin(),
817                UnaryOp::Cos => va.cos(),
818            }
819        }
820        Expr::Sum(args) => args.iter().map(|a| eval_expr(a, x)).sum(),
821        Expr::Cse(body) => eval_expr(body, x),
822    }
823}
824
825/// Reverse-mode gradient: accumulates `seed * d(expr)/dx_i` into `grad`.
826pub fn grad_expr(e: &Expr, x: &[Number], seed: Number, grad: &mut [Number]) {
827    match e {
828        Expr::Const(_) => {}
829        Expr::Var(i) => grad[*i] += seed,
830        Expr::Binary(op, a, b) => {
831            let va = eval_expr(a, x);
832            let vb = eval_expr(b, x);
833            match op {
834                BinOp::Add => {
835                    grad_expr(a, x, seed, grad);
836                    grad_expr(b, x, seed, grad);
837                }
838                BinOp::Sub => {
839                    grad_expr(a, x, seed, grad);
840                    grad_expr(b, x, -seed, grad);
841                }
842                BinOp::Mul => {
843                    grad_expr(a, x, seed * vb, grad);
844                    grad_expr(b, x, seed * va, grad);
845                }
846                BinOp::Div => {
847                    grad_expr(a, x, seed / vb, grad);
848                    grad_expr(b, x, -seed * va / (vb * vb), grad);
849                }
850                BinOp::Pow => {
851                    // d/da: b * a^(b-1)
852                    let dpa = vb * va.powf(vb - 1.0);
853                    grad_expr(a, x, seed * dpa, grad);
854                    // d/db: a^b * ln(a) (only valid for a>0; simple branch)
855                    if va > 0.0 {
856                        let dpb = va.powf(vb) * va.ln();
857                        grad_expr(b, x, seed * dpb, grad);
858                    }
859                }
860            }
861        }
862        Expr::Unary(op, a) => {
863            let va = eval_expr(a, x);
864            let d = match op {
865                UnaryOp::Neg => -1.0,
866                UnaryOp::Sqrt => 0.5 / va.sqrt(),
867                UnaryOp::Log => 1.0 / va,
868                UnaryOp::Log10 => 1.0 / (va * std::f64::consts::LN_10),
869                UnaryOp::Exp => va.exp(),
870                UnaryOp::Abs => {
871                    if va > 0.0 {
872                        1.0
873                    } else if va < 0.0 {
874                        -1.0
875                    } else {
876                        0.0
877                    }
878                }
879                UnaryOp::Sin => va.cos(),
880                UnaryOp::Cos => -va.sin(),
881            };
882            grad_expr(a, x, seed * d, grad);
883        }
884        Expr::Sum(args) => {
885            for arg in args {
886                grad_expr(arg, x, seed, grad);
887            }
888        }
889        Expr::Cse(body) => grad_expr(body, x, seed, grad),
890    }
891}
892
893/// Walk `e` and insert every `Var(i)` index into `out`.
894pub fn collect_vars(e: &Expr, out: &mut BTreeSet<usize>) {
895    match e {
896        Expr::Const(_) => {}
897        Expr::Var(i) => {
898            out.insert(*i);
899        }
900        Expr::Binary(_, a, b) => {
901            collect_vars(a, out);
902            collect_vars(b, out);
903        }
904        Expr::Unary(_, a) => collect_vars(a, out),
905        Expr::Sum(args) => {
906            for a in args {
907                collect_vars(a, out);
908            }
909        }
910        Expr::Cse(body) => collect_vars(body, out),
911    }
912}
913
914// --------------------------------------------------------------------
915// TNLP wrapper — backed by `Tape` reverse-mode AD for value, gradient,
916// Jacobian, and Hessian. Built once at construction; every solve-time
917// callback is a tape sweep, no expression-tree recursion.
918// --------------------------------------------------------------------
919
920/// Per-color decoding instruction for `eval_h` Hessian-coloring.
921/// After a directional Hessian-vector product `compressed = H · s_c`,
922/// the entry at row `row` came uniquely from column `col` (because
923/// no two columns of color `c` share any nonzero row), so we
924/// scatter `compressed[row]` into `values[hess_idx]`.
925#[derive(Debug, Clone)]
926struct ColorWrite {
927    row: u32,
928    hess_idx: u32,
929}
930
931#[derive(Debug)]
932pub struct NlTnlp {
933    prob: NlProblem,
934    /// Per-summand objective tapes (one `Tape` per top-level
935    /// summand after `split_top_sums`).
936    obj_tapes: Vec<Tape>,
937    /// Per-constraint, per-summand tapes. Length `m`; row `i` holds
938    /// one `Tape` per summand of constraint `i`.
939    con_tapes: Vec<Vec<Tape>>,
940    /// Lower-triangle Hessian sparsity (row >= col), one entry per
941    /// structurally nonzero second derivative in the Lagrangian.
942    h_irow: Vec<i32>,
943    h_jcol: Vec<i32>,
944    /// Per-row sorted variable indices for the constraint Jacobian.
945    jac_cols: Vec<Vec<usize>>,
946    jac_nnz: usize,
947    /// Per-color seed vector: `seeds[c][k] = 1.0` iff variable `k`
948    /// is in color `c`, else `0.0`. Each color is a set of
949    /// variables whose Hessian columns have pairwise-disjoint
950    /// nonzero rows; one directional H·s product per color
951    /// recovers all those columns simultaneously. Dense for
952    /// O(1) lookup in the per-op forward tangent.
953    seeds: Vec<Vec<f64>>,
954    /// Per-color decoding table: for each `(row, hess_idx)` entry,
955    /// scatter `compressed_c[row] -> values[hess_idx]` after the
956    /// per-color directional product.
957    decoding: Vec<Vec<ColorWrite>>,
958    /// For each objective tape: the distinct colors of vars it
959    /// references. Lets us skip tape × color pairs where the tape
960    /// has zero overlap with the color's seed.
961    obj_tape_colors: Vec<Vec<u32>>,
962    /// Same as `obj_tape_colors` but per constraint × summand.
963    con_tape_colors: Vec<Vec<Vec<u32>>>,
964    final_x: Option<Vec<Number>>,
965    final_obj: Number,
966    /// Per-row Jacobian accumulator (length n).
967    scratch_row_grad: Vec<f64>,
968    /// Scratch buffers for `Tape::hessian_directional` (each sized
969    /// to `max_tape_n`).
970    vals_scratch: Vec<f64>,
971    dot_scratch: Vec<f64>,
972    adj_scratch: Vec<f64>,
973    adj_dot_scratch: Vec<f64>,
974    /// Per-color compressed Hessian-vector results, sized to
975    /// `prob.n`. Reused across `eval_h` calls but allocated once.
976    compressed: Vec<Vec<f64>>,
977}
978
979/// Recursively flatten top-level Sum and binary-Add nodes into a list
980/// of independent summands. Non-Sum/Add expressions are returned as a
981/// single-element vector. This lets `NlTnlp` build one small tape per
982/// term so the per-variable Hessian sweep only walks the term that
983/// actually depends on that variable.
984fn split_top_sums(expr: &Expr) -> Vec<Expr> {
985    let mut out = Vec::new();
986    fn go(e: &Expr, out: &mut Vec<Expr>) {
987        match e {
988            Expr::Sum(terms) => {
989                for t in terms {
990                    go(t, out);
991                }
992            }
993            Expr::Binary(BinOp::Add, l, r) => {
994                go(l, out);
995                go(r, out);
996            }
997            _ => out.push(e.clone()),
998        }
999    }
1000    go(expr, &mut out);
1001    if out.is_empty() {
1002        out.push(Expr::Const(0.0));
1003    }
1004    out
1005}
1006
1007/// Greedy column coloring of a symmetric sparsity pattern stored
1008/// as lower-triangle pairs.
1009///
1010/// Builds the column-intersection graph: columns `c1` and `c2` are
1011/// adjacent iff there exists a row `r` with `H[r, c1] != 0` and
1012/// `H[r, c2] != 0`. A distance-1 greedy coloring on this graph
1013/// satisfies the direct-recovery condition for symmetric Hessians
1014/// (Coleman-Moré): for any color, the columns it contains have
1015/// pairwise disjoint row supports, so a single H·s product
1016/// recovers them all unambiguously.
1017///
1018/// Returns `(var_color, n_colors)` where `var_color[k]` is the
1019/// color assigned to variable `k`, or `u32::MAX` for variables
1020/// not in any Hessian pair (they contribute nothing and don't
1021/// need a color).
1022fn greedy_hessian_coloring(n: usize, lower_pairs: &[(usize, usize)]) -> (Vec<u32>, usize) {
1023    if n == 0 {
1024        return (Vec::new(), 0);
1025    }
1026
1027    // For each variable k, list of rows in which column k has a
1028    // nonzero in the FULL (symmetric) Hessian. Built from lower
1029    // pairs: (i, j) with i >= j contributes row i to column j and
1030    // row j to column i (when i != j); diagonals contribute once.
1031    let mut col_rows: Vec<Vec<u32>> = vec![Vec::new(); n];
1032    let mut row_cols: Vec<Vec<u32>> = vec![Vec::new(); n];
1033    for &(i, j) in lower_pairs {
1034        col_rows[j].push(i as u32);
1035        row_cols[i].push(j as u32);
1036        if i != j {
1037            col_rows[i].push(j as u32);
1038            row_cols[j].push(i as u32);
1039        }
1040    }
1041
1042    let mut var_color = vec![u32::MAX; n];
1043    let mut forbidden = vec![u32::MAX; n + 1];
1044    let mut n_colors: u32 = 0;
1045
1046    for j in 0..n {
1047        // Variable `j` has no Hessian entries → skip (no color).
1048        if col_rows[j].is_empty() {
1049            continue;
1050        }
1051        // Mark colors used by any column sharing a row with `j`.
1052        // Row-of-col -> col-in-row visit pattern collects all
1053        // distance-1 neighbors in the column-intersection graph.
1054        for &r in &col_rows[j] {
1055            for &c in &row_cols[r as usize] {
1056                if c as usize == j {
1057                    continue;
1058                }
1059                let cc = var_color[c as usize];
1060                if cc != u32::MAX {
1061                    forbidden[cc as usize] = j as u32;
1062                }
1063            }
1064        }
1065        // First color not stamped with `j as u32`.
1066        let mut chosen: u32 = 0;
1067        while (chosen as usize) < forbidden.len() && forbidden[chosen as usize] == j as u32 {
1068            chosen += 1;
1069        }
1070        var_color[j] = chosen;
1071        if chosen + 1 > n_colors {
1072            n_colors = chosen + 1;
1073        }
1074    }
1075
1076    (var_color, n_colors as usize)
1077}
1078
1079impl NlTnlp {
1080    pub fn new(prob: NlProblem) -> Self {
1081        // Flatten objective and each constraint into independent
1082        // summands. Each summand becomes its own `Tape` (CSE bodies
1083        // are deduplicated within a tape via Rc identity in
1084        // `Tape::build`; bodies shared across summands are
1085        // duplicated, which we accept as a simplicity tradeoff).
1086        let obj_summands = split_top_sums(&prob.obj_nonlinear);
1087        let obj_tapes: Vec<Tape> = obj_summands.iter().map(Tape::build).collect();
1088
1089        let mut con_tapes: Vec<Vec<Tape>> = Vec::with_capacity(prob.m);
1090        for k in 0..prob.m {
1091            let summands = split_top_sums(&prob.con_nonlinear[k]);
1092            con_tapes.push(summands.iter().map(Tape::build).collect());
1093        }
1094
1095        // Hessian-of-Lagrangian sparsity: union of each tape's own
1096        // structural Hessian sparsity.
1097        let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
1098        for t in &obj_tapes {
1099            pairs.extend(t.hessian_sparsity());
1100        }
1101        for row in &con_tapes {
1102            for t in row {
1103                pairs.extend(t.hessian_sparsity());
1104            }
1105        }
1106        let mut h_irow = Vec::with_capacity(pairs.len());
1107        let mut h_jcol = Vec::with_capacity(pairs.len());
1108        let mut hess_map = HashMap::with_capacity(pairs.len());
1109        for (k, (hi, lo)) in pairs.iter().enumerate() {
1110            h_irow.push(*hi as i32);
1111            h_jcol.push(*lo as i32);
1112            hess_map.insert((*hi, *lo), k);
1113        }
1114
1115        // Hessian column coloring. The chromatic number of the
1116        // column-intersection graph bounds how many directional
1117        // Hessian-vector products we need per `eval_h` call —
1118        // typically O(stencil) for PDE-mesh problems.
1119        let lower_pairs: Vec<(usize, usize)> = pairs.iter().copied().collect();
1120        let (var_color, n_colors) = greedy_hessian_coloring(prob.n, &lower_pairs);
1121
1122        // Per-color seed vectors (dense for O(1) Var lookup in
1123        // `Tape::hessian_directional`).
1124        let mut seeds: Vec<Vec<f64>> = vec![vec![0.0; prob.n]; n_colors];
1125        for (k, &c) in var_color.iter().enumerate() {
1126            if c != u32::MAX {
1127                seeds[c as usize][k] = 1.0;
1128            }
1129        }
1130
1131        // Per-color decoding table. For each lower-tri pair (i, j)
1132        // with i >= j, the entry belongs to column j's color: after
1133        // computing compressed_{c_j} = (H · s_{c_j}), the value at
1134        // row i is exactly H[i, j] (coloring guarantees no other
1135        // column in c_j has a nonzero at row i).
1136        let mut decoding: Vec<Vec<ColorWrite>> = vec![Vec::new(); n_colors];
1137        for (&(i, j), &idx) in hess_map.iter() {
1138            let c = var_color[j];
1139            debug_assert!(
1140                c != u32::MAX,
1141                "column {j} has Hessian pair {idx} but no color"
1142            );
1143            decoding[c as usize].push(ColorWrite {
1144                row: i as u32,
1145                hess_idx: idx as u32,
1146            });
1147        }
1148
1149        // Per-tape distinct color set: for each tape, the colors
1150        // its variables fall into. `eval_h` loops over only these
1151        // (tape, color) pairs instead of n_tapes × n_colors.
1152        let tape_colors = |t: &Tape| -> Vec<u32> {
1153            let mut s: BTreeSet<u32> = BTreeSet::new();
1154            for v in t.variables() {
1155                let c = var_color[v];
1156                if c != u32::MAX {
1157                    s.insert(c);
1158                }
1159            }
1160            s.into_iter().collect()
1161        };
1162        let obj_tape_colors: Vec<Vec<u32>> = obj_tapes.iter().map(tape_colors).collect();
1163        let con_tape_colors: Vec<Vec<Vec<u32>>> = con_tapes
1164            .iter()
1165            .map(|row| row.iter().map(tape_colors).collect())
1166            .collect();
1167
1168        // Per-row Jacobian sparsity = union of tape vars plus
1169        // linear-segment vars.
1170        let mut jac_cols: Vec<Vec<usize>> = Vec::with_capacity(prob.m);
1171        let mut jac_nnz = 0;
1172        for i in 0..prob.m {
1173            let mut set: BTreeSet<usize> = BTreeSet::new();
1174            for t in &con_tapes[i] {
1175                for v in t.variables() {
1176                    set.insert(v);
1177                }
1178            }
1179            for (v, _) in &prob.con_linear[i] {
1180                set.insert(*v);
1181            }
1182            let cols: Vec<usize> = set.into_iter().collect();
1183            jac_nnz += cols.len();
1184            jac_cols.push(cols);
1185        }
1186
1187        let mut max_tape_n: usize = 0;
1188        for t in &obj_tapes {
1189            max_tape_n = max_tape_n.max(t.ops.len());
1190        }
1191        for row in &con_tapes {
1192            for t in row {
1193                max_tape_n = max_tape_n.max(t.ops.len());
1194            }
1195        }
1196
1197        if std::env::var("POUNCE_DBG_TAPE_STATS").is_ok() {
1198            let n_obj = obj_tapes.len();
1199            let n_con: usize = con_tapes.iter().map(|r| r.len()).sum();
1200            let total = n_obj + n_con;
1201            let mut sum_ops: usize = 0;
1202            for t in &obj_tapes {
1203                sum_ops += t.ops.len();
1204            }
1205            for row in &con_tapes {
1206                for t in row {
1207                    sum_ops += t.ops.len();
1208                }
1209            }
1210            let t = total.max(1);
1211            let nnz_h = h_irow.len();
1212            let avg_decode =
1213                decoding.iter().map(|d| d.len()).sum::<usize>() as f64 / n_colors.max(1) as f64;
1214            eprintln!(
1215                "[tape stats] summands={total} (obj={n_obj} con={n_con}) \
1216                 total_ops={sum_ops} avg_ops={:.1} max_ops={max_tape_n} \
1217                 n_colors={n_colors} avg_decode_per_color={avg_decode:.1} nnz_h={nnz_h}",
1218                sum_ops as f64 / t as f64,
1219            );
1220        }
1221
1222        let compressed: Vec<Vec<f64>> = vec![vec![0.0; prob.n]; n_colors];
1223
1224        Self {
1225            prob,
1226            obj_tapes,
1227            con_tapes,
1228            h_irow,
1229            h_jcol,
1230            jac_cols,
1231            jac_nnz,
1232            seeds,
1233            decoding,
1234            obj_tape_colors,
1235            con_tape_colors,
1236            final_x: None,
1237            final_obj: 0.0,
1238            scratch_row_grad: Vec::new(),
1239            vals_scratch: vec![0.0; max_tape_n],
1240            dot_scratch: vec![0.0; max_tape_n],
1241            adj_scratch: vec![0.0; max_tape_n],
1242            adj_dot_scratch: vec![0.0; max_tape_n],
1243            compressed,
1244        }
1245    }
1246
1247    pub fn final_x(&self) -> Option<&[Number]> {
1248        self.final_x.as_deref()
1249    }
1250
1251    pub fn final_obj(&self) -> Number {
1252        self.final_obj
1253    }
1254}
1255
1256impl TNLP for NlTnlp {
1257    fn get_nlp_info(&mut self) -> Option<NlpInfo> {
1258        Some(NlpInfo {
1259            n: self.prob.n as Index,
1260            m: self.prob.m as Index,
1261            nnz_jac_g: self.jac_nnz as Index,
1262            nnz_h_lag: self.h_irow.len() as Index,
1263            index_style: IndexStyle::C,
1264        })
1265    }
1266
1267    fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool {
1268        b.x_l.copy_from_slice(&self.prob.x_l);
1269        b.x_u.copy_from_slice(&self.prob.x_u);
1270        if !self.prob.g_l.is_empty() {
1271            b.g_l.copy_from_slice(&self.prob.g_l);
1272            b.g_u.copy_from_slice(&self.prob.g_u);
1273        }
1274        true
1275    }
1276
1277    fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool {
1278        sp.x.copy_from_slice(&self.prob.x0);
1279        true
1280    }
1281
1282    fn eval_f(&mut self, x: &[Number], _new_x: bool) -> Option<Number> {
1283        let mut nl: Number = 0.0;
1284        for t in &self.obj_tapes {
1285            nl += t.eval(x);
1286        }
1287        let lin: Number = self.prob.obj_linear.iter().map(|(i, c)| c * x[*i]).sum();
1288        let v = self.prob.obj_constant + nl + lin;
1289        let signed = if self.prob.minimize { v } else { -v };
1290        Some(signed)
1291    }
1292
1293    fn eval_grad_f(&mut self, x: &[Number], _new_x: bool, grad: &mut [Number]) -> bool {
1294        grad.fill(0.0);
1295        for t in &self.obj_tapes {
1296            t.gradient_seed(x, 1.0, grad);
1297        }
1298        for (i, c) in &self.prob.obj_linear {
1299            grad[*i] += c;
1300        }
1301        if !self.prob.minimize {
1302            for g in grad.iter_mut() {
1303                *g = -*g;
1304            }
1305        }
1306        true
1307    }
1308
1309    fn eval_g(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
1310        for i in 0..self.prob.m {
1311            let mut nl: Number = 0.0;
1312            for t in &self.con_tapes[i] {
1313                nl += t.eval(x);
1314            }
1315            let lin: Number = self.prob.con_linear[i].iter().map(|(j, c)| c * x[*j]).sum();
1316            g[i] = nl + lin;
1317        }
1318        true
1319    }
1320
1321    fn eval_jac_g(
1322        &mut self,
1323        x: Option<&[Number]>,
1324        _new_x: bool,
1325        mode: SparsityRequest<'_>,
1326    ) -> bool {
1327        match mode {
1328            SparsityRequest::Structure { irow, jcol } => {
1329                let mut k = 0;
1330                for i in 0..self.prob.m {
1331                    for &j in &self.jac_cols[i] {
1332                        irow[k] = i as Index;
1333                        jcol[k] = j as Index;
1334                        k += 1;
1335                    }
1336                }
1337                true
1338            }
1339            SparsityRequest::Values { values } => {
1340                let n = self.prob.n;
1341                let xs = x.unwrap_or(&self.prob.x0);
1342                if self.scratch_row_grad.len() < n {
1343                    self.scratch_row_grad.resize(n, 0.0);
1344                }
1345                let mut k = 0;
1346                for i in 0..self.prob.m {
1347                    for &j in &self.jac_cols[i] {
1348                        self.scratch_row_grad[j] = 0.0;
1349                    }
1350                    for t in &self.con_tapes[i] {
1351                        t.gradient_seed(xs, 1.0, &mut self.scratch_row_grad);
1352                    }
1353                    for &(v, c) in &self.prob.con_linear[i] {
1354                        self.scratch_row_grad[v] += c;
1355                    }
1356                    for &j in &self.jac_cols[i] {
1357                        values[k] = self.scratch_row_grad[j];
1358                        k += 1;
1359                    }
1360                }
1361                true
1362            }
1363        }
1364    }
1365
1366    fn eval_h(
1367        &mut self,
1368        x: Option<&[Number]>,
1369        _new_x: bool,
1370        obj_factor: Number,
1371        lambda: Option<&[Number]>,
1372        _new_lambda: bool,
1373        mode: SparsityRequest<'_>,
1374    ) -> bool {
1375        match mode {
1376            SparsityRequest::Structure { irow, jcol } => {
1377                irow.copy_from_slice(&self.h_irow);
1378                jcol.copy_from_slice(&self.h_jcol);
1379                true
1380            }
1381            SparsityRequest::Values { values } => {
1382                let x = x.unwrap_or(&self.prob.x0);
1383                values.fill(0.0);
1384
1385                let obj_seed = if self.prob.minimize {
1386                    obj_factor
1387                } else {
1388                    -obj_factor
1389                };
1390                // Coloring path. For each (tape, weight) we do
1391                // one forward pass into `vals_scratch`, then one
1392                // forward-tangent+reverse-over-tangent per color
1393                // touched by that tape. Each pass accumulates a
1394                // weighted contribution of (H_tape · seed_c) into
1395                // `compressed[c]`. After all tapes done, we
1396                // decode each color's compressed vector into the
1397                // sparse `values` array.
1398                for buf in &mut self.compressed {
1399                    buf.fill(0.0);
1400                }
1401
1402                if obj_seed != 0.0 {
1403                    for (ti, t) in self.obj_tapes.iter().enumerate() {
1404                        if t.ops.is_empty() {
1405                            continue;
1406                        }
1407                        t.forward_into(x, &mut self.vals_scratch);
1408                        for &c in &self.obj_tape_colors[ti] {
1409                            t.hessian_directional(
1410                                &self.vals_scratch,
1411                                &self.seeds[c as usize],
1412                                obj_seed,
1413                                &mut self.compressed[c as usize],
1414                                &mut self.dot_scratch,
1415                                &mut self.adj_scratch,
1416                                &mut self.adj_dot_scratch,
1417                            );
1418                        }
1419                    }
1420                }
1421
1422                if let Some(lam) = lambda {
1423                    for k in 0..self.prob.m {
1424                        let w = lam[k];
1425                        if w == 0.0 {
1426                            continue;
1427                        }
1428                        for (ti, t) in self.con_tapes[k].iter().enumerate() {
1429                            if t.ops.is_empty() {
1430                                continue;
1431                            }
1432                            t.forward_into(x, &mut self.vals_scratch);
1433                            for &c in &self.con_tape_colors[k][ti] {
1434                                t.hessian_directional(
1435                                    &self.vals_scratch,
1436                                    &self.seeds[c as usize],
1437                                    w,
1438                                    &mut self.compressed[c as usize],
1439                                    &mut self.dot_scratch,
1440                                    &mut self.adj_scratch,
1441                                    &mut self.adj_dot_scratch,
1442                                );
1443                            }
1444                        }
1445                    }
1446                }
1447
1448                // Decode each color's compressed Hessian-vector
1449                // result into the lower-triangle `values` array.
1450                for (c, table) in self.decoding.iter().enumerate() {
1451                    let comp = &self.compressed[c];
1452                    for w in table {
1453                        values[w.hess_idx as usize] += comp[w.row as usize];
1454                    }
1455                }
1456                true
1457            }
1458        }
1459    }
1460
1461    fn finalize_solution(&mut self, sol: Solution<'_>, _d: &IpoptData, _q: &IpoptCq) {
1462        self.final_x = Some(sol.x.to_vec());
1463        self.final_obj = sol.obj_value;
1464    }
1465
1466    fn get_constraints_linearity(&mut self, types: &mut [Linearity]) -> bool {
1467        // A row is linear iff its nonlinear-part expression is the
1468        // identity zero left over from initial allocation (post-parse
1469        // identity for "no `C<idx>` segment touched this row").
1470        for (i, t) in types.iter_mut().enumerate() {
1471            *t = match &self.prob.con_nonlinear[i] {
1472                Expr::Const(c) if *c == 0.0 => Linearity::Linear,
1473                _ => Linearity::NonLinear,
1474            };
1475        }
1476        true
1477    }
1478}
1479
1480/// Convenience: read an `.nl` file and build a TNLP-compatible Rc.
1481pub fn load_nl_as_tnlp(path: &Path) -> Result<Rc<RefCell<dyn TNLP>>, String> {
1482    let prob = read_nl_file(path)?;
1483    Ok(Rc::new(RefCell::new(NlTnlp::new(prob))))
1484}
1485
1486#[cfg(test)]
1487mod tests {
1488    use super::*;
1489
1490    /// `min (x0 - 1)^2 + (x1 - 2)^2` written in `.nl` ASCII form.
1491    /// Header values:
1492    ///   line 2: n=2 m=0 num_obj=1 0 0
1493    ///   line 3: 0 1   (1 nonlinear objective)
1494    ///   line 4: 0 0
1495    ///   line 5: 0 2 0 (nonlinear vars in obj=2)
1496    ///   line 6: 0 0 0 1
1497    ///   line 7: 0 0 0 0 0
1498    ///   line 8: 0 0   (no Jacobian nonzeros, no linear obj)
1499    ///   line 9: 0 0
1500    ///   line 10: 0 0 0 0 0
1501    /// Then `O0 0` followed by an expression tree:
1502    /// `(x0 - 1)^2 + (x1 - 2)^2` =
1503    ///   o0
1504    ///     o5 (o1 v0 n1) n2
1505    ///     o5 (o1 v1 n2) n2
1506    /// Then `b` segment: free for both.
1507    const SIMPLE: &str = "g3 0 1 0
15082 0 1 0 0
15090 1
15100 0
15110 2 0
15120 0 0 1
15130 0 0 0 0
15140 0
15150 0
15160 0 0 0 0
1517O0 0
1518o0
1519o5
1520o1
1521v0
1522n1
1523n2
1524o5
1525o1
1526v1
1527n2
1528n2
1529b
15303
15313
1532";
1533
1534    #[test]
1535    fn parses_simple_quadratic() {
1536        let p = parse_nl_text(SIMPLE).expect("parse");
1537        assert_eq!(p.n, 2);
1538        assert_eq!(p.m, 0);
1539        assert_eq!(p.num_obj, 1);
1540        // f(0,0) = 1 + 4 = 5
1541        let f = eval_expr(&p.obj_nonlinear, &[0.0, 0.0]);
1542        assert!((f - 5.0).abs() < 1e-12);
1543        // f(1,2) = 0
1544        let f = eval_expr(&p.obj_nonlinear, &[1.0, 2.0]);
1545        assert!(f.abs() < 1e-12);
1546    }
1547
1548    #[test]
1549    fn gradient_matches_analytic() {
1550        let p = parse_nl_text(SIMPLE).expect("parse");
1551        let x = [0.5, 1.0];
1552        let mut g = [0.0_f64; 2];
1553        grad_expr(&p.obj_nonlinear, &x, 1.0, &mut g);
1554        // d/dx0 = 2*(x0-1) = -1.0
1555        // d/dx1 = 2*(x1-2) = -2.0
1556        assert!((g[0] - (-1.0)).abs() < 1e-12);
1557        assert!((g[1] - (-2.0)).abs() < 1e-12);
1558    }
1559
1560    /// `min x0^2 + x1^2  s.t.  x0 + x1 = 1`.
1561    /// One equality constraint with a purely linear Jacobian — exercises
1562    /// the constrained path (`eval_g`, `eval_jac_g`, `r`-segment bound
1563    /// kind 4).
1564    ///
1565    /// Header layout:
1566    ///   line 1: g3 0 1 0
1567    ///   line 2: 2 1 1 0 0   (n=2, m=1, num_obj=1)
1568    ///   line 3: 0 1         (1 nonlinear obj, 0 nonlinear cons)
1569    ///   line 4: 0 0
1570    ///   line 5: 0 2 0       (nonlinear vars in obj=2)
1571    ///   line 6: 0 0 0 1
1572    ///   line 7: 0 0 0 0 0
1573    ///   line 8: 2 0         (Jacobian nnz=2, no linear obj)
1574    ///   line 9: 0 0
1575    ///   line 10: 0 0 0 0 0
1576    /// Then C0 = const 0 (no nonlinear part), O0 = x0^2 + x1^2,
1577    /// r-segment kind 4 (eq) value 1, b-segment free, k-segment, J-row.
1578    const EQ_LIN: &str = "g3 0 1 0
15792 1 1 0 0
15800 1
15810 0
15820 2 0
15830 0 0 1
15840 0 0 0 0
15852 0
15860 0
15870 0 0 0 0
1588C0
1589n0
1590O0 0
1591o0
1592o5
1593v0
1594n2
1595o5
1596v1
1597n2
1598r
15994 1
1600b
16013
16023
1603k1
16042
1605J0 2
16060 1
16071 1
1608";
1609
1610    #[test]
1611    fn parses_constrained_problem() {
1612        let p = parse_nl_text(EQ_LIN).expect("parse");
1613        assert_eq!(p.n, 2);
1614        assert_eq!(p.m, 1);
1615        // r-segment kind 4 (equality with rhs=1).
1616        assert!((p.g_l[0] - 1.0).abs() < 1e-12);
1617        assert!((p.g_u[0] - 1.0).abs() < 1e-12);
1618        // J-row 0: x0 (coef 1), x1 (coef 1).
1619        assert_eq!(p.con_linear[0], vec![(0, 1.0), (1, 1.0)]);
1620    }
1621
1622    #[test]
1623    fn constrained_tnlp_eval_g_jac_h() {
1624        let p = parse_nl_text(EQ_LIN).expect("parse");
1625        let mut t = NlTnlp::new(p);
1626        let info = t.get_nlp_info().unwrap();
1627        assert_eq!(info.m, 1);
1628        assert_eq!(info.nnz_jac_g, 2);
1629
1630        // g(0.3, 0.4) = 0.3 + 0.4 = 0.7
1631        let mut g = [0.0_f64; 1];
1632        assert!(t.eval_g(&[0.3, 0.4], true, &mut g));
1633        assert!((g[0] - 0.7).abs() < 1e-12);
1634
1635        // Jacobian structure: row 0, cols [0, 1].
1636        let mut irow = [0_i32; 2];
1637        let mut jcol = [0_i32; 2];
1638        assert!(t.eval_jac_g(
1639            None,
1640            true,
1641            SparsityRequest::Structure {
1642                irow: &mut irow,
1643                jcol: &mut jcol
1644            }
1645        ));
1646        assert_eq!(irow, [0, 0]);
1647        assert_eq!(jcol, [0, 1]);
1648
1649        // Jacobian values: both 1.0.
1650        let mut vals = [0.0_f64; 2];
1651        assert!(t.eval_jac_g(
1652            Some(&[0.3, 0.4]),
1653            true,
1654            SparsityRequest::Values { values: &mut vals }
1655        ));
1656        assert!((vals[0] - 1.0).abs() < 1e-12);
1657        assert!((vals[1] - 1.0).abs() < 1e-12);
1658
1659        // Hessian of L = (x0^2 + x1^2) + λ*(x0 + x1 - 1) is diag(2,2);
1660        // λ contributes nothing because the constraint is linear, and
1661        // x0^2 + x1^2 is separable so there's no (1,0) entry in the
1662        // structural sparsity. nnz_h_lag = 2: (0,0) and (1,1).
1663        assert_eq!(info.nnz_h_lag, 2);
1664        let mut hirow = [0_i32; 2];
1665        let mut hjcol = [0_i32; 2];
1666        assert!(t.eval_h(
1667            None,
1668            true,
1669            1.0,
1670            None,
1671            true,
1672            SparsityRequest::Structure {
1673                irow: &mut hirow,
1674                jcol: &mut hjcol
1675            }
1676        ));
1677        assert_eq!(hirow, [0, 1]);
1678        assert_eq!(hjcol, [0, 1]);
1679        let mut hvals = [0.0_f64; 2];
1680        assert!(t.eval_h(
1681            Some(&[0.3, 0.4]),
1682            true,
1683            1.0,
1684            Some(&[0.5]),
1685            true,
1686            SparsityRequest::Values { values: &mut hvals }
1687        ));
1688        assert!((hvals[0] - 2.0).abs() < 1e-12);
1689        assert!((hvals[1] - 2.0).abs() < 1e-12);
1690    }
1691
1692    /// `min (x0 + x1)^2 + (x0 + x1)` with the shared sum `(x0 + x1)`
1693    /// encoded as common-subexpression `V2`. Header line 10 declares
1694    /// one obj-only CSE; expression tree references `v2` twice.
1695    const CSE_OBJ: &str = "g3 0 1 0
16962 0 1 0 0
16970 1
16980 0
16990 2 0
17000 0 0 1
17010 0 0 0 0
17020 0
17030 0
17040 1 0 0 0
1705V2 0 0
1706o0
1707v0
1708v1
1709O0 0
1710o0
1711o5
1712v2
1713n2
1714v2
1715b
17163
17173
1718";
1719
1720    #[test]
1721    fn parses_v_segment_cse() {
1722        let p = parse_nl_text(CSE_OBJ).expect("parse");
1723        assert_eq!(p.n, 2);
1724        // f(1,2) = 9 + 3 = 12
1725        let f = eval_expr(&p.obj_nonlinear, &[1.0, 2.0]);
1726        assert!((f - 12.0).abs() < 1e-12, "got {f}");
1727        // d/dx0 = 2*(x0+x1) + 1 = 7 at (1,2). Same for x1.
1728        let mut g = [0.0_f64; 2];
1729        grad_expr(&p.obj_nonlinear, &[1.0, 2.0], 1.0, &mut g);
1730        assert!((g[0] - 7.0).abs() < 1e-12, "g[0]={}", g[0]);
1731        assert!((g[1] - 7.0).abs() < 1e-12, "g[1]={}", g[1]);
1732        // collect_vars reaches into the CSE body and finds {0, 1}.
1733        let mut vs = BTreeSet::new();
1734        collect_vars(&p.obj_nonlinear, &mut vs);
1735        assert_eq!(vs.into_iter().collect::<Vec<_>>(), vec![0, 1]);
1736    }
1737
1738    /// `min (x0 - 1)^2` with three suffix segments attached: an
1739    /// integer constraint-suffix (target=1, kind=1), an integer var-
1740    /// suffix (target=0, kind=0), and a real var-suffix (target=0,
1741    /// kind=4). The .nl format is `S<kind> <nentries> <name>` then
1742    /// `<idx> <value>` lines.
1743    const WITH_SUFFIXES: &str = "g3 0 1 0
17441 0 1 0 0
17450 1
17460 0
17470 1 0
17480 0 0 1
17490 0 0 0 0
17500 0
17510 0
17520 0 0 0 0
1753O0 0
1754o5
1755o1
1756v0
1757n1
1758n2
1759b
17603
1761S0 1 sens_state_1
17620 7
1763S4 1 sens_state_value_1
17640 4.5
1765";
1766
1767    #[test]
1768    fn parses_var_int_and_var_real_suffixes() {
1769        let p = parse_nl_text(WITH_SUFFIXES).expect("parse");
1770        // Integer var-suffix: dense length 1, slot 0 = 7.
1771        let v = p.suffixes.var_int.get("sens_state_1").expect("var_int");
1772        assert_eq!(v.as_slice(), &[7]);
1773        // Real var-suffix: dense length 1, slot 0 = 4.5.
1774        let r = p
1775            .suffixes
1776            .var_real
1777            .get("sens_state_value_1")
1778            .expect("var_real");
1779        assert_eq!(r.len(), 1);
1780        assert!((r[0] - 4.5).abs() < 1e-12);
1781        // Other suffix slots stay empty.
1782        assert!(p.suffixes.con_int.is_empty());
1783        assert!(p.suffixes.con_real.is_empty());
1784    }
1785
1786    /// Two-variable + two-constraint problem with a constraint-level
1787    /// integer suffix (kind=1). Sparse entries scatter to dense length 2.
1788    const WITH_CON_SUFFIX: &str = "g3 0 1 0
17892 2 1 0 0
17900 0
17910 0
17920 2 0
17930 0 0 1
17940 0 0 0 0
17952 0
17960 0
17970 0 0 0 0 0
1798C0
1799n0
1800C1
1801n0
1802O0 0
1803n0
1804r
18054 0.0
18064 0.0
1807b
18083
18093
1810k1
18110
1812J0 2
18130 1
18141 1
1815J1 2
18160 1
18171 -1
1818S1 2 sens_init_constr
18190 1
18201 2
1821";
1822
1823    #[test]
1824    fn parses_con_int_suffix() {
1825        let p = parse_nl_text(WITH_CON_SUFFIX).expect("parse");
1826        let s = p.suffixes.con_int.get("sens_init_constr").expect("con_int");
1827        // Sparse {0:1, 1:2} → dense [1, 2] at length m=2.
1828        assert_eq!(s.as_slice(), &[1, 2]);
1829    }
1830
1831    #[test]
1832    fn rejects_suffix_with_out_of_range_index() {
1833        let bad = WITH_CON_SUFFIX.replace("1 2\n", "5 2\n"); // m=2, idx=5 invalid
1834        let err = parse_nl_text(&bad).expect_err("must reject");
1835        assert!(
1836            err.contains("out of range"),
1837            "expected out-of-range error, got: {err}"
1838        );
1839    }
1840
1841    #[test]
1842    fn tnlp_round_trip_solves() {
1843        let p = parse_nl_text(SIMPLE).expect("parse");
1844        let mut tnlp = NlTnlp::new(p);
1845        let info = tnlp.get_nlp_info().unwrap();
1846        assert_eq!(info.n, 2);
1847        assert_eq!(info.m, 0);
1848        let f0 = tnlp.eval_f(&[0.0, 0.0], true).unwrap();
1849        assert!((f0 - 5.0).abs() < 1e-12);
1850        let mut g = [0.0_f64; 2];
1851        tnlp.eval_grad_f(&[0.0, 0.0], true, &mut g);
1852        // d/dx0 at x=0: 2*(0-1) = -2; d/dx1: 2*(0-2) = -4
1853        assert!((g[0] - (-2.0)).abs() < 1e-12);
1854        assert!((g[1] - (-4.0)).abs() < 1e-12);
1855    }
1856}