Skip to main content

primitives/types/identifiers/
session_id.rs

1use aes::cipher::generic_array::GenericArray;
2use derive_more::derive::{AsMut, AsRef, IntoIterator};
3use hybrid_array::Array;
4#[cfg(any(test, feature = "dev"))]
5use rand::{distributions::Standard, prelude::Distribution, Rng};
6use serde::{Deserialize, Serialize};
7#[cfg(any(test, feature = "dev"))]
8use typenum::Unsigned;
9use typenum::U16;
10
11use crate::{
12    constants::CollisionResistanceBytes,
13    hashing::{self, Digest},
14    random::Seed,
15    transcripts::Transcript,
16};
17
18/// The type of a session identifier, commonly used by protocols to achieve UC
19/// security in the CRS model. It should be unique for each protocol execution.
20/// We make it be random by:
21/// - Sampling an original Session ID via a distributed protocol (or via local sampling in tests).
22/// - Refreshing it upon each protocol execution by mixig it with the protocol transcript.
23#[derive(Default, Copy, AsRef, AsMut, Clone, Serialize, Deserialize, PartialEq, IntoIterator)]
24#[into_iterator(owned, ref, ref_mut)]
25pub struct SessionId(Array<u8, CollisionResistanceBytes>);
26
27impl SessionId {
28    /// Refreshes the session ID.
29    pub fn refresh_with<T: AsRef<[u8]> + ?Sized>(&mut self, tag: &T) {
30        self.0 = hashing::hash(&[self.as_ref(), tag.as_ref()]);
31    }
32
33    /// Refreshes the session ID by extracting randomness from a given transcript.
34    pub fn refresh_from<T: Transcript>(transcript: &mut T) -> SessionId {
35        SessionId(transcript.extract(b"new_session_id").into())
36    }
37}
38
39// ---------- Conversions ----------- //
40
41impl AsRef<[u8]> for SessionId {
42    fn as_ref(&self) -> &[u8] {
43        &self.0
44    }
45}
46
47impl std::ops::Deref for SessionId {
48    type Target = [u8; 32];
49
50    fn deref(&self) -> &[u8; 32] {
51        self.0.as_ref()
52    }
53}
54
55impl From<Digest> for SessionId {
56    fn from(value: Digest) -> Self {
57        SessionId(value)
58    }
59}
60
61impl From<SessionId> for [u8; 32] {
62    fn from(session_id: SessionId) -> [u8; 32] {
63        session_id.0.into()
64    }
65}
66
67impl<'sid> From<&'sid SessionId> for &'sid [u8; 32] {
68    fn from(session_id: &'sid SessionId) -> &'sid [u8; 32] {
69        (&session_id.0).into()
70    }
71}
72
73impl From<&SessionId> for u32 {
74    fn from(session_id: &SessionId) -> u32 {
75        u32::from_le_bytes(session_id.0[0..4].try_into().unwrap())
76    }
77}
78
79impl From<&SessionId> for [u8; 16] {
80    fn from(session_id: &SessionId) -> [u8; 16] {
81        let mut hash = [0; 16];
82        hashing::hash_into([session_id], &mut hash);
83        hash
84    }
85}
86
87impl From<&SessionId> for GenericArray<u8, U16> {
88    fn from(session_id: &SessionId) -> GenericArray<u8, U16> {
89        let mut hash = GenericArray::<u8, U16>::default();
90        hashing::hash_into([session_id], &mut hash);
91        hash
92    }
93}
94
95// ------ Generation ------ //
96
97/// Trait to gate the SessionId generation to:
98/// - Random sampling or hashing in dev/tests.
99/// - Running `drand` in production.
100pub trait SessionIdGenerator {
101    /// Generates a new session ID from a given seed.
102    fn generate_session_id_from(&self, seed: Seed) -> SessionId {
103        SessionId(seed.into())
104    }
105}
106
107#[cfg(any(test, feature = "dev"))]
108impl Distribution<SessionId> for Standard {
109    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SessionId {
110        let mut bytes = Array([0; CollisionResistanceBytes::USIZE]);
111        rng.fill_bytes(&mut bytes);
112        SessionId(bytes)
113    }
114}
115
116#[cfg(any(test, feature = "dev"))]
117impl SessionId {
118    /// Generates a new session ID by hashing the given seed.
119    pub fn from_hashed_seed(seed: &[u8]) -> SessionId {
120        let mut bytes = Array([0; CollisionResistanceBytes::USIZE]);
121        hashing::hash_into([seed], &mut bytes);
122        SessionId(bytes)
123    }
124}
125
126// --------- Display -------- //
127
128#[cfg(not(any(test, feature = "dev")))]
129impl std::fmt::Display for SessionId {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        write!(f, "SessionId({})", hex::encode(self.0))
132    }
133}
134
135#[cfg(not(any(test, feature = "dev")))]
136impl std::fmt::Debug for SessionId {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        write!(f, "SessionId({})", hex::encode(self.0))
139    }
140}
141
142#[cfg(any(test, feature = "dev"))]
143impl std::fmt::Display for SessionId {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        write!(f, "SessionId({}...)", &hex::encode(self.0)[0..6])
146    }
147}
148
149#[cfg(any(test, feature = "dev"))]
150impl std::fmt::Debug for SessionId {
151    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
152        write!(f, "SessionId({}...)", &hex::encode(self.0)[0..6])
153    }
154}