arcium-primitives 0.4.2

Arcium primitives
Documentation
use std::num::Wrapping;

use crate::{
    random::Seed,
    transcripts::{AutoTranscript, FolkloreTranscript, Transcript},
    types::{identifiers::ProtocolInfo, SessionId},
};

/// A transcript that appends a counter to each message. The counter starts at 0
/// and increments by 1 with each appended message. This ensures that each message
/// can be uniquely identified by its position in the transcript.
#[derive(Debug, Clone)]
pub struct ProtocolTranscript<T = FolkloreTranscript> {
    transcript: T,
    counter: Wrapping<u64>,
    protocol_info: &'static ProtocolInfo,
}

impl<T: Transcript> Transcript for ProtocolTranscript<T> {
    type Rng = T::Rng;
    fn new(protocol_info: &'static ProtocolInfo, session_id: &SessionId) -> Self {
        Self {
            transcript: T::new(protocol_info, session_id),
            counter: Wrapping(0),
            protocol_info,
        }
    }

    fn append_with<Msg: AsRef<[u8]>>(&mut self, label: &'static [u8], msg: &Msg) {
        let counter = self.next_label();
        self.transcript
            .append_many_with(label, &[counter.as_ref(), msg.as_ref()]);
    }

    fn append_many_with<Msg: AsRef<[u8]>>(&mut self, label: &'static [u8], messages: &[Msg]) {
        let auto_label = self.next_label();
        let messages: Vec<&[u8]> = std::iter::once(auto_label.as_ref())
            .chain(messages.iter().map(AsRef::as_ref))
            .collect();
        self.transcript.append_many_with(label, &messages);
    }

    fn extract(&mut self, label: &'static [u8]) -> Seed {
        self.transcript.extract(label)
    }

    fn extract_rng(&mut self, label: &'static [u8]) -> T::Rng {
        self.transcript.extract_rng(label)
    }
}

impl<T: Transcript> AutoTranscript for ProtocolTranscript<T> {
    type Label = [u8; 8];

    #[inline]
    fn get_current_label(&self) -> Self::Label {
        self.counter.0.to_le_bytes()
    }

    #[inline]
    fn next_label(&mut self) -> Self::Label {
        self.counter += Wrapping(1);
        self.get_current_label()
    }

    #[inline]
    fn append<Msg: AsRef<[u8]>>(&mut self, msg: &Msg) {
        self.transcript.append_with(self.protocol_info.tag(), msg);
    }

    #[inline]
    fn append_many<Msg: AsRef<[u8]>>(&mut self, msg: &[Msg]) {
        self.transcript
            .append_many_with(self.protocol_info.tag(), msg);
    }
}