use data_size::DataSize;
use serde::{Deserialize, Deserializer};
use std::fmt;
use crate::{types::TypeInner, CandidType};
pub const UNBOUNDED: usize = usize::MAX;
#[derive(Clone, Eq, PartialEq, Debug, Default)]
pub struct BoundedVec<
const MAX_ALLOWED_LEN: usize,
const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
T,
>(Vec<T>);
impl<
const MAX_ALLOWED_LEN: usize,
const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
T: CandidType,
> CandidType
for BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
{
fn _ty() -> super::Type {
TypeInner::Vec(T::_ty()).into()
}
fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
where
S: super::Serializer,
{
self.0.idl_serialize(serializer)
}
}
impl<
const MAX_ALLOWED_LEN: usize,
const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
T,
> BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
{
pub fn new(data: Vec<T>) -> Self {
assert!(
MAX_ALLOWED_LEN != UNBOUNDED
|| MAX_ALLOWED_TOTAL_DATA_SIZE != UNBOUNDED
|| MAX_ALLOWED_ELEMENT_DATA_SIZE != UNBOUNDED,
"BoundedVec must be bounded by at least one parameter."
);
Self(data)
}
pub fn get(&self) -> &Vec<T> {
&self.0
}
}
impl<
'de,
const MAX_ALLOWED_LEN: usize,
const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
T: Deserialize<'de> + DataSize,
> Deserialize<'de>
for BoundedVec<MAX_ALLOWED_LEN, MAX_ALLOWED_TOTAL_DATA_SIZE, MAX_ALLOWED_ELEMENT_DATA_SIZE, T>
{
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct SeqVisitor<
const MAX_ALLOWED_LEN: usize,
const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
T,
> {
_marker: std::marker::PhantomData<T>,
}
use serde::de::{SeqAccess, Visitor};
impl<
'de,
const MAX_ALLOWED_LEN: usize,
const MAX_ALLOWED_TOTAL_DATA_SIZE: usize,
const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize,
T: Deserialize<'de> + DataSize,
> Visitor<'de>
for SeqVisitor<
MAX_ALLOWED_LEN,
MAX_ALLOWED_TOTAL_DATA_SIZE,
MAX_ALLOWED_ELEMENT_DATA_SIZE,
T,
>
{
type Value = BoundedVec<
MAX_ALLOWED_LEN,
MAX_ALLOWED_TOTAL_DATA_SIZE,
MAX_ALLOWED_ELEMENT_DATA_SIZE,
T,
>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(
formatter,
"{}",
describe_sequence(
MAX_ALLOWED_LEN,
MAX_ALLOWED_TOTAL_DATA_SIZE,
MAX_ALLOWED_ELEMENT_DATA_SIZE,
)
)
}
fn visit_seq<S>(self, mut seq: S) -> Result<Self::Value, S::Error>
where
S: SeqAccess<'de>,
{
let mut total_data_size = 0;
let mut elements = if MAX_ALLOWED_LEN == UNBOUNDED {
Vec::new()
} else {
Vec::with_capacity(MAX_ALLOWED_LEN)
};
while let Some(element) = seq.next_element::<T>()? {
if elements.len() >= MAX_ALLOWED_LEN {
return Err(serde::de::Error::custom(format!(
"The number of elements exceeds maximum allowed {MAX_ALLOWED_LEN}"
)));
}
let new_element_data_size = element.data_size();
if new_element_data_size > MAX_ALLOWED_ELEMENT_DATA_SIZE {
return Err(serde::de::Error::custom(format!(
"The single element data size exceeds maximum allowed {MAX_ALLOWED_ELEMENT_DATA_SIZE}"
)));
}
let new_total_data_size = total_data_size + new_element_data_size;
if new_total_data_size > MAX_ALLOWED_TOTAL_DATA_SIZE {
return Err(serde::de::Error::custom(format!(
"The total data size exceeds maximum allowed {MAX_ALLOWED_TOTAL_DATA_SIZE}"
)));
}
total_data_size = new_total_data_size;
elements.push(element);
}
Ok(BoundedVec::new(elements))
}
}
deserializer.deserialize_seq(SeqVisitor::<
MAX_ALLOWED_LEN,
MAX_ALLOWED_TOTAL_DATA_SIZE,
MAX_ALLOWED_ELEMENT_DATA_SIZE,
T,
> {
_marker: std::marker::PhantomData,
})
}
}
fn describe_sequence(
max_allowed_len: usize,
max_allowed_total_data_size: usize,
max_allowed_element_data_size: usize,
) -> String {
let mut msg = String::new();
if max_allowed_len != UNBOUNDED {
msg.push_str(&format!("max {max_allowed_len} elements"));
};
if max_allowed_total_data_size != UNBOUNDED {
if !msg.is_empty() {
msg.push_str(", ");
}
msg.push_str(&format!("max {max_allowed_total_data_size} bytes total"));
};
if max_allowed_element_data_size != UNBOUNDED {
if !msg.is_empty() {
msg.push_str(", ");
}
msg.push_str(&format!(
"max {max_allowed_element_data_size} bytes per element"
));
};
format!("a sequence with {msg}")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Decode, Encode};
#[test]
fn test_describe_sequence() {
assert_eq!(
describe_sequence(42, UNBOUNDED, UNBOUNDED),
"a sequence with max 42 elements".to_string()
);
assert_eq!(
describe_sequence(UNBOUNDED, 256, UNBOUNDED),
"a sequence with max 256 bytes total".to_string(),
);
assert_eq!(
describe_sequence(UNBOUNDED, UNBOUNDED, 64),
"a sequence with max 64 bytes per element".to_string(),
);
assert_eq!(
describe_sequence(42, 256, UNBOUNDED),
"a sequence with max 42 elements, max 256 bytes total".to_string(),
);
assert_eq!(
describe_sequence(42, UNBOUNDED, 64),
"a sequence with max 42 elements, max 64 bytes per element".to_string(),
);
assert_eq!(
describe_sequence(UNBOUNDED, 256, 64),
"a sequence with max 256 bytes total, max 64 bytes per element".to_string(),
);
assert_eq!(
describe_sequence(42, 256, 64),
"a sequence with max 42 elements, max 256 bytes total, max 64 bytes per element"
.to_string(),
);
}
#[test]
#[should_panic]
fn test_not_bounded_vector_fails() {
type NotBoundedVec = BoundedVec<UNBOUNDED, UNBOUNDED, UNBOUNDED, u8>;
let _ = NotBoundedVec::new(vec![1, 2, 3]);
}
#[test]
fn test_bounded_vector_lengths() {
type BoundedLen = BoundedVec<MAX_ALLOWED_LEN, UNBOUNDED, UNBOUNDED, u8>;
const MAX_ALLOWED_LEN: usize = 30;
const TEST_START: usize = 20;
const TEST_END: usize = 40;
for i in TEST_START..=TEST_END {
let data = BoundedLen::new(vec![42; i]);
let bytes = Encode!(&data).unwrap();
let result = Decode!(&bytes, BoundedLen);
if i <= MAX_ALLOWED_LEN {
assert!(result.is_ok());
assert_eq!(result.unwrap(), data);
} else {
assert!(result.is_err());
let error = result.unwrap_err();
assert!(
format!("{error:?}").contains(&format!(
"Deserialize error: The number of elements exceeds maximum allowed {MAX_ALLOWED_LEN}"
)),
"Actual: {}",
error
);
}
}
}
#[test]
fn test_bounded_vector_total_data_sizes() {
const MAX_ALLOWED_TOTAL_DATA_SIZE: usize = 100;
const ELEMENT_SIZE: usize = 37;
assert_ne!(MAX_ALLOWED_TOTAL_DATA_SIZE % ELEMENT_SIZE, 0);
for aimed_total_size in 64..=256 {
type BoundedSize =
BoundedVec<UNBOUNDED, MAX_ALLOWED_TOTAL_DATA_SIZE, UNBOUNDED, Vec<u8>>;
let element = vec![b'a'; ELEMENT_SIZE - std::mem::size_of::<Vec<u8>>()];
let elements_count = aimed_total_size / element.data_size();
let data = BoundedSize::new(vec![element; elements_count]);
let actual_total_size = data.get().data_size();
let bytes = Encode!(&data).unwrap();
let result = Decode!(&bytes, BoundedSize);
if actual_total_size <= MAX_ALLOWED_TOTAL_DATA_SIZE {
assert!(result.is_ok());
assert_eq!(result.unwrap(), data);
} else {
assert!(result.is_err());
let error = result.unwrap_err();
assert!(
format!("{error:?}").contains(&format!(
"Deserialize error: The total data size exceeds maximum allowed {MAX_ALLOWED_TOTAL_DATA_SIZE}"
)),
"Actual: {}",
error
);
}
}
}
#[test]
fn test_bounded_vector_element_data_sizes() {
const MAX_ALLOWED_ELEMENT_DATA_SIZE: usize = 100;
for element_size in 64..=256 {
type BoundedSize =
BoundedVec<UNBOUNDED, UNBOUNDED, MAX_ALLOWED_ELEMENT_DATA_SIZE, Vec<u8>>;
let element = vec![b'a'; element_size - std::mem::size_of::<Vec<u8>>()];
let data = BoundedSize::new(vec![element; 42]);
let bytes = Encode!(&data).unwrap();
let result = Decode!(&bytes, BoundedSize);
if element_size <= MAX_ALLOWED_ELEMENT_DATA_SIZE {
assert!(result.is_ok());
assert_eq!(result.unwrap(), data);
} else {
assert!(result.is_err());
let error = result.unwrap_err();
assert!(
format!("{error:?}").contains(&format!(
"Deserialize error: The single element data size exceeds maximum allowed {MAX_ALLOWED_ELEMENT_DATA_SIZE}"
)),
"Actual: {}",
error
);
}
}
}
}
mod data_size {
pub trait DataSize {
fn data_size(&self) -> usize {
0
}
}
impl DataSize for u8 {
fn data_size(&self) -> usize {
std::mem::size_of::<u8>()
}
}
impl DataSize for [u8] {
fn data_size(&self) -> usize {
std::mem::size_of_val(self)
}
}
impl DataSize for u64 {
fn data_size(&self) -> usize {
std::mem::size_of::<u64>()
}
}
impl DataSize for &str {
fn data_size(&self) -> usize {
self.as_bytes().data_size()
}
}
impl DataSize for String {
fn data_size(&self) -> usize {
self.as_bytes().data_size()
}
}
impl<T: DataSize> DataSize for Vec<T> {
fn data_size(&self) -> usize {
std::mem::size_of::<Self>() + self.iter().map(|x| x.data_size()).sum::<usize>()
}
}
impl DataSize for ic_principal::Principal {
fn data_size(&self) -> usize {
self.as_slice().len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_data_size_u8() {
assert_eq!(0_u8.data_size(), 1);
assert_eq!(42_u8.data_size(), 1);
}
#[test]
fn test_data_size_u8_slice() {
let a: [u8; 0] = [];
assert_eq!(a.data_size(), 0);
assert_eq!([1_u8].data_size(), 1);
assert_eq!([1_u8, 2_u8].data_size(), 2);
}
#[test]
fn test_data_size_u64() {
assert_eq!(0_u64.data_size(), 8);
assert_eq!(42_u64.data_size(), 8);
}
#[test]
fn test_data_size_u8_vec() {
let base = 24;
assert_eq!(Vec::<u8>::from([]).data_size(), base);
assert_eq!(Vec::<u8>::from([1]).data_size(), base + 1);
assert_eq!(Vec::<u8>::from([1, 2]).data_size(), base + 2);
}
#[test]
fn test_data_size_str() {
assert_eq!("a".data_size(), 1);
assert_eq!("ab".data_size(), 2);
}
#[test]
fn test_data_size_string() {
assert_eq!(String::from("a").data_size(), 1);
assert_eq!(String::from("ab").data_size(), 2);
for size_bytes in 0..1_024 {
assert_eq!(
String::from_utf8(vec![b'x'; size_bytes])
.unwrap()
.data_size(),
size_bytes
);
}
}
}
}