use curve25519_dalek::{
constants::RISTRETTO_BASEPOINT_TABLE,
ristretto::{CompressedRistretto, RistrettoPoint},
scalar::Scalar,
};
use rand::RngExt;
use serde::{Deserialize, Serialize};
use sha2::Sha512;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum OprfError {
InvalidBlindedInput,
InvalidBlindedOutput,
SerializationError,
}
impl std::fmt::Display for OprfError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidBlindedInput => write!(f, "Invalid blinded input"),
Self::InvalidBlindedOutput => write!(f, "Invalid blinded output"),
Self::SerializationError => write!(f, "Serialization error"),
}
}
}
impl std::error::Error for OprfError {}
pub type OprfResult<T> = Result<T, OprfError>;
#[derive(Clone)]
pub struct OprfServer {
secret_key: Scalar,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlindedInput {
point: CompressedRistretto,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlindedOutput {
point: CompressedRistretto,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OprfOutput {
value: [u8; 32],
}
pub struct OprfClient {
blind: Scalar,
input: Vec<u8>,
}
impl OprfServer {
pub fn new() -> Self {
let mut rng = rand::rng();
let mut bytes = [0u8; 32];
rng.fill(&mut bytes);
let secret_key = Scalar::from_bytes_mod_order(bytes);
Self { secret_key }
}
pub fn from_key(secret_key: Scalar) -> Self {
Self { secret_key }
}
pub fn evaluate(&self, blinded_input: &BlindedInput) -> BlindedOutput {
let point = blinded_input.point.decompress().unwrap_or_default();
let blinded_output_point = point * self.secret_key;
BlindedOutput {
point: blinded_output_point.compress(),
}
}
pub fn evaluate_direct(&self, input: &[u8]) -> OprfOutput {
let point = hash_to_point(input);
let output_point = point * self.secret_key;
OprfOutput {
value: blake3::hash(output_point.compress().as_bytes()).into(),
}
}
pub fn batch_evaluate(&self, inputs: &[BlindedInput]) -> Vec<BlindedOutput> {
inputs.iter().map(|input| self.evaluate(input)).collect()
}
pub fn public_key(&self) -> CompressedRistretto {
(&self.secret_key * RISTRETTO_BASEPOINT_TABLE).compress()
}
pub fn to_bytes(&self) -> [u8; 32] {
self.secret_key.to_bytes()
}
pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
let scalar = Scalar::from_canonical_bytes(*bytes)
.into_option()
.ok_or(OprfError::SerializationError)?;
Ok(Self::from_key(scalar))
}
}
impl Default for OprfServer {
fn default() -> Self {
Self::new()
}
}
impl OprfClient {
pub fn blind(input: &[u8]) -> (Self, BlindedInput) {
let mut rng = rand::rng();
let mut bytes = [0u8; 32];
rng.fill(&mut bytes);
let blind = Scalar::from_bytes_mod_order(bytes);
let point = hash_to_point(input);
let blinded_point = point * blind;
let client = Self {
blind,
input: input.to_vec(),
};
let blinded_input = BlindedInput {
point: blinded_point.compress(),
};
(client, blinded_input)
}
pub fn unblind(&self, blinded_output: &BlindedOutput) -> OprfOutput {
let point = blinded_output.point.decompress().unwrap_or_default();
let blind_inv = self.blind.invert();
let output_point = point * blind_inv;
OprfOutput {
value: blake3::hash(output_point.compress().as_bytes()).into(),
}
}
pub fn input(&self) -> &[u8] {
&self.input
}
}
impl BlindedInput {
pub fn to_bytes(&self) -> [u8; 32] {
self.point.to_bytes()
}
pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
Ok(Self {
point: CompressedRistretto(*bytes),
})
}
}
impl BlindedOutput {
pub fn to_bytes(&self) -> [u8; 32] {
self.point.to_bytes()
}
pub fn from_bytes(bytes: &[u8; 32]) -> OprfResult<Self> {
Ok(Self {
point: CompressedRistretto(*bytes),
})
}
}
impl OprfOutput {
pub fn as_bytes(&self) -> &[u8; 32] {
&self.value
}
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self { value: bytes }
}
}
fn hash_to_point(input: &[u8]) -> RistrettoPoint {
let scalar = Scalar::hash_from_bytes::<Sha512>(input);
&scalar * RISTRETTO_BASEPOINT_TABLE
}
pub struct BatchOprfClient {
clients: Vec<OprfClient>,
}
impl BatchOprfClient {
pub fn blind_batch(inputs: &[&[u8]]) -> (Self, Vec<BlindedInput>) {
let mut clients = Vec::with_capacity(inputs.len());
let mut blinded_inputs = Vec::with_capacity(inputs.len());
for input in inputs {
let (client, blinded_input) = OprfClient::blind(input);
clients.push(client);
blinded_inputs.push(blinded_input);
}
(Self { clients }, blinded_inputs)
}
pub fn unblind_batch(&self, blinded_outputs: &[BlindedOutput]) -> Vec<OprfOutput> {
self.clients
.iter()
.zip(blinded_outputs.iter())
.map(|(client, output)| client.unblind(output))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oprf_basic() {
let server = OprfServer::new();
let input = b"test-input";
let (client, blinded_input) = OprfClient::blind(input);
let blinded_output = server.evaluate(&blinded_input);
let output = client.unblind(&blinded_output);
let direct_output = server.evaluate_direct(input);
assert_eq!(output, direct_output);
}
#[test]
fn test_oprf_deterministic() {
let server = OprfServer::new();
let input = b"deterministic-test";
let (client1, blinded1) = OprfClient::blind(input);
let output1 = client1.unblind(&server.evaluate(&blinded1));
let (client2, blinded2) = OprfClient::blind(input);
let output2 = client2.unblind(&server.evaluate(&blinded2));
assert_eq!(output1, output2);
}
#[test]
fn test_oprf_different_inputs() {
let server = OprfServer::new();
let (client1, blinded1) = OprfClient::blind(b"input1");
let output1 = client1.unblind(&server.evaluate(&blinded1));
let (client2, blinded2) = OprfClient::blind(b"input2");
let output2 = client2.unblind(&server.evaluate(&blinded2));
assert_ne!(output1, output2);
}
#[test]
fn test_oprf_different_servers() {
let server1 = OprfServer::new();
let server2 = OprfServer::new();
let input = b"test";
let (client1, blinded1) = OprfClient::blind(input);
let output1 = client1.unblind(&server1.evaluate(&blinded1));
let (client2, blinded2) = OprfClient::blind(input);
let output2 = client2.unblind(&server2.evaluate(&blinded2));
assert_ne!(output1, output2);
}
#[test]
fn test_oprf_serialization() {
let server = OprfServer::new();
let bytes = server.to_bytes();
let server2 = OprfServer::from_bytes(&bytes).unwrap();
let input = b"serialize-test";
let output1 = server.evaluate_direct(input);
let output2 = server2.evaluate_direct(input);
assert_eq!(output1, output2);
}
#[test]
fn test_blinded_input_serialization() {
let (_client, blinded) = OprfClient::blind(b"test");
let bytes = blinded.to_bytes();
let blinded2 = BlindedInput::from_bytes(&bytes).unwrap();
assert_eq!(blinded.point, blinded2.point);
}
#[test]
fn test_blinded_output_serialization() {
let server = OprfServer::new();
let (_client, blinded_input) = OprfClient::blind(b"test");
let blinded_output = server.evaluate(&blinded_input);
let bytes = blinded_output.to_bytes();
let blinded_output2 = BlindedOutput::from_bytes(&bytes).unwrap();
assert_eq!(blinded_output.point, blinded_output2.point);
}
#[test]
fn test_batch_oprf() {
let server = OprfServer::new();
let inputs = vec![b"input1".as_ref(), b"input2".as_ref(), b"input3".as_ref()];
let (batch_client, blinded_inputs) = BatchOprfClient::blind_batch(&inputs);
let blinded_outputs = server.batch_evaluate(&blinded_inputs);
let outputs = batch_client.unblind_batch(&blinded_outputs);
for (input, output) in inputs.iter().zip(outputs.iter()) {
let direct = server.evaluate_direct(input);
assert_eq!(*output, direct);
}
}
#[test]
fn test_batch_oprf_different_outputs() {
let server = OprfServer::new();
let inputs = vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()];
let (batch_client, blinded_inputs) = BatchOprfClient::blind_batch(&inputs);
let blinded_outputs = server.batch_evaluate(&blinded_inputs);
let outputs = batch_client.unblind_batch(&blinded_outputs);
assert_ne!(outputs[0], outputs[1]);
assert_ne!(outputs[1], outputs[2]);
assert_ne!(outputs[0], outputs[2]);
}
#[test]
fn test_oprf_public_key() {
let server = OprfServer::new();
let pk = server.public_key();
assert!(pk.decompress().is_some());
}
#[test]
fn test_oprf_empty_input() {
let server = OprfServer::new();
let input = b"";
let (client, blinded_input) = OprfClient::blind(input);
let blinded_output = server.evaluate(&blinded_input);
let output = client.unblind(&blinded_output);
let direct = server.evaluate_direct(input);
assert_eq!(output, direct);
}
#[test]
fn test_oprf_large_input() {
let server = OprfServer::new();
let input = vec![0xAB; 10000];
let (client, blinded_input) = OprfClient::blind(&input);
let blinded_output = server.evaluate(&blinded_input);
let output = client.unblind(&blinded_output);
let direct = server.evaluate_direct(&input);
assert_eq!(output, direct);
}
#[test]
fn test_oprf_output_uniqueness() {
let server = OprfServer::new();
let mut outputs = std::collections::HashSet::new();
for i in 0..100 {
let input = format!("input-{}", i);
let (client, blinded) = OprfClient::blind(input.as_bytes());
let output = client.unblind(&server.evaluate(&blinded));
outputs.insert(output.value);
}
assert_eq!(outputs.len(), 100);
}
}