1use super::LanguageParser;
2use cnvx_core::{LinExpr, Model, Objective, VarId};
3use std::collections::HashMap;
4
5#[derive(Default)]
6pub struct MPSLanguage;
7
8impl MPSLanguage {
9 pub fn new() -> Self {
10 Self {}
11 }
12}
13
14impl LanguageParser for MPSLanguage {
15 fn parse(&self, src: &str) -> Result<Model, String> {
16 let mut model = Model::new();
17 let mut section = "";
18
19 let mut rows: HashMap<String, char> = HashMap::new();
20 let mut col_exprs: HashMap<String, LinExpr> = HashMap::new();
21 let mut rhs_map: HashMap<String, f64> = HashMap::new();
22 let mut var_map: HashMap<String, VarId> = HashMap::new();
23
24 for raw in src.lines() {
25 let line = raw.trim();
26 if line.is_empty() {
27 continue;
28 }
29 if line.eq_ignore_ascii_case("ROWS") {
30 section = "ROWS";
31 continue;
32 } else if line.eq_ignore_ascii_case("COLUMNS") {
33 section = "COLUMNS";
34 continue;
35 } else if line.eq_ignore_ascii_case("RHS") {
36 section = "RHS";
37 continue;
38 } else if line.eq_ignore_ascii_case("BOUNDS") {
39 section = "BOUNDS";
40 continue;
41 } else if line.eq_ignore_ascii_case("ENDATA") {
42 break;
43 }
44
45 match section {
46 "ROWS" => {
47 let parts: Vec<_> = line.split_whitespace().collect();
48 if parts.len() >= 2 {
49 if parts[0].ends_with('.') && parts.len() >= 3 {
50 let rtype = parts[1].chars().next().unwrap_or(' ');
51 let name = parts[2].to_string();
52 rows.insert(name, rtype);
53 } else {
54 let rtype = parts[0].chars().next().unwrap_or(' ');
55 let name = parts[1].to_string();
56 rows.insert(name, rtype);
57 }
58 }
59 }
60 "COLUMNS" => {
61 let parts: Vec<_> = line.split_whitespace().collect();
62 if parts.len() < 2 {
63 continue;
64 }
65 let mut idx = 0;
66 if parts[0].ends_with('.') && parts.len() >= 2 {
67 idx = 1;
68 }
69 if parts.len() <= idx {
70 continue;
71 }
72 let col = parts[idx].to_string();
73 let varid = *var_map
74 .entry(col.clone())
75 .or_insert_with(|| model.add_var().finish());
76 let mut i = idx + 1;
77 while i + 1 < parts.len() {
78 let row = parts[i].to_string();
79 let val = parts[i + 1].parse::<f64>().map_err(|_| {
80 format!("invalid number in COLUMNS: {}", parts[i + 1])
81 })?;
82 let entry = col_exprs
83 .entry(row.clone())
84 .or_insert(LinExpr::constant(0.0));
85 *entry += LinExpr::new(varid, val);
86 i += 2;
87 }
88 }
89 "RHS" => {
90 let parts: Vec<_> = line.split_whitespace().collect();
91 if parts.len() < 3 {
92 continue;
93 }
94 let mut idx = 0;
95 if parts[0].ends_with('.') && parts.len() >= 2 {
96 idx = 1;
97 }
98 let mut i = idx + 1;
99 while i + 1 < parts.len() {
100 let row = parts[i].to_string();
101 let val = parts[i + 1].parse::<f64>().map_err(|_| {
102 format!("invalid number in RHS: {}", parts[i + 1])
103 })?;
104 rhs_map.insert(row, val);
105 i += 2;
106 }
107 }
108 "BOUNDS" => {
109 let parts: Vec<_> = line.split_whitespace().collect();
110 if parts.len() < 3 {
111 continue;
112 }
113 let mut idx = 0;
114 if parts[0].ends_with('.') && parts.len() >= 2 {
115 idx = 1;
116 }
117 let btype = parts[idx];
118 if parts.len() <= idx + 2 {
119 continue;
120 }
121 let varname = parts[idx + 2].to_string();
122 let varid = *var_map
123 .entry(varname.clone())
124 .or_insert_with(|| model.add_var().finish());
125 match btype {
126 "UP" => {
127 if parts.len() >= idx + 4
128 && let Ok(v) = parts[idx + 3].parse::<f64>()
129 {
130 model.vars[varid.0].ub = Some(v);
131 }
132 }
133 "LO" => {
134 if parts.len() >= idx + 4
135 && let Ok(v) = parts[idx + 3].parse::<f64>()
136 {
137 model.vars[varid.0].lb = Some(v);
138 }
139 }
140 "FR" => {
141 model.vars[varid.0].lb = None;
142 model.vars[varid.0].ub = None;
143 }
144 "MI" => {
145 model.vars[varid.0].lb = None;
146 }
147 "BV" => {
148 model.vars[varid.0].is_integer = true;
149 model.vars[varid.0].lb = Some(0.0);
150 model.vars[varid.0].ub = Some(1.0);
151 }
152 "FX" => {
153 if parts.len() >= idx + 4
154 && let Ok(v) = parts[idx + 3].parse::<f64>()
155 {
156 model.vars[varid.0].lb = Some(v);
157 model.vars[varid.0].ub = Some(v);
158 }
159 }
160 _ => {}
161 }
162 }
163 _ => {}
164 }
165 }
166
167 for (rname, rtype) in &rows {
168 let expr = col_exprs.get(rname).cloned().unwrap_or(LinExpr::constant(0.0));
169 let rhs = *rhs_map.get(rname).unwrap_or(&0.0);
170 match *rtype {
171 'N' => {
172 model.add_objective(Objective::minimize(expr).name("Z"));
173 }
174 'L' => {
175 model += expr.leq(rhs);
176 }
177 'G' => {
178 model += expr.geq(rhs);
179 }
180 'E' => {
181 model += expr.eq(rhs);
182 }
183 _ => {}
184 }
185 }
186
187 Ok(model)
188 }
189}