1use nalgebra::{DMatrix, DVector};
4use runmat_builtins::{
5 BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6 BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7 StructValue, Tensor, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12 BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13 ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::math::optim::type_resolvers::linear_programming_type;
16use crate::{build_runtime_error, BuiltinResult, RuntimeError};
17
18const NAME: &str = "linprog";
19const ALGORITHM: &str = "active-set vertex enumeration";
20const TOL: f64 = 1.0e-8;
21
22const LINPROG_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
23 name: "x",
24 ty: BuiltinParamType::NumericArray,
25 arity: BuiltinParamArity::Required,
26 default: None,
27 description: "Optimal decision vector.",
28}];
29
30const LINPROG_OUTPUT_X_FVAL: [BuiltinParamDescriptor; 2] = [
31 BuiltinParamDescriptor {
32 name: "x",
33 ty: BuiltinParamType::NumericArray,
34 arity: BuiltinParamArity::Required,
35 default: None,
36 description: "Optimal decision vector.",
37 },
38 BuiltinParamDescriptor {
39 name: "fval",
40 ty: BuiltinParamType::NumericScalar,
41 arity: BuiltinParamArity::Required,
42 default: None,
43 description: "Objective value f'*x at the solution.",
44 },
45];
46
47const LINPROG_OUTPUT_X_FVAL_EXITFLAG: [BuiltinParamDescriptor; 3] = [
48 BuiltinParamDescriptor {
49 name: "x",
50 ty: BuiltinParamType::NumericArray,
51 arity: BuiltinParamArity::Required,
52 default: None,
53 description: "Optimal decision vector.",
54 },
55 BuiltinParamDescriptor {
56 name: "fval",
57 ty: BuiltinParamType::NumericScalar,
58 arity: BuiltinParamArity::Required,
59 default: None,
60 description: "Objective value f'*x at the solution.",
61 },
62 BuiltinParamDescriptor {
63 name: "exitflag",
64 ty: BuiltinParamType::NumericScalar,
65 arity: BuiltinParamArity::Required,
66 default: None,
67 description: "Solver status code.",
68 },
69];
70
71const LINPROG_OUTPUT_ALL: [BuiltinParamDescriptor; 4] = [
72 BuiltinParamDescriptor {
73 name: "x",
74 ty: BuiltinParamType::NumericArray,
75 arity: BuiltinParamArity::Required,
76 default: None,
77 description: "Optimal decision vector.",
78 },
79 BuiltinParamDescriptor {
80 name: "fval",
81 ty: BuiltinParamType::NumericScalar,
82 arity: BuiltinParamArity::Required,
83 default: None,
84 description: "Objective value f'*x at the solution.",
85 },
86 BuiltinParamDescriptor {
87 name: "exitflag",
88 ty: BuiltinParamType::NumericScalar,
89 arity: BuiltinParamArity::Required,
90 default: None,
91 description: "Solver status code.",
92 },
93 BuiltinParamDescriptor {
94 name: "output",
95 ty: BuiltinParamType::Any,
96 arity: BuiltinParamArity::Required,
97 default: None,
98 description: "Diagnostic metadata struct.",
99 },
100];
101
102const LINPROG_INPUTS_CORE: [BuiltinParamDescriptor; 3] = [
103 BuiltinParamDescriptor {
104 name: "f",
105 ty: BuiltinParamType::NumericArray,
106 arity: BuiltinParamArity::Required,
107 default: None,
108 description: "Linear objective vector.",
109 },
110 BuiltinParamDescriptor {
111 name: "A",
112 ty: BuiltinParamType::NumericArray,
113 arity: BuiltinParamArity::Required,
114 default: None,
115 description: "Inequality constraint matrix.",
116 },
117 BuiltinParamDescriptor {
118 name: "b",
119 ty: BuiltinParamType::NumericArray,
120 arity: BuiltinParamArity::Required,
121 default: None,
122 description: "Inequality constraint right-hand side.",
123 },
124];
125
126const LINPROG_INPUTS_EQ: [BuiltinParamDescriptor; 5] = [
127 BuiltinParamDescriptor {
128 name: "f",
129 ty: BuiltinParamType::NumericArray,
130 arity: BuiltinParamArity::Required,
131 default: None,
132 description: "Linear objective vector.",
133 },
134 BuiltinParamDescriptor {
135 name: "A",
136 ty: BuiltinParamType::NumericArray,
137 arity: BuiltinParamArity::Required,
138 default: None,
139 description: "Inequality constraint matrix.",
140 },
141 BuiltinParamDescriptor {
142 name: "b",
143 ty: BuiltinParamType::NumericArray,
144 arity: BuiltinParamArity::Required,
145 default: None,
146 description: "Inequality constraint right-hand side.",
147 },
148 BuiltinParamDescriptor {
149 name: "Aeq",
150 ty: BuiltinParamType::NumericArray,
151 arity: BuiltinParamArity::Optional,
152 default: Some("[]"),
153 description: "Equality constraint matrix.",
154 },
155 BuiltinParamDescriptor {
156 name: "beq",
157 ty: BuiltinParamType::NumericArray,
158 arity: BuiltinParamArity::Optional,
159 default: Some("[]"),
160 description: "Equality constraint right-hand side.",
161 },
162];
163
164const LINPROG_INPUTS_BOUNDS: [BuiltinParamDescriptor; 7] = [
165 BuiltinParamDescriptor {
166 name: "f",
167 ty: BuiltinParamType::NumericArray,
168 arity: BuiltinParamArity::Required,
169 default: None,
170 description: "Linear objective vector.",
171 },
172 BuiltinParamDescriptor {
173 name: "A",
174 ty: BuiltinParamType::NumericArray,
175 arity: BuiltinParamArity::Required,
176 default: None,
177 description: "Inequality constraint matrix.",
178 },
179 BuiltinParamDescriptor {
180 name: "b",
181 ty: BuiltinParamType::NumericArray,
182 arity: BuiltinParamArity::Required,
183 default: None,
184 description: "Inequality constraint right-hand side.",
185 },
186 BuiltinParamDescriptor {
187 name: "Aeq",
188 ty: BuiltinParamType::NumericArray,
189 arity: BuiltinParamArity::Optional,
190 default: Some("[]"),
191 description: "Equality constraint matrix.",
192 },
193 BuiltinParamDescriptor {
194 name: "beq",
195 ty: BuiltinParamType::NumericArray,
196 arity: BuiltinParamArity::Optional,
197 default: Some("[]"),
198 description: "Equality constraint right-hand side.",
199 },
200 BuiltinParamDescriptor {
201 name: "lb",
202 ty: BuiltinParamType::NumericArray,
203 arity: BuiltinParamArity::Optional,
204 default: Some("[]"),
205 description: "Lower bounds.",
206 },
207 BuiltinParamDescriptor {
208 name: "ub",
209 ty: BuiltinParamType::NumericArray,
210 arity: BuiltinParamArity::Optional,
211 default: Some("[]"),
212 description: "Upper bounds.",
213 },
214];
215
216const LINPROG_SIGNATURES: [BuiltinSignatureDescriptor; 12] = [
217 BuiltinSignatureDescriptor {
218 label: "x = linprog(f, A, b)",
219 inputs: &LINPROG_INPUTS_CORE,
220 outputs: &LINPROG_OUTPUT_X,
221 },
222 BuiltinSignatureDescriptor {
223 label: "x = linprog(f, A, b, Aeq, beq)",
224 inputs: &LINPROG_INPUTS_EQ,
225 outputs: &LINPROG_OUTPUT_X,
226 },
227 BuiltinSignatureDescriptor {
228 label: "x = linprog(f, A, b, Aeq, beq, lb, ub)",
229 inputs: &LINPROG_INPUTS_BOUNDS,
230 outputs: &LINPROG_OUTPUT_X,
231 },
232 BuiltinSignatureDescriptor {
233 label: "[x, fval] = linprog(f, A, b)",
234 inputs: &LINPROG_INPUTS_CORE,
235 outputs: &LINPROG_OUTPUT_X_FVAL,
236 },
237 BuiltinSignatureDescriptor {
238 label: "[x, fval] = linprog(f, A, b, Aeq, beq)",
239 inputs: &LINPROG_INPUTS_EQ,
240 outputs: &LINPROG_OUTPUT_X_FVAL,
241 },
242 BuiltinSignatureDescriptor {
243 label: "[x, fval] = linprog(f, A, b, Aeq, beq, lb, ub)",
244 inputs: &LINPROG_INPUTS_BOUNDS,
245 outputs: &LINPROG_OUTPUT_X_FVAL,
246 },
247 BuiltinSignatureDescriptor {
248 label: "[x, fval, exitflag] = linprog(f, A, b)",
249 inputs: &LINPROG_INPUTS_CORE,
250 outputs: &LINPROG_OUTPUT_X_FVAL_EXITFLAG,
251 },
252 BuiltinSignatureDescriptor {
253 label: "[x, fval, exitflag] = linprog(f, A, b, Aeq, beq)",
254 inputs: &LINPROG_INPUTS_EQ,
255 outputs: &LINPROG_OUTPUT_X_FVAL_EXITFLAG,
256 },
257 BuiltinSignatureDescriptor {
258 label: "[x, fval, exitflag] = linprog(f, A, b, Aeq, beq, lb, ub)",
259 inputs: &LINPROG_INPUTS_BOUNDS,
260 outputs: &LINPROG_OUTPUT_X_FVAL_EXITFLAG,
261 },
262 BuiltinSignatureDescriptor {
263 label: "[x, fval, exitflag, output] = linprog(f, A, b)",
264 inputs: &LINPROG_INPUTS_CORE,
265 outputs: &LINPROG_OUTPUT_ALL,
266 },
267 BuiltinSignatureDescriptor {
268 label: "[x, fval, exitflag, output] = linprog(f, A, b, Aeq, beq)",
269 inputs: &LINPROG_INPUTS_EQ,
270 outputs: &LINPROG_OUTPUT_ALL,
271 },
272 BuiltinSignatureDescriptor {
273 label: "[x, fval, exitflag, output] = linprog(f, A, b, Aeq, beq, lb, ub)",
274 inputs: &LINPROG_INPUTS_BOUNDS,
275 outputs: &LINPROG_OUTPUT_ALL,
276 },
277];
278
279const LINPROG_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
280 code: "RM.LINPROG.INVALID_ARGUMENT",
281 identifier: Some("RunMat:linprog:InvalidArgument"),
282 when: "The argument count or optional argument grammar is invalid.",
283 message: "linprog: invalid argument",
284};
285
286const LINPROG_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
287 code: "RM.LINPROG.INVALID_INPUT",
288 identifier: Some("RunMat:linprog:InvalidInput"),
289 when: "Objective, constraint, or bound dimensions/types are invalid.",
290 message: "linprog: invalid input",
291};
292
293const LINPROG_ERRORS: [BuiltinErrorDescriptor; 2] =
294 [LINPROG_ERROR_INVALID_ARGUMENT, LINPROG_ERROR_INVALID_INPUT];
295
296pub const LINPROG_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
297 signatures: &LINPROG_SIGNATURES,
298 output_mode: BuiltinOutputMode::ByRequestedOutputCount,
299 completion_policy: BuiltinCompletionPolicy::Public,
300 errors: &LINPROG_ERRORS,
301};
302
303fn linprog_error_with_detail(
304 error: &'static BuiltinErrorDescriptor,
305 detail: impl AsRef<str>,
306) -> RuntimeError {
307 let detail = detail.as_ref();
308 let message = if detail.starts_with("linprog:") {
309 detail.to_string()
310 } else {
311 format!("{}: {detail}", error.message)
312 };
313 let mut builder = build_runtime_error(message).with_builtin(NAME);
314 if let Some(identifier) = error.identifier {
315 builder = builder.with_identifier(identifier);
316 }
317 builder.build()
318}
319
320#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::linprog")]
321pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
322 name: "linprog",
323 op_kind: GpuOpKind::Custom("linear-programming"),
324 supported_precisions: &[],
325 broadcast: BroadcastSemantics::None,
326 provider_hooks: &[],
327 constant_strategy: ConstantStrategy::InlineLiteral,
328 residency: ResidencyPolicy::GatherImmediately,
329 nan_mode: ReductionNaN::Include,
330 two_pass_threshold: None,
331 workgroup_size: None,
332 accepts_nan_mode: false,
333 notes: "Host active-set LP solver. GPU-resident numeric inputs are gathered before solving.",
334};
335
336#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::linprog")]
337pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
338 name: "linprog",
339 shape: ShapeRequirements::Any,
340 constant_strategy: ConstantStrategy::InlineLiteral,
341 elementwise: None,
342 reduction: None,
343 emits_nan: false,
344 notes: "Linear programming is a solver boundary and terminates fusion planning.",
345};
346
347#[runtime_builtin(
348 name = "linprog",
349 category = "math/optim",
350 summary = "Solve a linear programming minimization problem with linear constraints and bounds.",
351 keywords = "linprog,linear programming,optimization,linear constraints,bounds",
352 accel = "sink",
353 type_resolver(linear_programming_type),
354 descriptor(crate::builtins::math::optim::linprog::LINPROG_DESCRIPTOR),
355 builtin_path = "crate::builtins::math::optim::linprog"
356)]
357async fn linprog_builtin(f: Value, a: Value, b: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
358 if rest.len() > 4 {
359 return Err(linprog_error_with_detail(
360 &LINPROG_ERROR_INVALID_ARGUMENT,
361 "too many input arguments",
362 ));
363 }
364
365 let f = numeric_vector("f", f, FiniteMode::Finite).await?;
366 if f.is_empty() {
367 return Err(linprog_error_with_detail(
368 &LINPROG_ERROR_INVALID_INPUT,
369 "f must be a nonempty numeric vector",
370 ));
371 }
372 let n = f.len();
373
374 let (mut a_ineq, mut b_ineq) = parse_constraint_pair("A", a, "b", b, n).await?;
375 let (a_eq, b_eq) = parse_optional_equality(rest.first(), rest.get(1), n).await?;
376 let (lb, ub) = parse_bounds(rest.get(2), rest.get(3), n).await?;
377
378 for i in 0..n {
379 if lb[i] > ub[i] + TOL {
380 return Ok(finalize(LinprogOutcome::infeasible(
381 "No feasible point found: lower bound exceeds upper bound.",
382 )));
383 }
384 if lb[i].is_finite() {
385 let mut row = vec![0.0; n];
386 row[i] = -1.0;
387 a_ineq.push(row);
388 b_ineq.push(-lb[i]);
389 }
390 if ub[i].is_finite() {
391 let mut row = vec![0.0; n];
392 row[i] = 1.0;
393 a_ineq.push(row);
394 b_ineq.push(ub[i]);
395 }
396 }
397
398 let problem = LinearProgram {
399 f,
400 a_ineq,
401 b_ineq,
402 a_eq,
403 b_eq,
404 };
405 Ok(finalize(solve_linprog(&problem)))
406}
407
408#[derive(Clone, Copy)]
409enum FiniteMode {
410 Finite,
411 Bounds,
412}
413
414#[derive(Clone)]
415struct MatrixInput {
416 rows: usize,
417 cols: usize,
418 data: Vec<f64>,
419}
420
421impl MatrixInput {
422 fn row(&self, row: usize) -> Vec<f64> {
423 (0..self.cols)
424 .map(|col| self.data[row + col * self.rows])
425 .collect()
426 }
427}
428
429async fn gather(value: Value) -> BuiltinResult<Value> {
430 crate::dispatcher::gather_if_needed_async(&value)
431 .await
432 .map_err(|err| linprog_error_with_detail(&LINPROG_ERROR_INVALID_INPUT, err.message()))
433}
434
435fn is_empty_value(value: &Value) -> bool {
436 matches!(value, Value::Tensor(t) if t.data.is_empty())
437}
438
439async fn numeric_vector(
440 label: &str,
441 value: Value,
442 finite_mode: FiniteMode,
443) -> BuiltinResult<Vec<f64>> {
444 let value = gather(value).await?;
445 if is_empty_value(&value) {
446 return Ok(Vec::new());
447 }
448 let data = match value {
449 Value::Num(n) => vec![n],
450 Value::Int(i) => vec![i.to_f64()],
451 Value::Tensor(t) => {
452 let dims = t.shape.len();
453 if dims > 2 || (t.rows() != 1 && t.cols() != 1) {
454 return Err(linprog_error_with_detail(
455 &LINPROG_ERROR_INVALID_INPUT,
456 format!("{label} must be a vector"),
457 ));
458 }
459 t.data
460 }
461 other => {
462 return Err(linprog_error_with_detail(
463 &LINPROG_ERROR_INVALID_INPUT,
464 format!("{label} must be a real numeric vector, got {other:?}"),
465 ))
466 }
467 };
468 validate_numbers(label, &data, finite_mode)?;
469 Ok(data)
470}
471
472async fn numeric_matrix(label: &str, value: Value) -> BuiltinResult<Option<MatrixInput>> {
473 let value = gather(value).await?;
474 if is_empty_value(&value) {
475 return Ok(None);
476 }
477 match value {
478 Value::Num(n) => {
479 validate_numbers(label, &[n], FiniteMode::Finite)?;
480 Ok(Some(MatrixInput {
481 rows: 1,
482 cols: 1,
483 data: vec![n],
484 }))
485 }
486 Value::Int(i) => {
487 let value = i.to_f64();
488 validate_numbers(label, &[value], FiniteMode::Finite)?;
489 Ok(Some(MatrixInput {
490 rows: 1,
491 cols: 1,
492 data: vec![value],
493 }))
494 }
495 Value::Tensor(t) => {
496 if t.shape.len() > 2 {
497 return Err(linprog_error_with_detail(
498 &LINPROG_ERROR_INVALID_INPUT,
499 format!("{label} must be a numeric matrix"),
500 ));
501 }
502 validate_numbers(label, &t.data, FiniteMode::Finite)?;
503 Ok(Some(MatrixInput {
504 rows: t.rows(),
505 cols: t.cols(),
506 data: t.data,
507 }))
508 }
509 other => Err(linprog_error_with_detail(
510 &LINPROG_ERROR_INVALID_INPUT,
511 format!("{label} must be a real numeric matrix, got {other:?}"),
512 )),
513 }
514}
515
516fn validate_numbers(label: &str, data: &[f64], finite_mode: FiniteMode) -> BuiltinResult<()> {
517 for value in data {
518 match finite_mode {
519 FiniteMode::Finite if !value.is_finite() => {
520 return Err(linprog_error_with_detail(
521 &LINPROG_ERROR_INVALID_INPUT,
522 format!("{label} values must be finite"),
523 ))
524 }
525 FiniteMode::Bounds if value.is_nan() => {
526 return Err(linprog_error_with_detail(
527 &LINPROG_ERROR_INVALID_INPUT,
528 format!("{label} bounds cannot be NaN"),
529 ))
530 }
531 _ => {}
532 }
533 }
534 Ok(())
535}
536
537async fn parse_constraint_pair(
538 matrix_label: &str,
539 matrix: Value,
540 rhs_label: &str,
541 rhs: Value,
542 n: usize,
543) -> BuiltinResult<(Vec<Vec<f64>>, Vec<f64>)> {
544 let matrix = numeric_matrix(matrix_label, matrix).await?;
545 let rhs = numeric_vector(rhs_label, rhs, FiniteMode::Finite).await?;
546 match (matrix, rhs.is_empty()) {
547 (None, true) => Ok((Vec::new(), Vec::new())),
548 (None, false) => Err(linprog_error_with_detail(
549 &LINPROG_ERROR_INVALID_INPUT,
550 format!("{matrix_label} cannot be empty when {rhs_label} is nonempty"),
551 )),
552 (Some(matrix), _) => {
553 if matrix.cols != n {
554 return Err(linprog_error_with_detail(
555 &LINPROG_ERROR_INVALID_INPUT,
556 format!("{matrix_label} must have one column per element of f"),
557 ));
558 }
559 if rhs.len() != matrix.rows {
560 return Err(linprog_error_with_detail(
561 &LINPROG_ERROR_INVALID_INPUT,
562 format!("{rhs_label} length must match rows of {matrix_label}"),
563 ));
564 }
565 let rows = (0..matrix.rows).map(|row| matrix.row(row)).collect();
566 Ok((rows, rhs))
567 }
568 }
569}
570
571async fn parse_optional_equality(
572 aeq: Option<&Value>,
573 beq: Option<&Value>,
574 n: usize,
575) -> BuiltinResult<(Vec<Vec<f64>>, Vec<f64>)> {
576 match (aeq, beq) {
577 (None, None) => Ok((Vec::new(), Vec::new())),
578 (Some(aeq), None) if is_empty_value(aeq) => Ok((Vec::new(), Vec::new())),
579 (Some(_), None) => Err(linprog_error_with_detail(
580 &LINPROG_ERROR_INVALID_ARGUMENT,
581 "Aeq requires a matching beq argument",
582 )),
583 (None, Some(_)) => Err(linprog_error_with_detail(
584 &LINPROG_ERROR_INVALID_ARGUMENT,
585 "beq requires a matching Aeq argument",
586 )),
587 (Some(aeq), Some(beq)) => {
588 parse_constraint_pair("Aeq", aeq.clone(), "beq", beq.clone(), n).await
589 }
590 }
591}
592
593async fn parse_bounds(
594 lb: Option<&Value>,
595 ub: Option<&Value>,
596 n: usize,
597) -> BuiltinResult<(Vec<f64>, Vec<f64>)> {
598 let lb = match lb {
599 None => vec![f64::NEG_INFINITY; n],
600 Some(value) if is_empty_value(value) => vec![f64::NEG_INFINITY; n],
601 Some(value) => {
602 let values = numeric_vector("lb", value.clone(), FiniteMode::Bounds).await?;
603 normalize_bound("lb", values, n)?
604 }
605 };
606 let ub = match ub {
607 None => vec![f64::INFINITY; n],
608 Some(value) if is_empty_value(value) => vec![f64::INFINITY; n],
609 Some(value) => {
610 let values = numeric_vector("ub", value.clone(), FiniteMode::Bounds).await?;
611 normalize_bound("ub", values, n)?
612 }
613 };
614 Ok((lb, ub))
615}
616
617fn normalize_bound(label: &str, values: Vec<f64>, n: usize) -> BuiltinResult<Vec<f64>> {
618 if values.len() == n {
619 Ok(values)
620 } else {
621 Err(linprog_error_with_detail(
622 &LINPROG_ERROR_INVALID_INPUT,
623 format!("{label} length must match f"),
624 ))
625 }
626}
627
628struct LinearProgram {
629 f: Vec<f64>,
630 a_ineq: Vec<Vec<f64>>,
631 b_ineq: Vec<f64>,
632 a_eq: Vec<Vec<f64>>,
633 b_eq: Vec<f64>,
634}
635
636#[derive(Clone)]
637struct LinprogOutcome {
638 x: Option<Vec<f64>>,
639 fval: Option<f64>,
640 exitflag: i32,
641 iterations: usize,
642 constrviolation: f64,
643 message: String,
644}
645
646impl LinprogOutcome {
647 fn infeasible(message: &str) -> Self {
648 Self {
649 x: None,
650 fval: None,
651 exitflag: -2,
652 iterations: 0,
653 constrviolation: 0.0,
654 message: message.to_string(),
655 }
656 }
657
658 fn unbounded(iterations: usize) -> Self {
659 Self {
660 x: None,
661 fval: None,
662 exitflag: -3,
663 iterations,
664 constrviolation: 0.0,
665 message: "Problem is unbounded.".to_string(),
666 }
667 }
668}
669
670fn solve_linprog(problem: &LinearProgram) -> LinprogOutcome {
671 let n = problem.f.len();
672 let Some(face) = equality_face(problem, n) else {
673 return LinprogOutcome::infeasible("No feasible point found.");
674 };
675 let reduced = reduce_to_equality_face(problem, &face);
676 let k = reduced.f.len();
677 let mut candidates = Vec::new();
678 let mut combinations = 0usize;
679
680 enumerate_vertices(&reduced, |y| {
681 combinations += 1;
682 if is_feasible(&reduced, &y) {
683 candidates.push(y);
684 }
685 });
686
687 let feasible_fallback = if candidates.is_empty() {
688 let y0 = vec![0.0; k];
689 is_feasible(&reduced, &y0).then_some(y0)
690 } else {
691 None
692 };
693 let has_feasible_point = !candidates.is_empty() || feasible_fallback.is_some();
694 if !has_feasible_point {
695 return LinprogOutcome::infeasible("No feasible point found.");
696 }
697 if has_unbounded_descent_direction(&reduced) {
698 return LinprogOutcome::unbounded(combinations);
699 }
700
701 if let Some(x) = feasible_fallback {
702 candidates.push(x);
703 }
704
705 let mut best_y = candidates[0].clone();
706 let mut best_fval = dot(&reduced.f, &best_y);
707 for candidate in candidates.into_iter().skip(1) {
708 let fval = dot(&reduced.f, &candidate);
709 if fval < best_fval - TOL {
710 best_y = candidate;
711 best_fval = fval;
712 }
713 }
714
715 let best = lift_from_equality_face(&face, &best_y);
716 let best_fval = dot(&problem.f, &best);
717 let constrviolation = constraint_violation(problem, &best);
718 LinprogOutcome {
719 x: Some(best),
720 fval: Some(best_fval),
721 exitflag: 1,
722 iterations: combinations,
723 constrviolation,
724 message: "Optimal solution found.".to_string(),
725 }
726}
727
728struct EqualityFace {
729 x0: Vec<f64>,
730 basis: Vec<Vec<f64>>,
731}
732
733fn equality_face(problem: &LinearProgram, n: usize) -> Option<EqualityFace> {
734 let x0 = if problem.a_eq.is_empty() {
735 vec![0.0; n]
736 } else {
737 pseudo_solve(&problem.a_eq, &problem.b_eq, n)?
738 };
739 Some(EqualityFace {
740 x0,
741 basis: nullspace_basis(&problem.a_eq, n),
742 })
743}
744
745fn reduce_to_equality_face(problem: &LinearProgram, face: &EqualityFace) -> LinearProgram {
746 LinearProgram {
747 f: face
748 .basis
749 .iter()
750 .map(|basis_vector| dot(&problem.f, basis_vector))
751 .collect(),
752 a_ineq: problem
753 .a_ineq
754 .iter()
755 .map(|row| {
756 face.basis
757 .iter()
758 .map(|basis_vector| dot(row, basis_vector))
759 .collect()
760 })
761 .collect(),
762 b_ineq: problem
763 .a_ineq
764 .iter()
765 .zip(&problem.b_ineq)
766 .map(|(row, rhs)| rhs - dot(row, &face.x0))
767 .collect(),
768 a_eq: Vec::new(),
769 b_eq: Vec::new(),
770 }
771}
772
773fn lift_from_equality_face(face: &EqualityFace, y: &[f64]) -> Vec<f64> {
774 let mut x = face.x0.clone();
775 for (coeff, basis_vector) in y.iter().zip(&face.basis) {
776 for (x_j, basis_j) in x.iter_mut().zip(basis_vector) {
777 *x_j += coeff * basis_j;
778 }
779 }
780 x
781}
782
783fn enumerate_vertices(problem: &LinearProgram, mut visit: impl FnMut(Vec<f64>)) {
784 let n = problem.f.len();
785 let max_active = problem.a_ineq.len().min(n);
786 for active_count in 0..=max_active {
787 enumerate_combinations(problem.a_ineq.len(), active_count, |active| {
788 let mut rows = problem.a_eq.clone();
789 let mut rhs = problem.b_eq.clone();
790 for &idx in active {
791 rows.push(problem.a_ineq[idx].clone());
792 rhs.push(problem.b_ineq[idx]);
793 }
794 if let Some(x) = pseudo_solve(&rows, &rhs, n) {
795 visit(x);
796 }
797 });
798 }
799}
800
801fn has_unbounded_descent_direction(problem: &LinearProgram) -> bool {
802 let n = problem.f.len();
803 let max_active = problem.a_ineq.len().min(n.saturating_sub(1));
804 for active_count in 0..=max_active {
805 let mut found = false;
806 enumerate_combinations(problem.a_ineq.len(), active_count, |active| {
807 if found {
808 return;
809 }
810 let mut rows = problem.a_eq.clone();
811 for &idx in active {
812 rows.push(problem.a_ineq[idx].clone());
813 }
814 for direction in candidate_nullspace_descent_directions(&rows, &problem.f, n) {
815 if is_recession_direction(problem, &direction) && dot(&problem.f, &direction) < -TOL
816 {
817 found = true;
818 return;
819 }
820 }
821 });
822 if found {
823 return true;
824 }
825 }
826 false
827}
828
829fn candidate_nullspace_descent_directions(rows: &[Vec<f64>], f: &[f64], n: usize) -> Vec<Vec<f64>> {
830 let basis = nullspace_basis(rows, n);
831 if basis.is_empty() {
832 return Vec::new();
833 }
834 let mut directions = Vec::new();
835 let mut projected = vec![0.0; n];
836 for basis_vector in &basis {
837 let coeff = -dot(f, basis_vector);
838 for i in 0..n {
839 projected[i] += coeff * basis_vector[i];
840 }
841 directions.push(basis_vector.clone());
842 directions.push(basis_vector.iter().map(|v| -*v).collect());
843 }
844 if norm(&projected) > TOL {
845 directions.push(projected);
846 }
847 directions
848}
849
850fn is_recession_direction(problem: &LinearProgram, direction: &[f64]) -> bool {
851 norm(direction) > TOL
852 && problem
853 .a_eq
854 .iter()
855 .all(|row| dot(row, direction).abs() <= TOL)
856 && problem.a_ineq.iter().all(|row| dot(row, direction) <= TOL)
857}
858
859fn is_feasible(problem: &LinearProgram, x: &[f64]) -> bool {
860 constraint_violation(problem, x) <= 1.0e-7
861}
862
863fn constraint_violation(problem: &LinearProgram, x: &[f64]) -> f64 {
864 let eq = problem
865 .a_eq
866 .iter()
867 .zip(&problem.b_eq)
868 .map(|(row, rhs)| (dot(row, x) - rhs).abs())
869 .fold(0.0, f64::max);
870 let ineq = problem
871 .a_ineq
872 .iter()
873 .zip(&problem.b_ineq)
874 .map(|(row, rhs)| (dot(row, x) - rhs).max(0.0))
875 .fold(0.0, f64::max);
876 eq.max(ineq)
877}
878
879fn nullspace_basis(rows: &[Vec<f64>], n: usize) -> Vec<Vec<f64>> {
880 if n == 0 {
881 return Vec::new();
882 }
883 if rows.is_empty() {
884 return (0..n)
885 .map(|i| {
886 let mut basis = vec![0.0; n];
887 basis[i] = 1.0;
888 basis
889 })
890 .collect();
891 }
892
893 let (reduced, pivots) = rref(rows, n);
894 let free_cols = (0..n)
895 .filter(|col| !pivots.contains(col))
896 .collect::<Vec<_>>();
897 free_cols
898 .into_iter()
899 .filter_map(|free_col| {
900 let mut basis = vec![0.0; n];
901 basis[free_col] = 1.0;
902 for (row, pivot_col) in pivots.iter().enumerate() {
903 basis[*pivot_col] = -reduced[row][free_col];
904 }
905 let length = norm(&basis);
906 (length > TOL).then(|| basis.into_iter().map(|value| value / length).collect())
907 })
908 .collect()
909}
910
911fn rref(rows: &[Vec<f64>], n: usize) -> (Vec<Vec<f64>>, Vec<usize>) {
912 let mut matrix = rows.to_vec();
913 let mut pivots = Vec::new();
914 let mut pivot_row = 0usize;
915
916 for col in 0..n {
917 let Some(best_row) = (pivot_row..matrix.len()).max_by(|&a, &b| {
918 matrix[a][col]
919 .abs()
920 .partial_cmp(&matrix[b][col].abs())
921 .unwrap_or(std::cmp::Ordering::Equal)
922 }) else {
923 break;
924 };
925 if matrix[best_row][col].abs() <= TOL {
926 continue;
927 }
928
929 matrix.swap(pivot_row, best_row);
930 let pivot = matrix[pivot_row][col];
931 for value in &mut matrix[pivot_row] {
932 *value /= pivot;
933 }
934
935 for row in 0..matrix.len() {
936 if row == pivot_row {
937 continue;
938 }
939 let factor = matrix[row][col];
940 if factor.abs() <= TOL {
941 continue;
942 }
943 for j in col..n {
944 matrix[row][j] -= factor * matrix[pivot_row][j];
945 }
946 }
947
948 pivots.push(col);
949 pivot_row += 1;
950 if pivot_row == matrix.len() {
951 break;
952 }
953 }
954
955 (matrix, pivots)
956}
957
958fn pseudo_solve(rows: &[Vec<f64>], rhs: &[f64], n: usize) -> Option<Vec<f64>> {
959 if rows.is_empty() {
960 return Some(vec![0.0; n]);
961 }
962 let matrix = dmatrix_from_rows(rows, n);
963 let rhs_vec = DVector::from_column_slice(rhs);
964 let svd = matrix.svd(true, true);
965 let u = svd.u.as_ref()?;
966 let v_t = svd.v_t.as_ref()?;
967 let mut x = vec![0.0; n];
968 for (i, sigma) in svd.singular_values.iter().enumerate() {
969 if *sigma <= TOL {
970 continue;
971 }
972 let coeff = (0..rows.len())
973 .map(|row| u[(row, i)] * rhs_vec[row])
974 .sum::<f64>()
975 / sigma;
976 for col in 0..n {
977 x[col] += v_t[(i, col)] * coeff;
978 }
979 }
980 let residual = rows
981 .iter()
982 .zip(rhs)
983 .map(|(row, target)| (dot(row, &x) - target).abs())
984 .fold(0.0, f64::max);
985 (residual <= 1.0e-7).then_some(x)
986}
987
988fn dmatrix_from_rows(rows: &[Vec<f64>], n: usize) -> DMatrix<f64> {
989 let data = rows
990 .iter()
991 .flat_map(|row| row.iter().copied())
992 .collect::<Vec<_>>();
993 DMatrix::from_row_slice(rows.len(), n, &data)
994}
995
996fn enumerate_combinations(len: usize, choose: usize, mut visit: impl FnMut(&[usize])) {
997 fn rec(
998 len: usize,
999 choose: usize,
1000 start: usize,
1001 current: &mut Vec<usize>,
1002 visit: &mut dyn FnMut(&[usize]),
1003 ) {
1004 if current.len() == choose {
1005 visit(current);
1006 return;
1007 }
1008 let remaining = choose - current.len();
1009 for idx in start..=len - remaining {
1010 current.push(idx);
1011 rec(len, choose, idx + 1, current, visit);
1012 current.pop();
1013 }
1014 }
1015
1016 if choose > len {
1017 return;
1018 }
1019 let mut current = Vec::with_capacity(choose);
1020 rec(len, choose, 0, &mut current, &mut visit);
1021}
1022
1023fn dot(a: &[f64], b: &[f64]) -> f64 {
1024 a.iter().zip(b).map(|(x, y)| x * y).sum()
1025}
1026
1027fn norm(values: &[f64]) -> f64 {
1028 dot(values, values).sqrt()
1029}
1030
1031fn finalize(outcome: LinprogOutcome) -> Value {
1032 let x = outcome
1033 .x
1034 .clone()
1035 .map(vector_value)
1036 .unwrap_or_else(empty_double);
1037 let fval = outcome.fval.map(Value::Num).unwrap_or_else(empty_double);
1038 let exitflag = Value::Num(outcome.exitflag as f64);
1039 let output = Value::Struct(build_output_struct(&outcome));
1040
1041 match crate::output_count::current_output_count() {
1042 None => x,
1043 Some(0) => Value::OutputList(Vec::new()),
1044 Some(1) => crate::output_count::output_list_with_padding(1, vec![x]),
1045 Some(2) => crate::output_count::output_list_with_padding(2, vec![x, fval]),
1046 Some(3) => crate::output_count::output_list_with_padding(3, vec![x, fval, exitflag]),
1047 Some(n) if n >= 4 => {
1048 crate::output_count::output_list_with_padding(n, vec![x, fval, exitflag, output])
1049 }
1050 Some(_) => x,
1051 }
1052}
1053
1054fn vector_value(values: Vec<f64>) -> Value {
1055 let n = values.len();
1056 Tensor::new(values, vec![n, 1])
1057 .map(Value::Tensor)
1058 .unwrap_or_else(|_| empty_double())
1059}
1060
1061fn empty_double() -> Value {
1062 Value::Tensor(Tensor::zeros(vec![0, 0]))
1063}
1064
1065fn build_output_struct(outcome: &LinprogOutcome) -> StructValue {
1066 let mut fields = StructValue::new();
1067 fields.insert("iterations", Value::Num(outcome.iterations as f64));
1068 fields.insert("algorithm", Value::from(ALGORITHM));
1069 fields.insert("constrviolation", Value::Num(outcome.constrviolation));
1070 fields.insert("message", Value::from(outcome.message.clone()));
1071 fields
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076 use super::*;
1077 use futures::executor::block_on;
1078 use runmat_builtins::Value as V;
1079
1080 fn tensor(data: Vec<f64>, rows: usize, cols: usize) -> V {
1081 V::Tensor(Tensor::new(data, vec![rows, cols]).unwrap())
1082 }
1083
1084 fn empty() -> V {
1085 V::Tensor(Tensor::zeros(vec![0, 0]))
1086 }
1087
1088 fn run(f: V, a: V, b: V, rest: Vec<V>, outputs: usize) -> Vec<V> {
1089 let _guard = crate::output_count::push_output_count(Some(outputs));
1090 let value = block_on(linprog_builtin(f, a, b, rest)).expect("linprog");
1091 match value {
1092 V::OutputList(values) => values,
1093 other => vec![other],
1094 }
1095 }
1096
1097 #[test]
1098 fn solves_bounded_feasible_problem() {
1099 let outputs = run(
1100 tensor(vec![-1.0, -2.0], 2, 1),
1101 tensor(vec![1.0, 1.0], 1, 2),
1102 V::Num(4.0),
1103 vec![empty(), empty(), tensor(vec![0.0, 0.0], 2, 1), empty()],
1104 3,
1105 );
1106 match &outputs[0] {
1107 V::Tensor(x) => {
1108 assert!((x.data[0] - 0.0).abs() < 1.0e-7, "{x:?}");
1109 assert!((x.data[1] - 4.0).abs() < 1.0e-7, "{x:?}");
1110 }
1111 other => panic!("unexpected x {other:?}"),
1112 }
1113 assert!(matches!(&outputs[1], V::Num(fval) if (*fval + 8.0).abs() < 1.0e-7));
1114 assert!(matches!(&outputs[2], V::Num(flag) if *flag == 1.0));
1115 }
1116
1117 #[test]
1118 fn solves_equality_constrained_problem() {
1119 let outputs = run(
1120 tensor(vec![1.0, 2.0], 2, 1),
1121 empty(),
1122 empty(),
1123 vec![
1124 tensor(vec![1.0, 1.0], 1, 2),
1125 V::Num(3.0),
1126 tensor(vec![1.0, 0.0], 2, 1),
1127 empty(),
1128 ],
1129 2,
1130 );
1131 match &outputs[0] {
1132 V::Tensor(x) => {
1133 assert!((x.data[0] - 3.0).abs() < 1.0e-7, "{x:?}");
1134 assert!((x.data[1] - 0.0).abs() < 1.0e-7, "{x:?}");
1135 }
1136 other => panic!("unexpected x {other:?}"),
1137 }
1138 assert!(matches!(&outputs[1], V::Num(fval) if (*fval - 3.0).abs() < 1.0e-7));
1139 }
1140
1141 #[test]
1142 fn reports_infeasible_bounds() {
1143 let outputs = run(
1144 V::Num(1.0),
1145 empty(),
1146 empty(),
1147 vec![empty(), empty(), V::Num(2.0), V::Num(1.0)],
1148 4,
1149 );
1150 assert!(matches!(&outputs[0], V::Tensor(t) if t.data.is_empty()));
1151 assert!(matches!(&outputs[1], V::Tensor(t) if t.data.is_empty()));
1152 assert!(matches!(&outputs[2], V::Num(flag) if *flag == -2.0));
1153 assert!(matches!(&outputs[3], V::Struct(s) if s.fields.contains_key("message")));
1154 }
1155
1156 #[test]
1157 fn reports_unbounded_problem() {
1158 let outputs = run(V::Num(-1.0), empty(), empty(), Vec::new(), 3);
1159 assert!(matches!(&outputs[0], V::Tensor(t) if t.data.is_empty()));
1160 assert!(matches!(&outputs[1], V::Tensor(t) if t.data.is_empty()));
1161 assert!(matches!(&outputs[2], V::Num(flag) if *flag == -3.0));
1162 }
1163
1164 #[test]
1165 fn accepts_empty_optional_placeholders() {
1166 let outputs = run(
1167 tensor(vec![1.0, 1.0], 2, 1),
1168 empty(),
1169 empty(),
1170 vec![empty(), empty(), tensor(vec![2.0, 3.0], 2, 1), empty()],
1171 2,
1172 );
1173 match &outputs[0] {
1174 V::Tensor(x) => {
1175 assert!((x.data[0] - 2.0).abs() < 1.0e-7, "{x:?}");
1176 assert!((x.data[1] - 3.0).abs() < 1.0e-7, "{x:?}");
1177 }
1178 other => panic!("unexpected x {other:?}"),
1179 }
1180 assert!(matches!(&outputs[1], V::Num(fval) if (*fval - 5.0).abs() < 1.0e-7));
1181 }
1182
1183 #[test]
1184 fn solves_one_sided_bound_with_fewer_rows_than_variables() {
1185 let outputs = run(
1186 tensor(vec![1.0, 0.0], 2, 1),
1187 empty(),
1188 empty(),
1189 vec![
1190 empty(),
1191 empty(),
1192 tensor(vec![2.0, f64::NEG_INFINITY], 2, 1),
1193 empty(),
1194 ],
1195 3,
1196 );
1197 match &outputs[0] {
1198 V::Tensor(x) => {
1199 assert!((x.data[0] - 2.0).abs() < 1.0e-7, "{x:?}");
1200 assert!(x.data[1].abs() < 1.0e-7, "{x:?}");
1201 }
1202 other => panic!("unexpected x {other:?}"),
1203 }
1204 assert!(matches!(&outputs[1], V::Num(fval) if (*fval - 2.0).abs() < 1.0e-7));
1205 assert!(matches!(&outputs[2], V::Num(flag) if *flag == 1.0));
1206 }
1207
1208 #[test]
1209 fn optimizes_along_equality_face_when_particular_solution_is_suboptimal() {
1210 let outputs = run(
1211 tensor(vec![-1.0, 0.0, 0.0], 3, 1),
1212 tensor(vec![1.0, 0.0, 0.0], 1, 3),
1213 V::Num(1.0),
1214 vec![
1215 tensor(vec![0.0, 0.0, 1.0], 1, 3),
1216 V::Num(0.0),
1217 empty(),
1218 empty(),
1219 ],
1220 3,
1221 );
1222 match &outputs[0] {
1223 V::Tensor(x) => {
1224 assert!((x.data[0] - 1.0).abs() < 1.0e-7, "{x:?}");
1225 assert!(x.data[1].abs() < 1.0e-7, "{x:?}");
1226 assert!(x.data[2].abs() < 1.0e-7, "{x:?}");
1227 }
1228 other => panic!("unexpected x {other:?}"),
1229 }
1230 assert!(matches!(&outputs[1], V::Num(fval) if (*fval + 1.0).abs() < 1.0e-7));
1231 assert!(matches!(&outputs[2], V::Num(flag) if *flag == 1.0));
1232 }
1233
1234 #[test]
1235 fn validates_matrix_dimensions() {
1236 let err = block_on(linprog_builtin(
1237 tensor(vec![1.0, 1.0], 2, 1),
1238 tensor(vec![1.0, 1.0, 1.0], 1, 3),
1239 V::Num(1.0),
1240 Vec::new(),
1241 ))
1242 .unwrap_err();
1243 assert_eq!(err.identifier(), Some("RunMat:linprog:InvalidInput"));
1244 }
1245}