use super::decimal_compute::{
ComputeStorage, DECIMAL_COMPUTE_DP,
decimal_compute_zero, decimal_compute_one,
decimal_compute_add, decimal_compute_div, decimal_compute_halve,
decimal_compute_is_zero, decimal_compute_is_negative,
decimal_compute_cmp,
};
use crate::fixed_point::domains::symbolic::rational::rational_number::OverflowDetected;
pub fn decimal_sqrt(x: ComputeStorage) -> Result<ComputeStorage, OverflowDetected> {
if decimal_compute_is_negative(&x) {
return Err(OverflowDetected::DomainError);
}
if decimal_compute_is_zero(&x) {
return Ok(decimal_compute_zero());
}
let bits = bit_length(&x);
let one = decimal_compute_one();
let y_init = if bits == 0 {
decimal_compute_halve(one)
} else {
let dp_bits = (DECIMAL_COMPUTE_DP as u32) * 33219 / 10000; let target_bits = (bits + dp_bits) / 2;
shift_one_left(target_bits)
};
let mut y = if is_compute_zero_or_neg(&y_init) {
decimal_compute_halve(one)
} else {
y_init
};
const MAX_ITERATIONS: u32 = 200;
let mut prev = decimal_compute_zero();
for _ in 0..MAX_ITERATIONS {
let x_div_y = decimal_compute_div(x, y)?;
let sum = decimal_compute_add(y, x_div_y);
let new_y = decimal_compute_halve(sum);
if decimal_compute_cmp(&new_y, &y) == std::cmp::Ordering::Equal {
return Ok(new_y);
}
if decimal_compute_cmp(&new_y, &prev) == std::cmp::Ordering::Equal {
return Ok(new_y);
}
prev = y;
y = new_y;
}
Ok(y)
}
fn bit_length(v: &ComputeStorage) -> u32 {
#[cfg(table_format = "q16_16")]
{
if *v <= 0 { 0 } else { 64 - v.leading_zeros() }
}
#[cfg(table_format = "q32_32")]
{
if *v <= 0 { 0 } else { 128 - v.leading_zeros() }
}
#[cfg(table_format = "q64_64")]
{
for i in (0..4).rev() {
if v.words[i] != 0 {
return (i as u32) * 64 + (64 - v.words[i].leading_zeros());
}
}
0
}
#[cfg(table_format = "q128_128")]
{
for i in (0..8).rev() {
if v.words[i] != 0 {
return (i as u32) * 64 + (64 - v.words[i].leading_zeros());
}
}
0
}
#[cfg(table_format = "q256_256")]
{
for i in (0..16).rev() {
if v.words[i] != 0 {
return (i as u32) * 64 + (64 - v.words[i].leading_zeros());
}
}
0
}
}
fn shift_one_left(n: u32) -> ComputeStorage {
#[cfg(table_format = "q16_16")]
{
if n >= 63 { i64::MAX } else { 1i64 << n }
}
#[cfg(table_format = "q32_32")]
{
if n >= 127 { i128::MAX } else { 1i128 << n }
}
#[cfg(table_format = "q64_64")]
{
use crate::fixed_point::i256::I256;
if n >= 254 { I256::from_i128(1) << 254usize } else { I256::from_i128(1) << (n as usize) }
}
#[cfg(table_format = "q128_128")]
{
use crate::fixed_point::i512::I512;
if n >= 510 { I512::from_i128(1) << 510usize } else { I512::from_i128(1) << (n as usize) }
}
#[cfg(table_format = "q256_256")]
{
use crate::fixed_point::I1024;
if n >= 1020 { I1024::from_i128(1) << 1020usize } else { I1024::from_i128(1) << (n as usize) }
}
}
fn is_compute_zero_or_neg(v: &ComputeStorage) -> bool {
use super::decimal_compute::{decimal_compute_is_zero, decimal_compute_is_negative};
decimal_compute_is_zero(v) || decimal_compute_is_negative(v)
}
#[cfg(all(test, table_format = "q64_64"))]
mod tests {
use super::*;
use super::super::decimal_compute::decimal_compute_from_int;
use crate::fixed_point::i256::I256;
fn parse_decimal_str(s: &str) -> I256 {
let mut result = I256::from_i128(0);
let ten = I256::from_i128(10);
for ch in s.chars() {
let digit = ch.to_digit(10).expect("non-digit");
result = result * ten + I256::from_i128(digit as i128);
}
result
}
#[test]
fn sqrt_zero() {
let result = decimal_sqrt(decimal_compute_zero()).unwrap();
assert_eq!(result, decimal_compute_zero());
}
#[test]
fn sqrt_one() {
let result = decimal_sqrt(decimal_compute_one()).unwrap();
assert_eq!(result, decimal_compute_one());
}
#[test]
fn sqrt_four_is_two() {
let four = decimal_compute_from_int(4);
let two = decimal_compute_from_int(2);
let result = decimal_sqrt(four).unwrap();
assert_eq!(result, two);
}
#[test]
fn sqrt_nine_is_three() {
let nine = decimal_compute_from_int(9);
let three = decimal_compute_from_int(3);
let result = decimal_sqrt(nine).unwrap();
assert_eq!(result, three);
}
#[test]
fn sqrt_two_mpmath() {
let two = decimal_compute_from_int(2);
let result = decimal_sqrt(two).unwrap();
let expected = parse_decimal_str("141421356237309504880168872420969807857");
let diff = if result > expected { result - expected } else { expected - result };
let tolerance = I256::from_i128(1000);
assert!(
diff < tolerance,
"sqrt(2) precision: got={:?} expected={:?} diff={:?}",
result, expected, diff
);
}
#[test]
fn sqrt_three_mpmath() {
let three = decimal_compute_from_int(3);
let result = decimal_sqrt(three).unwrap();
let expected = parse_decimal_str("173205080756887729352744634150587236694");
let diff = if result > expected { result - expected } else { expected - result };
let tolerance = I256::from_i128(1000);
assert!(
diff < tolerance,
"sqrt(3) precision: got={:?} expected={:?} diff={:?}",
result, expected, diff
);
}
}