use crate::{FieldElement, PrimitiveError};
use embed_doc_image::embed_doc_image;
use ruint::aliases::U256;
use serde::{Deserialize, Deserializer, Serialize, Serializer, de::Error as _};
#[expect(unused_imports, reason = "used in doc comments")]
use crate::circuit_inputs::QueryProofCircuitInput;
#[repr(u8)]
pub enum SessionFeType {
OprfSeed = 0x01,
Action = 0x02,
}
pub trait SessionFieldElement {
fn random_for_session<R: rand::CryptoRng + rand::RngCore>(
rng: &mut R,
element_type: SessionFeType,
) -> FieldElement;
fn is_valid_for_session(&self, element_type: SessionFeType) -> bool;
}
impl SessionFieldElement for FieldElement {
fn random_for_session<R: rand::CryptoRng + rand::RngCore>(
rng: &mut R,
element_type: SessionFeType,
) -> FieldElement {
let mut bytes = [0u8; 32];
rng.fill_bytes(&mut bytes);
bytes[0] = element_type as u8;
let seed = U256::from_be_bytes(bytes);
Self::try_from(seed).expect(
"should always fit in the field because with 0x01 as the MSB, the field element < babyjubjub modulus",
)
}
fn is_valid_for_session(&self, element_type: SessionFeType) -> bool {
self.to_be_bytes()[0] == element_type as u8
}
}
#[embed_doc_image("session-proofs.png", "assets/session-proofs.png")]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SessionId {
pub commitment: FieldElement,
pub oprf_seed: FieldElement,
}
impl SessionId {
const JSON_PREFIX: &str = "session_";
const DS_C: &[u8] = b"H(id, r)";
pub fn new(commitment: FieldElement, oprf_seed: FieldElement) -> Result<Self, PrimitiveError> {
if !oprf_seed.is_valid_for_session(SessionFeType::OprfSeed) {
return Err(PrimitiveError::InvalidInput {
attribute: "session_id".to_string(),
reason: "inner oprf_seed is not valid".to_string(),
});
}
Ok(Self {
commitment,
oprf_seed,
})
}
pub fn from_r_seed(
leaf_index: u64,
session_id_r_seed: FieldElement,
oprf_seed: FieldElement,
) -> Result<Self, PrimitiveError> {
let sub_ds = FieldElement::from_be_bytes_mod_order(Self::DS_C);
if !oprf_seed.is_valid_for_session(SessionFeType::OprfSeed) {
return Err(PrimitiveError::InvalidInput {
attribute: "session_id".to_string(),
reason: "inner oprf_seed is not valid".to_string(),
});
}
let mut input = [*sub_ds, leaf_index.into(), *session_id_r_seed];
poseidon2::bn254::t3::permutation_in_place(&mut input);
let commitment = input[1].into();
Ok(Self {
commitment,
oprf_seed,
})
}
pub fn generate_oprf_seed<R: rand::CryptoRng + rand::RngCore>(rng: &mut R) -> FieldElement {
FieldElement::random_for_session(rng, SessionFeType::OprfSeed)
}
#[must_use]
pub fn to_compressed_bytes(&self) -> [u8; 64] {
let mut bytes = [0u8; 64];
bytes[..32].copy_from_slice(&self.commitment.to_be_bytes());
bytes[32..].copy_from_slice(&self.oprf_seed.to_be_bytes());
bytes
}
pub fn from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
if bytes.len() != 64 {
return Err(format!(
"Invalid length: expected 64 bytes, got {}",
bytes.len()
));
}
let commitment = FieldElement::from_be_bytes(bytes[..32].try_into().unwrap())
.map_err(|e| format!("invalid commitment: {e}"))?;
let oprf_seed = FieldElement::from_be_bytes(bytes[32..].try_into().unwrap())
.map_err(|e| format!("invalid oprf_seed: {e}"))?;
if bytes[32] != SessionFeType::OprfSeed as u8 {
return Err("invalid prefix for oprf_seed".to_string());
}
Ok(Self {
commitment,
oprf_seed,
})
}
}
impl Default for SessionId {
fn default() -> Self {
let mut oprf_seed = [0u8; 32];
oprf_seed[0] = SessionFeType::OprfSeed as u8;
let oprf_seed = U256::from_be_bytes(oprf_seed)
.try_into()
.expect("always fits in the field");
Self {
commitment: FieldElement::ZERO,
oprf_seed,
}
}
}
impl Serialize for SessionId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let bytes = self.to_compressed_bytes();
if serializer.is_human_readable() {
serializer.serialize_str(&format!("{}{}", Self::JSON_PREFIX, hex::encode(bytes)))
} else {
serializer.serialize_bytes(&bytes)
}
}
}
impl<'de> Deserialize<'de> for SessionId {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bytes = if deserializer.is_human_readable() {
let value = String::deserialize(deserializer)?;
let hex_str = value.strip_prefix(Self::JSON_PREFIX).ok_or_else(|| {
D::Error::custom(format!(
"session id must start with '{}'",
Self::JSON_PREFIX
))
})?;
hex::decode(hex_str).map_err(D::Error::custom)?
} else {
Vec::deserialize(deserializer)?
};
Self::from_compressed_bytes(&bytes).map_err(D::Error::custom)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SessionNullifier {
nullifier: FieldElement,
action: FieldElement,
}
impl SessionNullifier {
const JSON_PREFIX: &str = "snil_";
pub fn new(nullifier: FieldElement, action: FieldElement) -> Result<Self, PrimitiveError> {
if !action.is_valid_for_session(SessionFeType::Action) {
return Err(PrimitiveError::InvalidInput {
attribute: "session_nullifier".to_string(),
reason: "inner action is not valid".to_string(),
});
}
Ok(Self { nullifier, action })
}
#[must_use]
pub const fn nullifier(&self) -> FieldElement {
self.nullifier
}
#[must_use]
pub const fn action(&self) -> FieldElement {
self.action
}
#[must_use]
pub fn as_ethereum_representation(&self) -> [U256; 2] {
[self.nullifier.into(), self.action.into()]
}
pub fn from_ethereum_representation(value: [U256; 2]) -> Result<Self, String> {
let nullifier =
FieldElement::try_from(value[0]).map_err(|e| format!("invalid nullifier: {e}"))?;
let action =
FieldElement::try_from(value[1]).map_err(|e| format!("invalid action: {e}"))?;
if !action.is_valid_for_session(SessionFeType::Action) {
return Err("inner action is not valid".to_string());
}
Ok(Self { nullifier, action })
}
#[must_use]
pub fn to_compressed_bytes(&self) -> [u8; 64] {
let mut bytes = [0u8; 64];
bytes[..32].copy_from_slice(&self.nullifier.to_be_bytes());
bytes[32..].copy_from_slice(&self.action.to_be_bytes());
bytes
}
pub fn from_compressed_bytes(bytes: &[u8]) -> Result<Self, String> {
if bytes.len() != 64 {
return Err(format!(
"Invalid length: expected 64 bytes, got {}",
bytes.len()
));
}
let nullifier = FieldElement::from_be_bytes(bytes[..32].try_into().unwrap())
.map_err(|e| format!("invalid nullifier: {e}"))?;
let action = FieldElement::from_be_bytes(bytes[32..].try_into().unwrap())
.map_err(|e| format!("invalid action: {e}"))?;
if bytes[32] != SessionFeType::Action as u8 {
return Err("invalid action. missing expected prefix.".to_string());
}
Ok(Self { nullifier, action })
}
}
impl Default for SessionNullifier {
fn default() -> Self {
let mut action = [0u8; 32];
action[0] = SessionFeType::Action as u8;
let action = U256::from_be_bytes(action)
.try_into()
.expect("always fits in the field");
Self {
nullifier: FieldElement::ZERO,
action,
}
}
}
impl Serialize for SessionNullifier {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let bytes = self.to_compressed_bytes();
if serializer.is_human_readable() {
serializer.serialize_str(&format!("{}{}", Self::JSON_PREFIX, hex::encode(bytes)))
} else {
serializer.serialize_bytes(&bytes)
}
}
}
impl<'de> Deserialize<'de> for SessionNullifier {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bytes = if deserializer.is_human_readable() {
let value = String::deserialize(deserializer)?;
let hex_str = value.strip_prefix(Self::JSON_PREFIX).ok_or_else(|| {
D::Error::custom(format!(
"session nullifier must start with '{}'",
Self::JSON_PREFIX
))
})?;
hex::decode(hex_str).map_err(D::Error::custom)?
} else {
Vec::deserialize(deserializer)?
};
Self::from_compressed_bytes(&bytes).map_err(D::Error::custom)
}
}
impl From<SessionNullifier> for [U256; 2] {
fn from(value: SessionNullifier) -> Self {
value.as_ethereum_representation()
}
}
#[cfg(test)]
mod session_id_tests {
use super::*;
use ruint::uint;
fn test_field_element(value: u64) -> FieldElement {
FieldElement::from(value)
}
fn test_oprf_seed(value: u64) -> FieldElement {
let n = U256::from(value)
| uint!(0x0100000000000000000000000000000000000000000000000000000000000000_U256);
FieldElement::try_from(n).expect("test value fits in field")
}
#[test]
fn test_new_and_accessors() {
let commitment = test_field_element(1001);
let seed = test_oprf_seed(42);
let id = SessionId::new(commitment, seed).unwrap();
assert_eq!(id.commitment, commitment);
assert_eq!(id.oprf_seed, seed);
}
#[test]
fn test_default() {
let id = SessionId::default();
assert_eq!(id.commitment, FieldElement::ZERO);
assert_eq!(
id.oprf_seed,
uint!(0x0100000000000000000000000000000000000000000000000000000000000000_U256)
.try_into()
.unwrap()
);
}
#[test]
fn test_bytes_roundtrip() {
let id = SessionId::new(test_field_element(1001), test_oprf_seed(42)).unwrap();
let bytes = id.to_compressed_bytes();
assert_eq!(bytes.len(), 64);
let decoded = SessionId::from_compressed_bytes(&bytes).unwrap();
assert_eq!(id, decoded);
}
#[test]
fn test_bytes_use_field_element_encoding() {
let id = SessionId::new(test_field_element(1001), test_oprf_seed(42)).unwrap();
let bytes = id.to_compressed_bytes();
let mut expected = [0u8; 64];
expected[..32].copy_from_slice(&id.commitment.to_be_bytes());
expected[32..].copy_from_slice(&id.oprf_seed.to_be_bytes());
assert_eq!(bytes, expected);
}
#[test]
fn test_invalid_bytes_length() {
let too_short = vec![0u8; 63];
let result = SessionId::from_compressed_bytes(&too_short);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid length"));
let too_long = vec![0u8; 65];
let result = SessionId::from_compressed_bytes(&too_long);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid length"));
}
#[test]
fn test_from_compressed_bytes_rejects_wrong_oprf_seed_prefix() {
let mut bytes = [0u8; 64];
bytes[32] = 0x00;
let result = SessionId::from_compressed_bytes(&bytes);
assert!(result.is_err());
assert!(
result.unwrap_err().contains("invalid prefix"),
"should reject oprf_seed without 0x01 prefix"
);
}
#[test]
fn test_json_roundtrip() {
let id = SessionId::new(test_field_element(1001), test_oprf_seed(42)).unwrap();
let json = serde_json::to_string(&id).unwrap();
assert!(json.starts_with("\"session_"));
assert!(json.ends_with('"'));
let decoded: SessionId = serde_json::from_str(&json).unwrap();
assert_eq!(id, decoded);
}
#[test]
fn test_json_format() {
let id = SessionId::new(test_field_element(1), test_oprf_seed(2)).unwrap();
let json = serde_json::to_string(&id).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.is_string());
let value = parsed.as_str().unwrap();
assert!(value.starts_with("session_"));
}
#[test]
fn test_json_wrong_prefix_rejected() {
let result = serde_json::from_str::<SessionId>("\"snil_00\"");
assert!(result.is_err());
}
#[test]
fn test_generates_random_oprf_seed() {
let mut rng = rand::rngs::OsRng;
let seed_1 = SessionId::generate_oprf_seed(&mut rng);
let seed_2 = SessionId::generate_oprf_seed(&mut rng);
assert_ne!(seed_1, seed_2);
}
#[test]
fn test_from_r_seed_generated_seed_has_session_prefix() {
let mut rng = rand::rngs::OsRng;
for _ in 0..50 {
let seed = SessionId::generate_oprf_seed(&mut rng);
assert_eq!(seed.to_u256() >> 248, U256::from(1));
}
}
#[test]
fn test_from_r_seed_commitment_snapshot() {
let leaf_index = 42u64;
let r_seed = test_field_element(123);
let oprf_seed = test_oprf_seed(456);
let session_id = SessionId::from_r_seed(leaf_index, r_seed, oprf_seed).unwrap();
let expected = "0x1e7853ebd4fc9d9f0232fdcfae116023610bdf66a22e2700445d7a2e0e7e6152"
.parse::<U256>()
.unwrap();
assert_eq!(
session_id.commitment.to_u256(),
expected,
"commitment snapashot for session commitment changed"
);
}
}
#[cfg(test)]
mod session_nullifier_tests {
use super::*;
use ruint::uint;
fn test_field_element(value: u64) -> FieldElement {
FieldElement::from(value)
}
fn test_action(value: u64) -> FieldElement {
let n = U256::from(value)
| uint!(0x0200000000000000000000000000000000000000000000000000000000000000_U256);
FieldElement::try_from(n).expect("test value fits in field")
}
#[test]
fn test_new_and_accessors() {
let nullifier = test_field_element(1001);
let action = test_action(42);
let session = SessionNullifier::new(nullifier, action).unwrap();
assert_eq!(session.nullifier(), nullifier);
assert_eq!(session.action(), action);
}
#[test]
fn test_as_ethereum_representation() {
let nullifier = test_field_element(100);
let action = test_action(200);
let session = SessionNullifier::new(nullifier, action).unwrap();
let repr = session.as_ethereum_representation();
assert_eq!(repr[0], U256::from(100));
assert_eq!(repr[1], action.to_u256());
}
#[test]
fn test_from_ethereum_representation() {
let action = test_action(200);
let repr = [U256::from(100), action.to_u256()];
let session = SessionNullifier::from_ethereum_representation(repr).unwrap();
assert_eq!(session.nullifier(), test_field_element(100));
assert_eq!(session.action(), action);
}
#[test]
fn test_json_roundtrip() {
let session = SessionNullifier::new(test_field_element(1001), test_action(42)).unwrap();
let json = serde_json::to_string(&session).unwrap();
assert!(json.starts_with("\"snil_"));
assert!(json.ends_with('"'));
let decoded: SessionNullifier = serde_json::from_str(&json).unwrap();
assert_eq!(session, decoded);
}
#[test]
fn test_json_format() {
let session = SessionNullifier::new(test_field_element(1), test_action(2)).unwrap();
let json = serde_json::to_string(&session).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert!(parsed.is_string());
let value = parsed.as_str().unwrap();
assert!(value.starts_with("snil_"));
}
#[test]
fn test_bytes_roundtrip() {
let session = SessionNullifier::new(test_field_element(1001), test_action(42)).unwrap();
let bytes = session.to_compressed_bytes();
assert_eq!(bytes.len(), 64);
let decoded = SessionNullifier::from_compressed_bytes(&bytes).unwrap();
assert_eq!(session, decoded);
}
#[test]
fn test_bytes_use_field_element_encoding() {
let session = SessionNullifier::new(test_field_element(1001), test_action(42)).unwrap();
let bytes = session.to_compressed_bytes();
let mut expected = [0u8; 64];
expected[..32].copy_from_slice(&session.nullifier().to_be_bytes());
expected[32..].copy_from_slice(&session.action().to_be_bytes());
assert_eq!(bytes, expected);
}
#[test]
fn test_invalid_bytes_length() {
let too_short = vec![0u8; 63];
let result = SessionNullifier::from_compressed_bytes(&too_short);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid length"));
let too_long = vec![0u8; 65];
let result = SessionNullifier::from_compressed_bytes(&too_long);
assert!(result.is_err());
assert!(result.unwrap_err().contains("Invalid length"));
}
#[test]
fn test_default() {
let session = SessionNullifier::default();
assert_eq!(session.nullifier(), FieldElement::ZERO);
let expected_action: FieldElement =
uint!(0x0200000000000000000000000000000000000000000000000000000000000000_U256)
.try_into()
.unwrap();
assert_eq!(session.action(), expected_action);
}
#[test]
fn test_into_u256_array() {
let action = test_action(200);
let session = SessionNullifier::new(test_field_element(100), action).unwrap();
let arr: [U256; 2] = session.into();
assert_eq!(arr[0], U256::from(100));
assert_eq!(arr[1], action.to_u256());
}
#[test]
fn test_new_rejects_invalid_action_prefix() {
let nullifier = test_field_element(1);
let bad_action = test_field_element(42); let result = SessionNullifier::new(nullifier, bad_action);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, PrimitiveError::InvalidInput { .. }),
"expected InvalidInput, got {err:?}"
);
}
#[test]
fn test_new_rejects_oprf_seed_prefix_as_action() {
let nullifier = test_field_element(1);
let oprf_prefixed = U256::from(42u64)
| uint!(0x0100000000000000000000000000000000000000000000000000000000000000_U256);
let bad_action = FieldElement::try_from(oprf_prefixed).unwrap();
assert!(SessionNullifier::new(nullifier, bad_action).is_err());
}
#[test]
fn test_from_ethereum_representation_rejects_invalid_action() {
let repr = [U256::from(100), U256::from(200)]; let result = SessionNullifier::from_ethereum_representation(repr);
assert!(result.is_err());
assert!(
result.unwrap_err().contains("action"),
"error should mention the action"
);
}
#[test]
fn test_from_compressed_bytes_rejects_invalid_action_prefix() {
let mut bytes = [0u8; 64];
bytes[32] = 0x00;
let result = SessionNullifier::from_compressed_bytes(&bytes);
assert!(result.is_err());
assert!(
result.unwrap_err().contains("action"),
"error should mention the action"
);
}
#[test]
fn test_json_rejects_invalid_action_prefix() {
let nullifier = test_field_element(1);
let bad_action = test_field_element(2); let mut bytes = [0u8; 64];
bytes[..32].copy_from_slice(&nullifier.to_be_bytes());
bytes[32..].copy_from_slice(&bad_action.to_be_bytes());
let json = format!("\"snil_{}\"", hex::encode(bytes));
let result = serde_json::from_str::<SessionNullifier>(&json);
assert!(result.is_err());
}
}