solana_zk_token_sdk/encryption/
discrete_log.rs

1//! The discrete log implementation for the twisted ElGamal decryption.
2//!
3//! The implementation uses the baby-step giant-step method, which consists of a precomputation
4//! step and an online step. The precomputation step involves computing a hash table of a number
5//! of Ristretto points that is independent of a discrete log instance. The online phase computes
6//! the final discrete log solution using the discrete log instance and the pre-computed hash
7//! table. More details on the baby-step giant-step algorithm and the implementation can be found
8//! in the [spl documentation](https://spl.solana.com).
9//!
10//! The implementation is NOT intended to run in constant-time. There are some measures to prevent
11//! straightforward timing attacks. For instance, it does not short-circuit the search when a
12//! solution is found. However, the use of hashtables, batching, and threads make the
13//! implementation inherently not constant-time. This may theoretically allow an adversary to gain
14//! information on a discrete log solution depending on the execution time of the implementation.
15//!
16
17#![cfg(not(target_os = "solana"))]
18
19#[cfg(not(target_arch = "wasm32"))]
20use std::thread;
21use {
22    crate::RISTRETTO_POINT_LEN,
23    curve25519_dalek::{
24        constants::RISTRETTO_BASEPOINT_POINT as G,
25        ristretto::RistrettoPoint,
26        scalar::Scalar,
27        traits::{Identity, IsIdentity},
28    },
29    itertools::Itertools,
30    serde::{Deserialize, Serialize},
31    std::{collections::HashMap, num::NonZeroUsize},
32    thiserror::Error,
33};
34
35const TWO16: u64 = 65536; // 2^16
36const TWO17: u64 = 131072; // 2^17
37
38/// Maximum number of threads permitted for discrete log computation
39#[cfg(not(target_arch = "wasm32"))]
40const MAX_THREAD: usize = 65536;
41
42#[derive(Error, Clone, Debug, Eq, PartialEq)]
43pub enum DiscreteLogError {
44    #[error("discrete log number of threads not power-of-two")]
45    DiscreteLogThreads,
46    #[error("discrete log batch size too large")]
47    DiscreteLogBatchSize,
48}
49
50/// Type that captures a discrete log challenge.
51///
52/// The goal of discrete log is to find x such that x * generator = target.
53#[derive(Serialize, Deserialize, Copy, Clone, Debug, Eq, PartialEq)]
54pub struct DiscreteLog {
55    /// Generator point for discrete log
56    pub generator: RistrettoPoint,
57    /// Target point for discrete log
58    pub target: RistrettoPoint,
59    /// Number of threads used for discrete log computation
60    num_threads: Option<NonZeroUsize>,
61    /// Range bound for discrete log search derived from the max value to search for and
62    /// `num_threads`
63    range_bound: NonZeroUsize,
64    /// Ristretto point representing each step of the discrete log search
65    step_point: RistrettoPoint,
66    /// Ristretto point compression batch size
67    compression_batch_size: NonZeroUsize,
68}
69
70#[derive(Serialize, Deserialize, Default)]
71pub struct DecodePrecomputation(HashMap<[u8; RISTRETTO_POINT_LEN], u16>);
72
73/// Builds a HashMap of 2^16 elements
74#[allow(dead_code)]
75fn decode_u32_precomputation(generator: RistrettoPoint) -> DecodePrecomputation {
76    let mut hashmap = HashMap::new();
77
78    let two17_scalar = Scalar::from(TWO17);
79    let identity = RistrettoPoint::identity(); // 0 * G
80    let generator = two17_scalar * generator; // 2^17 * G
81
82    // iterator for 2^17*0G , 2^17*1G, 2^17*2G, ...
83    let ristretto_iter = RistrettoIterator::new((identity, 0), (generator, 1));
84    for (point, x_hi) in ristretto_iter.take(TWO16 as usize) {
85        let key = point.compress().to_bytes();
86        hashmap.insert(key, x_hi as u16);
87    }
88
89    DecodePrecomputation(hashmap)
90}
91
92/// Pre-computed HashMap needed for decryption. The HashMap is independent of (works for) any key.
93pub static DECODE_PRECOMPUTATION_FOR_G: std::sync::LazyLock<DecodePrecomputation> =
94    std::sync::LazyLock::new(|| {
95        static DECODE_PRECOMPUTATION_FOR_G_BINCODE: &[u8] =
96            include_bytes!("decode_u32_precomputation_for_G.bincode");
97        bincode::deserialize(DECODE_PRECOMPUTATION_FOR_G_BINCODE).unwrap_or_default()
98    });
99
100/// Solves the discrete log instance using a 16/16 bit offline/online split
101impl DiscreteLog {
102    /// Discrete log instance constructor.
103    ///
104    /// Default number of threads set to 1.
105    pub fn new(generator: RistrettoPoint, target: RistrettoPoint) -> Self {
106        Self {
107            generator,
108            target,
109            num_threads: None,
110            range_bound: (TWO16 as usize).try_into().unwrap(),
111            step_point: G,
112            compression_batch_size: 32.try_into().unwrap(),
113        }
114    }
115
116    /// Adjusts number of threads in a discrete log instance.
117    #[cfg(not(target_arch = "wasm32"))]
118    pub fn num_threads(&mut self, num_threads: NonZeroUsize) -> Result<(), DiscreteLogError> {
119        // number of threads must be a positive power-of-two integer
120        if !num_threads.is_power_of_two() || num_threads.get() > MAX_THREAD {
121            return Err(DiscreteLogError::DiscreteLogThreads);
122        }
123
124        self.num_threads = Some(num_threads);
125        self.range_bound = (TWO16 as usize)
126            .checked_div(num_threads.get())
127            .and_then(|range_bound| range_bound.try_into().ok())
128            .unwrap(); // `num_threads` cannot exceed `TWO16`, so `range_bound` always non-zero
129        self.step_point = Scalar::from(num_threads.get() as u64) * G;
130
131        Ok(())
132    }
133
134    /// Adjusts inversion batch size in a discrete log instance.
135    pub fn set_compression_batch_size(
136        &mut self,
137        compression_batch_size: NonZeroUsize,
138    ) -> Result<(), DiscreteLogError> {
139        if compression_batch_size.get() >= TWO16 as usize {
140            return Err(DiscreteLogError::DiscreteLogBatchSize);
141        }
142        self.compression_batch_size = compression_batch_size;
143
144        Ok(())
145    }
146
147    /// Solves the discrete log problem under the assumption that the solution
148    /// is a positive 32-bit number.
149    pub fn decode_u32(self) -> Option<u64> {
150        if let Some(num_threads) = self.num_threads {
151            #[cfg(not(target_arch = "wasm32"))]
152            {
153                let mut starting_point = self.target;
154                let handles = (0..num_threads.get())
155                    .map(|i| {
156                        let ristretto_iterator = RistrettoIterator::new(
157                            (starting_point, i as u64),
158                            (-(&self.step_point), num_threads.get() as u64),
159                        );
160
161                        let handle = thread::spawn(move || {
162                            Self::decode_range(
163                                ristretto_iterator,
164                                self.range_bound,
165                                self.compression_batch_size,
166                            )
167                        });
168
169                        starting_point -= G;
170                        handle
171                    })
172                    .collect::<Vec<_>>();
173
174                handles
175                    .into_iter()
176                    .map_while(|h| h.join().ok())
177                    .find(|x| x.is_some())
178                    .flatten()
179            }
180            #[cfg(target_arch = "wasm32")]
181            unreachable!() // `self.num_threads` always `None` on wasm target
182        } else {
183            let ristretto_iterator =
184                RistrettoIterator::new((self.target, 0_u64), (-(&self.step_point), 1u64));
185
186            Self::decode_range(
187                ristretto_iterator,
188                self.range_bound,
189                self.compression_batch_size,
190            )
191        }
192    }
193
194    fn decode_range(
195        ristretto_iterator: RistrettoIterator,
196        range_bound: NonZeroUsize,
197        compression_batch_size: NonZeroUsize,
198    ) -> Option<u64> {
199        let hashmap = &DECODE_PRECOMPUTATION_FOR_G;
200        let mut decoded = None;
201
202        for batch in &ristretto_iterator
203            .take(range_bound.get())
204            .chunks(compression_batch_size.get())
205        {
206            // batch compression currently errors if any point in the batch is the identity point
207            let (batch_points, batch_indices): (Vec<_>, Vec<_>) = batch
208                .filter(|(point, index)| {
209                    if point.is_identity() {
210                        decoded = Some(*index);
211                        return false;
212                    }
213                    true
214                })
215                .unzip();
216
217            let batch_compressed = RistrettoPoint::double_and_compress_batch(&batch_points);
218
219            for (point, x_lo) in batch_compressed.iter().zip(batch_indices.iter()) {
220                let key = point.to_bytes();
221                if hashmap.0.contains_key(&key) {
222                    let x_hi = hashmap.0[&key];
223                    decoded = Some(x_lo + TWO16 * x_hi as u64);
224                }
225            }
226        }
227
228        decoded
229    }
230}
231
232/// Hashable Ristretto iterator.
233///
234/// Given an initial point X and a stepping point P, the iterator iterates through
235/// X + 0*P, X + 1*P, X + 2*P, X + 3*P, ...
236struct RistrettoIterator {
237    pub current: (RistrettoPoint, u64),
238    pub step: (RistrettoPoint, u64),
239}
240
241impl RistrettoIterator {
242    fn new(current: (RistrettoPoint, u64), step: (RistrettoPoint, u64)) -> Self {
243        RistrettoIterator { current, step }
244    }
245}
246
247impl Iterator for RistrettoIterator {
248    type Item = (RistrettoPoint, u64);
249
250    fn next(&mut self) -> Option<Self::Item> {
251        let r = self.current;
252        self.current = (self.current.0 + self.step.0, self.current.1 + self.step.1);
253        Some(r)
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use {super::*, std::time::Instant};
260
261    #[test]
262    #[allow(non_snake_case)]
263    fn test_serialize_decode_u32_precomputation_for_G() {
264        let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
265        // let decode_u32_precomputation_for_G = decode_u32_precomputation(G);
266
267        if decode_u32_precomputation_for_G.0 != DECODE_PRECOMPUTATION_FOR_G.0 {
268            use std::{fs::File, io::Write, path::PathBuf};
269            let mut f = File::create(PathBuf::from(
270                "src/encryption/decode_u32_precomputation_for_G.bincode",
271            ))
272            .unwrap();
273            f.write_all(&bincode::serialize(&decode_u32_precomputation_for_G).unwrap())
274                .unwrap();
275            panic!("Rebuild and run this test again");
276        }
277    }
278
279    #[test]
280    fn test_decode_correctness() {
281        // general case
282        let amount: u64 = 4294967295;
283
284        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
285
286        // Very informal measurements for now
287        let start_computation = Instant::now();
288        let decoded = instance.decode_u32();
289        let computation_secs = start_computation.elapsed().as_secs_f64();
290
291        assert_eq!(amount, decoded.unwrap());
292
293        println!("single thread discrete log computation secs: {computation_secs:?} sec");
294    }
295
296    #[cfg(not(target_arch = "wasm32"))]
297    #[test]
298    fn test_decode_correctness_threaded() {
299        // general case
300        let amount: u64 = 55;
301
302        let mut instance = DiscreteLog::new(G, Scalar::from(amount) * G);
303        instance.num_threads(4.try_into().unwrap()).unwrap();
304
305        // Very informal measurements for now
306        let start_computation = Instant::now();
307        let decoded = instance.decode_u32();
308        let computation_secs = start_computation.elapsed().as_secs_f64();
309
310        assert_eq!(amount, decoded.unwrap());
311
312        println!("4 thread discrete log computation: {computation_secs:?} sec");
313
314        // amount 0
315        let amount: u64 = 0;
316
317        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
318
319        let decoded = instance.decode_u32();
320        assert_eq!(amount, decoded.unwrap());
321
322        // amount 1
323        let amount: u64 = 1;
324
325        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
326
327        let decoded = instance.decode_u32();
328        assert_eq!(amount, decoded.unwrap());
329
330        // amount 2
331        let amount: u64 = 2;
332
333        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
334
335        let decoded = instance.decode_u32();
336        assert_eq!(amount, decoded.unwrap());
337
338        // amount 3
339        let amount: u64 = 3;
340
341        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
342
343        let decoded = instance.decode_u32();
344        assert_eq!(amount, decoded.unwrap());
345
346        // max amount
347        let amount: u64 = (1_u64 << 32) - 1;
348
349        let instance = DiscreteLog::new(G, Scalar::from(amount) * G);
350
351        let decoded = instance.decode_u32();
352        assert_eq!(amount, decoded.unwrap());
353    }
354}