ddk_trie/
multi_oracle_trie.rs

1//! # MultiOracleTrie
2//! Data structure and functions used to store adaptor signature information
3//! for numerical outcome DLC with t of n oracles where at least t oracles
4//! need to sign the same outcome for the contract to be able to close.
5
6use crate::combination_iterator::CombinationIterator;
7use crate::digit_decomposition::group_by_ignoring_digits;
8use crate::digit_trie::{DigitTrie, DigitTrieDump, DigitTrieIter};
9use crate::multi_trie::{MultiTrie, MultiTrieDump, MultiTrieIterator};
10use crate::utils::{get_value_callback, pre_pad_vec};
11use crate::{DlcTrie, IndexedPath, LookupResult, OracleNumericInfo, RangeInfo, TrieIterInfo};
12use ddk_dlc::{Error, RangePayout};
13
14/// Data structure used to store adaptor signature information for numerical
15/// outcome DLC with t of n oracles where at least t oracles need to sign the
16/// same outcome for the contract to be able to close.
17#[derive(Clone)]
18pub struct MultiOracleTrie {
19    /// The underlying trie data structure.
20    digit_trie: DigitTrie<Vec<RangeInfo>>,
21    threshold: usize,
22    oracle_numeric_infos: OracleNumericInfo,
23    extra_cover_trie: Option<MultiTrie<RangeInfo>>,
24}
25
26/// Container for a dump of a MultiOracleTrie used for serialization purpose.
27pub struct MultiOracleTrieDump {
28    /// A dump of the underlying digit trie.
29    pub digit_trie_dump: DigitTrieDump<Vec<RangeInfo>>,
30    /// The required number of oracles for this trie.
31    pub threshold: usize,
32    /// Information about each oracle numerical representation.
33    pub oracle_numeric_infos: OracleNumericInfo,
34    /// A dump of the trie for extra coverage.
35    pub extra_cover_trie_dump: Option<MultiTrieDump<RangeInfo>>,
36}
37
38impl MultiOracleTrie {
39    /// Dump the trie information.
40    pub fn dump(&self) -> MultiOracleTrieDump {
41        MultiOracleTrieDump {
42            digit_trie_dump: self.digit_trie.dump(),
43            threshold: self.threshold,
44            oracle_numeric_infos: self.oracle_numeric_infos.clone(),
45            extra_cover_trie_dump: self.extra_cover_trie.as_ref().map(|trie| trie.dump()),
46        }
47    }
48
49    /// Recover a MultiOracleTrie from a dump.
50    pub fn from_dump(dump: MultiOracleTrieDump) -> MultiOracleTrie {
51        let MultiOracleTrieDump {
52            digit_trie_dump,
53            threshold,
54            oracle_numeric_infos,
55            extra_cover_trie_dump,
56        } = dump;
57        MultiOracleTrie {
58            digit_trie: DigitTrie::from_dump(digit_trie_dump),
59            threshold,
60            oracle_numeric_infos,
61            extra_cover_trie: extra_cover_trie_dump.map(MultiTrie::from_dump),
62        }
63    }
64
65    fn get_agreeing_oracles(
66        &self,
67        paths: &[(usize, Vec<usize>)],
68    ) -> Option<Vec<(Vec<usize>, Vec<usize>)>> {
69        let mut hash_set: std::collections::HashMap<Vec<usize>, Vec<usize>> =
70            std::collections::HashMap::new();
71
72        for path in paths {
73            let index = path.0;
74            let outcome_path = &path.1;
75
76            if let Some(index_set) = hash_set.get_mut(outcome_path) {
77                index_set.push(index);
78            } else {
79                let index_set = vec![index];
80                hash_set.insert(outcome_path.to_vec(), index_set);
81            }
82        }
83
84        if hash_set.is_empty() {
85            return None;
86        }
87
88        let mut values: Vec<_> = hash_set.into_iter().collect();
89        values.sort_by(|x, y| x.1.len().partial_cmp(&y.1.len()).unwrap());
90        let res = values
91            .into_iter()
92            .filter(|x| x.1.len() >= self.threshold)
93            .collect::<Vec<_>>();
94        if !res.is_empty() {
95            Some(res)
96        } else {
97            None
98        }
99    }
100
101    /// Lookup for nodes whose path is either equal or a prefix of `path`.
102    pub fn look_up(&self, paths: &[(usize, Vec<usize>)]) -> Option<(RangeInfo, Vec<IndexedPath>)> {
103        let min_nb_digits = self.oracle_numeric_infos.get_min_nb_digits();
104        // Take all the paths that have a max value of base^min_nb_digits - 1, and
105        // shorten them to min_nb_digits.
106        let stripped_paths = paths
107            .iter()
108            .filter_map(|x| {
109                let extra_len = x.1.len().checked_sub(min_nb_digits)?;
110                if extra_len == 0 {
111                    Some((x.0, x.1.clone()))
112                } else if x.1.iter().take(extra_len).all(|x| *x == 0) {
113                    let mut cloned = x.1.clone();
114                    cloned.drain(..extra_len);
115                    Some((x.0, cloned))
116                } else {
117                    None
118                }
119            })
120            .collect::<Vec<_>>();
121        // Try to get the combinations of at least threshold oracles that agree on the outcome.
122        let agreeing_combinations = self.get_agreeing_oracles(&stripped_paths);
123        if let Some(sufficient_combinations) = agreeing_combinations {
124            for (path, combination) in sufficient_combinations {
125                debug_assert_eq!(
126                    path.len(),
127                    min_nb_digits,
128                    "Expected length {} got length {}",
129                    min_nb_digits,
130                    path.len()
131                );
132
133                if let Some(res) = self.digit_trie.look_up(&path) {
134                    let sufficient_combination: Vec<_> =
135                        combination.into_iter().take(self.threshold).collect();
136                    if let Some(position) = CombinationIterator::new(
137                        self.oracle_numeric_infos.nb_digits.len(),
138                        self.threshold,
139                    )
140                    .get_index_for_combination(&sufficient_combination)
141                    {
142                        return Some((
143                            res[0].value[position].clone(),
144                            sufficient_combination
145                                .iter()
146                                .map(|x| {
147                                    let actual_len = res[0].path.len()
148                                        + self.oracle_numeric_infos.nb_digits[*x]
149                                        - min_nb_digits;
150                                    (*x, pre_pad_vec(res[0].path.clone(), actual_len))
151                                })
152                                .collect::<Vec<_>>(),
153                        ));
154                    }
155                }
156            }
157        }
158
159        if let Some(extra_cover_trie) = &self.extra_cover_trie {
160            if let Some(res) = extra_cover_trie.look_up(paths) {
161                return Some((res.0.clone(), res.1));
162            }
163        }
164
165        None
166    }
167
168    /// Creates a new MultiOracleTrie
169    pub fn new(oracle_numeric_infos: &OracleNumericInfo, threshold: usize) -> Result<Self, Error> {
170        if oracle_numeric_infos.nb_digits.is_empty() {
171            return Err(Error::InvalidArgument);
172        }
173        let digit_trie = DigitTrie::new(oracle_numeric_infos.base);
174        let extra_cover_trie = if oracle_numeric_infos.has_diff_nb_digits() {
175            // The support and coverage parameters don't matter as we only use this trie for coverage of
176            // the "out of bounds" outcomes.
177            Some(MultiTrie::new(oracle_numeric_infos, threshold, 1, 2, true))
178        } else {
179            None
180        };
181        Ok(MultiOracleTrie {
182            digit_trie,
183            threshold,
184            oracle_numeric_infos: oracle_numeric_infos.clone(),
185            extra_cover_trie,
186        })
187    }
188}
189
190impl<'a> DlcTrie<'a, MultiOracleTrieIter<'a>> for MultiOracleTrie {
191    fn generate(
192        &mut self,
193        adaptor_index_start: usize,
194        outcomes: &[RangePayout],
195    ) -> Result<Vec<TrieIterInfo>, Error> {
196        let threshold = self.threshold;
197        let nb_oracles = self.oracle_numeric_infos.nb_digits.len();
198        let min_nb_digits = self.oracle_numeric_infos.get_min_nb_digits();
199        let mut adaptor_index = adaptor_index_start;
200        let mut trie_infos = Vec::new();
201        let oracle_numeric_infos = &self.oracle_numeric_infos;
202        for (cet_index, outcome) in outcomes.iter().enumerate() {
203            if outcome.count == 0 {
204                return Err(Error::InvalidArgument);
205            }
206            let groups = group_by_ignoring_digits(
207                outcome.start,
208                outcome.start + outcome.count - 1,
209                self.digit_trie.base,
210                min_nb_digits,
211            );
212            for group in groups {
213                let mut get_value = |_: Option<Vec<RangeInfo>>| -> Result<Vec<RangeInfo>, Error> {
214                    let combination_iterator = CombinationIterator::new(nb_oracles, threshold);
215                    let mut range_infos: Vec<RangeInfo> = Vec::new();
216                    for selector in combination_iterator {
217                        let range_info = RangeInfo {
218                            cet_index,
219                            adaptor_index,
220                        };
221                        adaptor_index += 1;
222                        let paths = oracle_numeric_infos
223                            .nb_digits
224                            .iter()
225                            .enumerate()
226                            .filter_map(|(i, nb_digits)| {
227                                if !selector.contains(&i) {
228                                    return None;
229                                }
230                                let expected_len = group.len() + nb_digits - min_nb_digits;
231                                Some(pre_pad_vec(group.clone(), expected_len))
232                            })
233                            .collect();
234                        let trie_info = TrieIterInfo {
235                            paths,
236                            indexes: selector,
237                            value: range_info.clone(),
238                        };
239                        trie_infos.push(trie_info);
240                        range_infos.push(range_info);
241                    }
242                    Ok(range_infos)
243                };
244                self.digit_trie.insert(&group, &mut get_value)?;
245            }
246        }
247
248        if let Some(extra_cover_trie) = &mut self.extra_cover_trie {
249            let mut get_value =
250                |paths: &[Vec<usize>], oracle_indexes: &[usize]| -> Result<RangeInfo, Error> {
251                    get_value_callback(
252                        paths,
253                        oracle_indexes,
254                        outcomes.len() - 1,
255                        &mut adaptor_index,
256                        &mut trie_infos,
257                    )
258                };
259            extra_cover_trie.insert_max_paths(&mut get_value)?;
260        }
261
262        Ok(trie_infos)
263    }
264
265    fn iter(&'a self) -> MultiOracleTrieIter<'a> {
266        let digit_trie_iterator = DigitTrieIter::new(&self.digit_trie);
267        let extra_cover_trie_iterator = self.extra_cover_trie.as_ref().map(MultiTrieIterator::new);
268        MultiOracleTrieIter {
269            digit_trie_iterator,
270            extra_cover_trie_iterator,
271            cur_res: None,
272            cur_index: 0,
273            combination_iter: CombinationIterator::new(
274                self.oracle_numeric_infos.nb_digits.len(),
275                self.threshold,
276            ),
277            oracle_numeric_infos: self.oracle_numeric_infos.clone(),
278        }
279    }
280}
281
282/// Iterator for a MultiOracleTrie.
283pub struct MultiOracleTrieIter<'a> {
284    digit_trie_iterator: DigitTrieIter<'a, Vec<RangeInfo>>,
285    extra_cover_trie_iterator: Option<MultiTrieIterator<'a, RangeInfo>>,
286    cur_res: Option<LookupResult<'a, Vec<RangeInfo>, usize>>,
287    cur_index: usize,
288    combination_iter: CombinationIterator,
289    oracle_numeric_infos: OracleNumericInfo,
290}
291
292impl Iterator for MultiOracleTrieIter<'_> {
293    type Item = TrieIterInfo;
294
295    fn next(&mut self) -> Option<Self::Item> {
296        if self.cur_res.is_none() {
297            self.cur_res = self.digit_trie_iterator.next();
298        }
299        let res = match &self.cur_res {
300            None => {
301                if let Some(extra_cover_trie_iterator) = &mut self.extra_cover_trie_iterator {
302                    let res = extra_cover_trie_iterator.next()?;
303                    let (indexes, paths) = res.path.iter().fold(
304                        (Vec::new(), Vec::new()),
305                        |(mut indexes, mut paths), x| {
306                            indexes.push(x.0);
307                            paths.push(x.1.clone());
308                            (indexes, paths)
309                        },
310                    );
311                    return Some(TrieIterInfo {
312                        indexes,
313                        paths,
314                        value: res.value.clone(),
315                    });
316                } else {
317                    return None;
318                }
319            }
320            Some(res) => res,
321        };
322
323        let indexes = match self.combination_iter.next() {
324            Some(selector) => selector,
325            None => {
326                self.cur_res = None;
327                self.cur_index = 0;
328                self.combination_iter = CombinationIterator::new(
329                    self.combination_iter.nb_elements,
330                    self.combination_iter.nb_selected,
331                );
332                return self.next();
333            }
334        };
335        let min_nb_digits = self.oracle_numeric_infos.get_min_nb_digits();
336        let paths = &std::iter::repeat_n(res.path.clone(), indexes.len())
337            .take(indexes.len())
338            .zip(indexes.iter())
339            .map(|(x, i)| {
340                let extra_len = self.oracle_numeric_infos.nb_digits[*i] - min_nb_digits;
341                if extra_len == 0 {
342                    x
343                } else {
344                    let expected_size = extra_len + x.len();
345                    pre_pad_vec(x, expected_size)
346                }
347            })
348            .collect::<Vec<Vec<_>>>();
349        let value = res.value[self.cur_index].clone();
350        self.cur_index += 1;
351        Some(TrieIterInfo {
352            indexes,
353            paths: paths.clone(),
354            value,
355        })
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use bitcoin::Amount;
362    use ddk_dlc::{Payout, RangePayout};
363
364    use crate::{test_utils::get_variable_oracle_numeric_infos, DlcTrie};
365
366    use super::MultiOracleTrie;
367    #[test]
368    fn test_longer_outcome_len() {
369        let range_payouts = vec![RangePayout {
370            start: 0,
371            count: 1023,
372            payout: Payout {
373                offer: Amount::from_sat(200000000),
374                accept: Amount::ZERO,
375            },
376        }];
377        let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[10, 15, 15, 15, 12], 2);
378        let mut multi_oracle_trie = MultiOracleTrie::new(&oracle_numeric_infos, 2).unwrap();
379        multi_oracle_trie.generate(0, &range_payouts).unwrap();
380        multi_oracle_trie
381            .look_up(&[
382                (1, vec![0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
383                (4, vec![0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
384            ])
385            .expect("Could not retrieve path with extra len.");
386    }
387
388    #[test]
389    fn test_over_bound_outcome() {
390        let range_payouts = vec![RangePayout {
391            start: 0,
392            count: 1023,
393            payout: Payout {
394                offer: Amount::from_sat(200000000),
395                accept: Amount::ZERO,
396            },
397        }];
398        let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[10, 15, 15, 15, 12], 2);
399        let mut multi_oracle_trie = MultiOracleTrie::new(&oracle_numeric_infos, 2).unwrap();
400        multi_oracle_trie.generate(0, &range_payouts).unwrap();
401        multi_oracle_trie
402            .look_up(&[
403                (1, vec![1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
404                (4, vec![0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
405            ])
406            .expect("Could not retrieve path with extra len.");
407    }
408
409    #[test]
410    fn test_invalid_range_payout() {
411        let range_payouts = vec![RangePayout {
412            start: 0,
413            count: 0,
414            payout: Payout {
415                offer: Amount::ZERO,
416                accept: Amount::from_sat(200000000),
417            },
418        }];
419
420        let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[13, 12], 2);
421        let mut multi_oracle_trie = MultiOracleTrie::new(&oracle_numeric_infos, 2).unwrap();
422        multi_oracle_trie
423            .generate(0, &range_payouts)
424            .expect_err("Should fail when given a range payout with a count of 0");
425    }
426}