use super::super::FloatType;
use super::super::Roots;
pub fn find_roots_quadratic<F: FloatType>(a2: F, a1: F, a0: F) -> Roots<F> {
if a2 == F::zero() {
super::linear::find_roots_linear(a1, a0)
} else {
let _2 = F::from(2i16);
let _4 = F::from(4i16);
let discriminant = a1 * a1 - _4 * a2 * a0;
if discriminant < F::zero() {
Roots::No([])
} else {
let a2x2 = _2 * a2;
if discriminant == F::zero() {
Roots::One([-a1 / a2x2])
} else {
let sq = discriminant.sqrt();
let (same_sign, diff_sign) = if a1 < F::zero() {
(-a1 + sq, -a1 - sq)
} else {
(-a1 - sq, -a1 + sq)
};
let (x1, x2) = if same_sign.abs() > a2x2.abs() {
let a0x2 = _2 * a0;
if diff_sign.abs() > a2x2.abs() {
(a0x2 / same_sign, a0x2 / diff_sign)
} else {
(a0x2 / same_sign, same_sign / a2x2)
}
} else {
(diff_sign / a2x2, same_sign / a2x2)
};
if x1 < x2 {
Roots::Two([x1, x2])
} else {
Roots::Two([x2, x1])
}
}
}
}
}
#[cfg(test)]
mod test {
use super::super::super::*;
#[test]
fn test_find_roots_quadratic() {
assert_eq!(find_roots_quadratic(0f32, 0f32, 0f32), Roots::One([0f32]));
assert_eq!(find_roots_quadratic(1f32, 0f32, 1f32), Roots::No([]));
assert_eq!(find_roots_quadratic(1f64, 0f64, -1f64), Roots::Two([-1f64, 1f64]));
}
#[test]
fn test_find_roots_quadratic_small_a2() {
assert_eq!(
find_roots_quadratic(1e-20f32, -1f32, -1e-30f32),
Roots::Two([-1e-30f32, 1e20f32])
);
assert_eq!(
find_roots_quadratic(-1e-20f32, 1f32, 1e-30f32),
Roots::Two([-1e-30f32, 1e20f32])
);
assert_eq!(find_roots_quadratic(1e-20f32, -1f32, 1f32), Roots::Two([1f32, 1e20f32]));
assert_eq!(find_roots_quadratic(-1e-20f32, 1f32, 1f32), Roots::Two([-1f32, 1e20f32]));
assert_eq!(find_roots_quadratic(-1e-20f32, 1f32, -1f32), Roots::Two([1f32, 1e20f32]));
}
#[test]
fn test_find_roots_quadratic_big_a1() {
assert_eq!(find_roots_quadratic(1f32, -1e15f32, -1f32), Roots::Two([-1e-15f32, 1e15f32]));
assert_eq!(find_roots_quadratic(-1f32, 1e15f32, 1f32), Roots::Two([-1e-15f32, 1e15f32]));
}
}