chie_crypto/
oprf.rs

1//! Oblivious Pseudorandom Function (OPRF) implementation.
2//!
3//! OPRF allows a client to compute PRF(key, input) where:
4//! - Server holds the secret key
5//! - Client provides the input
6//! - Server learns nothing about the input
7//! - Client learns only PRF(key, input), not the key
8//!
9//! Perfect for:
10//! - Private rate limiting (check if user exceeded quota without revealing identity)
11//! - Password-authenticated key exchange
12//! - Anonymous credentials
13//! - Privacy-preserving set membership tests
14//!
15//! # Example
16//! ```
17//! use chie_crypto::oprf::{OprfServer, OprfClient};
18//!
19//! // Server setup
20//! let server = OprfServer::new();
21//!
22//! // Client blind request
23//! let input = b"user@example.com";
24//! let (client, blinded_input) = OprfClient::blind(input);
25//!
26//! // Server evaluates on blinded input
27//! let blinded_output = server.evaluate(&blinded_input);
28//!
29//! // Client unblinds to get PRF output
30//! let prf_output = client.unblind(&blinded_output);
31//!
32//! // Can verify this matches direct evaluation (for testing)
33//! assert_eq!(prf_output, server.evaluate_direct(input));
34//! ```
35
36use curve25519_dalek::{
37    constants::RISTRETTO_BASEPOINT_TABLE,
38    ristretto::{CompressedRistretto, RistrettoPoint},
39    scalar::Scalar,
40};
41use rand::Rng;
42use serde::{Deserialize, Serialize};
43use sha2::Sha512;
44
45/// OPRF error types.
46#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum OprfError {
48    /// Invalid blinded input
49    InvalidBlindedInput,
50    /// Invalid blinded output
51    InvalidBlindedOutput,
52    /// Serialization error
53    SerializationError,
54}
55
56impl std::fmt::Display for OprfError {
57    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58        match self {
59            Self::InvalidBlindedInput => write!(f, "Invalid blinded input"),
60            Self::InvalidBlindedOutput => write!(f, "Invalid blinded output"),
61            Self::SerializationError => write!(f, "Serialization error"),
62        }
63    }
64}
65
66impl std::error::Error for OprfError {}
67
68pub type OprfResult<T> = Result<T, OprfError>;
69
70/// OPRF server holding the secret key.
71#[derive(Clone)]
72pub struct OprfServer {
73    /// Secret key for the PRF
74    secret_key: Scalar,
75}
76
77/// Blinded input sent from client to server.
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct BlindedInput {
80    point: CompressedRistretto,
81}
82
83/// Blinded output sent from server to client.
84#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct BlindedOutput {
86    point: CompressedRistretto,
87}
88
89/// PRF output after unblinding.
90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct OprfOutput {
92    value: [u8; 32],
93}
94
95/// OPRF client state during the protocol.
96pub struct OprfClient {
97    /// Blinding factor (kept secret)
98    blind: Scalar,
99    /// Original input (for verification)
100    input: Vec<u8>,
101}
102
103impl OprfServer {
104    /// Create a new OPRF server with random secret key.
105    pub fn new() -> Self {
106        let mut rng = rand::thread_rng();
107        let mut bytes = [0u8; 32];
108        rng.fill(&mut bytes);
109        let secret_key = Scalar::from_bytes_mod_order(bytes);
110        Self { secret_key }
111    }
112
113    /// Create OPRF server from existing secret key.
114    pub fn from_key(secret_key: Scalar) -> Self {
115        Self { secret_key }
116    }
117
118    /// Evaluate the OPRF on a blinded input.
119    ///
120    /// Returns blinded output that client can unblind.
121    pub fn evaluate(&self, blinded_input: &BlindedInput) -> BlindedOutput {
122        let point = blinded_input.point.decompress().unwrap_or_default();
123        let blinded_output_point = point * self.secret_key;
124        BlindedOutput {
125            point: blinded_output_point.compress(),
126        }
127    }
128
129    /// Evaluate OPRF directly on input (for testing/verification).
130    ///
131    /// In real protocol, server never sees the actual input.
132    pub fn evaluate_direct(&self, input: &[u8]) -> OprfOutput {
133        // Hash input to point
134        let point = hash_to_point(input);
135        // Apply secret key
136        let output_point = point * self.secret_key;
137        // Hash to final output
138        OprfOutput {
139            value: blake3::hash(output_point.compress().as_bytes()).into(),
140        }
141    }
142
143    /// Batch evaluate multiple blinded inputs.
144    pub fn batch_evaluate(&self, inputs: &[BlindedInput]) -> Vec<BlindedOutput> {
145        inputs.iter().map(|input| self.evaluate(input)).collect()
146    }
147
148    /// Get the server's public key (for verification protocols).
149    pub fn public_key(&self) -> CompressedRistretto {
150        (&self.secret_key * RISTRETTO_BASEPOINT_TABLE).compress()
151    }
152
153    /// Serialize server secret key.
154    pub fn to_bytes(&self) -> [u8; 32] {
155        self.secret_key.to_bytes()
156    }
157
158    /// Deserialize server secret key.
159    pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
160        let scalar = Scalar::from_canonical_bytes(*bytes)
161            .into_option()
162            .ok_or(OprfError::SerializationError)?;
163        Ok(Self::from_key(scalar))
164    }
165}
166
167impl Default for OprfServer {
168    fn default() -> Self {
169        Self::new()
170    }
171}
172
173impl OprfClient {
174    /// Blind an input to send to the server.
175    ///
176    /// Returns (client state, blinded input to send to server).
177    pub fn blind(input: &[u8]) -> (Self, BlindedInput) {
178        let mut rng = rand::thread_rng();
179        let mut bytes = [0u8; 32];
180        rng.fill(&mut bytes);
181        let blind = Scalar::from_bytes_mod_order(bytes);
182
183        // Hash input to point
184        let point = hash_to_point(input);
185
186        // Blind the point
187        let blinded_point = point * blind;
188
189        let client = Self {
190            blind,
191            input: input.to_vec(),
192        };
193
194        let blinded_input = BlindedInput {
195            point: blinded_point.compress(),
196        };
197
198        (client, blinded_input)
199    }
200
201    /// Unblind the server's response to get the final PRF output.
202    pub fn unblind(&self, blinded_output: &BlindedOutput) -> OprfOutput {
203        let point = blinded_output.point.decompress().unwrap_or_default();
204
205        // Unblind by multiplying by blind^(-1)
206        let blind_inv = self.blind.invert();
207        let output_point = point * blind_inv;
208
209        // Hash to final output
210        OprfOutput {
211            value: blake3::hash(output_point.compress().as_bytes()).into(),
212        }
213    }
214
215    /// Get the original input (for debugging).
216    pub fn input(&self) -> &[u8] {
217        &self.input
218    }
219}
220
221impl BlindedInput {
222    /// Serialize to bytes.
223    pub fn to_bytes(&self) -> [u8; 32] {
224        self.point.to_bytes()
225    }
226
227    /// Deserialize from bytes.
228    pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
229        Ok(Self {
230            point: CompressedRistretto(*bytes),
231        })
232    }
233}
234
235impl BlindedOutput {
236    /// Serialize to bytes.
237    pub fn to_bytes(&self) -> [u8; 32] {
238        self.point.to_bytes()
239    }
240
241    /// Deserialize from bytes.
242    pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
243        Ok(Self {
244            point: CompressedRistretto(*bytes),
245        })
246    }
247}
248
249impl OprfOutput {
250    /// Get output as bytes.
251    pub fn as_bytes(&self) -> &[u8; 32] {
252        &self.value
253    }
254
255    /// Create from bytes.
256    pub fn from_bytes(bytes: [u8; 32]) -> Self {
257        Self { value: bytes }
258    }
259}
260
261/// Hash arbitrary input to a Ristretto point.
262fn hash_to_point(input: &[u8]) -> RistrettoPoint {
263    // Hash input using SHA-512 and convert to scalar
264    let scalar = Scalar::hash_from_bytes::<Sha512>(input);
265    // Multiply base point to get deterministic point
266    &scalar * RISTRETTO_BASEPOINT_TABLE
267}
268
269/// Batch OPRF client for multiple inputs.
270pub struct BatchOprfClient {
271    clients: Vec<OprfClient>,
272}
273
274impl BatchOprfClient {
275    /// Blind multiple inputs at once.
276    pub fn blind_batch(inputs: &[&[u8]]) -> (Self, Vec<BlindedInput>) {
277        let mut clients = Vec::with_capacity(inputs.len());
278        let mut blinded_inputs = Vec::with_capacity(inputs.len());
279
280        for input in inputs {
281            let (client, blinded_input) = OprfClient::blind(input);
282            clients.push(client);
283            blinded_inputs.push(blinded_input);
284        }
285
286        (Self { clients }, blinded_inputs)
287    }
288
289    /// Unblind multiple outputs.
290    pub fn unblind_batch(&self, blinded_outputs: &[BlindedOutput]) -> Vec<OprfOutput> {
291        self.clients
292            .iter()
293            .zip(blinded_outputs.iter())
294            .map(|(client, output)| client.unblind(output))
295            .collect()
296    }
297}
298
299#[cfg(test)]
300mod tests {
301    use super::*;
302
303    #[test]
304    fn test_oprf_basic() {
305        let server = OprfServer::new();
306        let input = b"test-input";
307
308        let (client, blinded_input) = OprfClient::blind(input);
309        let blinded_output = server.evaluate(&blinded_input);
310        let output = client.unblind(&blinded_output);
311
312        // Verify output matches direct evaluation
313        let direct_output = server.evaluate_direct(input);
314        assert_eq!(output, direct_output);
315    }
316
317    #[test]
318    fn test_oprf_deterministic() {
319        let server = OprfServer::new();
320        let input = b"deterministic-test";
321
322        // Multiple evaluations should give same result
323        let (client1, blinded1) = OprfClient::blind(input);
324        let output1 = client1.unblind(&server.evaluate(&blinded1));
325
326        let (client2, blinded2) = OprfClient::blind(input);
327        let output2 = client2.unblind(&server.evaluate(&blinded2));
328
329        assert_eq!(output1, output2);
330    }
331
332    #[test]
333    fn test_oprf_different_inputs() {
334        let server = OprfServer::new();
335
336        let (client1, blinded1) = OprfClient::blind(b"input1");
337        let output1 = client1.unblind(&server.evaluate(&blinded1));
338
339        let (client2, blinded2) = OprfClient::blind(b"input2");
340        let output2 = client2.unblind(&server.evaluate(&blinded2));
341
342        assert_ne!(output1, output2);
343    }
344
345    #[test]
346    fn test_oprf_different_servers() {
347        let server1 = OprfServer::new();
348        let server2 = OprfServer::new();
349        let input = b"test";
350
351        let (client1, blinded1) = OprfClient::blind(input);
352        let output1 = client1.unblind(&server1.evaluate(&blinded1));
353
354        let (client2, blinded2) = OprfClient::blind(input);
355        let output2 = client2.unblind(&server2.evaluate(&blinded2));
356
357        // Different servers should give different outputs
358        assert_ne!(output1, output2);
359    }
360
361    #[test]
362    fn test_oprf_serialization() {
363        let server = OprfServer::new();
364        let bytes = server.to_bytes();
365        let server2 = OprfServer::from_bytes(&bytes).unwrap();
366
367        let input = b"serialize-test";
368        let output1 = server.evaluate_direct(input);
369        let output2 = server2.evaluate_direct(input);
370
371        assert_eq!(output1, output2);
372    }
373
374    #[test]
375    fn test_blinded_input_serialization() {
376        let (_client, blinded) = OprfClient::blind(b"test");
377        let bytes = blinded.to_bytes();
378        let blinded2 = BlindedInput::from_bytes(&bytes).unwrap();
379
380        assert_eq!(blinded.point, blinded2.point);
381    }
382
383    #[test]
384    fn test_blinded_output_serialization() {
385        let server = OprfServer::new();
386        let (_client, blinded_input) = OprfClient::blind(b"test");
387        let blinded_output = server.evaluate(&blinded_input);
388
389        let bytes = blinded_output.to_bytes();
390        let blinded_output2 = BlindedOutput::from_bytes(&bytes).unwrap();
391
392        assert_eq!(blinded_output.point, blinded_output2.point);
393    }
394
395    #[test]
396    fn test_batch_oprf() {
397        let server = OprfServer::new();
398        let inputs = vec![b"input1".as_ref(), b"input2".as_ref(), b"input3".as_ref()];
399
400        let (batch_client, blinded_inputs) = BatchOprfClient::blind_batch(&inputs);
401        let blinded_outputs = server.batch_evaluate(&blinded_inputs);
402        let outputs = batch_client.unblind_batch(&blinded_outputs);
403
404        // Verify each output matches direct evaluation
405        for (input, output) in inputs.iter().zip(outputs.iter()) {
406            let direct = server.evaluate_direct(input);
407            assert_eq!(*output, direct);
408        }
409    }
410
411    #[test]
412    fn test_batch_oprf_different_outputs() {
413        let server = OprfServer::new();
414        let inputs = vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()];
415
416        let (batch_client, blinded_inputs) = BatchOprfClient::blind_batch(&inputs);
417        let blinded_outputs = server.batch_evaluate(&blinded_inputs);
418        let outputs = batch_client.unblind_batch(&blinded_outputs);
419
420        // All outputs should be different
421        assert_ne!(outputs[0], outputs[1]);
422        assert_ne!(outputs[1], outputs[2]);
423        assert_ne!(outputs[0], outputs[2]);
424    }
425
426    #[test]
427    fn test_oprf_public_key() {
428        let server = OprfServer::new();
429        let pk = server.public_key();
430
431        // Public key should be valid compressed point
432        assert!(pk.decompress().is_some());
433    }
434
435    #[test]
436    fn test_oprf_empty_input() {
437        let server = OprfServer::new();
438        let input = b"";
439
440        let (client, blinded_input) = OprfClient::blind(input);
441        let blinded_output = server.evaluate(&blinded_input);
442        let output = client.unblind(&blinded_output);
443
444        let direct = server.evaluate_direct(input);
445        assert_eq!(output, direct);
446    }
447
448    #[test]
449    fn test_oprf_large_input() {
450        let server = OprfServer::new();
451        let input = vec![0xAB; 10000]; // 10KB input
452
453        let (client, blinded_input) = OprfClient::blind(&input);
454        let blinded_output = server.evaluate(&blinded_input);
455        let output = client.unblind(&blinded_output);
456
457        let direct = server.evaluate_direct(&input);
458        assert_eq!(output, direct);
459    }
460
461    #[test]
462    fn test_oprf_output_uniqueness() {
463        let server = OprfServer::new();
464        let mut outputs = std::collections::HashSet::new();
465
466        // Generate many outputs
467        for i in 0..100 {
468            let input = format!("input-{}", i);
469            let (client, blinded) = OprfClient::blind(input.as_bytes());
470            let output = client.unblind(&server.evaluate(&blinded));
471            outputs.insert(output.value);
472        }
473
474        // All should be unique
475        assert_eq!(outputs.len(), 100);
476    }
477}