1use crate::ct;
7use crate::error::{BigIntError, Result};
8use crate::limbs::{Limb, limb_mac, limb_mul_wide};
9use crate::montgomery::{MontgomeryContext, from_mont, mont_mul, to_mont};
10use crate::types::BigIntCore;
11
12#[cfg(feature = "alloc")]
13use crate::types::BigInt;
14#[cfg(feature = "alloc")]
15use alloc::{vec, vec::Vec};
16
17#[derive(Clone, Debug)]
22pub enum PartialResult<T> {
23 Complete(T),
25 Partial(ContinuationState),
27}
28
29#[derive(Clone, Debug)]
31pub struct ContinuationState {
32 pub op_type: OpType,
34 pub state: OperationState,
36 pub cycles_consumed: u64,
38 pub max_cycles: u64,
40}
41
42#[derive(Clone, Debug, PartialEq, Eq)]
44pub enum OpType {
45 Mul,
47 ModExp,
49}
50
51#[derive(Clone, Debug)]
53pub enum OperationState {
54 MulState {
56 a: BigInt,
58 b: BigInt,
60 result: Vec<Limb>,
62 i: usize,
64 j: usize,
66 },
67 ModExpState {
69 modulus: BigInt,
71 exponent: BigInt,
73 r0: BigInt,
75 r1: BigInt,
77 bit_pos: usize,
79 total_bits: usize,
81 },
82}
83
84#[cfg(feature = "alloc")]
89pub fn mul_timeboxed(a: &BigInt, b: &BigInt, max_cycles: u64) -> Result<PartialResult<BigInt>> {
90 if a.is_zero() || b.is_zero() {
92 let max_limbs = a.max_limbs().max(b.max_limbs());
93 return Ok(PartialResult::Complete(BigInt::new(max_limbs)));
94 }
95
96 let n = a.limb_count().max(b.limb_count());
97 let result_limbs = 2 * n;
98
99 if result_limbs > a.max_limbs().max(b.max_limbs()) {
101 return Err(BigIntError::Overflow);
102 }
103
104 let max_limbs = a.max_limbs().max(b.max_limbs()).max(result_limbs);
105 let mut result = BigInt::new(max_limbs);
106 result.ensure_capacity(result_limbs)?;
107
108 let a_limbs = a.limbs();
109 let b_limbs = b.limbs();
110
111 let mut a_normalized = vec![0u64; n];
113 let mut b_normalized = vec![0u64; n];
114 a_normalized[..a_limbs.len()].copy_from_slice(a_limbs);
115 b_normalized[..b_limbs.len()].copy_from_slice(b_limbs);
116
117 let result_limbs_mut = result.limbs_mut();
118
119 for i in 0..(2 * n) {
121 result_limbs_mut[i] = 0;
122 }
123
124 let mut result_vec = result_limbs_mut.to_vec();
126 mul_timeboxed_comba(
127 &a_normalized,
128 &b_normalized,
129 &mut result_vec,
130 0,
131 0,
132 0,
133 max_cycles,
134 )
135}
136
137fn mul_timeboxed_comba(
138 a: &[Limb],
139 b: &[Limb],
140 result: &mut Vec<Limb>,
141 start_i: usize,
142 start_j: usize,
143 cycles_consumed: u64,
144 max_cycles: u64,
145) -> Result<PartialResult<BigInt>> {
146 let n = a.len();
147 debug_assert_eq!(b.len(), n);
148 debug_assert!(result.len() >= 2 * n);
149
150 let mut cycles = cycles_consumed;
151 let mut i = start_i;
152 let mut j = start_j;
153
154 while i < n {
156 let mut carry = 0u64;
157 let _j_start = if i == start_i { start_j } else { 0 };
158
159 while j < n {
160 cycles += 10; if cycles >= max_cycles {
163 let max_limbs = result.len();
165 let mut result_bigint = BigInt::new(max_limbs);
166 result_bigint.limbs_mut().copy_from_slice(result);
167
168 let state = OperationState::MulState {
169 a: BigInt::from_limbs(a, max_limbs)?,
170 b: BigInt::from_limbs(b, max_limbs)?,
171 result: result.to_vec(),
172 i,
173 j,
174 };
175
176 return Ok(PartialResult::Partial(ContinuationState {
177 op_type: OpType::Mul,
178 state,
179 cycles_consumed: cycles,
180 max_cycles,
181 }));
182 }
183
184 let wide = limb_mul_wide(a[i], b[j]);
185 let (lo, hi) = (wide.lo, wide.hi);
186 let (sum_lo, carry1) = limb_mac(result[i + j], lo, carry);
187 result[i + j] = sum_lo;
188 carry = carry1 + hi;
189 j += 1;
190 }
191
192 cycles += 2; if i + n < result.len() {
195 let (sum, _) = limb_mac(result[i + n], 0, carry);
196 result[i + n] = sum;
197 }
198
199 i += 1;
200 j = 0;
201 }
202
203 let max_limbs = result.len();
205 let mut result_bigint = BigInt::new(max_limbs);
206 result_bigint.ensure_capacity(result.len())?;
207 result_bigint.limbs_mut().copy_from_slice(result);
208 result_bigint.canonicalize()?;
209
210 Ok(PartialResult::Complete(result_bigint))
211}
212
213#[cfg(feature = "alloc")]
218pub fn modexp_timeboxed(
219 ctx: &MontgomeryContext,
220 base: &BigInt,
221 exponent: &BigInt,
222 max_cycles: u64,
223) -> Result<PartialResult<BigInt>> {
224 if exponent.is_zero() {
226 let one = BigInt::from_u64(1, ctx.modulus().max_limbs());
228 return Ok(PartialResult::Complete(one));
229 }
230
231 if base.is_zero() {
232 return Ok(PartialResult::Complete(BigInt::from_u64(
234 0,
235 ctx.modulus().max_limbs(),
236 )));
237 }
238
239 let x_mont = to_mont(ctx, base)?;
241
242 let one = BigInt::from_u64(1, ctx.modulus().max_limbs());
244 let r_mont = to_mont(ctx, &one)?;
245
246 let r0 = r_mont;
248 let r1 = x_mont;
249
250 let exp_limbs = exponent.limbs();
252 let mut msb_found = false;
253 let mut bit_pos = 0;
254
255 for i in (0..exp_limbs.len()).rev() {
256 if exp_limbs[i] != 0 {
257 bit_pos = (i * 64) + (63 - exp_limbs[i].leading_zeros() as usize);
258 msb_found = true;
259 break;
260 }
261 }
262
263 if !msb_found {
264 let one = BigInt::from_u64(1, ctx.modulus().max_limbs());
266 return Ok(PartialResult::Complete(one));
267 }
268
269 modexp_timeboxed_ladder(ctx, exponent, r0, r1, bit_pos, 0, max_cycles)
271}
272
273fn modexp_timeboxed_ladder(
274 ctx: &MontgomeryContext,
275 exponent: &BigInt,
276 mut r0: BigInt,
277 mut r1: BigInt,
278 total_bits: usize,
279 start_bit: usize,
280 max_cycles: u64,
281) -> Result<PartialResult<BigInt>> {
282 let mut cycles_consumed = 0u64;
283 let exp_limbs = exponent.limbs();
284
285 for bit_pos in (0..=total_bits).rev().skip(start_bit) {
287 cycles_consumed += 5; if cycles_consumed >= max_cycles {
290 let state = OperationState::ModExpState {
291 modulus: ctx.modulus().clone(),
292 exponent: exponent.clone(),
293 r0: r0.clone(),
294 r1: r1.clone(),
295 bit_pos,
296 total_bits,
297 };
298
299 return Ok(PartialResult::Partial(ContinuationState {
300 op_type: OpType::ModExp,
301 state,
302 cycles_consumed,
303 max_cycles,
304 }));
305 }
306
307 let limb_idx = bit_pos / 64;
308 let bit_idx = bit_pos % 64;
309 let bit = if limb_idx < exp_limbs.len() {
310 (exp_limbs[limb_idx] >> bit_idx) & 1
311 } else {
312 0
313 };
314
315 cycles_consumed += 20; ct::ct_swap(bit, &mut r0.limbs_mut()[0], &mut r1.limbs_mut()[0]);
318 swap_bigint(bit, &mut r0, &mut r1);
319
320 cycles_consumed += 50; r1 = mont_mul(ctx, &r0, &r1)?;
323 r0 = mont_mul(ctx, &r0, &r0)?;
324
325 cycles_consumed += 20; swap_bigint(bit, &mut r0, &mut r1);
328 }
329
330 let result = from_mont(ctx, &r0)?;
332
333 Ok(PartialResult::Complete(result))
334}
335
336fn swap_bigint(choice: u64, a: &mut BigInt, b: &mut BigInt) {
338 let a_limbs = a.limbs_mut();
339 let b_limbs = b.limbs_mut();
340 let len = a_limbs.len().min(b_limbs.len());
341
342 for i in 0..len {
343 ct::ct_swap(choice, &mut a_limbs[i], &mut b_limbs[i]);
344 }
345
346 let a_sign = a.sign();
348 let b_sign = b.sign();
349 let new_a_sign = ct::ct_select(choice, b_sign as u64, a_sign as u64) != 0;
350 let new_b_sign = ct::ct_select(choice, a_sign as u64, b_sign as u64) != 0;
351 a.set_sign(new_a_sign);
352 b.set_sign(new_b_sign);
353}
354
355#[cfg(feature = "alloc")]
359pub fn resume(state: ContinuationState) -> Result<PartialResult<BigInt>> {
360 match state.op_type {
361 OpType::Mul => {
362 if let OperationState::MulState {
364 a,
365 b,
366 mut result,
367 i,
368 j,
369 } = state.state
370 {
371 let a_limbs = a.limbs();
372 let b_limbs = b.limbs();
373
374 mul_timeboxed_comba(
375 a_limbs,
376 b_limbs,
377 &mut result,
378 i,
379 j,
380 state.cycles_consumed,
381 state.max_cycles,
382 )
383 } else {
384 Err(BigIntError::InvalidState)
385 }
386 }
387 OpType::ModExp => {
388 if let OperationState::ModExpState {
390 modulus,
391 exponent,
392 r0,
393 r1,
394 bit_pos,
395 total_bits,
396 } = state.state
397 {
398 let ctx = MontgomeryContext::new(modulus)?;
400
401 modexp_timeboxed_ladder(
402 &ctx,
403 &exponent,
404 r0,
405 r1,
406 total_bits,
407 bit_pos,
408 state.max_cycles - state.cycles_consumed,
409 )
410 } else {
411 Err(BigIntError::InvalidState)
412 }
413 }
414 }
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420
421 #[test]
422 fn test_mul_timeboxed() {
423 let a = BigInt::from_u64(10, 10);
424 let b = BigInt::from_u64(20, 10);
425 let result = mul_timeboxed(&a, &b, 1000).unwrap();
426 match result {
427 PartialResult::Complete(r) => {
428 assert_eq!(r.limbs()[0], 200);
429 }
430 PartialResult::Partial(_) => {
431 panic!("Expected complete result");
432 }
433 }
434 }
435
436 #[test]
437 fn test_mul_timeboxed_with_high_cycles() {
438 let a = BigInt::from_u64(100, 10);
439 let b = BigInt::from_u64(200, 10);
440
441 let result = mul_timeboxed(&a, &b, 10000).unwrap();
443 match result {
444 PartialResult::Complete(r) => {
445 assert_eq!(r.limbs()[0], 20000);
446 }
447 PartialResult::Partial(_) => {
448 panic!("Expected complete result with high cycle limit");
449 }
450 }
451 }
452}