use glass_pumpkin::safe_prime;
use num_bigint::{BigUint, RandBigInt};
use rand::thread_rng;
use sha2::{digest::Digest, Sha256};
use std::{cell::RefCell, fmt};
use arithmetic_eval::{
arith::{Arithmetic, ArithmeticExt, ModularArithmetic},
error::{ArithmeticError, AuxErrorInfo},
fns, Assertions, CallContext, ErrorKind, EvalResult, ExecutableModule, NativeFn, Number,
Prelude, SpannedValue, Value,
};
use arithmetic_parser::{
grammars::{Features, NumGrammar, NumLiteral, Parse, Untyped},
InputSpan, NomResult,
};
#[derive(Debug, Clone)]
enum GroupLiteral {
Scalar(BigUint),
GroupElement(BigUint),
}
impl fmt::Display for GroupLiteral {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Scalar(sc) => fmt::Display::fmt(sc, formatter),
Self::GroupElement(ge) => write!(formatter, "Ge({})", ge),
}
}
}
impl NumLiteral for GroupLiteral {
fn parse(input: InputSpan<'_>) -> NomResult<'_, Self> {
<BigUint as NumLiteral>::parse(input)
.map(|(rest, output)| (rest, GroupLiteral::Scalar(output)))
}
}
impl Number for GroupLiteral {}
#[derive(Debug)]
struct CyclicGroupArithmetic {
for_group: ModularArithmetic<BigUint>,
generator: BigUint,
for_scalars: ModularArithmetic<BigUint>,
}
impl CyclicGroupArithmetic {
fn new(bits: usize) -> Self {
let safe_prime = safe_prime::new(bits).unwrap();
let prime_subgroup_order = &safe_prime >> 1;
let two = BigUint::from(2_u32);
let generator = thread_rng()
.gen_biguint_range(&two, &safe_prime)
.modpow(&two, &safe_prime);
Self {
for_group: ModularArithmetic::new(safe_prime),
generator,
for_scalars: ModularArithmetic::new(prime_subgroup_order),
}
}
fn rand_scalar(&self) -> impl Fn() -> GroupLiteral {
let rng = RefCell::new(thread_rng());
let two = BigUint::from(2_u32);
let prime_subgroup_order = self.for_scalars.modulus().to_owned();
move || {
GroupLiteral::Scalar(
rng.borrow_mut()
.gen_biguint_range(&two, &prime_subgroup_order),
)
}
}
fn hash_to_scalar(&self) -> HashToScalar {
let max_bit_len = self.for_group.modulus().bits();
let max_byte_len = (max_bit_len / 8) as usize + (max_bit_len % 8 != 0) as usize;
HashToScalar {
modulus: self.for_scalars.modulus().to_owned(),
max_byte_len,
}
}
fn to_scalar(&self) -> impl Fn(GroupLiteral) -> GroupLiteral {
let prime_subgroup_order = self.for_scalars.modulus().to_owned();
move |value| match value {
GroupLiteral::Scalar(sc) => GroupLiteral::Scalar(sc),
GroupLiteral::GroupElement(ge) => GroupLiteral::Scalar(ge % &prime_subgroup_order),
}
}
fn set_imports(&self, module: &mut ExecutableModule<'_, GroupLiteral>) {
let generator = GroupLiteral::GroupElement(self.generator.clone());
let prime_subgroup_order = GroupLiteral::Scalar(self.for_group.modulus().to_owned());
module
.set_import("GEN", Value::Prim(generator))
.set_import("ORDER", Value::Prim(prime_subgroup_order))
.set_import("rand_scalar", Value::wrapped_fn(self.rand_scalar()))
.set_import("hash_to_scalar", Value::native_fn(self.hash_to_scalar()));
}
}
#[derive(Debug)]
struct HashToScalar {
modulus: BigUint,
max_byte_len: usize,
}
impl HashToScalar {
fn hash_scalar(&self, hasher: &mut Sha256, sc: &BigUint) {
hasher.update(&[0]);
let mut sc_bytes = sc.to_bytes_le();
assert!(sc_bytes.len() <= self.max_byte_len);
sc_bytes.resize(self.max_byte_len, 0);
hasher.update(&sc_bytes); }
fn hash_group_element(&self, hasher: &mut Sha256, ge: &BigUint) {
hasher.update(&[1]);
let mut ge_bytes = ge.to_bytes_le();
assert!(ge_bytes.len() <= self.max_byte_len);
ge_bytes.resize(self.max_byte_len, 0);
hasher.update(&ge_bytes); }
}
impl NativeFn<GroupLiteral> for HashToScalar {
fn evaluate<'a>(
&self,
args: Vec<SpannedValue<'a, GroupLiteral>>,
context: &mut CallContext<'_, 'a, GroupLiteral>,
) -> EvalResult<'a, GroupLiteral> {
let mut hasher = Sha256::default();
for arg in &args {
match &arg.extra {
Value::Prim(GroupLiteral::Scalar(sc)) => self.hash_scalar(&mut hasher, sc),
Value::Prim(GroupLiteral::GroupElement(ge)) => {
self.hash_group_element(&mut hasher, ge);
}
_ => {
let err = ErrorKind::native("Cannot hash value");
return Err(context
.call_site_error(err)
.with_span(arg, AuxErrorInfo::InvalidArg));
}
}
}
let mut hash_scalar = BigUint::from_bytes_le(hasher.finalize().as_slice());
hash_scalar %= &self.modulus;
Ok(Value::Prim(GroupLiteral::Scalar(hash_scalar)))
}
}
impl Arithmetic<GroupLiteral> for CyclicGroupArithmetic {
fn add(&self, x: GroupLiteral, y: GroupLiteral) -> Result<GroupLiteral, ArithmeticError> {
match (x, y) {
(GroupLiteral::Scalar(x), GroupLiteral::Scalar(y)) => {
self.for_scalars.add(x, y).map(GroupLiteral::Scalar)
}
_ => Err(ArithmeticError::invalid_op("only scalars may be added")),
}
}
fn sub(&self, x: GroupLiteral, y: GroupLiteral) -> Result<GroupLiteral, ArithmeticError> {
match (x, y) {
(GroupLiteral::Scalar(x), GroupLiteral::Scalar(y)) => {
self.for_scalars.sub(x, y).map(GroupLiteral::Scalar)
}
_ => Err(ArithmeticError::invalid_op(
"only scalars may be subtracted",
)),
}
}
fn mul(&self, x: GroupLiteral, y: GroupLiteral) -> Result<GroupLiteral, ArithmeticError> {
match (x, y) {
(GroupLiteral::Scalar(x), GroupLiteral::Scalar(y)) => {
self.for_scalars.mul(x, y).map(GroupLiteral::Scalar)
}
(GroupLiteral::GroupElement(x), GroupLiteral::GroupElement(y)) => {
self.for_group.mul(x, y).map(GroupLiteral::GroupElement)
}
_ => Err(ArithmeticError::invalid_op(
"multiplication operands must have same type",
)),
}
}
fn div(&self, x: GroupLiteral, y: GroupLiteral) -> Result<GroupLiteral, ArithmeticError> {
match (x, y) {
(GroupLiteral::Scalar(x), GroupLiteral::Scalar(y)) => {
self.for_scalars.div(x, y).map(GroupLiteral::Scalar)
}
(GroupLiteral::GroupElement(x), GroupLiteral::GroupElement(y)) => {
self.for_group.div(x, y).map(GroupLiteral::GroupElement)
}
_ => Err(ArithmeticError::invalid_op(
"division operands must have same type",
)),
}
}
fn pow(&self, x: GroupLiteral, y: GroupLiteral) -> Result<GroupLiteral, ArithmeticError> {
match (x, y) {
(GroupLiteral::Scalar(x), GroupLiteral::Scalar(y)) => {
self.for_scalars.pow(x, y).map(GroupLiteral::Scalar)
}
(GroupLiteral::GroupElement(x), GroupLiteral::Scalar(y)) => {
self.for_group.pow(x, y).map(GroupLiteral::GroupElement)
}
_ => Err(ArithmeticError::invalid_op("exponent must be a scalar")),
}
}
fn neg(&self, x: GroupLiteral) -> Result<GroupLiteral, ArithmeticError> {
if let GroupLiteral::Scalar(x) = x {
self.for_scalars.neg(x).map(GroupLiteral::Scalar)
} else {
Err(ArithmeticError::invalid_op("only scalars can be negated"))
}
}
fn eq(&self, x: &GroupLiteral, y: &GroupLiteral) -> bool {
match (x, y) {
(GroupLiteral::Scalar(x), GroupLiteral::Scalar(y)) => self.for_scalars.eq(x, y),
(GroupLiteral::GroupElement(x), GroupLiteral::GroupElement(y)) => {
self.for_group.eq(x, y)
}
_ => false,
}
}
}
const SCHNORR_SIGNATURES: &str = include_str!("schnorr.script");
const DSA_SIGNATURES: &str = include_str!("dsa.script");
#[derive(Debug, Clone, Copy)]
struct GroupGrammar;
impl Parse<'_> for GroupGrammar {
type Base = Untyped<NumGrammar<GroupLiteral>>;
const FEATURES: Features = Features::all()
.without(Features::TYPE_ANNOTATIONS)
.without(Features::ORDER_COMPARISONS);
}
fn main() -> anyhow::Result<()> {
const BIT_LENGTH: usize = 256;
let schnorr_signatures = GroupGrammar::parse_statements(SCHNORR_SIGNATURES)?;
let mut schnorr_signatures = ExecutableModule::builder("schnorr", &schnorr_signatures)?
.with_imports_from(&Prelude)
.with_imports_from(&Assertions)
.with_import("dbg", Value::native_fn(fns::Dbg))
.set_imports(|_| Value::void());
let dsa_signatures = GroupGrammar::parse_statements(DSA_SIGNATURES)?;
let mut dsa_signatures = ExecutableModule::builder("dsa", &dsa_signatures)?
.with_imports_from(&Prelude)
.with_imports_from(&Assertions)
.with_import("dbg", Value::native_fn(fns::Dbg))
.set_imports(|_| Value::void());
for i in 0..5 {
println!("\nRunning sample #{}", i);
let arithmetic = CyclicGroupArithmetic::new(BIT_LENGTH);
arithmetic.set_imports(&mut schnorr_signatures);
arithmetic.set_imports(&mut dsa_signatures);
dsa_signatures.set_import("to_scalar", Value::wrapped_fn(arithmetic.to_scalar()));
let arithmetic = arithmetic.without_comparisons();
schnorr_signatures.with_arithmetic(&arithmetic).run()?;
dsa_signatures.with_arithmetic(&arithmetic).run()?;
}
Ok(())
}