1use curve25519_dalek::{
37 constants::RISTRETTO_BASEPOINT_TABLE,
38 ristretto::{CompressedRistretto, RistrettoPoint},
39 scalar::Scalar,
40};
41use rand::Rng;
42use serde::{Deserialize, Serialize};
43use sha2::Sha512;
44
45#[derive(Debug, Clone, PartialEq, Eq)]
47pub enum OprfError {
48 InvalidBlindedInput,
50 InvalidBlindedOutput,
52 SerializationError,
54}
55
56impl std::fmt::Display for OprfError {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 match self {
59 Self::InvalidBlindedInput => write!(f, "Invalid blinded input"),
60 Self::InvalidBlindedOutput => write!(f, "Invalid blinded output"),
61 Self::SerializationError => write!(f, "Serialization error"),
62 }
63 }
64}
65
66impl std::error::Error for OprfError {}
67
68pub type OprfResult<T> = Result<T, OprfError>;
69
70#[derive(Clone)]
72pub struct OprfServer {
73 secret_key: Scalar,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct BlindedInput {
80 point: CompressedRistretto,
81}
82
83#[derive(Debug, Clone, Serialize, Deserialize)]
85pub struct BlindedOutput {
86 point: CompressedRistretto,
87}
88
89#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct OprfOutput {
92 value: [u8; 32],
93}
94
95pub struct OprfClient {
97 blind: Scalar,
99 input: Vec<u8>,
101}
102
103impl OprfServer {
104 pub fn new() -> Self {
106 let mut rng = rand::thread_rng();
107 let mut bytes = [0u8; 32];
108 rng.fill(&mut bytes);
109 let secret_key = Scalar::from_bytes_mod_order(bytes);
110 Self { secret_key }
111 }
112
113 pub fn from_key(secret_key: Scalar) -> Self {
115 Self { secret_key }
116 }
117
118 pub fn evaluate(&self, blinded_input: &BlindedInput) -> BlindedOutput {
122 let point = blinded_input.point.decompress().unwrap_or_default();
123 let blinded_output_point = point * self.secret_key;
124 BlindedOutput {
125 point: blinded_output_point.compress(),
126 }
127 }
128
129 pub fn evaluate_direct(&self, input: &[u8]) -> OprfOutput {
133 let point = hash_to_point(input);
135 let output_point = point * self.secret_key;
137 OprfOutput {
139 value: blake3::hash(output_point.compress().as_bytes()).into(),
140 }
141 }
142
143 pub fn batch_evaluate(&self, inputs: &[BlindedInput]) -> Vec<BlindedOutput> {
145 inputs.iter().map(|input| self.evaluate(input)).collect()
146 }
147
148 pub fn public_key(&self) -> CompressedRistretto {
150 (&self.secret_key * RISTRETTO_BASEPOINT_TABLE).compress()
151 }
152
153 pub fn to_bytes(&self) -> [u8; 32] {
155 self.secret_key.to_bytes()
156 }
157
158 pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
160 let scalar = Scalar::from_canonical_bytes(*bytes)
161 .into_option()
162 .ok_or(OprfError::SerializationError)?;
163 Ok(Self::from_key(scalar))
164 }
165}
166
167impl Default for OprfServer {
168 fn default() -> Self {
169 Self::new()
170 }
171}
172
173impl OprfClient {
174 pub fn blind(input: &[u8]) -> (Self, BlindedInput) {
178 let mut rng = rand::thread_rng();
179 let mut bytes = [0u8; 32];
180 rng.fill(&mut bytes);
181 let blind = Scalar::from_bytes_mod_order(bytes);
182
183 let point = hash_to_point(input);
185
186 let blinded_point = point * blind;
188
189 let client = Self {
190 blind,
191 input: input.to_vec(),
192 };
193
194 let blinded_input = BlindedInput {
195 point: blinded_point.compress(),
196 };
197
198 (client, blinded_input)
199 }
200
201 pub fn unblind(&self, blinded_output: &BlindedOutput) -> OprfOutput {
203 let point = blinded_output.point.decompress().unwrap_or_default();
204
205 let blind_inv = self.blind.invert();
207 let output_point = point * blind_inv;
208
209 OprfOutput {
211 value: blake3::hash(output_point.compress().as_bytes()).into(),
212 }
213 }
214
215 pub fn input(&self) -> &[u8] {
217 &self.input
218 }
219}
220
221impl BlindedInput {
222 pub fn to_bytes(&self) -> [u8; 32] {
224 self.point.to_bytes()
225 }
226
227 pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
229 Ok(Self {
230 point: CompressedRistretto(*bytes),
231 })
232 }
233}
234
235impl BlindedOutput {
236 pub fn to_bytes(&self) -> [u8; 32] {
238 self.point.to_bytes()
239 }
240
241 pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
243 Ok(Self {
244 point: CompressedRistretto(*bytes),
245 })
246 }
247}
248
249impl OprfOutput {
250 pub fn as_bytes(&self) -> &[u8; 32] {
252 &self.value
253 }
254
255 pub fn from_bytes(bytes: [u8; 32]) -> Self {
257 Self { value: bytes }
258 }
259}
260
261fn hash_to_point(input: &[u8]) -> RistrettoPoint {
263 let scalar = Scalar::hash_from_bytes::<Sha512>(input);
265 &scalar * RISTRETTO_BASEPOINT_TABLE
267}
268
269pub struct BatchOprfClient {
271 clients: Vec<OprfClient>,
272}
273
274impl BatchOprfClient {
275 pub fn blind_batch(inputs: &[&[u8]]) -> (Self, Vec<BlindedInput>) {
277 let mut clients = Vec::with_capacity(inputs.len());
278 let mut blinded_inputs = Vec::with_capacity(inputs.len());
279
280 for input in inputs {
281 let (client, blinded_input) = OprfClient::blind(input);
282 clients.push(client);
283 blinded_inputs.push(blinded_input);
284 }
285
286 (Self { clients }, blinded_inputs)
287 }
288
289 pub fn unblind_batch(&self, blinded_outputs: &[BlindedOutput]) -> Vec<OprfOutput> {
291 self.clients
292 .iter()
293 .zip(blinded_outputs.iter())
294 .map(|(client, output)| client.unblind(output))
295 .collect()
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_oprf_basic() {
305 let server = OprfServer::new();
306 let input = b"test-input";
307
308 let (client, blinded_input) = OprfClient::blind(input);
309 let blinded_output = server.evaluate(&blinded_input);
310 let output = client.unblind(&blinded_output);
311
312 let direct_output = server.evaluate_direct(input);
314 assert_eq!(output, direct_output);
315 }
316
317 #[test]
318 fn test_oprf_deterministic() {
319 let server = OprfServer::new();
320 let input = b"deterministic-test";
321
322 let (client1, blinded1) = OprfClient::blind(input);
324 let output1 = client1.unblind(&server.evaluate(&blinded1));
325
326 let (client2, blinded2) = OprfClient::blind(input);
327 let output2 = client2.unblind(&server.evaluate(&blinded2));
328
329 assert_eq!(output1, output2);
330 }
331
332 #[test]
333 fn test_oprf_different_inputs() {
334 let server = OprfServer::new();
335
336 let (client1, blinded1) = OprfClient::blind(b"input1");
337 let output1 = client1.unblind(&server.evaluate(&blinded1));
338
339 let (client2, blinded2) = OprfClient::blind(b"input2");
340 let output2 = client2.unblind(&server.evaluate(&blinded2));
341
342 assert_ne!(output1, output2);
343 }
344
345 #[test]
346 fn test_oprf_different_servers() {
347 let server1 = OprfServer::new();
348 let server2 = OprfServer::new();
349 let input = b"test";
350
351 let (client1, blinded1) = OprfClient::blind(input);
352 let output1 = client1.unblind(&server1.evaluate(&blinded1));
353
354 let (client2, blinded2) = OprfClient::blind(input);
355 let output2 = client2.unblind(&server2.evaluate(&blinded2));
356
357 assert_ne!(output1, output2);
359 }
360
361 #[test]
362 fn test_oprf_serialization() {
363 let server = OprfServer::new();
364 let bytes = server.to_bytes();
365 let server2 = OprfServer::from_bytes(&bytes).unwrap();
366
367 let input = b"serialize-test";
368 let output1 = server.evaluate_direct(input);
369 let output2 = server2.evaluate_direct(input);
370
371 assert_eq!(output1, output2);
372 }
373
374 #[test]
375 fn test_blinded_input_serialization() {
376 let (_client, blinded) = OprfClient::blind(b"test");
377 let bytes = blinded.to_bytes();
378 let blinded2 = BlindedInput::from_bytes(&bytes).unwrap();
379
380 assert_eq!(blinded.point, blinded2.point);
381 }
382
383 #[test]
384 fn test_blinded_output_serialization() {
385 let server = OprfServer::new();
386 let (_client, blinded_input) = OprfClient::blind(b"test");
387 let blinded_output = server.evaluate(&blinded_input);
388
389 let bytes = blinded_output.to_bytes();
390 let blinded_output2 = BlindedOutput::from_bytes(&bytes).unwrap();
391
392 assert_eq!(blinded_output.point, blinded_output2.point);
393 }
394
395 #[test]
396 fn test_batch_oprf() {
397 let server = OprfServer::new();
398 let inputs = vec![b"input1".as_ref(), b"input2".as_ref(), b"input3".as_ref()];
399
400 let (batch_client, blinded_inputs) = BatchOprfClient::blind_batch(&inputs);
401 let blinded_outputs = server.batch_evaluate(&blinded_inputs);
402 let outputs = batch_client.unblind_batch(&blinded_outputs);
403
404 for (input, output) in inputs.iter().zip(outputs.iter()) {
406 let direct = server.evaluate_direct(input);
407 assert_eq!(*output, direct);
408 }
409 }
410
411 #[test]
412 fn test_batch_oprf_different_outputs() {
413 let server = OprfServer::new();
414 let inputs = vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()];
415
416 let (batch_client, blinded_inputs) = BatchOprfClient::blind_batch(&inputs);
417 let blinded_outputs = server.batch_evaluate(&blinded_inputs);
418 let outputs = batch_client.unblind_batch(&blinded_outputs);
419
420 assert_ne!(outputs[0], outputs[1]);
422 assert_ne!(outputs[1], outputs[2]);
423 assert_ne!(outputs[0], outputs[2]);
424 }
425
426 #[test]
427 fn test_oprf_public_key() {
428 let server = OprfServer::new();
429 let pk = server.public_key();
430
431 assert!(pk.decompress().is_some());
433 }
434
435 #[test]
436 fn test_oprf_empty_input() {
437 let server = OprfServer::new();
438 let input = b"";
439
440 let (client, blinded_input) = OprfClient::blind(input);
441 let blinded_output = server.evaluate(&blinded_input);
442 let output = client.unblind(&blinded_output);
443
444 let direct = server.evaluate_direct(input);
445 assert_eq!(output, direct);
446 }
447
448 #[test]
449 fn test_oprf_large_input() {
450 let server = OprfServer::new();
451 let input = vec![0xAB; 10000]; let (client, blinded_input) = OprfClient::blind(&input);
454 let blinded_output = server.evaluate(&blinded_input);
455 let output = client.unblind(&blinded_output);
456
457 let direct = server.evaluate_direct(&input);
458 assert_eq!(output, direct);
459 }
460
461 #[test]
462 fn test_oprf_output_uniqueness() {
463 let server = OprfServer::new();
464 let mut outputs = std::collections::HashSet::new();
465
466 for i in 0..100 {
468 let input = format!("input-{}", i);
469 let (client, blinded) = OprfClient::blind(input.as_bytes());
470 let output = client.unblind(&server.evaluate(&blinded));
471 outputs.insert(output.value);
472 }
473
474 assert_eq!(outputs.len(), 100);
476 }
477}