Skip to main content

ad_ess/
rts.rs

1use rug::Complete;
2use rug::Integer;
3use rug::Rational;
4
5use crate::trellis::Trellis;
6use crate::trellis_utils;
7use crate::utils;
8
9pub struct RTS {
10    pub trellis: Trellis,
11}
12
13impl RTS {
14    /// Returns an [RTS] instance which encodes at least `num_bits` bits
15    ///
16    /// The smallest possible trellis that encodes `num_bits` bits is used, in
17    /// some cases this trellis is capable of encoding more than `num_bits` bits.
18    pub fn new(num_bits: usize, n_max: usize, weights: &[usize]) -> RTS {
19        let trellis = trellis_utils::reverse_trellis_upto_num_sequences(
20            Integer::u_pow_u(2, num_bits as u32).complete(),
21            n_max,
22            weights,
23        )
24        .unwrap();
25        RTS { trellis }
26    }
27
28    /// Returns the amplitude value for a given weight index
29    fn weight_idx_to_amplitude(weight_index: usize) -> usize {
30        weight_index * 2 + 1
31    }
32    /// Returns the weight index for a given amplitude value
33    fn amplitude_to_weight_idx(amplitude: usize) -> usize {
34        (amplitude - 1) / 2
35    }
36    /// Replaces the amplitudes in a sequence by their weight indexes
37    fn amplitude_seq_to_weight_idx_seq(amplitude_sequence: &[usize]) -> Vec<usize> {
38        amplitude_sequence.iter().map(|a| (a - 1) / 2).collect()
39    }
40}
41
42impl RTS {
43    /// Returns the number of sequences that can be encoded / decoded
44    pub fn num_sequences(&self) -> Integer {
45        let n_max = self.trellis.n_max;
46        self.trellis.get_stage(n_max).iter().sum()
47    }
48    /// Returns the number of bits that can be encoded / decoded
49    pub fn num_bits(&self) -> u32 {
50        self.num_sequences().significant_bits() - 1
51    }
52    /// Returns the weights used by the internal trellis
53    pub fn get_weights(&self) -> Vec<usize> {
54        self.trellis.get_weights()
55    }
56    /// Returns the distribution [RTS] is optimizing for
57    pub fn get_distribution(&self, res_factor: f32) -> Vec<f32> {
58        utils::distribution_from_weights(&self.get_weights(), res_factor)
59    }
60    /// Returns the amplitude sequence for a given index
61    pub fn sequence_for_index(&self, index: &Integer) -> Vec<usize> {
62        assert!(index < &self.num_sequences(), "Index out of range!");
63
64        let n_max = self.trellis.n_max;
65        let mut wl_path = vec![0usize; n_max + 1];
66
67        let mut lower_nodes_sum = Integer::from(0);
68        for (wl_idx, node_value) in self.trellis.get_stage(n_max).iter().enumerate() {
69            lower_nodes_sum += node_value;
70            if &lower_nodes_sum > index {
71                wl_path[n_max] = self.trellis.get_weight_levels()[wl_idx];
72                lower_nodes_sum -= node_value;
73                break;
74            }
75        }
76        let mut local_index = index - lower_nodes_sum;
77        let mut weight_idx_seq = vec![0usize; n_max];
78        for stage in (0..n_max).rev() {
79            lower_nodes_sum = Integer::from(0);
80            // caching predecessors may improve speed
81            for (w_idx, pred_wl) in self.trellis.get_predecessors(wl_path[stage + 1]) {
82                let node_value = self.trellis.get(stage, pred_wl);
83
84                lower_nodes_sum += &node_value;
85                if lower_nodes_sum > local_index {
86                    wl_path[stage] = pred_wl;
87                    weight_idx_seq[stage] = w_idx;
88                    lower_nodes_sum -= &node_value;
89                    break;
90                }
91            }
92            local_index -= lower_nodes_sum;
93        }
94        weight_idx_seq
95            .iter()
96            .map(|&weight_idx| RTS::weight_idx_to_amplitude(weight_idx))
97            .collect()
98    }
99    /// Returns the index for a given amplitude sequence
100    pub fn index_for_sequence(&self, amplitude_sequence: &[usize]) -> Integer {
101        let n_max = self.trellis.n_max;
102
103        let weight_idx_seq = RTS::amplitude_seq_to_weight_idx_seq(amplitude_sequence);
104        let weights = self.trellis.get_weights();
105        let weight_seq: Vec<usize> = weight_idx_seq.iter().map(|&w_idx| weights[w_idx]).collect();
106        let wl_path = utils::cumsum(&weight_seq);
107
108        let num_lower_end_nodes = self.trellis.get_weight_level_index(wl_path[n_max]);
109
110        let mut index: Integer = self
111            .trellis
112            .get_stage(self.trellis.n_max)
113            .iter()
114            .take(num_lower_end_nodes)
115            .sum();
116
117        for (idx, (&weight_idx, wl_transition)) in weight_idx_seq
118            .iter()
119            .zip(wl_path.windows(2))
120            .enumerate()
121            .rev()
122        {
123            let stage = idx + 1;
124            if let &[predecessor_wl, wl] = wl_transition {
125                self.trellis
126                    .get_predecessors(wl)
127                    .iter()
128                    .take_while(|(possible_weight_idx, possible_predecessor_wl)| {
129                        *possible_predecessor_wl <= predecessor_wl
130                            && *possible_weight_idx != weight_idx
131                    })
132                    .for_each(|(_, possible_predecessor_wl)| {
133                        index += self.trellis.get(stage - 1, *possible_predecessor_wl);
134                    });
135            } else {
136                panic!("`window(2)` produced a window of length != 2");
137            }
138        }
139
140        index
141    }
142    fn count_amplitude_in_stage(
143        &self,
144        amplitude: usize,
145        stage: usize,
146        first_abandoned_seq: &[usize],
147    ) -> Integer {
148        let the_w_idx = RTS::amplitude_to_weight_idx(amplitude);
149        let the_weight = self.trellis.get_weight(the_w_idx);
150        let the_stage = stage;
151
152        let n_max = self.trellis.n_max;
153
154        let fas_w_idxs = RTS::amplitude_seq_to_weight_idx_seq(first_abandoned_seq);
155        let fas_weights: Vec<usize> = fas_w_idxs
156            .iter()
157            .map(|&w_idx| self.trellis.get_weight(w_idx))
158            .collect();
159        let fas_wls = utils::cumsum(&fas_weights);
160
161        // calculation is split according to the stage in which the considered sequences join the
162        // first abandoned sequence (FAS)
163        let mut amplitude_count = Integer::from(0);
164
165        // sequences that never join the FAS
166        amplitude_count += self
167            .trellis
168            .get_weight_levels()
169            .iter()
170            .take_while(|wl| *wl < fas_wls.last().unwrap())
171            .skip_while(|wl| **wl < the_weight) // ensure `wl - the_weight` is positive
172            .map(|wl| self.trellis.get_or_0(n_max - 1, *wl - the_weight))
173            .sum::<Integer>();
174
175        // sequences that join the FAS between stages `the_stage` + 2 and `n_max`
176        amplitude_count += (the_stage + 2..n_max + 1)
177            .flat_map(|stage| {
178                let ref_fas_w_idxs = &fas_w_idxs;
179                self.trellis
180                    .get_predecessors(fas_wls[stage])
181                    .into_iter()
182                    .take_while(move |(w_idx, _)| *w_idx != ref_fas_w_idxs[stage - 1])
183                    // ensure `predecessor_wl - the_weight >= 0`
184                    .skip_while(|(_, predecessor_wl)| *predecessor_wl < the_weight)
185                    .map(move |(_, predecessor_wl)| {
186                        self.trellis
187                            .get_or_0(stage - 2, predecessor_wl - the_weight)
188                    })
189            })
190            .sum::<Integer>();
191
192        // sequences that join the FAS in stage `the_stage` + 1
193        if (the_weight > fas_weights[the_stage]
194            // this condition depends on the internal ordering used by [Trellis::get_predecessor]
195            // if code is changed there this code might break
196            || (the_weight == fas_weights[the_stage] && the_w_idx > fas_w_idxs[the_stage]))
197        // ensure `fas_wls[the_stage + 1] - the_weight >= 0`
198        && fas_wls[the_stage + 1] >= the_weight
199        {
200            amplitude_count += self
201                .trellis
202                .get_or_0(the_stage, fas_wls[the_stage + 1] - the_weight)
203        }
204
205        // sequences that join the FAS in stage `the_stage` or before
206        if the_w_idx == fas_w_idxs[the_stage] {
207            amplitude_count += (1..the_stage + 1)
208                .flat_map(|stage| {
209                    let ref_fas_w_idxs = &fas_w_idxs;
210                    self.trellis
211                        .get_predecessors(fas_wls[stage])
212                        .into_iter()
213                        .take_while(move |(w_idx, _)| *w_idx != ref_fas_w_idxs[stage - 1])
214                        .map(move |(_, predecessor_wl)| {
215                            self.trellis.get_or_0(stage - 1, predecessor_wl)
216                        })
217                })
218                .sum::<Integer>();
219        }
220
221        amplitude_count
222    }
223    /// Returns the amplitude distribution as a [Vec]
224    ///
225    /// The amplitude distribution is valid if only sequences with indexes
226    /// representable with [self.num_bits] bits are used.
227    pub fn amplitude_distribution(&self) -> Vec<f32> {
228        let num_sequences_used = Integer::u_pow_u(2, self.num_bits()).complete();
229        if num_sequences_used == self.num_sequences() {
230            return self.amplitude_distribution_full_utilization();
231        }
232
233        let first_abandoned_seq = self.sequence_for_index(&num_sequences_used);
234        let n_max = self.trellis.n_max;
235
236        let num_weights = self.trellis.get_weights().len();
237        let amplitudes = (0..num_weights).map(RTS::weight_idx_to_amplitude);
238
239        let amplitude_counts = amplitudes.map(|amplitude| {
240            (0..n_max)
241                .map(|stage| self.count_amplitude_in_stage(amplitude, stage, &first_abandoned_seq))
242                .sum::<Integer>()
243        });
244
245        amplitude_counts
246            .map(|amplitude_count| {
247                Rational::from((&amplitude_count, &num_sequences_used * n_max)).to_f32()
248            })
249            .collect()
250    }
251    /// Returns the amplitude distribution as a [Vec]
252    ///
253    /// The amplitude distribution is valid if all sequences in the trellis
254    /// are used equiprobably.
255    pub fn amplitude_distribution_full_utilization(&self) -> Vec<f32> {
256        let n_max = self.trellis.n_max;
257        let weight_levels = self.trellis.get_weight_levels();
258        let threshold = self.trellis.threshold;
259        let num_sequences = self.num_sequences();
260
261        self.trellis
262            .get_weights()
263            .iter()
264            .map(|weight| {
265                let num_weight_occurences: Integer = weight_levels
266                    .iter()
267                    .take_while(|wl| *wl + *weight <= threshold)
268                    .map(|wl| self.trellis.get(n_max - 1, *wl))
269                    .sum();
270
271                Rational::from((&num_weight_occurences, &num_sequences)).to_f32()
272            })
273            .collect()
274    }
275    // /// Returns the average energy
276    // ///
277    // /// Assumes only indexes representable with [self.num_bits] bits are used.
278    // pub fn average_energy(&self) -> f32 {
279    // let amplitude_distribution = self.amplitude_distribution();
280    // amplitude_distribution
281    // .iter()
282    // .enumerate()
283    // .map(|(w_idx, p)| (RTS::weight_idx_to_amplitude(w_idx) as f32, p))
284    // .map(|(a, p)| a * a * p) // expected value of energy == squared amplitude * probability
285    // .sum::<f32>()
286    // }
287}