use std::ops::{Add, Sub, Mul, Shl, Shr, BitOr, BitAnd, Neg};
use std::cmp::{Ord, PartialOrd, Ordering};
use crate::fixed_point::i256::I256;
#[derive(Clone, Copy, Debug)]
pub struct I512 {
pub words: [u64; 8],
}
impl I512 {
#[inline(always)]
pub const fn zero() -> Self {
I512 { words: [0; 8] }
}
#[inline(always)]
pub const fn one() -> Self {
I512 { words: [1, 0, 0, 0, 0, 0, 0, 0] }
}
#[inline(always)]
pub const fn max_value() -> Self {
I512 {
words: [
0xFFFF_FFFF_FFFF_FFFF, 0xFFFF_FFFF_FFFF_FFFF, 0xFFFF_FFFF_FFFF_FFFF, 0xFFFF_FFFF_FFFF_FFFF, 0xFFFF_FFFF_FFFF_FFFF, 0xFFFF_FFFF_FFFF_FFFF, 0xFFFF_FFFF_FFFF_FFFF, 0x7FFF_FFFF_FFFF_FFFF, ]
}
}
#[inline(always)]
pub const fn min_value() -> Self {
I512 {
words: [
0, 0, 0, 0, 0, 0, 0, 0x8000_0000_0000_0000, ]
}
}
#[inline(always)]
pub const fn from_words(words: [u64; 8]) -> Self {
I512 { words }
}
#[inline(always)]
pub const fn from_i256(value: I256) -> Self {
let is_negative = (value.words[3] as i64) < 0;
let sign_extend = if is_negative { u64::MAX } else { 0 };
I512 {
words: [
value.words[0],
value.words[1],
value.words[2],
value.words[3],
sign_extend,
sign_extend,
sign_extend,
sign_extend,
]
}
}
#[inline(always)]
pub fn is_zero(&self) -> bool {
self.words.iter().all(|&word| word == 0)
}
#[inline(always)]
pub fn is_negative(self) -> bool {
(self.words[7] & 0x8000_0000_0000_0000) != 0
}
#[inline(always)]
pub const fn from_i128(value: i128) -> Self {
let is_negative = value < 0;
let sign_extend = if is_negative { u64::MAX } else { 0 };
I512 {
words: [
value as u64,
(value >> 64) as u64,
sign_extend,
sign_extend,
sign_extend,
sign_extend,
sign_extend,
sign_extend,
]
}
}
#[inline(always)]
pub const fn from_u128(value: u128) -> Self {
I512 {
words: [
value as u64, (value >> 64) as u64, 0, 0,
0,
0,
0,
0,
]
}
}
#[inline(always)]
pub fn as_i256(self) -> I256 {
I256::from_words([
self.words[0],
self.words[1],
self.words[2],
self.words[3],
])
}
#[inline(always)]
pub fn as_i256_saturating(self) -> I256 {
if self.fits_in_i256() {
self.as_i256()
} else {
let is_negative = (self.words[7] as i64) < 0;
if is_negative {
I256::min_value()
} else {
I256::max_value()
}
}
}
#[inline(always)]
pub fn as_i128(self) -> i128 {
((self.words[1] as i128) << 64) | (self.words[0] as i128)
}
#[inline(always)]
pub fn fits_in_i256(self) -> bool {
let sign_bit_i256 = (self.words[3] as i64) < 0;
let expected_high = if sign_bit_i256 { u64::MAX } else { 0 };
self.words[4] == expected_high &&
self.words[5] == expected_high &&
self.words[6] == expected_high &&
self.words[7] == expected_high
}
#[inline(always)]
pub fn fits_in_i128(self) -> bool {
let is_negative = (self.words[1] as i64) < 0;
let expected_high = if is_negative { u64::MAX } else { 0 };
(2..8).all(|i| self.words[i] == expected_high)
}
#[inline(always)]
pub fn to_bytes_le(self) -> Vec<u8> {
let mut bytes = Vec::with_capacity(64);
for word in self.words.iter() {
bytes.extend_from_slice(&word.to_le_bytes());
}
bytes
}
#[inline(always)]
pub fn from_bytes_le(bytes: &[u8]) -> Self {
assert_eq!(bytes.len(), 64, "I512 requires exactly 64 bytes");
let mut words = [0u64; 8];
for i in 0..8 {
let start = i * 8;
let word_bytes: [u8; 8] = bytes[start..start+8].try_into().unwrap();
words[i] = u64::from_le_bytes(word_bytes);
}
I512 { words }
}
#[inline(always)]
pub fn checked_add(self, rhs: I512) -> Option<I512> {
let result = self + rhs;
let self_negative = (self.words[7] as i64) < 0;
let rhs_negative = (rhs.words[7] as i64) < 0;
let result_negative = (result.words[7] as i64) < 0;
if (self_negative == rhs_negative) && (self_negative != result_negative) {
None
} else {
Some(result)
}
}
#[inline(always)]
pub fn checked_sub(self, rhs: I512) -> Option<I512> {
let result = self - rhs;
let self_negative = (self.words[7] as i64) < 0;
let rhs_negative = (rhs.words[7] as i64) < 0;
let result_negative = (result.words[7] as i64) < 0;
if (self_negative != rhs_negative) && (result_negative != self_negative) {
None
} else {
Some(result)
}
}
#[inline(always)]
pub fn checked_neg(self) -> Option<I512> {
if self == I512::from_words([0, 0, 0, 0, 0, 0, 0, 0x8000_0000_0000_0000]) {
None
} else {
Some(-self)
}
}
#[inline(always)]
pub fn as_i512(self) -> I512 {
self
}
#[inline(always)]
pub fn checked_mul(self, rhs: I512) -> Option<I512> {
let result_i1024 = self.multiply_i512(&rhs);
let is_negative = (result_i1024.words[15] as i64) < 0;
let expected_high = if is_negative { u64::MAX } else { 0 };
let fits = (8..16).all(|i| result_i1024.words[i] == expected_high);
if fits {
Some(I512::from_words([
result_i1024.words[0],
result_i1024.words[1],
result_i1024.words[2],
result_i1024.words[3],
result_i1024.words[4],
result_i1024.words[5],
result_i1024.words[6],
result_i1024.words[7],
]))
} else {
None
}
}
}
impl PartialEq for I512 {
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
self.words == other.words
}
}
impl Eq for I512 {}
impl PartialOrd for I512 {
#[inline(always)]
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for I512 {
#[inline(always)]
fn cmp(&self, other: &Self) -> Ordering {
let self_negative = (self.words[7] as i64) < 0;
let other_negative = (other.words[7] as i64) < 0;
match (self_negative, other_negative) {
(true, false) => Ordering::Less,
(false, true) => Ordering::Greater,
_ => {
for i in (0..8).rev() {
match self.words[i].cmp(&other.words[i]) {
Ordering::Equal => continue,
ord => return ord,
}
}
Ordering::Equal
}
}
}
}
impl Add for I512 {
type Output = Self;
#[inline(always)]
fn add(self, rhs: Self) -> Self {
let mut result = [0u64; 8];
let mut carry = 0u64;
for i in 0..8 {
let sum = (self.words[i] as u128) + (rhs.words[i] as u128) + (carry as u128);
result[i] = sum as u64;
carry = (sum >> 64) as u64;
}
I512 { words: result }
}
}
impl Sub for I512 {
type Output = Self;
#[inline(always)]
fn sub(self, rhs: Self) -> Self {
let mut result = [0u64; 8];
let mut borrow = 0u64;
for i in 0..8 {
let a = self.words[i] as u128;
let b = rhs.words[i] as u128 + borrow as u128;
if a >= b {
result[i] = (a - b) as u64;
borrow = 0;
} else {
result[i] = ((1u128 << 64) + a - b) as u64;
borrow = 1;
}
}
I512 { words: result }
}
}
impl std::fmt::Display for I512 {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.fits_in_i128() {
let as_i128 = self.as_i128();
write!(f, "{}", as_i128)
} else {
let as_i256 = self.as_i256_saturating();
write!(f, "{}(I256)", as_i256)
}
}
}
impl Neg for I512 {
type Output = Self;
#[inline(always)]
fn neg(self) -> Self {
let mut result = [0u64; 8];
let mut carry = 1u64;
for i in 0..8 {
let sum = (!self.words[i] as u128) + (carry as u128);
result[i] = sum as u64;
carry = (sum >> 64) as u64;
}
I512 { words: result }
}
}
impl Shl<usize> for I512 {
type Output = Self;
#[inline(always)]
fn shl(self, shift: usize) -> Self {
if shift == 0 { return self; }
if shift >= 512 { return I512::zero(); }
let word_shift = shift / 64;
let bit_shift = shift % 64;
let mut result = [0u64; 8];
if bit_shift == 0 {
for i in word_shift..8 {
result[i] = self.words[i - word_shift];
}
} else {
for i in word_shift..8 {
let low = if i > word_shift { self.words[i - word_shift - 1] >> (64 - bit_shift) } else { 0 };
let high = self.words[i - word_shift] << bit_shift;
result[i] = high | low;
}
}
I512 { words: result }
}
}
impl Shr<usize> for I512 {
type Output = Self;
#[inline(always)]
fn shr(self, shift: usize) -> Self {
if shift == 0 { return self; }
if shift >= 512 {
let sign = (self.words[7] as i64) < 0;
return I512 { words: if sign { [u64::MAX; 8] } else { [0; 8] } };
}
let word_shift = shift / 64;
let bit_shift = shift % 64;
let mut result = [0u64; 8];
if bit_shift == 0 {
for i in 0..(8 - word_shift) {
result[i] = self.words[i + word_shift];
}
let sign = (self.words[7] as i64) < 0;
for i in (8 - word_shift)..8 {
result[i] = if sign { u64::MAX } else { 0 };
}
} else {
for i in 0..(8 - word_shift) {
let low = self.words[i + word_shift] >> bit_shift;
let high = if i + word_shift + 1 < 8 {
self.words[i + word_shift + 1] << (64 - bit_shift)
} else {
if (self.words[7] as i64) < 0 {
u64::MAX << (64 - bit_shift)
} else {
0
}
};
result[i] = low | high;
}
let sign = (self.words[7] as i64) < 0;
for i in (8 - word_shift)..8 {
result[i] = if sign { u64::MAX } else { 0 };
}
}
I512 { words: result }
}
}
impl BitOr for I512 {
type Output = Self;
#[inline(always)]
fn bitor(self, rhs: Self) -> Self {
let mut result = [0u64; 8];
for i in 0..8 {
result[i] = self.words[i] | rhs.words[i];
}
I512 { words: result }
}
}
impl BitAnd for I512 {
type Output = Self;
#[inline(always)]
fn bitand(self, rhs: Self) -> Self {
let mut result = [0u64; 8];
for i in 0..8 {
result[i] = self.words[i] & rhs.words[i];
}
I512 { words: result }
}
}
impl Mul for I512 {
type Output = Self;
#[inline(always)]
fn mul(self, rhs: Self) -> Self {
if self.fits_in_i128() && rhs.fits_in_i128() {
let a = self.as_i128();
let b = rhs.as_i128();
let result = crate::fixed_point::i256::mul_i128_to_i256(a, b);
I512::from_i256(result)
} else {
self.mul_simple(rhs)
}
}
}
impl std::ops::Div for I512 {
type Output = Self;
#[inline(always)]
fn div(self, rhs: Self) -> Self {
divmod_i512_by_i512(self, rhs).0
}
}
impl std::ops::Rem for I512 {
type Output = Self;
#[inline(always)]
fn rem(self, rhs: Self) -> Self {
divmod_i512_by_i512(self, rhs).1
}
}
impl I512 {
#[inline(always)]
fn mul_simple(self, rhs: Self) -> Self {
let mut result = I512::zero();
let self_neg = (self.words[7] as i64) < 0;
let rhs_neg = (rhs.words[7] as i64) < 0;
let result_neg = self_neg != rhs_neg;
let self_abs = if self_neg { -self } else { self };
let rhs_abs = if rhs_neg { -rhs } else { rhs };
for i in 0..8 {
if rhs_abs.words[i] == 0 { continue; }
let mut carry = 0u64;
for j in 0..8 {
if i + j >= 8 { break; }
let prod = (self_abs.words[j] as u128) * (rhs_abs.words[i] as u128) + (carry as u128) + (result.words[i + j] as u128);
result.words[i + j] = prod as u64;
carry = (prod >> 64) as u64;
}
}
if result_neg {
-result
} else {
result
}
}
}
impl From<I256> for I512 {
#[inline(always)]
fn from(value: I256) -> Self {
I512::from_i256(value)
}
}
impl From<i128> for I512 {
#[inline(always)]
fn from(value: i128) -> Self {
I512::from_i128(value)
}
}
impl From<i64> for I512 {
#[inline(always)]
fn from(value: i64) -> Self {
I512::from_i128(value as i128)
}
}
impl I512 {
#[inline(always)]
pub const fn one_q256_256() -> Self {
let mut words = [0u64; 8];
words[4] = 1; I512 { words }
}
#[inline(always)]
pub fn mul_q256_256(a: I512, b: I512) -> I512 {
let a_wide = crate::fixed_point::I1024::from_i512(a);
let b_wide = crate::fixed_point::I1024::from_i512(b);
let full_product = a_wide * b_wide;
let shifted = full_product >> 256;
let round_bit = (full_product >> 255).words[0] & 1;
let result = shifted.as_i512();
if round_bit != 0 {
result + I512::from_i128(1)
} else {
result
}
}
#[inline(always)]
pub fn from_q64_64(value: i128) -> Self {
I512::from_i128(value) << 192
}
#[inline(always)]
pub fn to_q64_64(self) -> i128 {
let shifted = self >> 192;
shifted.as_i128()
}
#[inline(always)]
pub fn multiply_i512(&self, other: &I512) -> crate::fixed_point::I1024 {
let a_i1024 = crate::fixed_point::I1024::from_i512(*self);
let b_i1024 = crate::fixed_point::I1024::from_i512(*other);
a_i1024 * b_i1024
}
#[inline(always)]
pub fn mul_to_i1024(&self, other: I512) -> crate::fixed_point::I1024 {
use crate::fixed_point::I1024;
let mut result = I1024::zero();
for i in 0..8 {
let mut carry = 0u64;
for j in 0..8 {
let product = (self.words[i] as u128) * (other.words[j] as u128) +
(result.words[i + j] as u128) +
(carry as u128);
result.words[i + j] = product as u64;
carry = (product >> 64) as u64;
}
if i + 8 < 16 {
result.words[i + 8] = carry;
}
}
result
}
}
pub fn divmod_i512_by_i512(dividend: I512, divisor: I512) -> (I512, I512) {
if divisor.is_zero() {
let saturated_quotient = if dividend.is_negative() {
I512::min_value()
} else {
I512::max_value()
};
return (saturated_quotient, I512::zero());
}
if dividend.fits_in_i128() && divisor.fits_in_i128() {
let dividend_i128 = dividend.as_i128();
let divisor_i128 = divisor.as_i128();
let quotient = dividend_i128 / divisor_i128;
let remainder = dividend_i128 % divisor_i128;
return (I512::from_i128(quotient), I512::from_i128(remainder));
}
if dividend.fits_in_i256() && divisor.fits_in_i256() {
let dividend_i256 = dividend.as_i256();
let divisor_i256 = divisor.as_i256();
let (quot, rem) = crate::fixed_point::i256::divmod_i256_by_i256(dividend_i256, divisor_i256);
return (I512::from_i256(quot), I512::from_i256(rem));
}
let dividend_negative = dividend.is_negative();
let divisor_negative = divisor.is_negative();
let quotient_negative = dividend_negative != divisor_negative;
let abs_dividend = if dividend_negative {
negate_i512(dividend)
} else {
dividend
};
let abs_divisor = if divisor_negative {
negate_i512(divisor)
} else {
divisor
};
let mut quotient_words = [0u64; 8];
let mut remainder = I512::zero();
for word_idx in (0..8).rev() {
for bit_idx in (0..64).rev() {
remainder = shift_left_i512_by_1(remainder);
let dividend_bit = (abs_dividend.words[word_idx] >> bit_idx) & 1;
remainder.words[0] |= dividend_bit;
if compare_i512_unsigned(remainder, abs_divisor) >= 0 {
remainder = subtract_i512_unsigned(remainder, abs_divisor);
quotient_words[word_idx] |= 1u64 << bit_idx;
}
}
}
let mut quotient = I512 { words: quotient_words };
if quotient_negative && !quotient.is_zero() {
quotient = negate_i512(quotient);
}
if dividend_negative && !remainder.is_zero() {
remainder = negate_i512(remainder);
}
(quotient, remainder)
}
#[inline(always)]
fn negate_i512(value: I512) -> I512 {
let mut result = [0u64; 8];
let mut carry = 1u64;
for i in 0..8 {
let (val, c) = (!value.words[i]).overflowing_add(carry);
result[i] = val;
carry = c as u64;
}
I512 { words: result }
}
#[inline(always)]
fn shift_left_i512_by_1(value: I512) -> I512 {
let mut result = [0u64; 8];
let mut carry = 0u64;
for i in 0..8 {
let word = value.words[i];
result[i] = (word << 1) | carry;
carry = word >> 63; }
I512 { words: result }
}
#[inline(always)]
fn compare_i512_unsigned(a: I512, b: I512) -> i8 {
for i in (0..8).rev() {
if a.words[i] > b.words[i] {
return 1;
} else if a.words[i] < b.words[i] {
return -1;
}
}
0 }
#[inline(always)]
fn subtract_i512_unsigned(a: I512, b: I512) -> I512 {
let mut result = [0u64; 8];
let mut borrow = 0i128;
for i in 0..8 {
let diff = (a.words[i] as i128) - (b.words[i] as i128) - borrow;
if diff < 0 {
result[i] = (diff + (1i128 << 64)) as u64;
borrow = 1;
} else {
result[i] = diff as u64;
borrow = 0;
}
}
I512 { words: result }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_i512_basic_operations() {
let a = I512::from_i128(42);
let b = I512::from_i128(17);
assert_eq!((a + b).as_i128(), 59);
assert_eq!((a - b).as_i128(), 25);
assert_eq!((a * b).as_i128(), 714);
}
#[test]
fn test_i512_from_i256() {
let i256_val = I256::from_i128(0x123456789ABCDEF0_i128);
let i512_val = I512::from_i256(i256_val);
assert_eq!(i512_val.as_i256(), i256_val);
assert_eq!(i512_val.as_i128(), 0x123456789ABCDEF0_i128);
}
#[test]
fn test_i512_shift_operations() {
let value = I512::from_i128(0xFFFFFFFF00000000_i128);
let shifted = value >> 32;
assert_eq!(shifted.as_i128(), 0xFFFFFFFF_i128);
let left_shifted = I512::from_i128(1) << 100;
assert_ne!(left_shifted.as_i128(), 1); }
#[test]
fn test_q256_256_operations() {
let one = I512::one_q256_256();
let two = one + one;
let result = I512::mul_q256_256(one, two);
assert_eq!(result, two);
}
#[test]
fn test_q64_64_conversion() {
let q64_val = 42i128 << 64; let q256_val = I512::from_q64_64(q64_val);
let back_to_q64 = q256_val.to_q64_64();
assert_eq!(back_to_q64, q64_val);
}
#[test]
fn test_bytes_serialization() {
let value = I512::from_i128(0x123456789ABCDEF0_i128);
let bytes = value.to_bytes_le();
let restored = I512::from_bytes_le(&bytes);
assert_eq!(value, restored);
}
}