routee_compass/plugin/input/default/grid_search/
plugin.rs1use crate::app::search::SearchApp;
2use crate::plugin::input::input_field::InputField;
3use crate::plugin::input::input_plugin::InputPlugin;
4use crate::plugin::input::InputJsonExtensions;
5use crate::plugin::input::InputPluginError;
6use routee_compass_core::util::multiset::MultiSet;
7use std::sync::Arc;
8
9pub struct GridSearchPlugin {}
12
13impl InputPlugin for GridSearchPlugin {
14 fn process(
15 &self,
16 input: &mut serde_json::Value,
17 _search_app: Arc<SearchApp>,
18 ) -> Result<(), InputPluginError> {
19 match process_grid_search(input)? {
20 None => Ok(()),
21 Some(grid_expansion) => {
22 let mut replacement = serde_json::json![grid_expansion];
23 std::mem::swap(&mut replacement, input);
24 Ok(())
25 }
26 }
27 }
28}
29
30fn process_grid_search(
31 input: &serde_json::Value,
32) -> Result<Option<Vec<serde_json::Value>>, InputPluginError> {
33 let grid_search_input = match input.get_grid_search() {
34 Some(gsi) => gsi,
35 None => return Ok(None),
36 };
37
38 let recurses = serde_json::to_string(grid_search_input)
40 .map_err(|e| InputPluginError::JsonError { source: e })?
41 .contains("grid_search");
42 if recurses {
43 return Err(InputPluginError::InputPluginFailed(String::from(
44 "grid search section cannot contain the string 'grid_search'",
45 )));
46 }
47
48 let map = grid_search_input
49 .as_object()
50 .ok_or_else(|| InputPluginError::UnexpectedQueryStructure(format!("{input:?}")))?;
51 let mut keys: Vec<String> = vec![];
52 let mut multiset_input: Vec<Vec<serde_json::Value>> = vec![];
53 let mut multiset_indices: Vec<Vec<usize>> = vec![];
54 for (k, v) in map {
55 if let Some(v) = v.as_array() {
56 keys.push(k.to_string());
57 multiset_input.push(v.to_vec());
58 let indices = (0..v.len()).collect();
59 multiset_indices.push(indices);
60 }
61 }
62 let mut initial_map = input
66 .as_object()
67 .ok_or_else(|| InputPluginError::UnexpectedQueryStructure(format!("{input:?}")))?
68 .clone();
69 initial_map.remove(InputField::GridSearch.to_str());
70 let initial = serde_json::json!(initial_map);
71 let multiset = MultiSet::from(&multiset_indices);
72 let result: Vec<serde_json::Value> = multiset
73 .into_iter()
74 .map(|combination| {
75 let mut instance = initial.clone();
76 let it = keys.iter().zip(combination.iter()).enumerate();
77 for (set_idx, (key, val_idx)) in it {
78 let value = multiset_input[set_idx][*val_idx].clone();
79 match value {
80 serde_json::Value::Object(o) => {
81 for (k, v) in o.into_iter() {
82 instance[k] = v.clone();
83 }
84 }
85 _ => {
86 instance[key] = multiset_input[set_idx][*val_idx].clone();
87 }
88 }
89 }
90 instance
91 })
92 .collect();
93
94 Ok(Some(result))
95}
96
97#[cfg(test)]
98mod test {
99 use super::*;
100
101 use serde_json::json;
102
103 #[test]
104 fn test_grid_search_empty_parent_object() {
105 let input = serde_json::json!({
106 "grid_search": {
107 "bar": ["a", "b", "c"],
108 "foo": [1.2, 3.4]
109 }
110 });
111
112 let result = match process_grid_search(&input) {
113 Ok(Some(rows)) => rows,
114 Ok(None) => panic!("process_grid_search returned no expansions"),
115 Err(e) => panic!("{}", e),
116 };
117 let expected = vec![
118 json![{"bar":"a","foo":1.2}],
119 json![{"bar":"b","foo":1.2}],
120 json![{"bar":"c","foo":1.2}],
121 json![{"bar":"a","foo":3.4}],
122 json![{"bar":"b","foo":3.4}],
123 json![{"bar":"c","foo":3.4}],
124 ];
125 assert_eq!(result, expected)
126 }
127
128 #[test]
129 fn test_grid_search_persisted_parent_keys() {
130 let input = serde_json::json!({
131 "ignored_key": "ignored_value",
132 "grid_search": {
133 "bar": ["a", "b", "c"],
134 "foo": [1.2, 3.4]
135 }
136 });
137
138 let result = match process_grid_search(&input) {
139 Ok(Some(rows)) => rows,
140 Ok(None) => panic!("process_grid_search returned no expansions"),
141 Err(e) => panic!("{}", e),
142 };
143
144 let expected = vec![
145 json![{"bar":"a","foo":1.2,"ignored_key": "ignored_value"}],
146 json![{"bar":"b","foo":1.2,"ignored_key": "ignored_value"}],
147 json![{"bar":"c","foo":1.2,"ignored_key": "ignored_value"}],
148 json![{"bar":"a","foo":3.4,"ignored_key": "ignored_value"}],
149 json![{"bar":"b","foo":3.4,"ignored_key": "ignored_value"}],
150 json![{"bar":"c","foo":3.4,"ignored_key": "ignored_value"}],
151 ];
152
153 assert_eq!(result, expected)
154 }
155
156 #[test]
157 fn test_grid_search_using_objects() {
158 let input = serde_json::json!({
159 "ignored_key": "ignored_value",
160 "grid_search": {
161 "a": [1, 2],
162 "ignored_inner_key": [
163 { "x": 0, "y": 0 },
164 { "x": 1, "y": 1 }
165 ],
166 }
167 });
168
169 let result = match process_grid_search(&input) {
170 Ok(Some(rows)) => rows,
171 Ok(None) => panic!("process_grid_search returned no expansions"),
172 Err(e) => panic!("{}", e),
173 };
174
175 let expected = vec![
176 json![{"a":1,"ignored_key":"ignored_value","x":0,"y":0}],
177 json![{"a":2,"ignored_key":"ignored_value","x":0,"y":0}],
178 json![{"a":1,"ignored_key":"ignored_value","x":1,"y":1}],
179 json![ {"a":2,"ignored_key":"ignored_value","x":1,"y":1}],
180 ];
181
182 assert_eq!(result, expected)
183 }
184
185 #[test]
186 fn test_nested() {
187 let input = serde_json::json!({
188 "abc": 123,
189 "grid_search":{
190 "model_name": ["2016_TOYOTA_Camry_4cyl_2WD","2017_CHEVROLET_Bolt"],
191 "_ignore":[
192 { "name":"d1", "weights": { "distance":1, "time":0, "energy_electric":0 } },
193 { "name":"t1", "weights": { "distance":0, "time":1, "energy_electric":0 } },
194 { "name":"e1", "weights": { "distance":0, "time":0, "energy_electric":1 } }
195 ]
196 }
197 });
198
199 let result = match process_grid_search(&input) {
200 Ok(Some(rows)) => rows,
201 Ok(None) => panic!("process_grid_search returned no expansions"),
202 Err(e) => panic!("{}", e),
203 };
204
205 let expected = vec![
206 json![{"abc":123,"model_name":"2016_TOYOTA_Camry_4cyl_2WD","name":"d1","weights":{"distance":1,"time":0,"energy_electric":0}}],
207 json![{"abc":123,"model_name":"2017_CHEVROLET_Bolt","name":"d1","weights":{"distance":1,"time":0,"energy_electric":0}}],
208 json![{"abc":123,"model_name":"2016_TOYOTA_Camry_4cyl_2WD","name":"t1","weights":{"distance":0,"time":1,"energy_electric":0}}],
209 json![{"abc":123,"model_name":"2017_CHEVROLET_Bolt","name":"t1","weights":{"distance":0,"time":1,"energy_electric":0}}],
210 json![{"abc":123,"model_name":"2016_TOYOTA_Camry_4cyl_2WD","name":"e1","weights":{"distance":0,"time":0,"energy_electric":1}}],
211 json![{"abc":123,"model_name":"2017_CHEVROLET_Bolt","name":"e1","weights":{"distance":0,"time":0,"energy_electric":1}}],
212 ];
213
214 assert_eq!(result, expected)
215 }
216
217 #[test]
218 pub fn test_handle_recursion() {
219 let input = serde_json::json!({
220 "abc": 123,
221 "grid_search":{
222 "grid_search": {
223 "foo": [ "a", "b" ]
224 }
225 }
226 });
227
228 match process_grid_search(&input) {
229 Ok(Some(_)) => panic!("process_grid_search should return an error"),
230 Ok(None) => panic!("process_grid_search returned no error"),
231 Err(_) => {}
232 };
233 }
234}