fts_core/models/
group.rs

1use std::hash::Hash;
2
3macro_rules! hashmap_newtype {
4    ($map:ident, $name:expr) => {
5        /// A wrapper around an implementation of a HashMap, with values of f64.
6        ///
7        /// Predictable and consistent ordering is important to ensure identical solutions
8        /// from repeated solves, so we replace the std::collections::HashMap with
9        /// indexmap::IndexMap. However, this is an implementation detail, so we wrap it
10        /// in a newtype, allowing us to replace it at a future date without breaking
11        /// semver. This unfortunately leads to additional boiler-plate, but at least it
12        /// is not particularly complicated.
13        #[derive(Debug, Clone, PartialEq)]
14        #[cfg_attr(feature = "schemars", derive(schemars::JsonSchema), schemars(rename = $name))]
15        #[cfg_attr(
16            feature = "serde",
17            derive(serde::Serialize, serde::Deserialize),
18            serde(
19                from = "Collection::<K>",
20                into = "Collection::<K>",
21                bound(serialize = "K: serde::Serialize + Clone")
22            )
23        )]
24        pub struct $map<K: Eq + Hash>(indexmap::IndexMap<K, f64, rustc_hash::FxBuildHasher>);
25
26        impl<K: Eq + Hash> Default for $map<K> {
27            fn default() -> Self {
28                Self(indexmap::IndexMap::default())
29            }
30        }
31
32        impl<K: Eq + Hash> std::ops::Deref for $map<K> {
33            type Target = indexmap::IndexMap<K, f64, rustc_hash::FxBuildHasher>;
34
35            fn deref(&self) -> &Self::Target {
36                &self.0
37            }
38        }
39
40        impl<K: Eq + Hash> std::ops::DerefMut for $map<K> {
41            fn deref_mut(&mut self) -> &mut Self::Target {
42                &mut self.0
43            }
44        }
45
46        impl<K: Eq + Hash> IntoIterator for $map<K> {
47            type Item = (K, f64);
48            type IntoIter = indexmap::map::IntoIter<K, f64>;
49
50            fn into_iter(self) -> Self::IntoIter {
51                self.0.into_iter()
52            }
53        }
54
55        impl<K: Eq + Hash> FromIterator<(K, f64)> for $map<K> {
56            fn from_iter<I: IntoIterator<Item = (K, f64)>>(iter: I) -> Self {
57                Self(indexmap::IndexMap::from_iter(iter))
58            }
59        }
60
61        impl<K: Eq + Hash> From<Collection<K>> for $map<K> {
62            fn from(value: Collection<K>) -> Self {
63                match value {
64                    Collection::OneOf(entry) => std::iter::once((entry, 1.0)).collect(),
65                    Collection::SumOf(entries) => {
66                        entries.into_iter().zip(std::iter::repeat(1.0)).collect()
67                    }
68                    Collection::MapOf(entries) => Self(entries),
69                }
70            }
71        }
72
73        impl<K: Eq + Hash + Clone> Into<Collection<K>> for $map<K> {
74            fn into(self) -> Collection<K> {
75                Collection::MapOf(self.0)
76            }
77        }
78    };
79}
80
81// For now, we implement demand- and product-groups the same way, though this
82// allows us to change the implementations separately. (For example, maybe we
83// switch the DemandGroup implementation to be optimal for assumed-small hash
84// tables.)
85
86hashmap_newtype!(DemandGroup, "DemandGroup");
87hashmap_newtype!(ProductGroup, "ProductGroup");
88
89// This type spells out the 3 ways to define a collection
90
91#[derive(Debug)]
92#[cfg_attr(
93    feature = "schemars",
94    derive(schemars::JsonSchema),
95    schemars(rename = "{T}Group", untagged)
96)]
97enum Collection<T: Eq + Hash> {
98    OneOf(T),
99    SumOf(Vec<T>),
100    MapOf(indexmap::IndexMap<T, f64, rustc_hash::FxBuildHasher>),
101}
102
103#[cfg(feature = "serde")]
104impl<'de, T: Eq + Hash + serde::Deserialize<'de>> serde::Deserialize<'de> for Collection<T> {
105    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
106    where
107        D: serde::Deserializer<'de>,
108    {
109        serde_untagged::UntaggedEnumVisitor::new()
110            .string(|one| {
111                T::deserialize(serde::de::value::StrDeserializer::new(one)).map(Self::OneOf)
112            })
113            .seq(|sum| sum.deserialize().map(Self::SumOf))
114            .map(|map| map.deserialize().map(Self::MapOf))
115            .deserialize(deserializer)
116    }
117}
118
119#[cfg(feature = "serde")]
120impl<T: Eq + Hash + serde::Serialize> serde::Serialize for Collection<T> {
121    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
122    where
123        S: serde::Serializer,
124    {
125        match self {
126            Self::OneOf(one) => one.serialize(serializer),
127            Self::SumOf(sum) => sum.serialize(serializer),
128            Self::MapOf(map) => map.serialize(serializer),
129        }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    #[test]
138    fn test_from_scalar() {
139        let result: Result<Collection<String>, _> = serde_json::from_str(r#""apples""#);
140        match result {
141            Ok(Collection::OneOf(_)) => (),
142            Ok(_) => {
143                panic!("scalar parsed incorrectly");
144            }
145            Err(_) => {
146                panic!("could not parse collection");
147            }
148        };
149    }
150
151    #[test]
152    fn test_from_vector() {
153        let result: Result<Collection<String>, _> =
154            serde_json::from_str(r#"["apples", "bananas"]"#);
155        match result {
156            Ok(Collection::SumOf(_)) => (),
157            Ok(_) => {
158                panic!("scalar parsed incorrectly");
159            }
160            Err(_) => {
161                panic!("could not parse collection");
162            }
163        };
164    }
165
166    #[test]
167    fn test_from_map() {
168        let result: Result<Collection<String>, _> =
169            serde_json::from_str(r#"{"apples": 1, "bananas": 1}"#);
170        match result {
171            Ok(Collection::MapOf(_)) => (),
172            Ok(_) => {
173                panic!("scalar parsed incorrectly");
174            }
175            Err(_) => {
176                panic!("could not parse collection");
177            }
178        };
179    }
180
181    #[test]
182    fn test_to_scalar() {
183        match serde_json::to_value(Collection::OneOf("apples".to_owned())) {
184            Ok(serde_json::Value::String(value)) => assert_eq!(value, "apples"),
185            Ok(_) => panic!("scalar serialized incorrectly"),
186            Err(_) => panic!("could not serialize collection"),
187        }
188    }
189
190    #[test]
191    fn test_to_vector() {
192        match serde_json::to_value(Collection::SumOf(vec![
193            "apples".to_owned(),
194            "bananas".to_owned(),
195        ])) {
196            Ok(serde_json::Value::Array(value)) => assert_eq!(value.len(), 2),
197            Ok(_) => panic!("vector serialized incorrectly"),
198            Err(_) => panic!("could not serialize collection"),
199        }
200    }
201
202    #[test]
203    fn test_to_map() {
204        match serde_json::to_value(Collection::MapOf(
205            ["apples", "bananas"]
206                .into_iter()
207                .map(|fruit| (fruit.to_owned(), 1.0))
208                .collect(),
209        )) {
210            Ok(serde_json::Value::Object(value)) => assert_eq!(value.len(), 2),
211            Ok(_) => panic!("map serialized incorrectly"),
212            Err(_) => panic!("could not serialize collection"),
213        }
214    }
215}