1use std::collections::HashMap;
6
7use crate::{error::BalanceError, Equation};
8use ndarray::prelude::*;
9use num::{Integer, Rational64, Signed, Zero};
10
11#[derive(Debug, Clone)]
25pub struct EquationBalancer<'a> {
26 eq: &'a Equation,
27 matrix: Array2<Rational64>,
28}
29
30impl<'a> EquationBalancer<'a> {
31 pub fn new(eq: &'a Equation) -> Self {
33 let uniq_elements: HashMap<&str, usize> = eq
35 .uniq_elements()
36 .into_iter()
37 .enumerate()
38 .map(|(i, e)| (e, i))
39 .collect();
40
41 let row = eq.num_compounds();
43 let col = uniq_elements.len();
44 let mut arr = Array2::<Rational64>::zeros((row, col));
45
46 let mut left_or_right: Rational64 = 1.into();
47 for (cmp, i) in eq.iter_compounds().zip(0..row) {
49 for el in &cmp.elements {
50 let index = *uniq_elements.get(el.symbol()).unwrap();
51 arr[[i, index]] = <i64 as Into<Rational64>>::into(el.count as i64) * left_or_right;
52 }
53 if i + 1 >= eq.left.len() {
57 left_or_right = Rational64::from_integer(-1);
58 }
59 }
60
61 Self {
62 eq,
63 matrix: arr.reversed_axes(),
64 }
65 }
66
67 pub fn balance(self) -> Result<Equation, BalanceError> {
81 if !self.eq.is_valid() {
82 return Err(BalanceError::InvalidEquation);
83 }
84 if self.eq.is_balanced() {
85 return Ok(self.eq.clone());
86 }
87 let mut eq = self.eq.clone();
88
89 let matrix = self.matrix;
90 let null_space = rref(augment(rref(matrix.view()).t()).view());
92
93 let vec = null_space
95 .row(null_space.dim().0 - 1)
96 .to_owned()
97 .iter()
98 .skip_while(|n| *n.numer() == 0)
99 .map(Rational64::abs)
100 .collect::<Vec<Rational64>>();
101 let coef_col = Array1::from_vec(vec);
102
103 let lcm = coef_col
105 .iter()
106 .map(Rational64::denom)
107 .fold(1, |acc: i64, f| acc.lcm(f));
108
109 let coef_col = coef_col * lcm;
111 if coef_col.to_vec().contains(&Rational64::from_integer(0)) {
112 return Err(BalanceError::Infeasable);
113 }
114
115 for (compound, coef) in eq
117 .iter_compounds_mut()
118 .zip(coef_col.iter().map(Rational64::numer))
119 {
120 compound.coefficient = *coef as _;
121 }
122
123 let mut comp_str: Vec<String> = self
125 .eq
126 .equation
127 .split(' ')
128 .filter(|c| !matches!(*c, "+" | "<-" | "<->" | "->"))
129 .map(Into::into)
130 .collect();
131 for (cmp, str) in eq.iter_compounds().zip(comp_str.iter_mut()) {
132 let mut to_remove = 0;
133 for c in str.chars() {
134 if c.is_numeric() {
135 to_remove += 1;
136 } else {
137 break;
138 }
139 }
140 for _ in 0..to_remove {
141 str.remove(0);
142 }
143 if cmp.coefficient != 1 {
144 str.insert_str(0, cmp.coefficient.to_string().as_str());
145 }
146 }
147 let reactants = comp_str[..eq.left.len()].join(" + ");
149 let products = comp_str[eq.left.len()..].join(" + ");
150
151 eq.equation = format!("{} {} {}", reactants, eq.direction, products);
153
154 Ok(eq)
155 }
156}
157
158impl<'a> From<&'a Equation> for EquationBalancer<'a> {
159 fn from(eq: &'a Equation) -> Self {
161 Self::new(eq)
162 }
163}
164
165fn rref(a: ArrayView2<Rational64>) -> Array2<Rational64> {
168 let mut out = ArrayBase::zeros(a.raw_dim());
169 out.zip_mut_with(&a, |x, y| *x = *y);
170
171 let mut pivot = 0;
172 let (rows, cols) = out.raw_dim().into_pattern();
173
174 'outer: for r in 0..rows {
175 if cols <= pivot {
176 break;
177 }
178 let mut i = r;
179 while (out[[i, pivot]] as Rational64).numer().is_zero() {
180 i += 1;
181 if i == rows {
182 i = r;
183 pivot += 1;
184 if cols == pivot {
185 break 'outer;
186 }
187 }
188 }
189 for j in 0..cols {
190 out.swap([r, j], [i, j]);
191 }
192 let divisor: Rational64 = out[[r, pivot]];
193 if !divisor.numer().is_zero() {
194 out.row_mut(r).iter_mut().for_each(|e| *e /= divisor);
195 }
196 for j in 0..rows {
197 if j != r {
198 let hold = out[[j, pivot]];
199 for k in 0..cols {
200 let t = out[[r, k]];
201 out[[j, k]] -= hold * t;
202 }
203 }
204 }
205 pivot += 1;
206 }
207
208 out
209}
210
211fn augment(a: ArrayView2<Rational64>) -> Array2<Rational64> {
213 ndarray::concatenate(Axis(1), &[a.view(), Array2::eye(a.shape()[0]).view()]).unwrap()
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219
220 #[test]
221 fn balance_simple() {
222 let eq = Equation::new("H2 + O2 -> H2O")
223 .unwrap()
224 .to_balancer()
225 .balance()
226 .unwrap();
227 assert_eq!(eq.equation, "2H2 + O2 -> 2H2O");
228 }
229
230 #[test]
231 fn balance_simple_backwards() {
232 let eq = Equation::new("O2 + H2 -> H2O")
233 .unwrap()
234 .to_balancer()
235 .balance()
236 .unwrap();
237 assert_eq!(eq.equation, "O2 + 2H2 -> 2H2O");
238 }
239
240 #[test]
241 fn balance_other_simple() {
242 let eq = Equation::new("Al + O2 -> Al2O3")
243 .unwrap()
244 .to_balancer()
245 .balance()
246 .unwrap();
247 assert_eq!(eq.equation, "4Al + 3O2 -> 2Al2O3");
248 }
249
250 #[test]
251 fn balance_already_done() {
252 let eq = Equation::new("C2H4 + 3O2 -> 2CO2 + 2H2O")
253 .unwrap()
254 .to_balancer()
255 .balance()
256 .unwrap();
257 assert_eq!(eq.equation, "C2H4 + 3O2 -> 2CO2 + 2H2O");
258 }
259
260 #[test]
261 fn balance_harder() {
262 let eq = Equation::new("C2H6 + O2 -> CO2 + H2O")
263 .unwrap()
264 .to_balancer()
265 .balance()
266 .unwrap();
267 assert_eq!(eq.equation, "2C2H6 + 7O2 -> 4CO2 + 6H2O");
268 }
269
270 #[test]
271 fn try_balance_infeasible() {
272 let res = Equation::new("K4Fe(CN)6 + K2S2O3 -> CO2 + K2SO4 + NO2 + FeS")
273 .unwrap()
274 .to_balancer()
275 .balance();
276 assert_eq!(res, Err(BalanceError::Infeasable));
277 }
278
279 #[test]
280 fn try_balance_coefs_already_exist() {
281 let res = Equation::new("H2 + I -> 2HI")
282 .unwrap()
283 .to_balancer()
284 .balance()
285 .unwrap();
286 assert_eq!(res.equation(), "H2 + 2I -> 2HI");
287 }
288
289 #[test]
290 fn try_balance_coefs_already_exist_two() {
291 let res = Equation::new("N2 + H <-> 2NH3")
292 .unwrap()
293 .to_balancer()
294 .balance()
295 .unwrap();
296 assert_eq!(res.equation(), "N2 + 6H <-> 2NH3");
297 }
298
299 #[test]
300 fn balance_coefs_exist_but_should_be_one() {
301 let res = Equation::new("2H2 + I2 -> 2HI")
302 .unwrap()
303 .to_balancer()
304 .balance()
305 .unwrap();
306 assert_eq!(res.equation(), "H2 + I2 -> 2HI");
307 }
308}