1use crate::nl_tape::Tape;
36use pounce_common::types::{Index, Number};
37use pounce_nlp::tnlp::{
38 BoundsInfo, IndexStyle, IpoptCq, IpoptData, Linearity, MetaData, NlpInfo, Solution,
39 SparsityRequest, StartingPoint, IDX_NAMES, TNLP,
40};
41use std::cell::RefCell;
42use std::collections::{BTreeMap, BTreeSet, HashMap};
43use std::path::Path;
44use std::rc::Rc;
45use std::sync::Arc;
46
47#[derive(Debug, Clone)]
48pub enum Expr {
49 Const(Number),
51 Var(usize),
53 Binary(BinOp, Box<Expr>, Box<Expr>),
55 Unary(UnaryOp, Box<Expr>),
57 Sum(Vec<Expr>),
60 Cse(Arc<Expr>),
71 Funcall { id: usize, args: Vec<FuncallArg> },
75 Compare(CmpOp, Box<Expr>, Box<Expr>),
81 And(Box<Expr>, Box<Expr>),
84 Or(Box<Expr>, Box<Expr>),
87 Not(Box<Expr>),
90 Cond {
96 cond: Box<Expr>,
97 then_: Box<Expr>,
98 else_: Box<Expr>,
99 },
100 MinList(Vec<Expr>),
106 MaxList(Vec<Expr>),
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum CmpOp {
116 Lt,
117 Le,
118 Eq,
119 Ge,
120 Gt,
121 Ne,
122}
123
124#[derive(Debug, Clone)]
128pub enum FuncallArg {
129 Real(Expr),
130 Str(String),
131}
132
133#[derive(Debug, Clone)]
136pub struct ImportedFunc {
137 pub id: usize,
138 pub kind: usize,
140 pub nargs: i64,
142 pub name: String,
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub enum BinOp {
147 Add,
148 Sub,
149 Mul,
150 Div,
151 Pow,
152 Atan2,
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum UnaryOp {
158 Neg,
159 Sqrt,
160 Log,
161 Exp,
162 Abs,
163 Sin,
164 Cos,
165 Log10,
166 Tan,
167 Atan,
168 Acos,
169 Sinh,
170 Cosh,
171 Tanh,
172 Asin,
173 Acosh,
174 Asinh,
175 Atanh,
176}
177
178#[derive(Debug, Clone)]
180pub struct NlProblem {
181 pub n: usize,
182 pub m: usize,
183 pub num_obj: usize,
184 pub minimize: bool,
185 pub obj_nonlinear: Expr,
186 pub obj_linear: Vec<(usize, Number)>,
187 pub obj_constant: Number,
188 pub con_nonlinear: Vec<Expr>,
190 pub con_linear: Vec<Vec<(usize, Number)>>,
192 pub x_l: Vec<Number>,
193 pub x_u: Vec<Number>,
194 pub g_l: Vec<Number>,
195 pub g_u: Vec<Number>,
196 pub x0: Vec<Number>,
197 pub lambda0: Vec<Number>,
198 pub suffixes: NlSuffixes,
206 pub imported_funcs: Vec<ImportedFunc>,
210 pub var_names: Vec<String>,
221 pub con_names: Vec<String>,
225}
226
227#[derive(Debug, Clone, Default)]
232pub struct NlSuffixes {
233 pub var_int: BTreeMap<String, Vec<Index>>,
236 pub con_int: BTreeMap<String, Vec<Index>>,
238 pub obj_int: BTreeMap<String, Vec<Index>>,
240 pub problem_int: BTreeMap<String, Index>,
242 pub var_real: BTreeMap<String, Vec<Number>>,
244 pub con_real: BTreeMap<String, Vec<Number>>,
246 pub obj_real: BTreeMap<String, Vec<Number>>,
248 pub problem_real: BTreeMap<String, Number>,
250}
251
252pub fn read_nl_file(path: &Path) -> Result<NlProblem, String> {
263 let resolved = if path.exists() {
271 path.to_path_buf()
272 } else {
273 let with_nl = append_extension(path, "nl");
274 if with_nl.exists() {
275 with_nl
276 } else {
277 path.to_path_buf()
278 }
279 };
280 let txt = std::fs::read_to_string(&resolved)
281 .map_err(|e| format!("could not read {}: {}", resolved.display(), e))?;
282 let mut prob = parse_nl_text(&txt)?;
283 prob.var_names = read_name_file(&resolved.with_extension("col"), prob.n);
284 prob.con_names = read_name_file(&resolved.with_extension("row"), prob.m);
285 Ok(prob)
286}
287
288fn append_extension(path: &Path, ext: &str) -> std::path::PathBuf {
294 let mut name = path.as_os_str().to_os_string();
295 name.push(".");
296 name.push(ext);
297 std::path::PathBuf::from(name)
298}
299
300fn read_name_file(path: &Path, expected: usize) -> Vec<String> {
310 let Ok(txt) = std::fs::read_to_string(path) else {
311 return Vec::new();
312 };
313 let names: Vec<String> = txt.lines().take(expected).map(str::to_owned).collect();
314 if names.len() == expected {
315 names
316 } else {
317 Vec::new()
318 }
319}
320
321pub fn parse_nl_text(txt: &str) -> Result<NlProblem, String> {
323 let mut p = Parser::new(txt);
324 p.parse_header()?;
325 let n = p.n;
326 let m = p.m;
327 let num_obj = p.num_obj;
328
329 let mut con_nonlinear: Vec<Expr> = (0..m).map(|_| Expr::Const(0.0)).collect();
330 let mut obj_nonlinear = Expr::Const(0.0);
331 let mut minimize = true;
332 let mut obj_linear: Vec<(usize, Number)> = Vec::new();
333 let mut con_linear: Vec<Vec<(usize, Number)>> = vec![Vec::new(); m];
334 let mut x_l = vec![-1e19; n];
335 let mut x_u = vec![1e19; n];
336 let mut g_l = vec![-1e19; m];
337 let mut g_u = vec![1e19; m];
338 let mut x0 = vec![0.0; n];
339 let mut lambda0 = vec![0.0; m];
340 let mut suffixes = NlSuffixes::default();
341 let mut imported_funcs: Vec<ImportedFunc> = Vec::new();
342
343 while let Some(line) = p.peek_segment_line() {
344 let tag = line
345 .trim_start()
346 .chars()
347 .next()
348 .ok_or("unexpected blank segment header")?;
349 match tag {
350 'C' => {
351 let (_hdr, rest) = p.eat_segment_header()?;
352 let _ = rest;
353 let idx = parse_segment_index(&_hdr, 'C')?;
354 if idx >= m {
355 return Err(format!("C{idx} out of range; m={m}"));
356 }
357 con_nonlinear[idx] = p.parse_expr()?;
358 }
359 'O' => {
360 let (hdr, _rest) = p.eat_segment_header()?;
361 let parts: Vec<&str> = hdr.split_whitespace().collect();
362 if parts.len() < 2 {
363 return Err(format!("malformed O-segment header: {hdr}"));
364 }
365 let idx = parse_segment_index(parts[0], 'O')?;
366 let kind: i32 = parts[1].parse().map_err(|e| format!("O kind: {e}"))?;
367 if idx == 0 {
368 minimize = kind == 0;
369 obj_nonlinear = p.parse_expr()?;
370 } else {
371 let _ = p.parse_expr()?;
373 }
374 }
375 'r' => {
376 p.eat_segment_header()?;
377 for i in 0..m {
378 let line = p.next_data_line()?;
379 let (lo, hi) = parse_bound_line(&line)?;
380 g_l[i] = lo;
381 g_u[i] = hi;
382 }
383 }
384 'b' => {
385 p.eat_segment_header()?;
386 for i in 0..n {
387 let line = p.next_data_line()?;
388 let (lo, hi) = parse_bound_line(&line)?;
389 x_l[i] = lo;
390 x_u[i] = hi;
391 }
392 }
393 'k' => {
394 let (hdr, _) = p.eat_segment_header()?;
407 let declared = parse_segment_index(&hdr, 'k')?;
408 let expected = if n == 0 { 0 } else { n - 1 };
409 if declared != expected {
410 return Err(format!(
411 "k-segment declares {declared} column-count lines but \
412 the standard count for n={n} variables is {expected}"
413 ));
414 }
415 for _ in 0..declared {
416 p.next_data_line()?;
417 }
418 }
419 'J' => {
420 let (hdr, _) = p.eat_segment_header()?;
421 let parts: Vec<&str> = hdr.split_whitespace().collect();
422 if parts.len() < 2 {
423 return Err(format!("malformed J-segment header: {hdr}"));
424 }
425 let row = parse_segment_index(parts[0], 'J')?;
426 let nz: usize = parts[1].parse().map_err(|e| format!("J nz: {e}"))?;
427 if row >= m {
428 return Err(format!("J{row} out of range"));
429 }
430 for _ in 0..nz {
431 let line = p.next_data_line()?;
432 let (var, coef) = parse_var_coef(&line)?;
433 if var >= n {
438 return Err(format!(
439 "J{row} entry variable index {var} out of range (n={n})"
440 ));
441 }
442 con_linear[row].push((var, coef));
443 }
444 }
445 'G' => {
446 let (hdr, _) = p.eat_segment_header()?;
447 let parts: Vec<&str> = hdr.split_whitespace().collect();
448 if parts.len() < 2 {
449 return Err(format!("malformed G-segment header: {hdr}"));
450 }
451 let idx = parse_segment_index(parts[0], 'G')?;
452 let nz: usize = parts[1].parse().map_err(|e| format!("G nz: {e}"))?;
453 let mut acc = Vec::with_capacity(nz);
454 for _ in 0..nz {
455 let line = p.next_data_line()?;
456 let (var, coef) = parse_var_coef(&line)?;
457 if var >= n {
460 return Err(format!(
461 "G{idx} entry variable index {var} out of range (n={n})"
462 ));
463 }
464 acc.push((var, coef));
465 }
466 if idx == 0 {
467 obj_linear = acc;
468 }
469 }
470 'x' => {
471 let (hdr, _) = p.eat_segment_header()?;
472 let parts: Vec<&str> = hdr.split_whitespace().collect();
473 let nx: usize = parts
474 .first()
475 .and_then(|s| s.trim_start_matches('x').parse().ok())
476 .ok_or_else(|| format!("malformed x-segment header: {hdr}"))?;
477 for _ in 0..nx {
478 let line = p.next_data_line()?;
479 let (idx, val) = parse_var_coef(&line)?;
480 if idx >= n {
484 return Err(format!(
485 "x-segment variable index {idx} out of range (n={n})"
486 ));
487 }
488 x0[idx] = val;
489 }
490 }
491 'd' => {
492 let (hdr, _) = p.eat_segment_header()?;
493 let parts: Vec<&str> = hdr.split_whitespace().collect();
494 let nd: usize = parts
495 .first()
496 .and_then(|s| s.trim_start_matches('d').parse().ok())
497 .ok_or_else(|| format!("malformed d-segment header: {hdr}"))?;
498 for _ in 0..nd {
499 let line = p.next_data_line()?;
500 let (idx, val) = parse_var_coef(&line)?;
501 if idx >= m {
505 return Err(format!(
506 "d-segment constraint index {idx} out of range (m={m})"
507 ));
508 }
509 lambda0[idx] = val;
510 }
511 }
512 'V' => p.parse_v_segment()?,
513 'S' => {
514 parse_suffix_segment(&mut p, n, m, num_obj, &mut suffixes)?;
515 }
516 'F' => {
517 let (hdr, _rest) = p.eat_segment_header()?;
520 let parts: Vec<&str> = hdr.split_whitespace().collect();
521 if parts.is_empty() {
522 return Err(format!("malformed F-segment header: '{hdr}'"));
523 }
524 let id = parse_segment_index(parts[0], 'F')?;
525 let kind: usize = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
526 let nargs: i64 = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0);
527 let name = parts.get(3).copied().unwrap_or("").to_string();
528 imported_funcs.push(ImportedFunc {
529 id,
530 kind,
531 nargs,
532 name,
533 });
534 }
535 other => return Err(format!("unknown .nl segment tag '{other}'")),
536 }
537 }
538
539 Ok(NlProblem {
540 n,
541 m,
542 num_obj,
543 minimize,
544 obj_nonlinear,
545 obj_linear,
546 obj_constant: 0.0,
547 con_nonlinear,
548 con_linear,
549 x_l,
550 x_u,
551 g_l,
552 g_u,
553 x0,
554 lambda0,
555 suffixes,
556 imported_funcs,
557 var_names: Vec::new(),
560 con_names: Vec::new(),
561 })
562}
563
564fn parse_suffix_segment(
581 p: &mut Parser,
582 n: usize,
583 m: usize,
584 num_obj: usize,
585 out: &mut NlSuffixes,
586) -> Result<(), String> {
587 let (hdr, _) = p.eat_segment_header()?;
588 let parts: Vec<&str> = hdr.split_whitespace().collect();
589 if parts.len() < 3 {
590 return Err(format!(
591 "malformed S-segment header: '{hdr}' (expected `S<kind> <n> <name>`)"
592 ));
593 }
594 let kind_str = parts[0].trim_start_matches('S');
595 let kind: u32 = kind_str
596 .parse()
597 .map_err(|e| format!("S kind '{kind_str}': {e}"))?;
598 let nentries: usize = parts[1].parse().map_err(|e| format!("S nentries: {e}"))?;
599 let name = parts[2].to_string();
600
601 let is_real = (kind & 0x4) != 0;
602 let target = kind & 0x3;
603 let target_dim = match target {
604 0 => n,
605 1 => m,
606 2 => num_obj,
607 3 => 0, _ => unreachable!("kind & 0x3 is in 0..=3"),
609 };
610
611 let mut int_buf: Vec<Index> = if !is_real && target != 3 {
615 vec![0; target_dim]
616 } else {
617 Vec::new()
618 };
619 let mut real_buf: Vec<Number> = if is_real && target != 3 {
620 vec![0.0; target_dim]
621 } else {
622 Vec::new()
623 };
624 let mut problem_int: Index = 0;
625 let mut problem_real: Number = 0.0;
626
627 for _ in 0..nentries {
628 let line = p.next_data_line()?;
629 let parts: Vec<&str> = line.split_whitespace().collect();
630 if parts.len() < 2 {
631 return Err(format!(
632 "malformed S-segment entry '{line}' (expected `<idx> <value>`)"
633 ));
634 }
635 let idx: usize = parts[0]
636 .parse()
637 .map_err(|e| format!("S entry idx '{}': {e}", parts[0]))?;
638 if target != 3 && idx >= target_dim {
639 return Err(format!(
640 "S-suffix '{name}' index {idx} out of range for target dim {target_dim}"
641 ));
642 }
643 if is_real {
644 let v: Number = parts[1]
645 .parse()
646 .map_err(|e| format!("S real entry value '{}': {e}", parts[1]))?;
647 if target == 3 {
648 problem_real = v;
649 } else {
650 real_buf[idx] = v;
651 }
652 } else {
653 let v: Index = parts[1]
654 .parse()
655 .map_err(|e| format!("S int entry value '{}': {e}", parts[1]))?;
656 if target == 3 {
657 problem_int = v;
658 } else {
659 int_buf[idx] = v;
660 }
661 }
662 }
663
664 match (target, is_real) {
665 (0, false) => {
666 out.var_int.insert(name, int_buf);
667 }
668 (1, false) => {
669 out.con_int.insert(name, int_buf);
670 }
671 (2, false) => {
672 out.obj_int.insert(name, int_buf);
673 }
674 (3, false) => {
675 out.problem_int.insert(name, problem_int);
676 }
677 (0, true) => {
678 out.var_real.insert(name, real_buf);
679 }
680 (1, true) => {
681 out.con_real.insert(name, real_buf);
682 }
683 (2, true) => {
684 out.obj_real.insert(name, real_buf);
685 }
686 (3, true) => {
687 out.problem_real.insert(name, problem_real);
688 }
689 _ => unreachable!(),
690 }
691 Ok(())
692}
693
694fn parse_segment_index(s: &str, tag: char) -> Result<usize, String> {
695 let trimmed = s.trim_start_matches(tag);
696 trimmed
697 .parse()
698 .map_err(|e| format!("malformed {tag}-segment index '{s}': {e}"))
699}
700
701fn parse_bound_line(line: &str) -> Result<(Number, Number), String> {
702 let parts: Vec<&str> = line.split_whitespace().collect();
703 if parts.is_empty() {
704 return Err("empty bound line".into());
705 }
706 let kind: i32 = parts[0].parse().map_err(|e| format!("bound kind: {e}"))?;
707 let lo;
708 let hi;
709 match kind {
710 0 => {
711 if parts.len() < 3 {
713 return Err(format!("bound kind 0 needs 2 values: '{line}'"));
714 }
715 lo = parts[1].parse().map_err(|e| format!("lo: {e}"))?;
716 hi = parts[2].parse().map_err(|e| format!("hi: {e}"))?;
717 }
718 1 => {
719 if parts.len() < 2 {
721 return Err(format!("bound kind 1 needs 1 value: '{line}'"));
722 }
723 lo = -1e19;
724 hi = parts[1].parse().map_err(|e| format!("hi: {e}"))?;
725 }
726 2 => {
727 if parts.len() < 2 {
729 return Err(format!("bound kind 2 needs 1 value: '{line}'"));
730 }
731 lo = parts[1].parse().map_err(|e| format!("lo: {e}"))?;
732 hi = 1e19;
733 }
734 3 => {
735 lo = -1e19;
737 hi = 1e19;
738 }
739 4 => {
740 if parts.len() < 2 {
742 return Err(format!("bound kind 4 needs 1 value: '{line}'"));
743 }
744 let v: Number = parts[1].parse().map_err(|e| format!("eq: {e}"))?;
745 lo = v;
746 hi = v;
747 }
748 5 => return Err("complementarity (kind 5) bounds are not supported".into()),
749 other => return Err(format!("unknown bound kind {other}")),
750 }
751 Ok((lo, hi))
752}
753
754fn parse_var_coef(line: &str) -> Result<(usize, Number), String> {
755 let parts: Vec<&str> = line.split_whitespace().collect();
756 if parts.len() < 2 {
757 return Err(format!("malformed var/coef line: '{line}'"));
758 }
759 let v: usize = parts[0].parse().map_err(|e| format!("var idx: {e}"))?;
760 let c: Number = parts[1].parse().map_err(|e| format!("coef: {e}"))?;
761 Ok((v, c))
762}
763
764struct Parser<'a> {
765 lines: Vec<&'a str>,
766 pos: usize,
767 n: usize,
768 m: usize,
769 num_obj: usize,
770 n_funcs: usize,
772 cses: Vec<Arc<Expr>>,
775}
776
777impl<'a> Parser<'a> {
778 fn new(txt: &'a str) -> Self {
779 let lines: Vec<&str> = txt.lines().collect();
780 Self {
781 lines,
782 pos: 0,
783 n: 0,
784 m: 0,
785 num_obj: 0,
786 n_funcs: 0,
787 cses: Vec::new(),
788 }
789 }
790
791 fn next_line(&mut self) -> Option<&'a str> {
792 while self.pos < self.lines.len() {
793 let l = self.lines[self.pos];
794 self.pos += 1;
795 let trimmed = strip_comment(l).trim();
799 if !trimmed.is_empty() {
800 return Some(l);
801 }
802 }
803 None
804 }
805
806 fn next_data_line(&mut self) -> Result<String, String> {
807 let raw = self
808 .next_line()
809 .ok_or_else(|| "unexpected end of file in data line".to_string())?;
810 Ok(strip_comment(raw).trim().to_string())
811 }
812
813 fn parse_header(&mut self) -> Result<(), String> {
814 let line0 = self.next_line().ok_or("empty .nl file")?;
815 let trimmed = strip_comment(line0).trim();
816 let first = trimmed.chars().next().ok_or("empty header line")?;
817 if first != 'g' {
818 return Err(format!(
819 "only ASCII (g-) .nl files supported; got header '{trimmed}'"
820 ));
821 }
822
823 let l2 = self.next_data_line()?;
825 let nums: Vec<&str> = l2.split_whitespace().collect();
826 if nums.len() < 3 {
827 return Err(format!("malformed line 2: '{l2}'"));
828 }
829 self.n = nums[0].parse().map_err(|e| format!("n: {e}"))?;
830 self.m = nums[1].parse().map_err(|e| format!("m: {e}"))?;
831 self.num_obj = nums[2].parse().map_err(|e| format!("num_obj: {e}"))?;
832
833 for _ in 0..3 {
835 self.next_data_line()?;
836 }
837 let l5 = self.next_data_line()?;
839 let nums5: Vec<&str> = l5.split_whitespace().collect();
840 self.n_funcs = nums5.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
841 for _ in 0..4 {
843 self.next_data_line()?;
844 }
845 Ok(())
846 }
847
848 fn peek_segment_line(&mut self) -> Option<&'a str> {
849 let saved = self.pos;
850 let l = self.next_line()?;
851 self.pos = saved;
852 Some(l)
853 }
854
855 fn eat_segment_header(&mut self) -> Result<(String, String), String> {
858 let raw = self
859 .next_line()
860 .ok_or_else(|| "expected segment header".to_string())?;
861 let (hdr, comment) = split_comment(raw);
862 Ok((hdr.trim().to_string(), comment.trim().to_string()))
863 }
864
865 fn parse_expr(&mut self) -> Result<Expr, String> {
866 let raw = self
867 .next_line()
868 .ok_or_else(|| "expected expression token".to_string())?;
869 let tok = strip_comment(raw).trim().to_string();
870 if tok.is_empty() {
871 return Err("empty expression token".into());
872 }
873 let first = tok.chars().next().ok_or("empty expression token")?;
874 match first {
875 'n' => {
876 let v: Number = tok[1..]
877 .trim()
878 .parse()
879 .map_err(|e| format!("n value: {e}"))?;
880 Ok(Expr::Const(v))
881 }
882 'v' => {
883 let i: usize = tok[1..]
884 .trim()
885 .parse()
886 .map_err(|e| format!("v index: {e}"))?;
887 Ok(self.var_or_cse(i)?)
888 }
889 'o' => {
890 let code: i32 = tok[1..]
891 .trim()
892 .parse()
893 .map_err(|e| format!("opcode: {e}"))?;
894 self.parse_opcode(code)
895 }
896 'f' => {
897 let rest = &tok[1..];
900 let mut parts = rest.split_whitespace();
901 let id_str = parts
902 .next()
903 .ok_or_else(|| format!("missing function id in '{tok}'"))?;
904 let nargs_str = parts
905 .next()
906 .ok_or_else(|| format!("missing nargs in '{tok}'"))?;
907 let id: usize = id_str
908 .parse()
909 .map_err(|e| format!("bad function id '{id_str}': {e}"))?;
910 let nargs: usize = nargs_str
911 .parse()
912 .map_err(|e| format!("bad funcall nargs '{nargs_str}': {e}"))?;
913 let mut args: Vec<FuncallArg> = Vec::with_capacity(nargs);
914 for _ in 0..nargs {
915 args.push(self.parse_funcall_arg()?);
916 }
917 Ok(Expr::Funcall { id, args })
918 }
919 't' | 'u' => Err(format!("unsupported expression token '{tok}'")),
920 other => Err(format!(
921 "unexpected expression token start '{other}': '{tok}'"
922 )),
923 }
924 }
925
926 fn parse_funcall_arg(&mut self) -> Result<FuncallArg, String> {
932 let saved = self.pos;
934 let raw = self
935 .next_line()
936 .ok_or_else(|| "expected funcall argument".to_string())?;
937 let lead = raw.trim_start();
945 if let Some(after_h) = lead.strip_prefix('h') {
946 let colon = after_h
947 .find(':')
948 .ok_or_else(|| format!("malformed Hollerith string arg (no ':'): {lead:?}"))?;
949 let len: usize = after_h[..colon]
950 .trim()
951 .parse()
952 .map_err(|e| format!("Hollerith length in {lead:?}: {e}"))?;
953 let chars = &after_h[colon + 1..];
954 if chars.len() < len {
955 return Err(format!(
956 "Hollerith string shorter than declared length {len}: {chars:?}"
957 ));
958 }
959 if !chars.is_char_boundary(len) {
962 return Err(format!(
963 "Hollerith length {len} splits a multibyte char in {chars:?}"
964 ));
965 }
966 Ok(FuncallArg::Str(chars[..len].to_string()))
967 } else {
968 self.pos = saved;
970 Ok(FuncallArg::Real(self.parse_expr()?))
971 }
972 }
973
974 fn parse_opcode(&mut self, code: i32) -> Result<Expr, String> {
975 match code {
976 0 => {
977 let a = self.parse_expr()?;
978 let b = self.parse_expr()?;
979 Ok(Expr::Binary(BinOp::Add, Box::new(a), Box::new(b)))
980 }
981 1 => {
982 let a = self.parse_expr()?;
983 let b = self.parse_expr()?;
984 Ok(Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b)))
985 }
986 2 => {
987 let a = self.parse_expr()?;
988 let b = self.parse_expr()?;
989 Ok(Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b)))
990 }
991 3 => {
992 let a = self.parse_expr()?;
993 let b = self.parse_expr()?;
994 Ok(Expr::Binary(BinOp::Div, Box::new(a), Box::new(b)))
995 }
996 5 => {
997 let a = self.parse_expr()?;
998 let b = self.parse_expr()?;
999 Ok(Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b)))
1000 }
1001 15 => Ok(Expr::Unary(UnaryOp::Abs, Box::new(self.parse_expr()?))),
1002 16 => Ok(Expr::Unary(UnaryOp::Neg, Box::new(self.parse_expr()?))),
1003 39 => Ok(Expr::Unary(UnaryOp::Sqrt, Box::new(self.parse_expr()?))),
1004 41 => Ok(Expr::Unary(UnaryOp::Sin, Box::new(self.parse_expr()?))),
1005 42 => Ok(Expr::Unary(UnaryOp::Log10, Box::new(self.parse_expr()?))),
1006 43 => Ok(Expr::Unary(UnaryOp::Log, Box::new(self.parse_expr()?))),
1007 44 => Ok(Expr::Unary(UnaryOp::Exp, Box::new(self.parse_expr()?))),
1008 46 => Ok(Expr::Unary(UnaryOp::Cos, Box::new(self.parse_expr()?))),
1009 38 => Ok(Expr::Unary(UnaryOp::Tan, Box::new(self.parse_expr()?))),
1010 49 => Ok(Expr::Unary(UnaryOp::Atan, Box::new(self.parse_expr()?))),
1011 53 => Ok(Expr::Unary(UnaryOp::Acos, Box::new(self.parse_expr()?))),
1012 40 => Ok(Expr::Unary(UnaryOp::Sinh, Box::new(self.parse_expr()?))),
1013 45 => Ok(Expr::Unary(UnaryOp::Cosh, Box::new(self.parse_expr()?))),
1014 37 => Ok(Expr::Unary(UnaryOp::Tanh, Box::new(self.parse_expr()?))),
1015 51 => Ok(Expr::Unary(UnaryOp::Asin, Box::new(self.parse_expr()?))),
1016 52 => Ok(Expr::Unary(UnaryOp::Acosh, Box::new(self.parse_expr()?))),
1017 50 => Ok(Expr::Unary(UnaryOp::Asinh, Box::new(self.parse_expr()?))),
1018 47 => Ok(Expr::Unary(UnaryOp::Atanh, Box::new(self.parse_expr()?))),
1019 48 => {
1021 let a = self.parse_expr()?;
1022 let b = self.parse_expr()?;
1023 Ok(Expr::Binary(BinOp::Atan2, Box::new(a), Box::new(b)))
1024 }
1025 22 => self.parse_compare(CmpOp::Lt),
1028 23 => self.parse_compare(CmpOp::Le),
1029 24 => self.parse_compare(CmpOp::Eq),
1030 28 => self.parse_compare(CmpOp::Ge),
1031 29 => self.parse_compare(CmpOp::Gt),
1032 30 => self.parse_compare(CmpOp::Ne),
1033 20 => {
1035 let a = self.parse_expr()?;
1036 let b = self.parse_expr()?;
1037 Ok(Expr::Or(Box::new(a), Box::new(b)))
1038 }
1039 21 => {
1040 let a = self.parse_expr()?;
1041 let b = self.parse_expr()?;
1042 Ok(Expr::And(Box::new(a), Box::new(b)))
1043 }
1044 34 => Ok(Expr::Not(Box::new(self.parse_expr()?))),
1045 35 => {
1047 let cond = self.parse_expr()?;
1048 let then_ = self.parse_expr()?;
1049 let else_ = self.parse_expr()?;
1050 Ok(Expr::Cond {
1051 cond: Box::new(cond),
1052 then_: Box::new(then_),
1053 else_: Box::new(else_),
1054 })
1055 }
1056 54 => {
1057 let count_line = self.next_data_line()?;
1059 let count: usize = count_line
1060 .split_whitespace()
1061 .next()
1062 .ok_or_else(|| "missing variadic count".to_string())?
1063 .parse()
1064 .map_err(|e| format!("variadic count: {e}"))?;
1065 let mut args = Vec::with_capacity(count);
1066 for _ in 0..count {
1067 args.push(self.parse_expr()?);
1068 }
1069 Ok(Expr::Sum(args))
1070 }
1071 11 | 12 => {
1074 let count_line = self.next_data_line()?;
1075 let count: usize = count_line
1076 .split_whitespace()
1077 .next()
1078 .ok_or_else(|| "missing min/max list count".to_string())?
1079 .parse()
1080 .map_err(|e| format!("min/max list count: {e}"))?;
1081 let mut args = Vec::with_capacity(count);
1082 for _ in 0..count {
1083 args.push(self.parse_expr()?);
1084 }
1085 if code == 11 {
1086 Ok(Expr::MinList(args))
1087 } else {
1088 Ok(Expr::MaxList(args))
1089 }
1090 }
1091 81 => {
1107 let base = self.parse_expr()?;
1108 let exp = self.parse_expr()?;
1109 Ok(Expr::Binary(BinOp::Pow, Box::new(base), Box::new(exp)))
1110 }
1111 82 => {
1113 let base = self.parse_expr()?;
1114 Ok(Expr::Binary(
1115 BinOp::Pow,
1116 Box::new(base),
1117 Box::new(Expr::Const(2.0)),
1118 ))
1119 }
1120 83 => {
1123 let base = self.parse_expr()?;
1124 let exp = self.parse_expr()?;
1125 Ok(Expr::Binary(BinOp::Pow, Box::new(base), Box::new(exp)))
1126 }
1127 other => Err(format!("unsupported opcode o{other}")),
1128 }
1129 }
1130
1131 fn parse_compare(&mut self, op: CmpOp) -> Result<Expr, String> {
1134 let a = self.parse_expr()?;
1135 let b = self.parse_expr()?;
1136 Ok(Expr::Compare(op, Box::new(a), Box::new(b)))
1137 }
1138
1139 fn var_or_cse(&self, i: usize) -> Result<Expr, String> {
1142 if i < self.n {
1143 Ok(Expr::Var(i))
1144 } else {
1145 let local = i - self.n;
1146 self.cses
1147 .get(local)
1148 .map(|rc| Expr::Cse(rc.clone()))
1149 .ok_or_else(|| {
1150 format!(
1151 "v{i} references CSE {local} but only {} have been defined",
1152 self.cses.len()
1153 )
1154 })
1155 }
1156 }
1157
1158 fn parse_v_segment(&mut self) -> Result<(), String> {
1162 let (hdr, _) = self.eat_segment_header()?;
1163 let parts: Vec<&str> = hdr.split_whitespace().collect();
1164 if parts.len() < 2 {
1165 return Err(format!("malformed V-segment header: {hdr}"));
1166 }
1167 let cse_idx = parse_segment_index(parts[0], 'V')?;
1168 let nlin: usize = parts[1].parse().map_err(|e| format!("V nlin: {e}"))?;
1169 let mut linear: Vec<(usize, Number)> = Vec::with_capacity(nlin);
1171 for _ in 0..nlin {
1172 let line = self.next_data_line()?;
1173 let (var, coef) = parse_var_coef(&line)?;
1174 linear.push((var, coef));
1175 }
1176 let nonlin = self.parse_expr()?;
1177 let mut combined = nonlin;
1180 for (var, coef) in linear {
1181 let v_expr = self.var_or_cse(var)?;
1182 let term = if coef == 1.0 {
1183 v_expr
1184 } else {
1185 Expr::Binary(BinOp::Mul, Box::new(Expr::Const(coef)), Box::new(v_expr))
1186 };
1187 combined = Expr::Binary(BinOp::Add, Box::new(combined), Box::new(term));
1188 }
1189 if cse_idx < self.n {
1190 return Err(format!("V{cse_idx} below n={}", self.n));
1191 }
1192 let local = cse_idx - self.n;
1193 if local != self.cses.len() {
1194 return Err(format!(
1195 "V-segment index V{cse_idx} out of order; expected V{}",
1196 self.n + self.cses.len()
1197 ));
1198 }
1199 self.cses.push(Arc::new(combined));
1200 Ok(())
1201 }
1202}
1203
1204fn strip_comment(s: &str) -> &str {
1205 match s.find('#') {
1206 Some(i) => &s[..i],
1207 None => s,
1208 }
1209}
1210
1211fn split_comment(s: &str) -> (&str, &str) {
1212 match s.find('#') {
1213 Some(i) => (&s[..i], &s[i + 1..]),
1214 None => (s, ""),
1215 }
1216}
1217
1218pub fn eval_expr(e: &Expr, x: &[Number]) -> Number {
1226 match e {
1227 Expr::Const(c) => *c,
1228 Expr::Var(i) => x[*i],
1229 Expr::Binary(op, a, b) => {
1230 let va = eval_expr(a, x);
1231 let vb = eval_expr(b, x);
1232 match op {
1233 BinOp::Add => va + vb,
1234 BinOp::Sub => va - vb,
1235 BinOp::Mul => va * vb,
1236 BinOp::Div => va / vb,
1237 BinOp::Pow => va.powf(vb),
1238 BinOp::Atan2 => va.atan2(vb),
1239 }
1240 }
1241 Expr::Unary(op, a) => {
1242 let va = eval_expr(a, x);
1243 match op {
1244 UnaryOp::Neg => -va,
1245 UnaryOp::Sqrt => va.sqrt(),
1246 UnaryOp::Log => va.ln(),
1247 UnaryOp::Log10 => va.log10(),
1248 UnaryOp::Exp => va.exp(),
1249 UnaryOp::Abs => va.abs(),
1250 UnaryOp::Sin => va.sin(),
1251 UnaryOp::Cos => va.cos(),
1252 UnaryOp::Tan => va.tan(),
1253 UnaryOp::Atan => va.atan(),
1254 UnaryOp::Acos => va.acos(),
1255 UnaryOp::Sinh => va.sinh(),
1256 UnaryOp::Cosh => va.cosh(),
1257 UnaryOp::Tanh => va.tanh(),
1258 UnaryOp::Asin => va.asin(),
1259 UnaryOp::Acosh => va.acosh(),
1260 UnaryOp::Asinh => va.asinh(),
1261 UnaryOp::Atanh => va.atanh(),
1262 }
1263 }
1264 Expr::Sum(args) => args.iter().map(|a| eval_expr(a, x)).sum(),
1265 Expr::MinList(args) => args
1266 .iter()
1267 .map(|a| eval_expr(a, x))
1268 .fold(Number::INFINITY, Number::min),
1269 Expr::MaxList(args) => args
1270 .iter()
1271 .map(|a| eval_expr(a, x))
1272 .fold(Number::NEG_INFINITY, Number::max),
1273 Expr::Compare(op, a, b) => {
1274 let va = eval_expr(a, x);
1275 let vb = eval_expr(b, x);
1276 let truth = match op {
1277 CmpOp::Lt => va < vb,
1278 CmpOp::Le => va <= vb,
1279 CmpOp::Eq => va == vb,
1280 CmpOp::Ge => va >= vb,
1281 CmpOp::Gt => va > vb,
1282 CmpOp::Ne => va != vb,
1283 };
1284 if truth {
1285 1.0
1286 } else {
1287 0.0
1288 }
1289 }
1290 Expr::And(a, b) => {
1291 if eval_expr(a, x) != 0.0 && eval_expr(b, x) != 0.0 {
1292 1.0
1293 } else {
1294 0.0
1295 }
1296 }
1297 Expr::Or(a, b) => {
1298 if eval_expr(a, x) != 0.0 || eval_expr(b, x) != 0.0 {
1299 1.0
1300 } else {
1301 0.0
1302 }
1303 }
1304 Expr::Not(a) => {
1305 if eval_expr(a, x) == 0.0 {
1306 1.0
1307 } else {
1308 0.0
1309 }
1310 }
1311 Expr::Cond { cond, then_, else_ } => {
1312 if eval_expr(cond, x) != 0.0 {
1313 eval_expr(then_, x)
1314 } else {
1315 eval_expr(else_, x)
1316 }
1317 }
1318 Expr::Cse(body) => eval_expr(body, x),
1319 Expr::Funcall { .. } => panic!(
1320 "eval_expr: AMPL imported function called without an external resolver; \
1321 evaluate through the tape AD path (Tape::build_with_externals) instead"
1322 ),
1323 }
1324}
1325
1326fn argmin_argmax(args: &[Expr], x: &[Number], want_min: bool) -> Option<usize> {
1331 let mut best: Option<(usize, Number)> = None;
1332 for (i, a) in args.iter().enumerate() {
1333 let v = eval_expr(a, x);
1334 match best {
1335 None => best = Some((i, v)),
1336 Some((_, bv)) => {
1337 if (want_min && v < bv) || (!want_min && v > bv) {
1341 best = Some((i, v));
1342 }
1343 }
1344 }
1345 }
1346 best.map(|(i, _)| i)
1347}
1348
1349pub fn grad_expr(e: &Expr, x: &[Number], seed: Number, grad: &mut [Number]) {
1351 match e {
1352 Expr::Const(_) => {}
1353 Expr::Var(i) => grad[*i] += seed,
1354 Expr::Binary(op, a, b) => {
1355 let va = eval_expr(a, x);
1356 let vb = eval_expr(b, x);
1357 match op {
1358 BinOp::Add => {
1359 grad_expr(a, x, seed, grad);
1360 grad_expr(b, x, seed, grad);
1361 }
1362 BinOp::Sub => {
1363 grad_expr(a, x, seed, grad);
1364 grad_expr(b, x, -seed, grad);
1365 }
1366 BinOp::Mul => {
1367 grad_expr(a, x, seed * vb, grad);
1368 grad_expr(b, x, seed * va, grad);
1369 }
1370 BinOp::Div => {
1371 grad_expr(a, x, seed / vb, grad);
1372 grad_expr(b, x, -seed * va / (vb * vb), grad);
1373 }
1374 BinOp::Pow => {
1375 let dpa = vb * va.powf(vb - 1.0);
1377 grad_expr(a, x, seed * dpa, grad);
1378 if va > 0.0 {
1380 let dpb = va.powf(vb) * va.ln();
1381 grad_expr(b, x, seed * dpb, grad);
1382 }
1383 }
1384 BinOp::Atan2 => {
1385 let d = va * va + vb * vb;
1387 grad_expr(a, x, seed * vb / d, grad);
1388 grad_expr(b, x, -seed * va / d, grad);
1389 }
1390 }
1391 }
1392 Expr::Unary(op, a) => {
1393 let va = eval_expr(a, x);
1394 let d = match op {
1395 UnaryOp::Neg => -1.0,
1396 UnaryOp::Sqrt => 0.5 / va.sqrt(),
1397 UnaryOp::Log => 1.0 / va,
1398 UnaryOp::Log10 => 1.0 / (va * std::f64::consts::LN_10),
1399 UnaryOp::Exp => va.exp(),
1400 UnaryOp::Abs => {
1401 if va > 0.0 {
1402 1.0
1403 } else if va < 0.0 {
1404 -1.0
1405 } else {
1406 0.0
1407 }
1408 }
1409 UnaryOp::Sin => va.cos(),
1410 UnaryOp::Cos => -va.sin(),
1411 UnaryOp::Tan => {
1412 let t = va.tan();
1413 1.0 + t * t
1414 }
1415 UnaryOp::Atan => 1.0 / (1.0 + va * va),
1416 UnaryOp::Acos => -1.0 / (1.0 - va * va).sqrt(),
1417 UnaryOp::Sinh => va.cosh(),
1418 UnaryOp::Cosh => va.sinh(),
1419 UnaryOp::Tanh => {
1420 let t = va.tanh();
1421 1.0 - t * t
1422 }
1423 UnaryOp::Asin => 1.0 / (1.0 - va * va).sqrt(),
1424 UnaryOp::Acosh => 1.0 / (va * va - 1.0).sqrt(),
1425 UnaryOp::Asinh => 1.0 / (va * va + 1.0).sqrt(),
1426 UnaryOp::Atanh => 1.0 / (1.0 - va * va),
1427 };
1428 grad_expr(a, x, seed * d, grad);
1429 }
1430 Expr::Sum(args) => {
1431 for arg in args {
1432 grad_expr(arg, x, seed, grad);
1433 }
1434 }
1435 Expr::MinList(args) => {
1440 if let Some(k) = argmin_argmax(args, x, true) {
1441 grad_expr(&args[k], x, seed, grad);
1442 }
1443 }
1444 Expr::MaxList(args) => {
1445 if let Some(k) = argmin_argmax(args, x, false) {
1446 grad_expr(&args[k], x, seed, grad);
1447 }
1448 }
1449 Expr::Compare(_, _, _) | Expr::And(_, _) | Expr::Or(_, _) | Expr::Not(_) => {}
1452 Expr::Cond { cond, then_, else_ } => {
1455 if eval_expr(cond, x) != 0.0 {
1456 grad_expr(then_, x, seed, grad);
1457 } else {
1458 grad_expr(else_, x, seed, grad);
1459 }
1460 }
1461 Expr::Cse(body) => grad_expr(body, x, seed, grad),
1462 Expr::Funcall { .. } => {
1463 panic!("grad_expr: AMPL imported function called without an external resolver")
1464 }
1465 }
1466}
1467
1468pub fn collect_vars(e: &Expr, out: &mut BTreeSet<usize>) {
1470 match e {
1471 Expr::Const(_) => {}
1472 Expr::Var(i) => {
1473 out.insert(*i);
1474 }
1475 Expr::Binary(_, a, b) => {
1476 collect_vars(a, out);
1477 collect_vars(b, out);
1478 }
1479 Expr::Unary(_, a) => collect_vars(a, out),
1480 Expr::Sum(args) | Expr::MinList(args) | Expr::MaxList(args) => {
1481 for a in args {
1482 collect_vars(a, out);
1483 }
1484 }
1485 Expr::Compare(_, a, b) | Expr::And(a, b) | Expr::Or(a, b) => {
1491 collect_vars(a, out);
1492 collect_vars(b, out);
1493 }
1494 Expr::Not(a) => collect_vars(a, out),
1495 Expr::Cond { cond, then_, else_ } => {
1496 collect_vars(cond, out);
1497 collect_vars(then_, out);
1498 collect_vars(else_, out);
1499 }
1500 Expr::Cse(body) => collect_vars(body, out),
1501 Expr::Funcall { args, .. } => {
1502 for a in args {
1503 if let FuncallArg::Real(e) = a {
1504 collect_vars(e, out);
1505 }
1506 }
1507 }
1508 }
1509}
1510
1511#[derive(Debug, Clone)]
1523struct ColorWrite {
1524 row: u32,
1525 hess_idx: u32,
1526}
1527
1528#[derive(Debug, Clone)]
1533pub struct NlTnlp {
1534 prob: NlProblem,
1535 obj_tapes: Vec<Tape>,
1538 con_tapes: Vec<Vec<Tape>>,
1541 h_irow: Vec<i32>,
1544 h_jcol: Vec<i32>,
1545 jac_cols: Vec<Vec<usize>>,
1547 jac_nnz: usize,
1548 seeds: Vec<Vec<f64>>,
1555 decoding: Vec<Vec<ColorWrite>>,
1559 obj_tape_colors: Vec<Vec<u32>>,
1563 con_tape_colors: Vec<Vec<Vec<u32>>>,
1565 final_x: Option<Vec<Number>>,
1566 final_obj: Number,
1567 scratch_row_grad: Vec<f64>,
1569 vals_scratch: Vec<f64>,
1572 dot_scratch: Vec<f64>,
1573 adj_scratch: Vec<f64>,
1574 adj_dot_scratch: Vec<f64>,
1575 compressed: Vec<Vec<f64>>,
1578}
1579
1580const P_ADD: u8 = 10;
1597const P_MUL: u8 = 20;
1598const P_NEG: u8 = 30;
1599const P_POW: u8 = 40;
1600const P_ATOM: u8 = 100;
1601
1602fn fmt_num(x: Number) -> String {
1605 if x.is_finite() && x == x.trunc() && x.abs() < 1e15 {
1606 format!("{}", x as i64)
1607 } else {
1608 format!("{x}")
1609 }
1610}
1611
1612fn var_label(i: usize, var_names: &[String]) -> String {
1615 match var_names.get(i) {
1616 Some(s) if !s.is_empty() => s.clone(),
1617 _ => format!("x[{i}]"),
1618 }
1619}
1620
1621fn expr_prec(e: &Expr) -> u8 {
1623 match e {
1624 Expr::Binary(BinOp::Add, ..) | Expr::Binary(BinOp::Sub, ..) | Expr::Sum(_) => P_ADD,
1625 Expr::Binary(BinOp::Mul, ..) | Expr::Binary(BinOp::Div, ..) => P_MUL,
1626 Expr::Unary(UnaryOp::Neg, _) => P_NEG,
1627 Expr::Binary(BinOp::Pow, ..) => P_POW,
1628 Expr::Cse(inner) => expr_prec(inner),
1629 _ => P_ATOM,
1631 }
1632}
1633
1634fn render_prec(e: &Expr, min_prec: u8, vn: &[String], funcs: &[ImportedFunc]) -> String {
1637 let s = render_expr(e, vn, funcs);
1638 if expr_prec(e) < min_prec {
1639 format!("({s})")
1640 } else {
1641 s
1642 }
1643}
1644
1645fn unary_name(op: UnaryOp) -> &'static str {
1646 match op {
1647 UnaryOp::Neg => "-",
1648 UnaryOp::Sqrt => "sqrt",
1649 UnaryOp::Log => "log",
1650 UnaryOp::Exp => "exp",
1651 UnaryOp::Abs => "abs",
1652 UnaryOp::Sin => "sin",
1653 UnaryOp::Cos => "cos",
1654 UnaryOp::Log10 => "log10",
1655 UnaryOp::Tan => "tan",
1656 UnaryOp::Atan => "atan",
1657 UnaryOp::Acos => "acos",
1658 UnaryOp::Sinh => "sinh",
1659 UnaryOp::Cosh => "cosh",
1660 UnaryOp::Tanh => "tanh",
1661 UnaryOp::Asin => "asin",
1662 UnaryOp::Acosh => "acosh",
1663 UnaryOp::Asinh => "asinh",
1664 UnaryOp::Atanh => "atanh",
1665 }
1666}
1667
1668fn cmp_sym(op: CmpOp) -> &'static str {
1669 match op {
1670 CmpOp::Lt => "<",
1671 CmpOp::Le => "<=",
1672 CmpOp::Eq => "==",
1673 CmpOp::Ge => ">=",
1674 CmpOp::Gt => ">",
1675 CmpOp::Ne => "!=",
1676 }
1677}
1678
1679fn push_additive(out: &mut String, rendered: &str, first: bool) {
1684 if first {
1685 out.push_str(rendered);
1686 } else if let Some(rest) = rendered.strip_prefix('-') {
1687 out.push_str(" - ");
1688 out.push_str(rest);
1689 } else {
1690 out.push_str(" + ");
1691 out.push_str(rendered);
1692 }
1693}
1694
1695fn render_expr(e: &Expr, vn: &[String], funcs: &[ImportedFunc]) -> String {
1697 match e {
1698 Expr::Const(c) => fmt_num(*c),
1699 Expr::Var(i) => var_label(*i, vn),
1700 Expr::Binary(op, l, r) => match op {
1701 BinOp::Add => {
1702 let mut s = render_prec(l, P_ADD, vn, funcs);
1703 push_additive(&mut s, &render_prec(r, P_ADD, vn, funcs), false);
1704 s
1705 }
1706 BinOp::Sub => format!(
1708 "{} - {}",
1709 render_prec(l, P_ADD, vn, funcs),
1710 render_prec(r, P_ADD + 1, vn, funcs)
1711 ),
1712 BinOp::Mul => format!(
1713 "{}*{}",
1714 render_prec(l, P_MUL, vn, funcs),
1715 render_prec(r, P_MUL, vn, funcs)
1716 ),
1717 BinOp::Div => format!(
1718 "{}/{}",
1719 render_prec(l, P_MUL, vn, funcs),
1720 render_prec(r, P_MUL + 1, vn, funcs)
1721 ),
1722 BinOp::Pow => format!(
1724 "{}^{}",
1725 render_prec(l, P_POW + 1, vn, funcs),
1726 render_prec(r, P_POW, vn, funcs)
1727 ),
1728 BinOp::Atan2 => format!(
1729 "atan2({}, {})",
1730 render_expr(l, vn, funcs),
1731 render_expr(r, vn, funcs)
1732 ),
1733 },
1734 Expr::Unary(UnaryOp::Neg, a) => format!("-{}", render_prec(a, P_NEG, vn, funcs)),
1735 Expr::Unary(op, a) => format!("{}({})", unary_name(*op), render_expr(a, vn, funcs)),
1736 Expr::Sum(xs) => {
1737 if xs.is_empty() {
1738 "0".to_string()
1739 } else {
1740 let mut s = String::new();
1741 for (k, x) in xs.iter().enumerate() {
1742 push_additive(&mut s, &render_prec(x, P_ADD, vn, funcs), k == 0);
1743 }
1744 s
1745 }
1746 }
1747 Expr::Cse(inner) => render_expr(inner, vn, funcs),
1748 Expr::Funcall { id, args } => {
1749 let name = funcs
1750 .iter()
1751 .find(|f| f.id == *id)
1752 .map(|f| f.name.clone())
1753 .unwrap_or_else(|| format!("extern#{id}"));
1754 let parts: Vec<String> = args
1755 .iter()
1756 .map(|a| match a {
1757 FuncallArg::Real(x) => render_expr(x, vn, funcs),
1758 FuncallArg::Str(s) => format!("{s:?}"),
1759 })
1760 .collect();
1761 format!("{name}({})", parts.join(", "))
1762 }
1763 Expr::Compare(op, a, b) => format!(
1764 "({} {} {})",
1765 render_expr(a, vn, funcs),
1766 cmp_sym(*op),
1767 render_expr(b, vn, funcs)
1768 ),
1769 Expr::And(a, b) => format!(
1770 "({} && {})",
1771 render_expr(a, vn, funcs),
1772 render_expr(b, vn, funcs)
1773 ),
1774 Expr::Or(a, b) => format!(
1775 "({} || {})",
1776 render_expr(a, vn, funcs),
1777 render_expr(b, vn, funcs)
1778 ),
1779 Expr::Not(a) => format!("!({})", render_expr(a, vn, funcs)),
1780 Expr::Cond { cond, then_, else_ } => format!(
1781 "if({}, {}, {})",
1782 render_expr(cond, vn, funcs),
1783 render_expr(then_, vn, funcs),
1784 render_expr(else_, vn, funcs)
1785 ),
1786 Expr::MinList(xs) => format!(
1787 "min({})",
1788 xs.iter()
1789 .map(|x| render_expr(x, vn, funcs))
1790 .collect::<Vec<_>>()
1791 .join(", ")
1792 ),
1793 Expr::MaxList(xs) => format!(
1794 "max({})",
1795 xs.iter()
1796 .map(|x| render_expr(x, vn, funcs))
1797 .collect::<Vec<_>>()
1798 .join(", ")
1799 ),
1800 }
1801}
1802
1803fn render_linear(linear: &[(usize, Number)], vn: &[String]) -> String {
1806 let mut out = String::new();
1807 let mut first = true;
1812 for (var, coef) in linear {
1813 if *coef == 0.0 {
1814 continue;
1815 }
1816 let neg = *coef < 0.0;
1817 let mag = coef.abs();
1818 let term = if mag == 1.0 {
1819 var_label(*var, vn)
1820 } else {
1821 format!("{}*{}", fmt_num(mag), var_label(*var, vn))
1822 };
1823 if first {
1824 if neg {
1825 out.push('-');
1826 }
1827 out.push_str(&term);
1828 first = false;
1829 } else {
1830 out.push_str(if neg { " - " } else { " + " });
1831 out.push_str(&term);
1832 }
1833 }
1834 out
1835}
1836
1837fn render_body(linear: &[(usize, Number)], nonlinear: &Expr, prob: &NlProblem) -> String {
1839 let mut s = render_linear(linear, &prob.var_names);
1840 let nl_is_zero = matches!(nonlinear, Expr::Const(c) if *c == 0.0);
1841 if !nl_is_zero {
1842 let nl = render_prec(nonlinear, P_ADD, &prob.var_names, &prob.imported_funcs);
1843 if s.is_empty() {
1844 s = nl;
1845 } else {
1846 push_additive(&mut s, &nl, false);
1847 }
1848 }
1849 if s.is_empty() {
1850 s = "0".to_string();
1851 }
1852 s
1853}
1854
1855pub fn render_constraint_equation(prob: &NlProblem, k: usize) -> String {
1859 let body = render_body(&prob.con_linear[k], &prob.con_nonlinear[k], prob);
1860 let lo = prob.g_l[k];
1861 let hi = prob.g_u[k];
1862 const INF: Number = 1.0e19;
1863 let has_lo = lo > -INF;
1864 let has_hi = hi < INF;
1865 match (has_lo, has_hi) {
1866 (true, true) if lo == hi => format!("{body} = {}", fmt_num(lo)),
1867 (true, true) => format!("{} <= {body} <= {}", fmt_num(lo), fmt_num(hi)),
1868 (true, false) => format!("{body} >= {}", fmt_num(lo)),
1869 (false, true) => format!("{body} <= {}", fmt_num(hi)),
1870 (false, false) => format!("{body} (free)"),
1871 }
1872}
1873
1874pub fn render_all_constraint_equations(prob: &NlProblem) -> Vec<String> {
1877 (0..prob.m)
1878 .map(|k| render_constraint_equation(prob, k))
1879 .collect()
1880}
1881
1882pub fn constraint_jacobian_sparsity(prob: &NlProblem) -> (Vec<Index>, Vec<Index>) {
1896 let mut irow: Vec<Index> = Vec::new();
1897 let mut jcol: Vec<Index> = Vec::new();
1898 let mut support: BTreeSet<usize> = BTreeSet::new();
1899 for k in 0..prob.m {
1900 support.clear();
1901 for &(j, _coef) in &prob.con_linear[k] {
1902 support.insert(j);
1903 }
1904 collect_vars(&prob.con_nonlinear[k], &mut support);
1905 for &j in &support {
1906 irow.push(k as Index);
1907 jcol.push(j as Index);
1908 }
1909 }
1910 (irow, jcol)
1911}
1912
1913fn split_top_sums(expr: &Expr) -> Vec<Expr> {
1940 let mut out = Vec::new();
1941 fn push_leaf(e: &Expr, factor: f64, out: &mut Vec<Expr>) {
1942 if factor == 1.0 {
1943 out.push(e.clone());
1944 } else if factor == -1.0 {
1945 out.push(Expr::Unary(UnaryOp::Neg, Box::new(e.clone())));
1946 } else {
1947 out.push(Expr::Binary(
1948 BinOp::Mul,
1949 Box::new(Expr::Const(factor)),
1950 Box::new(e.clone()),
1951 ));
1952 }
1953 }
1954 fn go(e: &Expr, factor: f64, out: &mut Vec<Expr>) {
1955 match e {
1956 Expr::Sum(terms) => {
1957 for t in terms {
1958 go(t, factor, out);
1959 }
1960 }
1961 Expr::Binary(BinOp::Add, l, r) => {
1962 go(l, factor, out);
1963 go(r, factor, out);
1964 }
1965 Expr::Binary(BinOp::Sub, l, r) => {
1966 go(l, factor, out);
1967 go(r, -factor, out);
1968 }
1969 Expr::Unary(UnaryOp::Neg, x) => {
1970 go(x, -factor, out);
1971 }
1972 Expr::Binary(BinOp::Mul, l, r) => match (l.as_ref(), r.as_ref()) {
1975 (Expr::Const(c), _) => go(r, factor * c, out),
1976 (_, Expr::Const(c)) => go(l, factor * c, out),
1977 _ => push_leaf(e, factor, out),
1978 },
1979 Expr::Binary(BinOp::Div, l, r) => match r.as_ref() {
1980 Expr::Const(c) if *c != 0.0 => go(l, factor / c, out),
1981 _ => push_leaf(e, factor, out),
1982 },
1983 _ => push_leaf(e, factor, out),
1984 }
1985 }
1986 go(expr, 1.0, &mut out);
1987 if out.is_empty() {
1988 out.push(Expr::Const(0.0));
1989 }
1990 out
1991}
1992
1993fn greedy_hessian_coloring(n: usize, lower_pairs: &[(usize, usize)]) -> (Vec<u32>, usize) {
2009 if n == 0 {
2010 return (Vec::new(), 0);
2011 }
2012
2013 let mut col_rows: Vec<Vec<u32>> = vec![Vec::new(); n];
2018 let mut row_cols: Vec<Vec<u32>> = vec![Vec::new(); n];
2019 for &(i, j) in lower_pairs {
2020 col_rows[j].push(i as u32);
2021 row_cols[i].push(j as u32);
2022 if i != j {
2023 col_rows[i].push(j as u32);
2024 row_cols[j].push(i as u32);
2025 }
2026 }
2027
2028 let mut var_color = vec![u32::MAX; n];
2029 let mut forbidden = vec![u32::MAX; n + 1];
2030 let mut n_colors: u32 = 0;
2031
2032 for j in 0..n {
2033 if col_rows[j].is_empty() {
2035 continue;
2036 }
2037 for &r in &col_rows[j] {
2041 for &c in &row_cols[r as usize] {
2042 if c as usize == j {
2043 continue;
2044 }
2045 let cc = var_color[c as usize];
2046 if cc != u32::MAX {
2047 forbidden[cc as usize] = j as u32;
2048 }
2049 }
2050 }
2051 let mut chosen: u32 = 0;
2053 while (chosen as usize) < forbidden.len() && forbidden[chosen as usize] == j as u32 {
2054 chosen += 1;
2055 }
2056 var_color[j] = chosen;
2057 if chosen + 1 > n_colors {
2058 n_colors = chosen + 1;
2059 }
2060 }
2061
2062 (var_color, n_colors as usize)
2063}
2064
2065impl NlTnlp {
2066 pub fn new(prob: NlProblem) -> Self {
2075 Self::try_new(prob)
2076 .unwrap_or_else(|e| panic!("failed to resolve AMPL external functions: {e}"))
2077 }
2078
2079 pub fn try_new(prob: NlProblem) -> Result<Self, String> {
2084 let mut referenced: BTreeSet<usize> = BTreeSet::new();
2090 super::nl_external::collect_funcall_ids(&prob.obj_nonlinear, &mut referenced);
2091 for c in &prob.con_nonlinear {
2092 super::nl_external::collect_funcall_ids(c, &mut referenced);
2093 }
2094 let resolver = if referenced.is_empty() {
2095 super::nl_external::ExternalResolver::default()
2096 } else {
2097 super::nl_external::ExternalResolver::build_for_problem(
2098 &prob.imported_funcs,
2099 &referenced,
2100 )?
2101 };
2102
2103 let obj_summands = split_top_sums(&prob.obj_nonlinear);
2109 let obj_tapes: Vec<Tape> = obj_summands
2110 .iter()
2111 .map(|e| Tape::build_with_externals(e, &resolver))
2112 .collect();
2113
2114 let mut con_tapes: Vec<Vec<Tape>> = Vec::with_capacity(prob.m);
2115 for k in 0..prob.m {
2116 let summands = split_top_sums(&prob.con_nonlinear[k]);
2117 con_tapes.push(
2118 summands
2119 .iter()
2120 .map(|e| Tape::build_with_externals(e, &resolver))
2121 .collect(),
2122 );
2123 }
2124
2125 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
2128 for t in &obj_tapes {
2129 pairs.extend(t.hessian_sparsity());
2130 }
2131 for row in &con_tapes {
2132 for t in row {
2133 pairs.extend(t.hessian_sparsity());
2134 }
2135 }
2136 let mut h_irow = Vec::with_capacity(pairs.len());
2137 let mut h_jcol = Vec::with_capacity(pairs.len());
2138 let mut hess_map = HashMap::with_capacity(pairs.len());
2139 for (k, (hi, lo)) in pairs.iter().enumerate() {
2140 h_irow.push(*hi as i32);
2141 h_jcol.push(*lo as i32);
2142 hess_map.insert((*hi, *lo), k);
2143 }
2144
2145 let lower_pairs: Vec<(usize, usize)> = pairs.iter().copied().collect();
2150 let (var_color, n_colors) = greedy_hessian_coloring(prob.n, &lower_pairs);
2151
2152 let mut seeds: Vec<Vec<f64>> = vec![vec![0.0; prob.n]; n_colors];
2155 for (k, &c) in var_color.iter().enumerate() {
2156 if c != u32::MAX {
2157 seeds[c as usize][k] = 1.0;
2158 }
2159 }
2160
2161 let mut decoding: Vec<Vec<ColorWrite>> = vec![Vec::new(); n_colors];
2167 for (&(i, j), &idx) in hess_map.iter() {
2168 let c = var_color[j];
2169 debug_assert!(
2170 c != u32::MAX,
2171 "column {j} has Hessian pair {idx} but no color"
2172 );
2173 decoding[c as usize].push(ColorWrite {
2174 row: i as u32,
2175 hess_idx: idx as u32,
2176 });
2177 }
2178
2179 let tape_colors = |t: &Tape| -> Vec<u32> {
2183 let mut s: BTreeSet<u32> = BTreeSet::new();
2184 for v in t.variables() {
2185 let c = var_color[v];
2186 if c != u32::MAX {
2187 s.insert(c);
2188 }
2189 }
2190 s.into_iter().collect()
2191 };
2192 let obj_tape_colors: Vec<Vec<u32>> = obj_tapes.iter().map(tape_colors).collect();
2193 let con_tape_colors: Vec<Vec<Vec<u32>>> = con_tapes
2194 .iter()
2195 .map(|row| row.iter().map(tape_colors).collect())
2196 .collect();
2197
2198 let mut jac_cols: Vec<Vec<usize>> = Vec::with_capacity(prob.m);
2201 let mut jac_nnz = 0;
2202 for i in 0..prob.m {
2203 let mut set: BTreeSet<usize> = BTreeSet::new();
2204 for t in &con_tapes[i] {
2205 for v in t.variables() {
2206 set.insert(v);
2207 }
2208 }
2209 for (v, _) in &prob.con_linear[i] {
2210 set.insert(*v);
2211 }
2212 let cols: Vec<usize> = set.into_iter().collect();
2213 jac_nnz += cols.len();
2214 jac_cols.push(cols);
2215 }
2216
2217 let mut max_tape_n: usize = 0;
2218 for t in &obj_tapes {
2219 max_tape_n = max_tape_n.max(t.ops.len());
2220 }
2221 for row in &con_tapes {
2222 for t in row {
2223 max_tape_n = max_tape_n.max(t.ops.len());
2224 }
2225 }
2226
2227 if std::env::var("POUNCE_DBG_TAPE_STATS").is_ok() {
2228 let n_obj = obj_tapes.len();
2229 let n_con: usize = con_tapes.iter().map(|r| r.len()).sum();
2230 let total = n_obj + n_con;
2231 let mut sum_ops: usize = 0;
2232 for t in &obj_tapes {
2233 sum_ops += t.ops.len();
2234 }
2235 for row in &con_tapes {
2236 for t in row {
2237 sum_ops += t.ops.len();
2238 }
2239 }
2240 let t = total.max(1);
2241 let nnz_h = h_irow.len();
2242 let avg_decode =
2243 decoding.iter().map(|d| d.len()).sum::<usize>() as f64 / n_colors.max(1) as f64;
2244 eprintln!(
2245 "[tape stats] summands={total} (obj={n_obj} con={n_con}) \
2246 total_ops={sum_ops} avg_ops={:.1} max_ops={max_tape_n} \
2247 n_colors={n_colors} avg_decode_per_color={avg_decode:.1} nnz_h={nnz_h}",
2248 sum_ops as f64 / t as f64,
2249 );
2250 }
2251
2252 let compressed: Vec<Vec<f64>> = vec![vec![0.0; prob.n]; n_colors];
2253
2254 Ok(Self {
2255 prob,
2256 obj_tapes,
2257 con_tapes,
2258 h_irow,
2259 h_jcol,
2260 jac_cols,
2261 jac_nnz,
2262 seeds,
2263 decoding,
2264 obj_tape_colors,
2265 con_tape_colors,
2266 final_x: None,
2267 final_obj: 0.0,
2268 scratch_row_grad: Vec::new(),
2269 vals_scratch: vec![0.0; max_tape_n],
2270 dot_scratch: vec![0.0; max_tape_n],
2271 adj_scratch: vec![0.0; max_tape_n],
2272 adj_dot_scratch: vec![0.0; max_tape_n],
2273 compressed,
2274 })
2275 }
2276
2277 pub fn final_x(&self) -> Option<&[Number]> {
2278 self.final_x.as_deref()
2279 }
2280
2281 pub fn final_obj(&self) -> Number {
2282 self.final_obj
2283 }
2284
2285 pub fn problem(&self) -> &NlProblem {
2289 &self.prob
2290 }
2291
2292 pub fn variant(&self, v: &NlVariation) -> Result<Self, String> {
2304 let check = |name: &str, got: usize, want: usize| -> Result<(), String> {
2305 if got == want {
2306 Ok(())
2307 } else {
2308 Err(format!(
2309 "NlVariation.{name} has length {got}, expected {want}"
2310 ))
2311 }
2312 };
2313 let mut out = self.clone();
2314 out.final_x = None;
2315 out.final_obj = 0.0;
2316 if let Some(x0) = &v.x0 {
2317 check("x0", x0.len(), self.prob.n)?;
2318 out.prob.x0.clone_from(x0);
2319 }
2320 if let Some(x_l) = &v.x_l {
2321 check("x_l", x_l.len(), self.prob.n)?;
2322 out.prob.x_l.clone_from(x_l);
2323 }
2324 if let Some(x_u) = &v.x_u {
2325 check("x_u", x_u.len(), self.prob.n)?;
2326 out.prob.x_u.clone_from(x_u);
2327 }
2328 if let Some(g_l) = &v.g_l {
2329 check("g_l", g_l.len(), self.prob.m)?;
2330 out.prob.g_l.clone_from(g_l);
2331 }
2332 if let Some(g_u) = &v.g_u {
2333 check("g_u", g_u.len(), self.prob.m)?;
2334 out.prob.g_u.clone_from(g_u);
2335 }
2336 Ok(out)
2337 }
2338
2339 pub fn variants(&self, vs: &[NlVariation]) -> Result<Vec<Self>, String> {
2343 vs.iter().map(|v| self.variant(v)).collect()
2344 }
2345}
2346
2347#[derive(Debug, Clone, Default)]
2354pub struct NlVariation {
2355 pub x0: Option<Vec<Number>>,
2356 pub x_l: Option<Vec<Number>>,
2357 pub x_u: Option<Vec<Number>>,
2358 pub g_l: Option<Vec<Number>>,
2359 pub g_u: Option<Vec<Number>>,
2360}
2361
2362impl pounce_nlp::expression_provider::ExpressionProvider for NlTnlp {
2363 fn constraint_expression(&self, i: usize) -> Option<pounce_nlp::FbbtTape> {
2368 let nonlinear = self.prob.con_nonlinear.get(i)?;
2369 let linear = self
2370 .prob
2371 .con_linear
2372 .get(i)
2373 .map(|v| v.as_slice())
2374 .unwrap_or(&[]);
2375 crate::nl_fbbt_translate::translate_constraint(nonlinear, linear)
2376 }
2377
2378 fn variable_name(&self, i: usize) -> Option<&str> {
2381 self.prob.var_names.get(i).map(String::as_str)
2382 }
2383
2384 fn constraint_name(&self, i: usize) -> Option<&str> {
2387 self.prob.con_names.get(i).map(String::as_str)
2388 }
2389}
2390
2391impl TNLP for NlTnlp {
2392 fn get_nlp_info(&mut self) -> Option<NlpInfo> {
2393 Some(NlpInfo {
2394 n: self.prob.n as Index,
2395 m: self.prob.m as Index,
2396 nnz_jac_g: self.jac_nnz as Index,
2397 nnz_h_lag: self.h_irow.len() as Index,
2398 index_style: IndexStyle::C,
2399 })
2400 }
2401
2402 fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool {
2403 b.x_l.copy_from_slice(&self.prob.x_l);
2404 b.x_u.copy_from_slice(&self.prob.x_u);
2405 if !self.prob.g_l.is_empty() {
2406 b.g_l.copy_from_slice(&self.prob.g_l);
2407 b.g_u.copy_from_slice(&self.prob.g_u);
2408 }
2409 true
2410 }
2411
2412 fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool {
2413 sp.x.copy_from_slice(&self.prob.x0);
2414 if sp.init_lambda {
2423 sp.lambda.copy_from_slice(&self.prob.lambda0);
2424 }
2425 true
2426 }
2427
2428 fn eval_f(&mut self, x: &[Number], _new_x: bool) -> Option<Number> {
2429 let mut nl: Number = 0.0;
2430 for t in &self.obj_tapes {
2431 nl += t.eval(x);
2432 }
2433 let lin: Number = self.prob.obj_linear.iter().map(|(i, c)| c * x[*i]).sum();
2434 let v = self.prob.obj_constant + nl + lin;
2435 let signed = if self.prob.minimize { v } else { -v };
2436 Some(signed)
2437 }
2438
2439 fn eval_grad_f(&mut self, x: &[Number], _new_x: bool, grad: &mut [Number]) -> bool {
2440 grad.fill(0.0);
2441 for t in &self.obj_tapes {
2445 t.gradient_seed_into(x, 1.0, grad, &mut self.vals_scratch, &mut self.adj_scratch);
2446 }
2447 for (i, c) in &self.prob.obj_linear {
2448 grad[*i] += c;
2449 }
2450 if !self.prob.minimize {
2451 for g in grad.iter_mut() {
2452 *g = -*g;
2453 }
2454 }
2455 true
2456 }
2457
2458 fn eval_g(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
2459 for i in 0..self.prob.m {
2460 let mut nl: Number = 0.0;
2461 for t in &self.con_tapes[i] {
2462 nl += t.eval(x);
2463 }
2464 let lin: Number = self.prob.con_linear[i].iter().map(|(j, c)| c * x[*j]).sum();
2465 g[i] = nl + lin;
2466 }
2467 true
2468 }
2469
2470 fn eval_jac_g(
2471 &mut self,
2472 x: Option<&[Number]>,
2473 _new_x: bool,
2474 mode: SparsityRequest<'_>,
2475 ) -> bool {
2476 match mode {
2477 SparsityRequest::Structure { irow, jcol } => {
2478 let mut k = 0;
2479 for i in 0..self.prob.m {
2480 for &j in &self.jac_cols[i] {
2481 irow[k] = i as Index;
2482 jcol[k] = j as Index;
2483 k += 1;
2484 }
2485 }
2486 true
2487 }
2488 SparsityRequest::Values { values } => {
2489 let n = self.prob.n;
2490 let xs = x.unwrap_or(&self.prob.x0);
2491 if self.scratch_row_grad.len() < n {
2492 self.scratch_row_grad.resize(n, 0.0);
2493 }
2494 let mut k = 0;
2495 for i in 0..self.prob.m {
2496 for &j in &self.jac_cols[i] {
2497 self.scratch_row_grad[j] = 0.0;
2498 }
2499 for t in &self.con_tapes[i] {
2500 t.gradient_seed_into(
2503 xs,
2504 1.0,
2505 &mut self.scratch_row_grad,
2506 &mut self.vals_scratch,
2507 &mut self.adj_scratch,
2508 );
2509 }
2510 for &(v, c) in &self.prob.con_linear[i] {
2511 self.scratch_row_grad[v] += c;
2512 }
2513 for &j in &self.jac_cols[i] {
2514 values[k] = self.scratch_row_grad[j];
2515 k += 1;
2516 }
2517 }
2518 true
2519 }
2520 }
2521 }
2522
2523 fn eval_h(
2524 &mut self,
2525 x: Option<&[Number]>,
2526 _new_x: bool,
2527 obj_factor: Number,
2528 lambda: Option<&[Number]>,
2529 _new_lambda: bool,
2530 mode: SparsityRequest<'_>,
2531 ) -> bool {
2532 match mode {
2533 SparsityRequest::Structure { irow, jcol } => {
2534 irow.copy_from_slice(&self.h_irow);
2535 jcol.copy_from_slice(&self.h_jcol);
2536 true
2537 }
2538 SparsityRequest::Values { values } => {
2539 let x = x.unwrap_or(&self.prob.x0);
2540 values.fill(0.0);
2541
2542 let obj_seed = if self.prob.minimize {
2543 obj_factor
2544 } else {
2545 -obj_factor
2546 };
2547 for buf in &mut self.compressed {
2556 buf.fill(0.0);
2557 }
2558
2559 if obj_seed != 0.0 {
2560 for (ti, t) in self.obj_tapes.iter().enumerate() {
2561 if t.ops.is_empty() {
2562 continue;
2563 }
2564 t.forward_into(x, &mut self.vals_scratch);
2565 for &c in &self.obj_tape_colors[ti] {
2566 t.hessian_directional(
2567 &self.vals_scratch,
2568 &self.seeds[c as usize],
2569 obj_seed,
2570 &mut self.compressed[c as usize],
2571 &mut self.dot_scratch,
2572 &mut self.adj_scratch,
2573 &mut self.adj_dot_scratch,
2574 );
2575 }
2576 }
2577 }
2578
2579 if let Some(lam) = lambda {
2580 for k in 0..self.prob.m {
2581 let w = lam[k];
2582 if w == 0.0 {
2583 continue;
2584 }
2585 for (ti, t) in self.con_tapes[k].iter().enumerate() {
2586 if t.ops.is_empty() {
2587 continue;
2588 }
2589 t.forward_into(x, &mut self.vals_scratch);
2590 for &c in &self.con_tape_colors[k][ti] {
2591 t.hessian_directional(
2592 &self.vals_scratch,
2593 &self.seeds[c as usize],
2594 w,
2595 &mut self.compressed[c as usize],
2596 &mut self.dot_scratch,
2597 &mut self.adj_scratch,
2598 &mut self.adj_dot_scratch,
2599 );
2600 }
2601 }
2602 }
2603 }
2604
2605 for (c, table) in self.decoding.iter().enumerate() {
2608 let comp = &self.compressed[c];
2609 for w in table {
2610 values[w.hess_idx as usize] += comp[w.row as usize];
2611 }
2612 }
2613 true
2614 }
2615 }
2616 }
2617
2618 fn finalize_solution(&mut self, sol: Solution<'_>, _d: &IpoptData, _q: &IpoptCq) {
2619 self.final_x = Some(sol.x.to_vec());
2620 self.final_obj = sol.obj_value;
2621 }
2622
2623 fn get_var_con_metadata(&mut self, var: &mut MetaData, con: &mut MetaData) -> bool {
2633 let mut any = false;
2634 if !self.prob.var_names.is_empty() {
2635 var.strings
2636 .insert(IDX_NAMES.to_string(), self.prob.var_names.clone());
2637 any = true;
2638 }
2639 if !self.prob.con_names.is_empty() {
2640 con.strings
2641 .insert(IDX_NAMES.to_string(), self.prob.con_names.clone());
2642 any = true;
2643 }
2644 any
2645 }
2646
2647 fn get_constraints_linearity(&mut self, types: &mut [Linearity]) -> bool {
2648 for (i, t) in types.iter_mut().enumerate() {
2652 *t = match &self.prob.con_nonlinear[i] {
2653 Expr::Const(c) if *c == 0.0 => Linearity::Linear,
2654 _ => Linearity::NonLinear,
2655 };
2656 }
2657 true
2658 }
2659
2660 fn get_variables_linearity(&mut self, types: &mut [Linearity]) -> bool {
2661 let mut nonlinear: BTreeSet<usize> = BTreeSet::new();
2671 collect_vars(&self.prob.obj_nonlinear, &mut nonlinear);
2672 for row in &self.prob.con_nonlinear {
2673 collect_vars(row, &mut nonlinear);
2674 }
2675 for (i, t) in types.iter_mut().enumerate() {
2676 *t = if nonlinear.contains(&i) {
2677 Linearity::NonLinear
2678 } else {
2679 Linearity::Linear
2680 };
2681 }
2682 true
2683 }
2684
2685 fn get_objective_variables_linearity(&mut self, types: &mut [Linearity]) -> bool {
2686 let mut nonlinear: BTreeSet<usize> = BTreeSet::new();
2697 collect_vars(&self.prob.obj_nonlinear, &mut nonlinear);
2698 for (i, t) in types.iter_mut().enumerate() {
2699 *t = if nonlinear.contains(&i) {
2700 Linearity::NonLinear
2701 } else {
2702 Linearity::Linear
2703 };
2704 }
2705 true
2706 }
2707}
2708
2709pub fn load_nl_as_tnlp(path: &Path) -> Result<Rc<RefCell<dyn TNLP>>, String> {
2711 let prob = read_nl_file(path)?;
2712 Ok(Rc::new(RefCell::new(NlTnlp::new(prob))))
2713}
2714
2715#[cfg(test)]
2716mod tests {
2717 use super::*;
2718
2719 #[test]
2724 fn nl_problem_and_tnlp_are_send() {
2725 fn assert_send<T: Send>() {}
2726 assert_send::<NlProblem>();
2727 assert_send::<NlTnlp>();
2728 assert_send::<Expr>();
2729 }
2730
2731 #[test]
2734 fn variant_overrides_bounds_and_x0() {
2735 let p = parse_nl_text(SIMPLE).expect("parse");
2736 let mut base = NlTnlp::new(p);
2737 let var = base
2738 .variant(&NlVariation {
2739 x0: Some(vec![3.0, 4.0]),
2740 x_l: Some(vec![-1.0, -2.0]),
2741 x_u: Some(vec![5.0, 6.0]),
2742 ..Default::default()
2743 })
2744 .expect("variant");
2745 let mut var = var;
2746 let (mut x_l, mut x_u) = ([0.0; 2], [0.0; 2]);
2747 let (mut g_l, mut g_u) = ([0.0; 0], [0.0; 0]);
2748 assert!(var.get_bounds_info(BoundsInfo {
2749 x_l: &mut x_l,
2750 x_u: &mut x_u,
2751 g_l: &mut g_l,
2752 g_u: &mut g_u,
2753 }));
2754 assert_eq!(x_l, [-1.0, -2.0]);
2755 assert_eq!(x_u, [5.0, 6.0]);
2756 let mut x = [0.0; 2];
2757 let (mut zl, mut zu, mut lam) = ([0.0; 2], [0.0; 2], [0.0; 0]);
2758 assert!(var.get_starting_point(StartingPoint {
2759 init_x: true,
2760 x: &mut x,
2761 init_z: false,
2762 z_l: &mut zl,
2763 z_u: &mut zu,
2764 init_lambda: false,
2765 lambda: &mut lam,
2766 }));
2767 assert_eq!(x, [3.0, 4.0]);
2768 assert!(base.problem().x_l[0] < -1.0e18);
2770 assert!(base
2772 .variant(&NlVariation {
2773 x0: Some(vec![1.0]),
2774 ..Default::default()
2775 })
2776 .is_err());
2777 }
2778
2779 const SIMPLE: &str = "g3 0 1 0
27972 0 1 0 0
27980 1
27990 0
28000 2 0
28010 0 0 1
28020 0 0 0 0
28030 0
28040 0
28050 0 0 0 0
2806O0 0
2807o0
2808o5
2809o1
2810v0
2811n1
2812n2
2813o5
2814o1
2815v1
2816n2
2817n2
2818b
28193
28203
2821";
2822
2823 #[test]
2824 fn parses_simple_quadratic() {
2825 let p = parse_nl_text(SIMPLE).expect("parse");
2826 assert_eq!(p.n, 2);
2827 assert_eq!(p.m, 0);
2828 assert_eq!(p.num_obj, 1);
2829 let f = eval_expr(&p.obj_nonlinear, &[0.0, 0.0]);
2831 assert!((f - 5.0).abs() < 1e-12);
2832 let f = eval_expr(&p.obj_nonlinear, &[1.0, 2.0]);
2834 assert!(f.abs() < 1e-12);
2835 }
2836
2837 #[test]
2838 fn gradient_matches_analytic() {
2839 let p = parse_nl_text(SIMPLE).expect("parse");
2840 let x = [0.5, 1.0];
2841 let mut g = [0.0_f64; 2];
2842 grad_expr(&p.obj_nonlinear, &x, 1.0, &mut g);
2843 assert!((g[0] - (-1.0)).abs() < 1e-12);
2846 assert!((g[1] - (-2.0)).abs() < 1e-12);
2847 }
2848
2849 #[test]
2860 fn variables_linearity_tags_obj_nonlinear_vs_linear_vars() {
2861 let obj_nl = Expr::Binary(
2863 BinOp::Pow,
2864 Box::new(Expr::Binary(
2865 BinOp::Sub,
2866 Box::new(Expr::Var(0)),
2867 Box::new(Expr::Const(1.0)),
2868 )),
2869 Box::new(Expr::Const(2.0)),
2870 );
2871 let prob = NlProblem {
2872 n: 2,
2873 m: 0,
2874 num_obj: 1,
2875 minimize: true,
2876 obj_nonlinear: obj_nl,
2877 obj_linear: vec![(1, 3.0)],
2878 obj_constant: 0.0,
2879 con_nonlinear: vec![],
2880 con_linear: vec![],
2881 x_l: vec![f64::NEG_INFINITY; 2],
2882 x_u: vec![f64::INFINITY; 2],
2883 g_l: vec![],
2884 g_u: vec![],
2885 x0: vec![0.0; 2],
2886 lambda0: vec![],
2887 suffixes: NlSuffixes::default(),
2888 imported_funcs: vec![],
2889 var_names: vec![],
2890 con_names: vec![],
2891 };
2892 let mut tnlp = NlTnlp::new(prob);
2893 let mut types = vec![Linearity::Linear; 2];
2894 let ok = tnlp.get_variables_linearity(&mut types);
2895 assert!(
2897 ok,
2898 "get_variables_linearity must report it filled the slice"
2899 );
2900 assert!(
2901 matches!(types[0], Linearity::NonLinear),
2902 "x0 is nonlinear in the objective"
2903 );
2904 assert!(
2905 matches!(types[1], Linearity::Linear),
2906 "x1 appears only in the linear part"
2907 );
2908 }
2909
2910 #[test]
2918 fn objective_variables_linearity_ignores_constraint_nonlinearity() {
2919 let con_nl = Expr::Binary(
2921 BinOp::Pow,
2922 Box::new(Expr::Var(0)),
2923 Box::new(Expr::Const(2.0)),
2924 );
2925 let prob = NlProblem {
2926 n: 2,
2927 m: 1,
2928 num_obj: 1,
2929 minimize: true,
2930 obj_nonlinear: Expr::Const(0.0),
2931 obj_linear: vec![(1, 3.0)],
2932 obj_constant: 0.0,
2933 con_nonlinear: vec![con_nl],
2934 con_linear: vec![vec![]],
2935 x_l: vec![f64::NEG_INFINITY; 2],
2936 x_u: vec![f64::INFINITY; 2],
2937 g_l: vec![4.0],
2938 g_u: vec![4.0],
2939 x0: vec![0.0; 2],
2940 lambda0: vec![0.0],
2941 suffixes: NlSuffixes::default(),
2942 imported_funcs: vec![],
2943 var_names: vec![],
2944 con_names: vec![],
2945 };
2946 let mut tnlp = NlTnlp::new(prob);
2947
2948 let mut global = vec![Linearity::Linear; 2];
2949 assert!(tnlp.get_variables_linearity(&mut global));
2950 assert!(
2951 matches!(global[0], Linearity::NonLinear),
2952 "global tags see x0's constraint nonlinearity"
2953 );
2954
2955 let mut obj = vec![Linearity::NonLinear; 2];
2956 assert!(tnlp.get_objective_variables_linearity(&mut obj));
2957 assert!(
2958 matches!(obj[0], Linearity::Linear),
2959 "x0 is linear w.r.t. the objective despite the nonlinear constraint"
2960 );
2961 assert!(
2962 matches!(obj[1], Linearity::Linear),
2963 "x1 is linear everywhere"
2964 );
2965 }
2966
2967 const EQ_LIN: &str = "g3 0 1 0
29862 1 1 0 0
29870 1
29880 0
29890 2 0
29900 0 0 1
29910 0 0 0 0
29922 0
29930 0
29940 0 0 0 0
2995C0
2996n0
2997O0 0
2998o0
2999o5
3000v0
3001n2
3002o5
3003v1
3004n2
3005r
30064 1
3007b
30083
30093
3010k1
30112
3012J0 2
30130 1
30141 1
3015";
3016
3017 #[test]
3018 fn parses_constrained_problem() {
3019 let p = parse_nl_text(EQ_LIN).expect("parse");
3020 assert_eq!(p.n, 2);
3021 assert_eq!(p.m, 1);
3022 assert!((p.g_l[0] - 1.0).abs() < 1e-12);
3024 assert!((p.g_u[0] - 1.0).abs() < 1e-12);
3025 assert_eq!(p.con_linear[0], vec![(0, 1.0), (1, 1.0)]);
3027 }
3028
3029 #[test]
3030 fn malformed_j_variable_index_is_parse_error_not_panic() {
3031 let bad = EQ_LIN.replace("J0 2\n0 1\n1 1\n", "J0 2\n0 1\n5 1\n");
3037 assert_ne!(bad, EQ_LIN, "fixture substitution must apply");
3038 let err = parse_nl_text(&bad).expect_err("out-of-range J var must error");
3039 assert!(err.contains("out of range"), "unexpected error: {err}");
3040 }
3041
3042 #[test]
3043 fn out_of_range_x_segment_index_is_parse_error() {
3044 let bad = format!("{EQ_LIN}x1\n5 0.5\n");
3048 let err = parse_nl_text(&bad).expect_err("out-of-range x index must error");
3049 assert!(err.contains("out of range"), "unexpected error: {err}");
3050 }
3051
3052 #[test]
3053 fn k_segment_nonstandard_count_is_parse_error_at_source() {
3054 let bad = EQ_LIN.replace("k1\n2\n", "k0\n");
3064 assert_ne!(bad, EQ_LIN, "fixture substitution must apply");
3065 let err = parse_nl_text(&bad).expect_err("nonstandard k count must error");
3066 assert!(
3067 err.contains("k-segment declares"),
3068 "expected a clear k-segment count error, got: {err}"
3069 );
3070 }
3071
3072 #[test]
3073 fn get_starting_point_returns_nl_initial_duals() {
3074 let nl = format!("{EQ_LIN}\nd1\n0 2.5\n");
3082 let p = parse_nl_text(&nl).expect("parse");
3083 assert_eq!(p.lambda0, vec![2.5], "the `d` segment fills lambda0");
3084
3085 let mut t = NlTnlp::new(p);
3086 let info = t.get_nlp_info().unwrap();
3087 let (n, m) = (info.n as usize, info.m as usize);
3088
3089 let mut x = vec![0.0; n];
3092 let mut z_l = vec![0.0; n];
3093 let mut z_u = vec![0.0; n];
3094 let mut lambda = vec![0.0; m];
3095 assert!(t.get_starting_point(StartingPoint {
3096 init_x: true,
3097 x: &mut x,
3098 init_z: false,
3099 z_l: &mut z_l,
3100 z_u: &mut z_u,
3101 init_lambda: true,
3102 lambda: &mut lambda,
3103 }));
3104 assert_eq!(
3105 lambda,
3106 vec![2.5],
3107 "a warm start must use the `.nl` initial duals, not zero"
3108 );
3109
3110 let mut lambda_untouched = vec![7.0; m];
3113 assert!(t.get_starting_point(StartingPoint {
3114 init_x: true,
3115 x: &mut x,
3116 init_z: false,
3117 z_l: &mut z_l,
3118 z_u: &mut z_u,
3119 init_lambda: false,
3120 lambda: &mut lambda_untouched,
3121 }));
3122 assert_eq!(
3123 lambda_untouched,
3124 vec![7.0],
3125 "without init_lambda the multiplier buffer must be untouched"
3126 );
3127 }
3128
3129 #[test]
3130 fn constrained_tnlp_eval_g_jac_h() {
3131 let p = parse_nl_text(EQ_LIN).expect("parse");
3132 let mut t = NlTnlp::new(p);
3133 let info = t.get_nlp_info().unwrap();
3134 assert_eq!(info.m, 1);
3135 assert_eq!(info.nnz_jac_g, 2);
3136
3137 let mut g = [0.0_f64; 1];
3139 assert!(t.eval_g(&[0.3, 0.4], true, &mut g));
3140 assert!((g[0] - 0.7).abs() < 1e-12);
3141
3142 let mut irow = [0_i32; 2];
3144 let mut jcol = [0_i32; 2];
3145 assert!(t.eval_jac_g(
3146 None,
3147 true,
3148 SparsityRequest::Structure {
3149 irow: &mut irow,
3150 jcol: &mut jcol
3151 }
3152 ));
3153 assert_eq!(irow, [0, 0]);
3154 assert_eq!(jcol, [0, 1]);
3155
3156 let mut vals = [0.0_f64; 2];
3158 assert!(t.eval_jac_g(
3159 Some(&[0.3, 0.4]),
3160 true,
3161 SparsityRequest::Values { values: &mut vals }
3162 ));
3163 assert!((vals[0] - 1.0).abs() < 1e-12);
3164 assert!((vals[1] - 1.0).abs() < 1e-12);
3165
3166 assert_eq!(info.nnz_h_lag, 2);
3171 let mut hirow = [0_i32; 2];
3172 let mut hjcol = [0_i32; 2];
3173 assert!(t.eval_h(
3174 None,
3175 true,
3176 1.0,
3177 None,
3178 true,
3179 SparsityRequest::Structure {
3180 irow: &mut hirow,
3181 jcol: &mut hjcol
3182 }
3183 ));
3184 assert_eq!(hirow, [0, 1]);
3185 assert_eq!(hjcol, [0, 1]);
3186 let mut hvals = [0.0_f64; 2];
3187 assert!(t.eval_h(
3188 Some(&[0.3, 0.4]),
3189 true,
3190 1.0,
3191 Some(&[0.5]),
3192 true,
3193 SparsityRequest::Values { values: &mut hvals }
3194 ));
3195 assert!((hvals[0] - 2.0).abs() < 1e-12);
3196 assert!((hvals[1] - 2.0).abs() < 1e-12);
3197 }
3198
3199 const CSE_OBJ: &str = "g3 0 1 0
32032 0 1 0 0
32040 1
32050 0
32060 2 0
32070 0 0 1
32080 0 0 0 0
32090 0
32100 0
32110 1 0 0 0
3212V2 0 0
3213o0
3214v0
3215v1
3216O0 0
3217o0
3218o5
3219v2
3220n2
3221v2
3222b
32233
32243
3225";
3226
3227 #[test]
3228 fn parses_v_segment_cse() {
3229 let p = parse_nl_text(CSE_OBJ).expect("parse");
3230 assert_eq!(p.n, 2);
3231 let f = eval_expr(&p.obj_nonlinear, &[1.0, 2.0]);
3233 assert!((f - 12.0).abs() < 1e-12, "got {f}");
3234 let mut g = [0.0_f64; 2];
3236 grad_expr(&p.obj_nonlinear, &[1.0, 2.0], 1.0, &mut g);
3237 assert!((g[0] - 7.0).abs() < 1e-12, "g[0]={}", g[0]);
3238 assert!((g[1] - 7.0).abs() < 1e-12, "g[1]={}", g[1]);
3239 let mut vs = BTreeSet::new();
3241 collect_vars(&p.obj_nonlinear, &mut vs);
3242 assert_eq!(vs.into_iter().collect::<Vec<_>>(), vec![0, 1]);
3243 }
3244
3245 const WITH_SUFFIXES: &str = "g3 0 1 0
32511 0 1 0 0
32520 1
32530 0
32540 1 0
32550 0 0 1
32560 0 0 0 0
32570 0
32580 0
32590 0 0 0 0
3260O0 0
3261o5
3262o1
3263v0
3264n1
3265n2
3266b
32673
3268S0 1 sens_state_1
32690 7
3270S4 1 sens_state_value_1
32710 4.5
3272";
3273
3274 #[test]
3275 fn parses_var_int_and_var_real_suffixes() {
3276 let p = parse_nl_text(WITH_SUFFIXES).expect("parse");
3277 let v = p.suffixes.var_int.get("sens_state_1").expect("var_int");
3279 assert_eq!(v.as_slice(), &[7]);
3280 let r = p
3282 .suffixes
3283 .var_real
3284 .get("sens_state_value_1")
3285 .expect("var_real");
3286 assert_eq!(r.len(), 1);
3287 assert!((r[0] - 4.5).abs() < 1e-12);
3288 assert!(p.suffixes.con_int.is_empty());
3290 assert!(p.suffixes.con_real.is_empty());
3291 }
3292
3293 const WITH_CON_SUFFIX: &str = "g3 0 1 0
32962 2 1 0 0
32970 0
32980 0
32990 2 0
33000 0 0 1
33010 0 0 0 0
33022 0
33030 0
33040 0 0 0 0 0
3305C0
3306n0
3307C1
3308n0
3309O0 0
3310n0
3311r
33124 0.0
33134 0.0
3314b
33153
33163
3317k1
33180
3319J0 2
33200 1
33211 1
3322J1 2
33230 1
33241 -1
3325S1 2 sens_init_constr
33260 1
33271 2
3328";
3329
3330 #[test]
3331 fn parses_con_int_suffix() {
3332 let p = parse_nl_text(WITH_CON_SUFFIX).expect("parse");
3333 let s = p.suffixes.con_int.get("sens_init_constr").expect("con_int");
3334 assert_eq!(s.as_slice(), &[1, 2]);
3336 }
3337
3338 #[test]
3339 fn rejects_suffix_with_out_of_range_index() {
3340 let bad = WITH_CON_SUFFIX.replace("1 2\n", "5 2\n"); let err = parse_nl_text(&bad).expect_err("must reject");
3342 assert!(
3343 err.contains("out of range"),
3344 "expected out-of-range error, got: {err}"
3345 );
3346 }
3347
3348 #[test]
3349 fn tnlp_round_trip_solves() {
3350 let p = parse_nl_text(SIMPLE).expect("parse");
3351 let mut tnlp = NlTnlp::new(p);
3352 let info = tnlp.get_nlp_info().unwrap();
3353 assert_eq!(info.n, 2);
3354 assert_eq!(info.m, 0);
3355 let f0 = tnlp.eval_f(&[0.0, 0.0], true).unwrap();
3356 assert!((f0 - 5.0).abs() < 1e-12);
3357 let mut g = [0.0_f64; 2];
3358 tnlp.eval_grad_f(&[0.0, 0.0], true, &mut g);
3359 assert!((g[0] - (-2.0)).abs() < 1e-12);
3361 assert!((g[1] - (-4.0)).abs() < 1e-12);
3362 }
3363
3364 use pounce_nlp::expression_provider::ExpressionProvider;
3371 use std::sync::atomic::{AtomicUsize, Ordering};
3372
3373 fn scratch_dir(tag: &str) -> std::path::PathBuf {
3375 static N: AtomicUsize = AtomicUsize::new(0);
3376 let seq = N.fetch_add(1, Ordering::Relaxed);
3377 let dir = std::env::temp_dir().join(format!(
3378 "pounce_nlnames_{}_{}_{}",
3379 std::process::id(),
3380 tag,
3381 seq
3382 ));
3383 std::fs::create_dir_all(&dir).expect("create scratch dir");
3384 dir
3385 }
3386
3387 #[test]
3388 fn read_name_file_reads_in_order() {
3389 let dir = scratch_dir("col_order");
3390 let p = dir.join("m.col");
3391 std::fs::write(&p, "x_in\nT_reactor\nflow\n").unwrap();
3392 assert_eq!(read_name_file(&p, 3), vec!["x_in", "T_reactor", "flow"]);
3393 }
3394
3395 #[test]
3396 fn read_name_file_truncates_extra_lines() {
3397 let dir = scratch_dir("row_obj");
3401 let p = dir.join("m.row");
3402 std::fs::write(&p, "mass_balance\nenergy_balance\nobj\n").unwrap();
3403 assert_eq!(
3404 read_name_file(&p, 2),
3405 vec!["mass_balance", "energy_balance"]
3406 );
3407 }
3408
3409 #[test]
3410 fn read_name_file_empty_on_short_or_missing() {
3411 let dir = scratch_dir("short");
3412 let short = dir.join("m.col");
3413 std::fs::write(&short, "only_one\n").unwrap();
3414 assert!(read_name_file(&short, 3).is_empty());
3416 assert!(read_name_file(&dir.join("absent.col"), 2).is_empty());
3418 }
3419
3420 #[test]
3421 fn read_nl_file_captures_sibling_names() {
3422 let dir = scratch_dir("sibling");
3425 let nl = dir.join("m.nl");
3426 std::fs::write(&nl, SIMPLE).unwrap();
3427 std::fs::write(dir.join("m.col"), "alpha\nbeta\n").unwrap();
3428
3429 let prob = read_nl_file(&nl).expect("parse + name capture");
3430 assert_eq!(prob.var_names, vec!["alpha", "beta"]);
3431 assert!(prob.con_names.is_empty()); let tnlp = NlTnlp::new(prob);
3434 assert_eq!(tnlp.variable_name(0), Some("alpha"));
3435 assert_eq!(tnlp.variable_name(1), Some("beta"));
3436 assert_eq!(tnlp.variable_name(2), None); }
3438
3439 #[test]
3440 fn read_nl_file_without_names_yields_empty() {
3441 let dir = scratch_dir("noname");
3442 let nl = dir.join("m.nl");
3443 std::fs::write(&nl, SIMPLE).unwrap();
3444 let prob = read_nl_file(&nl).expect("parse");
3445 assert!(prob.var_names.is_empty());
3446 assert!(prob.con_names.is_empty());
3447 let tnlp = NlTnlp::new(prob);
3448 assert_eq!(tnlp.variable_name(0), None);
3449 }
3450
3451 #[test]
3452 fn read_nl_file_resolves_extensionless_ampl_stub() {
3453 let dir = scratch_dir("stub");
3457 std::fs::write(dir.join("mystub.nl"), SIMPLE).unwrap();
3458 let stub = dir.join("mystub");
3460 assert!(!stub.exists(), "stub must be extensionless / absent");
3461 let prob = read_nl_file(&stub).expect("stub should resolve to mystub.nl");
3462 assert_eq!(prob.n, 2);
3463 assert_eq!(prob.m, 0);
3464
3465 std::fs::write(dir.join("mystub.col"), "alpha\nbeta\n").unwrap();
3467 let prob = read_nl_file(&stub).expect("stub resolves, names ride along");
3468 assert_eq!(prob.var_names, vec!["alpha", "beta"]);
3469 }
3470
3471 #[test]
3472 fn read_nl_file_prefers_exact_path_over_nl_sibling() {
3473 let dir = scratch_dir("exact");
3477 std::fs::write(dir.join("data"), SIMPLE).unwrap();
3479 std::fs::write(dir.join("data.nl"), "not an nl file").unwrap();
3480 let prob = read_nl_file(&dir.join("data")).expect("exact path wins");
3481 assert_eq!(prob.n, 2);
3482 }
3483
3484 #[test]
3485 fn append_extension_appends_rather_than_replaces() {
3486 use std::path::Path;
3487 assert_eq!(
3488 append_extension(Path::new("mystub"), "nl"),
3489 Path::new("mystub.nl")
3490 );
3491 assert_eq!(
3494 append_extension(Path::new("my.model"), "nl"),
3495 Path::new("my.model.nl")
3496 );
3497 }
3498
3499 fn names(v: &[&str]) -> Vec<String> {
3502 v.iter().map(|s| s.to_string()).collect()
3503 }
3504
3505 #[test]
3506 fn render_uses_variable_names_when_present() {
3507 let e = Expr::Binary(BinOp::Mul, Box::new(Expr::Var(0)), Box::new(Expr::Var(1)));
3508 assert_eq!(render_expr(&e, &names(&["T", "flow"]), &[]), "T*flow");
3509 assert_eq!(render_expr(&e, &[], &[]), "x[0]*x[1]");
3511 }
3512
3513 #[test]
3514 fn render_parenthesizes_by_precedence() {
3515 let sum = Expr::Binary(BinOp::Add, Box::new(Expr::Var(0)), Box::new(Expr::Var(1)));
3517 let e = Expr::Binary(BinOp::Mul, Box::new(sum), Box::new(Expr::Var(2)));
3518 assert_eq!(render_expr(&e, &[], &[]), "(x[0] + x[1])*x[2]");
3519
3520 let mul = Expr::Binary(BinOp::Mul, Box::new(Expr::Var(1)), Box::new(Expr::Var(2)));
3522 let e2 = Expr::Binary(BinOp::Add, Box::new(Expr::Var(0)), Box::new(mul));
3523 assert_eq!(render_expr(&e2, &[], &[]), "x[0] + x[1]*x[2]");
3524 }
3525
3526 #[test]
3527 fn render_subtraction_right_assoc_parens() {
3528 let inner = Expr::Binary(BinOp::Sub, Box::new(Expr::Var(1)), Box::new(Expr::Var(2)));
3530 let e = Expr::Binary(BinOp::Sub, Box::new(Expr::Var(0)), Box::new(inner));
3531 assert_eq!(render_expr(&e, &[], &[]), "x[0] - (x[1] - x[2])");
3532 }
3533
3534 #[test]
3535 fn render_functions_and_pow() {
3536 let sq = Expr::Binary(
3537 BinOp::Pow,
3538 Box::new(Expr::Var(0)),
3539 Box::new(Expr::Const(2.0)),
3540 );
3541 let e = Expr::Unary(UnaryOp::Exp, Box::new(sq));
3542 assert_eq!(render_expr(&e, &names(&["q"]), &[]), "exp(q^2)");
3543 }
3544
3545 #[test]
3546 fn render_linear_signs_are_tidy() {
3547 let lin = vec![(0usize, 1.0), (1, -2.0), (2, 1.0)];
3549 assert_eq!(render_linear(&lin, &names(&["a", "b", "c"])), "a - 2*b + c");
3550 }
3551
3552 #[test]
3553 fn render_linear_skips_zero_coefficients() {
3554 let lin = vec![(0usize, 1.0), (1, 0.0), (2, -3.0)];
3557 assert_eq!(render_linear(&lin, &names(&["a", "b", "c"])), "a - 3*c");
3558 let lin = vec![(0usize, 0.0), (1, 2.0)];
3560 assert_eq!(render_linear(&lin, &names(&["a", "b"])), "2*b");
3561 }
3562
3563 #[test]
3564 fn render_sum_folds_negative_terms() {
3565 let sq = |i| {
3567 Expr::Binary(
3568 BinOp::Pow,
3569 Box::new(Expr::Var(i)),
3570 Box::new(Expr::Const(2.0)),
3571 )
3572 };
3573 let neg = |i| {
3574 Expr::Binary(
3575 BinOp::Mul,
3576 Box::new(Expr::Const(-1.0)),
3577 Box::new(Expr::Var(i)),
3578 )
3579 };
3580 let e = Expr::Sum(vec![
3581 sq(0),
3582 neg(1),
3583 Expr::Unary(UnaryOp::Neg, Box::new(Expr::Var(2))),
3584 ]);
3585 assert_eq!(
3586 render_expr(&e, &names(&["a", "b", "c"]), &[]),
3587 "a^2 - 1*b - c"
3588 );
3589 }
3590
3591 #[test]
3592 fn render_constraint_equation_forms() {
3593 let mut prob = parse_nl_text(SIMPLE).unwrap();
3595 prob.n = 2;
3597 prob.m = 2;
3598 prob.var_names = names(&["mass_in", "mass_out"]);
3599 prob.con_names = names(&["balance", "window"]);
3600 prob.con_linear = vec![
3601 vec![(0, 1.0), (1, -1.0)], vec![(0, 1.0)], ];
3604 prob.con_nonlinear = vec![Expr::Const(0.0), Expr::Const(0.0)];
3605 prob.g_l = vec![0.0, 0.0];
3606 prob.g_u = vec![0.0, 500.0];
3607
3608 assert_eq!(
3609 render_constraint_equation(&prob, 0),
3610 "mass_in - mass_out = 0"
3611 );
3612 assert_eq!(render_constraint_equation(&prob, 1), "0 <= mass_in <= 500");
3613
3614 let all = render_all_constraint_equations(&prob);
3615 assert_eq!(all.len(), 2);
3616 assert_eq!(all[1], "0 <= mass_in <= 500");
3617 }
3618
3619 #[test]
3620 fn constraint_jacobian_sparsity_unions_linear_and_nonlinear() {
3621 let mut prob = parse_nl_text(SIMPLE).unwrap();
3622 prob.n = 3;
3623 prob.m = 2;
3624 prob.con_linear = vec![vec![(1, 4.0)], vec![(2, 1.0)]];
3627 prob.con_nonlinear = vec![
3628 Expr::Binary(BinOp::Mul, Box::new(Expr::Var(0)), Box::new(Expr::Var(2))),
3629 Expr::Const(0.0),
3630 ];
3631 prob.g_l = vec![0.0, 0.0];
3632 prob.g_u = vec![0.0, 0.0];
3633
3634 let (irow, jcol) = constraint_jacobian_sparsity(&prob);
3635 assert_eq!(irow, vec![0, 0, 0, 1]);
3637 assert_eq!(jcol, vec![0, 1, 2, 2]);
3638 }
3639
3640 #[test]
3641 fn funcall_string_arg_with_hash_is_not_truncated() {
3642 let mut p = Parser::new("h3:a#b\n");
3648 match p.parse_funcall_arg().expect("parse hollerith arg") {
3649 FuncallArg::Str(s) => assert_eq!(s, "a#b"),
3650 other => panic!("expected Str, got {other:?}"),
3651 }
3652 }
3653
3654 #[test]
3655 fn funcall_string_arg_honors_declared_length() {
3656 let mut p = Parser::new("h3:abc # trailing comment\n");
3660 match p.parse_funcall_arg().expect("parse hollerith arg") {
3661 FuncallArg::Str(s) => assert_eq!(s, "abc"),
3662 other => panic!("expected Str, got {other:?}"),
3663 }
3664 }
3665
3666 fn parse_one_expr(n: usize, expr_src: &str) -> Expr {
3678 let mut p = Parser::new(expr_src);
3679 p.n = n;
3680 p.parse_expr().expect("parse expression")
3681 }
3682
3683 #[test]
3684 fn opcode_o82_square_is_unary_pow_of_two() {
3685 let e = parse_one_expr(1, "o82\nv0\n");
3687 match &e {
3688 Expr::Binary(BinOp::Pow, base, exp) => {
3689 assert!(matches!(**base, Expr::Var(0)));
3690 match **exp {
3691 Expr::Const(c) => assert!((c - 2.0).abs() < 1e-12, "exp const = {c}"),
3692 ref other => panic!("o82 exponent must be Const(2.0), got {other:?}"),
3693 }
3694 }
3695 other => panic!("o82 must parse to Pow(base, 2), got {other:?}"),
3696 }
3697 assert!((eval_expr(&e, &[3.0]) - 9.0).abs() < 1e-12);
3700 assert!((eval_expr(&e, &[-3.0]) - 9.0).abs() < 1e-12);
3701 let mut g = [0.0_f64; 1];
3703 grad_expr(&e, &[3.0], 1.0, &mut g);
3704 assert!((g[0] - 6.0).abs() < 1e-9, "grad at 3 = {}", g[0]);
3705 g[0] = 0.0;
3706 grad_expr(&e, &[-3.0], 1.0, &mut g);
3707 assert!((g[0] + 6.0).abs() < 1e-9, "grad at -3 = {}", g[0]);
3708 }
3709
3710 #[test]
3711 fn opcode_o81_const_exponent_is_base_pow_const() {
3712 let e = parse_one_expr(1, "o81\nv0\nn3\n");
3714 match &e {
3715 Expr::Binary(BinOp::Pow, base, exp) => {
3716 assert!(matches!(**base, Expr::Var(0)), "base must be the variable");
3717 match **exp {
3718 Expr::Const(c) => assert!((c - 3.0).abs() < 1e-12, "exp const = {c}"),
3719 ref other => panic!("o81 exponent must be Const(3.0), got {other:?}"),
3720 }
3721 }
3722 other => panic!("o81 must parse to Pow(var, const), got {other:?}"),
3723 }
3724 assert!((eval_expr(&e, &[2.0]) - 8.0).abs() < 1e-12);
3726 assert!((eval_expr(&e, &[-2.0]) + 8.0).abs() < 1e-12);
3729 let mut g = [0.0_f64; 1];
3731 grad_expr(&e, &[2.0], 1.0, &mut g);
3732 assert!((g[0] - 12.0).abs() < 1e-9, "grad at 2 = {}", g[0]);
3733 }
3734
3735 #[test]
3736 fn opcode_o83_const_base_is_const_pow_exp() {
3737 let e = parse_one_expr(1, "o83\nn2\nv0\n");
3739 match &e {
3740 Expr::Binary(BinOp::Pow, base, exp) => {
3741 match **base {
3742 Expr::Const(c) => assert!((c - 2.0).abs() < 1e-12, "base const = {c}"),
3743 ref other => panic!("o83 base must be Const(2.0), got {other:?}"),
3744 }
3745 assert!(
3746 matches!(**exp, Expr::Var(0)),
3747 "exponent must be the variable"
3748 );
3749 }
3750 other => panic!("o83 must parse to Pow(const, var), got {other:?}"),
3751 }
3752 assert!((eval_expr(&e, &[3.0]) - 8.0).abs() < 1e-12);
3754 assert!((eval_expr(&e, &[0.0]) - 1.0).abs() < 1e-12);
3755 let mut g = [0.0_f64; 1];
3757 grad_expr(&e, &[3.0], 1.0, &mut g);
3758 assert!(
3759 (g[0] - 8.0 * 2.0_f64.ln()).abs() < 1e-9,
3760 "grad at 3 = {} (want {})",
3761 g[0],
3762 8.0 * 2.0_f64.ln()
3763 );
3764 }
3765
3766 #[test]
3767 fn power_specializations_agree_with_general_o5() {
3768 let o5_sq = parse_one_expr(1, "o5\nv0\nn2\n"); let o82 = parse_one_expr(1, "o82\nv0\n");
3773 let o5_cube = parse_one_expr(1, "o5\nv0\nn3\n"); let o81 = parse_one_expr(1, "o81\nv0\nn3\n");
3775 let o5_exp = parse_one_expr(1, "o5\nn2\nv0\n"); let o83 = parse_one_expr(1, "o83\nn2\nv0\n");
3777 for &x in &[-2.0_f64, -0.5, 0.0, 1.0, 2.5, 4.0] {
3778 assert!((eval_expr(&o82, &[x]) - eval_expr(&o5_sq, &[x])).abs() < 1e-12);
3779 assert!((eval_expr(&o81, &[x]) - eval_expr(&o5_cube, &[x])).abs() < 1e-12);
3780 assert!((eval_expr(&o83, &[x]) - eval_expr(&o5_exp, &[x])).abs() < 1e-12);
3782 }
3783 }
3784
3785 #[test]
3786 fn power_opcodes_round_trip_through_parse_nl_text() {
3787 let nl = SIMPLE.replace(
3791 "o0\no5\no1\nv0\nn1\nn2\no5\no1\nv1\nn2\nn2\n",
3792 "o0\no82\nv0\no82\nv1\n",
3793 );
3794 assert_ne!(nl, SIMPLE, "fixture substitution must apply");
3795 let p = parse_nl_text(&nl).expect("parse o82 objective");
3796 assert!((eval_expr(&p.obj_nonlinear, &[3.0, 4.0]) - 25.0).abs() < 1e-12);
3798 assert!((eval_expr(&p.obj_nonlinear, &[-3.0, -4.0]) - 25.0).abs() < 1e-12);
3799 }
3800
3801 #[test]
3802 fn power_opcode_o81_evaluates_through_the_tape_at_negative_base() {
3803 let nl = SIMPLE.replace(
3809 "o0\no5\no1\nv0\nn1\nn2\no5\no1\nv1\nn2\nn2\n",
3810 "o0\no81\nv0\nn3\no81\nv1\nn3\n",
3811 );
3812 assert_ne!(nl, SIMPLE, "fixture substitution must apply");
3813 let p = parse_nl_text(&nl).expect("parse o81 objective");
3814 let mut tnlp = NlTnlp::new(p);
3815 tnlp.get_nlp_info().unwrap();
3816 let f = tnlp.eval_f(&[-2.0, 1.0], true).unwrap();
3818 assert!((f + 7.0).abs() < 1e-12, "f(-2,1) = {f}");
3819 let mut g = [0.0_f64; 2];
3821 assert!(tnlp.eval_grad_f(&[-2.0, 1.0], true, &mut g));
3822 assert!((g[0] - 12.0).abs() < 1e-9, "df/dx0 = {}", g[0]);
3823 assert!((g[1] - 3.0).abs() < 1e-9, "df/dx1 = {}", g[1]);
3824 }
3825}