clock_bigint/
timeboxed.rs

1//! Time-sliced (pausable) operations for long-running computations.
2//!
3//! Enables asynchronous VM execution by allowing operations to be paused
4//! and resumed, preventing blocking during long computations.
5
6use crate::ct;
7use crate::error::{BigIntError, Result};
8use crate::limbs::{Limb, limb_mac, limb_mul_wide};
9use crate::montgomery::{MontgomeryContext, from_mont, mont_mul, to_mont};
10use crate::types::BigIntCore;
11
12#[cfg(feature = "alloc")]
13use crate::types::BigInt;
14#[cfg(feature = "alloc")]
15use alloc::{vec, vec::Vec};
16
17/// Partial result from a time-sliced operation.
18///
19/// Contains the current state and a flag indicating whether the operation
20/// is complete or needs to be resumed.
21#[derive(Clone, Debug)]
22pub enum PartialResult<T> {
23    /// Operation completed successfully.
24    Complete(T),
25    /// Operation needs to be resumed with the given continuation state.
26    Partial(ContinuationState),
27}
28
29/// Continuation state for resuming a time-sliced operation.
30#[derive(Clone, Debug)]
31pub struct ContinuationState {
32    /// Operation type identifier.
33    pub op_type: OpType,
34    /// Current intermediate values.
35    pub state: OperationState,
36    /// Number of cycles consumed so far.
37    pub cycles_consumed: u64,
38    /// Maximum cycles allowed.
39    pub max_cycles: u64,
40}
41
42/// Type of operation being performed.
43#[derive(Clone, Debug, PartialEq, Eq)]
44pub enum OpType {
45    /// Multiplication operation.
46    Mul,
47    /// Modular exponentiation operation.
48    ModExp,
49}
50
51/// Internal state for different operation types.
52#[derive(Clone, Debug)]
53pub enum OperationState {
54    /// State for multiplication.
55    MulState {
56        /// First operand.
57        a: BigInt,
58        /// Second operand.
59        b: BigInt,
60        /// Current result buffer.
61        result: Vec<Limb>,
62        /// Current iteration index.
63        i: usize,
64        /// Current j index (for nested loops).
65        j: usize,
66    },
67    /// State for modular exponentiation.
68    ModExpState {
69        /// Modulus for Montgomery arithmetic.
70        modulus: BigInt,
71        /// Exponent.
72        exponent: BigInt,
73        /// R0 value (Montgomery ladder).
74        r0: BigInt,
75        /// R1 value (Montgomery ladder).
76        r1: BigInt,
77        /// Current bit position.
78        bit_pos: usize,
79        /// Total bits to process.
80        total_bits: usize,
81    },
82}
83
84/// Start a time-sliced multiplication operation.
85///
86/// Returns a PartialResult that can be resumed if the operation
87/// exceeds the maximum cycle count.
88#[cfg(feature = "alloc")]
89pub fn mul_timeboxed(a: &BigInt, b: &BigInt, max_cycles: u64) -> Result<PartialResult<BigInt>> {
90    // Handle zero cases
91    if a.is_zero() || b.is_zero() {
92        let max_limbs = a.max_limbs().max(b.max_limbs());
93        return Ok(PartialResult::Complete(BigInt::new(max_limbs)));
94    }
95
96    let n = a.limb_count().max(b.limb_count());
97    let result_limbs = 2 * n;
98
99    // Check for overflow
100    if result_limbs > a.max_limbs().max(b.max_limbs()) {
101        return Err(BigIntError::Overflow);
102    }
103
104    let max_limbs = a.max_limbs().max(b.max_limbs()).max(result_limbs);
105    let mut result = BigInt::new(max_limbs);
106    result.ensure_capacity(result_limbs)?;
107
108    let a_limbs = a.limbs();
109    let b_limbs = b.limbs();
110
111    // Normalize to same length for multiplication
112    let mut a_normalized = vec![0u64; n];
113    let mut b_normalized = vec![0u64; n];
114    a_normalized[..a_limbs.len()].copy_from_slice(a_limbs);
115    b_normalized[..b_limbs.len()].copy_from_slice(b_limbs);
116
117    let result_limbs_mut = result.limbs_mut();
118
119    // Initialize result to zero
120    for i in 0..(2 * n) {
121        result_limbs_mut[i] = 0;
122    }
123
124    // Start timeboxed Comba multiplication
125    let mut result_vec = result_limbs_mut.to_vec();
126    mul_timeboxed_comba(
127        &a_normalized,
128        &b_normalized,
129        &mut result_vec,
130        0,
131        0,
132        0,
133        max_cycles,
134    )
135}
136
137fn mul_timeboxed_comba(
138    a: &[Limb],
139    b: &[Limb],
140    result: &mut Vec<Limb>,
141    start_i: usize,
142    start_j: usize,
143    cycles_consumed: u64,
144    max_cycles: u64,
145) -> Result<PartialResult<BigInt>> {
146    let n = a.len();
147    debug_assert_eq!(b.len(), n);
148    debug_assert!(result.len() >= 2 * n);
149
150    let mut cycles = cycles_consumed;
151    let mut i = start_i;
152    let mut j = start_j;
153
154    // Comba multiplication algorithm with cycle tracking
155    while i < n {
156        let mut carry = 0u64;
157        let _j_start = if i == start_i { start_j } else { 0 };
158
159        while j < n {
160            // Each limb multiplication costs some cycles
161            cycles += 10; // Estimate cycles for limb operations
162            if cycles >= max_cycles {
163                // Create continuation state
164                let max_limbs = result.len();
165                let mut result_bigint = BigInt::new(max_limbs);
166                result_bigint.limbs_mut().copy_from_slice(result);
167
168                let state = OperationState::MulState {
169                    a: BigInt::from_limbs(a, max_limbs)?,
170                    b: BigInt::from_limbs(b, max_limbs)?,
171                    result: result.to_vec(),
172                    i,
173                    j,
174                };
175
176                return Ok(PartialResult::Partial(ContinuationState {
177                    op_type: OpType::Mul,
178                    state,
179                    cycles_consumed: cycles,
180                    max_cycles,
181                }));
182            }
183
184            let wide = limb_mul_wide(a[i], b[j]);
185            let (lo, hi) = (wide.lo, wide.hi);
186            let (sum_lo, carry1) = limb_mac(result[i + j], lo, carry);
187            result[i + j] = sum_lo;
188            carry = carry1 + hi;
189            j += 1;
190        }
191
192        // Add final carry
193        cycles += 2; // Estimate cycles for carry addition
194        if i + n < result.len() {
195            let (sum, _) = limb_mac(result[i + n], 0, carry);
196            result[i + n] = sum;
197        }
198
199        i += 1;
200        j = 0;
201    }
202
203    // Operation complete - create final BigInt
204    let max_limbs = result.len();
205    let mut result_bigint = BigInt::new(max_limbs);
206    result_bigint.ensure_capacity(result.len())?;
207    result_bigint.limbs_mut().copy_from_slice(result);
208    result_bigint.canonicalize()?;
209
210    Ok(PartialResult::Complete(result_bigint))
211}
212
213/// Start a time-sliced modular exponentiation operation.
214///
215/// Returns a PartialResult that can be resumed if the operation
216/// exceeds the maximum cycle count.
217#[cfg(feature = "alloc")]
218pub fn modexp_timeboxed(
219    ctx: &MontgomeryContext,
220    base: &BigInt,
221    exponent: &BigInt,
222    max_cycles: u64,
223) -> Result<PartialResult<BigInt>> {
224    // Handle special cases
225    if exponent.is_zero() {
226        // base^0 = 1 mod m
227        let one = BigInt::from_u64(1, ctx.modulus().max_limbs());
228        return Ok(PartialResult::Complete(one));
229    }
230
231    if base.is_zero() {
232        // 0^exponent = 0 mod m (if exponent > 0)
233        return Ok(PartialResult::Complete(BigInt::from_u64(
234            0,
235            ctx.modulus().max_limbs(),
236        )));
237    }
238
239    // Convert base to Montgomery form
240    let x_mont = to_mont(ctx, base)?;
241
242    // Convert 1 to Montgomery form (R̃ = to_mont(1))
243    let one = BigInt::from_u64(1, ctx.modulus().max_limbs());
244    let r_mont = to_mont(ctx, &one)?;
245
246    // Initialize Montgomery ladder: R0 = R̃, R1 = x̃
247    let r0 = r_mont;
248    let r1 = x_mont;
249
250    // Find the most significant bit
251    let exp_limbs = exponent.limbs();
252    let mut msb_found = false;
253    let mut bit_pos = 0;
254
255    for i in (0..exp_limbs.len()).rev() {
256        if exp_limbs[i] != 0 {
257            bit_pos = (i * 64) + (63 - exp_limbs[i].leading_zeros() as usize);
258            msb_found = true;
259            break;
260        }
261    }
262
263    if !msb_found {
264        // Exponent is zero (should have been caught above, but handle it)
265        let one = BigInt::from_u64(1, ctx.modulus().max_limbs());
266        return Ok(PartialResult::Complete(one));
267    }
268
269    // Start timeboxed Montgomery ladder
270    modexp_timeboxed_ladder(ctx, exponent, r0, r1, bit_pos, 0, max_cycles)
271}
272
273fn modexp_timeboxed_ladder(
274    ctx: &MontgomeryContext,
275    exponent: &BigInt,
276    mut r0: BigInt,
277    mut r1: BigInt,
278    total_bits: usize,
279    start_bit: usize,
280    max_cycles: u64,
281) -> Result<PartialResult<BigInt>> {
282    let mut cycles_consumed = 0u64;
283    let exp_limbs = exponent.limbs();
284
285    // Process bits from MSB to LSB
286    for bit_pos in (0..=total_bits).rev().skip(start_bit) {
287        // Check cycle limit before each step
288        cycles_consumed += 5; // Estimate cycles for bit processing setup
289        if cycles_consumed >= max_cycles {
290            let state = OperationState::ModExpState {
291                modulus: ctx.modulus().clone(),
292                exponent: exponent.clone(),
293                r0: r0.clone(),
294                r1: r1.clone(),
295                bit_pos,
296                total_bits,
297            };
298
299            return Ok(PartialResult::Partial(ContinuationState {
300                op_type: OpType::ModExp,
301                state,
302                cycles_consumed,
303                max_cycles,
304            }));
305        }
306
307        let limb_idx = bit_pos / 64;
308        let bit_idx = bit_pos % 64;
309        let bit = if limb_idx < exp_limbs.len() {
310            (exp_limbs[limb_idx] >> bit_idx) & 1
311        } else {
312            0
313        };
314
315        // Constant-time conditional swap
316        cycles_consumed += 20; // Estimate cycles for swap operations
317        ct::ct_swap(bit, &mut r0.limbs_mut()[0], &mut r1.limbs_mut()[0]);
318        swap_bigint(bit, &mut r0, &mut r1);
319
320        // Montgomery ladder step
321        cycles_consumed += 50; // Estimate cycles for Montgomery operations
322        r1 = mont_mul(ctx, &r0, &r1)?;
323        r0 = mont_mul(ctx, &r0, &r0)?;
324
325        // Swap again
326        cycles_consumed += 20; // Estimate cycles for swap operations
327        swap_bigint(bit, &mut r0, &mut r1);
328    }
329
330    // Convert result from Montgomery form
331    let result = from_mont(ctx, &r0)?;
332
333    Ok(PartialResult::Complete(result))
334}
335
336/// Constant-time swap of two BigInt values.
337fn swap_bigint(choice: u64, a: &mut BigInt, b: &mut BigInt) {
338    let a_limbs = a.limbs_mut();
339    let b_limbs = b.limbs_mut();
340    let len = a_limbs.len().min(b_limbs.len());
341
342    for i in 0..len {
343        ct::ct_swap(choice, &mut a_limbs[i], &mut b_limbs[i]);
344    }
345
346    // Also swap signs
347    let a_sign = a.sign();
348    let b_sign = b.sign();
349    let new_a_sign = ct::ct_select(choice, b_sign as u64, a_sign as u64) != 0;
350    let new_b_sign = ct::ct_select(choice, a_sign as u64, b_sign as u64) != 0;
351    a.set_sign(new_a_sign);
352    b.set_sign(new_b_sign);
353}
354
355/// Resume a time-sliced operation from a continuation state.
356///
357/// Continues execution until completion or until max_cycles is reached.
358#[cfg(feature = "alloc")]
359pub fn resume(state: ContinuationState) -> Result<PartialResult<BigInt>> {
360    match state.op_type {
361        OpType::Mul => {
362            // Resume multiplication
363            if let OperationState::MulState {
364                a,
365                b,
366                mut result,
367                i,
368                j,
369            } = state.state
370            {
371                let a_limbs = a.limbs();
372                let b_limbs = b.limbs();
373
374                mul_timeboxed_comba(
375                    a_limbs,
376                    b_limbs,
377                    &mut result,
378                    i,
379                    j,
380                    state.cycles_consumed,
381                    state.max_cycles,
382                )
383            } else {
384                Err(BigIntError::InvalidState)
385            }
386        }
387        OpType::ModExp => {
388            // Resume modular exponentiation
389            if let OperationState::ModExpState {
390                modulus,
391                exponent,
392                r0,
393                r1,
394                bit_pos,
395                total_bits,
396            } = state.state
397            {
398                // Reconstruct Montgomery context
399                let ctx = MontgomeryContext::new(modulus)?;
400
401                modexp_timeboxed_ladder(
402                    &ctx,
403                    &exponent,
404                    r0,
405                    r1,
406                    total_bits,
407                    bit_pos,
408                    state.max_cycles - state.cycles_consumed,
409                )
410            } else {
411                Err(BigIntError::InvalidState)
412            }
413        }
414    }
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420
421    #[test]
422    fn test_mul_timeboxed() {
423        let a = BigInt::from_u64(10, 10);
424        let b = BigInt::from_u64(20, 10);
425        let result = mul_timeboxed(&a, &b, 1000).unwrap();
426        match result {
427            PartialResult::Complete(r) => {
428                assert_eq!(r.limbs()[0], 200);
429            }
430            PartialResult::Partial(_) => {
431                panic!("Expected complete result");
432            }
433        }
434    }
435
436    #[test]
437    fn test_mul_timeboxed_with_high_cycles() {
438        let a = BigInt::from_u64(100, 10);
439        let b = BigInt::from_u64(200, 10);
440
441        // Use high cycle limit to ensure completion
442        let result = mul_timeboxed(&a, &b, 10000).unwrap();
443        match result {
444            PartialResult::Complete(r) => {
445                assert_eq!(r.limbs()[0], 20000);
446            }
447            PartialResult::Partial(_) => {
448                panic!("Expected complete result with high cycle limit");
449            }
450        }
451    }
452}