1use super::super::successor_generator::SuccessorGenerator;
2use super::TransitionWithId;
3use core::ops::Deref;
4use dypdl::variable_type::Numeric;
5use dypdl::{CostExpression, StateFunctionCache, StateFunctions, Transition, TransitionInterface};
6use std::fmt::Debug;
7
8#[derive(Debug, PartialEq, Clone, Default)]
10pub struct TransitionWithCustomCost {
11 pub transition: Transition,
13 pub custom_cost: CostExpression,
15}
16
17impl TransitionInterface for TransitionWithCustomCost {
18 #[inline]
19 fn is_applicable<S: dypdl::StateInterface>(
20 &self,
21 state: &S,
22 function_cache: &mut StateFunctionCache,
23 state_functions: &StateFunctions,
24 registry: &dypdl::TableRegistry,
25 ) -> bool {
26 self.transition
27 .is_applicable(state, function_cache, state_functions, registry)
28 }
29
30 #[inline]
31 fn apply<S: dypdl::StateInterface, T: From<dypdl::State>>(
32 &self,
33 state: &S,
34 function_cache: &mut StateFunctionCache,
35 state_functions: &StateFunctions,
36 registry: &dypdl::TableRegistry,
37 ) -> T {
38 self.transition
39 .apply(state, function_cache, state_functions, registry)
40 }
41
42 #[inline]
43 fn eval_cost<U: Numeric, T: dypdl::StateInterface>(
44 &self,
45 cost: U,
46 state: &T,
47 function_cache: &mut StateFunctionCache,
48 state_functions: &StateFunctions,
49 registry: &dypdl::TableRegistry,
50 ) -> U {
51 self.transition
52 .eval_cost(cost, state, function_cache, state_functions, registry)
53 }
54}
55
56impl From<TransitionWithCustomCost> for Transition {
57 fn from(transition: TransitionWithCustomCost) -> Self {
58 transition.transition
59 }
60}
61
62impl<U, R> SuccessorGenerator<TransitionWithCustomCost, U, R>
63where
64 U: Deref<Target = TransitionWithId<TransitionWithCustomCost>>
65 + Clone
66 + From<TransitionWithId<TransitionWithCustomCost>>,
67 R: Deref<Target = dypdl::Model>,
68{
69 pub fn from_model_with_custom_costs(
71 model: R,
72 custom_costs: &[CostExpression],
73 forced_custom_costs: &[CostExpression],
74 backward: bool,
75 ) -> Self {
76 let forced_transitions = if backward {
77 &model.backward_forced_transitions
78 } else {
79 &model.forward_forced_transitions
80 };
81 let forced_transitions = forced_transitions
82 .iter()
83 .zip(forced_custom_costs)
84 .enumerate()
85 .map(|(id, (t, c))| {
86 U::from(TransitionWithId {
87 transition: TransitionWithCustomCost {
88 transition: t.clone(),
89 custom_cost: c.simplify(&model.table_registry),
90 },
91 forced: true,
92 id,
93 })
94 })
95 .collect();
96
97 let transitions = if backward {
98 &model.backward_transitions
99 } else {
100 &model.forward_transitions
101 };
102 let transitions = transitions
103 .iter()
104 .zip(custom_costs)
105 .enumerate()
106 .map(|(id, (t, c))| {
107 U::from(TransitionWithId {
108 transition: TransitionWithCustomCost {
109 transition: t.clone(),
110 custom_cost: c.simplify(&model.table_registry),
111 },
112 forced: false,
113 id,
114 })
115 })
116 .collect();
117
118 SuccessorGenerator::new(forced_transitions, transitions, backward, model)
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use dypdl::expression::*;
126 use dypdl::prelude::*;
127 use std::rc::Rc;
128
129 #[test]
130 fn transition_with_custom_cost_to_transition() {
131 let mut transition = Transition::new("transition");
132 transition.set_cost(IntegerExpression::Cost + 1);
133 let transition_with_custom_cost = TransitionWithCustomCost {
134 transition: transition.clone(),
135 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
136 };
137 assert_eq!(Transition::from(transition_with_custom_cost), transition);
138 }
139
140 #[test]
141 fn is_applicable() {
142 let mut model = Model::default();
143 let var = model.add_integer_variable("v", 0);
144 assert!(var.is_ok());
145 let var = var.unwrap();
146
147 let mut transition = Transition::new("transition");
148 transition.add_precondition(Condition::comparison_i(ComparisonOperator::Le, var, 1));
149 let transition = TransitionWithCustomCost {
150 transition,
151 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
152 };
153 let state = model.target;
154 let mut function_cache = StateFunctionCache::new(&model.state_functions);
155 assert!(transition.is_applicable(
156 &state,
157 &mut function_cache,
158 &model.state_functions,
159 &model.table_registry
160 ));
161 }
162
163 #[test]
164 fn is_not_applicable() {
165 let mut model = Model::default();
166 let var = model.add_integer_variable("v", 0);
167 assert!(var.is_ok());
168 let var = var.unwrap();
169
170 let mut transition = Transition::new("transition");
171 transition.add_precondition(Condition::comparison_i(ComparisonOperator::Le, var, 0));
172 let transition = TransitionWithCustomCost {
173 transition,
174 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
175 };
176 let mut function_cache = StateFunctionCache::new(&model.state_functions);
177 assert!(transition.is_applicable(
178 &model.target,
179 &mut function_cache,
180 &model.state_functions,
181 &model.table_registry
182 ));
183 }
184
185 #[test]
186 fn apply() {
187 let mut model = Model::default();
188 let var1 = model.add_integer_variable("var1", 0);
189 assert!(var1.is_ok());
190 let var1 = var1.unwrap();
191 let var2 = model.add_integer_variable("var2", 0);
192 assert!(var2.is_ok());
193
194 let mut transition = Transition::new("transition");
195 let result = transition.add_effect(var1, var1 + 1);
196 assert!(result.is_ok());
197 let transition = TransitionWithCustomCost {
198 transition,
199 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
200 };
201
202 let mut function_cache = StateFunctionCache::new(&model.state_functions);
203 let state: State = transition.apply(
204 &model.target,
205 &mut function_cache,
206 &model.state_functions,
207 &model.table_registry,
208 );
209 assert_eq!(state.get_integer_variable(0), 1);
210 assert_eq!(state.get_integer_variable(1), 0);
211 }
212
213 #[test]
214 fn eval_cost() {
215 let model = Model::default();
216
217 let mut transition = Transition::new("transition");
218 transition.set_cost(IntegerExpression::Cost + 1);
219 let transition = TransitionWithCustomCost {
220 transition,
221 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 2),
222 };
223 let mut function_cache = StateFunctionCache::new(&model.state_functions);
224 let cost = transition.eval_cost(
225 0,
226 &model.target,
227 &mut function_cache,
228 &model.state_functions,
229 &model.table_registry,
230 );
231 assert_eq!(cost, 1);
232 }
233
234 #[test]
235 fn from_model_with_custom_costs_forward() {
236 let mut model = Model::default();
237 let mut transition1 = Transition::new("transition1");
238 transition1.set_cost(IntegerExpression::Cost + 1);
239 let result = model.add_forward_transition(transition1.clone());
240 assert!(result.is_ok());
241 let mut transition2 = Transition::new("transition2");
242 transition2.set_cost(IntegerExpression::Cost + 2);
243 let result = model.add_forward_transition(transition2.clone());
244 assert!(result.is_ok());
245 let mut transition3 = Transition::new("transition3");
246 transition3.set_cost(IntegerExpression::Cost + 3);
247 let result = model.add_forward_forced_transition(transition3.clone());
248 assert!(result.is_ok());
249 let mut transition4 = Transition::new("transition4");
250 transition4.set_cost(IntegerExpression::Cost + 4);
251 let result = model.add_forward_forced_transition(transition4.clone());
252 assert!(result.is_ok());
253 let mut transition5 = Transition::new("transition5");
254 transition5.set_cost(IntegerExpression::Cost + 5);
255 let result = model.add_backward_transition(transition5.clone());
256 assert!(result.is_ok());
257 let mut transition6 = Transition::new("transition6");
258 transition6.set_cost(IntegerExpression::Cost + 6);
259 let result = model.add_backward_forced_transition(transition6.clone());
260 assert!(result.is_ok());
261 let model = Rc::new(model);
262
263 let custom_costs = [
264 CostExpression::Integer(IntegerExpression::Cost + 7),
265 CostExpression::Integer(IntegerExpression::Cost + 8),
266 ];
267 let forced_custom_costs = [
268 CostExpression::Integer(IntegerExpression::Cost + 9),
269 CostExpression::Integer(IntegerExpression::Cost + 10),
270 ];
271 let generator = SuccessorGenerator::<_>::from_model_with_custom_costs(
272 model.clone(),
273 &custom_costs,
274 &forced_custom_costs,
275 false,
276 );
277
278 assert_eq!(generator.model, model);
279 assert_eq!(
280 generator.transitions,
281 vec![
282 Rc::new(TransitionWithId {
283 transition: TransitionWithCustomCost {
284 transition: transition1,
285 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 7),
286 },
287 forced: false,
288 id: 0
289 }),
290 Rc::new(TransitionWithId {
291 transition: TransitionWithCustomCost {
292 transition: transition2,
293 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 8),
294 },
295 forced: false,
296 id: 1
297 }),
298 ]
299 );
300 assert_eq!(
301 generator.forced_transitions,
302 vec![
303 Rc::new(TransitionWithId {
304 transition: TransitionWithCustomCost {
305 transition: transition3,
306 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 9),
307 },
308 forced: true,
309 id: 0,
310 }),
311 Rc::new(TransitionWithId {
312 transition: TransitionWithCustomCost {
313 transition: transition4,
314 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 10),
315 },
316 forced: true,
317 id: 1,
318 }),
319 ]
320 );
321 }
322
323 #[test]
324 fn from_model_with_custom_costs_backward() {
325 let mut model = Model::default();
326 let mut transition1 = Transition::new("transition1");
327 transition1.set_cost(IntegerExpression::Cost + 1);
328 let result = model.add_backward_transition(transition1.clone());
329 assert!(result.is_ok());
330 let mut transition2 = Transition::new("transition2");
331 transition2.set_cost(IntegerExpression::Cost + 2);
332 let result = model.add_backward_transition(transition2.clone());
333 assert!(result.is_ok());
334 let mut transition3 = Transition::new("transition3");
335 transition3.set_cost(IntegerExpression::Cost + 3);
336 let result = model.add_backward_forced_transition(transition3.clone());
337 assert!(result.is_ok());
338 let mut transition4 = Transition::new("transition4");
339 transition4.set_cost(IntegerExpression::Cost + 4);
340 let result = model.add_backward_forced_transition(transition4.clone());
341 assert!(result.is_ok());
342 let mut transition5 = Transition::new("transition5");
343 transition5.set_cost(IntegerExpression::Cost + 5);
344 let result = model.add_forward_transition(transition5.clone());
345 assert!(result.is_ok());
346 let mut transition6 = Transition::new("transition6");
347 transition6.set_cost(IntegerExpression::Cost + 6);
348 let result = model.add_forward_forced_transition(transition6.clone());
349 assert!(result.is_ok());
350 let model = Rc::new(model);
351
352 let custom_costs = [
353 CostExpression::Integer(IntegerExpression::Cost + 7),
354 CostExpression::Integer(IntegerExpression::Cost + 8),
355 ];
356 let forced_custom_costs = [
357 CostExpression::Integer(IntegerExpression::Cost + 9),
358 CostExpression::Integer(IntegerExpression::Cost + 10),
359 ];
360 let generator = SuccessorGenerator::<_>::from_model_with_custom_costs(
361 model.clone(),
362 &custom_costs,
363 &forced_custom_costs,
364 true,
365 );
366
367 assert_eq!(generator.model, model);
368 assert_eq!(
369 generator.transitions,
370 vec![
371 Rc::new(TransitionWithId {
372 transition: TransitionWithCustomCost {
373 transition: transition1,
374 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 7),
375 },
376 forced: false,
377 id: 0,
378 }),
379 Rc::new(TransitionWithId {
380 transition: TransitionWithCustomCost {
381 transition: transition2,
382 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 8),
383 },
384 forced: false,
385 id: 1,
386 }),
387 ]
388 );
389 assert_eq!(
390 generator.forced_transitions,
391 vec![
392 Rc::new(TransitionWithId {
393 transition: TransitionWithCustomCost {
394 transition: transition3,
395 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 9),
396 },
397 forced: true,
398 id: 0,
399 }),
400 Rc::new(TransitionWithId {
401 transition: TransitionWithCustomCost {
402 transition: transition4,
403 custom_cost: CostExpression::Integer(IntegerExpression::Cost + 10),
404 },
405 forced: true,
406 id: 1,
407 }),
408 ]
409 );
410 }
411}