use crate::component::{ComponentData, ComponentId};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use tls_codec::{TlsDeserialize, TlsDeserializeBytes, TlsSerialize, TlsSize, VLBytes};
#[derive(thiserror::Error, Debug)]
enum BuildAppDataDictionaryError {
#[error("entries not in order")]
EntriesNotInOrder,
#[error("duplicate entries")]
DuplicateEntries,
}
#[derive(PartialEq, Eq, Clone, Debug, Default, Serialize, Deserialize)]
pub struct AppDataDictionary {
component_data: BTreeMap<ComponentId, ComponentData>,
}
impl AppDataDictionary {
pub fn new() -> Self {
Self {
component_data: BTreeMap::new(),
}
}
pub fn entries(&self) -> impl Iterator<Item = &ComponentData> {
self.component_data.values()
}
pub fn to_entries(self) -> Vec<ComponentData> {
self.entries().cloned().collect()
}
pub fn len(&self) -> usize {
self.component_data.len()
}
pub fn is_empty(&self) -> bool {
self.component_data.is_empty()
}
pub fn get(&self, component_id: &ComponentId) -> Option<&[u8]> {
self.component_data
.get(component_id)
.map(|component_data| component_data.data())
}
pub fn insert(&mut self, component_id: ComponentId, data: Vec<u8>) -> Option<VLBytes> {
self.component_data
.insert(
component_id,
ComponentData::from_parts(component_id, data.into()),
)
.map(|component_data| component_data.into_data())
}
pub fn contains(&self, component_id: &ComponentId) -> bool {
self.component_data.contains_key(component_id)
}
pub fn remove(&mut self, component_id: &ComponentId) -> Option<VLBytes> {
self.component_data
.remove(component_id)
.map(|component_data| component_data.into_data())
}
fn try_from_data(
data: impl IntoIterator<Item = ComponentData>,
) -> Result<Self, BuildAppDataDictionaryError> {
let mut map = BTreeMap::<ComponentId, ComponentData>::new();
for component_data in data {
let (component_id, data) = component_data.into_parts();
if map.contains_key(&component_id) {
return Err(BuildAppDataDictionaryError::DuplicateEntries);
}
if let Some((max, _)) = map.last_key_value() {
if *max > component_id {
return Err(BuildAppDataDictionaryError::EntriesNotInOrder);
}
}
let _ = map.insert(component_id, ComponentData::from_parts(component_id, data));
}
Ok(Self {
component_data: map,
})
}
}
impl tls_codec::Size for AppDataDictionary {
fn tls_serialized_len(&self) -> usize {
let data: Vec<&ComponentData> = self.entries().collect();
data.tls_serialized_len()
}
}
impl tls_codec::Serialize for AppDataDictionary {
fn tls_serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<usize, tls_codec::Error> {
let data: Vec<&ComponentData> = self.entries().collect();
data.tls_serialize(writer)
}
}
impl tls_codec::Deserialize for AppDataDictionary {
fn tls_deserialize<R: std::io::Read>(bytes: &mut R) -> Result<Self, tls_codec::Error> {
let data = Vec::<ComponentData>::tls_deserialize(bytes)?;
AppDataDictionary::try_from_data(data)
.map_err(|e| tls_codec::Error::DecodingError(e.to_string()))
}
}
impl tls_codec::DeserializeBytes for AppDataDictionary {
fn tls_deserialize_bytes(bytes: &[u8]) -> Result<(Self, &[u8]), tls_codec::Error> {
use tls_codec::{Deserialize, Size};
let mut bytes_ref = bytes;
let dictionary = Self::tls_deserialize(&mut bytes_ref)?;
let remainder = &bytes[dictionary.tls_serialized_len()..];
Ok((dictionary, remainder))
}
}
#[derive(
PartialEq,
Eq,
Clone,
Debug,
Default,
Serialize,
Deserialize,
TlsSerialize,
TlsDeserialize,
TlsDeserializeBytes,
TlsSize,
)]
pub struct AppDataDictionaryExtension {
dictionary: AppDataDictionary,
}
impl AppDataDictionaryExtension {
pub fn dictionary(&self) -> &AppDataDictionary {
&self.dictionary
}
pub fn new(dictionary: AppDataDictionary) -> Self {
Self { dictionary }
}
}
#[cfg(test)]
mod test {
use super::*;
use tls_codec::{Deserialize, Serialize};
#[openmls_test::openmls_test]
fn test_serialize_deserialize() {
let mut dictionary = AppDataDictionary::new();
let _ = dictionary.insert(0, vec![]);
let _ = dictionary.insert(0, vec![1, 2, 3]);
assert_eq!(dictionary.len(), 1);
let mut dictionary_orig = AppDataDictionary::new();
let _ = dictionary_orig.insert(5, vec![]);
let _ = dictionary_orig.insert(0, vec![1, 2, 3]);
assert_eq!(dictionary_orig.len(), 2);
let extension_orig = AppDataDictionaryExtension::new(dictionary_orig.clone());
let bytes = extension_orig.tls_serialize_detached().unwrap();
let extension_deserialized =
AppDataDictionaryExtension::tls_deserialize(&mut bytes.as_slice()).unwrap();
assert_eq!(extension_orig, extension_deserialized);
}
#[openmls_test::openmls_test]
fn test_serialization_empty() {
let dictionary_orig = AppDataDictionary::new();
assert_eq!(dictionary_orig.len(), 0);
let extension_orig = AppDataDictionaryExtension::new(dictionary_orig.clone());
let bytes = extension_orig.tls_serialize_detached().unwrap();
let extension_deserialized =
AppDataDictionaryExtension::tls_deserialize(&mut bytes.as_slice()).unwrap();
assert_eq!(extension_orig, extension_deserialized);
}
#[openmls_test::openmls_test]
fn test_serialization_invalid() {
let component_data = vec![
ComponentData::from_parts(5, vec![].into()),
ComponentData::from_parts(5, vec![1, 2, 3].into()),
ComponentData::from_parts(9, vec![].into()),
];
let serialized = component_data.tls_serialize_detached().unwrap();
let err = AppDataDictionary::tls_deserialize_exact(serialized).unwrap_err();
assert_eq!(
err,
tls_codec::Error::DecodingError(
BuildAppDataDictionaryError::DuplicateEntries.to_string()
)
);
let component_data = vec![
ComponentData::from_parts(5, vec![].into()),
ComponentData::from_parts(9, vec![].into()),
ComponentData::from_parts(4, vec![1, 2, 3].into()),
];
let serialized = component_data.tls_serialize_detached().unwrap();
let err = AppDataDictionary::tls_deserialize_exact(serialized).unwrap_err();
assert_eq!(
err,
tls_codec::Error::DecodingError(
BuildAppDataDictionaryError::EntriesNotInOrder.to_string()
)
);
}
}