1use std::hash::Hash;
2
3macro_rules! hashmap_newtype {
4 ($map:ident, $name:expr) => {
5 #[derive(Debug, Clone, PartialEq)]
14 #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(rename = $name))]
15 #[cfg_attr(
16 feature = "serde",
17 derive(serde::Serialize, serde::Deserialize),
18 serde(
19 from = "Collection::<K>",
20 into = "Collection::<K>",
21 bound(serialize = "K: serde::Serialize + Clone")
22 )
23 )]
24 pub struct $map<K: Eq + Hash>(indexmap::IndexMap<K, f64, rustc_hash::FxBuildHasher>);
25
26 impl<K: Eq + Hash> Default for $map<K> {
27 fn default() -> Self {
28 Self(indexmap::IndexMap::default())
29 }
30 }
31
32 impl<K: Eq + Hash> std::ops::Deref for $map<K> {
33 type Target = indexmap::IndexMap<K, f64, rustc_hash::FxBuildHasher>;
34
35 fn deref(&self) -> &Self::Target {
36 &self.0
37 }
38 }
39
40 impl<K: Eq + Hash> std::ops::DerefMut for $map<K> {
41 fn deref_mut(&mut self) -> &mut Self::Target {
42 &mut self.0
43 }
44 }
45
46 impl<K: Eq + Hash> IntoIterator for $map<K> {
47 type Item = (K, f64);
48 type IntoIter = indexmap::map::IntoIter<K, f64>;
49
50 fn into_iter(self) -> Self::IntoIter {
51 self.0.into_iter()
52 }
53 }
54
55 impl<K: Eq + Hash> FromIterator<(K, f64)> for $map<K> {
56 fn from_iter<I: IntoIterator<Item = (K, f64)>>(iter: I) -> Self {
57 Self(indexmap::IndexMap::from_iter(iter))
58 }
59 }
60
61 impl<K: Eq + Hash> From<Collection<K>> for $map<K> {
62 fn from(value: Collection<K>) -> Self {
63 match value {
64 Collection::OneOf(entry) => std::iter::once((entry, 1.0)).collect(),
65 Collection::SumOf(entries) => {
66 entries.into_iter().zip(std::iter::repeat(1.0)).collect()
67 }
68 Collection::MapOf(entries) => Self(entries),
69 }
70 }
71 }
72
73 impl<K: Eq + Hash + Clone> Into<Collection<K>> for $map<K> {
74 fn into(self) -> Collection<K> {
75 Collection::MapOf(self.0)
76 }
77 }
78 };
79}
80
81hashmap_newtype!(DemandGroup, "DemandGroup");
87hashmap_newtype!(ProductGroup, "ProductGroup");
88
89#[derive(Debug)]
92#[cfg_attr(
93 feature = "schemars",
94 derive(schemars::JsonSchema),
95 schemars(rename = "{T}Group", untagged)
96)]
97enum Collection<T: Eq + Hash> {
98 OneOf(T),
99 SumOf(Vec<T>),
100 MapOf(indexmap::IndexMap<T, f64, rustc_hash::FxBuildHasher>),
101}
102
103#[cfg(feature = "serde")]
104impl<'de, T: Eq + Hash + serde::Deserialize<'de>> serde::Deserialize<'de> for Collection<T> {
105 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
106 where
107 D: serde::Deserializer<'de>,
108 {
109 serde_untagged::UntaggedEnumVisitor::new()
110 .string(|one| {
111 T::deserialize(serde::de::value::StrDeserializer::new(one)).map(Self::OneOf)
112 })
113 .seq(|sum| sum.deserialize().map(Self::SumOf))
114 .map(|map| map.deserialize().map(Self::MapOf))
115 .deserialize(deserializer)
116 }
117}
118
119#[cfg(feature = "serde")]
120impl<T: Eq + Hash + serde::Serialize> serde::Serialize for Collection<T> {
121 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122 where
123 S: serde::Serializer,
124 {
125 match self {
126 Self::OneOf(one) => one.serialize(serializer),
127 Self::SumOf(sum) => sum.serialize(serializer),
128 Self::MapOf(map) => map.serialize(serializer),
129 }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_from_scalar() {
139 let result: Result<Collection<String>, _> = serde_json::from_str(r#""apples""#);
140 match result {
141 Ok(Collection::OneOf(_)) => (),
142 Ok(_) => {
143 panic!("scalar parsed incorrectly");
144 }
145 Err(_) => {
146 panic!("could not parse collection");
147 }
148 };
149 }
150
151 #[test]
152 fn test_from_vector() {
153 let result: Result<Collection<String>, _> =
154 serde_json::from_str(r#"["apples", "bananas"]"#);
155 match result {
156 Ok(Collection::SumOf(_)) => (),
157 Ok(_) => {
158 panic!("scalar parsed incorrectly");
159 }
160 Err(_) => {
161 panic!("could not parse collection");
162 }
163 };
164 }
165
166 #[test]
167 fn test_from_map() {
168 let result: Result<Collection<String>, _> =
169 serde_json::from_str(r#"{"apples": 1, "bananas": 1}"#);
170 match result {
171 Ok(Collection::MapOf(_)) => (),
172 Ok(_) => {
173 panic!("scalar parsed incorrectly");
174 }
175 Err(_) => {
176 panic!("could not parse collection");
177 }
178 };
179 }
180
181 #[test]
182 fn test_to_scalar() {
183 match serde_json::to_value(Collection::OneOf("apples".to_owned())) {
184 Ok(serde_json::Value::String(value)) => assert_eq!(value, "apples"),
185 Ok(_) => panic!("scalar serialized incorrectly"),
186 Err(_) => panic!("could not serialize collection"),
187 }
188 }
189
190 #[test]
191 fn test_to_vector() {
192 match serde_json::to_value(Collection::SumOf(vec![
193 "apples".to_owned(),
194 "bananas".to_owned(),
195 ])) {
196 Ok(serde_json::Value::Array(value)) => assert_eq!(value.len(), 2),
197 Ok(_) => panic!("vector serialized incorrectly"),
198 Err(_) => panic!("could not serialize collection"),
199 }
200 }
201
202 #[test]
203 fn test_to_map() {
204 match serde_json::to_value(Collection::MapOf(
205 ["apples", "bananas"]
206 .into_iter()
207 .map(|fruit| (fruit.to_owned(), 1.0))
208 .collect(),
209 )) {
210 Ok(serde_json::Value::Object(value)) => assert_eq!(value.len(), 2),
211 Ok(_) => panic!("map serialized incorrectly"),
212 Err(_) => panic!("could not serialize collection"),
213 }
214 }
215}