primitives/types/identifiers/
session_id.rs

1use aes::cipher::generic_array::GenericArray;
2use derive_more::derive::{AsMut, AsRef, IntoIterator};
3use hybrid_array::Array;
4#[cfg(any(test, feature = "dev"))]
5use rand::{distributions::Standard, prelude::Distribution, Rng};
6use serde::{Deserialize, Serialize};
7#[cfg(any(test, feature = "dev"))]
8use typenum::Unsigned;
9use typenum::U16;
10
11use crate::{
12    constants::CollisionResistanceBytes,
13    hashing::{self, Digest},
14    random::{derive_rng::DeriveRng, Seed},
15    transcripts::Transcript,
16};
17
18/// The type of a session identifier, commonly used by protocols to achieve UC
19/// security in the CRS model. It should be unique for each protocol execution.
20/// We make it be random by:
21/// - Sampling an original Session ID via a distributed protocol (or via local sampling in tests).
22/// - Refreshing it upon each protocol execution by mixig it with the protocol transcript.
23#[derive(Default, Copy, AsRef, AsMut, Clone, Serialize, Deserialize, PartialEq, IntoIterator)]
24#[into_iterator(owned, ref, ref_mut)]
25pub struct SessionId(Array<u8, CollisionResistanceBytes>);
26
27impl SessionId {
28    /// Refreshes the session ID.
29    pub fn refresh_with<T: AsRef<[u8]> + ?Sized>(&mut self, tag: &T) {
30        self.0 = hashing::hash(&[self.as_ref(), tag.as_ref()]);
31    }
32
33    /// Refreshes the session ID by extracting randomness from a given transcript.
34    pub fn refresh_from<T: Transcript>(transcript: &mut T) -> SessionId {
35        SessionId(transcript.extract(b"new_session_id").into())
36    }
37}
38
39impl DeriveRng for SessionId {}
40
41// ---------- Conversions ----------- //
42
43impl AsRef<[u8]> for SessionId {
44    fn as_ref(&self) -> &[u8] {
45        &self.0
46    }
47}
48
49impl From<Digest> for SessionId {
50    fn from(value: Digest) -> Self {
51        SessionId(value)
52    }
53}
54
55impl From<SessionId> for [u8; 32] {
56    fn from(session_id: SessionId) -> [u8; 32] {
57        session_id.0.into()
58    }
59}
60
61impl<'sid> From<&'sid SessionId> for &'sid [u8; 32] {
62    fn from(session_id: &'sid SessionId) -> &'sid [u8; 32] {
63        (&session_id.0).into()
64    }
65}
66
67impl From<&SessionId> for u32 {
68    fn from(session_id: &SessionId) -> u32 {
69        u32::from_le_bytes(session_id.0[0..4].try_into().unwrap())
70    }
71}
72
73impl From<&SessionId> for [u8; 16] {
74    fn from(session_id: &SessionId) -> [u8; 16] {
75        let mut hash = [0; 16];
76        hashing::hash_into([session_id], &mut hash);
77        hash
78    }
79}
80
81impl From<&SessionId> for GenericArray<u8, U16> {
82    fn from(session_id: &SessionId) -> GenericArray<u8, U16> {
83        let mut hash = GenericArray::<u8, U16>::default();
84        hashing::hash_into([session_id], &mut hash);
85        hash
86    }
87}
88
89// ------ Generation ------ //
90
91/// Trait to gate the SessionId generation to:
92/// - Random sampling or hashing in dev/tests.
93/// - Running `drand` in production.
94pub trait SessionIdGenerator {
95    /// Generates a new session ID from a given seed.
96    fn generate_session_id_from(&self, seed: Seed) -> SessionId {
97        SessionId(seed.into())
98    }
99}
100
101#[cfg(any(test, feature = "dev"))]
102impl Distribution<SessionId> for Standard {
103    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> SessionId {
104        let mut bytes = Array([0; CollisionResistanceBytes::USIZE]);
105        rng.fill_bytes(&mut bytes);
106        SessionId(bytes)
107    }
108}
109
110#[cfg(any(test, feature = "dev"))]
111impl SessionId {
112    /// Generates a new session ID by hashing the given seed.
113    pub fn from_hashed_seed(seed: &[u8]) -> SessionId {
114        let mut bytes = Array([0; CollisionResistanceBytes::USIZE]);
115        hashing::hash_into([seed], &mut bytes);
116        SessionId(bytes)
117    }
118}
119
120// --------- Display -------- //
121
122#[cfg(not(any(test, feature = "dev")))]
123impl std::fmt::Display for SessionId {
124    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
125        write!(f, "SessionId({})", hex::encode(self.0))
126    }
127}
128
129#[cfg(not(any(test, feature = "dev")))]
130impl std::fmt::Debug for SessionId {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        write!(f, "SessionId({})", hex::encode(self.0))
133    }
134}
135
136#[cfg(any(test, feature = "dev"))]
137impl std::fmt::Display for SessionId {
138    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
139        write!(f, "SessionId({}...)", &hex::encode(self.0)[0..6])
140    }
141}
142
143#[cfg(any(test, feature = "dev"))]
144impl std::fmt::Debug for SessionId {
145    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
146        write!(f, "SessionId({}...)", &hex::encode(self.0)[0..6])
147    }
148}