1use std::fmt;
2
3use itertools::Itertools;
4use log::debug;
5use schemars::{JsonSchema, json_schema};
6use serde::{Deserialize, Serialize};
7use thiserror::Error;
8
9use super::{ops::ParseExprError, *};
10
11#[derive(Clone)]
12pub struct Equation {
13 pub lhs: Box<dyn Expr>,
14 pub rhs: Box<dyn Expr>,
15}
16
17impl fmt::Display for Equation {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 write!(f, "{} = {}", self.lhs.str(), self.rhs.str())
20 }
21}
22
23impl Serialize for Equation {
24 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
25 where
26 S: serde::Serializer,
27 {
28 serializer.serialize_str(&self.str())
29 }
30}
31
32struct ExprDeserializeVisitor;
33
34#[derive(Error, Debug)]
35pub enum Error {
36 #[error("{0}")]
37 Message(String),
38 #[error("parsed expression is not an equation")]
39 NotAnEquation,
40 #[error("{0}")]
41 FailedParsing(#[from] ParseExprError),
42}
43
44impl<'de> serde::de::Visitor<'de> for ExprDeserializeVisitor {
45 type Value = Equation;
46
47 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
48 formatter.write_str("a properly written equation")
49 }
50
51 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
52 where
53 E: serde::de::Error,
54 {
55 Equation::from_str(v).map_err(|e| E::custom(e.to_string()))
56 }
57}
58
59impl<'de> Deserialize<'de> for Equation {
60 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
61 where
62 D: serde::Deserializer<'de>,
63 {
64 deserializer.deserialize_str(ExprDeserializeVisitor)
65 }
66}
67
68impl fmt::Debug for Equation {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 write!(f, "{}\n{}", self.str(), self.srepr())
71 }
73}
74
75impl Equation {
76 pub fn into_new(lhs: &Box<dyn Expr>, rhs: &Box<dyn Expr>) -> Equation {
77 Equation {
78 lhs: lhs.clone(),
79 rhs: rhs.clone(),
80 }
81 }
82
83 pub fn new(lhs: &dyn Expr, rhs: &dyn Expr) -> Equation {
84 Equation {
85 lhs: lhs.clone_box(),
86 rhs: rhs.clone_box(),
87 }
88 }
89
90 pub fn new_box(lhs: Box<dyn Expr>, rhs: Box<dyn Expr>) -> Box<dyn Expr> {
91 Box::new(Equation { lhs, rhs })
92 }
93
94 pub fn from_str(s: &str) -> Result<Equation, Error> {
95 ops::parse_expr(s)?.as_eq().ok_or(Error::NotAnEquation)
96 }
97}
98
99impl Expr for Equation {
100 fn name(&self) -> String {
101 "Eq".to_string()
102 }
103
104 fn get_ref<'a>(&'a self) -> &'a dyn Expr {
105 self as &dyn Expr
106 }
107 fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
108 f(&*self.lhs);
109 f(&*self.rhs);
110 }
111
112 fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
113 Box::new(Equation {
114 lhs: args[0].clone().into(),
115 rhs: args[1].clone().into(),
116 })
117 }
118
119 fn clone_box(&self) -> Box<dyn Expr> {
120 Box::new(self.clone())
121 }
122
123 fn str(&self) -> String {
124 format!("{} = {}", self.lhs.str(), self.rhs.str())
125 }
126}
127
128impl std::ops::SubAssign<&dyn Expr> for Equation {
129 fn sub_assign(&mut self, rhs: &dyn Expr) {
130 self.lhs -= rhs;
131 self.rhs -= rhs;
132 }
133}
134
135impl std::ops::DivAssign<&dyn Expr> for Equation {
136 fn div_assign(&mut self, rhs: &dyn Expr) {
137 self.lhs /= rhs;
138 self.rhs /= rhs;
139 }
140}
141
142impl std::cmp::PartialEq for Equation {
143 fn eq(&self, other: &Self) -> bool {
144 &self.lhs == &other.lhs && &self.rhs == &other.rhs
145 }
146}
147
148#[derive(Error, Debug)]
149#[error("failed to solve equation {equation} for unknowns {unknowns} : {reason}")]
150pub struct SolvingError {
151 unknowns: String,
152 equation: String,
153 reason: String,
154}
155
156impl Equation {
157 pub fn solve<'a, S: IntoIterator<Item = &'a dyn Expr>>(
163 &self,
164 exprs: S,
165 ) -> Result<Equation, SolvingError> {
166 let eq = (self.expand()).as_eq().expect("Should remain an eqation");
167
168 let symbols: Vec<_> = exprs.into_iter().collect();
169 debug!("solving equation {} for unknowns {:?}", self.str(), symbols);
170 debug!("expanded: {}", eq.str());
171
172 if symbols.is_empty() {
173 Err(SolvingError {
174 unknowns: "".to_string(),
175 equation: eq.str(),
176 reason: "no unknowns given".into(),
177 })?
178 }
179
180 let move_right: Vec<_> = eq
181 .lhs
182 .terms()
183 .filter(|e| symbols.iter().all(|s| !e.has(s.get_ref())))
184 .collect();
185 let move_left: Vec<_> = eq
186 .rhs
187 .terms()
188 .filter(|e| symbols.iter().any(|s| e.has(s.get_ref())))
189 .collect();
190
191 let mut res = eq.clone();
195 debug!("Equation: {}", res.str());
196 for t in move_right {
201 debug!("Moving {t} to the right");
202 res -= t;
203 debug!("Equation: {}", res.str());
204 }
205
206 for t in move_left {
207 debug!("Moving {t} to the left");
208 res -= t;
209 }
210
211 let (coeff, _) = (&res.lhs).get_coeff();
212
213 if coeff.is_zero() {
214 Err(SolvingError {
215 unknowns: symbols.iter().map(|e| e.str()).join(", "),
216 equation: res.str(),
217 reason: "failed to get unknown coefficient".into(),
218 })?
219 }
220
221 if !coeff.is_one() {
222 res /= coeff.get_ref();
223 }
224
225 let mut symbols_coeff = Vec::new();
226 if let KnownExpr::Mul(Mul { operands }) = KnownExpr::from_expr_box(&res.lhs) {
227 for op in operands {
228 if !symbols.iter().any(|s| op.has(s.get_ref())) {
229 symbols_coeff.push(op.clone_box());
230 }
231 }
232 }
233
234 let symbols_coeff = Mul {
235 operands: symbols_coeff,
236 };
237 res /= symbols_coeff.get_ref();
238
239 debug!("solved equation: {}", res.str());
240
241 Ok(res)
242 }
243}
244
245impl JsonSchema for Equation {
246 fn schema_name() -> std::borrow::Cow<'static, str> {
247 "Equation".into()
248 }
249 fn schema_id() -> std::borrow::Cow<'static, str> {
250 concat!(module_path!(), "::Equation").into()
251 }
252
253 fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
254 json_schema!({
255 "type": "string",
256 "pattern": "^[^=]+=[^=]+$"
257 })
258 }
259 fn inline_schema() -> bool {
260 true
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267 use crate::{Integer, Symbol};
268
269 #[test]
270 fn test_solve_solved() {
271 let x = &Symbol::new("x");
272
273 let expr = Equation::new(x, &Integer::zero())
274 .solve([x.get_ref()])
275 .expect("solved equation");
276 let expected = "Eq(Symbol(x), Integer(0))";
277
278 assert_eq!(expr.srepr(), expected)
279 }
280
281 #[test]
282 fn test_solve_basic() {
283 let x = &Symbol::new("x");
284
285 let expr = Equation::new(&Integer::zero(), x)
286 .solve([x.get_ref()])
287 .expect("solved equation");
288 let expected = "Eq(Symbol(x), Integer(0))";
289
290 assert_eq!(expr.srepr(), expected)
291 }
292
293 #[test]
294 fn test_solve_normal() {
295 let x = &Symbol::new("x");
296 let y = &Symbol::new("y");
297 let two = &Integer::new(2);
298
299 let expr = Equation::new(y, &*(two * x));
300 let expected = Equation::new(x, &*(y / two));
301
302 assert_eq!(
303 expr.solve([x.get_ref()]).expect("solved equation"),
304 expected
305 )
306 }
307
308 #[test]
309 fn test_solve_non_number_coeff() {
310 let x = &Symbol::new("x");
311 let y = &Symbol::new("y");
312 let z = &Symbol::new("z");
313
314 let expr = Equation::new(y, &*(z * x));
315 let expected = Equation::new(x, &*(y / z));
316
317 assert_eq!(
318 expr.solve([x.get_ref()]).expect("solved equation"),
319 expected
320 )
321 }
322}