use super::FixedPoint;
use crate::fixed_point::universal::fasc::stack_evaluator::BinaryStorage;
#[cfg(table_format = "q16_16")]
use crate::fixed_point::frac_config;
#[cfg(table_format = "q64_64")]
use crate::fixed_point::I256;
#[cfg(table_format = "q128_128")]
use crate::fixed_point::{I256, I512};
#[cfg(table_format = "q256_256")]
use crate::fixed_point::{I512, I1024};
pub(crate) use crate::fixed_point::universal::fasc::stack_evaluator::ComputeStorage;
use crate::fixed_point::universal::fasc::stack_evaluator::compute::downscale_to_storage;
pub(crate) use crate::fixed_point::universal::fasc::stack_evaluator::compute::sincos_at_compute_tier;
#[inline]
pub(crate) fn round_to_storage(acc: ComputeStorage) -> BinaryStorage {
match downscale_to_storage(acc) {
Ok(v) => v,
Err(_) => {
#[cfg(table_format = "q64_64")]
{ (acc >> 64u32).as_i128() }
#[cfg(table_format = "q32_32")]
{ (acc >> 32) as i64 }
#[cfg(table_format = "q16_16")]
{ (acc >> frac_config::FRAC_BITS) as i32 }
#[cfg(table_format = "q128_128")]
{ (acc >> 128usize).as_i256() }
#[cfg(table_format = "q256_256")]
{ (acc >> 256usize).as_i512() }
}
}
}
#[inline]
pub(crate) fn upscale_to_compute(val: BinaryStorage) -> ComputeStorage {
#[cfg(table_format = "q64_64")]
{ I256::from_i128(val) << 64usize }
#[cfg(table_format = "q32_32")]
{ (val as i128) << 32 }
#[cfg(table_format = "q16_16")]
{ (val as i64) << frac_config::FRAC_BITS }
#[cfg(table_format = "q128_128")]
{ I512::from_i256(val) << 128usize }
#[cfg(table_format = "q256_256")]
{ I1024::from_i512(val) << 256usize }
}
pub fn compute_tier_dot(a: &[FixedPoint], b: &[FixedPoint]) -> FixedPoint {
assert_eq!(a.len(), b.len(), "compute_tier_dot: length mismatch");
#[cfg(table_format = "q64_64")]
{
let mut acc = I256::zero();
for i in 0..a.len() {
let a_wide = I256::from_i128(a[i].raw());
let b_wide = I256::from_i128(b[i].raw());
acc = acc + (a_wide * b_wide);
}
FixedPoint::from_raw((acc >> 64u32).as_i128())
}
#[cfg(table_format = "q32_32")]
{
let mut acc: i128 = 0;
for i in 0..a.len() {
acc += (a[i].raw() as i128) * (b[i].raw() as i128);
}
FixedPoint::from_raw((acc >> 32) as i64)
}
#[cfg(table_format = "q16_16")]
{
let mut acc: i64 = 0;
for i in 0..a.len() {
acc += (a[i].raw() as i64) * (b[i].raw() as i64);
}
FixedPoint::from_raw((acc >> frac_config::FRAC_BITS) as i32)
}
#[cfg(table_format = "q128_128")]
{
let mut acc = I512::zero();
for i in 0..a.len() {
let a_raw = a[i].raw();
let b_raw = b[i].raw();
let a_neg = a_raw.is_negative();
let b_neg = b_raw.is_negative();
let result_neg = a_neg != b_neg;
let abs_a = if a_neg { -a_raw } else { a_raw };
let abs_b = if b_neg { -b_raw } else { b_raw };
let product = abs_a.mul_to_i512(abs_b);
let signed_product = if result_neg { -product } else { product };
acc = acc + signed_product;
}
FixedPoint::from_raw((acc >> 128usize).as_i256())
}
#[cfg(table_format = "q256_256")]
{
let mut acc = I1024::zero();
for i in 0..a.len() {
let a_raw = a[i].raw();
let b_raw = b[i].raw();
let a_neg = a_raw.is_negative();
let b_neg = b_raw.is_negative();
let result_neg = a_neg != b_neg;
let abs_a = if a_neg { -a_raw } else { a_raw };
let abs_b = if b_neg { -b_raw } else { b_raw };
let product = abs_a.mul_to_i1024(abs_b);
let signed_product = if result_neg { -product } else { product };
acc = acc + signed_product;
}
FixedPoint::from_raw((acc >> 256usize).as_i512())
}
}
#[inline]
pub(crate) fn compute_tier_dot_raw(a: &[BinaryStorage], b: &[BinaryStorage]) -> BinaryStorage {
assert_eq!(a.len(), b.len(), "compute_tier_dot_raw: length mismatch");
#[cfg(table_format = "q64_64")]
{
let mut acc = I256::zero();
for i in 0..a.len() {
acc = acc + (I256::from_i128(a[i]) * I256::from_i128(b[i]));
}
round_to_storage(acc)
}
#[cfg(table_format = "q32_32")]
{
let mut acc: i128 = 0;
for i in 0..a.len() {
acc += (a[i] as i128) * (b[i] as i128);
}
round_to_storage(acc)
}
#[cfg(table_format = "q16_16")]
{
let mut acc: i64 = 0;
for i in 0..a.len() {
acc += (a[i] as i64) * (b[i] as i64);
}
round_to_storage(acc)
}
#[cfg(table_format = "q128_128")]
{
let mut acc = I512::zero();
for i in 0..a.len() {
let a_neg = a[i].is_negative();
let b_neg = b[i].is_negative();
let result_neg = a_neg != b_neg;
let abs_a = if a_neg { -a[i] } else { a[i] };
let abs_b = if b_neg { -b[i] } else { b[i] };
let product = abs_a.mul_to_i512(abs_b);
acc = acc + if result_neg { -product } else { product };
}
round_to_storage(acc)
}
#[cfg(table_format = "q256_256")]
{
let mut acc = I1024::zero();
for i in 0..a.len() {
let a_neg = a[i].is_negative();
let b_neg = b[i].is_negative();
let result_neg = a_neg != b_neg;
let abs_a = if a_neg { -a[i] } else { a[i] };
let abs_b = if b_neg { -b[i] } else { b[i] };
let product = abs_a.mul_to_i1024(abs_b);
acc = acc + if result_neg { -product } else { product };
}
round_to_storage(acc)
}
}
pub(crate) fn compute_tier_sub_dot_raw(
init: BinaryStorage,
a: &[BinaryStorage],
b: &[BinaryStorage],
) -> BinaryStorage {
assert_eq!(a.len(), b.len(), "compute_tier_sub_dot_raw: length mismatch");
let acc = compute_tier_sub_dot_compute(init, a, b);
round_to_storage(acc)
}
pub(crate) fn compute_tier_sub_dot_compute(
init: BinaryStorage,
a: &[BinaryStorage],
b: &[BinaryStorage],
) -> ComputeStorage {
assert_eq!(a.len(), b.len(), "compute_tier_sub_dot_compute: length mismatch");
#[cfg(table_format = "q64_64")]
{
let mut acc = I256::from_i128(init) << 64usize;
for i in 0..a.len() {
acc = acc - (I256::from_i128(a[i]) * I256::from_i128(b[i]));
}
acc
}
#[cfg(table_format = "q32_32")]
{
let mut acc: i128 = (init as i128) << 32;
for i in 0..a.len() {
acc -= (a[i] as i128) * (b[i] as i128);
}
acc
}
#[cfg(table_format = "q16_16")]
{
let mut acc: i64 = (init as i64) << frac_config::FRAC_BITS;
for i in 0..a.len() {
acc -= (a[i] as i64) * (b[i] as i64);
}
acc
}
#[cfg(table_format = "q128_128")]
{
let mut acc = I512::from_i256(init) << 128usize;
for i in 0..a.len() {
let a_neg = a[i].is_negative();
let b_neg = b[i].is_negative();
let result_neg = a_neg != b_neg;
let abs_a = if a_neg { -a[i] } else { a[i] };
let abs_b = if b_neg { -b[i] } else { b[i] };
let product = abs_a.mul_to_i512(abs_b);
acc = acc - if result_neg { -product } else { product };
}
acc
}
#[cfg(table_format = "q256_256")]
{
let mut acc = I1024::from_i512(init) << 256usize;
for i in 0..a.len() {
let a_neg = a[i].is_negative();
let b_neg = b[i].is_negative();
let result_neg = a_neg != b_neg;
let abs_a = if a_neg { -a[i] } else { a[i] };
let abs_b = if b_neg { -b[i] } else { b[i] };
let product = abs_a.mul_to_i1024(abs_b);
acc = acc - if result_neg { -product } else { product };
}
acc
}
}
pub(crate) fn givens(a: FixedPoint, b: FixedPoint) -> (FixedPoint, FixedPoint) {
let one = FixedPoint::one();
let zero = FixedPoint::ZERO;
if b.is_zero() {
return (one, zero);
}
if a.is_zero() {
let sn = if b.is_negative() { -one } else { one };
return (zero, sn);
}
if b.abs() > a.abs() {
let tau = a / b;
let sn = one / (one + tau * tau).sqrt();
let cs = sn * tau;
(cs, sn)
} else {
let tau = b / a;
let cs = one / (one + tau * tau).sqrt();
let sn = cs * tau;
(cs, sn)
}
}
#[inline]
pub(crate) fn apply_givens_compute(
cs: FixedPoint, sn: FixedPoint, x: FixedPoint, y: FixedPoint,
) -> (FixedPoint, FixedPoint) {
let cs_raw = cs.raw();
let sn_raw = sn.raw();
let neg_sn_raw = (-sn).raw();
let x_raw = x.raw();
let y_raw = y.raw();
let new_x = FixedPoint::from_raw(compute_tier_dot_raw(
&[cs_raw, sn_raw], &[x_raw, y_raw]
));
let new_y = FixedPoint::from_raw(compute_tier_dot_raw(
&[neg_sn_raw, cs_raw], &[x_raw, y_raw]
));
(new_x, new_y)
}
pub(crate) fn convergence_threshold(magnitude: FixedPoint) -> FixedPoint {
let quantum = FixedPoint::from_raw(quantum_raw());
let shifted = magnitude.abs().raw() >> half_frac_bits();
let result = FixedPoint::from_raw(shifted);
if result.is_zero() { quantum } else { result }
}
pub(crate) fn convergence_threshold_tight(magnitude: FixedPoint) -> FixedPoint {
let quantum = FixedPoint::from_raw(quantum_raw());
let shifted = magnitude.abs().raw() >> two_thirds_frac_bits();
let result = FixedPoint::from_raw(shifted);
if result.is_zero() { quantum } else { result }
}
#[cfg(table_format = "q64_64")]
fn two_thirds_frac_bits() -> u32 { 42 }
#[cfg(table_format = "q32_32")]
fn two_thirds_frac_bits() -> u32 { 21 }
#[cfg(table_format = "q16_16")]
fn two_thirds_frac_bits() -> u32 { 10 }
#[cfg(table_format = "q128_128")]
fn two_thirds_frac_bits() -> u32 { 85 }
#[cfg(table_format = "q256_256")]
fn two_thirds_frac_bits() -> usize { 170 }
#[cfg(table_format = "q64_64")]
fn half_frac_bits() -> u32 { 32 }
#[cfg(table_format = "q32_32")]
fn half_frac_bits() -> u32 { 16 }
#[cfg(table_format = "q16_16")]
fn half_frac_bits() -> u32 { 8 }
#[cfg(table_format = "q128_128")]
fn half_frac_bits() -> u32 { 64 }
#[cfg(table_format = "q256_256")]
fn half_frac_bits() -> usize { 128 }
#[cfg(table_format = "q64_64")]
fn quantum_raw() -> BinaryStorage { 1i128 }
#[cfg(table_format = "q32_32")]
fn quantum_raw() -> BinaryStorage { 1i64 }
#[cfg(table_format = "q16_16")]
fn quantum_raw() -> BinaryStorage { 1i32 }
#[cfg(table_format = "q128_128")]
fn quantum_raw() -> BinaryStorage { I256::from_i128(1) }
#[cfg(table_format = "q256_256")]
fn quantum_raw() -> BinaryStorage { I512::from_i128(1) }
#[allow(unused_imports)]
use crate::fixed_point::domains::balanced_ternary::trit_packing::Trit;
#[allow(dead_code)]
pub fn compute_tier_trit_dot_raw(
packed_trits: &[u8],
num_elements: usize,
values: &[BinaryStorage],
scale: BinaryStorage,
) -> BinaryStorage {
assert!(values.len() >= num_elements, "compute_tier_trit_dot_raw: values shorter than num_elements");
let mut acc = compute_zero();
let mut trit_idx = 0;
for &byte in packed_trits {
if trit_idx >= num_elements {
break;
}
let mut remaining = byte;
let mut chunk_trits = [1u8; 5]; for j in (0..5).rev() {
chunk_trits[j] = remaining % 3;
remaining /= 3;
}
for j in 0..5 {
if trit_idx >= num_elements {
break;
}
let trit = chunk_trits[j];
if trit == 2 {
let widened = upscale_to_compute(values[trit_idx]);
acc = compute_add(acc, widened);
} else if trit == 0 {
let widened = upscale_to_compute(values[trit_idx]);
acc = compute_sub(acc, widened);
}
trit_idx += 1;
}
}
let dot_storage = round_to_storage(acc);
compute_tier_mul_pair(dot_storage, scale)
}
#[allow(dead_code)]
pub fn compute_tier_trit_matvec_raw(
packed_trits: &[u8],
rows: usize,
cols: usize,
values: &[BinaryStorage],
scales: &[BinaryStorage],
) -> Vec<BinaryStorage> {
assert!(values.len() >= cols, "compute_tier_trit_matvec_raw: values shorter than cols");
assert!(scales.len() >= rows, "compute_tier_trit_matvec_raw: scales shorter than rows");
let bytes_per_row = (cols + 4) / 5;
let mut result = Vec::with_capacity(rows);
for row in 0..rows {
let row_start = row * bytes_per_row;
let row_end = row_start + bytes_per_row;
let row_trits = &packed_trits[row_start..row_end];
let dot = compute_tier_trit_dot_raw(row_trits, cols, values, scales[row]);
result.push(dot);
}
result
}
#[allow(dead_code)]
#[inline]
fn compute_zero() -> ComputeStorage {
#[cfg(table_format = "q64_64")]
{ I256::zero() }
#[cfg(table_format = "q32_32")]
{ 0i128 }
#[cfg(table_format = "q16_16")]
{ 0i64 }
#[cfg(table_format = "q128_128")]
{ I512::zero() }
#[cfg(table_format = "q256_256")]
{ I1024::zero() }
}
#[allow(dead_code)]
#[inline]
fn compute_add(a: ComputeStorage, b: ComputeStorage) -> ComputeStorage {
a + b
}
#[allow(dead_code)]
#[inline]
fn compute_sub(a: ComputeStorage, b: ComputeStorage) -> ComputeStorage {
a - b
}
#[allow(dead_code)]
#[inline]
fn compute_tier_mul_pair(a: BinaryStorage, b: BinaryStorage) -> BinaryStorage {
#[cfg(table_format = "q64_64")]
{
let a_wide = I256::from_i128(a);
let b_wide = I256::from_i128(b);
let product = a_wide * b_wide;
round_to_storage(product)
}
#[cfg(table_format = "q32_32")]
{
let product = (a as i128) * (b as i128);
round_to_storage(product)
}
#[cfg(table_format = "q16_16")]
{
let product = (a as i64) * (b as i64);
round_to_storage(product)
}
#[cfg(table_format = "q128_128")]
{
let a_neg = a.is_negative();
let b_neg = b.is_negative();
let result_neg = a_neg != b_neg;
let abs_a = if a_neg { -a } else { a };
let abs_b = if b_neg { -b } else { b };
let product = abs_a.mul_to_i512(abs_b);
let product = if result_neg { -product } else { product };
round_to_storage(product)
}
#[cfg(table_format = "q256_256")]
{
let a_neg = a.is_negative();
let b_neg = b.is_negative();
let result_neg = a_neg != b_neg;
let abs_a = if a_neg { -a } else { a };
let abs_b = if b_neg { -b } else { b };
let product = abs_a.mul_to_i1024(abs_b);
let product = if result_neg { -product } else { product };
round_to_storage(product)
}
}