use chrono::{DateTime, Utc};
use num_bigint::BigUint;
use num_traits::One;
use sha2::{Digest, Sha256};
use crate::domain::crypto::{decrypt_aes_gcm, encrypt_aes_gcm};
use crate::domain::errors::TimeLockError;
use crate::domain::types::{Payload, TimeLockPuzzle};
const DEFAULT_SQUARINGS_PER_SEC: u64 = 10_000_000;
const MODULUS_BITS: u64 = 256;
const NONCE_LEN: usize = 12;
fn derive_key_and_nonce(solution: &BigUint) -> ([u8; 32], [u8; NONCE_LEN]) {
let solution_bytes = solution.to_bytes_be();
let key: [u8; 32] = Sha256::digest(&solution_bytes).into();
let mut nonce_input = b"nonce".to_vec();
nonce_input.extend_from_slice(&solution_bytes);
let nonce_hash = Sha256::digest(&nonce_input);
let mut nonce = [0u8; NONCE_LEN];
nonce.copy_from_slice(nonce_hash.get(..NONCE_LEN).unwrap_or(&[0u8; NONCE_LEN]));
(key, nonce)
}
fn sequential_square(g: &BigUint, n: &BigUint, t: u64) -> BigUint {
let mut result = g.clone();
for _ in 0..t {
result = (&result * &result) % n;
}
result
}
pub fn create_puzzle(
payload: &Payload,
unlock_at: DateTime<Utc>,
squarings_per_sec: u64,
) -> Result<TimeLockPuzzle, TimeLockError> {
let now = Utc::now();
let duration_secs = (unlock_at - now).num_seconds().max(0).cast_unsigned();
let squarings_required = duration_secs.saturating_mul(squarings_per_sec);
let mut rng = rand::rng();
let half_bits = MODULUS_BITS / 2;
let p = generate_random_prime(&mut rng, half_bits);
let q = generate_random_prime(&mut rng, half_bits);
let n = &p * &q;
let g = {
let two = BigUint::from(2u32);
let upper = &n - &two;
let rand_val = random_biguint(&mut rng, MODULUS_BITS);
(rand_val % &upper) + &two
};
let phi = (&p - BigUint::one()) * (&q - BigUint::one());
let two = BigUint::from(2u32);
let exponent = two.modpow(&BigUint::from(squarings_required), &phi);
let solution = g.modpow(&exponent, &n);
let (key, nonce) = derive_key_and_nonce(&solution);
let ciphertext = encrypt_aes_gcm(&key, &nonce, payload.as_bytes()).map_err(|e| {
TimeLockError::ComputationFailed {
reason: format!("encryption failed: {e}"),
}
})?;
Ok(TimeLockPuzzle {
ciphertext,
modulus: n.to_bytes_be(),
start_value: g.to_bytes_be(),
squarings_required,
created_at: now,
unlock_at,
})
}
pub fn solve_puzzle(puzzle: &TimeLockPuzzle) -> Result<Payload, TimeLockError> {
let n = BigUint::from_bytes_be(&puzzle.modulus);
let g = BigUint::from_bytes_be(&puzzle.start_value);
let solution = sequential_square(&g, &n, puzzle.squarings_required);
let (key, nonce) = derive_key_and_nonce(&solution);
let plaintext = decrypt_aes_gcm(&key, &nonce, &puzzle.ciphertext)
.map_err(|source| TimeLockError::DecryptFailed { source })?;
Ok(Payload::from_bytes(plaintext.to_vec()))
}
pub fn try_solve_puzzle(
puzzle: &TimeLockPuzzle,
squarings_per_sec: u64,
) -> Result<Option<Payload>, TimeLockError> {
let now = Utc::now();
let elapsed_secs = (now - puzzle.created_at)
.num_seconds()
.max(0)
.cast_unsigned();
let estimated_solvable_squarings = elapsed_secs.saturating_mul(squarings_per_sec);
if estimated_solvable_squarings < puzzle.squarings_required {
return Ok(None);
}
solve_puzzle(puzzle).map(Some)
}
fn generate_random_prime(rng: &mut impl rand::Rng, bits: u64) -> BigUint {
loop {
let mut candidate = random_biguint(rng, bits);
candidate |= BigUint::one();
candidate |= BigUint::one() << (bits - 1);
if is_probably_prime(&candidate) {
return candidate;
}
}
}
fn is_probably_prime(n: &BigUint) -> bool {
let two = BigUint::from(2u32);
if n < &two {
return false;
}
if n == &two || n == &BigUint::from(3u32) {
return true;
}
if n % &two == BigUint::ZERO {
return false;
}
let mut i = BigUint::from(3u32);
let limit = BigUint::from(10_000u32);
while &i * &i <= *n && i < limit {
if n % &i == BigUint::ZERO {
return false;
}
i += &two;
}
true
}
fn random_biguint(rng: &mut impl rand::Rng, bits: u64) -> BigUint {
#[expect(
clippy::cast_possible_truncation,
reason = "bits <= 256, always fits usize"
)]
let byte_count = bits.div_ceil(8) as usize;
let mut buf = vec![0u8; byte_count];
rng.fill_bytes(&mut buf);
let excess_bits = (byte_count * 8) as u64 - bits;
if excess_bits > 0
&& let Some(first) = buf.first_mut()
{
*first &= 0xFF >> excess_bits;
}
BigUint::from_bytes_be(&buf)
}
#[must_use]
pub const fn default_squarings_per_sec() -> u64 {
DEFAULT_SQUARINGS_PER_SEC
}
#[cfg(test)]
mod tests {
use super::*;
type TestResult = Result<(), Box<dyn std::error::Error>>;
#[test]
fn roundtrip_lock_then_unlock() -> TestResult {
let payload = Payload::from_bytes(b"secret message".to_vec());
let unlock_at = Utc::now();
let puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
let recovered = solve_puzzle(&puzzle)?;
assert_eq!(recovered.as_bytes(), payload.as_bytes());
Ok(())
}
#[test]
fn roundtrip_with_small_delay() -> TestResult {
let payload = Payload::from_bytes(b"timed secret".to_vec());
let unlock_at = Utc::now();
let puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
let recovered = solve_puzzle(&puzzle)?;
assert_eq!(recovered.as_bytes(), payload.as_bytes());
Ok(())
}
#[test]
fn try_solve_returns_none_for_future_puzzle() -> TestResult {
let payload = Payload::from_bytes(b"future secret".to_vec());
let unlock_at = Utc::now() + chrono::Duration::hours(1);
let puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
let result = try_solve_puzzle(&puzzle, DEFAULT_SQUARINGS_PER_SEC)?;
assert!(result.is_none());
Ok(())
}
#[test]
fn puzzle_serialises_to_json() -> TestResult {
let payload = Payload::from_bytes(b"json test".to_vec());
let unlock_at = Utc::now();
let puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
let json = serde_json::to_string_pretty(&puzzle)?;
assert!(json.contains("squarings_required"));
assert!(json.contains("modulus"));
let recovered: TimeLockPuzzle = serde_json::from_str(&json)?;
assert_eq!(recovered.squarings_required, puzzle.squarings_required);
Ok(())
}
#[test]
fn different_payloads_produce_different_puzzles() -> TestResult {
let p1 = Payload::from_bytes(b"payload one".to_vec());
let p2 = Payload::from_bytes(b"payload two".to_vec());
let unlock_at = Utc::now();
let puzzle1 = create_puzzle(&p1, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
let puzzle2 = create_puzzle(&p2, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
assert_ne!(puzzle1.modulus, puzzle2.modulus);
Ok(())
}
#[test]
fn wrong_solution_fails_decrypt() -> TestResult {
let payload = Payload::from_bytes(b"fail test".to_vec());
let unlock_at = Utc::now();
let mut puzzle = create_puzzle(&payload, unlock_at, DEFAULT_SQUARINGS_PER_SEC)?;
if let Some(byte) = puzzle.start_value.first_mut() {
*byte ^= 0xFF;
}
let result = solve_puzzle(&puzzle);
assert!(result.is_err());
Ok(())
}
#[test]
fn derive_key_and_nonce_is_deterministic() {
let solution = BigUint::from(42u32);
let (k1, n1) = derive_key_and_nonce(&solution);
let (k2, n2) = derive_key_and_nonce(&solution);
assert_eq!(k1, k2);
assert_eq!(n1, n2);
}
#[test]
fn sequential_square_identity() {
let n = BigUint::from(143u32); let g = BigUint::from(2u32);
let r0 = sequential_square(&g, &n, 0);
assert_eq!(r0, g);
let r1 = sequential_square(&g, &n, 1);
assert_eq!(r1, BigUint::from(4u32));
let r2 = sequential_square(&g, &n, 2);
assert_eq!(r2, BigUint::from(16u32));
}
}