use std::sync::Arc;
use vyre_foundation::ir::model::expr::Ident;
use vyre_foundation::ir::{BufferAccess, BufferDecl, DataType, Expr, Node, Program};
pub const OP_ID: &str = "vyre-primitives::math::bigint_add_carry";
pub const BINDING_A_IN: u32 = 0;
pub const BINDING_B_IN: u32 = 1;
pub const BINDING_SUM_PARTIAL_OUT: u32 = 2;
pub const BINDING_CARRY_PARTIAL_OUT: u32 = 3;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum BigIntAddCarryError {
LimbCountMismatch {
a_len: usize,
b_len: usize,
},
SplitCarryLengthMismatch {
sum_len: usize,
carry_len: usize,
},
AllocationFailed {
operation: &'static str,
message: String,
},
}
#[must_use]
pub fn bigint_add_carry(limb_count: u32) -> Program {
if limb_count == 0 {
return crate::invalid_output_program(
OP_ID,
"sum_partial",
DataType::U32,
"Fix: bigint_add_carry requires limb_count > 0, got 0.".to_string(),
);
}
let body = vec![
Node::let_bind("limb_idx", Expr::InvocationId { axis: 0 }),
Node::if_then(
Expr::lt(Expr::var("limb_idx"), Expr::u32(limb_count)),
vec![
Node::let_bind("a_limb", Expr::load("a", Expr::var("limb_idx"))),
Node::let_bind("b_limb", Expr::load("b", Expr::var("limb_idx"))),
Node::let_bind("sum", Expr::add(Expr::var("a_limb"), Expr::var("b_limb"))),
Node::let_bind(
"carry_bool",
Expr::lt(Expr::var("sum"), Expr::var("a_limb")),
),
Node::let_bind(
"carry",
Expr::select(Expr::var("carry_bool"), Expr::u32(1), Expr::u32(0)),
),
Node::store("sum_partial", Expr::var("limb_idx"), Expr::var("sum")),
Node::store("carry_partial", Expr::var("limb_idx"), Expr::var("carry")),
],
),
];
let buffers = vec![
BufferDecl::storage("a", BINDING_A_IN, BufferAccess::ReadOnly, DataType::U32)
.with_count(limb_count),
BufferDecl::storage("b", BINDING_B_IN, BufferAccess::ReadOnly, DataType::U32)
.with_count(limb_count),
BufferDecl::storage(
"sum_partial",
BINDING_SUM_PARTIAL_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(limb_count),
BufferDecl::storage(
"carry_partial",
BINDING_CARRY_PARTIAL_OUT,
BufferAccess::ReadWrite,
DataType::U32,
)
.with_count(limb_count),
];
let entry = vec![Node::Region {
generator: Ident::from(OP_ID),
source_region: None,
body: Arc::new(body),
}];
Program::wrapped(buffers, [256, 1, 1], entry)
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn bigint_add_carry_cpu(
a: &[u32],
b: &[u32],
) -> Result<(Vec<u32>, Vec<u32>), BigIntAddCarryError> {
let mut sum_partial = Vec::with_capacity(a.len());
let mut carry_partial = Vec::with_capacity(a.len());
bigint_add_carry_cpu_into(a, b, &mut sum_partial, &mut carry_partial)?;
Ok((sum_partial, carry_partial))
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn bigint_add_carry_cpu_into(
a: &[u32],
b: &[u32],
sum_partial: &mut Vec<u32>,
carry_partial: &mut Vec<u32>,
) -> Result<(), BigIntAddCarryError> {
if a.len() != b.len() {
return Err(BigIntAddCarryError::LimbCountMismatch {
a_len: a.len(),
b_len: b.len(),
});
}
reserve_bigint_output(sum_partial, a.len(), "sum_partial")?;
reserve_bigint_output(carry_partial, a.len(), "carry_partial")?;
sum_partial.clear();
carry_partial.clear();
for (a_limb, b_limb) in a.iter().zip(b.iter()) {
let (sum, overflow) = a_limb.overflowing_add(*b_limb);
sum_partial.push(sum);
carry_partial.push(u32::from(overflow));
}
Ok(())
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn resolve_carry_chain_cpu(
sum_partial: &[u32],
carry_partial: &[u32],
) -> Result<(Vec<u32>, u32), BigIntAddCarryError> {
let mut final_sum = Vec::with_capacity(sum_partial.len());
let final_carry = resolve_carry_chain_cpu_into(sum_partial, carry_partial, &mut final_sum)?;
Ok((final_sum, final_carry))
}
#[cfg(any(test, feature = "cpu-parity"))]
pub fn resolve_carry_chain_cpu_into(
sum_partial: &[u32],
carry_partial: &[u32],
final_sum: &mut Vec<u32>,
) -> Result<u32, BigIntAddCarryError> {
if sum_partial.len() != carry_partial.len() {
return Err(BigIntAddCarryError::SplitCarryLengthMismatch {
sum_len: sum_partial.len(),
carry_len: carry_partial.len(),
});
}
reserve_bigint_output(final_sum, sum_partial.len(), "final_sum")?;
final_sum.clear();
let mut carry_in: u32 = 0;
for (sum, carry) in sum_partial.iter().zip(carry_partial.iter()) {
let (with_in, overflow_from_in) = sum.overflowing_add(carry_in);
final_sum.push(with_in);
carry_in = *carry | u32::from(overflow_from_in);
}
Ok(carry_in)
}
#[cfg(any(test, feature = "cpu-parity"))]
fn reserve_bigint_output(
out: &mut Vec<u32>,
len: usize,
operation: &'static str,
) -> Result<(), BigIntAddCarryError> {
if len > out.capacity() {
crate::graph::scratch::reserve_graph_items(
out,
len - out.len(),
"bigint add-carry CPU oracle",
operation,
)
.map_err(|message| BigIntAddCarryError::AllocationFailed { operation, message })?;
}
Ok(())
}
#[cfg(feature = "inventory-registry")]
inventory::submit! {
crate::harness::OpEntry::new(
OP_ID,
|| bigint_add_carry(4),
Some(|| {
vec![vec![
crate::wire::pack_u32_slice(&[1, u32::MAX, 5, u32::MAX]),
crate::wire::pack_u32_slice(&[2, 1, u32::MAX, u32::MAX]),
crate::wire::pack_u32_slice(&[0; 4]),
crate::wire::pack_u32_slice(&[0; 4]),
]]
}),
Some(|| {
vec![vec![
crate::wire::pack_u32_slice(&[3, 0, 4, u32::MAX - 1]),
crate::wire::pack_u32_slice(&[0, 1, 1, 1]),
]]
}),
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cpu_zero_plus_zero_returns_zero_with_no_carries() {
let (sum, carry) =
bigint_add_carry_cpu(&[0, 0, 0, 0], &[0, 0, 0, 0]).expect("Fix: matching limbs");
assert_eq!(sum, vec![0, 0, 0, 0]);
assert_eq!(carry, vec![0, 0, 0, 0]);
}
#[test]
fn cpu_no_overflow_per_limb_keeps_carries_zero() {
let a = [1u32, 2, 3, 4];
let b = [10u32, 20, 30, 40];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("Fix: matching limbs");
assert_eq!(sum, vec![11, 22, 33, 44]);
assert_eq!(carry, vec![0, 0, 0, 0]);
}
#[test]
fn cpu_per_limb_overflow_emits_carry_bit() {
let a = [0xFFFF_FFFFu32, 0xFFFF_FFFFu32];
let b = [1u32, 0u32];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("Fix: matching limbs");
assert_eq!(sum, vec![0, 0xFFFF_FFFF]);
assert_eq!(carry, vec![1, 0]);
}
#[test]
fn cpu_max_plus_max_emits_per_limb_carry_and_truncated_sum() {
let a = [0xFFFF_FFFFu32; 4];
let b = [0xFFFF_FFFFu32; 4];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("Fix: matching limbs");
assert_eq!(sum, vec![0xFFFF_FFFEu32; 4]);
assert_eq!(carry, vec![1u32; 4]);
}
#[test]
fn resolve_carry_chain_propagates_single_carry_through_zeros() {
let sum_partial = vec![0xFFFF_FFFFu32, 0, 0, 0];
let carry_partial = vec![1u32, 0, 0, 0];
let (final_sum, final_carry) = resolve_carry_chain_cpu(&sum_partial, &carry_partial)
.expect("Fix: matching split limbs");
assert_eq!(final_sum, vec![0xFFFF_FFFF, 1, 0, 0]);
assert_eq!(
final_carry, 0,
"the carry from limb 0 propagates into limb 1, then dies"
);
}
#[test]
fn resolve_carry_chain_handles_chained_overflow() {
let a = [0xFFFF_FFFFu32, 0xFFFF_FFFFu32, 0xFFFF_FFFFu32, 0];
let b = [1u32, 0, 0, 0];
let (sum_partial, carry_partial) =
bigint_add_carry_cpu(&a, &b).expect("Fix: matching limbs");
let (final_sum, final_carry) = resolve_carry_chain_cpu(&sum_partial, &carry_partial)
.expect("Fix: matching split limbs");
assert_eq!(final_sum, vec![0, 0, 0, 1]);
assert_eq!(final_carry, 0);
}
#[test]
fn resolve_carry_chain_emits_final_carry_out_at_top() {
let a = [0xFFFF_FFFFu32, 0xFFFF_FFFFu32];
let b = [0xFFFF_FFFFu32, 0xFFFF_FFFFu32];
let (sum_partial, carry_partial) =
bigint_add_carry_cpu(&a, &b).expect("Fix: matching limbs");
let (_final_sum, final_carry) = resolve_carry_chain_cpu(&sum_partial, &carry_partial)
.expect("Fix: matching split limbs");
assert_eq!(
final_carry, 1,
"max + max in 64 bits overflows into the 65th bit"
);
}
#[test]
fn resolve_carry_chain_handles_corner_carry_in_only() {
let sum_partial = vec![0xFFFF_FFFFu32, 0xFFFF_FFFFu32];
let carry_partial = vec![1u32, 0];
let (final_sum, final_carry) = resolve_carry_chain_cpu(&sum_partial, &carry_partial)
.expect("Fix: matching split limbs");
assert_eq!(final_sum, vec![0xFFFF_FFFF, 0]);
assert_eq!(
final_carry, 1,
"carry propagated into limb 1 made it overflow"
);
}
#[test]
fn cpu_handles_8_limb_256_bit_operands() {
let a = [0x1234_5678u32; 8];
let b = [0x8765_4321u32; 8];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("Fix: matching limbs");
assert_eq!(sum, vec![0x9999_9999u32; 8]);
assert_eq!(carry, vec![0u32; 8]);
}
#[test]
fn cpu_handles_128_limb_4096_bit_operands() {
let a = vec![0x5555_5555u32; 128];
let b = vec![0xAAAA_AAAAu32; 128];
let (sum, carry) = bigint_add_carry_cpu(&a, &b).expect("Fix: matching limbs");
assert_eq!(sum, vec![0xFFFF_FFFFu32; 128]);
assert_eq!(carry, vec![0u32; 128]);
}
#[test]
fn cpu_mismatched_limb_count_returns_error() {
let a = vec![0u32; 4];
let b = vec![0u32; 5];
assert_eq!(
bigint_add_carry_cpu(&a, &b),
Err(BigIntAddCarryError::LimbCountMismatch { a_len: 4, b_len: 5 })
);
}
#[test]
fn cpu_into_reuses_output_capacity() {
let a = [1u32, u32::MAX];
let b = [2u32, 1];
let mut sum = Vec::with_capacity(32);
let mut carry = Vec::with_capacity(32);
let sum_cap = sum.capacity();
let carry_cap = carry.capacity();
bigint_add_carry_cpu_into(&a, &b, &mut sum, &mut carry).expect("Fix: matching limbs");
assert_eq!(sum, vec![3, 0]);
assert_eq!(carry, vec![0, 1]);
assert_eq!(sum.capacity(), sum_cap);
assert_eq!(carry.capacity(), carry_cap);
}
#[test]
fn cpu_into_truncates_stale_tail_without_reallocating() {
let a = [1u32, u32::MAX];
let b = [2u32, 1];
let mut sum = Vec::with_capacity(8);
let mut carry = Vec::with_capacity(8);
sum.extend([99u32; 8]);
carry.extend([99u32; 8]);
let sum_ptr = sum.as_ptr();
let carry_ptr = carry.as_ptr();
bigint_add_carry_cpu_into(&a, &b, &mut sum, &mut carry).unwrap();
assert_eq!(sum, vec![3, 0]);
assert_eq!(carry, vec![0, 1]);
assert_eq!(sum.as_ptr(), sum_ptr);
assert_eq!(carry.as_ptr(), carry_ptr);
}
#[test]
fn resolve_into_truncates_stale_tail_without_reallocating() {
let mut out = Vec::with_capacity(8);
out.extend([99u32; 8]);
let ptr = out.as_ptr();
let carry = resolve_carry_chain_cpu_into(&[u32::MAX, u32::MAX], &[1, 0], &mut out).unwrap();
assert_eq!(out, vec![u32::MAX, 0]);
assert_eq!(carry, 1);
assert_eq!(out.as_ptr(), ptr);
}
#[test]
fn generated_split_and_resolve_matches_ripple_reference() {
for len in 1usize..=24 {
let a: Vec<u32> = (0..len)
.map(|idx| {
(idx as u32)
.wrapping_mul(0x9E37_79B9)
.wrapping_add(len as u32)
})
.collect();
let b: Vec<u32> = (0..len)
.map(|idx| (idx as u32).wrapping_mul(0x85EB_CA6B).wrapping_add(7))
.collect();
let (sum_partial, carry_partial) = bigint_add_carry_cpu(&a, &b).unwrap();
let (final_sum, final_carry) =
resolve_carry_chain_cpu(&sum_partial, &carry_partial).unwrap();
let mut expected = Vec::with_capacity(len);
let mut carry = 0u64;
for i in 0..len {
let total = a[i] as u64 + b[i] as u64 + carry;
expected.push(total as u32);
carry = total >> 32;
}
assert_eq!(final_sum, expected, "generated bigint case len={len}");
assert_eq!(
final_carry, carry as u32,
"generated bigint carry len={len}"
);
}
}
#[test]
fn resolve_carry_chain_rejects_length_mismatch() {
let mut out = Vec::new();
assert_eq!(
resolve_carry_chain_cpu_into(&[0, 1], &[0], &mut out),
Err(BigIntAddCarryError::SplitCarryLengthMismatch {
sum_len: 2,
carry_len: 1,
})
);
}
#[test]
fn build_program_returns_well_formed_program() {
let program = bigint_add_carry(8);
assert_eq!(
program.buffers().len(),
4,
"a, b, sum_partial, carry_partial"
);
assert_eq!(program.workgroup_size(), [256, 1, 1]);
}
#[test]
fn zero_limb_count_traps() {
let program = bigint_add_carry(0);
assert!(program.stats().trap());
}
#[test]
fn build_program_is_deterministic_across_calls() {
let p1 = bigint_add_carry(16);
let p2 = bigint_add_carry(16);
assert_eq!(
p1.buffers().len(),
p2.buffers().len(),
"two builds with identical inputs must produce identical buffer lists"
);
assert_eq!(p1.workgroup_size(), p2.workgroup_size());
}
#[test]
fn op_id_is_canonical_and_stable() {
assert_eq!(OP_ID, "vyre-primitives::math::bigint_add_carry");
}
#[test]
fn binding_indices_are_canonical_and_stable() {
assert_eq!(BINDING_A_IN, 0);
assert_eq!(BINDING_B_IN, 1);
assert_eq!(BINDING_SUM_PARTIAL_OUT, 2);
assert_eq!(BINDING_CARRY_PARTIAL_OUT, 3);
}
}