routee_compass/plugin/input/default/grid_search/
plugin.rs

1use crate::app::search::SearchApp;
2use crate::plugin::input::input_field::InputField;
3use crate::plugin::input::input_plugin::InputPlugin;
4use crate::plugin::input::InputJsonExtensions;
5use crate::plugin::input::InputPluginError;
6use routee_compass_core::util::multiset::MultiSet;
7use std::sync::Arc;
8
9/// Builds an input plugin that duplicates queries if array-valued fields are present
10/// by stepping through each combination of value
11pub struct GridSearchPlugin {}
12
13impl InputPlugin for GridSearchPlugin {
14    fn process(
15        &self,
16        input: &mut serde_json::Value,
17        _search_app: Arc<SearchApp>,
18    ) -> Result<(), InputPluginError> {
19        match process_grid_search(input)? {
20            None => Ok(()),
21            Some(grid_expansion) => {
22                let mut replacement = serde_json::json![grid_expansion];
23                std::mem::swap(&mut replacement, input);
24                Ok(())
25            }
26        }
27    }
28}
29
30fn process_grid_search(
31    input: &serde_json::Value,
32) -> Result<Option<Vec<serde_json::Value>>, InputPluginError> {
33    let grid_search_input = match input.get_grid_search() {
34        Some(gsi) => gsi,
35        None => return Ok(None),
36    };
37
38    // prevent recursion due to nested grid search keys
39    let recurses = serde_json::to_string(grid_search_input)
40        .map_err(|e| InputPluginError::JsonError { source: e })?
41        .contains("grid_search");
42    if recurses {
43        return Err(InputPluginError::InputPluginFailed(String::from(
44            "grid search section cannot contain the string 'grid_search'",
45        )));
46    }
47
48    let map = grid_search_input
49        .as_object()
50        .ok_or_else(|| InputPluginError::UnexpectedQueryStructure(format!("{input:?}")))?;
51    let mut keys: Vec<String> = vec![];
52    let mut multiset_input: Vec<Vec<serde_json::Value>> = vec![];
53    let mut multiset_indices: Vec<Vec<usize>> = vec![];
54    for (k, v) in map {
55        if let Some(v) = v.as_array() {
56            keys.push(k.to_string());
57            multiset_input.push(v.to_vec());
58            let indices = (0..v.len()).collect();
59            multiset_indices.push(indices);
60        }
61    }
62    // for each combination, copy the grid search values into a fresh
63    // copy of the source (minus the "grid_search" key)
64    // let remove_key = InputField::GridSearch.to_str();
65    let mut initial_map = input
66        .as_object()
67        .ok_or_else(|| InputPluginError::UnexpectedQueryStructure(format!("{input:?}")))?
68        .clone();
69    initial_map.remove(InputField::GridSearch.to_str());
70    let initial = serde_json::json!(initial_map);
71    let multiset = MultiSet::from(&multiset_indices);
72    let result: Vec<serde_json::Value> = multiset
73        .into_iter()
74        .map(|combination| {
75            let mut instance = initial.clone();
76            let it = keys.iter().zip(combination.iter()).enumerate();
77            for (set_idx, (key, val_idx)) in it {
78                let value = multiset_input[set_idx][*val_idx].clone();
79                match value {
80                    serde_json::Value::Object(o) => {
81                        for (k, v) in o.into_iter() {
82                            instance[k] = v.clone();
83                        }
84                    }
85                    _ => {
86                        instance[key] = multiset_input[set_idx][*val_idx].clone();
87                    }
88                }
89            }
90            instance
91        })
92        .collect();
93
94    Ok(Some(result))
95}
96
97#[cfg(test)]
98mod test {
99    use super::*;
100
101    use serde_json::json;
102
103    #[test]
104    fn test_grid_search_empty_parent_object() {
105        let input = serde_json::json!({
106            "grid_search": {
107                "bar": ["a", "b", "c"],
108                "foo": [1.2, 3.4]
109            }
110        });
111
112        let result = match process_grid_search(&input) {
113            Ok(Some(rows)) => rows,
114            Ok(None) => panic!("process_grid_search returned no expansions"),
115            Err(e) => panic!("{}", e),
116        };
117        let expected = vec![
118            json![{"bar":"a","foo":1.2}],
119            json![{"bar":"b","foo":1.2}],
120            json![{"bar":"c","foo":1.2}],
121            json![{"bar":"a","foo":3.4}],
122            json![{"bar":"b","foo":3.4}],
123            json![{"bar":"c","foo":3.4}],
124        ];
125        assert_eq!(result, expected)
126    }
127
128    #[test]
129    fn test_grid_search_persisted_parent_keys() {
130        let input = serde_json::json!({
131            "ignored_key": "ignored_value",
132            "grid_search": {
133                "bar": ["a", "b", "c"],
134                "foo": [1.2, 3.4]
135            }
136        });
137
138        let result = match process_grid_search(&input) {
139            Ok(Some(rows)) => rows,
140            Ok(None) => panic!("process_grid_search returned no expansions"),
141            Err(e) => panic!("{}", e),
142        };
143
144        let expected = vec![
145            json![{"bar":"a","foo":1.2,"ignored_key": "ignored_value"}],
146            json![{"bar":"b","foo":1.2,"ignored_key": "ignored_value"}],
147            json![{"bar":"c","foo":1.2,"ignored_key": "ignored_value"}],
148            json![{"bar":"a","foo":3.4,"ignored_key": "ignored_value"}],
149            json![{"bar":"b","foo":3.4,"ignored_key": "ignored_value"}],
150            json![{"bar":"c","foo":3.4,"ignored_key": "ignored_value"}],
151        ];
152
153        assert_eq!(result, expected)
154    }
155
156    #[test]
157    fn test_grid_search_using_objects() {
158        let input = serde_json::json!({
159            "ignored_key": "ignored_value",
160            "grid_search": {
161                "a": [1, 2],
162                "ignored_inner_key": [
163                    { "x": 0, "y": 0 },
164                    { "x": 1, "y": 1 }
165                ],
166            }
167        });
168
169        let result = match process_grid_search(&input) {
170            Ok(Some(rows)) => rows,
171            Ok(None) => panic!("process_grid_search returned no expansions"),
172            Err(e) => panic!("{}", e),
173        };
174
175        let expected = vec![
176            json![{"a":1,"ignored_key":"ignored_value","x":0,"y":0}],
177            json![{"a":2,"ignored_key":"ignored_value","x":0,"y":0}],
178            json![{"a":1,"ignored_key":"ignored_value","x":1,"y":1}],
179            json![ {"a":2,"ignored_key":"ignored_value","x":1,"y":1}],
180        ];
181
182        assert_eq!(result, expected)
183    }
184
185    #[test]
186    fn test_nested() {
187        let input = serde_json::json!({
188            "abc": 123,
189            "grid_search":{
190                "model_name": ["2016_TOYOTA_Camry_4cyl_2WD","2017_CHEVROLET_Bolt"],
191                "_ignore":[
192                    { "name":"d1", "weights": { "distance":1, "time":0, "energy_electric":0 } },
193                    { "name":"t1", "weights": { "distance":0, "time":1, "energy_electric":0 } },
194                    { "name":"e1", "weights": { "distance":0, "time":0, "energy_electric":1 } }
195                ]
196            }
197        });
198
199        let result = match process_grid_search(&input) {
200            Ok(Some(rows)) => rows,
201            Ok(None) => panic!("process_grid_search returned no expansions"),
202            Err(e) => panic!("{}", e),
203        };
204
205        let expected = vec![
206            json![{"abc":123,"model_name":"2016_TOYOTA_Camry_4cyl_2WD","name":"d1","weights":{"distance":1,"time":0,"energy_electric":0}}],
207            json![{"abc":123,"model_name":"2017_CHEVROLET_Bolt","name":"d1","weights":{"distance":1,"time":0,"energy_electric":0}}],
208            json![{"abc":123,"model_name":"2016_TOYOTA_Camry_4cyl_2WD","name":"t1","weights":{"distance":0,"time":1,"energy_electric":0}}],
209            json![{"abc":123,"model_name":"2017_CHEVROLET_Bolt","name":"t1","weights":{"distance":0,"time":1,"energy_electric":0}}],
210            json![{"abc":123,"model_name":"2016_TOYOTA_Camry_4cyl_2WD","name":"e1","weights":{"distance":0,"time":0,"energy_electric":1}}],
211            json![{"abc":123,"model_name":"2017_CHEVROLET_Bolt","name":"e1","weights":{"distance":0,"time":0,"energy_electric":1}}],
212        ];
213
214        assert_eq!(result, expected)
215    }
216
217    #[test]
218    pub fn test_handle_recursion() {
219        let input = serde_json::json!({
220            "abc": 123,
221            "grid_search":{
222                "grid_search": {
223                    "foo": [ "a", "b" ]
224                }
225            }
226        });
227
228        match process_grid_search(&input) {
229            Ok(Some(_)) => panic!("process_grid_search should return an error"),
230            Ok(None) => panic!("process_grid_search returned no error"),
231            Err(_) => {}
232        };
233    }
234}