symrs/expr/
eq.rs

1use std::fmt;
2
3use itertools::Itertools;
4use log::debug;
5use schemars::{JsonSchema, json_schema};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use super::{ops::ParseExprError, *};
10
11#[derive(Clone)]
12pub struct Equation {
13    pub lhs: Box<dyn Expr>,
14    pub rhs: Box<dyn Expr>,
15}
16
17impl fmt::Display for Equation {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        write!(f, "{} = {}", self.lhs.str(), self.rhs.str())
20    }
21}
22
23impl Serialize for Equation {
24    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25    where
26        S: serde::Serializer,
27    {
28        serializer.serialize_str(&self.str())
29    }
30}
31
32struct ExprDeserializeVisitor;
33
34#[derive(Error, Debug)]
35pub enum Error {
36    #[error("{0}")]
37    Message(String),
38    #[error("parsed expression is not an equation")]
39    NotAnEquation,
40    #[error("{0}")]
41    FailedParsing(#[from] ParseExprError),
42}
43
44impl<'de> serde::de::Visitor<'de> for ExprDeserializeVisitor {
45    type Value = Equation;
46
47    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
48        formatter.write_str("a properly written equation")
49    }
50
51    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
52    where
53        E: serde::de::Error,
54    {
55        Equation::from_str(v).map_err(|e| E::custom(e.to_string()))
56    }
57}
58
59impl<'de> Deserialize<'de> for Equation {
60    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
61    where
62        D: serde::Deserializer<'de>,
63    {
64        deserializer.deserialize_str(ExprDeserializeVisitor)
65    }
66}
67
68impl fmt::Debug for Equation {
69    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70        write!(f, "{}\n{}", self.str(), self.srepr())
71        // write!(f, "{}", self.str())
72    }
73}
74
75impl Equation {
76    pub fn into_new(lhs: &Box<dyn Expr>, rhs: &Box<dyn Expr>) -> Equation {
77        Equation {
78            lhs: lhs.clone(),
79            rhs: rhs.clone(),
80        }
81    }
82
83    pub fn new(lhs: &dyn Expr, rhs: &dyn Expr) -> Equation {
84        Equation {
85            lhs: lhs.clone_box(),
86            rhs: rhs.clone_box(),
87        }
88    }
89
90    pub fn new_box(lhs: Box<dyn Expr>, rhs: Box<dyn Expr>) -> Box<dyn Expr> {
91        Box::new(Equation { lhs, rhs })
92    }
93
94    pub fn from_str(s: &str) -> Result<Equation, Error> {
95        ops::parse_expr(s)?.as_eq().ok_or(Error::NotAnEquation)
96    }
97}
98
99impl Expr for Equation {
100    fn name(&self) -> String {
101        "Eq".to_string()
102    }
103
104    fn get_ref<'a>(&'a self) -> &'a dyn Expr {
105        self as &dyn Expr
106    }
107    fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
108        f(&*self.lhs);
109        f(&*self.rhs);
110    }
111
112    fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
113        Box::new(Equation {
114            lhs: args[0].clone().into(),
115            rhs: args[1].clone().into(),
116        })
117    }
118
119    fn clone_box(&self) -> Box<dyn Expr> {
120        Box::new(self.clone())
121    }
122
123    fn str(&self) -> String {
124        format!("{} = {}", self.lhs.str(), self.rhs.str())
125    }
126}
127
128impl std::ops::SubAssign<&dyn Expr> for Equation {
129    fn sub_assign(&mut self, rhs: &dyn Expr) {
130        self.lhs -= rhs;
131        self.rhs -= rhs;
132    }
133}
134
135impl std::ops::DivAssign<&dyn Expr> for Equation {
136    fn div_assign(&mut self, rhs: &dyn Expr) {
137        self.lhs /= rhs;
138        self.rhs /= rhs;
139    }
140}
141
142impl std::cmp::PartialEq for Equation {
143    fn eq(&self, other: &Self) -> bool {
144        &self.lhs == &other.lhs && &self.rhs == &other.rhs
145    }
146}
147
148#[derive(Error, Debug)]
149#[error("failed to solve equation {equation} for unknowns {unknowns} : {reason}")]
150pub struct SolvingError {
151    unknowns: String,
152    equation: String,
153    reason: String,
154}
155
156impl Equation {
157    // - Expand all terms containing solved symbols
158    // - Move them to the left
159    // - Others to the right
160    // - Factorize
161    // - Divide by factor of seeked symbol
162    pub fn solve<'a, S: IntoIterator<Item = &'a dyn Expr>>(
163        &self,
164        exprs: S,
165    ) -> Result<Equation, SolvingError> {
166        let eq = (self.expand()).as_eq().expect("Should remain an eqation");
167
168        let symbols: Vec<_> = exprs.into_iter().collect();
169        debug!("solving equation {} for unknowns {:?}", self.str(), symbols);
170        debug!("expanded: {}", eq.str());
171
172        if symbols.is_empty() {
173            Err(SolvingError {
174                unknowns: "".to_string(),
175                equation: eq.str(),
176                reason: "no unknowns given".into(),
177            })?
178        }
179
180        let move_right: Vec<_> = eq
181            .lhs
182            .terms()
183            .filter(|e| symbols.iter().all(|s| !e.has(s.get_ref())))
184            .collect();
185        let move_left: Vec<_> = eq
186            .rhs
187            .terms()
188            .filter(|e| symbols.iter().any(|s| e.has(s.get_ref())))
189            .collect();
190
191        // x + y = 2x + 2y -> -x = y
192        // x + y = 2x + 2y -> -x = y
193
194        let mut res = eq.clone();
195        debug!("Equation: {}", res.str());
196        // dbg!(&res);
197        // dbg!(&move_right);
198        // dbg!(&move_left);
199
200        for t in move_right {
201            debug!("Moving {t} to the right");
202            res -= t;
203            debug!("Equation: {}", res.str());
204        }
205
206        for t in move_left {
207            debug!("Moving {t} to the left");
208            res -= t;
209        }
210
211        let (coeff, _) = (&res.lhs).get_coeff();
212
213        if coeff.is_zero() {
214            Err(SolvingError {
215                unknowns: symbols.iter().map(|e| e.str()).join(", "),
216                equation: res.str(),
217                reason: "failed to get unknown coefficient".into(),
218            })?
219        }
220
221        if !coeff.is_one() {
222            res /= coeff.get_ref();
223        }
224
225        let mut symbols_coeff = Vec::new();
226        if let KnownExpr::Mul(Mul { operands }) = KnownExpr::from_expr_box(&res.lhs) {
227            for op in operands {
228                if !symbols.iter().any(|s| op.has(s.get_ref())) {
229                    symbols_coeff.push(op.clone_box());
230                }
231            }
232        }
233
234        let symbols_coeff = Mul {
235            operands: symbols_coeff,
236        };
237        res /= symbols_coeff.get_ref();
238
239        debug!("solved equation: {}", res.str());
240
241        Ok(res)
242    }
243}
244
245impl JsonSchema for Equation {
246    fn schema_name() -> std::borrow::Cow<'static, str> {
247        "Equation".into()
248    }
249    fn schema_id() -> std::borrow::Cow<'static, str> {
250        concat!(module_path!(), "::Equation").into()
251    }
252
253    fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
254        json_schema!({
255            "type": "string",
256            "pattern": "^[^=]+=[^=]+$"
257        })
258    }
259    fn inline_schema() -> bool {
260        true
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267    use crate::{Integer, Symbol};
268
269    #[test]
270    fn test_solve_solved() {
271        let x = &Symbol::new("x");
272
273        let expr = Equation::new(x, &Integer::zero())
274            .solve([x.get_ref()])
275            .expect("solved equation");
276        let expected = "Eq(Symbol(x), Integer(0))";
277
278        assert_eq!(expr.srepr(), expected)
279    }
280
281    #[test]
282    fn test_solve_basic() {
283        let x = &Symbol::new("x");
284
285        let expr = Equation::new(&Integer::zero(), x)
286            .solve([x.get_ref()])
287            .expect("solved equation");
288        let expected = "Eq(Symbol(x), Integer(0))";
289
290        assert_eq!(expr.srepr(), expected)
291    }
292
293    #[test]
294    fn test_solve_normal() {
295        let x = &Symbol::new("x");
296        let y = &Symbol::new("y");
297        let two = &Integer::new(2);
298
299        let expr = Equation::new(y, &*(two * x));
300        let expected = Equation::new(x, &*(y / two));
301
302        assert_eq!(
303            expr.solve([x.get_ref()]).expect("solved equation"),
304            expected
305        )
306    }
307
308    #[test]
309    fn test_solve_non_number_coeff() {
310        let x = &Symbol::new("x");
311        let y = &Symbol::new("y");
312        let z = &Symbol::new("z");
313
314        let expr = Equation::new(y, &*(z * x));
315        let expected = Equation::new(x, &*(y / z));
316
317        assert_eq!(
318            expr.solve([x.get_ref()]).expect("solved equation"),
319            expected
320        )
321    }
322}