1use crate::nl_tape::Tape;
36use pounce_common::types::{Index, Number};
37use pounce_nlp::tnlp::{
38 BoundsInfo, IndexStyle, IpoptCq, IpoptData, Linearity, MetaData, NlpInfo, Solution,
39 SparsityRequest, StartingPoint, IDX_NAMES, TNLP,
40};
41use std::cell::RefCell;
42use std::collections::{BTreeMap, BTreeSet, HashMap};
43use std::path::Path;
44use std::rc::Rc;
45use std::sync::Arc;
46
47#[derive(Debug, Clone)]
48pub enum Expr {
49 Const(Number),
51 Var(usize),
53 Binary(BinOp, Box<Expr>, Box<Expr>),
55 Unary(UnaryOp, Box<Expr>),
57 Sum(Vec<Expr>),
60 Cse(Arc<Expr>),
71 Funcall { id: usize, args: Vec<FuncallArg> },
75 Compare(CmpOp, Box<Expr>, Box<Expr>),
81 And(Box<Expr>, Box<Expr>),
84 Or(Box<Expr>, Box<Expr>),
87 Not(Box<Expr>),
90 Cond {
96 cond: Box<Expr>,
97 then_: Box<Expr>,
98 else_: Box<Expr>,
99 },
100 MinList(Vec<Expr>),
106 MaxList(Vec<Expr>),
109}
110
111#[derive(Debug, Clone, Copy, PartialEq, Eq)]
115pub enum CmpOp {
116 Lt,
117 Le,
118 Eq,
119 Ge,
120 Gt,
121 Ne,
122}
123
124#[derive(Debug, Clone)]
128pub enum FuncallArg {
129 Real(Expr),
130 Str(String),
131}
132
133#[derive(Debug, Clone)]
136pub struct ImportedFunc {
137 pub id: usize,
138 pub kind: usize,
140 pub nargs: i64,
142 pub name: String,
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq)]
146pub enum BinOp {
147 Add,
148 Sub,
149 Mul,
150 Div,
151 Pow,
152 Atan2,
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum UnaryOp {
158 Neg,
159 Sqrt,
160 Log,
161 Exp,
162 Abs,
163 Sin,
164 Cos,
165 Log10,
166 Tan,
167 Atan,
168 Acos,
169 Sinh,
170 Cosh,
171 Tanh,
172 Asin,
173 Acosh,
174 Asinh,
175 Atanh,
176}
177
178#[derive(Debug, Clone)]
180pub struct NlProblem {
181 pub n: usize,
182 pub m: usize,
183 pub num_obj: usize,
184 pub minimize: bool,
185 pub obj_nonlinear: Expr,
186 pub obj_linear: Vec<(usize, Number)>,
187 pub obj_constant: Number,
188 pub con_nonlinear: Vec<Expr>,
190 pub con_linear: Vec<Vec<(usize, Number)>>,
192 pub x_l: Vec<Number>,
193 pub x_u: Vec<Number>,
194 pub g_l: Vec<Number>,
195 pub g_u: Vec<Number>,
196 pub x0: Vec<Number>,
197 pub lambda0: Vec<Number>,
198 pub suffixes: NlSuffixes,
206 pub imported_funcs: Vec<ImportedFunc>,
210 pub var_names: Vec<String>,
221 pub con_names: Vec<String>,
225}
226
227#[derive(Debug, Clone, Default)]
232pub struct NlSuffixes {
233 pub var_int: BTreeMap<String, Vec<Index>>,
236 pub con_int: BTreeMap<String, Vec<Index>>,
238 pub obj_int: BTreeMap<String, Vec<Index>>,
240 pub problem_int: BTreeMap<String, Index>,
242 pub var_real: BTreeMap<String, Vec<Number>>,
244 pub con_real: BTreeMap<String, Vec<Number>>,
246 pub obj_real: BTreeMap<String, Vec<Number>>,
248 pub problem_real: BTreeMap<String, Number>,
250}
251
252pub fn read_nl_file(path: &Path) -> Result<NlProblem, String> {
263 let resolved = if path.exists() {
271 path.to_path_buf()
272 } else {
273 let with_nl = append_extension(path, "nl");
274 if with_nl.exists() {
275 with_nl
276 } else {
277 path.to_path_buf()
278 }
279 };
280 let txt = std::fs::read_to_string(&resolved)
281 .map_err(|e| format!("could not read {}: {}", resolved.display(), e))?;
282 let mut prob = parse_nl_text(&txt)?;
283 prob.var_names = read_name_file(&resolved.with_extension("col"), prob.n);
284 prob.con_names = read_name_file(&resolved.with_extension("row"), prob.m);
285 Ok(prob)
286}
287
288fn append_extension(path: &Path, ext: &str) -> std::path::PathBuf {
294 let mut name = path.as_os_str().to_os_string();
295 name.push(".");
296 name.push(ext);
297 std::path::PathBuf::from(name)
298}
299
300fn read_name_file(path: &Path, expected: usize) -> Vec<String> {
310 let Ok(txt) = std::fs::read_to_string(path) else {
311 return Vec::new();
312 };
313 let names: Vec<String> = txt.lines().take(expected).map(str::to_owned).collect();
314 if names.len() == expected {
315 names
316 } else {
317 Vec::new()
318 }
319}
320
321pub fn parse_nl_text(txt: &str) -> Result<NlProblem, String> {
323 let mut p = Parser::new(txt);
324 p.parse_header()?;
325 let n = p.n;
326 let m = p.m;
327 let num_obj = p.num_obj;
328
329 let mut con_nonlinear: Vec<Expr> = (0..m).map(|_| Expr::Const(0.0)).collect();
330 let mut obj_nonlinear = Expr::Const(0.0);
331 let mut minimize = true;
332 let mut obj_linear: Vec<(usize, Number)> = Vec::new();
333 let mut con_linear: Vec<Vec<(usize, Number)>> = vec![Vec::new(); m];
334 let mut x_l = vec![-1e19; n];
335 let mut x_u = vec![1e19; n];
336 let mut g_l = vec![-1e19; m];
337 let mut g_u = vec![1e19; m];
338 let mut x0 = vec![0.0; n];
339 let mut lambda0 = vec![0.0; m];
340 let mut suffixes = NlSuffixes::default();
341 let mut imported_funcs: Vec<ImportedFunc> = Vec::new();
342
343 while let Some(line) = p.peek_segment_line() {
344 let tag = line
345 .trim_start()
346 .chars()
347 .next()
348 .ok_or("unexpected blank segment header")?;
349 match tag {
350 'C' => {
351 let (_hdr, rest) = p.eat_segment_header()?;
352 let _ = rest;
353 let idx = parse_segment_index(&_hdr, 'C')?;
354 if idx >= m {
355 return Err(format!("C{idx} out of range; m={m}"));
356 }
357 con_nonlinear[idx] = p.parse_expr()?;
358 }
359 'O' => {
360 let (hdr, _rest) = p.eat_segment_header()?;
361 let parts: Vec<&str> = hdr.split_whitespace().collect();
362 if parts.len() < 2 {
363 return Err(format!("malformed O-segment header: {hdr}"));
364 }
365 let idx = parse_segment_index(parts[0], 'O')?;
366 let kind: i32 = parts[1].parse().map_err(|e| format!("O kind: {e}"))?;
367 if idx == 0 {
368 minimize = kind == 0;
369 obj_nonlinear = p.parse_expr()?;
370 } else {
371 let _ = p.parse_expr()?;
373 }
374 }
375 'r' => {
376 p.eat_segment_header()?;
377 for i in 0..m {
378 let line = p.next_data_line()?;
379 let (lo, hi) = parse_bound_line(&line)?;
380 g_l[i] = lo;
381 g_u[i] = hi;
382 }
383 }
384 'b' => {
385 p.eat_segment_header()?;
386 for i in 0..n {
387 let line = p.next_data_line()?;
388 let (lo, hi) = parse_bound_line(&line)?;
389 x_l[i] = lo;
390 x_u[i] = hi;
391 }
392 }
393 'k' => {
394 let (hdr, _) = p.eat_segment_header()?;
407 let declared = parse_segment_index(&hdr, 'k')?;
408 let expected = if n == 0 { 0 } else { n - 1 };
409 if declared != expected {
410 return Err(format!(
411 "k-segment declares {declared} column-count lines but \
412 the standard count for n={n} variables is {expected}"
413 ));
414 }
415 for _ in 0..declared {
416 p.next_data_line()?;
417 }
418 }
419 'J' => {
420 let (hdr, _) = p.eat_segment_header()?;
421 let parts: Vec<&str> = hdr.split_whitespace().collect();
422 if parts.len() < 2 {
423 return Err(format!("malformed J-segment header: {hdr}"));
424 }
425 let row = parse_segment_index(parts[0], 'J')?;
426 let nz: usize = parts[1].parse().map_err(|e| format!("J nz: {e}"))?;
427 if row >= m {
428 return Err(format!("J{row} out of range"));
429 }
430 for _ in 0..nz {
431 let line = p.next_data_line()?;
432 let (var, coef) = parse_var_coef(&line)?;
433 if var >= n {
438 return Err(format!(
439 "J{row} entry variable index {var} out of range (n={n})"
440 ));
441 }
442 con_linear[row].push((var, coef));
443 }
444 }
445 'G' => {
446 let (hdr, _) = p.eat_segment_header()?;
447 let parts: Vec<&str> = hdr.split_whitespace().collect();
448 if parts.len() < 2 {
449 return Err(format!("malformed G-segment header: {hdr}"));
450 }
451 let idx = parse_segment_index(parts[0], 'G')?;
452 let nz: usize = parts[1].parse().map_err(|e| format!("G nz: {e}"))?;
453 let mut acc = Vec::with_capacity(nz);
454 for _ in 0..nz {
455 let line = p.next_data_line()?;
456 let (var, coef) = parse_var_coef(&line)?;
457 if var >= n {
460 return Err(format!(
461 "G{idx} entry variable index {var} out of range (n={n})"
462 ));
463 }
464 acc.push((var, coef));
465 }
466 if idx == 0 {
467 obj_linear = acc;
468 }
469 }
470 'x' => {
471 let (hdr, _) = p.eat_segment_header()?;
472 let parts: Vec<&str> = hdr.split_whitespace().collect();
473 let nx: usize = parts
474 .first()
475 .and_then(|s| s.trim_start_matches('x').parse().ok())
476 .ok_or_else(|| format!("malformed x-segment header: {hdr}"))?;
477 for _ in 0..nx {
478 let line = p.next_data_line()?;
479 let (idx, val) = parse_var_coef(&line)?;
480 if idx >= n {
484 return Err(format!(
485 "x-segment variable index {idx} out of range (n={n})"
486 ));
487 }
488 x0[idx] = val;
489 }
490 }
491 'd' => {
492 let (hdr, _) = p.eat_segment_header()?;
493 let parts: Vec<&str> = hdr.split_whitespace().collect();
494 let nd: usize = parts
495 .first()
496 .and_then(|s| s.trim_start_matches('d').parse().ok())
497 .ok_or_else(|| format!("malformed d-segment header: {hdr}"))?;
498 for _ in 0..nd {
499 let line = p.next_data_line()?;
500 let (idx, val) = parse_var_coef(&line)?;
501 if idx >= m {
505 return Err(format!(
506 "d-segment constraint index {idx} out of range (m={m})"
507 ));
508 }
509 lambda0[idx] = val;
510 }
511 }
512 'V' => p.parse_v_segment()?,
513 'S' => {
514 parse_suffix_segment(&mut p, n, m, num_obj, &mut suffixes)?;
515 }
516 'F' => {
517 let (hdr, _rest) = p.eat_segment_header()?;
520 let parts: Vec<&str> = hdr.split_whitespace().collect();
521 if parts.is_empty() {
522 return Err(format!("malformed F-segment header: '{hdr}'"));
523 }
524 let id = parse_segment_index(parts[0], 'F')?;
525 let kind: usize = parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
526 let nargs: i64 = parts.get(2).and_then(|s| s.parse().ok()).unwrap_or(0);
527 let name = parts.get(3).copied().unwrap_or("").to_string();
528 imported_funcs.push(ImportedFunc {
529 id,
530 kind,
531 nargs,
532 name,
533 });
534 }
535 other => return Err(format!("unknown .nl segment tag '{other}'")),
536 }
537 }
538
539 Ok(NlProblem {
540 n,
541 m,
542 num_obj,
543 minimize,
544 obj_nonlinear,
545 obj_linear,
546 obj_constant: 0.0,
547 con_nonlinear,
548 con_linear,
549 x_l,
550 x_u,
551 g_l,
552 g_u,
553 x0,
554 lambda0,
555 suffixes,
556 imported_funcs,
557 var_names: Vec::new(),
560 con_names: Vec::new(),
561 })
562}
563
564fn parse_suffix_segment(
581 p: &mut Parser,
582 n: usize,
583 m: usize,
584 num_obj: usize,
585 out: &mut NlSuffixes,
586) -> Result<(), String> {
587 let (hdr, _) = p.eat_segment_header()?;
588 let parts: Vec<&str> = hdr.split_whitespace().collect();
589 if parts.len() < 3 {
590 return Err(format!(
591 "malformed S-segment header: '{hdr}' (expected `S<kind> <n> <name>`)"
592 ));
593 }
594 let kind_str = parts[0].trim_start_matches('S');
595 let kind: u32 = kind_str
596 .parse()
597 .map_err(|e| format!("S kind '{kind_str}': {e}"))?;
598 let nentries: usize = parts[1].parse().map_err(|e| format!("S nentries: {e}"))?;
599 let name = parts[2].to_string();
600
601 let is_real = (kind & 0x4) != 0;
602 let target = kind & 0x3;
603 let target_dim = match target {
604 0 => n,
605 1 => m,
606 2 => num_obj,
607 3 => 0, _ => unreachable!("kind & 0x3 is in 0..=3"),
609 };
610
611 let mut int_buf: Vec<Index> = if !is_real && target != 3 {
615 vec![0; target_dim]
616 } else {
617 Vec::new()
618 };
619 let mut real_buf: Vec<Number> = if is_real && target != 3 {
620 vec![0.0; target_dim]
621 } else {
622 Vec::new()
623 };
624 let mut problem_int: Index = 0;
625 let mut problem_real: Number = 0.0;
626
627 for _ in 0..nentries {
628 let line = p.next_data_line()?;
629 let parts: Vec<&str> = line.split_whitespace().collect();
630 if parts.len() < 2 {
631 return Err(format!(
632 "malformed S-segment entry '{line}' (expected `<idx> <value>`)"
633 ));
634 }
635 let idx: usize = parts[0]
636 .parse()
637 .map_err(|e| format!("S entry idx '{}': {e}", parts[0]))?;
638 if target != 3 && idx >= target_dim {
639 return Err(format!(
640 "S-suffix '{name}' index {idx} out of range for target dim {target_dim}"
641 ));
642 }
643 if is_real {
644 let v: Number = parts[1]
645 .parse()
646 .map_err(|e| format!("S real entry value '{}': {e}", parts[1]))?;
647 if target == 3 {
648 problem_real = v;
649 } else {
650 real_buf[idx] = v;
651 }
652 } else {
653 let v: Index = parts[1]
654 .parse()
655 .map_err(|e| format!("S int entry value '{}': {e}", parts[1]))?;
656 if target == 3 {
657 problem_int = v;
658 } else {
659 int_buf[idx] = v;
660 }
661 }
662 }
663
664 match (target, is_real) {
665 (0, false) => {
666 out.var_int.insert(name, int_buf);
667 }
668 (1, false) => {
669 out.con_int.insert(name, int_buf);
670 }
671 (2, false) => {
672 out.obj_int.insert(name, int_buf);
673 }
674 (3, false) => {
675 out.problem_int.insert(name, problem_int);
676 }
677 (0, true) => {
678 out.var_real.insert(name, real_buf);
679 }
680 (1, true) => {
681 out.con_real.insert(name, real_buf);
682 }
683 (2, true) => {
684 out.obj_real.insert(name, real_buf);
685 }
686 (3, true) => {
687 out.problem_real.insert(name, problem_real);
688 }
689 _ => unreachable!(),
690 }
691 Ok(())
692}
693
694fn parse_segment_index(s: &str, tag: char) -> Result<usize, String> {
695 let trimmed = s.trim_start_matches(tag);
696 trimmed
697 .parse()
698 .map_err(|e| format!("malformed {tag}-segment index '{s}': {e}"))
699}
700
701fn parse_bound_line(line: &str) -> Result<(Number, Number), String> {
702 let parts: Vec<&str> = line.split_whitespace().collect();
703 if parts.is_empty() {
704 return Err("empty bound line".into());
705 }
706 let kind: i32 = parts[0].parse().map_err(|e| format!("bound kind: {e}"))?;
707 let lo;
708 let hi;
709 match kind {
710 0 => {
711 if parts.len() < 3 {
713 return Err(format!("bound kind 0 needs 2 values: '{line}'"));
714 }
715 lo = parts[1].parse().map_err(|e| format!("lo: {e}"))?;
716 hi = parts[2].parse().map_err(|e| format!("hi: {e}"))?;
717 }
718 1 => {
719 if parts.len() < 2 {
721 return Err(format!("bound kind 1 needs 1 value: '{line}'"));
722 }
723 lo = -1e19;
724 hi = parts[1].parse().map_err(|e| format!("hi: {e}"))?;
725 }
726 2 => {
727 if parts.len() < 2 {
729 return Err(format!("bound kind 2 needs 1 value: '{line}'"));
730 }
731 lo = parts[1].parse().map_err(|e| format!("lo: {e}"))?;
732 hi = 1e19;
733 }
734 3 => {
735 lo = -1e19;
737 hi = 1e19;
738 }
739 4 => {
740 if parts.len() < 2 {
742 return Err(format!("bound kind 4 needs 1 value: '{line}'"));
743 }
744 let v: Number = parts[1].parse().map_err(|e| format!("eq: {e}"))?;
745 lo = v;
746 hi = v;
747 }
748 5 => return Err("complementarity (kind 5) bounds are not supported".into()),
749 other => return Err(format!("unknown bound kind {other}")),
750 }
751 Ok((lo, hi))
752}
753
754fn parse_var_coef(line: &str) -> Result<(usize, Number), String> {
755 let parts: Vec<&str> = line.split_whitespace().collect();
756 if parts.len() < 2 {
757 return Err(format!("malformed var/coef line: '{line}'"));
758 }
759 let v: usize = parts[0].parse().map_err(|e| format!("var idx: {e}"))?;
760 let c: Number = parts[1].parse().map_err(|e| format!("coef: {e}"))?;
761 Ok((v, c))
762}
763
764struct Parser<'a> {
765 lines: Vec<&'a str>,
766 pos: usize,
767 n: usize,
768 m: usize,
769 num_obj: usize,
770 n_funcs: usize,
772 cses: Vec<Arc<Expr>>,
775}
776
777impl<'a> Parser<'a> {
778 fn new(txt: &'a str) -> Self {
779 let lines: Vec<&str> = txt.lines().collect();
780 Self {
781 lines,
782 pos: 0,
783 n: 0,
784 m: 0,
785 num_obj: 0,
786 n_funcs: 0,
787 cses: Vec::new(),
788 }
789 }
790
791 fn next_line(&mut self) -> Option<&'a str> {
792 while self.pos < self.lines.len() {
793 let l = self.lines[self.pos];
794 self.pos += 1;
795 let trimmed = strip_comment(l).trim();
799 if !trimmed.is_empty() {
800 return Some(l);
801 }
802 }
803 None
804 }
805
806 fn next_data_line(&mut self) -> Result<String, String> {
807 let raw = self
808 .next_line()
809 .ok_or_else(|| "unexpected end of file in data line".to_string())?;
810 Ok(strip_comment(raw).trim().to_string())
811 }
812
813 fn parse_header(&mut self) -> Result<(), String> {
814 let line0 = self.next_line().ok_or("empty .nl file")?;
815 let trimmed = strip_comment(line0).trim();
816 let first = trimmed.chars().next().ok_or("empty header line")?;
817 if first != 'g' {
818 return Err(format!(
819 "only ASCII (g-) .nl files supported; got header '{trimmed}'"
820 ));
821 }
822
823 let l2 = self.next_data_line()?;
825 let nums: Vec<&str> = l2.split_whitespace().collect();
826 if nums.len() < 3 {
827 return Err(format!("malformed line 2: '{l2}'"));
828 }
829 self.n = nums[0].parse().map_err(|e| format!("n: {e}"))?;
830 self.m = nums[1].parse().map_err(|e| format!("m: {e}"))?;
831 self.num_obj = nums[2].parse().map_err(|e| format!("num_obj: {e}"))?;
832
833 for _ in 0..3 {
835 self.next_data_line()?;
836 }
837 let l5 = self.next_data_line()?;
839 let nums5: Vec<&str> = l5.split_whitespace().collect();
840 self.n_funcs = nums5.get(1).and_then(|s| s.parse().ok()).unwrap_or(0);
841 for _ in 0..4 {
843 self.next_data_line()?;
844 }
845 Ok(())
846 }
847
848 fn peek_segment_line(&mut self) -> Option<&'a str> {
849 let saved = self.pos;
850 let l = self.next_line()?;
851 self.pos = saved;
852 Some(l)
853 }
854
855 fn eat_segment_header(&mut self) -> Result<(String, String), String> {
858 let raw = self
859 .next_line()
860 .ok_or_else(|| "expected segment header".to_string())?;
861 let (hdr, comment) = split_comment(raw);
862 Ok((hdr.trim().to_string(), comment.trim().to_string()))
863 }
864
865 fn parse_expr(&mut self) -> Result<Expr, String> {
866 let raw = self
867 .next_line()
868 .ok_or_else(|| "expected expression token".to_string())?;
869 let tok = strip_comment(raw).trim().to_string();
870 if tok.is_empty() {
871 return Err("empty expression token".into());
872 }
873 let first = tok.chars().next().ok_or("empty expression token")?;
874 match first {
875 'n' => {
876 let v: Number = tok[1..]
877 .trim()
878 .parse()
879 .map_err(|e| format!("n value: {e}"))?;
880 Ok(Expr::Const(v))
881 }
882 'v' => {
883 let i: usize = tok[1..]
884 .trim()
885 .parse()
886 .map_err(|e| format!("v index: {e}"))?;
887 Ok(self.var_or_cse(i)?)
888 }
889 'o' => {
890 let code: i32 = tok[1..]
891 .trim()
892 .parse()
893 .map_err(|e| format!("opcode: {e}"))?;
894 self.parse_opcode(code)
895 }
896 'f' => {
897 let rest = &tok[1..];
900 let mut parts = rest.split_whitespace();
901 let id_str = parts
902 .next()
903 .ok_or_else(|| format!("missing function id in '{tok}'"))?;
904 let nargs_str = parts
905 .next()
906 .ok_or_else(|| format!("missing nargs in '{tok}'"))?;
907 let id: usize = id_str
908 .parse()
909 .map_err(|e| format!("bad function id '{id_str}': {e}"))?;
910 let nargs: usize = nargs_str
911 .parse()
912 .map_err(|e| format!("bad funcall nargs '{nargs_str}': {e}"))?;
913 let mut args: Vec<FuncallArg> = Vec::with_capacity(nargs);
914 for _ in 0..nargs {
915 args.push(self.parse_funcall_arg()?);
916 }
917 Ok(Expr::Funcall { id, args })
918 }
919 't' | 'u' => Err(format!("unsupported expression token '{tok}'")),
920 other => Err(format!(
921 "unexpected expression token start '{other}': '{tok}'"
922 )),
923 }
924 }
925
926 fn parse_funcall_arg(&mut self) -> Result<FuncallArg, String> {
932 let saved = self.pos;
934 let raw = self
935 .next_line()
936 .ok_or_else(|| "expected funcall argument".to_string())?;
937 let lead = raw.trim_start();
945 if let Some(after_h) = lead.strip_prefix('h') {
946 let colon = after_h
947 .find(':')
948 .ok_or_else(|| format!("malformed Hollerith string arg (no ':'): {lead:?}"))?;
949 let len: usize = after_h[..colon]
950 .trim()
951 .parse()
952 .map_err(|e| format!("Hollerith length in {lead:?}: {e}"))?;
953 let chars = &after_h[colon + 1..];
954 if chars.len() < len {
955 return Err(format!(
956 "Hollerith string shorter than declared length {len}: {chars:?}"
957 ));
958 }
959 if !chars.is_char_boundary(len) {
962 return Err(format!(
963 "Hollerith length {len} splits a multibyte char in {chars:?}"
964 ));
965 }
966 Ok(FuncallArg::Str(chars[..len].to_string()))
967 } else {
968 self.pos = saved;
970 Ok(FuncallArg::Real(self.parse_expr()?))
971 }
972 }
973
974 fn parse_opcode(&mut self, code: i32) -> Result<Expr, String> {
975 match code {
976 0 => {
977 let a = self.parse_expr()?;
978 let b = self.parse_expr()?;
979 Ok(Expr::Binary(BinOp::Add, Box::new(a), Box::new(b)))
980 }
981 1 => {
982 let a = self.parse_expr()?;
983 let b = self.parse_expr()?;
984 Ok(Expr::Binary(BinOp::Sub, Box::new(a), Box::new(b)))
985 }
986 2 => {
987 let a = self.parse_expr()?;
988 let b = self.parse_expr()?;
989 Ok(Expr::Binary(BinOp::Mul, Box::new(a), Box::new(b)))
990 }
991 3 => {
992 let a = self.parse_expr()?;
993 let b = self.parse_expr()?;
994 Ok(Expr::Binary(BinOp::Div, Box::new(a), Box::new(b)))
995 }
996 5 => {
997 let a = self.parse_expr()?;
998 let b = self.parse_expr()?;
999 Ok(Expr::Binary(BinOp::Pow, Box::new(a), Box::new(b)))
1000 }
1001 15 => Ok(Expr::Unary(UnaryOp::Abs, Box::new(self.parse_expr()?))),
1002 16 => Ok(Expr::Unary(UnaryOp::Neg, Box::new(self.parse_expr()?))),
1003 39 => Ok(Expr::Unary(UnaryOp::Sqrt, Box::new(self.parse_expr()?))),
1004 41 => Ok(Expr::Unary(UnaryOp::Sin, Box::new(self.parse_expr()?))),
1005 42 => Ok(Expr::Unary(UnaryOp::Log10, Box::new(self.parse_expr()?))),
1006 43 => Ok(Expr::Unary(UnaryOp::Log, Box::new(self.parse_expr()?))),
1007 44 => Ok(Expr::Unary(UnaryOp::Exp, Box::new(self.parse_expr()?))),
1008 46 => Ok(Expr::Unary(UnaryOp::Cos, Box::new(self.parse_expr()?))),
1009 38 => Ok(Expr::Unary(UnaryOp::Tan, Box::new(self.parse_expr()?))),
1010 49 => Ok(Expr::Unary(UnaryOp::Atan, Box::new(self.parse_expr()?))),
1011 53 => Ok(Expr::Unary(UnaryOp::Acos, Box::new(self.parse_expr()?))),
1012 40 => Ok(Expr::Unary(UnaryOp::Sinh, Box::new(self.parse_expr()?))),
1013 45 => Ok(Expr::Unary(UnaryOp::Cosh, Box::new(self.parse_expr()?))),
1014 37 => Ok(Expr::Unary(UnaryOp::Tanh, Box::new(self.parse_expr()?))),
1015 51 => Ok(Expr::Unary(UnaryOp::Asin, Box::new(self.parse_expr()?))),
1016 52 => Ok(Expr::Unary(UnaryOp::Acosh, Box::new(self.parse_expr()?))),
1017 50 => Ok(Expr::Unary(UnaryOp::Asinh, Box::new(self.parse_expr()?))),
1018 47 => Ok(Expr::Unary(UnaryOp::Atanh, Box::new(self.parse_expr()?))),
1019 48 => {
1021 let a = self.parse_expr()?;
1022 let b = self.parse_expr()?;
1023 Ok(Expr::Binary(BinOp::Atan2, Box::new(a), Box::new(b)))
1024 }
1025 22 => self.parse_compare(CmpOp::Lt),
1028 23 => self.parse_compare(CmpOp::Le),
1029 24 => self.parse_compare(CmpOp::Eq),
1030 28 => self.parse_compare(CmpOp::Ge),
1031 29 => self.parse_compare(CmpOp::Gt),
1032 30 => self.parse_compare(CmpOp::Ne),
1033 20 => {
1035 let a = self.parse_expr()?;
1036 let b = self.parse_expr()?;
1037 Ok(Expr::Or(Box::new(a), Box::new(b)))
1038 }
1039 21 => {
1040 let a = self.parse_expr()?;
1041 let b = self.parse_expr()?;
1042 Ok(Expr::And(Box::new(a), Box::new(b)))
1043 }
1044 34 => Ok(Expr::Not(Box::new(self.parse_expr()?))),
1045 35 => {
1047 let cond = self.parse_expr()?;
1048 let then_ = self.parse_expr()?;
1049 let else_ = self.parse_expr()?;
1050 Ok(Expr::Cond {
1051 cond: Box::new(cond),
1052 then_: Box::new(then_),
1053 else_: Box::new(else_),
1054 })
1055 }
1056 54 => {
1057 let count_line = self.next_data_line()?;
1059 let count: usize = count_line
1060 .split_whitespace()
1061 .next()
1062 .ok_or_else(|| "missing variadic count".to_string())?
1063 .parse()
1064 .map_err(|e| format!("variadic count: {e}"))?;
1065 let mut args = Vec::with_capacity(count);
1066 for _ in 0..count {
1067 args.push(self.parse_expr()?);
1068 }
1069 Ok(Expr::Sum(args))
1070 }
1071 11 | 12 => {
1074 let count_line = self.next_data_line()?;
1075 let count: usize = count_line
1076 .split_whitespace()
1077 .next()
1078 .ok_or_else(|| "missing min/max list count".to_string())?
1079 .parse()
1080 .map_err(|e| format!("min/max list count: {e}"))?;
1081 let mut args = Vec::with_capacity(count);
1082 for _ in 0..count {
1083 args.push(self.parse_expr()?);
1084 }
1085 if code == 11 {
1086 Ok(Expr::MinList(args))
1087 } else {
1088 Ok(Expr::MaxList(args))
1089 }
1090 }
1091 other => Err(format!("unsupported opcode o{other}")),
1092 }
1093 }
1094
1095 fn parse_compare(&mut self, op: CmpOp) -> Result<Expr, String> {
1098 let a = self.parse_expr()?;
1099 let b = self.parse_expr()?;
1100 Ok(Expr::Compare(op, Box::new(a), Box::new(b)))
1101 }
1102
1103 fn var_or_cse(&self, i: usize) -> Result<Expr, String> {
1106 if i < self.n {
1107 Ok(Expr::Var(i))
1108 } else {
1109 let local = i - self.n;
1110 self.cses
1111 .get(local)
1112 .map(|rc| Expr::Cse(rc.clone()))
1113 .ok_or_else(|| {
1114 format!(
1115 "v{i} references CSE {local} but only {} have been defined",
1116 self.cses.len()
1117 )
1118 })
1119 }
1120 }
1121
1122 fn parse_v_segment(&mut self) -> Result<(), String> {
1126 let (hdr, _) = self.eat_segment_header()?;
1127 let parts: Vec<&str> = hdr.split_whitespace().collect();
1128 if parts.len() < 2 {
1129 return Err(format!("malformed V-segment header: {hdr}"));
1130 }
1131 let cse_idx = parse_segment_index(parts[0], 'V')?;
1132 let nlin: usize = parts[1].parse().map_err(|e| format!("V nlin: {e}"))?;
1133 let mut linear: Vec<(usize, Number)> = Vec::with_capacity(nlin);
1135 for _ in 0..nlin {
1136 let line = self.next_data_line()?;
1137 let (var, coef) = parse_var_coef(&line)?;
1138 linear.push((var, coef));
1139 }
1140 let nonlin = self.parse_expr()?;
1141 let mut combined = nonlin;
1144 for (var, coef) in linear {
1145 let v_expr = self.var_or_cse(var)?;
1146 let term = if coef == 1.0 {
1147 v_expr
1148 } else {
1149 Expr::Binary(BinOp::Mul, Box::new(Expr::Const(coef)), Box::new(v_expr))
1150 };
1151 combined = Expr::Binary(BinOp::Add, Box::new(combined), Box::new(term));
1152 }
1153 if cse_idx < self.n {
1154 return Err(format!("V{cse_idx} below n={}", self.n));
1155 }
1156 let local = cse_idx - self.n;
1157 if local != self.cses.len() {
1158 return Err(format!(
1159 "V-segment index V{cse_idx} out of order; expected V{}",
1160 self.n + self.cses.len()
1161 ));
1162 }
1163 self.cses.push(Arc::new(combined));
1164 Ok(())
1165 }
1166}
1167
1168fn strip_comment(s: &str) -> &str {
1169 match s.find('#') {
1170 Some(i) => &s[..i],
1171 None => s,
1172 }
1173}
1174
1175fn split_comment(s: &str) -> (&str, &str) {
1176 match s.find('#') {
1177 Some(i) => (&s[..i], &s[i + 1..]),
1178 None => (s, ""),
1179 }
1180}
1181
1182pub fn eval_expr(e: &Expr, x: &[Number]) -> Number {
1190 match e {
1191 Expr::Const(c) => *c,
1192 Expr::Var(i) => x[*i],
1193 Expr::Binary(op, a, b) => {
1194 let va = eval_expr(a, x);
1195 let vb = eval_expr(b, x);
1196 match op {
1197 BinOp::Add => va + vb,
1198 BinOp::Sub => va - vb,
1199 BinOp::Mul => va * vb,
1200 BinOp::Div => va / vb,
1201 BinOp::Pow => va.powf(vb),
1202 BinOp::Atan2 => va.atan2(vb),
1203 }
1204 }
1205 Expr::Unary(op, a) => {
1206 let va = eval_expr(a, x);
1207 match op {
1208 UnaryOp::Neg => -va,
1209 UnaryOp::Sqrt => va.sqrt(),
1210 UnaryOp::Log => va.ln(),
1211 UnaryOp::Log10 => va.log10(),
1212 UnaryOp::Exp => va.exp(),
1213 UnaryOp::Abs => va.abs(),
1214 UnaryOp::Sin => va.sin(),
1215 UnaryOp::Cos => va.cos(),
1216 UnaryOp::Tan => va.tan(),
1217 UnaryOp::Atan => va.atan(),
1218 UnaryOp::Acos => va.acos(),
1219 UnaryOp::Sinh => va.sinh(),
1220 UnaryOp::Cosh => va.cosh(),
1221 UnaryOp::Tanh => va.tanh(),
1222 UnaryOp::Asin => va.asin(),
1223 UnaryOp::Acosh => va.acosh(),
1224 UnaryOp::Asinh => va.asinh(),
1225 UnaryOp::Atanh => va.atanh(),
1226 }
1227 }
1228 Expr::Sum(args) => args.iter().map(|a| eval_expr(a, x)).sum(),
1229 Expr::MinList(args) => args
1230 .iter()
1231 .map(|a| eval_expr(a, x))
1232 .fold(Number::INFINITY, Number::min),
1233 Expr::MaxList(args) => args
1234 .iter()
1235 .map(|a| eval_expr(a, x))
1236 .fold(Number::NEG_INFINITY, Number::max),
1237 Expr::Compare(op, a, b) => {
1238 let va = eval_expr(a, x);
1239 let vb = eval_expr(b, x);
1240 let truth = match op {
1241 CmpOp::Lt => va < vb,
1242 CmpOp::Le => va <= vb,
1243 CmpOp::Eq => va == vb,
1244 CmpOp::Ge => va >= vb,
1245 CmpOp::Gt => va > vb,
1246 CmpOp::Ne => va != vb,
1247 };
1248 if truth {
1249 1.0
1250 } else {
1251 0.0
1252 }
1253 }
1254 Expr::And(a, b) => {
1255 if eval_expr(a, x) != 0.0 && eval_expr(b, x) != 0.0 {
1256 1.0
1257 } else {
1258 0.0
1259 }
1260 }
1261 Expr::Or(a, b) => {
1262 if eval_expr(a, x) != 0.0 || eval_expr(b, x) != 0.0 {
1263 1.0
1264 } else {
1265 0.0
1266 }
1267 }
1268 Expr::Not(a) => {
1269 if eval_expr(a, x) == 0.0 {
1270 1.0
1271 } else {
1272 0.0
1273 }
1274 }
1275 Expr::Cond { cond, then_, else_ } => {
1276 if eval_expr(cond, x) != 0.0 {
1277 eval_expr(then_, x)
1278 } else {
1279 eval_expr(else_, x)
1280 }
1281 }
1282 Expr::Cse(body) => eval_expr(body, x),
1283 Expr::Funcall { .. } => panic!(
1284 "eval_expr: AMPL imported function called without an external resolver; \
1285 evaluate through the tape AD path (Tape::build_with_externals) instead"
1286 ),
1287 }
1288}
1289
1290fn argmin_argmax(args: &[Expr], x: &[Number], want_min: bool) -> Option<usize> {
1295 let mut best: Option<(usize, Number)> = None;
1296 for (i, a) in args.iter().enumerate() {
1297 let v = eval_expr(a, x);
1298 match best {
1299 None => best = Some((i, v)),
1300 Some((_, bv)) => {
1301 if (want_min && v < bv) || (!want_min && v > bv) {
1305 best = Some((i, v));
1306 }
1307 }
1308 }
1309 }
1310 best.map(|(i, _)| i)
1311}
1312
1313pub fn grad_expr(e: &Expr, x: &[Number], seed: Number, grad: &mut [Number]) {
1315 match e {
1316 Expr::Const(_) => {}
1317 Expr::Var(i) => grad[*i] += seed,
1318 Expr::Binary(op, a, b) => {
1319 let va = eval_expr(a, x);
1320 let vb = eval_expr(b, x);
1321 match op {
1322 BinOp::Add => {
1323 grad_expr(a, x, seed, grad);
1324 grad_expr(b, x, seed, grad);
1325 }
1326 BinOp::Sub => {
1327 grad_expr(a, x, seed, grad);
1328 grad_expr(b, x, -seed, grad);
1329 }
1330 BinOp::Mul => {
1331 grad_expr(a, x, seed * vb, grad);
1332 grad_expr(b, x, seed * va, grad);
1333 }
1334 BinOp::Div => {
1335 grad_expr(a, x, seed / vb, grad);
1336 grad_expr(b, x, -seed * va / (vb * vb), grad);
1337 }
1338 BinOp::Pow => {
1339 let dpa = vb * va.powf(vb - 1.0);
1341 grad_expr(a, x, seed * dpa, grad);
1342 if va > 0.0 {
1344 let dpb = va.powf(vb) * va.ln();
1345 grad_expr(b, x, seed * dpb, grad);
1346 }
1347 }
1348 BinOp::Atan2 => {
1349 let d = va * va + vb * vb;
1351 grad_expr(a, x, seed * vb / d, grad);
1352 grad_expr(b, x, -seed * va / d, grad);
1353 }
1354 }
1355 }
1356 Expr::Unary(op, a) => {
1357 let va = eval_expr(a, x);
1358 let d = match op {
1359 UnaryOp::Neg => -1.0,
1360 UnaryOp::Sqrt => 0.5 / va.sqrt(),
1361 UnaryOp::Log => 1.0 / va,
1362 UnaryOp::Log10 => 1.0 / (va * std::f64::consts::LN_10),
1363 UnaryOp::Exp => va.exp(),
1364 UnaryOp::Abs => {
1365 if va > 0.0 {
1366 1.0
1367 } else if va < 0.0 {
1368 -1.0
1369 } else {
1370 0.0
1371 }
1372 }
1373 UnaryOp::Sin => va.cos(),
1374 UnaryOp::Cos => -va.sin(),
1375 UnaryOp::Tan => {
1376 let t = va.tan();
1377 1.0 + t * t
1378 }
1379 UnaryOp::Atan => 1.0 / (1.0 + va * va),
1380 UnaryOp::Acos => -1.0 / (1.0 - va * va).sqrt(),
1381 UnaryOp::Sinh => va.cosh(),
1382 UnaryOp::Cosh => va.sinh(),
1383 UnaryOp::Tanh => {
1384 let t = va.tanh();
1385 1.0 - t * t
1386 }
1387 UnaryOp::Asin => 1.0 / (1.0 - va * va).sqrt(),
1388 UnaryOp::Acosh => 1.0 / (va * va - 1.0).sqrt(),
1389 UnaryOp::Asinh => 1.0 / (va * va + 1.0).sqrt(),
1390 UnaryOp::Atanh => 1.0 / (1.0 - va * va),
1391 };
1392 grad_expr(a, x, seed * d, grad);
1393 }
1394 Expr::Sum(args) => {
1395 for arg in args {
1396 grad_expr(arg, x, seed, grad);
1397 }
1398 }
1399 Expr::MinList(args) => {
1404 if let Some(k) = argmin_argmax(args, x, true) {
1405 grad_expr(&args[k], x, seed, grad);
1406 }
1407 }
1408 Expr::MaxList(args) => {
1409 if let Some(k) = argmin_argmax(args, x, false) {
1410 grad_expr(&args[k], x, seed, grad);
1411 }
1412 }
1413 Expr::Compare(_, _, _) | Expr::And(_, _) | Expr::Or(_, _) | Expr::Not(_) => {}
1416 Expr::Cond { cond, then_, else_ } => {
1419 if eval_expr(cond, x) != 0.0 {
1420 grad_expr(then_, x, seed, grad);
1421 } else {
1422 grad_expr(else_, x, seed, grad);
1423 }
1424 }
1425 Expr::Cse(body) => grad_expr(body, x, seed, grad),
1426 Expr::Funcall { .. } => {
1427 panic!("grad_expr: AMPL imported function called without an external resolver")
1428 }
1429 }
1430}
1431
1432pub fn collect_vars(e: &Expr, out: &mut BTreeSet<usize>) {
1434 match e {
1435 Expr::Const(_) => {}
1436 Expr::Var(i) => {
1437 out.insert(*i);
1438 }
1439 Expr::Binary(_, a, b) => {
1440 collect_vars(a, out);
1441 collect_vars(b, out);
1442 }
1443 Expr::Unary(_, a) => collect_vars(a, out),
1444 Expr::Sum(args) | Expr::MinList(args) | Expr::MaxList(args) => {
1445 for a in args {
1446 collect_vars(a, out);
1447 }
1448 }
1449 Expr::Compare(_, a, b) | Expr::And(a, b) | Expr::Or(a, b) => {
1455 collect_vars(a, out);
1456 collect_vars(b, out);
1457 }
1458 Expr::Not(a) => collect_vars(a, out),
1459 Expr::Cond { cond, then_, else_ } => {
1460 collect_vars(cond, out);
1461 collect_vars(then_, out);
1462 collect_vars(else_, out);
1463 }
1464 Expr::Cse(body) => collect_vars(body, out),
1465 Expr::Funcall { args, .. } => {
1466 for a in args {
1467 if let FuncallArg::Real(e) = a {
1468 collect_vars(e, out);
1469 }
1470 }
1471 }
1472 }
1473}
1474
1475#[derive(Debug, Clone)]
1487struct ColorWrite {
1488 row: u32,
1489 hess_idx: u32,
1490}
1491
1492#[derive(Debug, Clone)]
1497pub struct NlTnlp {
1498 prob: NlProblem,
1499 obj_tapes: Vec<Tape>,
1502 con_tapes: Vec<Vec<Tape>>,
1505 h_irow: Vec<i32>,
1508 h_jcol: Vec<i32>,
1509 jac_cols: Vec<Vec<usize>>,
1511 jac_nnz: usize,
1512 seeds: Vec<Vec<f64>>,
1519 decoding: Vec<Vec<ColorWrite>>,
1523 obj_tape_colors: Vec<Vec<u32>>,
1527 con_tape_colors: Vec<Vec<Vec<u32>>>,
1529 final_x: Option<Vec<Number>>,
1530 final_obj: Number,
1531 scratch_row_grad: Vec<f64>,
1533 vals_scratch: Vec<f64>,
1536 dot_scratch: Vec<f64>,
1537 adj_scratch: Vec<f64>,
1538 adj_dot_scratch: Vec<f64>,
1539 compressed: Vec<Vec<f64>>,
1542}
1543
1544const P_ADD: u8 = 10;
1561const P_MUL: u8 = 20;
1562const P_NEG: u8 = 30;
1563const P_POW: u8 = 40;
1564const P_ATOM: u8 = 100;
1565
1566fn fmt_num(x: Number) -> String {
1569 if x.is_finite() && x == x.trunc() && x.abs() < 1e15 {
1570 format!("{}", x as i64)
1571 } else {
1572 format!("{x}")
1573 }
1574}
1575
1576fn var_label(i: usize, var_names: &[String]) -> String {
1579 match var_names.get(i) {
1580 Some(s) if !s.is_empty() => s.clone(),
1581 _ => format!("x[{i}]"),
1582 }
1583}
1584
1585fn expr_prec(e: &Expr) -> u8 {
1587 match e {
1588 Expr::Binary(BinOp::Add, ..) | Expr::Binary(BinOp::Sub, ..) | Expr::Sum(_) => P_ADD,
1589 Expr::Binary(BinOp::Mul, ..) | Expr::Binary(BinOp::Div, ..) => P_MUL,
1590 Expr::Unary(UnaryOp::Neg, _) => P_NEG,
1591 Expr::Binary(BinOp::Pow, ..) => P_POW,
1592 Expr::Cse(inner) => expr_prec(inner),
1593 _ => P_ATOM,
1595 }
1596}
1597
1598fn render_prec(e: &Expr, min_prec: u8, vn: &[String], funcs: &[ImportedFunc]) -> String {
1601 let s = render_expr(e, vn, funcs);
1602 if expr_prec(e) < min_prec {
1603 format!("({s})")
1604 } else {
1605 s
1606 }
1607}
1608
1609fn unary_name(op: UnaryOp) -> &'static str {
1610 match op {
1611 UnaryOp::Neg => "-",
1612 UnaryOp::Sqrt => "sqrt",
1613 UnaryOp::Log => "log",
1614 UnaryOp::Exp => "exp",
1615 UnaryOp::Abs => "abs",
1616 UnaryOp::Sin => "sin",
1617 UnaryOp::Cos => "cos",
1618 UnaryOp::Log10 => "log10",
1619 UnaryOp::Tan => "tan",
1620 UnaryOp::Atan => "atan",
1621 UnaryOp::Acos => "acos",
1622 UnaryOp::Sinh => "sinh",
1623 UnaryOp::Cosh => "cosh",
1624 UnaryOp::Tanh => "tanh",
1625 UnaryOp::Asin => "asin",
1626 UnaryOp::Acosh => "acosh",
1627 UnaryOp::Asinh => "asinh",
1628 UnaryOp::Atanh => "atanh",
1629 }
1630}
1631
1632fn cmp_sym(op: CmpOp) -> &'static str {
1633 match op {
1634 CmpOp::Lt => "<",
1635 CmpOp::Le => "<=",
1636 CmpOp::Eq => "==",
1637 CmpOp::Ge => ">=",
1638 CmpOp::Gt => ">",
1639 CmpOp::Ne => "!=",
1640 }
1641}
1642
1643fn push_additive(out: &mut String, rendered: &str, first: bool) {
1648 if first {
1649 out.push_str(rendered);
1650 } else if let Some(rest) = rendered.strip_prefix('-') {
1651 out.push_str(" - ");
1652 out.push_str(rest);
1653 } else {
1654 out.push_str(" + ");
1655 out.push_str(rendered);
1656 }
1657}
1658
1659fn render_expr(e: &Expr, vn: &[String], funcs: &[ImportedFunc]) -> String {
1661 match e {
1662 Expr::Const(c) => fmt_num(*c),
1663 Expr::Var(i) => var_label(*i, vn),
1664 Expr::Binary(op, l, r) => match op {
1665 BinOp::Add => {
1666 let mut s = render_prec(l, P_ADD, vn, funcs);
1667 push_additive(&mut s, &render_prec(r, P_ADD, vn, funcs), false);
1668 s
1669 }
1670 BinOp::Sub => format!(
1672 "{} - {}",
1673 render_prec(l, P_ADD, vn, funcs),
1674 render_prec(r, P_ADD + 1, vn, funcs)
1675 ),
1676 BinOp::Mul => format!(
1677 "{}*{}",
1678 render_prec(l, P_MUL, vn, funcs),
1679 render_prec(r, P_MUL, vn, funcs)
1680 ),
1681 BinOp::Div => format!(
1682 "{}/{}",
1683 render_prec(l, P_MUL, vn, funcs),
1684 render_prec(r, P_MUL + 1, vn, funcs)
1685 ),
1686 BinOp::Pow => format!(
1688 "{}^{}",
1689 render_prec(l, P_POW + 1, vn, funcs),
1690 render_prec(r, P_POW, vn, funcs)
1691 ),
1692 BinOp::Atan2 => format!(
1693 "atan2({}, {})",
1694 render_expr(l, vn, funcs),
1695 render_expr(r, vn, funcs)
1696 ),
1697 },
1698 Expr::Unary(UnaryOp::Neg, a) => format!("-{}", render_prec(a, P_NEG, vn, funcs)),
1699 Expr::Unary(op, a) => format!("{}({})", unary_name(*op), render_expr(a, vn, funcs)),
1700 Expr::Sum(xs) => {
1701 if xs.is_empty() {
1702 "0".to_string()
1703 } else {
1704 let mut s = String::new();
1705 for (k, x) in xs.iter().enumerate() {
1706 push_additive(&mut s, &render_prec(x, P_ADD, vn, funcs), k == 0);
1707 }
1708 s
1709 }
1710 }
1711 Expr::Cse(inner) => render_expr(inner, vn, funcs),
1712 Expr::Funcall { id, args } => {
1713 let name = funcs
1714 .iter()
1715 .find(|f| f.id == *id)
1716 .map(|f| f.name.clone())
1717 .unwrap_or_else(|| format!("extern#{id}"));
1718 let parts: Vec<String> = args
1719 .iter()
1720 .map(|a| match a {
1721 FuncallArg::Real(x) => render_expr(x, vn, funcs),
1722 FuncallArg::Str(s) => format!("{s:?}"),
1723 })
1724 .collect();
1725 format!("{name}({})", parts.join(", "))
1726 }
1727 Expr::Compare(op, a, b) => format!(
1728 "({} {} {})",
1729 render_expr(a, vn, funcs),
1730 cmp_sym(*op),
1731 render_expr(b, vn, funcs)
1732 ),
1733 Expr::And(a, b) => format!(
1734 "({} && {})",
1735 render_expr(a, vn, funcs),
1736 render_expr(b, vn, funcs)
1737 ),
1738 Expr::Or(a, b) => format!(
1739 "({} || {})",
1740 render_expr(a, vn, funcs),
1741 render_expr(b, vn, funcs)
1742 ),
1743 Expr::Not(a) => format!("!({})", render_expr(a, vn, funcs)),
1744 Expr::Cond { cond, then_, else_ } => format!(
1745 "if({}, {}, {})",
1746 render_expr(cond, vn, funcs),
1747 render_expr(then_, vn, funcs),
1748 render_expr(else_, vn, funcs)
1749 ),
1750 Expr::MinList(xs) => format!(
1751 "min({})",
1752 xs.iter()
1753 .map(|x| render_expr(x, vn, funcs))
1754 .collect::<Vec<_>>()
1755 .join(", ")
1756 ),
1757 Expr::MaxList(xs) => format!(
1758 "max({})",
1759 xs.iter()
1760 .map(|x| render_expr(x, vn, funcs))
1761 .collect::<Vec<_>>()
1762 .join(", ")
1763 ),
1764 }
1765}
1766
1767fn render_linear(linear: &[(usize, Number)], vn: &[String]) -> String {
1770 let mut out = String::new();
1771 let mut first = true;
1776 for (var, coef) in linear {
1777 if *coef == 0.0 {
1778 continue;
1779 }
1780 let neg = *coef < 0.0;
1781 let mag = coef.abs();
1782 let term = if mag == 1.0 {
1783 var_label(*var, vn)
1784 } else {
1785 format!("{}*{}", fmt_num(mag), var_label(*var, vn))
1786 };
1787 if first {
1788 if neg {
1789 out.push('-');
1790 }
1791 out.push_str(&term);
1792 first = false;
1793 } else {
1794 out.push_str(if neg { " - " } else { " + " });
1795 out.push_str(&term);
1796 }
1797 }
1798 out
1799}
1800
1801fn render_body(linear: &[(usize, Number)], nonlinear: &Expr, prob: &NlProblem) -> String {
1803 let mut s = render_linear(linear, &prob.var_names);
1804 let nl_is_zero = matches!(nonlinear, Expr::Const(c) if *c == 0.0);
1805 if !nl_is_zero {
1806 let nl = render_prec(nonlinear, P_ADD, &prob.var_names, &prob.imported_funcs);
1807 if s.is_empty() {
1808 s = nl;
1809 } else {
1810 push_additive(&mut s, &nl, false);
1811 }
1812 }
1813 if s.is_empty() {
1814 s = "0".to_string();
1815 }
1816 s
1817}
1818
1819pub fn render_constraint_equation(prob: &NlProblem, k: usize) -> String {
1823 let body = render_body(&prob.con_linear[k], &prob.con_nonlinear[k], prob);
1824 let lo = prob.g_l[k];
1825 let hi = prob.g_u[k];
1826 const INF: Number = 1.0e19;
1827 let has_lo = lo > -INF;
1828 let has_hi = hi < INF;
1829 match (has_lo, has_hi) {
1830 (true, true) if lo == hi => format!("{body} = {}", fmt_num(lo)),
1831 (true, true) => format!("{} <= {body} <= {}", fmt_num(lo), fmt_num(hi)),
1832 (true, false) => format!("{body} >= {}", fmt_num(lo)),
1833 (false, true) => format!("{body} <= {}", fmt_num(hi)),
1834 (false, false) => format!("{body} (free)"),
1835 }
1836}
1837
1838pub fn render_all_constraint_equations(prob: &NlProblem) -> Vec<String> {
1841 (0..prob.m)
1842 .map(|k| render_constraint_equation(prob, k))
1843 .collect()
1844}
1845
1846pub fn constraint_jacobian_sparsity(prob: &NlProblem) -> (Vec<Index>, Vec<Index>) {
1860 let mut irow: Vec<Index> = Vec::new();
1861 let mut jcol: Vec<Index> = Vec::new();
1862 let mut support: BTreeSet<usize> = BTreeSet::new();
1863 for k in 0..prob.m {
1864 support.clear();
1865 for &(j, _coef) in &prob.con_linear[k] {
1866 support.insert(j);
1867 }
1868 collect_vars(&prob.con_nonlinear[k], &mut support);
1869 for &j in &support {
1870 irow.push(k as Index);
1871 jcol.push(j as Index);
1872 }
1873 }
1874 (irow, jcol)
1875}
1876
1877fn split_top_sums(expr: &Expr) -> Vec<Expr> {
1904 let mut out = Vec::new();
1905 fn push_leaf(e: &Expr, factor: f64, out: &mut Vec<Expr>) {
1906 if factor == 1.0 {
1907 out.push(e.clone());
1908 } else if factor == -1.0 {
1909 out.push(Expr::Unary(UnaryOp::Neg, Box::new(e.clone())));
1910 } else {
1911 out.push(Expr::Binary(
1912 BinOp::Mul,
1913 Box::new(Expr::Const(factor)),
1914 Box::new(e.clone()),
1915 ));
1916 }
1917 }
1918 fn go(e: &Expr, factor: f64, out: &mut Vec<Expr>) {
1919 match e {
1920 Expr::Sum(terms) => {
1921 for t in terms {
1922 go(t, factor, out);
1923 }
1924 }
1925 Expr::Binary(BinOp::Add, l, r) => {
1926 go(l, factor, out);
1927 go(r, factor, out);
1928 }
1929 Expr::Binary(BinOp::Sub, l, r) => {
1930 go(l, factor, out);
1931 go(r, -factor, out);
1932 }
1933 Expr::Unary(UnaryOp::Neg, x) => {
1934 go(x, -factor, out);
1935 }
1936 Expr::Binary(BinOp::Mul, l, r) => match (l.as_ref(), r.as_ref()) {
1939 (Expr::Const(c), _) => go(r, factor * c, out),
1940 (_, Expr::Const(c)) => go(l, factor * c, out),
1941 _ => push_leaf(e, factor, out),
1942 },
1943 Expr::Binary(BinOp::Div, l, r) => match r.as_ref() {
1944 Expr::Const(c) if *c != 0.0 => go(l, factor / c, out),
1945 _ => push_leaf(e, factor, out),
1946 },
1947 _ => push_leaf(e, factor, out),
1948 }
1949 }
1950 go(expr, 1.0, &mut out);
1951 if out.is_empty() {
1952 out.push(Expr::Const(0.0));
1953 }
1954 out
1955}
1956
1957fn greedy_hessian_coloring(n: usize, lower_pairs: &[(usize, usize)]) -> (Vec<u32>, usize) {
1973 if n == 0 {
1974 return (Vec::new(), 0);
1975 }
1976
1977 let mut col_rows: Vec<Vec<u32>> = vec![Vec::new(); n];
1982 let mut row_cols: Vec<Vec<u32>> = vec![Vec::new(); n];
1983 for &(i, j) in lower_pairs {
1984 col_rows[j].push(i as u32);
1985 row_cols[i].push(j as u32);
1986 if i != j {
1987 col_rows[i].push(j as u32);
1988 row_cols[j].push(i as u32);
1989 }
1990 }
1991
1992 let mut var_color = vec![u32::MAX; n];
1993 let mut forbidden = vec![u32::MAX; n + 1];
1994 let mut n_colors: u32 = 0;
1995
1996 for j in 0..n {
1997 if col_rows[j].is_empty() {
1999 continue;
2000 }
2001 for &r in &col_rows[j] {
2005 for &c in &row_cols[r as usize] {
2006 if c as usize == j {
2007 continue;
2008 }
2009 let cc = var_color[c as usize];
2010 if cc != u32::MAX {
2011 forbidden[cc as usize] = j as u32;
2012 }
2013 }
2014 }
2015 let mut chosen: u32 = 0;
2017 while (chosen as usize) < forbidden.len() && forbidden[chosen as usize] == j as u32 {
2018 chosen += 1;
2019 }
2020 var_color[j] = chosen;
2021 if chosen + 1 > n_colors {
2022 n_colors = chosen + 1;
2023 }
2024 }
2025
2026 (var_color, n_colors as usize)
2027}
2028
2029impl NlTnlp {
2030 pub fn new(prob: NlProblem) -> Self {
2039 Self::try_new(prob)
2040 .unwrap_or_else(|e| panic!("failed to resolve AMPL external functions: {e}"))
2041 }
2042
2043 pub fn try_new(prob: NlProblem) -> Result<Self, String> {
2048 let mut referenced: BTreeSet<usize> = BTreeSet::new();
2054 super::nl_external::collect_funcall_ids(&prob.obj_nonlinear, &mut referenced);
2055 for c in &prob.con_nonlinear {
2056 super::nl_external::collect_funcall_ids(c, &mut referenced);
2057 }
2058 let resolver = if referenced.is_empty() {
2059 super::nl_external::ExternalResolver::default()
2060 } else {
2061 super::nl_external::ExternalResolver::build_for_problem(
2062 &prob.imported_funcs,
2063 &referenced,
2064 )?
2065 };
2066
2067 let obj_summands = split_top_sums(&prob.obj_nonlinear);
2073 let obj_tapes: Vec<Tape> = obj_summands
2074 .iter()
2075 .map(|e| Tape::build_with_externals(e, &resolver))
2076 .collect();
2077
2078 let mut con_tapes: Vec<Vec<Tape>> = Vec::with_capacity(prob.m);
2079 for k in 0..prob.m {
2080 let summands = split_top_sums(&prob.con_nonlinear[k]);
2081 con_tapes.push(
2082 summands
2083 .iter()
2084 .map(|e| Tape::build_with_externals(e, &resolver))
2085 .collect(),
2086 );
2087 }
2088
2089 let mut pairs: BTreeSet<(usize, usize)> = BTreeSet::new();
2092 for t in &obj_tapes {
2093 pairs.extend(t.hessian_sparsity());
2094 }
2095 for row in &con_tapes {
2096 for t in row {
2097 pairs.extend(t.hessian_sparsity());
2098 }
2099 }
2100 let mut h_irow = Vec::with_capacity(pairs.len());
2101 let mut h_jcol = Vec::with_capacity(pairs.len());
2102 let mut hess_map = HashMap::with_capacity(pairs.len());
2103 for (k, (hi, lo)) in pairs.iter().enumerate() {
2104 h_irow.push(*hi as i32);
2105 h_jcol.push(*lo as i32);
2106 hess_map.insert((*hi, *lo), k);
2107 }
2108
2109 let lower_pairs: Vec<(usize, usize)> = pairs.iter().copied().collect();
2114 let (var_color, n_colors) = greedy_hessian_coloring(prob.n, &lower_pairs);
2115
2116 let mut seeds: Vec<Vec<f64>> = vec![vec![0.0; prob.n]; n_colors];
2119 for (k, &c) in var_color.iter().enumerate() {
2120 if c != u32::MAX {
2121 seeds[c as usize][k] = 1.0;
2122 }
2123 }
2124
2125 let mut decoding: Vec<Vec<ColorWrite>> = vec![Vec::new(); n_colors];
2131 for (&(i, j), &idx) in hess_map.iter() {
2132 let c = var_color[j];
2133 debug_assert!(
2134 c != u32::MAX,
2135 "column {j} has Hessian pair {idx} but no color"
2136 );
2137 decoding[c as usize].push(ColorWrite {
2138 row: i as u32,
2139 hess_idx: idx as u32,
2140 });
2141 }
2142
2143 let tape_colors = |t: &Tape| -> Vec<u32> {
2147 let mut s: BTreeSet<u32> = BTreeSet::new();
2148 for v in t.variables() {
2149 let c = var_color[v];
2150 if c != u32::MAX {
2151 s.insert(c);
2152 }
2153 }
2154 s.into_iter().collect()
2155 };
2156 let obj_tape_colors: Vec<Vec<u32>> = obj_tapes.iter().map(tape_colors).collect();
2157 let con_tape_colors: Vec<Vec<Vec<u32>>> = con_tapes
2158 .iter()
2159 .map(|row| row.iter().map(tape_colors).collect())
2160 .collect();
2161
2162 let mut jac_cols: Vec<Vec<usize>> = Vec::with_capacity(prob.m);
2165 let mut jac_nnz = 0;
2166 for i in 0..prob.m {
2167 let mut set: BTreeSet<usize> = BTreeSet::new();
2168 for t in &con_tapes[i] {
2169 for v in t.variables() {
2170 set.insert(v);
2171 }
2172 }
2173 for (v, _) in &prob.con_linear[i] {
2174 set.insert(*v);
2175 }
2176 let cols: Vec<usize> = set.into_iter().collect();
2177 jac_nnz += cols.len();
2178 jac_cols.push(cols);
2179 }
2180
2181 let mut max_tape_n: usize = 0;
2182 for t in &obj_tapes {
2183 max_tape_n = max_tape_n.max(t.ops.len());
2184 }
2185 for row in &con_tapes {
2186 for t in row {
2187 max_tape_n = max_tape_n.max(t.ops.len());
2188 }
2189 }
2190
2191 if std::env::var("POUNCE_DBG_TAPE_STATS").is_ok() {
2192 let n_obj = obj_tapes.len();
2193 let n_con: usize = con_tapes.iter().map(|r| r.len()).sum();
2194 let total = n_obj + n_con;
2195 let mut sum_ops: usize = 0;
2196 for t in &obj_tapes {
2197 sum_ops += t.ops.len();
2198 }
2199 for row in &con_tapes {
2200 for t in row {
2201 sum_ops += t.ops.len();
2202 }
2203 }
2204 let t = total.max(1);
2205 let nnz_h = h_irow.len();
2206 let avg_decode =
2207 decoding.iter().map(|d| d.len()).sum::<usize>() as f64 / n_colors.max(1) as f64;
2208 eprintln!(
2209 "[tape stats] summands={total} (obj={n_obj} con={n_con}) \
2210 total_ops={sum_ops} avg_ops={:.1} max_ops={max_tape_n} \
2211 n_colors={n_colors} avg_decode_per_color={avg_decode:.1} nnz_h={nnz_h}",
2212 sum_ops as f64 / t as f64,
2213 );
2214 }
2215
2216 let compressed: Vec<Vec<f64>> = vec![vec![0.0; prob.n]; n_colors];
2217
2218 Ok(Self {
2219 prob,
2220 obj_tapes,
2221 con_tapes,
2222 h_irow,
2223 h_jcol,
2224 jac_cols,
2225 jac_nnz,
2226 seeds,
2227 decoding,
2228 obj_tape_colors,
2229 con_tape_colors,
2230 final_x: None,
2231 final_obj: 0.0,
2232 scratch_row_grad: Vec::new(),
2233 vals_scratch: vec![0.0; max_tape_n],
2234 dot_scratch: vec![0.0; max_tape_n],
2235 adj_scratch: vec![0.0; max_tape_n],
2236 adj_dot_scratch: vec![0.0; max_tape_n],
2237 compressed,
2238 })
2239 }
2240
2241 pub fn final_x(&self) -> Option<&[Number]> {
2242 self.final_x.as_deref()
2243 }
2244
2245 pub fn final_obj(&self) -> Number {
2246 self.final_obj
2247 }
2248
2249 pub fn problem(&self) -> &NlProblem {
2253 &self.prob
2254 }
2255
2256 pub fn variant(&self, v: &NlVariation) -> Result<Self, String> {
2268 let check = |name: &str, got: usize, want: usize| -> Result<(), String> {
2269 if got == want {
2270 Ok(())
2271 } else {
2272 Err(format!(
2273 "NlVariation.{name} has length {got}, expected {want}"
2274 ))
2275 }
2276 };
2277 let mut out = self.clone();
2278 out.final_x = None;
2279 out.final_obj = 0.0;
2280 if let Some(x0) = &v.x0 {
2281 check("x0", x0.len(), self.prob.n)?;
2282 out.prob.x0.clone_from(x0);
2283 }
2284 if let Some(x_l) = &v.x_l {
2285 check("x_l", x_l.len(), self.prob.n)?;
2286 out.prob.x_l.clone_from(x_l);
2287 }
2288 if let Some(x_u) = &v.x_u {
2289 check("x_u", x_u.len(), self.prob.n)?;
2290 out.prob.x_u.clone_from(x_u);
2291 }
2292 if let Some(g_l) = &v.g_l {
2293 check("g_l", g_l.len(), self.prob.m)?;
2294 out.prob.g_l.clone_from(g_l);
2295 }
2296 if let Some(g_u) = &v.g_u {
2297 check("g_u", g_u.len(), self.prob.m)?;
2298 out.prob.g_u.clone_from(g_u);
2299 }
2300 Ok(out)
2301 }
2302
2303 pub fn variants(&self, vs: &[NlVariation]) -> Result<Vec<Self>, String> {
2307 vs.iter().map(|v| self.variant(v)).collect()
2308 }
2309}
2310
2311#[derive(Debug, Clone, Default)]
2318pub struct NlVariation {
2319 pub x0: Option<Vec<Number>>,
2320 pub x_l: Option<Vec<Number>>,
2321 pub x_u: Option<Vec<Number>>,
2322 pub g_l: Option<Vec<Number>>,
2323 pub g_u: Option<Vec<Number>>,
2324}
2325
2326impl pounce_nlp::expression_provider::ExpressionProvider for NlTnlp {
2327 fn constraint_expression(&self, i: usize) -> Option<pounce_nlp::FbbtTape> {
2332 let nonlinear = self.prob.con_nonlinear.get(i)?;
2333 let linear = self
2334 .prob
2335 .con_linear
2336 .get(i)
2337 .map(|v| v.as_slice())
2338 .unwrap_or(&[]);
2339 crate::nl_fbbt_translate::translate_constraint(nonlinear, linear)
2340 }
2341
2342 fn variable_name(&self, i: usize) -> Option<&str> {
2345 self.prob.var_names.get(i).map(String::as_str)
2346 }
2347
2348 fn constraint_name(&self, i: usize) -> Option<&str> {
2351 self.prob.con_names.get(i).map(String::as_str)
2352 }
2353}
2354
2355impl TNLP for NlTnlp {
2356 fn get_nlp_info(&mut self) -> Option<NlpInfo> {
2357 Some(NlpInfo {
2358 n: self.prob.n as Index,
2359 m: self.prob.m as Index,
2360 nnz_jac_g: self.jac_nnz as Index,
2361 nnz_h_lag: self.h_irow.len() as Index,
2362 index_style: IndexStyle::C,
2363 })
2364 }
2365
2366 fn get_bounds_info(&mut self, b: BoundsInfo<'_>) -> bool {
2367 b.x_l.copy_from_slice(&self.prob.x_l);
2368 b.x_u.copy_from_slice(&self.prob.x_u);
2369 if !self.prob.g_l.is_empty() {
2370 b.g_l.copy_from_slice(&self.prob.g_l);
2371 b.g_u.copy_from_slice(&self.prob.g_u);
2372 }
2373 true
2374 }
2375
2376 fn get_starting_point(&mut self, sp: StartingPoint<'_>) -> bool {
2377 sp.x.copy_from_slice(&self.prob.x0);
2378 if sp.init_lambda {
2387 sp.lambda.copy_from_slice(&self.prob.lambda0);
2388 }
2389 true
2390 }
2391
2392 fn eval_f(&mut self, x: &[Number], _new_x: bool) -> Option<Number> {
2393 let mut nl: Number = 0.0;
2394 for t in &self.obj_tapes {
2395 nl += t.eval(x);
2396 }
2397 let lin: Number = self.prob.obj_linear.iter().map(|(i, c)| c * x[*i]).sum();
2398 let v = self.prob.obj_constant + nl + lin;
2399 let signed = if self.prob.minimize { v } else { -v };
2400 Some(signed)
2401 }
2402
2403 fn eval_grad_f(&mut self, x: &[Number], _new_x: bool, grad: &mut [Number]) -> bool {
2404 grad.fill(0.0);
2405 for t in &self.obj_tapes {
2409 t.gradient_seed_into(x, 1.0, grad, &mut self.vals_scratch, &mut self.adj_scratch);
2410 }
2411 for (i, c) in &self.prob.obj_linear {
2412 grad[*i] += c;
2413 }
2414 if !self.prob.minimize {
2415 for g in grad.iter_mut() {
2416 *g = -*g;
2417 }
2418 }
2419 true
2420 }
2421
2422 fn eval_g(&mut self, x: &[Number], _new_x: bool, g: &mut [Number]) -> bool {
2423 for i in 0..self.prob.m {
2424 let mut nl: Number = 0.0;
2425 for t in &self.con_tapes[i] {
2426 nl += t.eval(x);
2427 }
2428 let lin: Number = self.prob.con_linear[i].iter().map(|(j, c)| c * x[*j]).sum();
2429 g[i] = nl + lin;
2430 }
2431 true
2432 }
2433
2434 fn eval_jac_g(
2435 &mut self,
2436 x: Option<&[Number]>,
2437 _new_x: bool,
2438 mode: SparsityRequest<'_>,
2439 ) -> bool {
2440 match mode {
2441 SparsityRequest::Structure { irow, jcol } => {
2442 let mut k = 0;
2443 for i in 0..self.prob.m {
2444 for &j in &self.jac_cols[i] {
2445 irow[k] = i as Index;
2446 jcol[k] = j as Index;
2447 k += 1;
2448 }
2449 }
2450 true
2451 }
2452 SparsityRequest::Values { values } => {
2453 let n = self.prob.n;
2454 let xs = x.unwrap_or(&self.prob.x0);
2455 if self.scratch_row_grad.len() < n {
2456 self.scratch_row_grad.resize(n, 0.0);
2457 }
2458 let mut k = 0;
2459 for i in 0..self.prob.m {
2460 for &j in &self.jac_cols[i] {
2461 self.scratch_row_grad[j] = 0.0;
2462 }
2463 for t in &self.con_tapes[i] {
2464 t.gradient_seed_into(
2467 xs,
2468 1.0,
2469 &mut self.scratch_row_grad,
2470 &mut self.vals_scratch,
2471 &mut self.adj_scratch,
2472 );
2473 }
2474 for &(v, c) in &self.prob.con_linear[i] {
2475 self.scratch_row_grad[v] += c;
2476 }
2477 for &j in &self.jac_cols[i] {
2478 values[k] = self.scratch_row_grad[j];
2479 k += 1;
2480 }
2481 }
2482 true
2483 }
2484 }
2485 }
2486
2487 fn eval_h(
2488 &mut self,
2489 x: Option<&[Number]>,
2490 _new_x: bool,
2491 obj_factor: Number,
2492 lambda: Option<&[Number]>,
2493 _new_lambda: bool,
2494 mode: SparsityRequest<'_>,
2495 ) -> bool {
2496 match mode {
2497 SparsityRequest::Structure { irow, jcol } => {
2498 irow.copy_from_slice(&self.h_irow);
2499 jcol.copy_from_slice(&self.h_jcol);
2500 true
2501 }
2502 SparsityRequest::Values { values } => {
2503 let x = x.unwrap_or(&self.prob.x0);
2504 values.fill(0.0);
2505
2506 let obj_seed = if self.prob.minimize {
2507 obj_factor
2508 } else {
2509 -obj_factor
2510 };
2511 for buf in &mut self.compressed {
2520 buf.fill(0.0);
2521 }
2522
2523 if obj_seed != 0.0 {
2524 for (ti, t) in self.obj_tapes.iter().enumerate() {
2525 if t.ops.is_empty() {
2526 continue;
2527 }
2528 t.forward_into(x, &mut self.vals_scratch);
2529 for &c in &self.obj_tape_colors[ti] {
2530 t.hessian_directional(
2531 &self.vals_scratch,
2532 &self.seeds[c as usize],
2533 obj_seed,
2534 &mut self.compressed[c as usize],
2535 &mut self.dot_scratch,
2536 &mut self.adj_scratch,
2537 &mut self.adj_dot_scratch,
2538 );
2539 }
2540 }
2541 }
2542
2543 if let Some(lam) = lambda {
2544 for k in 0..self.prob.m {
2545 let w = lam[k];
2546 if w == 0.0 {
2547 continue;
2548 }
2549 for (ti, t) in self.con_tapes[k].iter().enumerate() {
2550 if t.ops.is_empty() {
2551 continue;
2552 }
2553 t.forward_into(x, &mut self.vals_scratch);
2554 for &c in &self.con_tape_colors[k][ti] {
2555 t.hessian_directional(
2556 &self.vals_scratch,
2557 &self.seeds[c as usize],
2558 w,
2559 &mut self.compressed[c as usize],
2560 &mut self.dot_scratch,
2561 &mut self.adj_scratch,
2562 &mut self.adj_dot_scratch,
2563 );
2564 }
2565 }
2566 }
2567 }
2568
2569 for (c, table) in self.decoding.iter().enumerate() {
2572 let comp = &self.compressed[c];
2573 for w in table {
2574 values[w.hess_idx as usize] += comp[w.row as usize];
2575 }
2576 }
2577 true
2578 }
2579 }
2580 }
2581
2582 fn finalize_solution(&mut self, sol: Solution<'_>, _d: &IpoptData, _q: &IpoptCq) {
2583 self.final_x = Some(sol.x.to_vec());
2584 self.final_obj = sol.obj_value;
2585 }
2586
2587 fn get_var_con_metadata(&mut self, var: &mut MetaData, con: &mut MetaData) -> bool {
2597 let mut any = false;
2598 if !self.prob.var_names.is_empty() {
2599 var.strings
2600 .insert(IDX_NAMES.to_string(), self.prob.var_names.clone());
2601 any = true;
2602 }
2603 if !self.prob.con_names.is_empty() {
2604 con.strings
2605 .insert(IDX_NAMES.to_string(), self.prob.con_names.clone());
2606 any = true;
2607 }
2608 any
2609 }
2610
2611 fn get_constraints_linearity(&mut self, types: &mut [Linearity]) -> bool {
2612 for (i, t) in types.iter_mut().enumerate() {
2616 *t = match &self.prob.con_nonlinear[i] {
2617 Expr::Const(c) if *c == 0.0 => Linearity::Linear,
2618 _ => Linearity::NonLinear,
2619 };
2620 }
2621 true
2622 }
2623
2624 fn get_variables_linearity(&mut self, types: &mut [Linearity]) -> bool {
2625 let mut nonlinear: BTreeSet<usize> = BTreeSet::new();
2635 collect_vars(&self.prob.obj_nonlinear, &mut nonlinear);
2636 for row in &self.prob.con_nonlinear {
2637 collect_vars(row, &mut nonlinear);
2638 }
2639 for (i, t) in types.iter_mut().enumerate() {
2640 *t = if nonlinear.contains(&i) {
2641 Linearity::NonLinear
2642 } else {
2643 Linearity::Linear
2644 };
2645 }
2646 true
2647 }
2648
2649 fn get_objective_variables_linearity(&mut self, types: &mut [Linearity]) -> bool {
2650 let mut nonlinear: BTreeSet<usize> = BTreeSet::new();
2661 collect_vars(&self.prob.obj_nonlinear, &mut nonlinear);
2662 for (i, t) in types.iter_mut().enumerate() {
2663 *t = if nonlinear.contains(&i) {
2664 Linearity::NonLinear
2665 } else {
2666 Linearity::Linear
2667 };
2668 }
2669 true
2670 }
2671}
2672
2673pub fn load_nl_as_tnlp(path: &Path) -> Result<Rc<RefCell<dyn TNLP>>, String> {
2675 let prob = read_nl_file(path)?;
2676 Ok(Rc::new(RefCell::new(NlTnlp::new(prob))))
2677}
2678
2679#[cfg(test)]
2680mod tests {
2681 use super::*;
2682
2683 #[test]
2688 fn nl_problem_and_tnlp_are_send() {
2689 fn assert_send<T: Send>() {}
2690 assert_send::<NlProblem>();
2691 assert_send::<NlTnlp>();
2692 assert_send::<Expr>();
2693 }
2694
2695 #[test]
2698 fn variant_overrides_bounds_and_x0() {
2699 let p = parse_nl_text(SIMPLE).expect("parse");
2700 let mut base = NlTnlp::new(p);
2701 let var = base
2702 .variant(&NlVariation {
2703 x0: Some(vec![3.0, 4.0]),
2704 x_l: Some(vec![-1.0, -2.0]),
2705 x_u: Some(vec![5.0, 6.0]),
2706 ..Default::default()
2707 })
2708 .expect("variant");
2709 let mut var = var;
2710 let (mut x_l, mut x_u) = ([0.0; 2], [0.0; 2]);
2711 let (mut g_l, mut g_u) = ([0.0; 0], [0.0; 0]);
2712 assert!(var.get_bounds_info(BoundsInfo {
2713 x_l: &mut x_l,
2714 x_u: &mut x_u,
2715 g_l: &mut g_l,
2716 g_u: &mut g_u,
2717 }));
2718 assert_eq!(x_l, [-1.0, -2.0]);
2719 assert_eq!(x_u, [5.0, 6.0]);
2720 let mut x = [0.0; 2];
2721 let (mut zl, mut zu, mut lam) = ([0.0; 2], [0.0; 2], [0.0; 0]);
2722 assert!(var.get_starting_point(StartingPoint {
2723 init_x: true,
2724 x: &mut x,
2725 init_z: false,
2726 z_l: &mut zl,
2727 z_u: &mut zu,
2728 init_lambda: false,
2729 lambda: &mut lam,
2730 }));
2731 assert_eq!(x, [3.0, 4.0]);
2732 assert!(base.problem().x_l[0] < -1.0e18);
2734 assert!(base
2736 .variant(&NlVariation {
2737 x0: Some(vec![1.0]),
2738 ..Default::default()
2739 })
2740 .is_err());
2741 }
2742
2743 const SIMPLE: &str = "g3 0 1 0
27612 0 1 0 0
27620 1
27630 0
27640 2 0
27650 0 0 1
27660 0 0 0 0
27670 0
27680 0
27690 0 0 0 0
2770O0 0
2771o0
2772o5
2773o1
2774v0
2775n1
2776n2
2777o5
2778o1
2779v1
2780n2
2781n2
2782b
27833
27843
2785";
2786
2787 #[test]
2788 fn parses_simple_quadratic() {
2789 let p = parse_nl_text(SIMPLE).expect("parse");
2790 assert_eq!(p.n, 2);
2791 assert_eq!(p.m, 0);
2792 assert_eq!(p.num_obj, 1);
2793 let f = eval_expr(&p.obj_nonlinear, &[0.0, 0.0]);
2795 assert!((f - 5.0).abs() < 1e-12);
2796 let f = eval_expr(&p.obj_nonlinear, &[1.0, 2.0]);
2798 assert!(f.abs() < 1e-12);
2799 }
2800
2801 #[test]
2802 fn gradient_matches_analytic() {
2803 let p = parse_nl_text(SIMPLE).expect("parse");
2804 let x = [0.5, 1.0];
2805 let mut g = [0.0_f64; 2];
2806 grad_expr(&p.obj_nonlinear, &x, 1.0, &mut g);
2807 assert!((g[0] - (-1.0)).abs() < 1e-12);
2810 assert!((g[1] - (-2.0)).abs() < 1e-12);
2811 }
2812
2813 #[test]
2824 fn variables_linearity_tags_obj_nonlinear_vs_linear_vars() {
2825 let obj_nl = Expr::Binary(
2827 BinOp::Pow,
2828 Box::new(Expr::Binary(
2829 BinOp::Sub,
2830 Box::new(Expr::Var(0)),
2831 Box::new(Expr::Const(1.0)),
2832 )),
2833 Box::new(Expr::Const(2.0)),
2834 );
2835 let prob = NlProblem {
2836 n: 2,
2837 m: 0,
2838 num_obj: 1,
2839 minimize: true,
2840 obj_nonlinear: obj_nl,
2841 obj_linear: vec![(1, 3.0)],
2842 obj_constant: 0.0,
2843 con_nonlinear: vec![],
2844 con_linear: vec![],
2845 x_l: vec![f64::NEG_INFINITY; 2],
2846 x_u: vec![f64::INFINITY; 2],
2847 g_l: vec![],
2848 g_u: vec![],
2849 x0: vec![0.0; 2],
2850 lambda0: vec![],
2851 suffixes: NlSuffixes::default(),
2852 imported_funcs: vec![],
2853 var_names: vec![],
2854 con_names: vec![],
2855 };
2856 let mut tnlp = NlTnlp::new(prob);
2857 let mut types = vec![Linearity::Linear; 2];
2858 let ok = tnlp.get_variables_linearity(&mut types);
2859 assert!(
2861 ok,
2862 "get_variables_linearity must report it filled the slice"
2863 );
2864 assert!(
2865 matches!(types[0], Linearity::NonLinear),
2866 "x0 is nonlinear in the objective"
2867 );
2868 assert!(
2869 matches!(types[1], Linearity::Linear),
2870 "x1 appears only in the linear part"
2871 );
2872 }
2873
2874 #[test]
2882 fn objective_variables_linearity_ignores_constraint_nonlinearity() {
2883 let con_nl = Expr::Binary(
2885 BinOp::Pow,
2886 Box::new(Expr::Var(0)),
2887 Box::new(Expr::Const(2.0)),
2888 );
2889 let prob = NlProblem {
2890 n: 2,
2891 m: 1,
2892 num_obj: 1,
2893 minimize: true,
2894 obj_nonlinear: Expr::Const(0.0),
2895 obj_linear: vec![(1, 3.0)],
2896 obj_constant: 0.0,
2897 con_nonlinear: vec![con_nl],
2898 con_linear: vec![vec![]],
2899 x_l: vec![f64::NEG_INFINITY; 2],
2900 x_u: vec![f64::INFINITY; 2],
2901 g_l: vec![4.0],
2902 g_u: vec![4.0],
2903 x0: vec![0.0; 2],
2904 lambda0: vec![0.0],
2905 suffixes: NlSuffixes::default(),
2906 imported_funcs: vec![],
2907 var_names: vec![],
2908 con_names: vec![],
2909 };
2910 let mut tnlp = NlTnlp::new(prob);
2911
2912 let mut global = vec![Linearity::Linear; 2];
2913 assert!(tnlp.get_variables_linearity(&mut global));
2914 assert!(
2915 matches!(global[0], Linearity::NonLinear),
2916 "global tags see x0's constraint nonlinearity"
2917 );
2918
2919 let mut obj = vec![Linearity::NonLinear; 2];
2920 assert!(tnlp.get_objective_variables_linearity(&mut obj));
2921 assert!(
2922 matches!(obj[0], Linearity::Linear),
2923 "x0 is linear w.r.t. the objective despite the nonlinear constraint"
2924 );
2925 assert!(
2926 matches!(obj[1], Linearity::Linear),
2927 "x1 is linear everywhere"
2928 );
2929 }
2930
2931 const EQ_LIN: &str = "g3 0 1 0
29502 1 1 0 0
29510 1
29520 0
29530 2 0
29540 0 0 1
29550 0 0 0 0
29562 0
29570 0
29580 0 0 0 0
2959C0
2960n0
2961O0 0
2962o0
2963o5
2964v0
2965n2
2966o5
2967v1
2968n2
2969r
29704 1
2971b
29723
29733
2974k1
29752
2976J0 2
29770 1
29781 1
2979";
2980
2981 #[test]
2982 fn parses_constrained_problem() {
2983 let p = parse_nl_text(EQ_LIN).expect("parse");
2984 assert_eq!(p.n, 2);
2985 assert_eq!(p.m, 1);
2986 assert!((p.g_l[0] - 1.0).abs() < 1e-12);
2988 assert!((p.g_u[0] - 1.0).abs() < 1e-12);
2989 assert_eq!(p.con_linear[0], vec![(0, 1.0), (1, 1.0)]);
2991 }
2992
2993 #[test]
2994 fn malformed_j_variable_index_is_parse_error_not_panic() {
2995 let bad = EQ_LIN.replace("J0 2\n0 1\n1 1\n", "J0 2\n0 1\n5 1\n");
3001 assert_ne!(bad, EQ_LIN, "fixture substitution must apply");
3002 let err = parse_nl_text(&bad).expect_err("out-of-range J var must error");
3003 assert!(err.contains("out of range"), "unexpected error: {err}");
3004 }
3005
3006 #[test]
3007 fn out_of_range_x_segment_index_is_parse_error() {
3008 let bad = format!("{EQ_LIN}x1\n5 0.5\n");
3012 let err = parse_nl_text(&bad).expect_err("out-of-range x index must error");
3013 assert!(err.contains("out of range"), "unexpected error: {err}");
3014 }
3015
3016 #[test]
3017 fn k_segment_nonstandard_count_is_parse_error_at_source() {
3018 let bad = EQ_LIN.replace("k1\n2\n", "k0\n");
3028 assert_ne!(bad, EQ_LIN, "fixture substitution must apply");
3029 let err = parse_nl_text(&bad).expect_err("nonstandard k count must error");
3030 assert!(
3031 err.contains("k-segment declares"),
3032 "expected a clear k-segment count error, got: {err}"
3033 );
3034 }
3035
3036 #[test]
3037 fn get_starting_point_returns_nl_initial_duals() {
3038 let nl = format!("{EQ_LIN}\nd1\n0 2.5\n");
3046 let p = parse_nl_text(&nl).expect("parse");
3047 assert_eq!(p.lambda0, vec![2.5], "the `d` segment fills lambda0");
3048
3049 let mut t = NlTnlp::new(p);
3050 let info = t.get_nlp_info().unwrap();
3051 let (n, m) = (info.n as usize, info.m as usize);
3052
3053 let mut x = vec![0.0; n];
3056 let mut z_l = vec![0.0; n];
3057 let mut z_u = vec![0.0; n];
3058 let mut lambda = vec![0.0; m];
3059 assert!(t.get_starting_point(StartingPoint {
3060 init_x: true,
3061 x: &mut x,
3062 init_z: false,
3063 z_l: &mut z_l,
3064 z_u: &mut z_u,
3065 init_lambda: true,
3066 lambda: &mut lambda,
3067 }));
3068 assert_eq!(
3069 lambda,
3070 vec![2.5],
3071 "a warm start must use the `.nl` initial duals, not zero"
3072 );
3073
3074 let mut lambda_untouched = vec![7.0; m];
3077 assert!(t.get_starting_point(StartingPoint {
3078 init_x: true,
3079 x: &mut x,
3080 init_z: false,
3081 z_l: &mut z_l,
3082 z_u: &mut z_u,
3083 init_lambda: false,
3084 lambda: &mut lambda_untouched,
3085 }));
3086 assert_eq!(
3087 lambda_untouched,
3088 vec![7.0],
3089 "without init_lambda the multiplier buffer must be untouched"
3090 );
3091 }
3092
3093 #[test]
3094 fn constrained_tnlp_eval_g_jac_h() {
3095 let p = parse_nl_text(EQ_LIN).expect("parse");
3096 let mut t = NlTnlp::new(p);
3097 let info = t.get_nlp_info().unwrap();
3098 assert_eq!(info.m, 1);
3099 assert_eq!(info.nnz_jac_g, 2);
3100
3101 let mut g = [0.0_f64; 1];
3103 assert!(t.eval_g(&[0.3, 0.4], true, &mut g));
3104 assert!((g[0] - 0.7).abs() < 1e-12);
3105
3106 let mut irow = [0_i32; 2];
3108 let mut jcol = [0_i32; 2];
3109 assert!(t.eval_jac_g(
3110 None,
3111 true,
3112 SparsityRequest::Structure {
3113 irow: &mut irow,
3114 jcol: &mut jcol
3115 }
3116 ));
3117 assert_eq!(irow, [0, 0]);
3118 assert_eq!(jcol, [0, 1]);
3119
3120 let mut vals = [0.0_f64; 2];
3122 assert!(t.eval_jac_g(
3123 Some(&[0.3, 0.4]),
3124 true,
3125 SparsityRequest::Values { values: &mut vals }
3126 ));
3127 assert!((vals[0] - 1.0).abs() < 1e-12);
3128 assert!((vals[1] - 1.0).abs() < 1e-12);
3129
3130 assert_eq!(info.nnz_h_lag, 2);
3135 let mut hirow = [0_i32; 2];
3136 let mut hjcol = [0_i32; 2];
3137 assert!(t.eval_h(
3138 None,
3139 true,
3140 1.0,
3141 None,
3142 true,
3143 SparsityRequest::Structure {
3144 irow: &mut hirow,
3145 jcol: &mut hjcol
3146 }
3147 ));
3148 assert_eq!(hirow, [0, 1]);
3149 assert_eq!(hjcol, [0, 1]);
3150 let mut hvals = [0.0_f64; 2];
3151 assert!(t.eval_h(
3152 Some(&[0.3, 0.4]),
3153 true,
3154 1.0,
3155 Some(&[0.5]),
3156 true,
3157 SparsityRequest::Values { values: &mut hvals }
3158 ));
3159 assert!((hvals[0] - 2.0).abs() < 1e-12);
3160 assert!((hvals[1] - 2.0).abs() < 1e-12);
3161 }
3162
3163 const CSE_OBJ: &str = "g3 0 1 0
31672 0 1 0 0
31680 1
31690 0
31700 2 0
31710 0 0 1
31720 0 0 0 0
31730 0
31740 0
31750 1 0 0 0
3176V2 0 0
3177o0
3178v0
3179v1
3180O0 0
3181o0
3182o5
3183v2
3184n2
3185v2
3186b
31873
31883
3189";
3190
3191 #[test]
3192 fn parses_v_segment_cse() {
3193 let p = parse_nl_text(CSE_OBJ).expect("parse");
3194 assert_eq!(p.n, 2);
3195 let f = eval_expr(&p.obj_nonlinear, &[1.0, 2.0]);
3197 assert!((f - 12.0).abs() < 1e-12, "got {f}");
3198 let mut g = [0.0_f64; 2];
3200 grad_expr(&p.obj_nonlinear, &[1.0, 2.0], 1.0, &mut g);
3201 assert!((g[0] - 7.0).abs() < 1e-12, "g[0]={}", g[0]);
3202 assert!((g[1] - 7.0).abs() < 1e-12, "g[1]={}", g[1]);
3203 let mut vs = BTreeSet::new();
3205 collect_vars(&p.obj_nonlinear, &mut vs);
3206 assert_eq!(vs.into_iter().collect::<Vec<_>>(), vec![0, 1]);
3207 }
3208
3209 const WITH_SUFFIXES: &str = "g3 0 1 0
32151 0 1 0 0
32160 1
32170 0
32180 1 0
32190 0 0 1
32200 0 0 0 0
32210 0
32220 0
32230 0 0 0 0
3224O0 0
3225o5
3226o1
3227v0
3228n1
3229n2
3230b
32313
3232S0 1 sens_state_1
32330 7
3234S4 1 sens_state_value_1
32350 4.5
3236";
3237
3238 #[test]
3239 fn parses_var_int_and_var_real_suffixes() {
3240 let p = parse_nl_text(WITH_SUFFIXES).expect("parse");
3241 let v = p.suffixes.var_int.get("sens_state_1").expect("var_int");
3243 assert_eq!(v.as_slice(), &[7]);
3244 let r = p
3246 .suffixes
3247 .var_real
3248 .get("sens_state_value_1")
3249 .expect("var_real");
3250 assert_eq!(r.len(), 1);
3251 assert!((r[0] - 4.5).abs() < 1e-12);
3252 assert!(p.suffixes.con_int.is_empty());
3254 assert!(p.suffixes.con_real.is_empty());
3255 }
3256
3257 const WITH_CON_SUFFIX: &str = "g3 0 1 0
32602 2 1 0 0
32610 0
32620 0
32630 2 0
32640 0 0 1
32650 0 0 0 0
32662 0
32670 0
32680 0 0 0 0 0
3269C0
3270n0
3271C1
3272n0
3273O0 0
3274n0
3275r
32764 0.0
32774 0.0
3278b
32793
32803
3281k1
32820
3283J0 2
32840 1
32851 1
3286J1 2
32870 1
32881 -1
3289S1 2 sens_init_constr
32900 1
32911 2
3292";
3293
3294 #[test]
3295 fn parses_con_int_suffix() {
3296 let p = parse_nl_text(WITH_CON_SUFFIX).expect("parse");
3297 let s = p.suffixes.con_int.get("sens_init_constr").expect("con_int");
3298 assert_eq!(s.as_slice(), &[1, 2]);
3300 }
3301
3302 #[test]
3303 fn rejects_suffix_with_out_of_range_index() {
3304 let bad = WITH_CON_SUFFIX.replace("1 2\n", "5 2\n"); let err = parse_nl_text(&bad).expect_err("must reject");
3306 assert!(
3307 err.contains("out of range"),
3308 "expected out-of-range error, got: {err}"
3309 );
3310 }
3311
3312 #[test]
3313 fn tnlp_round_trip_solves() {
3314 let p = parse_nl_text(SIMPLE).expect("parse");
3315 let mut tnlp = NlTnlp::new(p);
3316 let info = tnlp.get_nlp_info().unwrap();
3317 assert_eq!(info.n, 2);
3318 assert_eq!(info.m, 0);
3319 let f0 = tnlp.eval_f(&[0.0, 0.0], true).unwrap();
3320 assert!((f0 - 5.0).abs() < 1e-12);
3321 let mut g = [0.0_f64; 2];
3322 tnlp.eval_grad_f(&[0.0, 0.0], true, &mut g);
3323 assert!((g[0] - (-2.0)).abs() < 1e-12);
3325 assert!((g[1] - (-4.0)).abs() < 1e-12);
3326 }
3327
3328 use pounce_nlp::expression_provider::ExpressionProvider;
3335 use std::sync::atomic::{AtomicUsize, Ordering};
3336
3337 fn scratch_dir(tag: &str) -> std::path::PathBuf {
3339 static N: AtomicUsize = AtomicUsize::new(0);
3340 let seq = N.fetch_add(1, Ordering::Relaxed);
3341 let dir = std::env::temp_dir().join(format!(
3342 "pounce_nlnames_{}_{}_{}",
3343 std::process::id(),
3344 tag,
3345 seq
3346 ));
3347 std::fs::create_dir_all(&dir).expect("create scratch dir");
3348 dir
3349 }
3350
3351 #[test]
3352 fn read_name_file_reads_in_order() {
3353 let dir = scratch_dir("col_order");
3354 let p = dir.join("m.col");
3355 std::fs::write(&p, "x_in\nT_reactor\nflow\n").unwrap();
3356 assert_eq!(read_name_file(&p, 3), vec!["x_in", "T_reactor", "flow"]);
3357 }
3358
3359 #[test]
3360 fn read_name_file_truncates_extra_lines() {
3361 let dir = scratch_dir("row_obj");
3365 let p = dir.join("m.row");
3366 std::fs::write(&p, "mass_balance\nenergy_balance\nobj\n").unwrap();
3367 assert_eq!(
3368 read_name_file(&p, 2),
3369 vec!["mass_balance", "energy_balance"]
3370 );
3371 }
3372
3373 #[test]
3374 fn read_name_file_empty_on_short_or_missing() {
3375 let dir = scratch_dir("short");
3376 let short = dir.join("m.col");
3377 std::fs::write(&short, "only_one\n").unwrap();
3378 assert!(read_name_file(&short, 3).is_empty());
3380 assert!(read_name_file(&dir.join("absent.col"), 2).is_empty());
3382 }
3383
3384 #[test]
3385 fn read_nl_file_captures_sibling_names() {
3386 let dir = scratch_dir("sibling");
3389 let nl = dir.join("m.nl");
3390 std::fs::write(&nl, SIMPLE).unwrap();
3391 std::fs::write(dir.join("m.col"), "alpha\nbeta\n").unwrap();
3392
3393 let prob = read_nl_file(&nl).expect("parse + name capture");
3394 assert_eq!(prob.var_names, vec!["alpha", "beta"]);
3395 assert!(prob.con_names.is_empty()); let tnlp = NlTnlp::new(prob);
3398 assert_eq!(tnlp.variable_name(0), Some("alpha"));
3399 assert_eq!(tnlp.variable_name(1), Some("beta"));
3400 assert_eq!(tnlp.variable_name(2), None); }
3402
3403 #[test]
3404 fn read_nl_file_without_names_yields_empty() {
3405 let dir = scratch_dir("noname");
3406 let nl = dir.join("m.nl");
3407 std::fs::write(&nl, SIMPLE).unwrap();
3408 let prob = read_nl_file(&nl).expect("parse");
3409 assert!(prob.var_names.is_empty());
3410 assert!(prob.con_names.is_empty());
3411 let tnlp = NlTnlp::new(prob);
3412 assert_eq!(tnlp.variable_name(0), None);
3413 }
3414
3415 #[test]
3416 fn read_nl_file_resolves_extensionless_ampl_stub() {
3417 let dir = scratch_dir("stub");
3421 std::fs::write(dir.join("mystub.nl"), SIMPLE).unwrap();
3422 let stub = dir.join("mystub");
3424 assert!(!stub.exists(), "stub must be extensionless / absent");
3425 let prob = read_nl_file(&stub).expect("stub should resolve to mystub.nl");
3426 assert_eq!(prob.n, 2);
3427 assert_eq!(prob.m, 0);
3428
3429 std::fs::write(dir.join("mystub.col"), "alpha\nbeta\n").unwrap();
3431 let prob = read_nl_file(&stub).expect("stub resolves, names ride along");
3432 assert_eq!(prob.var_names, vec!["alpha", "beta"]);
3433 }
3434
3435 #[test]
3436 fn read_nl_file_prefers_exact_path_over_nl_sibling() {
3437 let dir = scratch_dir("exact");
3441 std::fs::write(dir.join("data"), SIMPLE).unwrap();
3443 std::fs::write(dir.join("data.nl"), "not an nl file").unwrap();
3444 let prob = read_nl_file(&dir.join("data")).expect("exact path wins");
3445 assert_eq!(prob.n, 2);
3446 }
3447
3448 #[test]
3449 fn append_extension_appends_rather_than_replaces() {
3450 use std::path::Path;
3451 assert_eq!(
3452 append_extension(Path::new("mystub"), "nl"),
3453 Path::new("mystub.nl")
3454 );
3455 assert_eq!(
3458 append_extension(Path::new("my.model"), "nl"),
3459 Path::new("my.model.nl")
3460 );
3461 }
3462
3463 fn names(v: &[&str]) -> Vec<String> {
3466 v.iter().map(|s| s.to_string()).collect()
3467 }
3468
3469 #[test]
3470 fn render_uses_variable_names_when_present() {
3471 let e = Expr::Binary(BinOp::Mul, Box::new(Expr::Var(0)), Box::new(Expr::Var(1)));
3472 assert_eq!(render_expr(&e, &names(&["T", "flow"]), &[]), "T*flow");
3473 assert_eq!(render_expr(&e, &[], &[]), "x[0]*x[1]");
3475 }
3476
3477 #[test]
3478 fn render_parenthesizes_by_precedence() {
3479 let sum = Expr::Binary(BinOp::Add, Box::new(Expr::Var(0)), Box::new(Expr::Var(1)));
3481 let e = Expr::Binary(BinOp::Mul, Box::new(sum), Box::new(Expr::Var(2)));
3482 assert_eq!(render_expr(&e, &[], &[]), "(x[0] + x[1])*x[2]");
3483
3484 let mul = Expr::Binary(BinOp::Mul, Box::new(Expr::Var(1)), Box::new(Expr::Var(2)));
3486 let e2 = Expr::Binary(BinOp::Add, Box::new(Expr::Var(0)), Box::new(mul));
3487 assert_eq!(render_expr(&e2, &[], &[]), "x[0] + x[1]*x[2]");
3488 }
3489
3490 #[test]
3491 fn render_subtraction_right_assoc_parens() {
3492 let inner = Expr::Binary(BinOp::Sub, Box::new(Expr::Var(1)), Box::new(Expr::Var(2)));
3494 let e = Expr::Binary(BinOp::Sub, Box::new(Expr::Var(0)), Box::new(inner));
3495 assert_eq!(render_expr(&e, &[], &[]), "x[0] - (x[1] - x[2])");
3496 }
3497
3498 #[test]
3499 fn render_functions_and_pow() {
3500 let sq = Expr::Binary(
3501 BinOp::Pow,
3502 Box::new(Expr::Var(0)),
3503 Box::new(Expr::Const(2.0)),
3504 );
3505 let e = Expr::Unary(UnaryOp::Exp, Box::new(sq));
3506 assert_eq!(render_expr(&e, &names(&["q"]), &[]), "exp(q^2)");
3507 }
3508
3509 #[test]
3510 fn render_linear_signs_are_tidy() {
3511 let lin = vec![(0usize, 1.0), (1, -2.0), (2, 1.0)];
3513 assert_eq!(render_linear(&lin, &names(&["a", "b", "c"])), "a - 2*b + c");
3514 }
3515
3516 #[test]
3517 fn render_linear_skips_zero_coefficients() {
3518 let lin = vec![(0usize, 1.0), (1, 0.0), (2, -3.0)];
3521 assert_eq!(render_linear(&lin, &names(&["a", "b", "c"])), "a - 3*c");
3522 let lin = vec![(0usize, 0.0), (1, 2.0)];
3524 assert_eq!(render_linear(&lin, &names(&["a", "b"])), "2*b");
3525 }
3526
3527 #[test]
3528 fn render_sum_folds_negative_terms() {
3529 let sq = |i| {
3531 Expr::Binary(
3532 BinOp::Pow,
3533 Box::new(Expr::Var(i)),
3534 Box::new(Expr::Const(2.0)),
3535 )
3536 };
3537 let neg = |i| {
3538 Expr::Binary(
3539 BinOp::Mul,
3540 Box::new(Expr::Const(-1.0)),
3541 Box::new(Expr::Var(i)),
3542 )
3543 };
3544 let e = Expr::Sum(vec![
3545 sq(0),
3546 neg(1),
3547 Expr::Unary(UnaryOp::Neg, Box::new(Expr::Var(2))),
3548 ]);
3549 assert_eq!(
3550 render_expr(&e, &names(&["a", "b", "c"]), &[]),
3551 "a^2 - 1*b - c"
3552 );
3553 }
3554
3555 #[test]
3556 fn render_constraint_equation_forms() {
3557 let mut prob = parse_nl_text(SIMPLE).unwrap();
3559 prob.n = 2;
3561 prob.m = 2;
3562 prob.var_names = names(&["mass_in", "mass_out"]);
3563 prob.con_names = names(&["balance", "window"]);
3564 prob.con_linear = vec![
3565 vec![(0, 1.0), (1, -1.0)], vec![(0, 1.0)], ];
3568 prob.con_nonlinear = vec![Expr::Const(0.0), Expr::Const(0.0)];
3569 prob.g_l = vec![0.0, 0.0];
3570 prob.g_u = vec![0.0, 500.0];
3571
3572 assert_eq!(
3573 render_constraint_equation(&prob, 0),
3574 "mass_in - mass_out = 0"
3575 );
3576 assert_eq!(render_constraint_equation(&prob, 1), "0 <= mass_in <= 500");
3577
3578 let all = render_all_constraint_equations(&prob);
3579 assert_eq!(all.len(), 2);
3580 assert_eq!(all[1], "0 <= mass_in <= 500");
3581 }
3582
3583 #[test]
3584 fn constraint_jacobian_sparsity_unions_linear_and_nonlinear() {
3585 let mut prob = parse_nl_text(SIMPLE).unwrap();
3586 prob.n = 3;
3587 prob.m = 2;
3588 prob.con_linear = vec![vec![(1, 4.0)], vec![(2, 1.0)]];
3591 prob.con_nonlinear = vec![
3592 Expr::Binary(BinOp::Mul, Box::new(Expr::Var(0)), Box::new(Expr::Var(2))),
3593 Expr::Const(0.0),
3594 ];
3595 prob.g_l = vec![0.0, 0.0];
3596 prob.g_u = vec![0.0, 0.0];
3597
3598 let (irow, jcol) = constraint_jacobian_sparsity(&prob);
3599 assert_eq!(irow, vec![0, 0, 0, 1]);
3601 assert_eq!(jcol, vec![0, 1, 2, 2]);
3602 }
3603
3604 #[test]
3605 fn funcall_string_arg_with_hash_is_not_truncated() {
3606 let mut p = Parser::new("h3:a#b\n");
3612 match p.parse_funcall_arg().expect("parse hollerith arg") {
3613 FuncallArg::Str(s) => assert_eq!(s, "a#b"),
3614 other => panic!("expected Str, got {other:?}"),
3615 }
3616 }
3617
3618 #[test]
3619 fn funcall_string_arg_honors_declared_length() {
3620 let mut p = Parser::new("h3:abc # trailing comment\n");
3624 match p.parse_funcall_arg().expect("parse hollerith arg") {
3625 FuncallArg::Str(s) => assert_eq!(s, "abc"),
3626 other => panic!("expected Str, got {other:?}"),
3627 }
3628 }
3629}