use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use crate::crypto::VerificationKey;
use crate::error::Error;
use crate::traits::PairingEngine;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct VkId(u64);
impl VkId {
pub(crate) fn new(id: u64) -> Self {
Self(id)
}
pub fn as_u64(&self) -> u64 {
self.0
}
}
#[derive(Debug, Clone)]
pub struct RegisteredVk<E: PairingEngine> {
pub name: String,
pub vk: VerificationKey<E>,
}
pub struct VkRegistry<E: PairingEngine> {
inner: RwLock<VkRegistryInner<E>>,
}
struct VkRegistryInner<E: PairingEngine> {
vks: HashMap<VkId, Arc<RegisteredVk<E>>>,
name_to_id: HashMap<String, VkId>,
next_id: u64,
}
impl<E: PairingEngine> VkRegistry<E> {
pub fn new() -> Self {
Self {
inner: RwLock::new(VkRegistryInner {
vks: HashMap::new(),
name_to_id: HashMap::new(),
next_id: 0,
}),
}
}
pub fn register(&self, name: impl Into<String>, vk: VerificationKey<E>) -> VkId {
let name = name.into();
let mut inner = self.inner.write().unwrap();
let id = VkId(inner.next_id);
inner.next_id += 1;
let registered = Arc::new(RegisteredVk {
name: name.clone(),
vk,
});
inner.vks.insert(id, registered);
inner.name_to_id.insert(name, id);
id
}
pub fn get(&self, id: VkId) -> Option<Arc<RegisteredVk<E>>> {
self.inner.read().ok()?.vks.get(&id).cloned()
}
pub fn get_by_name(&self, name: &str) -> Option<VkId> {
self.inner.read().ok()?.name_to_id.get(name).copied()
}
pub fn contains(&self, id: VkId) -> bool {
self.inner.read().map(|guard| guard.vks.contains_key(&id)).unwrap_or(false)
}
pub fn len(&self) -> usize {
self.inner.read().map(|guard| guard.vks.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn require(&self, id: VkId) -> Result<Arc<RegisteredVk<E>>, Error> {
self.get(id).ok_or(Error::UnknownVk(id))
}
}
impl<E: PairingEngine> Default for VkRegistry<E> {
fn default() -> Self {
Self::new()
}
}
impl<E: PairingEngine> std::fmt::Debug for VkRegistry<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.inner.read().unwrap();
f.debug_struct("VkRegistry").field("num_vks", &inner.vks.len()).finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use group::{Curve, Group};
use halo2curves::bn256::{G1, G2};
use rand::rngs::OsRng;
#[derive(Clone, Debug)]
struct MockEngine;
impl PairingEngine for MockEngine {
type Fr = halo2curves::bn256::Fr;
type G1Affine = halo2curves::bn256::G1Affine;
type G1 = halo2curves::bn256::G1;
type G2Affine = halo2curves::bn256::G2Affine;
type Gt = halo2curves::bn256::Gt;
fn pairing(_p: &Self::G1Affine, _q: &Self::G2Affine) -> Self::Gt {
unimplemented!("mock")
}
fn multi_pairing<'a>(
_pairs: impl IntoIterator<Item = (&'a Self::G1Affine, &'a Self::G2Affine)>,
) -> Self::Gt {
unimplemented!("mock")
}
}
fn mock_vk(num_public_inputs: usize) -> VerificationKey<MockEngine> {
VerificationKey {
num_public_inputs,
domain_size: 1024,
selector_commitments: vec![
G1::random(OsRng).to_affine(),
G1::random(OsRng).to_affine(),
],
permutation_commitments: vec![G1::random(OsRng).to_affine()],
x_g2: G2::random(OsRng).to_affine(),
g2_generator: G2::generator().to_affine(),
}
}
#[test]
fn registry_assigns_unique_ids() {
let registry = VkRegistry::<MockEngine>::new();
let id1 = registry.register("circuit_a", mock_vk(5));
let id2 = registry.register("circuit_b", mock_vk(3));
assert_ne!(id1, id2);
}
#[test]
fn registry_retrieves_by_id() {
let registry = VkRegistry::<MockEngine>::new();
let id = registry.register("test", mock_vk(5));
let retrieved = registry.get(id).unwrap();
assert_eq!(retrieved.name, "test");
assert_eq!(retrieved.vk.num_public_inputs, 5);
}
#[test]
fn registry_retrieves_by_name() {
let registry = VkRegistry::<MockEngine>::new();
let id = registry.register("transfer", mock_vk(2));
let found_id = registry.get_by_name("transfer").unwrap();
assert_eq!(id, found_id);
}
#[test]
fn registry_returns_none_for_unknown() {
let registry = VkRegistry::<MockEngine>::new();
assert!(registry.get(VkId(999)).is_none());
assert!(registry.get_by_name("unknown").is_none());
}
#[test]
fn registry_require_returns_error_for_unknown() {
let registry = VkRegistry::<MockEngine>::new();
let result = registry.require(VkId(999));
assert!(matches!(result, Err(Error::UnknownVk(_))));
}
#[test]
fn registry_len_tracks_registrations() {
let registry = VkRegistry::<MockEngine>::new();
assert!(registry.is_empty());
assert_eq!(registry.len(), 0);
registry.register("a", mock_vk(1));
assert!(!registry.is_empty());
assert_eq!(registry.len(), 1);
registry.register("b", mock_vk(1));
assert_eq!(registry.len(), 2);
}
#[test]
fn registry_is_thread_safe() {
use std::thread;
let registry = Arc::new(VkRegistry::<MockEngine>::new());
let registry2 = Arc::clone(®istry);
let handle = thread::spawn(move || registry2.register("from_thread", mock_vk(1)));
let id1 = registry.register("from_main", mock_vk(1));
let id2 = handle.join().unwrap();
assert_ne!(id1, id2);
assert_eq!(registry.len(), 2);
}
}