use crate::math::NttContext;
use serde::{Deserialize, Serialize};
pub const DEFAULT_CRT_MODULI: [u64; 2] = [268_369_921, 249_561_089];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum SecurityLevel {
Bits128,
Bits256,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum InspireVariant {
#[default]
NoPacking,
#[allow(dead_code)]
OnePacking,
#[allow(dead_code)]
TwoPacking,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct InspireParams {
pub ring_dim: usize,
pub q: u64,
pub crt_moduli: Vec<u64>,
pub p: u64,
pub sigma: f64,
pub gadget_base: u64,
pub gadget_len: usize,
pub security_level: SecurityLevel,
}
impl InspireParams {
pub fn secure_128_d2048() -> Self {
let crt_moduli = DEFAULT_CRT_MODULI.to_vec();
let q: u64 = crt_moduli.iter().product();
let gadget_base: u64 = 1 << 20; let gadget_len = ((q as f64).log2() / 20.0).ceil() as usize;
Self {
ring_dim: 2048,
q,
crt_moduli,
p: 65537, sigma: 6.4,
gadget_base,
gadget_len,
security_level: SecurityLevel::Bits128,
}
}
pub fn secure_128_d4096() -> Self {
let crt_moduli = DEFAULT_CRT_MODULI.to_vec();
let q: u64 = crt_moduli.iter().product();
let gadget_base: u64 = 1 << 20;
let gadget_len = ((q as f64).log2() / 20.0).ceil() as usize;
Self {
ring_dim: 4096,
q,
crt_moduli,
p: 65537, sigma: 6.4,
gadget_base,
gadget_len,
security_level: SecurityLevel::Bits128,
}
}
pub fn delta(&self) -> u64 {
self.q / self.p
}
pub fn moduli(&self) -> &[u64] {
&self.crt_moduli
}
pub fn crt_count(&self) -> usize {
self.crt_moduli.len()
}
pub fn ntt_context(&self) -> NttContext {
NttContext::with_moduli(self.ring_dim, self.moduli())
}
pub fn validate(&self) -> Result<(), &'static str> {
if !self.ring_dim.is_power_of_two() {
return Err("ring_dim must be a power of two");
}
if self.crt_moduli.is_empty() {
return Err("crt_moduli must be non-empty");
}
if self.crt_moduli.len() > 2 {
return Err("crt_moduli length > 2 is not supported");
}
let two_n = 2 * self.ring_dim as u64;
for &m in &self.crt_moduli {
if m % two_n != 1 {
return Err("CRT moduli must be ≡ 1 (mod 2d) for NTT");
}
}
let crt_product: u64 = self.crt_moduli.iter().product();
if self.q != crt_product {
return Err("q must equal product of CRT moduli");
}
if self.crt_moduli.len() == 2 {
let a = self.crt_moduli[0];
let b = self.crt_moduli[1];
if gcd_u64(a, b) != 1 {
return Err("CRT moduli must be coprime");
}
}
if self.q < self.p {
return Err("q must be >= p");
}
Ok(())
}
}
fn gcd_u64(mut a: u64, mut b: u64) -> u64 {
while b != 0 {
let t = a % b;
a = b;
b = t;
}
a
}
impl Default for InspireParams {
fn default() -> Self {
Self::secure_128_d2048()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ShardConfig {
pub shard_size_bytes: u64,
pub entry_size_bytes: usize,
pub total_entries: u64,
}
impl ShardConfig {
pub fn ethereum_state(total_entries: u64) -> Self {
Self {
shard_size_bytes: 1 << 30, entry_size_bytes: 32,
total_entries,
}
}
pub fn entries_per_shard(&self) -> u64 {
self.shard_size_bytes / self.entry_size_bytes as u64
}
pub fn num_shards(&self) -> u64 {
self.total_entries.div_ceil(self.entries_per_shard())
}
pub fn index_to_shard(&self, global_idx: u64) -> (u32, u64) {
let entries_per_shard = self.entries_per_shard();
let shard_id = (global_idx / entries_per_shard) as u32;
let local_idx = global_idx % entries_per_shard;
(shard_id, local_idx)
}
pub fn shard_to_index(&self, shard_id: u32, local_idx: u64) -> u64 {
shard_id as u64 * self.entries_per_shard() + local_idx
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_params_valid() {
let params = InspireParams::default();
assert!(params.validate().is_ok());
}
#[test]
fn test_delta_calculation() {
let params = InspireParams::secure_128_d2048();
let delta = params.delta();
assert!(delta > 0);
assert!(delta > (1 << 39)); }
#[test]
fn test_shard_config() {
let config = ShardConfig::ethereum_state(2_417_514_276);
let entries_per_shard = config.entries_per_shard();
assert_eq!(entries_per_shard, 1 << 25);
let num_shards = config.num_shards();
assert!(num_shards > 70 && num_shards < 80);
}
#[test]
fn test_index_conversion() {
let config = ShardConfig::ethereum_state(2_417_514_276);
let global_idx = 100_000_000u64;
let (shard_id, local_idx) = config.index_to_shard(global_idx);
let recovered = config.shard_to_index(shard_id, local_idx);
assert_eq!(global_idx, recovered);
}
}