1use crate::ast::*;
2use crate::calc::{calc_function_call, CalcError, Env};
3
4use thiserror::Error;
5
6#[derive(Debug, PartialEq)]
9struct NormForm {
10 a1: Number,
11 a0: Number,
12}
13
14#[derive(Debug, PartialEq, Eq, Error)]
15pub enum SolverError {
16 #[error("Unknown variable `{0}` in `solve ... for ...`")]
17 UnknownVariable(String),
18 #[error("Unsupported `^2` of variable to solve for in `solve ... for ...`")]
19 UnsupportedXSquare,
20 #[error("Unsupported variable in denominator in `solve ... for ...`")]
21 UnsupportedXDenominator,
22 #[error("Unsupported % with solve for variable in `solve ... for ...`")]
23 UnsupportedRemainder,
24 #[error("Unsupported power in `solve ... for ...`")]
25 UnsupportedPower,
26 #[error("`solve ... for ...` contains no variable (after simplification)")]
27 NoVariable,
28 #[error(transparent)]
29 FunctionCallError(#[from] CalcError),
30}
31
32fn normalize_term(term: &Term, sym: &str, env: &dyn Env) -> Result<NormForm, SolverError> {
33 let lhs = normalize(&term.lhs, sym, env)?;
34 let rhs = normalize(&term.rhs, sym, env)?;
35 match term.op {
36 Operation::Add => Ok({
37 let factor = lhs.a1 + rhs.a1;
38 let summand = lhs.a0 + rhs.a0;
39 NormForm {
40 a1: factor,
41 a0: summand,
42 }
43 }),
44 Operation::Sub => Ok({
45 let factor = lhs.a1 - rhs.a1;
46 let summand = lhs.a0 - rhs.a0;
47 NormForm {
48 a1: factor,
49 a0: summand,
50 }
51 }),
52 Operation::Mul => {
53 let a2 = lhs.a1 * rhs.a1;
54 let a1 = lhs.a1 * rhs.a0 + rhs.a1 * lhs.a0;
55 let a0 = lhs.a0 * rhs.a0;
56 if a2 != 0.0 {
57 Err(SolverError::UnsupportedXSquare)
58 } else {
59 Ok(NormForm { a1, a0 })
60 }
61 }
62 Operation::Div => {
63 if rhs.a1 != 0.0 {
64 Err(SolverError::UnsupportedXDenominator)
65 } else {
66 let a1 = lhs.a1 / rhs.a0;
67 let a0 = lhs.a0 / rhs.a0;
68 Ok(NormForm { a1, a0 })
69 }
70 }
71 Operation::Rem => {
72 if (lhs.a1 != 0.0) || (rhs.a1 != 0.0) {
73 Err(SolverError::UnsupportedRemainder)
74 } else {
75 Ok(NormForm {
76 a1: 0.0,
77 a0: (lhs.a0 % rhs.a0),
78 })
79 }
80 }
81 Operation::Pow => {
82 if (lhs.a1 != 0.0) || (rhs.a1 != 0.0) {
83 Err(SolverError::UnsupportedPower)
84 } else {
85 Ok(NormForm {
86 a1: 0.0,
87 a0: (lhs.a0.powf(rhs.a0)),
88 })
89 }
90 }
91 }
92}
93
94fn normalize(op: &Operand, sym: &str, env: &dyn Env) -> Result<NormForm, SolverError> {
95 match op {
96 Operand::Number(num) => Ok(NormForm { a1: 0.0, a0: *num }),
97 Operand::Symbol(s) => {
98 if op.is_symbol(sym) {
99 Ok(NormForm { a1: 1.0, a0: 0.0 })
100 } else {
101 let num = env
102 .get(s)
103 .ok_or_else(|| SolverError::UnknownVariable(s.clone()))?;
104 Ok(NormForm { a1: 0.0, a0: *num })
105 }
106 }
107 Operand::Term(term) => normalize_term(&*term, sym, env),
108 Operand::FunCall(fun_call) => {
109 let num = calc_function_call(fun_call, env)?;
110 Ok(NormForm { a1: 0.0, a0: num })
111 }
112 }
113}
114
115pub fn solve_for(
116 lhs: &Operand,
117 rhs: &Operand,
118 sym: &str,
119 env: &dyn Env,
120) -> Result<Number, SolverError> {
121 let norm_form_lhs = normalize(lhs, sym, env)?;
122 let norm_form_rhs = normalize(rhs, sym, env)?;
123 let denominator = norm_form_lhs.a1 - norm_form_rhs.a1;
124 if 0.0 == denominator {
125 Err(SolverError::NoVariable)
126 } else {
127 let nominator = norm_form_rhs.a0 - norm_form_lhs.a0;
128 Ok(nominator / denominator)
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 mod helpers {
135 use crate::ast::{Operand, Statement};
136 use crate::parser::parse;
137
138 pub fn parse_expression(s: &str) -> Operand {
139 let statement = parse(s).unwrap();
140 if let Statement::Expression { op } = statement {
141 op
142 } else {
143 panic!("string is not a valid expression")
144 }
145 }
146
147 #[test]
148 fn parse_expression_success() {
149 assert_eq!(Operand::Number(1.0), parse_expression("1"));
150 }
151
152 #[test]
153 #[should_panic(expected = "string is not a valid expression")]
154 fn parse_expression_failed_assignment() {
155 parse_expression("x:=1");
156 }
157
158 #[test]
159 #[should_panic(expected = "InvalidExpression")]
160 fn parse_expression_failed_equation() {
161 parse_expression("1 @");
162 }
163 }
164 use self::helpers::parse_expression;
165 use super::*;
166 use crate::ast::CustomFunction;
167 use crate::calc::TopLevelEnv;
168 use crate::parse;
169
170 #[test]
171 fn normalize_operand_number() {
172 let exp = NormForm { a1: 0f64, a0: 1.2 };
173 assert_eq!(
174 exp,
175 normalize(&parse_expression("1.2"), "x", &TopLevelEnv::default()).unwrap()
176 );
177 }
178
179 #[test]
180 fn normalize_operand_symbol_x() {
181 let exp = NormForm { a1: 1f64, a0: 0f64 };
182 assert_eq!(
183 exp,
184 normalize(&parse_expression("x"), "x", &TopLevelEnv::default()).unwrap()
185 );
186 }
187
188 #[test]
189 fn normalize_operand_symbol_y_unknown() {
190 let act = normalize(&parse_expression("y"), "x", &TopLevelEnv::default());
191 assert!(matches!(act, Err(SolverError::UnknownVariable(s)) if s == "y"));
192 }
193
194 #[test]
195 fn normalize_operand_symbol_y() {
196 let mut env = TopLevelEnv::default();
197 env.put("y".to_string(), 12.0).unwrap();
198 let act = normalize(&parse_expression("y"), "x", &env);
199 assert_eq!(Ok(NormForm { a1: 0.0, a0: 12.0 }), act);
200 }
201
202 #[test]
203 fn normalize_operand_simple_add() {
204 let exp = NormForm { a1: 1f64, a0: 1f64 };
205 assert_eq!(
206 exp,
207 normalize(&parse_expression("x + 1"), "x", &TopLevelEnv::default()).unwrap()
208 );
209 }
210
211 #[test]
212 fn normalize_operand_simple_sub() {
213 let exp = NormForm {
214 a1: 1f64,
215 a0: -12f64,
216 };
217 assert_eq!(
218 exp,
219 normalize(&parse_expression("x - 12"), "x", &TopLevelEnv::default()).unwrap()
220 );
221 }
222
223 #[test]
224 fn normalize_operand_simple_mul() {
225 let exp = NormForm { a1: 2f64, a0: 0f64 };
226 assert_eq!(
227 exp,
228 normalize(&parse_expression("x * 2"), "x", &TopLevelEnv::default()).unwrap()
229 );
230 }
231
232 #[test]
233 fn normalize_operand_simple_rem() {
234 let exp = NormForm { a1: 0f64, a0: 1f64 };
235 assert_eq!(
236 exp,
237 normalize(&parse_expression("7 % 3"), "x", &TopLevelEnv::default()).unwrap()
238 );
239 }
240
241 #[test]
242 fn normalize_operand_simple_pow() {
243 let exp = NormForm {
244 a1: 0f64,
245 a0: 27f64,
246 };
247 assert_eq!(
248 exp,
249 normalize(&parse_expression("3 ^ 3"), "x", &TopLevelEnv::default()).unwrap()
250 );
251 }
252
253 #[test]
254 fn normalize_operand_simple_norm_form() {
255 let exp = NormForm { a1: 3f64, a0: 2f64 };
256 assert_eq!(
257 exp,
258 normalize(&parse_expression("3 * x + 2"), "x", &TopLevelEnv::default()).unwrap()
259 );
260 }
261
262 #[test]
263 fn normalize_operand_simple_norm_sub() {
264 let exp = NormForm {
265 a1: 3f64,
266 a0: -2f64,
267 };
268 assert_eq!(
269 exp,
270 normalize(&parse_expression("3 * x - 2"), "x", &TopLevelEnv::default()).unwrap()
271 );
272 }
273
274 #[test]
275 fn normalize_operand_div() {
276 let exp = NormForm {
277 a1: 4f64,
278 a0: -5f64,
279 };
280 assert_eq!(
281 exp,
282 normalize(
283 &parse_expression("(12 * x - 15) / 3"),
284 "x",
285 &TopLevelEnv::default()
286 )
287 .unwrap()
288 );
289 }
290
291 #[test]
292 fn solve_for_simple() {
293 assert!(
294 if let Statement::SolveFor { lhs, rhs, sym } = parse("solve x = 10 for x").unwrap() {
295 assert_eq!(
296 Ok(10.0),
297 solve_for(&lhs, &rhs, &sym, &TopLevelEnv::default())
298 );
299 true
300 } else {
301 false
302 }
303 );
304 }
305
306 #[test]
307 fn solve_for_complex() {
308 assert!(if let Statement::SolveFor { lhs, rhs, sym } =
309 parse("solve 5 + 2 * x + 12 = 22 - 6 * x + 7 for x").unwrap()
310 {
311 assert_eq!(
312 Ok(1.5),
313 solve_for(&lhs, &rhs, &sym, &TopLevelEnv::default())
314 );
315 true
316 } else {
317 false
318 });
319 }
320
321 #[test]
322 fn solve_for_with_function_call() {
323 let mut env = TopLevelEnv::default();
324 env.put_fun(
325 "add".to_string(),
326 Function::Custom(CustomFunction {
327 args: vec!["x".to_string(), "y".to_string()],
328 body: Operand::Term(Box::new(Term {
329 lhs: Operand::Symbol("x".to_string()),
330 rhs: Operand::Symbol("y".to_string()),
331 op: Operation::Add,
332 })),
333 }),
334 );
335 assert!(if let Statement::SolveFor { lhs, rhs, sym } =
336 parse("solve 2 * x + add(5, 12) = 22 - 6 * x + 7 for x").unwrap()
337 {
338 assert_eq!(Ok(1.5), solve_for(&lhs, &rhs, &sym, &env));
339 true
340 } else {
341 false
342 });
343 }
344}