Skip to main content

lilium_transcript/
transcript.rs

1use crate::{messages::PointRound, Error, Message};
2use ark_ff::Field;
3use sponge::sponge::Duplex;
4use std::{
5    any::{Any, TypeId},
6    marker::PhantomData,
7    vec::IntoIter,
8};
9
10pub struct Transcript<F: Field, S: Duplex<F>> {
11    sponge: S,
12    rounds: IntoIter<(TypeId, usize)>,
13    vars: usize,
14    _f: PhantomData<F>,
15}
16
17impl<F: Field, S: Duplex<F>> Transcript<F, S> {
18    /// Prints the current state of the sponge, for debugging.
19    pub fn print_state(&self) {
20        self.sponge.print();
21    }
22    pub fn guard<P>(&mut self, proof: P) -> TranscriptGuard<'_, F, S, P> {
23        TranscriptGuard::new(self, proof)
24    }
25
26    pub(crate) fn new(sponge: S, rounds: IntoIter<(TypeId, usize)>, vars: usize) -> Self {
27        Self {
28            sponge,
29            rounds,
30            vars,
31            _f: PhantomData,
32        }
33    }
34
35    pub fn send_message<T, const N: usize>(&mut self, message: &T) -> Result<[F; N], Error>
36    where
37        T: Any + Message<F>,
38    {
39        let id = message.type_id();
40        let elems = message.to_field_elements();
41        for elem in elems.into_iter() {
42            self.sponge.absorb(elem).map_err(Error::SpongeError)?;
43        }
44        let round = self.rounds.next().ok_or(Error::TranscriptFinished)?;
45        if round != (id, N) {
46            return Err(Error::UnexpectedMessage);
47        }
48        let challenges = [(); N].map(|_| self.sponge.squeeze().map_err(Error::SpongeError));
49        let challenges: Result<Vec<F>, Error> = challenges.into_iter().collect();
50        let challenges: [F; N] = challenges?.try_into().unwrap();
51        Ok(challenges)
52    }
53    /// generates a multivariate point
54    pub fn point(&mut self) -> Result<Vec<F>, Error> {
55        let round = self.rounds.next().ok_or(Error::TranscriptFinished)?;
56        let id = TypeId::of::<PointRound>();
57        if round != (id, self.vars) {
58            return Err(Error::UnexpectedMessage);
59        }
60        let challenges = (0..self.vars).map(|_| self.sponge.squeeze().map_err(Error::SpongeError));
61        challenges.into_iter().collect()
62    }
63    pub fn finish(self) -> Result<(), Error> {
64        self.sponge.finish().map_err(Error::SpongeError)
65    }
66    pub fn finish_unchecked(self) {
67        if let Err(err) = self.finish() {
68            println!("{:#?}", err);
69            panic!();
70        }
71    }
72}
73
74/// Wraps transcript and proof, ensuring no message circumvents
75/// the transcript.
76pub struct TranscriptGuard<'a, F: Field, S: Duplex<F>, P> {
77    transcript: &'a mut Transcript<F, S>,
78    proof: P,
79}
80
81/// wrapper to prevent values accidentally bypassing the transcript
82pub struct MessageGuard<I>(I);
83
84impl<I> From<I> for MessageGuard<I> {
85    fn from(value: I) -> Self {
86        Self(value)
87    }
88}
89
90impl<I> MessageGuard<I> {
91    pub fn new(inner: I) -> Self {
92        MessageGuard(inner)
93    }
94}
95
96impl<I> MessageGuard<Vec<I>> {
97    pub fn transpose(self) -> Vec<MessageGuard<I>> {
98        self.0.into_iter().map(MessageGuard).collect()
99    }
100}
101
102impl<I, const N: usize> MessageGuard<[I; N]> {
103    pub fn transpose(self) -> [MessageGuard<I>; N] {
104        self.0.map(MessageGuard)
105    }
106}
107
108impl<'a, F: Field, S: Duplex<F>, P> TranscriptGuard<'a, F, S, P> {
109    /// Prints the current state of the sponge, for debugging.
110    pub fn print_state(&self) {
111        self.transcript.print_state();
112    }
113
114    pub fn new(transcript: &'a mut Transcript<F, S>, proof: P) -> Self {
115        Self { transcript, proof }
116    }
117
118    pub fn new_guard<P2>(
119        &mut self,
120        proof: impl Into<MessageGuard<P2>>,
121    ) -> TranscriptGuard<'_, F, S, P2> {
122        let proof: MessageGuard<P2> = proof.into();
123        let proof = proof.0;
124        TranscriptGuard {
125            transcript: self.transcript,
126            proof,
127        }
128    }
129
130    /// Allows to extract messages from the proof, absorbing them in the
131    /// transcript automatically, also returning the corresponding challenges.
132    pub fn receive_message<M, Q, const N: usize>(&mut self, query: Q) -> Result<(M, [F; N]), Error>
133    where
134        M: Message<F> + 'static,
135        Q: Fn(&P) -> M,
136    {
137        let message = query(&self.proof);
138        let challenges: [F; N] = self.transcript.send_message(&message)?;
139        Ok((message, challenges))
140    }
141    /// similar to receive_message, doesn't interact with the sponge in any way and returns
142    /// a guarded value to be unwrapped later.
143    pub fn receive_message_delayed<M, Q>(&mut self, query: Q) -> MessageGuard<M>
144    where
145        M: 'static,
146        Q: FnOnce(&P) -> M,
147    {
148        let message = query(&self.proof);
149        MessageGuard(message)
150    }
151    /// unwraps the instance while absorbing it and also returning
152    /// challenges.
153    pub fn unwrap_guard<I: Message<F> + 'static, const N: usize>(
154        &mut self,
155        instance: MessageGuard<I>,
156    ) -> Result<(I, [F; N]), Error> {
157        let MessageGuard(instance) = instance;
158        let challenges = self.transcript.send_message(&instance)?;
159        Ok((instance, challenges))
160    }
161    /// Unwraps the instance while ignoring the transcript, caller must ensure
162    /// that not including the instance is acceptable.
163    /// Will still ultimately fail if the instance was expected in the pattern.
164    pub fn unwrap_instance_unsafe<I>(&mut self, instance: MessageGuard<I>) -> I {
165        instance.0
166    }
167    /// generates a multivariate point
168    pub fn point(&mut self) -> Result<Vec<F>, Error> {
169        self.transcript.point()
170    }
171}