Skip to main content

ad_ess/
trellis.rs

1use rug::Integer;
2use std::collections::HashSet;
3
4/// [Trellis] is a data structure to hold a bounded trellis
5///
6/// Trellis nodes hold a [rug::Integer] and are indexed by `stage` (0..n_max)
7/// and `weight_level` (one of the accepted weight levels).
8///
9/// `weight_levels` for each `stage` are returned by [Trellis::get_weight_levels()].
10/// Node values can be read and set by using the [Trellis::get()] and [Trellis::set()]
11/// methods.
12#[derive(Debug)]
13pub struct Trellis {
14    pub threshold: usize,
15    pub n_max: usize,
16    weights: Vec<usize>,
17    weight_levels: Vec<usize>,
18    weight_level_lookup: Vec<i64>,
19    sorted_weights: Vec<(usize, usize)>,
20    data: Vec<Vec<Integer>>,
21}
22
23impl Trellis {
24    /// Create a new [Trellis] instance
25    ///
26    /// The smallest weight must be 0
27    pub fn new(threshold: usize, n_max: usize, weights: &[usize]) -> Trellis {
28        assert_eq!(*weights.iter().min().unwrap(), 0);
29
30        let mut sorted_weights: Vec<(usize, usize)> = weights.iter().copied().enumerate().collect();
31        sorted_weights.sort_by_key(|&w_tuple| w_tuple.1);
32
33        let weight_levels = Trellis::calc_weight_levels(threshold, weights);
34        let weight_level_lookup = Trellis::make_weight_level_lookup(&weight_levels);
35
36        let data = vec![vec![Integer::from(0); weight_levels.len()]; 1 + n_max];
37
38        Trellis {
39            threshold,
40            n_max,
41            weights: weights.to_vec(),
42            weight_levels,
43            weight_level_lookup,
44            sorted_weights,
45            data,
46        }
47    }
48
49    pub fn new_like(trellis: &Trellis) -> Trellis {
50        Trellis::new(trellis.threshold, trellis.n_max, &trellis.get_weights())
51    }
52
53    pub fn new_expandable(n_max: usize, weights: &[usize]) -> Trellis {
54        assert_eq!(*weights.iter().min().unwrap(), 0);
55
56        let mut sorted_weights: Vec<(usize, usize)> = weights.iter().copied().enumerate().collect();
57        sorted_weights.sort_by_key(|&w_tuple| w_tuple.1);
58
59        let max_weight = weights
60            .iter()
61            .max()
62            .expect("Already checked if empty in assert above");
63        let max_threshold = n_max * max_weight;
64        let all_wls = Trellis::calc_weight_levels(max_threshold, weights);
65        let wl_lookup = Trellis::make_weight_level_lookup(&all_wls);
66
67        let data = vec![Vec::<Integer>::new(); 1 + n_max];
68
69        let threshold = all_wls[0];
70
71        Trellis {
72            threshold,
73            n_max,
74            weights: weights.to_vec(),
75            weight_levels: all_wls,
76            weight_level_lookup: wl_lookup,
77            sorted_weights,
78            data,
79        }
80    }
81
82    fn calc_weight_levels(threshold: usize, weights: &[usize]) -> Vec<usize> {
83        let mut weight_levels = HashSet::new();
84        weight_levels.insert(0);
85
86        let mut new_exist = true;
87        //print!("Calculating weight levels ");
88        while new_exist {
89            new_exist = false;
90            let mut new_entries = vec![];
91            for wl in weight_levels.iter() {
92                for w in weights.iter() {
93                    let new_wl = wl + w;
94                    if new_wl <= threshold {
95                        if !weight_levels.contains(&new_wl) {
96                            new_exist = true;
97                        }
98                        new_entries.push(new_wl);
99                    }
100                }
101            }
102            for new_wl in new_entries.into_iter() {
103                weight_levels.insert(new_wl);
104            }
105            //print!(".");
106        }
107
108        // convert to sorted vec
109        let mut weight_levels: Vec<usize> = weight_levels.into_iter().collect();
110        weight_levels.sort();
111
112        //println!(" done");
113        weight_levels
114    }
115
116    fn make_weight_level_lookup(weight_levels: &[usize]) -> Vec<i64> {
117        let max_wl = weight_levels
118            .iter()
119            .max()
120            .expect("weight_levels must be non empty");
121        let mut wl_lookup = vec![-1; *max_wl + 1];
122        for (wl_idx, &wl) in weight_levels.iter().enumerate() {
123            wl_lookup[wl] = wl_idx as i64;
124        }
125        wl_lookup
126    }
127    fn wl_idx_valid(weight_level_index: i64) -> bool {
128        // use not negative as 0 is a valid index
129        !weight_level_index.is_negative()
130    }
131}
132
133impl Trellis {
134    fn wl_valid(&self, weight_level: usize) -> bool {
135        let weight_level_index = self.weight_level_lookup[weight_level];
136        Trellis::wl_idx_valid(weight_level_index)
137    }
138    /// Get function for trellis values
139    pub fn get(&self, stage: usize, weight_level: usize) -> Integer {
140        let weight_level_index = self.weight_level_lookup[weight_level];
141        assert!(Trellis::wl_idx_valid(weight_level_index));
142        self.data[stage][weight_level_index as usize].clone()
143    }
144    /// Get function for trellis values, returns 0 if `weight_level` is invalid
145    pub fn get_or_0(&self, stage: usize, weight_level: usize) -> Integer {
146        if weight_level >= self.weight_level_lookup.len() {
147            return Integer::from(0);
148        }
149        let weight_level_index = self.weight_level_lookup[weight_level];
150
151        if Trellis::wl_idx_valid(weight_level_index) {
152            self.data[stage][weight_level_index as usize].clone()
153        } else {
154            Integer::from(0)
155        }
156    }
157    pub fn get_stage(&self, stage: usize) -> Vec<Integer> {
158        self.data[stage].clone()
159    }
160    /// Set function for trellis values
161    pub fn set(&mut self, stage: usize, weight_level: usize, value: Integer) {
162        let weight_level_index = self.weight_level_lookup[weight_level];
163        assert!(Trellis::wl_idx_valid(weight_level_index));
164        self.data[stage][weight_level_index as usize] = value;
165    }
166    /// Function to add a value to an existing trellis value
167    pub fn add(&mut self, stage: usize, weight_level: usize, value: Integer) {
168        let weight_level_index = self.weight_level_lookup[weight_level];
169        assert!(Trellis::wl_idx_valid(weight_level_index));
170        self.data[stage][weight_level_index as usize] += value;
171    }
172    /// Returns the weight for the given weight index
173    pub fn get_weight(&self, weight_index: usize) -> usize {
174        self.weights[weight_index]
175    }
176    /// Returns the weights of this trellis
177    pub fn get_weights(&self) -> Vec<usize> {
178        self.weights.clone()
179    }
180    /// Returns the weight levels of this trellis
181    pub fn get_weight_levels(&self) -> Vec<usize> {
182        self.weight_levels.clone()
183    }
184    /// Returns the number of weight levels used by the stored data
185    pub fn get_num_weight_levels(&self) -> usize {
186        self.data[0].len()
187    }
188    /// Returns the index of the given weight level
189    pub fn get_weight_level_index(&self, weight_level: usize) -> usize {
190        let weight_level_index = self.weight_level_lookup[weight_level];
191        assert!(Trellis::wl_idx_valid(weight_level_index));
192        weight_level_index as usize
193    }
194    pub fn get_storage_dimensions(&self) -> (usize, usize) {
195        (self.data.len(), self.get_num_weight_levels())
196    }
197    /// Increase the trellis size by one weight level mooving in the provided trellis values
198    ///
199    /// Note: the values are removed from `new_values`
200    pub fn expand_with(&mut self, new_values: &mut Vec<Integer>) -> Result<(), &'static str> {
201        assert_eq!(new_values.len(), self.data.len());
202
203        let current_num_wls = self.get_num_weight_levels();
204        let new_num_wls = current_num_wls + 1;
205
206        let max_num_wls = self.weight_levels.len();
207        if new_num_wls <= max_num_wls {
208            for stage in self.data.iter_mut().rev() {
209                stage.push(new_values.pop().expect("checked lenghts in assert"))
210            }
211            self.threshold = self.weight_levels[new_num_wls - 1];
212            Ok(())
213        } else {
214            Err("Impossible to add another weight level, trellis to small")
215        }
216    }
217    /// Returns a [Vec] of (weight_index, weight_level) for each weight level reachable
218    /// from `weight_level` with a single step
219    ///
220    /// The weight levels are sorted in ascending order.
221    /// Multiple entries with the same weight level are sorted by weight index in ascending order.
222    pub fn get_successors(&self, weight_level: usize) -> Vec<(usize, usize)> {
223        let mut successors = Vec::with_capacity(self.weights.len());
224        for (weight_index, w) in self.sorted_weights.iter() {
225            let possible_successor = weight_level + w;
226            if possible_successor <= self.threshold {
227                successors.push((*weight_index, possible_successor));
228            }
229        }
230        successors
231    }
232    /// Returns a [Vec] of (weight_index, weight_level) for each weight level which
233    /// can reach `weight_level` with a single step
234    ///
235    /// The tuples are sorted in ascending order wrt. the weight_level values.
236    /// Multiple tuples with the same `weight_level` are sorted in descending order wrt. the
237    /// `weight_index`.
238    pub fn get_predecessors(&self, weight_level: usize) -> Vec<(usize, usize)> {
239        let mut predecessors = Vec::with_capacity(self.weights.len());
240        for (weight_index, w) in self.sorted_weights.iter().rev() {
241            if weight_level >= *w {
242                let possible_predecessor = weight_level - w;
243                if self.wl_valid(possible_predecessor) {
244                    predecessors.push((*weight_index, possible_predecessor));
245                }
246            }
247        }
248        predecessors
249    }
250}
251
252impl PartialEq for Trellis {
253    fn eq(&self, other: &Self) -> bool {
254        if self.get_storage_dimensions() != other.get_storage_dimensions()
255            || self.get_weights() != other.get_weights()
256        {
257            return false;
258        }
259        let (i_max, j_max) = self.get_storage_dimensions();
260        for i in 0..i_max {
261            for j in 0..j_max {
262                if self.data[i][j] != other.data[i][j] {
263                    return false;
264                }
265            }
266        }
267        true
268    }
269}
270
271impl Eq for Trellis {}