Skip to main content

cnvx_core/
expr.rs

1use crate::{Constraint, VarId};
2use std::{
3    fmt::Display,
4    ops::{Add, AddAssign},
5};
6
7/// A single term in a linear expression: `coeff * var`.
8#[derive(Clone, Debug)]
9pub struct LinTerm {
10    /// The variable involved in this term.
11    pub var: VarId,
12    /// The coefficient for the variable.
13    pub coeff: f64,
14}
15
16/// Represents a linear expression of the form `a1*x1 + a2*x2 + ... + c`.
17#[derive(Clone, Debug, Default)]
18pub struct LinExpr {
19    /// All variable terms in the expression.
20    pub terms: Vec<LinTerm>,
21    /// Constant term in the expression.
22    pub constant: f64,
23}
24
25// TODO: Currently LinExpr only implements addition, but we want support for subtraction and negation.
26// This will later likely pivot to a more general `Expr` type for non-linear support.
27
28impl LinExpr {
29    /// Creates a new linear expression from a single variable and coefficient.
30    ///
31    /// # Example
32    ///
33    /// ```rust
34    /// # use cnvx_core::{LinExpr, VarId};
35    /// let x = VarId(0);
36    /// let expr = LinExpr::new(x, 3.0); // 3*VarId(0)
37    /// ```
38    pub fn new(var: VarId, coeff: f64) -> Self {
39        Self { terms: vec![LinTerm { var, coeff }], constant: 0.0 }
40    }
41
42    /// Creates a constant-only linear expression.
43    ///
44    /// # Example
45    ///
46    /// ```rust
47    /// # use cnvx_core::LinExpr;
48    /// let expr = LinExpr::constant(5.0); // 5
49    /// ```
50    pub fn constant(c: f64) -> Self {
51        Self { terms: vec![], constant: c }
52    }
53
54    /// Creates a `<=` constraint from this linear expression.
55    pub fn leq(self, rhs: f64) -> Constraint {
56        Constraint::leq(self, rhs)
57    }
58
59    /// Creates a `>=` constraint from this linear expression.
60    pub fn geq(self, rhs: f64) -> Constraint {
61        Constraint::geq(self, rhs)
62    }
63
64    /// Creates a `==` constraint from this linear expression.
65    pub fn eq(self, rhs: f64) -> Constraint {
66        Constraint::eq(self, rhs)
67    }
68}
69
70impl Display for LinExpr {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        let mut parts = Vec::new();
73        for term in &self.terms {
74            parts.push(format!("{}*VarId({})", term.coeff, term.var.0));
75        }
76        if self.constant != 0.0 || parts.is_empty() {
77            parts.push(self.constant.to_string());
78        }
79        write!(f, "{}", parts.join(" + "))
80    }
81}
82
83/////////////////////////////////////////////////////////////////////////////
84// Operator Overloads for LinExpr
85/////////////////////////////////////////////////////////////////////////////
86
87/// LinExpr + LinExpr
88impl Add for LinExpr {
89    type Output = LinExpr;
90
91    fn add(self, rhs: LinExpr) -> LinExpr {
92        let mut terms = self.terms;
93        terms.extend(rhs.terms);
94        LinExpr { terms, constant: self.constant + rhs.constant }
95    }
96}
97
98/// LinExpr + VarId
99impl Add<VarId> for LinExpr {
100    type Output = LinExpr;
101
102    fn add(mut self, rhs: VarId) -> LinExpr {
103        self.terms.push(LinTerm { var: rhs, coeff: 1.0 });
104        self
105    }
106}
107
108/// VarId + LinExpr
109impl Add<LinExpr> for VarId {
110    type Output = LinExpr;
111
112    fn add(self, rhs: LinExpr) -> LinExpr {
113        let mut terms = vec![LinTerm { var: self, coeff: 1.0 }];
114        terms.extend(rhs.terms);
115        LinExpr { terms, constant: rhs.constant }
116    }
117}
118
119/// VarId + VarId
120impl Add for VarId {
121    type Output = LinExpr;
122
123    fn add(self, rhs: VarId) -> LinExpr {
124        LinExpr {
125            terms: vec![
126                LinTerm { var: self, coeff: 1.0 },
127                LinTerm { var: rhs, coeff: 1.0 },
128            ],
129            constant: 0.0,
130        }
131    }
132}
133
134/// LinExpr += LinExpr
135impl AddAssign for LinExpr {
136    fn add_assign(&mut self, rhs: LinExpr) {
137        self.terms.extend(rhs.terms);
138        self.constant += rhs.constant;
139    }
140}
141
142/// LinExpr += VarId
143impl AddAssign<VarId> for LinExpr {
144    fn add_assign(&mut self, rhs: VarId) {
145        self.terms.push(LinTerm { var: rhs, coeff: 1.0 });
146    }
147}
148
149/// f64 + LinExpr
150impl Add<LinExpr> for f64 {
151    type Output = LinExpr;
152
153    fn add(self, rhs: LinExpr) -> LinExpr {
154        let mut expr = rhs.clone();
155        expr.constant += self;
156        expr
157    }
158}
159
160/// LinExpr + f64
161impl Add<f64> for LinExpr {
162    type Output = LinExpr;
163
164    fn add(mut self, rhs: f64) -> LinExpr {
165        self.constant += rhs;
166        self
167    }
168}
169
170/// Allows converting a single variable into a linear expression with coefficient 1.0.
171impl From<VarId> for LinExpr {
172    fn from(var: VarId) -> Self {
173        LinExpr::new(var, 1.0)
174    }
175}