dypdl_heuristic_search/search_algorithm/data_structure/transition/
transition_with_custom_cost.rs

1use super::super::successor_generator::SuccessorGenerator;
2use super::TransitionWithId;
3use core::ops::Deref;
4use dypdl::variable_type::Numeric;
5use dypdl::{CostExpression, StateFunctionCache, StateFunctions, Transition, TransitionInterface};
6use std::fmt::Debug;
7
8/// Transition with a customized cost expression.
9#[derive(Debug, PartialEq, Clone, Default)]
10pub struct TransitionWithCustomCost {
11    /// Transition.
12    pub transition: Transition,
13    /// Customized cost expression.
14    pub custom_cost: CostExpression,
15}
16
17impl TransitionInterface for TransitionWithCustomCost {
18    #[inline]
19    fn is_applicable<S: dypdl::StateInterface>(
20        &self,
21        state: &S,
22        function_cache: &mut StateFunctionCache,
23        state_functions: &StateFunctions,
24        registry: &dypdl::TableRegistry,
25    ) -> bool {
26        self.transition
27            .is_applicable(state, function_cache, state_functions, registry)
28    }
29
30    #[inline]
31    fn apply<S: dypdl::StateInterface, T: From<dypdl::State>>(
32        &self,
33        state: &S,
34        function_cache: &mut StateFunctionCache,
35        state_functions: &StateFunctions,
36        registry: &dypdl::TableRegistry,
37    ) -> T {
38        self.transition
39            .apply(state, function_cache, state_functions, registry)
40    }
41
42    #[inline]
43    fn eval_cost<U: Numeric, T: dypdl::StateInterface>(
44        &self,
45        cost: U,
46        state: &T,
47        function_cache: &mut StateFunctionCache,
48        state_functions: &StateFunctions,
49        registry: &dypdl::TableRegistry,
50    ) -> U {
51        self.transition
52            .eval_cost(cost, state, function_cache, state_functions, registry)
53    }
54}
55
56impl From<TransitionWithCustomCost> for Transition {
57    fn from(transition: TransitionWithCustomCost) -> Self {
58        transition.transition
59    }
60}
61
62impl<U, R> SuccessorGenerator<TransitionWithCustomCost, U, R>
63where
64    U: Deref<Target = TransitionWithId<TransitionWithCustomCost>>
65        + Clone
66        + From<TransitionWithId<TransitionWithCustomCost>>,
67    R: Deref<Target = dypdl::Model>,
68{
69    /// Returns a successor generator returning applicable transitions with customized cost expressions.
70    pub fn from_model_with_custom_costs(
71        model: R,
72        custom_costs: &[CostExpression],
73        forced_custom_costs: &[CostExpression],
74        backward: bool,
75    ) -> Self {
76        let forced_transitions = if backward {
77            &model.backward_forced_transitions
78        } else {
79            &model.forward_forced_transitions
80        };
81        let forced_transitions = forced_transitions
82            .iter()
83            .zip(forced_custom_costs)
84            .enumerate()
85            .map(|(id, (t, c))| {
86                U::from(TransitionWithId {
87                    transition: TransitionWithCustomCost {
88                        transition: t.clone(),
89                        custom_cost: c.simplify(&model.table_registry),
90                    },
91                    forced: true,
92                    id,
93                })
94            })
95            .collect();
96
97        let transitions = if backward {
98            &model.backward_transitions
99        } else {
100            &model.forward_transitions
101        };
102        let transitions = transitions
103            .iter()
104            .zip(custom_costs)
105            .enumerate()
106            .map(|(id, (t, c))| {
107                U::from(TransitionWithId {
108                    transition: TransitionWithCustomCost {
109                        transition: t.clone(),
110                        custom_cost: c.simplify(&model.table_registry),
111                    },
112                    forced: false,
113                    id,
114                })
115            })
116            .collect();
117
118        SuccessorGenerator::new(forced_transitions, transitions, backward, model)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use dypdl::expression::*;
126    use dypdl::prelude::*;
127    use std::rc::Rc;
128
129    #[test]
130    fn transition_with_custom_cost_to_transition() {
131        let mut transition = Transition::new("transition");
132        transition.set_cost(IntegerExpression::Cost + 1);
133        let transition_with_custom_cost = TransitionWithCustomCost {
134            transition: transition.clone(),
135            custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
136        };
137        assert_eq!(Transition::from(transition_with_custom_cost), transition);
138    }
139
140    #[test]
141    fn is_applicable() {
142        let mut model = Model::default();
143        let var = model.add_integer_variable("v", 0);
144        assert!(var.is_ok());
145        let var = var.unwrap();
146
147        let mut transition = Transition::new("transition");
148        transition.add_precondition(Condition::comparison_i(ComparisonOperator::Le, var, 1));
149        let transition = TransitionWithCustomCost {
150            transition,
151            custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
152        };
153        let state = model.target;
154        let mut function_cache = StateFunctionCache::new(&model.state_functions);
155        assert!(transition.is_applicable(
156            &state,
157            &mut function_cache,
158            &model.state_functions,
159            &model.table_registry
160        ));
161    }
162
163    #[test]
164    fn is_not_applicable() {
165        let mut model = Model::default();
166        let var = model.add_integer_variable("v", 0);
167        assert!(var.is_ok());
168        let var = var.unwrap();
169
170        let mut transition = Transition::new("transition");
171        transition.add_precondition(Condition::comparison_i(ComparisonOperator::Le, var, 0));
172        let transition = TransitionWithCustomCost {
173            transition,
174            custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
175        };
176        let mut function_cache = StateFunctionCache::new(&model.state_functions);
177        assert!(transition.is_applicable(
178            &model.target,
179            &mut function_cache,
180            &model.state_functions,
181            &model.table_registry
182        ));
183    }
184
185    #[test]
186    fn apply() {
187        let mut model = Model::default();
188        let var1 = model.add_integer_variable("var1", 0);
189        assert!(var1.is_ok());
190        let var1 = var1.unwrap();
191        let var2 = model.add_integer_variable("var2", 0);
192        assert!(var2.is_ok());
193
194        let mut transition = Transition::new("transition");
195        let result = transition.add_effect(var1, var1 + 1);
196        assert!(result.is_ok());
197        let transition = TransitionWithCustomCost {
198            transition,
199            custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
200        };
201
202        let mut function_cache = StateFunctionCache::new(&model.state_functions);
203        let state: State = transition.apply(
204            &model.target,
205            &mut function_cache,
206            &model.state_functions,
207            &model.table_registry,
208        );
209        assert_eq!(state.get_integer_variable(0), 1);
210        assert_eq!(state.get_integer_variable(1), 0);
211    }
212
213    #[test]
214    fn eval_cost() {
215        let model = Model::default();
216
217        let mut transition = Transition::new("transition");
218        transition.set_cost(IntegerExpression::Cost + 1);
219        let transition = TransitionWithCustomCost {
220            transition,
221            custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
222        };
223        let mut function_cache = StateFunctionCache::new(&model.state_functions);
224        let cost = transition.eval_cost(
225            0,
226            &model.target,
227            &mut function_cache,
228            &model.state_functions,
229            &model.table_registry,
230        );
231        assert_eq!(cost, 1);
232    }
233
234    #[test]
235    fn from_model_with_custom_costs_forward() {
236        let mut model = Model::default();
237        let mut transition1 = Transition::new("transition1");
238        transition1.set_cost(IntegerExpression::Cost + 1);
239        let result = model.add_forward_transition(transition1.clone());
240        assert!(result.is_ok());
241        let mut transition2 = Transition::new("transition2");
242        transition2.set_cost(IntegerExpression::Cost + 2);
243        let result = model.add_forward_transition(transition2.clone());
244        assert!(result.is_ok());
245        let mut transition3 = Transition::new("transition3");
246        transition3.set_cost(IntegerExpression::Cost + 3);
247        let result = model.add_forward_forced_transition(transition3.clone());
248        assert!(result.is_ok());
249        let mut transition4 = Transition::new("transition4");
250        transition4.set_cost(IntegerExpression::Cost + 4);
251        let result = model.add_forward_forced_transition(transition4.clone());
252        assert!(result.is_ok());
253        let mut transition5 = Transition::new("transition5");
254        transition5.set_cost(IntegerExpression::Cost + 5);
255        let result = model.add_backward_transition(transition5.clone());
256        assert!(result.is_ok());
257        let mut transition6 = Transition::new("transition6");
258        transition6.set_cost(IntegerExpression::Cost + 6);
259        let result = model.add_backward_forced_transition(transition6.clone());
260        assert!(result.is_ok());
261        let model = Rc::new(model);
262
263        let custom_costs = [
264            CostExpression::Integer(IntegerExpression::Cost + 7),
265            CostExpression::Integer(IntegerExpression::Cost + 8),
266        ];
267        let forced_custom_costs = [
268            CostExpression::Integer(IntegerExpression::Cost + 9),
269            CostExpression::Integer(IntegerExpression::Cost + 10),
270        ];
271        let generator = SuccessorGenerator::<_>::from_model_with_custom_costs(
272            model.clone(),
273            &custom_costs,
274            &forced_custom_costs,
275            false,
276        );
277
278        assert_eq!(generator.model, model);
279        assert_eq!(
280            generator.transitions,
281            vec![
282                Rc::new(TransitionWithId {
283                    transition: TransitionWithCustomCost {
284                        transition: transition1,
285                        custom_cost: CostExpression::Integer(IntegerExpression::Cost + 7),
286                    },
287                    forced: false,
288                    id: 0
289                }),
290                Rc::new(TransitionWithId {
291                    transition: TransitionWithCustomCost {
292                        transition: transition2,
293                        custom_cost: CostExpression::Integer(IntegerExpression::Cost + 8),
294                    },
295                    forced: false,
296                    id: 1
297                }),
298            ]
299        );
300        assert_eq!(
301            generator.forced_transitions,
302            vec![
303                Rc::new(TransitionWithId {
304                    transition: TransitionWithCustomCost {
305                        transition: transition3,
306                        custom_cost: CostExpression::Integer(IntegerExpression::Cost + 9),
307                    },
308                    forced: true,
309                    id: 0,
310                }),
311                Rc::new(TransitionWithId {
312                    transition: TransitionWithCustomCost {
313                        transition: transition4,
314                        custom_cost: CostExpression::Integer(IntegerExpression::Cost + 10),
315                    },
316                    forced: true,
317                    id: 1,
318                }),
319            ]
320        );
321    }
322
323    #[test]
324    fn from_model_with_custom_costs_backward() {
325        let mut model = Model::default();
326        let mut transition1 = Transition::new("transition1");
327        transition1.set_cost(IntegerExpression::Cost + 1);
328        let result = model.add_backward_transition(transition1.clone());
329        assert!(result.is_ok());
330        let mut transition2 = Transition::new("transition2");
331        transition2.set_cost(IntegerExpression::Cost + 2);
332        let result = model.add_backward_transition(transition2.clone());
333        assert!(result.is_ok());
334        let mut transition3 = Transition::new("transition3");
335        transition3.set_cost(IntegerExpression::Cost + 3);
336        let result = model.add_backward_forced_transition(transition3.clone());
337        assert!(result.is_ok());
338        let mut transition4 = Transition::new("transition4");
339        transition4.set_cost(IntegerExpression::Cost + 4);
340        let result = model.add_backward_forced_transition(transition4.clone());
341        assert!(result.is_ok());
342        let mut transition5 = Transition::new("transition5");
343        transition5.set_cost(IntegerExpression::Cost + 5);
344        let result = model.add_forward_transition(transition5.clone());
345        assert!(result.is_ok());
346        let mut transition6 = Transition::new("transition6");
347        transition6.set_cost(IntegerExpression::Cost + 6);
348        let result = model.add_forward_forced_transition(transition6.clone());
349        assert!(result.is_ok());
350        let model = Rc::new(model);
351
352        let custom_costs = [
353            CostExpression::Integer(IntegerExpression::Cost + 7),
354            CostExpression::Integer(IntegerExpression::Cost + 8),
355        ];
356        let forced_custom_costs = [
357            CostExpression::Integer(IntegerExpression::Cost + 9),
358            CostExpression::Integer(IntegerExpression::Cost + 10),
359        ];
360        let generator = SuccessorGenerator::<_>::from_model_with_custom_costs(
361            model.clone(),
362            &custom_costs,
363            &forced_custom_costs,
364            true,
365        );
366
367        assert_eq!(generator.model, model);
368        assert_eq!(
369            generator.transitions,
370            vec![
371                Rc::new(TransitionWithId {
372                    transition: TransitionWithCustomCost {
373                        transition: transition1,
374                        custom_cost: CostExpression::Integer(IntegerExpression::Cost + 7),
375                    },
376                    forced: false,
377                    id: 0,
378                }),
379                Rc::new(TransitionWithId {
380                    transition: TransitionWithCustomCost {
381                        transition: transition2,
382                        custom_cost: CostExpression::Integer(IntegerExpression::Cost + 8),
383                    },
384                    forced: false,
385                    id: 1,
386                }),
387            ]
388        );
389        assert_eq!(
390            generator.forced_transitions,
391            vec![
392                Rc::new(TransitionWithId {
393                    transition: TransitionWithCustomCost {
394                        transition: transition3,
395                        custom_cost: CostExpression::Integer(IntegerExpression::Cost + 9),
396                    },
397                    forced: true,
398                    id: 0,
399                }),
400                Rc::new(TransitionWithId {
401                    transition: TransitionWithCustomCost {
402                        transition: transition4,
403                        custom_cost: CostExpression::Integer(IntegerExpression::Cost + 10),
404                    },
405                    forced: true,
406                    id: 1,
407                }),
408            ]
409        );
410    }
411}