1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
// Copyright (c) 2015, Mikhail Vorotilov
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
//
// * Redistributions of source code must retain the above copyright notice, this
//   list of conditions and the following disclaimer.
//
// * Redistributions in binary form must reproduce the above copyright notice,
//   this list of conditions and the following disclaimer in the documentation
//   and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
// FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
// DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
// SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
// CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
// OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

use super::super::FloatType;
use super::super::Roots;

/// Solves a quartic equation a4*x^4 + a4*x^3 + a2*x^2 + a1*x + a0 = 0.
/// pp, rr, and dd are already computed while searching for multiple roots
fn find_roots_via_depressed_quartic<F: FloatType>(a4: F, a3: F, a2: F, a1: F, a0: F, pp: F, rr: F, dd: F) -> Roots<F> {
    // Depressed quartic
    // https://en.wikipedia.org/wiki/Quartic_function#Converting_to_a_depressed_quartic

    let _2 = F::from(2i16);
    let _3 = F::from(3i16);
    let _4 = F::from(4i16);
    let _6 = F::from(6i16);
    let _8 = F::from(8i16);
    let _12 = F::from(12i16);
    let _16 = F::from(16i16);
    let _256 = F::from(256i16);

    // a4*x^4 + a3*x^3 + a2*x^2 + a1*x + a0 = 0 => y^4 + p*y^2 + q*y + r.
    let a4_pow_2 = a4 * a4;
    let a4_pow_3 = a4_pow_2 * a4;
    let a4_pow_4 = a4_pow_2 * a4_pow_2;
    // Re-use pre-calculated values
    let p = pp / (_8 * a4_pow_2);
    let q = rr / (_8 * a4_pow_3);
    let r = (dd + _16 * a4_pow_2 * (_12 * a0 * a4 - _3 * a1 * a3 + a2 * a2)) / (_256 * a4_pow_4);

    let mut roots = Roots::No([]);
    for y in super::quartic_depressed::find_roots_quartic_depressed(p, q, r)
        .as_ref()
        .iter()
    {
        roots = roots.add_new_root(*y - a3 / (_4 * a4));
    }
    roots
}

/// Solves a quartic equation a4*x^4 + a3*x^3 + a2*x^2 + a1*x + a0 = 0.
///
/// Returned roots are ordered.
/// Precision is about 5e-15 for f64, 5e-7 for f32.
/// WARNING: f32 is often not enough to find multiple roots.
///
/// # Examples
///
/// ```
/// use roots::find_roots_quartic;
///
/// let one_root = find_roots_quartic(1f64, 0f64, 0f64, 0f64, 0f64);
/// // Returns Roots::One([0f64]) as 'x^4 = 0' has one root 0
///
/// let two_roots = find_roots_quartic(1f32, 0f32, 0f32, 0f32, -1f32);
/// // Returns Roots::Two([-1f32, 1f32]) as 'x^4 - 1 = 0' has roots -1 and 1
///
/// let multiple_roots = find_roots_quartic(-14.0625f64, -3.75f64, 29.75f64, 4.0f64, -16.0f64);
/// // Returns Roots::Two([-1.1016116464173349f64, 0.9682783130840016f64])
///
/// let multiple_roots_not_found = find_roots_quartic(-14.0625f32, -3.75f32, 29.75f32, 4.0f32, -16.0f32);
/// // Returns Roots::No([]) because of f32 rounding errors when trying to calculate the discriminant
/// ```
pub fn find_roots_quartic<F: FloatType>(a4: F, a3: F, a2: F, a1: F, a0: F) -> Roots<F> {
    // Handle non-standard cases
    if a4 == F::zero() {
        // a4 = 0; a3*x^3 + a2*x^2 + a1*x + a0 = 0; solve cubic equation
        super::cubic::find_roots_cubic(a3, a2, a1, a0)
    } else if a0 == F::zero() {
        // a0 = 0; x^4 + a2*x^2 + a1*x = 0; reduce to cubic and arrange results
        super::cubic::find_roots_cubic(a4, a3, a2, a1).add_new_root(F::zero())
    } else if a1 == F::zero() && a3 == F::zero() {
        // a1 = 0, a3 =0; a4*x^4 + a2*x^2 + a0 = 0; solve bi-quadratic equation
        super::biquadratic::find_roots_biquadratic(a4, a2, a0)
    } else {
        let _3 = F::from(3i16);
        let _4 = F::from(4i16);
        let _6 = F::from(6i16);
        let _8 = F::from(8i16);
        let _9 = F::from(9i16);
        let _10 = F::from(10i16);
        let _12 = F::from(12i16);
        let _16 = F::from(16i16);
        let _18 = F::from(18i16);
        let _27 = F::from(27i16);
        let _64 = F::from(64i16);
        let _72 = F::from(72i16);
        let _80 = F::from(80i16);
        let _128 = F::from(128i16);
        let _144 = F::from(144i16);
        let _192 = F::from(192i16);
        let _256 = F::from(256i16);
        // Discriminant
        // https://en.wikipedia.org/wiki/Quartic_function#Nature_of_the_roots
        // Partially simplifed to keep intermediate values smaller (to minimize rounding errors).
        let discriminant = a4 * a0 * a4 * (_256 * a4 * a0 * a0 + a1 * (_144 * a2 * a1 - _192 * a3 * a0))
            + a4 * a0 * a2 * a2 * (_16 * a2 * a2 - _80 * a3 * a1 - _128 * a4 * a0)
            + (a3
                * a3
                * (a4 * a0 * (_144 * a2 * a0 - _6 * a1 * a1)
                    + (a0 * (_18 * a3 * a2 * a1 - _27 * a3 * a3 * a0 - _4 * a2 * a2 * a2) + a1 * a1 * (a2 * a2 - _4 * a3 * a1))))
            + a4 * a1 * a1 * (_18 * a3 * a2 * a1 - _27 * a4 * a1 * a1 - _4 * a2 * a2 * a2);
        let pp = _8 * a4 * a2 - _3 * a3 * a3;
        let rr = a3 * a3 * a3 + _8 * a4 * a4 * a1 - _4 * a4 * a3 * a2;
        let delta0 = a2 * a2 - _3 * a3 * a1 + _12 * a4 * a0;
        let dd = _64 * a4 * a4 * a4 * a0 - _16 * a4 * a4 * a2 * a2 + _16 * a4 * a3 * a3 * a2
            - _16 * a4 * a4 * a3 * a1
            - _3 * a3 * a3 * a3 * a3;

        // Handle special cases
        let double_root = discriminant == F::zero();
        if double_root {
            let triple_root = double_root && delta0 == F::zero();
            let quadruple_root = triple_root && dd == F::zero();
            let no_roots = dd == F::zero() && pp > F::zero() && rr == F::zero();
            if quadruple_root {
                // Wiki: all four roots are equal
                Roots::One([-a3 / (_4 * a4)])
            } else if triple_root {
                // Wiki: At least three roots are equal to each other
                // x0 is the unique root of the remainder of the Euclidean division of the quartic by its second derivative
                //
                // Solved by SymPy (ra is the desired reminder)
                // a, b, c, d, e = symbols('a,b,c,d,e')
                // f=a*x**4+b*x**3+c*x**2+d*x+e     // Quartic polynom
                // g=6*a*x**2+3*b*x+c               // Second derivative
                // q, r = div(f, g)                 // SymPy only finds the highest power
                // simplify(f-(q*g+r)) == 0         // Verify the first division
                // qa, ra = div(r/a,g/a)            // Workaround to get the second division
                // simplify(f-((q+qa)*g+ra*a)) == 0 // Verify the second division
                // solve(ra,x)
                // ----- yields
                // (−72*a^2*e+10*a*c^2−3*b^2*c)/(9*(8*a^2*d−4*a*b*c+b^3))
                let x0 = (-_72 * a4 * a4 * a0 + _10 * a4 * a2 * a2 - _3 * a3 * a3 * a2)
                    / (_9 * (_8 * a4 * a4 * a1 - _4 * a4 * a3 * a2 + a3 * a3 * a3));
                let roots = Roots::One([x0]);
                roots.add_new_root(-(a3 / a4 + _3 * x0))
            } else if no_roots {
                // Wiki: two complex conjugate double roots
                Roots::No([])
            } else {
                find_roots_via_depressed_quartic(a4, a3, a2, a1, a0, pp, rr, dd)
            }
        } else {
            let no_roots = discriminant > F::zero() && (pp > F::zero() || dd > F::zero());
            if no_roots {
                // Wiki: two pairs of non-real complex conjugate roots
                Roots::No([])
            } else {
                find_roots_via_depressed_quartic(a4, a3, a2, a1, a0, pp, rr, dd)
            }
        }
    }
}

#[cfg(test)]
mod test {
    use super::super::super::*;

    #[test]
    fn test_find_roots_quartic() {
        assert_eq!(find_roots_quartic(1f32, 0f32, 0f32, 0f32, 0f32), Roots::One([0f32]));
        assert_eq!(find_roots_quartic(1f64, 0f64, 0f64, 0f64, -1f64), Roots::Two([-1f64, 1f64]));
        assert_eq!(
            find_roots_quartic(1f64, -10f64, 35f64, -50f64, 24f64),
            Roots::Four([1f64, 2f64, 3f64, 4f64])
        );

        match find_roots_quartic(1.1248467624839498f64, -4.8721513473605924f64,
                                 7.9323705711747614f64, -5.7774307699949397f64,
                                 1.5971379368787519f64) {
            Roots::Two(x) => {
                assert_float_array_eq!(2e-15f64, x, [1.225913506454221f64, 1.257275575390252f64]);
            }
            _ => {
                assert!(false);
            }
        }

        match find_roots_quartic(3f64, 5f64, -5f64, -5f64, 2f64) {
            Roots::Four(x) => {
                assert_float_array_eq!(2e-15f64, x, [-2f64, -1f64, 0.33333333333333333f64, 1f64]);
            }
            _ => {
                assert!(false);
            }
        }

        match find_roots_quartic(3f32, 5f32, -5f32, -5f32, 2f32) {
            Roots::Four(x) => {
                assert_float_array_eq!(5e-7, x, [-2f32, -1f32, 0.33333333333333333f32, 1f32]);
            }
            _ => {
                assert!(false);
            }
        }
    }

    #[test]
    fn test_find_roots_quartic_tim_luecke() {
        // Reported in December 2019
        assert_eq!(
            find_roots_quartic(-14.0625f64, -3.75f64, 29.75f64, 4.0f64, -16.0f64),
            Roots::Two([-1.1016116464173349f64, 0.9682783130840016f64])
        );
        // 32-bit floating point is not accurate enough to solve this case ...
        assert_eq!(
            find_roots_quartic(-14.0625f32, -3.75f32, 29.75f32, 4.0f32, -16.0f32),
            Roots::No([])
        );
        // ... even after normalizing
        assert_eq!(
            find_roots_quartic(
                1f32,
                -3.75f32 / -14.0625f32,
                29.75f32 / -14.0625f32,
                4.0f32 / -14.0625f32,
                -16.0f32 / -14.0625f32
            ),
            Roots::No([])
        );
        // assert_eq!(
        //     find_roots_quartic(-14.0625f32, -3.75f32, 29.75f32, 4.0f32, -16.0f32),
        //     Roots::Two([-1.1016117f32, 0.96827835f32])
        // );
    }

    #[test]
    fn test_find_roots_quartic_triple_root() {
        // (x+3)(3x-1)^3 == 27 x^4 + 54 x^3 - 72 x^2 + 26 x - 3
        assert_eq!(
            find_roots_quartic(27f64, 54f64, -72f64, 26f64, -3f64),
            Roots::Two([-3.0f64, 0.3333333333333333f64])
        );
        assert_eq!(
            find_roots_quartic(27f32, 54f32, -72f32, 26f32, -3f32),
            Roots::Two([-3.0f32, 0.33333333f32])
        );
    }

    #[test]
    fn test_find_roots_quartic_quadruple_root() {
        // (7x+2)^4 == 2401 x^4 + 2744 x^3 + 1176 x^2 + 224 x + 16
        assert_eq!(
            find_roots_quartic(2401f64, 2744f64, 1176f64, 224f64, 16f64),
            Roots::One([-0.2857142857142857f64])
        );
        // 32-bit floating point is less accurate
        assert_eq!(
            find_roots_quartic(2401f32, 2744f32, 1176f32, 224f32, 16f32),
            Roots::One([-0.2857143f32])
        );
    }
}