1use std::time::{Duration, Instant};
16
17use console::style;
18use indicatif::{ProgressBar, ProgressStyle};
19use itertools::Itertools;
20use log::debug;
21
22use super::afftree::*;
23use crate::linalg::affine::*;
24use crate::pwl::node::AffContent;
25use crate::tree::graph::*;
26
27pub trait CompositionVisitor {
31 fn start_composition(&mut self, expected_iterations: usize);
32 fn start_subtree(&mut self, node: TreeIndex);
33 fn finish_subtree(&mut self, n_nodes: usize);
34 fn finish_composition(&mut self);
35}
36
37pub trait CompositionSchema {
39 fn update_decision(original: &AffFunc, context: &AffFunc) -> AffFunc;
42
43 fn update_terminal(original: &AffFunc, context: &AffFunc) -> AffFunc;
46
47 fn explore<const K: usize>(context: &AffTree<K>, parent: TreeIndex, child: TreeIndex) -> bool;
49}
50
51#[derive(Clone, Debug)]
53pub struct NoOpVis {}
54
55impl CompositionVisitor for NoOpVis {
56 fn start_composition(&mut self, _: usize) {}
57
58 fn start_subtree(&mut self, _: TreeIndex) {}
59
60 fn finish_subtree(&mut self, _: usize) {}
61
62 fn finish_composition(&mut self) {}
63}
64
65#[derive(Clone, Debug)]
68pub struct CompositionConsole {
69 pb: ProgressBar,
70 timer: Instant,
71 len: usize,
72}
73
74impl Default for CompositionConsole {
75 fn default() -> Self {
76 Self::new()
77 }
78}
79
80impl CompositionConsole {
81 pub fn new() -> CompositionConsole {
82 CompositionConsole {
83 pb: ProgressBar::hidden(),
84 timer: Instant::now(),
85 len: 0,
86 }
87 }
88}
89
90impl CompositionVisitor for CompositionConsole {
91 fn start_composition(&mut self, expected_iterations: usize) {
92 self.pb = ProgressBar::new(expected_iterations as u64);
93 let sty = ProgressStyle::default_bar()
94 .template(&format!(
95 "{: >12} {}",
96 style("Building").cyan().bold(),
97 "[{bar:25}] {pos:>2}/{len:2} ({elapsed})"
98 ))
99 .unwrap()
100 .progress_chars("=> ");
101 self.pb.set_style(sty.clone());
102 self.pb.enable_steady_tick(Duration::from_secs(5));
103
104 self.timer = Instant::now();
105 self.len = expected_iterations;
106 }
107
108 fn start_subtree(&mut self, _node: TreeIndex) {}
109
110 fn finish_subtree(&mut self, _n_nodes: usize) {
111 self.pb.inc(1);
112 }
113
114 fn finish_composition(&mut self) {
115 self.pb.finish_and_clear();
116 }
117}
118
119#[derive(Clone, Debug)]
121pub struct FunctionComposition {}
122
123impl CompositionSchema for FunctionComposition {
124 fn update_decision(original: &AffFunc, context: &AffFunc) -> AffFunc {
125 AffFunc::from_mats(
126 original.mat.dot(&context.mat),
127 -original.mat.dot(&context.bias) + &original.bias,
128 )
129 }
130
131 fn update_terminal(original: &AffFunc, context: &AffFunc) -> AffFunc {
132 original.compose(context)
133 }
134
135 fn explore<const K: usize>(
136 _context: &AffTree<K>,
137 _parent: TreeIndex,
138 _child: TreeIndex,
139 ) -> bool {
140 true
141 }
142}
143
144#[derive(Clone, Debug)]
146pub struct FunctionCompositionInfeasible {}
147
148impl CompositionSchema for FunctionCompositionInfeasible {
149 fn update_decision(original: &AffFunc, context: &AffFunc) -> AffFunc {
150 AffFunc::from_mats(
151 original.mat.dot(&context.mat),
152 -original.mat.dot(&context.bias) + &original.bias,
153 )
154 }
155
156 fn update_terminal(original: &AffFunc, context: &AffFunc) -> AffFunc {
157 original.compose(context)
158 }
159
160 fn explore<const K: usize>(context: &AffTree<K>, parent: TreeIndex, child: TreeIndex) -> bool {
161 context.is_edge_feasible(parent, child)
162 }
163}
164
165impl<const K: usize> AffTree<K> {
167 #[inline]
192 pub fn compose<const PRUNE: bool, const VERBOSE: bool>(&mut self, other: &AffTree<K>) {
193 if PRUNE && VERBOSE {
194 AffTree::<K>::generic_composition_inplace(
195 other,
196 self,
197 self.tree.terminal_indices().collect_vec(),
198 FunctionCompositionInfeasible {},
199 CompositionConsole::new(),
200 );
201 } else if PRUNE && !VERBOSE {
202 AffTree::<K>::generic_composition_inplace(
203 other,
204 self,
205 self.tree.terminal_indices().collect_vec(),
206 FunctionCompositionInfeasible {},
207 NoOpVis {},
208 );
209 } else if !PRUNE && VERBOSE {
210 AffTree::<K>::generic_composition_inplace(
211 other,
212 self,
213 self.tree.terminal_indices().collect_vec(),
214 FunctionComposition {},
215 CompositionConsole::new(),
216 );
217 } else {
218 AffTree::<K>::generic_composition_inplace(
219 other,
220 self,
221 self.tree.terminal_indices().collect_vec(),
222 FunctionComposition {},
223 NoOpVis {},
224 );
225 }
226 }
227
228 pub fn generic_composition_inplace<I, C, V>(
232 lhs: &AffTree<K>,
233 rhs: &mut AffTree<K>,
234 terminals: I,
235 _schema: C,
236 mut visitor: V,
237 ) where
238 I: IntoIterator<Item = TreeIndex>,
239 C: CompositionSchema,
240 V: CompositionVisitor,
241 {
242 let iter = terminals.into_iter();
243
244 visitor.start_composition(iter.size_hint().0);
245
246 for terminal_idx in iter {
247 debug!("Processing terminal (rhs): id={:?}", terminal_idx);
248 let terminal = rhs
249 .tree
250 .tree_node(terminal_idx)
251 .expect("All nodes of the iterator should be terminals in the rhs tree");
252 assert!(
253 terminal.isleaf,
254 "Terminal node of given iterator should be a leaf"
255 );
256
257 let terminal_aff: AffFuncBase<FunctionT, ndarray::OwnedRepr<f64>> =
258 terminal.value.aff.clone();
259 let new_root_aff = match lhs.tree.get_root().isleaf {
260 true => C::update_terminal(&lhs.tree.get_root().value.aff, &terminal_aff),
261 false => C::update_decision(&lhs.tree.get_root().value.aff, &terminal_aff),
262 };
263 debug!("New terminal value: {}", new_root_aff);
264
265 visitor.start_subtree(terminal_idx);
266 let mut n_nodes = 0;
267
268 rhs.update_node(terminal_idx, new_root_aff).unwrap();
270
271 let mut stack: Vec<(TreeIndex, TreeIndex)> =
272 Vec::with_capacity(lhs.tree.num_terminals());
273 stack.push((lhs.tree.get_root_idx(), terminal_idx));
274
275 while let Some((parent0_idx, parent1_idx)) = stack.pop() {
276 let mut created_children = 0;
277 let mut skipped_children = 0;
278 let mut label_created = None;
279
280 for edg in lhs.tree.children(parent0_idx) {
281 let child0_idx = edg.target_idx;
282 let child0 = edg.target_value;
283 let label = edg.label;
284
285 let is_leaf = lhs.tree.is_leaf(child0_idx).unwrap();
286 debug!(
287 "Processing node from left tree with id={:?} ({})",
288 child0_idx,
289 if is_leaf { "T" } else { "D" }
290 );
291 let child1_aff = match is_leaf {
292 true => C::update_terminal(&child0.aff, &terminal_aff),
293 false => C::update_decision(&child0.aff, &terminal_aff),
294 };
295 debug!("New node value: {}", child1_aff);
296
297 let child1_idx = rhs
298 .tree
299 .add_child_node(parent1_idx, label, AffContent::new(child1_aff))
300 .unwrap();
301
302 if C::explore(rhs, parent1_idx, child1_idx) {
304 stack.push((child0_idx, child1_idx));
305 created_children += 1;
306 n_nodes += 1;
307 label_created = Some(label);
308 } else {
309 skipped_children += 1;
310 rhs.tree.remove_child(parent1_idx, label);
311 }
312 }
313
314 if created_children == 1 && created_children + skipped_children == K {
316 debug!("Forwarding node");
317 rhs.tree
319 .merge_child_with_parent(parent1_idx, label_created.unwrap())
320 .unwrap();
321 }
322 }
323
324 visitor.finish_subtree(n_nodes);
325 }
326 visitor.finish_composition();
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use ndarray::arr1;
333
334 use super::*;
335 use crate::{aff, path};
336
337 fn init_logger() {
338 use env_logger::Target;
339 use log::LevelFilter;
340
341 let _ = env_logger::builder()
342 .is_test(true)
343 .filter_module("minilp", LevelFilter::Error)
344 .target(Target::Stdout)
345 .filter_level(LevelFilter::Debug)
346 .try_init();
347 }
348
349 #[test]
350 fn test_compose() {
351 init_logger();
352
353 let mut tree0 = AffTree::<2>::from_aff(aff!([[1., 0.]] + [2.]));
354 tree0
355 .add_child_node(0, 0, aff!([[2, 0], [0, 2]] + [1, 0]))
356 .unwrap();
357 tree0
358 .add_child_node(0, 1, aff!([[2, 0], [0, 2]] + [0, 1]))
359 .unwrap();
360
361 let mut tree1 = AffTree::<2>::from_aff(aff!([[-0.5, 0.]] + [-1.]));
362 tree1
363 .add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0]))
364 .unwrap();
365 tree1
366 .add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5]))
367 .unwrap();
368
369 let mut comp = tree0.clone();
370
371 let terminals = comp.tree.terminal_indices().collect_vec();
372 AffTree::generic_composition_inplace(
373 &tree1,
374 &mut comp,
375 terminals,
376 FunctionComposition {},
377 CompositionConsole::new(),
378 );
379
380 assert_eq!(
381 comp.tree.node_value(path!(comp.tree, 0)).unwrap().aff,
382 aff!([-1, 0] + -0.5)
383 );
384
385 assert_eq!(
386 comp.tree.node_value(path!(comp.tree, 1, 1)).unwrap().aff,
387 aff!([[-6, 0], [0, -6]] + [0, 2])
388 );
389
390 assert_eq!(
391 tree1
392 .evaluate(&tree0.evaluate(&arr1(&[2., -7.])).unwrap())
393 .unwrap(),
394 comp.evaluate(&arr1(&[2., -7.])).unwrap()
395 );
396
397 assert_eq!(
398 tree1
399 .evaluate(&tree0.evaluate(&arr1(&[-1., -0.3])).unwrap())
400 .unwrap(),
401 comp.evaluate(&arr1(&[-1., -0.3])).unwrap()
402 );
403
404 assert_eq!(
405 tree1
406 .evaluate(&tree0.evaluate(&arr1(&[12., 3.])).unwrap())
407 .unwrap(),
408 comp.evaluate(&arr1(&[12., 3.])).unwrap()
409 );
410
411 assert_eq!(comp.tree.len(), 7);
413 }
414
415 #[test]
416 fn test_compose_infeasible() {
417 init_logger();
418
419 let mut dd0 = AffTree::<2>::from_aff(aff!([[1., 0.]] + [2.]));
420 dd0.add_child_node(0, 0, aff!([[2, 0], [0, 2]] + [1, 0]))
421 .unwrap();
422 dd0.add_child_node(0, 1, aff!([[2, 0], [0, 2]] + [0, 1]))
423 .unwrap();
424
425 let mut dd1 = AffTree::<2>::from_aff(aff!([[-0.5, 0.]] + [-1.]));
426 dd1.add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0]))
427 .unwrap();
428 dd1.add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5]))
429 .unwrap();
430
431 dd0.compose::<true, false>(&dd1);
432
433 assert_eq!(dd0.tree.len(), 5);
436 assert_eq!(dd1.tree.len(), 3);
438 }
439
440 macro_rules! value_at {
441 ($tree:expr , $( $label:literal ),* ) => {
442 $tree.tree.node_value(path!($tree.tree, $( $label ),* )).unwrap().aff
443 }
444 }
445
446 #[test]
447 #[rustfmt::skip]
448 fn test_compose_exact() {
449 init_logger();
450
451 let mut tree0 = AffTree::<2>::from_aff(aff!([1, 0, 0] + 2));
452 tree0.add_child_node(0, 0, aff!([[2, 0, 0], [0, 2, 0]] + [-4, 0])).unwrap();
453 tree0.add_child_node(0, 1, aff!([[2, 0, 0], [0, 0, 2]] + [-8, 1])).unwrap();
454
455 let mut tree1 = AffTree::<2>::from_aff(aff!([[0.5, 0.]] + [-1.]));
456 tree1.add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0])).unwrap();
457 tree1.add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5])).unwrap();
458
459 tree0.compose::<false, false>(&tree1);
460
461 eprintln!("{}", &tree0);
462
463 assert_eq!(
464 tree0.tree.node_value(0).unwrap().aff,
465 aff!([1, 0, 0] + 2)
466 );
467
468 assert_eq!(
469 value_at!(tree0, 0),
470 aff!([1, 0, 0] + 1)
471 );
472
473 assert_eq!(
474 value_at!(tree0, 1),
475 aff!([1, 0, 0] + 3)
476 );
477
478 assert_eq!(
479 value_at!(tree0, 0, 0),
480 aff!([[6, 0, 0], [0, 6, 0]] + [-12 + 5, 0])
481 );
482
483 assert_eq!(
484 value_at!(tree0, 0, 1),
485 aff!([[-6, 0, 0], [0, -6, 0]] + [12, 5])
486 );
487
488 assert_eq!(
489 value_at!(tree0, 1, 0),
490 aff!([[6, 0, 0], [0, 0, 6]] + [-24 + 5, 3])
491 );
492
493 assert_eq!(
494 value_at!(tree0, 1, 1),
495 aff!([[-6, 0, 0], [0, 0, -6]] + [24, -3 + 5])
496 );
497 }
498
499 #[test]
500 #[rustfmt::skip]
501 fn test_compose_exact_infeasible() {
502 init_logger();
503
504 let mut tree0 = AffTree::<2>::from_aff(aff!([1, 0, 0] + 2));
505 tree0.add_child_node(0, 0, aff!([[2, 0, 0], [0, 2, 0]] + [-4, 0])).unwrap();
506 tree0.add_child_node(0, 1, aff!([[2, 0, 0], [0, 0, 2]] + [-8, 1])).unwrap();
507
508 let mut tree1 = AffTree::<2>::from_aff(aff!([[0.5, 0.]] + [-1.]));
509 tree1.add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0])).unwrap();
510 tree1.add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5])).unwrap();
511
512 tree0.compose::<true, false>(&tree1);
513
514 eprintln!("{}", &tree0);
515
516 assert_eq!(
517 tree0.tree.node_value(0).unwrap().aff,
518 aff!([1, 0, 0] + 2)
519 );
520
521 assert_eq!(
522 value_at!(tree0, 0),
523 aff!([[6, 0, 0], [0, 6, 0]] + [-12 + 5, 0])
524 );
525
526 assert_eq!(
527 value_at!(tree0, 1),
528 aff!([[-6, 0, 0], [0, 0, -6]] + [24, -3 + 5])
529 );
530 }
531
532 #[test]
533 fn test_terminal_tree() {
534 init_logger();
535
536 let mut dd0 = AffTree::<2>::from_slice(&arr1(&[f64::NAN, 0.5]));
537
538 let dd1 = AffTree::<2>::from_poly(Polytope::hypercube(2, 1.0), AffFunc::identity(2), None)
539 .unwrap();
540
541 dd0.compose::<false, false>(&dd1);
542
543 assert_eq!(dd0.len(), 5);
544 }
545}