affinitree/pwl/
impl_composition.rs

1//   Copyright 2025 affinitree developers
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15use std::time::{Duration, Instant};
16
17use console::style;
18use indicatif::{ProgressBar, ProgressStyle};
19use itertools::Itertools;
20use log::debug;
21
22use super::afftree::*;
23use crate::linalg::affine::*;
24use crate::pwl::node::AffContent;
25use crate::tree::graph::*;
26
27/// A vistor for the composition of two [``AffTree``]s.
28///
29/// Use cases include logging or to display progress.
30pub trait CompositionVisitor {
31    fn start_composition(&mut self, expected_iterations: usize);
32    fn start_subtree(&mut self, node: TreeIndex);
33    fn finish_subtree(&mut self, n_nodes: usize);
34    fn finish_composition(&mut self);
35}
36
37/// A set of rules that define an algebraic operation over two [``AffTree``] instances.
38pub trait CompositionSchema {
39    /// Creates a new decision node based on the two operands ``original`` and ``context``.
40    /// Here, ``original`` is the the value of the lhs tree and ``context`` of the rhs tree.
41    fn update_decision(original: &AffFunc, context: &AffFunc) -> AffFunc;
42
43    /// Creates a new terminal node based on the two operands ``original`` and ``context``.
44    /// Here, ``original`` is the the value of the lhs tree and ``context`` of the rhs tree.
45    fn update_terminal(original: &AffFunc, context: &AffFunc) -> AffFunc;
46
47    /// A filter to stop descending an edge in lhs.
48    fn explore<const K: usize>(context: &AffTree<K>, parent: TreeIndex, child: TreeIndex) -> bool;
49}
50
51/// A [``CompositionVisitor``] with no operations.
52#[derive(Clone, Debug)]
53pub struct NoOpVis {}
54
55impl CompositionVisitor for NoOpVis {
56    fn start_composition(&mut self, _: usize) {}
57
58    fn start_subtree(&mut self, _: TreeIndex) {}
59
60    fn finish_subtree(&mut self, _: usize) {}
61
62    fn finish_composition(&mut self) {}
63}
64
65/// A [``CompositionVisitor``] which displays a progress bar of the current state
66/// of the composition at the console.
67#[derive(Clone, Debug)]
68pub struct CompositionConsole {
69    pb: ProgressBar,
70    timer: Instant,
71    len: usize,
72}
73
74impl Default for CompositionConsole {
75    fn default() -> Self {
76        Self::new()
77    }
78}
79
80impl CompositionConsole {
81    pub fn new() -> CompositionConsole {
82        CompositionConsole {
83            pb: ProgressBar::hidden(),
84            timer: Instant::now(),
85            len: 0,
86        }
87    }
88}
89
90impl CompositionVisitor for CompositionConsole {
91    fn start_composition(&mut self, expected_iterations: usize) {
92        self.pb = ProgressBar::new(expected_iterations as u64);
93        let sty = ProgressStyle::default_bar()
94            .template(&format!(
95                "{: >12} {}",
96                style("Building").cyan().bold(),
97                "[{bar:25}] {pos:>2}/{len:2} ({elapsed})"
98            ))
99            .unwrap()
100            .progress_chars("=> ");
101        self.pb.set_style(sty.clone());
102        self.pb.enable_steady_tick(Duration::from_secs(5));
103
104        self.timer = Instant::now();
105        self.len = expected_iterations;
106    }
107
108    fn start_subtree(&mut self, _node: TreeIndex) {}
109
110    fn finish_subtree(&mut self, _n_nodes: usize) {
111        self.pb.inc(1);
112    }
113
114    fn finish_composition(&mut self) {
115        self.pb.finish_and_clear();
116    }
117}
118
119/// A [``CompositionSchema``] for mathematical function composition.
120#[derive(Clone, Debug)]
121pub struct FunctionComposition {}
122
123impl CompositionSchema for FunctionComposition {
124    fn update_decision(original: &AffFunc, context: &AffFunc) -> AffFunc {
125        AffFunc::from_mats(
126            original.mat.dot(&context.mat),
127            -original.mat.dot(&context.bias) + &original.bias,
128        )
129    }
130
131    fn update_terminal(original: &AffFunc, context: &AffFunc) -> AffFunc {
132        original.compose(context)
133    }
134
135    fn explore<const K: usize>(
136        _context: &AffTree<K>,
137        _parent: TreeIndex,
138        _child: TreeIndex,
139    ) -> bool {
140        true
141    }
142}
143
144/// A [``CompositionSchema``] for mathematical function composition which performs on-the-fly infeasible elimination.
145#[derive(Clone, Debug)]
146pub struct FunctionCompositionInfeasible {}
147
148impl CompositionSchema for FunctionCompositionInfeasible {
149    fn update_decision(original: &AffFunc, context: &AffFunc) -> AffFunc {
150        AffFunc::from_mats(
151            original.mat.dot(&context.mat),
152            -original.mat.dot(&context.bias) + &original.bias,
153        )
154    }
155
156    fn update_terminal(original: &AffFunc, context: &AffFunc) -> AffFunc {
157        original.compose(context)
158    }
159
160    fn explore<const K: usize>(context: &AffTree<K>, parent: TreeIndex, child: TreeIndex) -> bool {
161        context.is_edge_feasible(parent, child)
162    }
163}
164
165/// # Composition
166impl<const K: usize> AffTree<K> {
167    /// Performs mathematical function composition of ``self`` and ``other``.
168    ///
169    /// # Example
170    ///
171    /// ```rust
172    /// use affinitree::{aff, pwl::afftree::AffTree};
173    /// use ndarray::arr1;
174    ///
175    /// let mut tree0 = AffTree::<2>::from_aff(aff!([[1., 0.]] + [2.]));
176    /// tree0.add_child_node(0, 0, aff!([[2, 0], [0, 2]] + [1, 0]));
177    /// tree0.add_child_node(0, 1, aff!([[2, 0], [0, 2]] + [0, 1]));
178    ///
179    /// let mut tree1 = AffTree::<2>::from_aff(aff!([[-0.5, 0.]] + [-1.]));
180    /// tree1.add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0]));
181    /// tree1.add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5]));
182    ///
183    /// let mut comp = tree0.clone();
184    /// comp.compose::<false, false>(&tree1);
185    ///
186    /// assert_eq!(
187    ///     tree1.evaluate(&tree0.evaluate(&arr1(&[2., -7.])).unwrap()).unwrap(),
188    ///     comp.evaluate(&arr1(&[2., -7.])).unwrap()
189    /// );
190    /// ```
191    #[inline]
192    pub fn compose<const PRUNE: bool, const VERBOSE: bool>(&mut self, other: &AffTree<K>) {
193        if PRUNE && VERBOSE {
194            AffTree::<K>::generic_composition_inplace(
195                other,
196                self,
197                self.tree.terminal_indices().collect_vec(),
198                FunctionCompositionInfeasible {},
199                CompositionConsole::new(),
200            );
201        } else if PRUNE && !VERBOSE {
202            AffTree::<K>::generic_composition_inplace(
203                other,
204                self,
205                self.tree.terminal_indices().collect_vec(),
206                FunctionCompositionInfeasible {},
207                NoOpVis {},
208            );
209        } else if !PRUNE && VERBOSE {
210            AffTree::<K>::generic_composition_inplace(
211                other,
212                self,
213                self.tree.terminal_indices().collect_vec(),
214                FunctionComposition {},
215                CompositionConsole::new(),
216            );
217        } else {
218            AffTree::<K>::generic_composition_inplace(
219                other,
220                self,
221                self.tree.terminal_indices().collect_vec(),
222                FunctionComposition {},
223                NoOpVis {},
224            );
225        }
226    }
227
228    /// Applies the algebraic operation defined by ``schema`` to ``rhs`` and ``lhs``.
229    /// This operation is applied inplace modifying ``rhs``.
230    /// It is implemented as one tree traversal over ``lhs`` for each terminal in ``rhs``.
231    pub fn generic_composition_inplace<I, C, V>(
232        lhs: &AffTree<K>,
233        rhs: &mut AffTree<K>,
234        terminals: I,
235        _schema: C,
236        mut visitor: V,
237    ) where
238        I: IntoIterator<Item = TreeIndex>,
239        C: CompositionSchema,
240        V: CompositionVisitor,
241    {
242        let iter = terminals.into_iter();
243
244        visitor.start_composition(iter.size_hint().0);
245
246        for terminal_idx in iter {
247            debug!("Processing terminal (rhs): id={:?}", terminal_idx);
248            let terminal = rhs
249                .tree
250                .tree_node(terminal_idx)
251                .expect("All nodes of the iterator should be terminals in the rhs tree");
252            assert!(
253                terminal.isleaf,
254                "Terminal node of given iterator should be a leaf"
255            );
256
257            let terminal_aff: AffFuncBase<FunctionT, ndarray::OwnedRepr<f64>> =
258                terminal.value.aff.clone();
259            let new_root_aff = match lhs.tree.get_root().isleaf {
260                true => C::update_terminal(&lhs.tree.get_root().value.aff, &terminal_aff),
261                false => C::update_decision(&lhs.tree.get_root().value.aff, &terminal_aff),
262            };
263            debug!("New terminal value: {}", new_root_aff);
264
265            visitor.start_subtree(terminal_idx);
266            let mut n_nodes = 0;
267
268            // Update the stored function to the new predicate while keeping the cache intact
269            rhs.update_node(terminal_idx, new_root_aff).unwrap();
270
271            let mut stack: Vec<(TreeIndex, TreeIndex)> =
272                Vec::with_capacity(lhs.tree.num_terminals());
273            stack.push((lhs.tree.get_root_idx(), terminal_idx));
274
275            while let Some((parent0_idx, parent1_idx)) = stack.pop() {
276                let mut created_children = 0;
277                let mut skipped_children = 0;
278                let mut label_created = None;
279
280                for edg in lhs.tree.children(parent0_idx) {
281                    let child0_idx = edg.target_idx;
282                    let child0 = edg.target_value;
283                    let label = edg.label;
284
285                    let is_leaf = lhs.tree.is_leaf(child0_idx).unwrap();
286                    debug!(
287                        "Processing node from left tree with id={:?} ({})",
288                        child0_idx,
289                        if is_leaf { "T" } else { "D" }
290                    );
291                    let child1_aff = match is_leaf {
292                        true => C::update_terminal(&child0.aff, &terminal_aff),
293                        false => C::update_decision(&child0.aff, &terminal_aff),
294                    };
295                    debug!("New node value: {}", child1_aff);
296
297                    let child1_idx = rhs
298                        .tree
299                        .add_child_node(parent1_idx, label, AffContent::new(child1_aff))
300                        .unwrap();
301
302                    // Test feasibility of newly created edge, remove if infeasible
303                    if C::explore(rhs, parent1_idx, child1_idx) {
304                        stack.push((child0_idx, child1_idx));
305                        created_children += 1;
306                        n_nodes += 1;
307                        label_created = Some(label);
308                    } else {
309                        skipped_children += 1;
310                        rhs.tree.remove_child(parent1_idx, label);
311                    }
312                }
313
314                // In the case of no children remove_child already cleans up the tree
315                if created_children == 1 && created_children + skipped_children == K {
316                    debug!("Forwarding node");
317                    // Move affine function to parent node and clean up tree
318                    rhs.tree
319                        .merge_child_with_parent(parent1_idx, label_created.unwrap())
320                        .unwrap();
321                }
322            }
323
324            visitor.finish_subtree(n_nodes);
325        }
326        visitor.finish_composition();
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use ndarray::arr1;
333
334    use super::*;
335    use crate::{aff, path};
336
337    fn init_logger() {
338        use env_logger::Target;
339        use log::LevelFilter;
340
341        let _ = env_logger::builder()
342            .is_test(true)
343            .filter_module("minilp", LevelFilter::Error)
344            .target(Target::Stdout)
345            .filter_level(LevelFilter::Debug)
346            .try_init();
347    }
348
349    #[test]
350    fn test_compose() {
351        init_logger();
352
353        let mut tree0 = AffTree::<2>::from_aff(aff!([[1., 0.]] + [2.]));
354        tree0
355            .add_child_node(0, 0, aff!([[2, 0], [0, 2]] + [1, 0]))
356            .unwrap();
357        tree0
358            .add_child_node(0, 1, aff!([[2, 0], [0, 2]] + [0, 1]))
359            .unwrap();
360
361        let mut tree1 = AffTree::<2>::from_aff(aff!([[-0.5, 0.]] + [-1.]));
362        tree1
363            .add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0]))
364            .unwrap();
365        tree1
366            .add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5]))
367            .unwrap();
368
369        let mut comp = tree0.clone();
370
371        let terminals = comp.tree.terminal_indices().collect_vec();
372        AffTree::generic_composition_inplace(
373            &tree1,
374            &mut comp,
375            terminals,
376            FunctionComposition {},
377            CompositionConsole::new(),
378        );
379
380        assert_eq!(
381            comp.tree.node_value(path!(comp.tree, 0)).unwrap().aff,
382            aff!([-1, 0] + -0.5)
383        );
384
385        assert_eq!(
386            comp.tree.node_value(path!(comp.tree, 1, 1)).unwrap().aff,
387            aff!([[-6, 0], [0, -6]] + [0, 2])
388        );
389
390        assert_eq!(
391            tree1
392                .evaluate(&tree0.evaluate(&arr1(&[2., -7.])).unwrap())
393                .unwrap(),
394            comp.evaluate(&arr1(&[2., -7.])).unwrap()
395        );
396
397        assert_eq!(
398            tree1
399                .evaluate(&tree0.evaluate(&arr1(&[-1., -0.3])).unwrap())
400                .unwrap(),
401            comp.evaluate(&arr1(&[-1., -0.3])).unwrap()
402        );
403
404        assert_eq!(
405            tree1
406                .evaluate(&tree0.evaluate(&arr1(&[12., 3.])).unwrap())
407                .unwrap(),
408            comp.evaluate(&arr1(&[12., 3.])).unwrap()
409        );
410
411        // full binary tree 1 + 2 + 4
412        assert_eq!(comp.tree.len(), 7);
413    }
414
415    #[test]
416    fn test_compose_infeasible() {
417        init_logger();
418
419        let mut dd0 = AffTree::<2>::from_aff(aff!([[1., 0.]] + [2.]));
420        dd0.add_child_node(0, 0, aff!([[2, 0], [0, 2]] + [1, 0]))
421            .unwrap();
422        dd0.add_child_node(0, 1, aff!([[2, 0], [0, 2]] + [0, 1]))
423            .unwrap();
424
425        let mut dd1 = AffTree::<2>::from_aff(aff!([[-0.5, 0.]] + [-1.]));
426        dd1.add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0]))
427            .unwrap();
428        dd1.add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5]))
429            .unwrap();
430
431        dd0.compose::<true, false>(&dd1);
432
433        // one node is infeasible and should be eliminated
434        // and one forwarded
435        assert_eq!(dd0.tree.len(), 5);
436        // dd1 should remain unchanged
437        assert_eq!(dd1.tree.len(), 3);
438    }
439
440    macro_rules! value_at {
441        ($tree:expr , $( $label:literal ),* ) => {
442            $tree.tree.node_value(path!($tree.tree, $( $label ),* )).unwrap().aff
443        }
444    }
445
446    #[test]
447    #[rustfmt::skip]
448    fn test_compose_exact() {
449        init_logger();
450
451        let mut tree0 = AffTree::<2>::from_aff(aff!([1, 0, 0] + 2));
452        tree0.add_child_node(0, 0, aff!([[2, 0, 0], [0, 2, 0]] + [-4, 0])).unwrap();
453        tree0.add_child_node(0, 1, aff!([[2, 0, 0], [0, 0, 2]] + [-8, 1])).unwrap();
454
455        let mut tree1 = AffTree::<2>::from_aff(aff!([[0.5, 0.]] + [-1.]));
456        tree1.add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0])).unwrap();
457        tree1.add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5])).unwrap();
458
459        tree0.compose::<false, false>(&tree1);
460
461        eprintln!("{}", &tree0);
462
463        assert_eq!(
464            tree0.tree.node_value(0).unwrap().aff,
465            aff!([1, 0, 0] + 2)
466        );
467
468        assert_eq!(
469            value_at!(tree0, 0),
470            aff!([1, 0, 0] + 1)
471        );
472
473        assert_eq!(
474            value_at!(tree0, 1),
475            aff!([1, 0, 0] + 3)
476        );
477
478        assert_eq!(
479            value_at!(tree0, 0, 0),
480            aff!([[6, 0, 0], [0, 6, 0]] + [-12 + 5, 0])
481        );
482
483        assert_eq!(
484            value_at!(tree0, 0, 1),
485            aff!([[-6, 0, 0], [0, -6, 0]] + [12, 5])
486        );
487
488        assert_eq!(
489            value_at!(tree0, 1, 0),
490            aff!([[6, 0, 0], [0, 0, 6]] + [-24 + 5, 3])
491        );
492
493        assert_eq!(
494            value_at!(tree0, 1, 1),
495            aff!([[-6, 0, 0], [0, 0, -6]] + [24, -3 + 5])
496        );
497    }
498
499    #[test]
500    #[rustfmt::skip]
501    fn test_compose_exact_infeasible() {
502        init_logger();
503
504        let mut tree0 = AffTree::<2>::from_aff(aff!([1, 0, 0] + 2));
505        tree0.add_child_node(0, 0, aff!([[2, 0, 0], [0, 2, 0]] + [-4, 0])).unwrap();
506        tree0.add_child_node(0, 1, aff!([[2, 0, 0], [0, 0, 2]] + [-8, 1])).unwrap();
507
508        let mut tree1 = AffTree::<2>::from_aff(aff!([[0.5, 0.]] + [-1.]));
509        tree1.add_child_node(0, 0, aff!([[3, 0], [0, 3]] + [5, 0])).unwrap();
510        tree1.add_child_node(0, 1, aff!([[-3, 0], [0, -3]] + [0, 5])).unwrap();
511
512        tree0.compose::<true, false>(&tree1);
513
514        eprintln!("{}", &tree0);
515
516        assert_eq!(
517            tree0.tree.node_value(0).unwrap().aff,
518            aff!([1, 0, 0] + 2)
519        );
520
521        assert_eq!(
522            value_at!(tree0, 0),
523            aff!([[6, 0, 0], [0, 6, 0]] + [-12 + 5, 0])
524        );
525
526        assert_eq!(
527            value_at!(tree0, 1),
528            aff!([[-6, 0, 0], [0, 0, -6]] + [24, -3 + 5])
529        );
530    }
531
532    #[test]
533    fn test_terminal_tree() {
534        init_logger();
535
536        let mut dd0 = AffTree::<2>::from_slice(&arr1(&[f64::NAN, 0.5]));
537
538        let dd1 = AffTree::<2>::from_poly(Polytope::hypercube(2, 1.0), AffFunc::identity(2), None)
539            .unwrap();
540
541        dd0.compose::<false, false>(&dd1);
542
543        assert_eq!(dd0.len(), 5);
544    }
545}