Skip to main content

runmat_runtime/builtins/math/optim/
linprog.rs

1//! MATLAB-compatible `linprog` builtin for small and medium linear programs.
2
3use nalgebra::{DMatrix, DVector};
4use runmat_builtins::{
5    BuiltinCompletionPolicy, BuiltinDescriptor, BuiltinErrorDescriptor, BuiltinOutputMode,
6    BuiltinParamArity, BuiltinParamDescriptor, BuiltinParamType, BuiltinSignatureDescriptor,
7    StructValue, Tensor, Value,
8};
9use runmat_macros::runtime_builtin;
10
11use crate::builtins::common::spec::{
12    BroadcastSemantics, BuiltinFusionSpec, BuiltinGpuSpec, ConstantStrategy, GpuOpKind,
13    ReductionNaN, ResidencyPolicy, ShapeRequirements,
14};
15use crate::builtins::math::optim::type_resolvers::linear_programming_type;
16use crate::{build_runtime_error, BuiltinResult, RuntimeError};
17
18const NAME: &str = "linprog";
19const ALGORITHM: &str = "active-set vertex enumeration";
20const TOL: f64 = 1.0e-8;
21
22const LINPROG_OUTPUT_X: [BuiltinParamDescriptor; 1] = [BuiltinParamDescriptor {
23    name: "x",
24    ty: BuiltinParamType::NumericArray,
25    arity: BuiltinParamArity::Required,
26    default: None,
27    description: "Optimal decision vector.",
28}];
29
30const LINPROG_OUTPUT_X_FVAL: [BuiltinParamDescriptor; 2] = [
31    BuiltinParamDescriptor {
32        name: "x",
33        ty: BuiltinParamType::NumericArray,
34        arity: BuiltinParamArity::Required,
35        default: None,
36        description: "Optimal decision vector.",
37    },
38    BuiltinParamDescriptor {
39        name: "fval",
40        ty: BuiltinParamType::NumericScalar,
41        arity: BuiltinParamArity::Required,
42        default: None,
43        description: "Objective value f'*x at the solution.",
44    },
45];
46
47const LINPROG_OUTPUT_X_FVAL_EXITFLAG: [BuiltinParamDescriptor; 3] = [
48    BuiltinParamDescriptor {
49        name: "x",
50        ty: BuiltinParamType::NumericArray,
51        arity: BuiltinParamArity::Required,
52        default: None,
53        description: "Optimal decision vector.",
54    },
55    BuiltinParamDescriptor {
56        name: "fval",
57        ty: BuiltinParamType::NumericScalar,
58        arity: BuiltinParamArity::Required,
59        default: None,
60        description: "Objective value f'*x at the solution.",
61    },
62    BuiltinParamDescriptor {
63        name: "exitflag",
64        ty: BuiltinParamType::NumericScalar,
65        arity: BuiltinParamArity::Required,
66        default: None,
67        description: "Solver status code.",
68    },
69];
70
71const LINPROG_OUTPUT_ALL: [BuiltinParamDescriptor; 4] = [
72    BuiltinParamDescriptor {
73        name: "x",
74        ty: BuiltinParamType::NumericArray,
75        arity: BuiltinParamArity::Required,
76        default: None,
77        description: "Optimal decision vector.",
78    },
79    BuiltinParamDescriptor {
80        name: "fval",
81        ty: BuiltinParamType::NumericScalar,
82        arity: BuiltinParamArity::Required,
83        default: None,
84        description: "Objective value f'*x at the solution.",
85    },
86    BuiltinParamDescriptor {
87        name: "exitflag",
88        ty: BuiltinParamType::NumericScalar,
89        arity: BuiltinParamArity::Required,
90        default: None,
91        description: "Solver status code.",
92    },
93    BuiltinParamDescriptor {
94        name: "output",
95        ty: BuiltinParamType::Any,
96        arity: BuiltinParamArity::Required,
97        default: None,
98        description: "Diagnostic metadata struct.",
99    },
100];
101
102const LINPROG_INPUTS_CORE: [BuiltinParamDescriptor; 3] = [
103    BuiltinParamDescriptor {
104        name: "f",
105        ty: BuiltinParamType::NumericArray,
106        arity: BuiltinParamArity::Required,
107        default: None,
108        description: "Linear objective vector.",
109    },
110    BuiltinParamDescriptor {
111        name: "A",
112        ty: BuiltinParamType::NumericArray,
113        arity: BuiltinParamArity::Required,
114        default: None,
115        description: "Inequality constraint matrix.",
116    },
117    BuiltinParamDescriptor {
118        name: "b",
119        ty: BuiltinParamType::NumericArray,
120        arity: BuiltinParamArity::Required,
121        default: None,
122        description: "Inequality constraint right-hand side.",
123    },
124];
125
126const LINPROG_INPUTS_EQ: [BuiltinParamDescriptor; 5] = [
127    BuiltinParamDescriptor {
128        name: "f",
129        ty: BuiltinParamType::NumericArray,
130        arity: BuiltinParamArity::Required,
131        default: None,
132        description: "Linear objective vector.",
133    },
134    BuiltinParamDescriptor {
135        name: "A",
136        ty: BuiltinParamType::NumericArray,
137        arity: BuiltinParamArity::Required,
138        default: None,
139        description: "Inequality constraint matrix.",
140    },
141    BuiltinParamDescriptor {
142        name: "b",
143        ty: BuiltinParamType::NumericArray,
144        arity: BuiltinParamArity::Required,
145        default: None,
146        description: "Inequality constraint right-hand side.",
147    },
148    BuiltinParamDescriptor {
149        name: "Aeq",
150        ty: BuiltinParamType::NumericArray,
151        arity: BuiltinParamArity::Optional,
152        default: Some("[]"),
153        description: "Equality constraint matrix.",
154    },
155    BuiltinParamDescriptor {
156        name: "beq",
157        ty: BuiltinParamType::NumericArray,
158        arity: BuiltinParamArity::Optional,
159        default: Some("[]"),
160        description: "Equality constraint right-hand side.",
161    },
162];
163
164const LINPROG_INPUTS_BOUNDS: [BuiltinParamDescriptor; 7] = [
165    BuiltinParamDescriptor {
166        name: "f",
167        ty: BuiltinParamType::NumericArray,
168        arity: BuiltinParamArity::Required,
169        default: None,
170        description: "Linear objective vector.",
171    },
172    BuiltinParamDescriptor {
173        name: "A",
174        ty: BuiltinParamType::NumericArray,
175        arity: BuiltinParamArity::Required,
176        default: None,
177        description: "Inequality constraint matrix.",
178    },
179    BuiltinParamDescriptor {
180        name: "b",
181        ty: BuiltinParamType::NumericArray,
182        arity: BuiltinParamArity::Required,
183        default: None,
184        description: "Inequality constraint right-hand side.",
185    },
186    BuiltinParamDescriptor {
187        name: "Aeq",
188        ty: BuiltinParamType::NumericArray,
189        arity: BuiltinParamArity::Optional,
190        default: Some("[]"),
191        description: "Equality constraint matrix.",
192    },
193    BuiltinParamDescriptor {
194        name: "beq",
195        ty: BuiltinParamType::NumericArray,
196        arity: BuiltinParamArity::Optional,
197        default: Some("[]"),
198        description: "Equality constraint right-hand side.",
199    },
200    BuiltinParamDescriptor {
201        name: "lb",
202        ty: BuiltinParamType::NumericArray,
203        arity: BuiltinParamArity::Optional,
204        default: Some("[]"),
205        description: "Lower bounds.",
206    },
207    BuiltinParamDescriptor {
208        name: "ub",
209        ty: BuiltinParamType::NumericArray,
210        arity: BuiltinParamArity::Optional,
211        default: Some("[]"),
212        description: "Upper bounds.",
213    },
214];
215
216const LINPROG_SIGNATURES: [BuiltinSignatureDescriptor; 12] = [
217    BuiltinSignatureDescriptor {
218        label: "x = linprog(f, A, b)",
219        inputs: &LINPROG_INPUTS_CORE,
220        outputs: &LINPROG_OUTPUT_X,
221    },
222    BuiltinSignatureDescriptor {
223        label: "x = linprog(f, A, b, Aeq, beq)",
224        inputs: &LINPROG_INPUTS_EQ,
225        outputs: &LINPROG_OUTPUT_X,
226    },
227    BuiltinSignatureDescriptor {
228        label: "x = linprog(f, A, b, Aeq, beq, lb, ub)",
229        inputs: &LINPROG_INPUTS_BOUNDS,
230        outputs: &LINPROG_OUTPUT_X,
231    },
232    BuiltinSignatureDescriptor {
233        label: "[x, fval] = linprog(f, A, b)",
234        inputs: &LINPROG_INPUTS_CORE,
235        outputs: &LINPROG_OUTPUT_X_FVAL,
236    },
237    BuiltinSignatureDescriptor {
238        label: "[x, fval] = linprog(f, A, b, Aeq, beq)",
239        inputs: &LINPROG_INPUTS_EQ,
240        outputs: &LINPROG_OUTPUT_X_FVAL,
241    },
242    BuiltinSignatureDescriptor {
243        label: "[x, fval] = linprog(f, A, b, Aeq, beq, lb, ub)",
244        inputs: &LINPROG_INPUTS_BOUNDS,
245        outputs: &LINPROG_OUTPUT_X_FVAL,
246    },
247    BuiltinSignatureDescriptor {
248        label: "[x, fval, exitflag] = linprog(f, A, b)",
249        inputs: &LINPROG_INPUTS_CORE,
250        outputs: &LINPROG_OUTPUT_X_FVAL_EXITFLAG,
251    },
252    BuiltinSignatureDescriptor {
253        label: "[x, fval, exitflag] = linprog(f, A, b, Aeq, beq)",
254        inputs: &LINPROG_INPUTS_EQ,
255        outputs: &LINPROG_OUTPUT_X_FVAL_EXITFLAG,
256    },
257    BuiltinSignatureDescriptor {
258        label: "[x, fval, exitflag] = linprog(f, A, b, Aeq, beq, lb, ub)",
259        inputs: &LINPROG_INPUTS_BOUNDS,
260        outputs: &LINPROG_OUTPUT_X_FVAL_EXITFLAG,
261    },
262    BuiltinSignatureDescriptor {
263        label: "[x, fval, exitflag, output] = linprog(f, A, b)",
264        inputs: &LINPROG_INPUTS_CORE,
265        outputs: &LINPROG_OUTPUT_ALL,
266    },
267    BuiltinSignatureDescriptor {
268        label: "[x, fval, exitflag, output] = linprog(f, A, b, Aeq, beq)",
269        inputs: &LINPROG_INPUTS_EQ,
270        outputs: &LINPROG_OUTPUT_ALL,
271    },
272    BuiltinSignatureDescriptor {
273        label: "[x, fval, exitflag, output] = linprog(f, A, b, Aeq, beq, lb, ub)",
274        inputs: &LINPROG_INPUTS_BOUNDS,
275        outputs: &LINPROG_OUTPUT_ALL,
276    },
277];
278
279const LINPROG_ERROR_INVALID_ARGUMENT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
280    code: "RM.LINPROG.INVALID_ARGUMENT",
281    identifier: Some("RunMat:linprog:InvalidArgument"),
282    when: "The argument count or optional argument grammar is invalid.",
283    message: "linprog: invalid argument",
284};
285
286const LINPROG_ERROR_INVALID_INPUT: BuiltinErrorDescriptor = BuiltinErrorDescriptor {
287    code: "RM.LINPROG.INVALID_INPUT",
288    identifier: Some("RunMat:linprog:InvalidInput"),
289    when: "Objective, constraint, or bound dimensions/types are invalid.",
290    message: "linprog: invalid input",
291};
292
293const LINPROG_ERRORS: [BuiltinErrorDescriptor; 2] =
294    [LINPROG_ERROR_INVALID_ARGUMENT, LINPROG_ERROR_INVALID_INPUT];
295
296pub const LINPROG_DESCRIPTOR: BuiltinDescriptor = BuiltinDescriptor {
297    signatures: &LINPROG_SIGNATURES,
298    output_mode: BuiltinOutputMode::ByRequestedOutputCount,
299    completion_policy: BuiltinCompletionPolicy::Public,
300    errors: &LINPROG_ERRORS,
301};
302
303fn linprog_error_with_detail(
304    error: &'static BuiltinErrorDescriptor,
305    detail: impl AsRef<str>,
306) -> RuntimeError {
307    let detail = detail.as_ref();
308    let message = if detail.starts_with("linprog:") {
309        detail.to_string()
310    } else {
311        format!("{}: {detail}", error.message)
312    };
313    let mut builder = build_runtime_error(message).with_builtin(NAME);
314    if let Some(identifier) = error.identifier {
315        builder = builder.with_identifier(identifier);
316    }
317    builder.build()
318}
319
320#[runmat_macros::register_gpu_spec(builtin_path = "crate::builtins::math::optim::linprog")]
321pub const GPU_SPEC: BuiltinGpuSpec = BuiltinGpuSpec {
322    name: "linprog",
323    op_kind: GpuOpKind::Custom("linear-programming"),
324    supported_precisions: &[],
325    broadcast: BroadcastSemantics::None,
326    provider_hooks: &[],
327    constant_strategy: ConstantStrategy::InlineLiteral,
328    residency: ResidencyPolicy::GatherImmediately,
329    nan_mode: ReductionNaN::Include,
330    two_pass_threshold: None,
331    workgroup_size: None,
332    accepts_nan_mode: false,
333    notes: "Host active-set LP solver. GPU-resident numeric inputs are gathered before solving.",
334};
335
336#[runmat_macros::register_fusion_spec(builtin_path = "crate::builtins::math::optim::linprog")]
337pub const FUSION_SPEC: BuiltinFusionSpec = BuiltinFusionSpec {
338    name: "linprog",
339    shape: ShapeRequirements::Any,
340    constant_strategy: ConstantStrategy::InlineLiteral,
341    elementwise: None,
342    reduction: None,
343    emits_nan: false,
344    notes: "Linear programming is a solver boundary and terminates fusion planning.",
345};
346
347#[runtime_builtin(
348    name = "linprog",
349    category = "math/optim",
350    summary = "Solve a linear programming minimization problem with linear constraints and bounds.",
351    keywords = "linprog,linear programming,optimization,linear constraints,bounds",
352    accel = "sink",
353    type_resolver(linear_programming_type),
354    descriptor(crate::builtins::math::optim::linprog::LINPROG_DESCRIPTOR),
355    builtin_path = "crate::builtins::math::optim::linprog"
356)]
357async fn linprog_builtin(f: Value, a: Value, b: Value, rest: Vec<Value>) -> BuiltinResult<Value> {
358    if rest.len() > 4 {
359        return Err(linprog_error_with_detail(
360            &LINPROG_ERROR_INVALID_ARGUMENT,
361            "too many input arguments",
362        ));
363    }
364
365    let f = numeric_vector("f", f, FiniteMode::Finite).await?;
366    if f.is_empty() {
367        return Err(linprog_error_with_detail(
368            &LINPROG_ERROR_INVALID_INPUT,
369            "f must be a nonempty numeric vector",
370        ));
371    }
372    let n = f.len();
373
374    let (mut a_ineq, mut b_ineq) = parse_constraint_pair("A", a, "b", b, n).await?;
375    let (a_eq, b_eq) = parse_optional_equality(rest.first(), rest.get(1), n).await?;
376    let (lb, ub) = parse_bounds(rest.get(2), rest.get(3), n).await?;
377
378    for i in 0..n {
379        if lb[i] > ub[i] + TOL {
380            return Ok(finalize(LinprogOutcome::infeasible(
381                "No feasible point found: lower bound exceeds upper bound.",
382            )));
383        }
384        if lb[i].is_finite() {
385            let mut row = vec![0.0; n];
386            row[i] = -1.0;
387            a_ineq.push(row);
388            b_ineq.push(-lb[i]);
389        }
390        if ub[i].is_finite() {
391            let mut row = vec![0.0; n];
392            row[i] = 1.0;
393            a_ineq.push(row);
394            b_ineq.push(ub[i]);
395        }
396    }
397
398    let problem = LinearProgram {
399        f,
400        a_ineq,
401        b_ineq,
402        a_eq,
403        b_eq,
404    };
405    Ok(finalize(solve_linprog(&problem)))
406}
407
408#[derive(Clone, Copy)]
409enum FiniteMode {
410    Finite,
411    Bounds,
412}
413
414#[derive(Clone)]
415struct MatrixInput {
416    rows: usize,
417    cols: usize,
418    data: Vec<f64>,
419}
420
421impl MatrixInput {
422    fn row(&self, row: usize) -> Vec<f64> {
423        (0..self.cols)
424            .map(|col| self.data[row + col * self.rows])
425            .collect()
426    }
427}
428
429async fn gather(value: Value) -> BuiltinResult<Value> {
430    crate::dispatcher::gather_if_needed_async(&value)
431        .await
432        .map_err(|err| linprog_error_with_detail(&LINPROG_ERROR_INVALID_INPUT, err.message()))
433}
434
435fn is_empty_value(value: &Value) -> bool {
436    matches!(value, Value::Tensor(t) if t.data.is_empty())
437}
438
439async fn numeric_vector(
440    label: &str,
441    value: Value,
442    finite_mode: FiniteMode,
443) -> BuiltinResult<Vec<f64>> {
444    let value = gather(value).await?;
445    if is_empty_value(&value) {
446        return Ok(Vec::new());
447    }
448    let data = match value {
449        Value::Num(n) => vec![n],
450        Value::Int(i) => vec![i.to_f64()],
451        Value::Tensor(t) => {
452            let dims = t.shape.len();
453            if dims > 2 || (t.rows() != 1 && t.cols() != 1) {
454                return Err(linprog_error_with_detail(
455                    &LINPROG_ERROR_INVALID_INPUT,
456                    format!("{label} must be a vector"),
457                ));
458            }
459            t.data
460        }
461        other => {
462            return Err(linprog_error_with_detail(
463                &LINPROG_ERROR_INVALID_INPUT,
464                format!("{label} must be a real numeric vector, got {other:?}"),
465            ))
466        }
467    };
468    validate_numbers(label, &data, finite_mode)?;
469    Ok(data)
470}
471
472async fn numeric_matrix(label: &str, value: Value) -> BuiltinResult<Option<MatrixInput>> {
473    let value = gather(value).await?;
474    if is_empty_value(&value) {
475        return Ok(None);
476    }
477    match value {
478        Value::Num(n) => {
479            validate_numbers(label, &[n], FiniteMode::Finite)?;
480            Ok(Some(MatrixInput {
481                rows: 1,
482                cols: 1,
483                data: vec![n],
484            }))
485        }
486        Value::Int(i) => {
487            let value = i.to_f64();
488            validate_numbers(label, &[value], FiniteMode::Finite)?;
489            Ok(Some(MatrixInput {
490                rows: 1,
491                cols: 1,
492                data: vec![value],
493            }))
494        }
495        Value::Tensor(t) => {
496            if t.shape.len() > 2 {
497                return Err(linprog_error_with_detail(
498                    &LINPROG_ERROR_INVALID_INPUT,
499                    format!("{label} must be a numeric matrix"),
500                ));
501            }
502            validate_numbers(label, &t.data, FiniteMode::Finite)?;
503            Ok(Some(MatrixInput {
504                rows: t.rows(),
505                cols: t.cols(),
506                data: t.data,
507            }))
508        }
509        other => Err(linprog_error_with_detail(
510            &LINPROG_ERROR_INVALID_INPUT,
511            format!("{label} must be a real numeric matrix, got {other:?}"),
512        )),
513    }
514}
515
516fn validate_numbers(label: &str, data: &[f64], finite_mode: FiniteMode) -> BuiltinResult<()> {
517    for value in data {
518        match finite_mode {
519            FiniteMode::Finite if !value.is_finite() => {
520                return Err(linprog_error_with_detail(
521                    &LINPROG_ERROR_INVALID_INPUT,
522                    format!("{label} values must be finite"),
523                ))
524            }
525            FiniteMode::Bounds if value.is_nan() => {
526                return Err(linprog_error_with_detail(
527                    &LINPROG_ERROR_INVALID_INPUT,
528                    format!("{label} bounds cannot be NaN"),
529                ))
530            }
531            _ => {}
532        }
533    }
534    Ok(())
535}
536
537async fn parse_constraint_pair(
538    matrix_label: &str,
539    matrix: Value,
540    rhs_label: &str,
541    rhs: Value,
542    n: usize,
543) -> BuiltinResult<(Vec<Vec<f64>>, Vec<f64>)> {
544    let matrix = numeric_matrix(matrix_label, matrix).await?;
545    let rhs = numeric_vector(rhs_label, rhs, FiniteMode::Finite).await?;
546    match (matrix, rhs.is_empty()) {
547        (None, true) => Ok((Vec::new(), Vec::new())),
548        (None, false) => Err(linprog_error_with_detail(
549            &LINPROG_ERROR_INVALID_INPUT,
550            format!("{matrix_label} cannot be empty when {rhs_label} is nonempty"),
551        )),
552        (Some(matrix), _) => {
553            if matrix.cols != n {
554                return Err(linprog_error_with_detail(
555                    &LINPROG_ERROR_INVALID_INPUT,
556                    format!("{matrix_label} must have one column per element of f"),
557                ));
558            }
559            if rhs.len() != matrix.rows {
560                return Err(linprog_error_with_detail(
561                    &LINPROG_ERROR_INVALID_INPUT,
562                    format!("{rhs_label} length must match rows of {matrix_label}"),
563                ));
564            }
565            let rows = (0..matrix.rows).map(|row| matrix.row(row)).collect();
566            Ok((rows, rhs))
567        }
568    }
569}
570
571async fn parse_optional_equality(
572    aeq: Option<&Value>,
573    beq: Option<&Value>,
574    n: usize,
575) -> BuiltinResult<(Vec<Vec<f64>>, Vec<f64>)> {
576    match (aeq, beq) {
577        (None, None) => Ok((Vec::new(), Vec::new())),
578        (Some(aeq), None) if is_empty_value(aeq) => Ok((Vec::new(), Vec::new())),
579        (Some(_), None) => Err(linprog_error_with_detail(
580            &LINPROG_ERROR_INVALID_ARGUMENT,
581            "Aeq requires a matching beq argument",
582        )),
583        (None, Some(_)) => Err(linprog_error_with_detail(
584            &LINPROG_ERROR_INVALID_ARGUMENT,
585            "beq requires a matching Aeq argument",
586        )),
587        (Some(aeq), Some(beq)) => {
588            parse_constraint_pair("Aeq", aeq.clone(), "beq", beq.clone(), n).await
589        }
590    }
591}
592
593async fn parse_bounds(
594    lb: Option<&Value>,
595    ub: Option<&Value>,
596    n: usize,
597) -> BuiltinResult<(Vec<f64>, Vec<f64>)> {
598    let lb = match lb {
599        None => vec![f64::NEG_INFINITY; n],
600        Some(value) if is_empty_value(value) => vec![f64::NEG_INFINITY; n],
601        Some(value) => {
602            let values = numeric_vector("lb", value.clone(), FiniteMode::Bounds).await?;
603            normalize_bound("lb", values, n)?
604        }
605    };
606    let ub = match ub {
607        None => vec![f64::INFINITY; n],
608        Some(value) if is_empty_value(value) => vec![f64::INFINITY; n],
609        Some(value) => {
610            let values = numeric_vector("ub", value.clone(), FiniteMode::Bounds).await?;
611            normalize_bound("ub", values, n)?
612        }
613    };
614    Ok((lb, ub))
615}
616
617fn normalize_bound(label: &str, values: Vec<f64>, n: usize) -> BuiltinResult<Vec<f64>> {
618    if values.len() == n {
619        Ok(values)
620    } else {
621        Err(linprog_error_with_detail(
622            &LINPROG_ERROR_INVALID_INPUT,
623            format!("{label} length must match f"),
624        ))
625    }
626}
627
628struct LinearProgram {
629    f: Vec<f64>,
630    a_ineq: Vec<Vec<f64>>,
631    b_ineq: Vec<f64>,
632    a_eq: Vec<Vec<f64>>,
633    b_eq: Vec<f64>,
634}
635
636#[derive(Clone)]
637struct LinprogOutcome {
638    x: Option<Vec<f64>>,
639    fval: Option<f64>,
640    exitflag: i32,
641    iterations: usize,
642    constrviolation: f64,
643    message: String,
644}
645
646impl LinprogOutcome {
647    fn infeasible(message: &str) -> Self {
648        Self {
649            x: None,
650            fval: None,
651            exitflag: -2,
652            iterations: 0,
653            constrviolation: 0.0,
654            message: message.to_string(),
655        }
656    }
657
658    fn unbounded(iterations: usize) -> Self {
659        Self {
660            x: None,
661            fval: None,
662            exitflag: -3,
663            iterations,
664            constrviolation: 0.0,
665            message: "Problem is unbounded.".to_string(),
666        }
667    }
668}
669
670fn solve_linprog(problem: &LinearProgram) -> LinprogOutcome {
671    let n = problem.f.len();
672    let Some(face) = equality_face(problem, n) else {
673        return LinprogOutcome::infeasible("No feasible point found.");
674    };
675    let reduced = reduce_to_equality_face(problem, &face);
676    let k = reduced.f.len();
677    let mut candidates = Vec::new();
678    let mut combinations = 0usize;
679
680    enumerate_vertices(&reduced, |y| {
681        combinations += 1;
682        if is_feasible(&reduced, &y) {
683            candidates.push(y);
684        }
685    });
686
687    let feasible_fallback = if candidates.is_empty() {
688        let y0 = vec![0.0; k];
689        is_feasible(&reduced, &y0).then_some(y0)
690    } else {
691        None
692    };
693    let has_feasible_point = !candidates.is_empty() || feasible_fallback.is_some();
694    if !has_feasible_point {
695        return LinprogOutcome::infeasible("No feasible point found.");
696    }
697    if has_unbounded_descent_direction(&reduced) {
698        return LinprogOutcome::unbounded(combinations);
699    }
700
701    if let Some(x) = feasible_fallback {
702        candidates.push(x);
703    }
704
705    let mut best_y = candidates[0].clone();
706    let mut best_fval = dot(&reduced.f, &best_y);
707    for candidate in candidates.into_iter().skip(1) {
708        let fval = dot(&reduced.f, &candidate);
709        if fval < best_fval - TOL {
710            best_y = candidate;
711            best_fval = fval;
712        }
713    }
714
715    let best = lift_from_equality_face(&face, &best_y);
716    let best_fval = dot(&problem.f, &best);
717    let constrviolation = constraint_violation(problem, &best);
718    LinprogOutcome {
719        x: Some(best),
720        fval: Some(best_fval),
721        exitflag: 1,
722        iterations: combinations,
723        constrviolation,
724        message: "Optimal solution found.".to_string(),
725    }
726}
727
728struct EqualityFace {
729    x0: Vec<f64>,
730    basis: Vec<Vec<f64>>,
731}
732
733fn equality_face(problem: &LinearProgram, n: usize) -> Option<EqualityFace> {
734    let x0 = if problem.a_eq.is_empty() {
735        vec![0.0; n]
736    } else {
737        pseudo_solve(&problem.a_eq, &problem.b_eq, n)?
738    };
739    Some(EqualityFace {
740        x0,
741        basis: nullspace_basis(&problem.a_eq, n),
742    })
743}
744
745fn reduce_to_equality_face(problem: &LinearProgram, face: &EqualityFace) -> LinearProgram {
746    LinearProgram {
747        f: face
748            .basis
749            .iter()
750            .map(|basis_vector| dot(&problem.f, basis_vector))
751            .collect(),
752        a_ineq: problem
753            .a_ineq
754            .iter()
755            .map(|row| {
756                face.basis
757                    .iter()
758                    .map(|basis_vector| dot(row, basis_vector))
759                    .collect()
760            })
761            .collect(),
762        b_ineq: problem
763            .a_ineq
764            .iter()
765            .zip(&problem.b_ineq)
766            .map(|(row, rhs)| rhs - dot(row, &face.x0))
767            .collect(),
768        a_eq: Vec::new(),
769        b_eq: Vec::new(),
770    }
771}
772
773fn lift_from_equality_face(face: &EqualityFace, y: &[f64]) -> Vec<f64> {
774    let mut x = face.x0.clone();
775    for (coeff, basis_vector) in y.iter().zip(&face.basis) {
776        for (x_j, basis_j) in x.iter_mut().zip(basis_vector) {
777            *x_j += coeff * basis_j;
778        }
779    }
780    x
781}
782
783fn enumerate_vertices(problem: &LinearProgram, mut visit: impl FnMut(Vec<f64>)) {
784    let n = problem.f.len();
785    let max_active = problem.a_ineq.len().min(n);
786    for active_count in 0..=max_active {
787        enumerate_combinations(problem.a_ineq.len(), active_count, |active| {
788            let mut rows = problem.a_eq.clone();
789            let mut rhs = problem.b_eq.clone();
790            for &idx in active {
791                rows.push(problem.a_ineq[idx].clone());
792                rhs.push(problem.b_ineq[idx]);
793            }
794            if let Some(x) = pseudo_solve(&rows, &rhs, n) {
795                visit(x);
796            }
797        });
798    }
799}
800
801fn has_unbounded_descent_direction(problem: &LinearProgram) -> bool {
802    let n = problem.f.len();
803    let max_active = problem.a_ineq.len().min(n.saturating_sub(1));
804    for active_count in 0..=max_active {
805        let mut found = false;
806        enumerate_combinations(problem.a_ineq.len(), active_count, |active| {
807            if found {
808                return;
809            }
810            let mut rows = problem.a_eq.clone();
811            for &idx in active {
812                rows.push(problem.a_ineq[idx].clone());
813            }
814            for direction in candidate_nullspace_descent_directions(&rows, &problem.f, n) {
815                if is_recession_direction(problem, &direction) && dot(&problem.f, &direction) < -TOL
816                {
817                    found = true;
818                    return;
819                }
820            }
821        });
822        if found {
823            return true;
824        }
825    }
826    false
827}
828
829fn candidate_nullspace_descent_directions(rows: &[Vec<f64>], f: &[f64], n: usize) -> Vec<Vec<f64>> {
830    let basis = nullspace_basis(rows, n);
831    if basis.is_empty() {
832        return Vec::new();
833    }
834    let mut directions = Vec::new();
835    let mut projected = vec![0.0; n];
836    for basis_vector in &basis {
837        let coeff = -dot(f, basis_vector);
838        for i in 0..n {
839            projected[i] += coeff * basis_vector[i];
840        }
841        directions.push(basis_vector.clone());
842        directions.push(basis_vector.iter().map(|v| -*v).collect());
843    }
844    if norm(&projected) > TOL {
845        directions.push(projected);
846    }
847    directions
848}
849
850fn is_recession_direction(problem: &LinearProgram, direction: &[f64]) -> bool {
851    norm(direction) > TOL
852        && problem
853            .a_eq
854            .iter()
855            .all(|row| dot(row, direction).abs() <= TOL)
856        && problem.a_ineq.iter().all(|row| dot(row, direction) <= TOL)
857}
858
859fn is_feasible(problem: &LinearProgram, x: &[f64]) -> bool {
860    constraint_violation(problem, x) <= 1.0e-7
861}
862
863fn constraint_violation(problem: &LinearProgram, x: &[f64]) -> f64 {
864    let eq = problem
865        .a_eq
866        .iter()
867        .zip(&problem.b_eq)
868        .map(|(row, rhs)| (dot(row, x) - rhs).abs())
869        .fold(0.0, f64::max);
870    let ineq = problem
871        .a_ineq
872        .iter()
873        .zip(&problem.b_ineq)
874        .map(|(row, rhs)| (dot(row, x) - rhs).max(0.0))
875        .fold(0.0, f64::max);
876    eq.max(ineq)
877}
878
879fn nullspace_basis(rows: &[Vec<f64>], n: usize) -> Vec<Vec<f64>> {
880    if n == 0 {
881        return Vec::new();
882    }
883    if rows.is_empty() {
884        return (0..n)
885            .map(|i| {
886                let mut basis = vec![0.0; n];
887                basis[i] = 1.0;
888                basis
889            })
890            .collect();
891    }
892
893    let (reduced, pivots) = rref(rows, n);
894    let free_cols = (0..n)
895        .filter(|col| !pivots.contains(col))
896        .collect::<Vec<_>>();
897    free_cols
898        .into_iter()
899        .filter_map(|free_col| {
900            let mut basis = vec![0.0; n];
901            basis[free_col] = 1.0;
902            for (row, pivot_col) in pivots.iter().enumerate() {
903                basis[*pivot_col] = -reduced[row][free_col];
904            }
905            let length = norm(&basis);
906            (length > TOL).then(|| basis.into_iter().map(|value| value / length).collect())
907        })
908        .collect()
909}
910
911fn rref(rows: &[Vec<f64>], n: usize) -> (Vec<Vec<f64>>, Vec<usize>) {
912    let mut matrix = rows.to_vec();
913    let mut pivots = Vec::new();
914    let mut pivot_row = 0usize;
915
916    for col in 0..n {
917        let Some(best_row) = (pivot_row..matrix.len()).max_by(|&a, &b| {
918            matrix[a][col]
919                .abs()
920                .partial_cmp(&matrix[b][col].abs())
921                .unwrap_or(std::cmp::Ordering::Equal)
922        }) else {
923            break;
924        };
925        if matrix[best_row][col].abs() <= TOL {
926            continue;
927        }
928
929        matrix.swap(pivot_row, best_row);
930        let pivot = matrix[pivot_row][col];
931        for value in &mut matrix[pivot_row] {
932            *value /= pivot;
933        }
934
935        for row in 0..matrix.len() {
936            if row == pivot_row {
937                continue;
938            }
939            let factor = matrix[row][col];
940            if factor.abs() <= TOL {
941                continue;
942            }
943            for j in col..n {
944                matrix[row][j] -= factor * matrix[pivot_row][j];
945            }
946        }
947
948        pivots.push(col);
949        pivot_row += 1;
950        if pivot_row == matrix.len() {
951            break;
952        }
953    }
954
955    (matrix, pivots)
956}
957
958fn pseudo_solve(rows: &[Vec<f64>], rhs: &[f64], n: usize) -> Option<Vec<f64>> {
959    if rows.is_empty() {
960        return Some(vec![0.0; n]);
961    }
962    let matrix = dmatrix_from_rows(rows, n);
963    let rhs_vec = DVector::from_column_slice(rhs);
964    let svd = matrix.svd(true, true);
965    let u = svd.u.as_ref()?;
966    let v_t = svd.v_t.as_ref()?;
967    let mut x = vec![0.0; n];
968    for (i, sigma) in svd.singular_values.iter().enumerate() {
969        if *sigma <= TOL {
970            continue;
971        }
972        let coeff = (0..rows.len())
973            .map(|row| u[(row, i)] * rhs_vec[row])
974            .sum::<f64>()
975            / sigma;
976        for col in 0..n {
977            x[col] += v_t[(i, col)] * coeff;
978        }
979    }
980    let residual = rows
981        .iter()
982        .zip(rhs)
983        .map(|(row, target)| (dot(row, &x) - target).abs())
984        .fold(0.0, f64::max);
985    (residual <= 1.0e-7).then_some(x)
986}
987
988fn dmatrix_from_rows(rows: &[Vec<f64>], n: usize) -> DMatrix<f64> {
989    let data = rows
990        .iter()
991        .flat_map(|row| row.iter().copied())
992        .collect::<Vec<_>>();
993    DMatrix::from_row_slice(rows.len(), n, &data)
994}
995
996fn enumerate_combinations(len: usize, choose: usize, mut visit: impl FnMut(&[usize])) {
997    fn rec(
998        len: usize,
999        choose: usize,
1000        start: usize,
1001        current: &mut Vec<usize>,
1002        visit: &mut dyn FnMut(&[usize]),
1003    ) {
1004        if current.len() == choose {
1005            visit(current);
1006            return;
1007        }
1008        let remaining = choose - current.len();
1009        for idx in start..=len - remaining {
1010            current.push(idx);
1011            rec(len, choose, idx + 1, current, visit);
1012            current.pop();
1013        }
1014    }
1015
1016    if choose > len {
1017        return;
1018    }
1019    let mut current = Vec::with_capacity(choose);
1020    rec(len, choose, 0, &mut current, &mut visit);
1021}
1022
1023fn dot(a: &[f64], b: &[f64]) -> f64 {
1024    a.iter().zip(b).map(|(x, y)| x * y).sum()
1025}
1026
1027fn norm(values: &[f64]) -> f64 {
1028    dot(values, values).sqrt()
1029}
1030
1031fn finalize(outcome: LinprogOutcome) -> Value {
1032    let x = outcome
1033        .x
1034        .clone()
1035        .map(vector_value)
1036        .unwrap_or_else(empty_double);
1037    let fval = outcome.fval.map(Value::Num).unwrap_or_else(empty_double);
1038    let exitflag = Value::Num(outcome.exitflag as f64);
1039    let output = Value::Struct(build_output_struct(&outcome));
1040
1041    match crate::output_count::current_output_count() {
1042        None => x,
1043        Some(0) => Value::OutputList(Vec::new()),
1044        Some(1) => crate::output_count::output_list_with_padding(1, vec![x]),
1045        Some(2) => crate::output_count::output_list_with_padding(2, vec![x, fval]),
1046        Some(3) => crate::output_count::output_list_with_padding(3, vec![x, fval, exitflag]),
1047        Some(n) if n >= 4 => {
1048            crate::output_count::output_list_with_padding(n, vec![x, fval, exitflag, output])
1049        }
1050        Some(_) => x,
1051    }
1052}
1053
1054fn vector_value(values: Vec<f64>) -> Value {
1055    let n = values.len();
1056    Tensor::new(values, vec![n, 1])
1057        .map(Value::Tensor)
1058        .unwrap_or_else(|_| empty_double())
1059}
1060
1061fn empty_double() -> Value {
1062    Value::Tensor(Tensor::zeros(vec![0, 0]))
1063}
1064
1065fn build_output_struct(outcome: &LinprogOutcome) -> StructValue {
1066    let mut fields = StructValue::new();
1067    fields.insert("iterations", Value::Num(outcome.iterations as f64));
1068    fields.insert("algorithm", Value::from(ALGORITHM));
1069    fields.insert("constrviolation", Value::Num(outcome.constrviolation));
1070    fields.insert("message", Value::from(outcome.message.clone()));
1071    fields
1072}
1073
1074#[cfg(test)]
1075mod tests {
1076    use super::*;
1077    use futures::executor::block_on;
1078    use runmat_builtins::Value as V;
1079
1080    fn tensor(data: Vec<f64>, rows: usize, cols: usize) -> V {
1081        V::Tensor(Tensor::new(data, vec![rows, cols]).unwrap())
1082    }
1083
1084    fn empty() -> V {
1085        V::Tensor(Tensor::zeros(vec![0, 0]))
1086    }
1087
1088    fn run(f: V, a: V, b: V, rest: Vec<V>, outputs: usize) -> Vec<V> {
1089        let _guard = crate::output_count::push_output_count(Some(outputs));
1090        let value = block_on(linprog_builtin(f, a, b, rest)).expect("linprog");
1091        match value {
1092            V::OutputList(values) => values,
1093            other => vec![other],
1094        }
1095    }
1096
1097    #[test]
1098    fn solves_bounded_feasible_problem() {
1099        let outputs = run(
1100            tensor(vec![-1.0, -2.0], 2, 1),
1101            tensor(vec![1.0, 1.0], 1, 2),
1102            V::Num(4.0),
1103            vec![empty(), empty(), tensor(vec![0.0, 0.0], 2, 1), empty()],
1104            3,
1105        );
1106        match &outputs[0] {
1107            V::Tensor(x) => {
1108                assert!((x.data[0] - 0.0).abs() < 1.0e-7, "{x:?}");
1109                assert!((x.data[1] - 4.0).abs() < 1.0e-7, "{x:?}");
1110            }
1111            other => panic!("unexpected x {other:?}"),
1112        }
1113        assert!(matches!(&outputs[1], V::Num(fval) if (*fval + 8.0).abs() < 1.0e-7));
1114        assert!(matches!(&outputs[2], V::Num(flag) if *flag == 1.0));
1115    }
1116
1117    #[test]
1118    fn solves_equality_constrained_problem() {
1119        let outputs = run(
1120            tensor(vec![1.0, 2.0], 2, 1),
1121            empty(),
1122            empty(),
1123            vec![
1124                tensor(vec![1.0, 1.0], 1, 2),
1125                V::Num(3.0),
1126                tensor(vec![1.0, 0.0], 2, 1),
1127                empty(),
1128            ],
1129            2,
1130        );
1131        match &outputs[0] {
1132            V::Tensor(x) => {
1133                assert!((x.data[0] - 3.0).abs() < 1.0e-7, "{x:?}");
1134                assert!((x.data[1] - 0.0).abs() < 1.0e-7, "{x:?}");
1135            }
1136            other => panic!("unexpected x {other:?}"),
1137        }
1138        assert!(matches!(&outputs[1], V::Num(fval) if (*fval - 3.0).abs() < 1.0e-7));
1139    }
1140
1141    #[test]
1142    fn reports_infeasible_bounds() {
1143        let outputs = run(
1144            V::Num(1.0),
1145            empty(),
1146            empty(),
1147            vec![empty(), empty(), V::Num(2.0), V::Num(1.0)],
1148            4,
1149        );
1150        assert!(matches!(&outputs[0], V::Tensor(t) if t.data.is_empty()));
1151        assert!(matches!(&outputs[1], V::Tensor(t) if t.data.is_empty()));
1152        assert!(matches!(&outputs[2], V::Num(flag) if *flag == -2.0));
1153        assert!(matches!(&outputs[3], V::Struct(s) if s.fields.contains_key("message")));
1154    }
1155
1156    #[test]
1157    fn reports_unbounded_problem() {
1158        let outputs = run(V::Num(-1.0), empty(), empty(), Vec::new(), 3);
1159        assert!(matches!(&outputs[0], V::Tensor(t) if t.data.is_empty()));
1160        assert!(matches!(&outputs[1], V::Tensor(t) if t.data.is_empty()));
1161        assert!(matches!(&outputs[2], V::Num(flag) if *flag == -3.0));
1162    }
1163
1164    #[test]
1165    fn accepts_empty_optional_placeholders() {
1166        let outputs = run(
1167            tensor(vec![1.0, 1.0], 2, 1),
1168            empty(),
1169            empty(),
1170            vec![empty(), empty(), tensor(vec![2.0, 3.0], 2, 1), empty()],
1171            2,
1172        );
1173        match &outputs[0] {
1174            V::Tensor(x) => {
1175                assert!((x.data[0] - 2.0).abs() < 1.0e-7, "{x:?}");
1176                assert!((x.data[1] - 3.0).abs() < 1.0e-7, "{x:?}");
1177            }
1178            other => panic!("unexpected x {other:?}"),
1179        }
1180        assert!(matches!(&outputs[1], V::Num(fval) if (*fval - 5.0).abs() < 1.0e-7));
1181    }
1182
1183    #[test]
1184    fn solves_one_sided_bound_with_fewer_rows_than_variables() {
1185        let outputs = run(
1186            tensor(vec![1.0, 0.0], 2, 1),
1187            empty(),
1188            empty(),
1189            vec![
1190                empty(),
1191                empty(),
1192                tensor(vec![2.0, f64::NEG_INFINITY], 2, 1),
1193                empty(),
1194            ],
1195            3,
1196        );
1197        match &outputs[0] {
1198            V::Tensor(x) => {
1199                assert!((x.data[0] - 2.0).abs() < 1.0e-7, "{x:?}");
1200                assert!(x.data[1].abs() < 1.0e-7, "{x:?}");
1201            }
1202            other => panic!("unexpected x {other:?}"),
1203        }
1204        assert!(matches!(&outputs[1], V::Num(fval) if (*fval - 2.0).abs() < 1.0e-7));
1205        assert!(matches!(&outputs[2], V::Num(flag) if *flag == 1.0));
1206    }
1207
1208    #[test]
1209    fn optimizes_along_equality_face_when_particular_solution_is_suboptimal() {
1210        let outputs = run(
1211            tensor(vec![-1.0, 0.0, 0.0], 3, 1),
1212            tensor(vec![1.0, 0.0, 0.0], 1, 3),
1213            V::Num(1.0),
1214            vec![
1215                tensor(vec![0.0, 0.0, 1.0], 1, 3),
1216                V::Num(0.0),
1217                empty(),
1218                empty(),
1219            ],
1220            3,
1221        );
1222        match &outputs[0] {
1223            V::Tensor(x) => {
1224                assert!((x.data[0] - 1.0).abs() < 1.0e-7, "{x:?}");
1225                assert!(x.data[1].abs() < 1.0e-7, "{x:?}");
1226                assert!(x.data[2].abs() < 1.0e-7, "{x:?}");
1227            }
1228            other => panic!("unexpected x {other:?}"),
1229        }
1230        assert!(matches!(&outputs[1], V::Num(fval) if (*fval + 1.0).abs() < 1.0e-7));
1231        assert!(matches!(&outputs[2], V::Num(flag) if *flag == 1.0));
1232    }
1233
1234    #[test]
1235    fn validates_matrix_dimensions() {
1236        let err = block_on(linprog_builtin(
1237            tensor(vec![1.0, 1.0], 2, 1),
1238            tensor(vec![1.0, 1.0, 1.0], 1, 3),
1239            V::Num(1.0),
1240            Vec::new(),
1241        ))
1242        .unwrap_err();
1243        assert_eq!(err.identifier(), Some("RunMat:linprog:InvalidInput"));
1244    }
1245}