mls_rs_core/group/
context.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5use crate::{crypto::CipherSuite, extension::ExtensionList, protocol_version::ProtocolVersion};
6use alloc::{vec, vec::Vec};
7use core::{
8    fmt::{self, Debug},
9    ops::Deref,
10};
11use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
12
13#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
14#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
15#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
16pub struct ConfirmedTranscriptHash(
17    #[mls_codec(with = "mls_rs_codec::byte_vec")]
18    #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
19    Vec<u8>,
20);
21
22impl Debug for ConfirmedTranscriptHash {
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        crate::debug::pretty_bytes(&self.0)
25            .named("ConfirmedTranscriptHash")
26            .fmt(f)
27    }
28}
29
30impl Deref for ConfirmedTranscriptHash {
31    type Target = Vec<u8>;
32
33    fn deref(&self) -> &Self::Target {
34        &self.0
35    }
36}
37
38impl From<Vec<u8>> for ConfirmedTranscriptHash {
39    fn from(value: Vec<u8>) -> Self {
40        Self(value)
41    }
42}
43
44#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
45#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
46#[cfg_attr(
47    all(feature = "ffi", not(test)),
48    safer_ffi_gen::ffi_type(clone, opaque)
49)]
50#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
51pub struct GroupContext {
52    pub protocol_version: ProtocolVersion,
53    pub cipher_suite: CipherSuite,
54    #[mls_codec(with = "mls_rs_codec::byte_vec")]
55    #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
56    pub group_id: Vec<u8>,
57    pub epoch: u64,
58    #[cfg_attr(feature = "serde", serde(with = "crate::vec_serde"))]
59    #[mls_codec(with = "mls_rs_codec::byte_vec")]
60    pub tree_hash: Vec<u8>,
61    pub confirmed_transcript_hash: ConfirmedTranscriptHash,
62    pub extensions: ExtensionList,
63}
64
65impl Debug for GroupContext {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        f.debug_struct("GroupContext")
68            .field("protocol_version", &self.protocol_version)
69            .field("cipher_suite", &self.cipher_suite)
70            .field("group_id", &crate::debug::pretty_group_id(&self.group_id))
71            .field("epoch", &self.epoch)
72            .field("tree_hash", &crate::debug::pretty_bytes(&self.tree_hash))
73            .field("confirmed_transcript_hash", &self.confirmed_transcript_hash)
74            .field("extensions", &self.extensions)
75            .finish()
76    }
77}
78
79#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
80impl GroupContext {
81    /// Create a group context for a new MLS group.
82    pub fn new(
83        protocol_version: ProtocolVersion,
84        cipher_suite: CipherSuite,
85        group_id: Vec<u8>,
86        tree_hash: Vec<u8>,
87        extensions: ExtensionList,
88    ) -> GroupContext {
89        GroupContext {
90            protocol_version,
91            cipher_suite,
92            group_id,
93            epoch: 0,
94            tree_hash,
95            confirmed_transcript_hash: vec![].into(),
96            extensions,
97        }
98    }
99
100    /// Get the current protocol version in use by the group.
101    pub fn version(&self) -> ProtocolVersion {
102        self.protocol_version
103    }
104
105    /// Get the current cipher suite in use by the group.
106    pub fn cipher_suite(&self) -> CipherSuite {
107        self.cipher_suite
108    }
109
110    /// Get the unique identifier of this group.
111    pub fn group_id(&self) -> &[u8] {
112        &self.group_id
113    }
114
115    /// Get the current epoch number of the group's state.
116    pub fn epoch(&self) -> u64 {
117        self.epoch
118    }
119
120    pub fn extensions(&self) -> &ExtensionList {
121        &self.extensions
122    }
123}