use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::bigint_add_carry";
pub const BINDING_A_IN: u32 = 0;
pub const BINDING_B_IN: u32 = 1;
pub const BINDING_SUM_PARTIAL_OUT: u32 = 2;
pub const BINDING_CARRY_PARTIAL_OUT: u32 = 3;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum BigIntAddCarryError {
LimbCountMismatch {
a_len: usize,
b_len: usize,
},
SplitCarryLengthMismatch {
sum_len: usize,
carry_len: usize,
},
}
#[must_use]
pub fn bigint_add_carry(limb_count: u32) -> Program {
let body = vec![
Node::let_bind("limb_idx", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(Expr::var("limb_idx"), Expr::u32(limb_count)),
vec![
Node::let_bind("a_limb", Expr::load("a", Expr::var("limb_idx"))),
Node::let_bind("b_limb", Expr::load("b", Expr::var("limb_idx"))),
Node::let_bind("sum", Expr::add(Expr::var("a_limb"), Expr::var("b_limb"))),
Node::let_bind(
"carry_bool",
Expr::lt(Expr::var("sum"), Expr::var("a_limb")),
),
Node::let_bind(
"carry",
Expr::select(Expr::var("carry_bool"), Expr::u32(1), Expr::u32(0)),
),
Node::store("sum_partial", Expr::var("limb_idx"), Expr::var("sum")),
Node::store("carry_partial", Expr::var("limb_idx"), Expr::var("carry")),
],
),
];
let buffers = vec![
BufferDecl::storage("a", BINDING_A_IN, BufferAccess::ReadOnly, DataType::U32)
.with_count(limb_count),
BufferDecl::storage("b", BINDING_B_IN, BufferAccess::ReadOnly, DataType::U32)
.with_count(limb_count),
BufferDecl::storage(
"sum_partial",
BINDING_SUM_PARTIAL_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(limb_count),
BufferDecl::storage(
"carry_partial",
BINDING_CARRY_PARTIAL_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(limb_count),
];
let entry = vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}];
Program::wrapped(buffers, [256, 1, 1], entry)
}
pub fn bigint_add_carry_cpu(
a: &[u32],
b: &[u32],
) -> Result<(Vec<u32>, Vec<u32>), BigIntAddCarryError> {
let mut sum_partial = Vec::with_capacity(a.len());
let mut carry_partial = Vec::with_capacity(a.len());
bigint_add_carry_cpu_into(a, b, &mut sum_partial, &mut carry_partial)?;
Ok((sum_partial, carry_partial))
}
pub fn bigint_add_carry_cpu_into(
a: &[u32],
b: &[u32],
sum_partial: &mut Vec<u32>,
carry_partial: &mut Vec<u32>,
) -> Result<(), BigIntAddCarryError> {
if a.len() != b.len() {
return Err(BigIntAddCarryError::LimbCountMismatch {
a_len: a.len(),
b_len: b.len(),
});
}
sum_partial.clear();
carry_partial.clear();
sum_partial.reserve(a.len());
carry_partial.reserve(a.len());
for (a_limb, b_limb) in a.iter().zip(b.iter()) {
let (sum, overflow) = a_limb.overflowing_add(*b_limb);
sum_partial.push(sum);
carry_partial.push(u32::from(overflow));
}
Ok(())
}
pub fn resolve_carry_chain_cpu(
sum_partial: &[u32],
carry_partial: &[u32],
) -> Result<(Vec<u32>, u32), BigIntAddCarryError> {
let mut final_sum = Vec::with_capacity(sum_partial.len());
let final_carry = resolve_carry_chain_cpu_into(sum_partial, carry_partial, &mut final_sum)?;
Ok((final_sum, final_carry))
}
pub fn resolve_carry_chain_cpu_into(
sum_partial: &[u32],
carry_partial: &[u32],
final_sum: &mut Vec<u32>,
) -> Result<u32, BigIntAddCarryError> {
if sum_partial.len() != carry_partial.len() {
return Err(BigIntAddCarryError::SplitCarryLengthMismatch {
sum_len: sum_partial.len(),
carry_len: carry_partial.len(),
});
}
final_sum.clear();
final_sum.reserve(sum_partial.len());
let mut carry_in: u32 = 0;
for (sum, carry) in sum_partial.iter().zip(carry_partial.iter()) {
let (with_in, overflow_from_in) = sum.overflowing_add(carry_in);
final_sum.push(with_in);
carry_in = *carry | u32::from(overflow_from_in);
}
Ok(carry_in)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_zero_plus_zero_returns_zero_with_no_carries() {
let (sum, carry) =
bigint_add_carry_cpu(&[0, 0, 0, 0], &[0, 0, 0, 0]).expect("matching limbs");
assert_eq!(sum, vec![0, 0, 0, 0]);
assert_eq!(carry, vec![0, 0, 0, 0]);
}
#[test]
fn cpu_no_overflow_per_limb_keeps_carries_zero() {
let a = [1u32, 2, 3, 4];
let b = [10u32, 20, 30, 40];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("matching limbs");
assert_eq!(sum, vec![11, 22, 33, 44]);
assert_eq!(carry, vec![0, 0, 0, 0]);
}
#[test]
fn cpu_per_limb_overflow_emits_carry_bit() {
let a = [0xFFFF_FFFFu32, 0xFFFF_FFFFu32];
let b = [1u32, 0u32];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("matching limbs");
assert_eq!(sum, vec![0, 0xFFFF_FFFF]);
assert_eq!(carry, vec![1, 0]);
}
#[test]
fn cpu_max_plus_max_emits_per_limb_carry_and_truncated_sum() {
let a = [0xFFFF_FFFFu32; 4];
let b = [0xFFFF_FFFFu32; 4];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("matching limbs");
assert_eq!(sum, vec![0xFFFF_FFFEu32; 4]);
assert_eq!(carry, vec![1u32; 4]);
}
#[test]
fn resolve_carry_chain_propagates_single_carry_through_zeros() {
let sum_partial = vec![0xFFFF_FFFFu32, 0, 0, 0];
let carry_partial = vec![1u32, 0, 0, 0];
let (final_sum, final_carry) =
resolve_carry_chain_cpu(&sum_partial, &carry_partial).expect("matching split limbs");
assert_eq!(final_sum, vec![0xFFFF_FFFF, 1, 0, 0]);
assert_eq!(
final_carry, 0,
"the carry from limb 0 propagates into limb 1, then dies"
);
}
#[test]
fn resolve_carry_chain_handles_chained_overflow() {
let a = [0xFFFF_FFFFu32, 0xFFFF_FFFFu32, 0xFFFF_FFFFu32, 0];
let b = [1u32, 0, 0, 0];
let (sum_partial, carry_partial) = bigint_add_carry_cpu(&a, &b).expect("matching limbs");
let (final_sum, final_carry) =
resolve_carry_chain_cpu(&sum_partial, &carry_partial).expect("matching split limbs");
assert_eq!(final_sum, vec![0, 0, 0, 1]);
assert_eq!(final_carry, 0);
}
#[test]
fn resolve_carry_chain_emits_final_carry_out_at_top() {
let a = [0xFFFF_FFFFu32, 0xFFFF_FFFFu32];
let b = [0xFFFF_FFFFu32, 0xFFFF_FFFFu32];
let (sum_partial, carry_partial) = bigint_add_carry_cpu(&a, &b).expect("matching limbs");
let (_final_sum, final_carry) =
resolve_carry_chain_cpu(&sum_partial, &carry_partial).expect("matching split limbs");
assert_eq!(
final_carry, 1,
"max + max in 64 bits overflows into the 65th bit"
);
}
#[test]
fn resolve_carry_chain_handles_corner_carry_in_only() {
let sum_partial = vec![0xFFFF_FFFFu32, 0xFFFF_FFFFu32];
let carry_partial = vec![1u32, 0];
let (final_sum, final_carry) =
resolve_carry_chain_cpu(&sum_partial, &carry_partial).expect("matching split limbs");
assert_eq!(final_sum, vec![0xFFFF_FFFF, 0]);
assert_eq!(
final_carry, 1,
"carry propagated into limb 1 made it overflow"
);
}
#[test]
fn cpu_handles_8_limb_256_bit_operands() {
let a = [0x1234_5678u32; 8];
let b = [0x8765_4321u32; 8];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("matching limbs");
assert_eq!(sum, vec![0x9999_9999u32; 8]);
assert_eq!(carry, vec![0u32; 8]);
}
#[test]
fn cpu_handles_128_limb_4096_bit_operands() {
let a = vec![0x5555_5555u32; 128];
let b = vec![0xAAAA_AAAAu32; 128];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("matching limbs");
assert_eq!(sum, vec![0xFFFF_FFFFu32; 128]);
assert_eq!(carry, vec![0u32; 128]);
}
#[test]
fn cpu_mismatched_limb_count_returns_error() {
let a = vec![0u32; 4];
let b = vec![0u32; 5];
assert_eq!(
bigint_add_carry_cpu(&a, &b),
Err(BigIntAddCarryError::LimbCountMismatch { a_len: 4, b_len: 5 })
);
}
#[test]
fn cpu_into_reuses_output_capacity() {
let a = [1u32, u32::MAX];
let b = [2u32, 1];
let mut sum = Vec::with_capacity(32);
let mut carry = Vec::with_capacity(32);
let sum_cap = sum.capacity();
let carry_cap = carry.capacity();
bigint_add_carry_cpu_into(&a, &b, &mut sum, &mut carry).expect("matching limbs");
assert_eq!(sum, vec![3, 0]);
assert_eq!(carry, vec![0, 1]);
assert_eq!(sum.capacity(), sum_cap);
assert_eq!(carry.capacity(), carry_cap);
}
#[test]
fn resolve_carry_chain_rejects_length_mismatch() {
let mut out = Vec::new();
assert_eq!(
resolve_carry_chain_cpu_into(&[0, 1], &[0], &mut out),
Err(BigIntAddCarryError::SplitCarryLengthMismatch {
sum_len: 2,
carry_len: 1,
})
);
}
#[test]
fn build_program_returns_well_formed_program() {
let program = bigint_add_carry(8);
assert_eq!(
program.buffers().len(),
4,
"a, b, sum_partial, carry_partial"
);
assert_eq!(program.workgroup_size(), [256, 1, 1]);
}
#[test]
fn build_program_is_deterministic_across_calls() {
let p1 = bigint_add_carry(16);
let p2 = bigint_add_carry(16);
assert_eq!(
p1.buffers().len(),
p2.buffers().len(),
"two builds with identical inputs must produce identical buffer lists"
);
assert_eq!(p1.workgroup_size(), p2.workgroup_size());
}
#[test]
fn op_id_is_canonical_and_stable() {
assert_eq!(OP_ID, "vyre-primitives::math::bigint_add_carry");
}
#[test]
fn binding_indices_are_canonical_and_stable() {
assert_eq!(BINDING_A_IN, 0);
assert_eq!(BINDING_B_IN, 1);
assert_eq!(BINDING_SUM_PARTIAL_OUT, 2);
assert_eq!(BINDING_CARRY_PARTIAL_OUT, 3);
}
}