use alloy_primitives::U256;
use crate::AmmMathError;
pub fn encode_sqrt_ratio_x96(amount1: U256, amount0: U256) -> crate::Result<U256> {
if amount0.is_zero() {
return Err(AmmMathError::ZeroAmounts);
}
let ratio_x192 = if amount1 <= U256::from(u64::MAX) {
(amount1 << 192) / amount0
} else {
let q192 = U256::from(1u64) << 192;
crate::full_math::mul_div(amount1, q192, amount0)?
};
let result = sqrt_u256(ratio_x192);
Ok(result)
}
fn sqrt_u256(x: U256) -> U256 {
if x.is_zero() {
return U256::ZERO;
}
if x <= U256::from(3u64) {
return U256::from(1u64);
}
let bit_len = 256 - x.leading_zeros() as u32;
let mut guess = U256::from(1u64) << bit_len.div_ceil(2);
loop {
let next = (guess + x / guess) >> 1;
if next >= guess {
break;
}
guess = next;
}
guess
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_encode_1_to_1() {
let result = encode_sqrt_ratio_x96(U256::from(1u64), U256::from(1u64)).unwrap();
let q96 = U256::from(1u64) << 96;
assert_eq!(result, q96);
}
#[test]
fn test_encode_4_to_1() {
let result = encode_sqrt_ratio_x96(U256::from(4u64), U256::from(1u64)).unwrap();
let expected = U256::from(1u64) << 97;
assert_eq!(result, expected);
}
#[test]
fn test_encode_1_to_4() {
let result = encode_sqrt_ratio_x96(U256::from(1u64), U256::from(4u64)).unwrap();
let expected = U256::from(1u64) << 95;
assert_eq!(result, expected);
}
#[test]
fn test_encode_zero_denominator() {
let result = encode_sqrt_ratio_x96(U256::from(1u64), U256::ZERO);
assert!(result.is_err());
}
#[test]
fn test_encode_large_ratio() {
let big = U256::from(1_000_000_000_000_000_000u64);
let result = encode_sqrt_ratio_x96(big, big).unwrap();
let q96 = U256::from(1u64) << 96;
assert_eq!(result, q96);
}
#[test]
fn test_sqrt_u256() {
assert_eq!(sqrt_u256(U256::ZERO), U256::ZERO);
assert_eq!(sqrt_u256(U256::from(1u64)), U256::from(1u64));
assert_eq!(sqrt_u256(U256::from(4u64)), U256::from(2u64));
assert_eq!(sqrt_u256(U256::from(9u64)), U256::from(3u64));
assert_eq!(sqrt_u256(U256::from(100u64)), U256::from(10u64));
assert_eq!(sqrt_u256(U256::from(8u64)), U256::from(2u64));
}
}