chem_eq/
balance.rs

1//! Balance a chemical equation
2//!
3//!
4
5use std::collections::HashMap;
6
7use crate::{error::BalanceError, Equation};
8use ndarray::prelude::*;
9use num::{Integer, Rational64, Signed, Zero};
10
11/// Takes an equation and balances it.
12///
13/// # Examples
14///
15/// ```rust
16/// use chem_eq::{Equation, balance::EquationBalancer};
17///
18/// let eq = Equation::new("H2 + O2 -> H2O").unwrap();
19/// let balancer = EquationBalancer::new(&eq);
20/// let balanced_eq = balancer.balance().unwrap();
21///
22/// assert_eq!(balanced_eq.equation(), "2H2 + O2 -> 2H2O");
23/// ```
24#[derive(Debug, Clone)]
25pub struct EquationBalancer<'a> {
26    eq: &'a Equation,
27    matrix: Array2<Rational64>,
28}
29
30impl<'a> EquationBalancer<'a> {
31    /// Create an equation balancer of a given [`Equation`]
32    pub fn new(eq: &'a Equation) -> Self {
33        // map each unique element to a column in the matrix
34        let uniq_elements: HashMap<&str, usize> = eq
35            .uniq_elements()
36            .into_iter()
37            .enumerate()
38            .map(|(i, e)| (e, i))
39            .collect();
40
41        // construct vector with correct sizing
42        let row = eq.num_compounds();
43        let col = uniq_elements.len();
44        let mut arr = Array2::<Rational64>::zeros((row, col));
45
46        let mut left_or_right: Rational64 = 1.into();
47        // fill in vector with counts of elements
48        for (cmp, i) in eq.iter_compounds().zip(0..row) {
49            for el in &cmp.elements {
50                let index = *uniq_elements.get(el.symbol()).unwrap();
51                arr[[i, index]] = <i64 as Into<Rational64>>::into(el.count as i64) * left_or_right;
52            }
53            // invert compounds on the right because they are products.
54            // when they're brought to the other side of the equation, (because they start off
55            // on the opposite side) the counts will be inverted (as math works).
56            if i + 1 >= eq.left.len() {
57                left_or_right = Rational64::from_integer(-1);
58            }
59        }
60
61        Self {
62            eq,
63            matrix: arr.reversed_axes(),
64        }
65    }
66
67    /// Balance the internal equation consuming self and returning the balanced form.
68    ///
69    /// # Examples
70    ///
71    /// ```rust
72    /// use chem_eq::{Equation, balance::EquationBalancer};
73    ///
74    /// let eq = Equation::new("Fe + O2 -> Fe2O3").unwrap();
75    /// let solver = EquationBalancer::new(&eq);
76    /// let solved = solver.balance().unwrap();
77    ///
78    /// assert_eq!(solved.equation(), "4Fe + 3O2 -> 2Fe2O3");
79    /// ```
80    pub fn balance(self) -> Result<Equation, BalanceError> {
81        if !self.eq.is_valid() {
82            return Err(BalanceError::InvalidEquation);
83        }
84        if self.eq.is_balanced() {
85            return Ok(self.eq.clone());
86        }
87        let mut eq = self.eq.clone();
88
89        let matrix = self.matrix;
90        // reduced row echelon form, or kernel, or null space
91        let null_space = rref(augment(rref(matrix.view()).t()).view());
92
93        // last column is the coefficients (as fractions)
94        let vec = null_space
95            .row(null_space.dim().0 - 1)
96            .to_owned()
97            .iter()
98            .skip_while(|n| *n.numer() == 0)
99            .map(Rational64::abs)
100            .collect::<Vec<Rational64>>();
101        let coef_col = Array1::from_vec(vec);
102
103        // get lcm of the denominators of the coefficients to scale them up
104        let lcm = coef_col
105            .iter()
106            .map(Rational64::denom)
107            .fold(1, |acc: i64, f| acc.lcm(f));
108
109        // scale up the solutions
110        let coef_col = coef_col * lcm;
111        if coef_col.to_vec().contains(&Rational64::from_integer(0)) {
112            return Err(BalanceError::Infeasable);
113        }
114
115        // replace the coefficients
116        for (compound, coef) in eq
117            .iter_compounds_mut()
118            .zip(coef_col.iter().map(Rational64::numer))
119        {
120            compound.coefficient = *coef as _;
121        }
122
123        // replace equation field with correct coefficients
124        let mut comp_str: Vec<String> = self
125            .eq
126            .equation
127            .split(' ')
128            .filter(|c| !matches!(*c, "+" | "<-" | "<->" | "->"))
129            .map(Into::into)
130            .collect();
131        for (cmp, str) in eq.iter_compounds().zip(comp_str.iter_mut()) {
132            let mut to_remove = 0;
133            for c in str.chars() {
134                if c.is_numeric() {
135                    to_remove += 1;
136                } else {
137                    break;
138                }
139            }
140            for _ in 0..to_remove {
141                str.remove(0);
142            }
143            if cmp.coefficient != 1 {
144                str.insert_str(0, cmp.coefficient.to_string().as_str());
145            }
146        }
147        // concatenate compounds with "+" signs
148        let reactants = comp_str[..eq.left.len()].join(" + ");
149        let products = comp_str[eq.left.len()..].join(" + ");
150
151        // combine products and reactants with sign in the middle
152        eq.equation = format!("{} {} {}", reactants, eq.direction, products);
153
154        Ok(eq)
155    }
156}
157
158impl<'a> From<&'a Equation> for EquationBalancer<'a> {
159    /// Create matrix for solving out of equation
160    fn from(eq: &'a Equation) -> Self {
161        Self::new(eq)
162    }
163}
164
165// Thanks to u/mindv0rtex on reddit, @mindv0rtex on github
166// reduced row echelon form
167fn rref(a: ArrayView2<Rational64>) -> Array2<Rational64> {
168    let mut out = ArrayBase::zeros(a.raw_dim());
169    out.zip_mut_with(&a, |x, y| *x = *y);
170
171    let mut pivot = 0;
172    let (rows, cols) = out.raw_dim().into_pattern();
173
174    'outer: for r in 0..rows {
175        if cols <= pivot {
176            break;
177        }
178        let mut i = r;
179        while (out[[i, pivot]] as Rational64).numer().is_zero() {
180            i += 1;
181            if i == rows {
182                i = r;
183                pivot += 1;
184                if cols == pivot {
185                    break 'outer;
186                }
187            }
188        }
189        for j in 0..cols {
190            out.swap([r, j], [i, j]);
191        }
192        let divisor: Rational64 = out[[r, pivot]];
193        if !divisor.numer().is_zero() {
194            out.row_mut(r).iter_mut().for_each(|e| *e /= divisor);
195        }
196        for j in 0..rows {
197            if j != r {
198                let hold = out[[j, pivot]];
199                for k in 0..cols {
200                    let t = out[[r, k]];
201                    out[[j, k]] -= hold * t;
202                }
203            }
204        }
205        pivot += 1;
206    }
207
208    out
209}
210
211// ...
212fn augment(a: ArrayView2<Rational64>) -> Array2<Rational64> {
213    ndarray::concatenate(Axis(1), &[a.view(), Array2::eye(a.shape()[0]).view()]).unwrap()
214}
215
216#[cfg(test)]
217mod tests {
218    use super::*;
219
220    #[test]
221    fn balance_simple() {
222        let eq = Equation::new("H2 + O2 -> H2O")
223            .unwrap()
224            .to_balancer()
225            .balance()
226            .unwrap();
227        assert_eq!(eq.equation, "2H2 + O2 -> 2H2O");
228    }
229
230    #[test]
231    fn balance_simple_backwards() {
232        let eq = Equation::new("O2 + H2 -> H2O")
233            .unwrap()
234            .to_balancer()
235            .balance()
236            .unwrap();
237        assert_eq!(eq.equation, "O2 + 2H2 -> 2H2O");
238    }
239
240    #[test]
241    fn balance_other_simple() {
242        let eq = Equation::new("Al + O2 -> Al2O3")
243            .unwrap()
244            .to_balancer()
245            .balance()
246            .unwrap();
247        assert_eq!(eq.equation, "4Al + 3O2 -> 2Al2O3");
248    }
249
250    #[test]
251    fn balance_already_done() {
252        let eq = Equation::new("C2H4 + 3O2 -> 2CO2 + 2H2O")
253            .unwrap()
254            .to_balancer()
255            .balance()
256            .unwrap();
257        assert_eq!(eq.equation, "C2H4 + 3O2 -> 2CO2 + 2H2O");
258    }
259
260    #[test]
261    fn balance_harder() {
262        let eq = Equation::new("C2H6 + O2 -> CO2 + H2O")
263            .unwrap()
264            .to_balancer()
265            .balance()
266            .unwrap();
267        assert_eq!(eq.equation, "2C2H6 + 7O2 -> 4CO2 + 6H2O");
268    }
269
270    #[test]
271    fn try_balance_infeasible() {
272        let res = Equation::new("K4Fe(CN)6 + K2S2O3 -> CO2 + K2SO4 + NO2 + FeS")
273            .unwrap()
274            .to_balancer()
275            .balance();
276        assert_eq!(res, Err(BalanceError::Infeasable));
277    }
278
279    #[test]
280    fn try_balance_coefs_already_exist() {
281        let res = Equation::new("H2 + I -> 2HI")
282            .unwrap()
283            .to_balancer()
284            .balance()
285            .unwrap();
286        assert_eq!(res.equation(), "H2 + 2I -> 2HI");
287    }
288
289    #[test]
290    fn try_balance_coefs_already_exist_two() {
291        let res = Equation::new("N2 + H <-> 2NH3")
292            .unwrap()
293            .to_balancer()
294            .balance()
295            .unwrap();
296        assert_eq!(res.equation(), "N2 + 6H <-> 2NH3");
297    }
298
299    #[test]
300    fn balance_coefs_exist_but_should_be_one() {
301        let res = Equation::new("2H2 + I2 -> 2HI")
302            .unwrap()
303            .to_balancer()
304            .balance()
305            .unwrap();
306        assert_eq!(res.equation(), "H2 + I2 -> 2HI");
307    }
308}