1#![cfg_attr(docsrs, doc(cfg(feature = "borsh")))]
2
3#[cfg(feature = "borsh-schema")]
4use ::{
5 alloc::collections::btree_map::BTreeMap,
6 alloc::format,
7 borsh::schema::{add_definition, Declaration, Definition},
8 borsh::BorshSchema,
9};
10
11use alloc::vec::Vec;
12use core::hash::BuildHasher;
13use core::hash::Hash;
14use core::mem::size_of;
15
16use borsh::error::ERROR_ZST_FORBIDDEN;
17use borsh::io::{Error, ErrorKind, Read, Result, Write};
18use borsh::{BorshDeserialize, BorshSerialize};
19
20use crate::map::IndexMap;
21use crate::set::IndexSet;
22
23impl<K, V, S> BorshSerialize for IndexMap<K, V, S>
24where
25 K: BorshSerialize,
26 V: BorshSerialize,
27{
28 #[inline]
29 fn serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
30 check_zst::<K>()?;
31
32 let iterator = self.iter();
33
34 u32::try_from(iterator.len())
35 .map_err(|_| ErrorKind::InvalidData)?
36 .serialize(writer)?;
37
38 for (key, value) in iterator {
39 key.serialize(writer)?;
40 value.serialize(writer)?;
41 }
42
43 Ok(())
44 }
45}
46
47impl<K, V, S> BorshDeserialize for IndexMap<K, V, S>
48where
49 K: BorshDeserialize + Eq + Hash,
50 V: BorshDeserialize,
51 S: BuildHasher + Default,
52{
53 #[inline]
54 fn deserialize_reader<R: Read>(reader: &mut R) -> Result<Self> {
55 check_zst::<K>()?;
56 let vec = <Vec<(K, V)>>::deserialize_reader(reader)?;
57 Ok(vec.into_iter().collect::<IndexMap<K, V, S>>())
58 }
59}
60
61#[cfg(feature = "borsh-schema")]
62impl<K, V, S> BorshSchema for IndexMap<K, V, S>
63where
64 K: BorshSchema,
65 V: BorshSchema,
66{
67 fn add_definitions_recursively(definitions: &mut BTreeMap<Declaration, Definition>) {
68 let definition = Definition::Sequence {
69 length_width: Definition::DEFAULT_LENGTH_WIDTH,
70 length_range: Definition::DEFAULT_LENGTH_RANGE,
71 elements: <(K, V)>::declaration(),
72 };
73 add_definition(Self::declaration(), definition, definitions);
74 <(K, V)>::add_definitions_recursively(definitions);
75 }
76
77 fn declaration() -> Declaration {
78 format!(r#"IndexMap<{}, {}>"#, K::declaration(), V::declaration())
79 }
80}
81
82impl<T, S> BorshSerialize for IndexSet<T, S>
83where
84 T: BorshSerialize,
85{
86 #[inline]
87 fn serialize<W: Write>(&self, writer: &mut W) -> Result<()> {
88 check_zst::<T>()?;
89
90 let iterator = self.iter();
91
92 u32::try_from(iterator.len())
93 .map_err(|_| ErrorKind::InvalidData)?
94 .serialize(writer)?;
95
96 for item in iterator {
97 item.serialize(writer)?;
98 }
99
100 Ok(())
101 }
102}
103
104impl<T, S> BorshDeserialize for IndexSet<T, S>
105where
106 T: BorshDeserialize + Eq + Hash,
107 S: BuildHasher + Default,
108{
109 #[inline]
110 fn deserialize_reader<R: Read>(reader: &mut R) -> Result<Self> {
111 check_zst::<T>()?;
112 let vec = <Vec<T>>::deserialize_reader(reader)?;
113 Ok(vec.into_iter().collect::<IndexSet<T, S>>())
114 }
115}
116
117#[cfg(feature = "borsh-schema")]
118impl<T, S> BorshSchema for IndexSet<T, S>
119where
120 T: BorshSchema,
121{
122 fn add_definitions_recursively(definitions: &mut BTreeMap<Declaration, Definition>) {
123 let definition = Definition::Sequence {
124 length_width: Definition::DEFAULT_LENGTH_WIDTH,
125 length_range: Definition::DEFAULT_LENGTH_RANGE,
126 elements: <T>::declaration(),
127 };
128 add_definition(Self::declaration(), definition, definitions);
129 <T>::add_definitions_recursively(definitions);
130 }
131
132 fn declaration() -> Declaration {
133 format!(r#"IndexSet<{}>"#, T::declaration())
134 }
135}
136
137fn check_zst<T>() -> Result<()> {
138 if size_of::<T>() == 0 {
139 return Err(Error::new(ErrorKind::InvalidData, ERROR_ZST_FORBIDDEN));
140 }
141 Ok(())
142}
143
144#[cfg(test)]
145mod borsh_tests {
146 use super::*;
147
148 #[test]
149 fn map_borsh_roundtrip() {
150 let original_map: IndexMap<i32, i32> = {
151 let mut map = IndexMap::new();
152 map.insert(1, 2);
153 map.insert(3, 4);
154 map.insert(5, 6);
155 map
156 };
157 let serialized_map = borsh::to_vec(&original_map).unwrap();
158 let deserialized_map: IndexMap<i32, i32> =
159 BorshDeserialize::try_from_slice(&serialized_map).unwrap();
160 assert_eq!(original_map, deserialized_map);
161 }
162
163 #[test]
164 fn set_borsh_roundtrip() {
165 let original_map: IndexSet<i32> = [1, 2, 3, 4, 5, 6].into_iter().collect();
166 let serialized_map = borsh::to_vec(&original_map).unwrap();
167 let deserialized_map: IndexSet<i32> =
168 BorshDeserialize::try_from_slice(&serialized_map).unwrap();
169 assert_eq!(original_map, deserialized_map);
170 }
171}