1use crate::combination_iterator::CombinationIterator;
7use crate::digit_decomposition::group_by_ignoring_digits;
8use crate::digit_trie::{DigitTrie, DigitTrieDump, DigitTrieIter};
9use crate::multi_trie::{MultiTrie, MultiTrieDump, MultiTrieIterator};
10use crate::utils::{get_value_callback, pre_pad_vec};
11use crate::{DlcTrie, IndexedPath, LookupResult, OracleNumericInfo, RangeInfo, TrieIterInfo};
12use ddk_dlc::{Error, RangePayout};
13
14#[derive(Clone)]
18pub struct MultiOracleTrie {
19 digit_trie: DigitTrie<Vec<RangeInfo>>,
21 threshold: usize,
22 oracle_numeric_infos: OracleNumericInfo,
23 extra_cover_trie: Option<MultiTrie<RangeInfo>>,
24}
25
26pub struct MultiOracleTrieDump {
28 pub digit_trie_dump: DigitTrieDump<Vec<RangeInfo>>,
30 pub threshold: usize,
32 pub oracle_numeric_infos: OracleNumericInfo,
34 pub extra_cover_trie_dump: Option<MultiTrieDump<RangeInfo>>,
36}
37
38impl MultiOracleTrie {
39 pub fn dump(&self) -> MultiOracleTrieDump {
41 MultiOracleTrieDump {
42 digit_trie_dump: self.digit_trie.dump(),
43 threshold: self.threshold,
44 oracle_numeric_infos: self.oracle_numeric_infos.clone(),
45 extra_cover_trie_dump: self.extra_cover_trie.as_ref().map(|trie| trie.dump()),
46 }
47 }
48
49 pub fn from_dump(dump: MultiOracleTrieDump) -> MultiOracleTrie {
51 let MultiOracleTrieDump {
52 digit_trie_dump,
53 threshold,
54 oracle_numeric_infos,
55 extra_cover_trie_dump,
56 } = dump;
57 MultiOracleTrie {
58 digit_trie: DigitTrie::from_dump(digit_trie_dump),
59 threshold,
60 oracle_numeric_infos,
61 extra_cover_trie: extra_cover_trie_dump.map(MultiTrie::from_dump),
62 }
63 }
64
65 fn get_agreeing_oracles(
66 &self,
67 paths: &[(usize, Vec<usize>)],
68 ) -> Option<Vec<(Vec<usize>, Vec<usize>)>> {
69 let mut hash_set: std::collections::HashMap<Vec<usize>, Vec<usize>> =
70 std::collections::HashMap::new();
71
72 for path in paths {
73 let index = path.0;
74 let outcome_path = &path.1;
75
76 if let Some(index_set) = hash_set.get_mut(outcome_path) {
77 index_set.push(index);
78 } else {
79 let index_set = vec![index];
80 hash_set.insert(outcome_path.to_vec(), index_set);
81 }
82 }
83
84 if hash_set.is_empty() {
85 return None;
86 }
87
88 let mut values: Vec<_> = hash_set.into_iter().collect();
89 values.sort_by(|x, y| x.1.len().partial_cmp(&y.1.len()).unwrap());
90 let res = values
91 .into_iter()
92 .filter(|x| x.1.len() >= self.threshold)
93 .collect::<Vec<_>>();
94 if !res.is_empty() {
95 Some(res)
96 } else {
97 None
98 }
99 }
100
101 pub fn look_up(&self, paths: &[(usize, Vec<usize>)]) -> Option<(RangeInfo, Vec<IndexedPath>)> {
103 let min_nb_digits = self.oracle_numeric_infos.get_min_nb_digits();
104 let stripped_paths = paths
107 .iter()
108 .filter_map(|x| {
109 let extra_len = x.1.len().checked_sub(min_nb_digits)?;
110 if extra_len == 0 {
111 Some((x.0, x.1.clone()))
112 } else if x.1.iter().take(extra_len).all(|x| *x == 0) {
113 let mut cloned = x.1.clone();
114 cloned.drain(..extra_len);
115 Some((x.0, cloned))
116 } else {
117 None
118 }
119 })
120 .collect::<Vec<_>>();
121 let agreeing_combinations = self.get_agreeing_oracles(&stripped_paths);
123 if let Some(sufficient_combinations) = agreeing_combinations {
124 for (path, combination) in sufficient_combinations {
125 debug_assert_eq!(
126 path.len(),
127 min_nb_digits,
128 "Expected length {} got length {}",
129 min_nb_digits,
130 path.len()
131 );
132
133 if let Some(res) = self.digit_trie.look_up(&path) {
134 let sufficient_combination: Vec<_> =
135 combination.into_iter().take(self.threshold).collect();
136 if let Some(position) = CombinationIterator::new(
137 self.oracle_numeric_infos.nb_digits.len(),
138 self.threshold,
139 )
140 .get_index_for_combination(&sufficient_combination)
141 {
142 return Some((
143 res[0].value[position].clone(),
144 sufficient_combination
145 .iter()
146 .map(|x| {
147 let actual_len = res[0].path.len()
148 + self.oracle_numeric_infos.nb_digits[*x]
149 - min_nb_digits;
150 (*x, pre_pad_vec(res[0].path.clone(), actual_len))
151 })
152 .collect::<Vec<_>>(),
153 ));
154 }
155 }
156 }
157 }
158
159 if let Some(extra_cover_trie) = &self.extra_cover_trie {
160 if let Some(res) = extra_cover_trie.look_up(paths) {
161 return Some((res.0.clone(), res.1));
162 }
163 }
164
165 None
166 }
167
168 pub fn new(oracle_numeric_infos: &OracleNumericInfo, threshold: usize) -> Result<Self, Error> {
170 if oracle_numeric_infos.nb_digits.is_empty() {
171 return Err(Error::InvalidArgument);
172 }
173 let digit_trie = DigitTrie::new(oracle_numeric_infos.base);
174 let extra_cover_trie = if oracle_numeric_infos.has_diff_nb_digits() {
175 Some(MultiTrie::new(oracle_numeric_infos, threshold, 1, 2, true))
178 } else {
179 None
180 };
181 Ok(MultiOracleTrie {
182 digit_trie,
183 threshold,
184 oracle_numeric_infos: oracle_numeric_infos.clone(),
185 extra_cover_trie,
186 })
187 }
188}
189
190impl<'a> DlcTrie<'a, MultiOracleTrieIter<'a>> for MultiOracleTrie {
191 fn generate(
192 &mut self,
193 adaptor_index_start: usize,
194 outcomes: &[RangePayout],
195 ) -> Result<Vec<TrieIterInfo>, Error> {
196 let threshold = self.threshold;
197 let nb_oracles = self.oracle_numeric_infos.nb_digits.len();
198 let min_nb_digits = self.oracle_numeric_infos.get_min_nb_digits();
199 let mut adaptor_index = adaptor_index_start;
200 let mut trie_infos = Vec::new();
201 let oracle_numeric_infos = &self.oracle_numeric_infos;
202 for (cet_index, outcome) in outcomes.iter().enumerate() {
203 if outcome.count == 0 {
204 return Err(Error::InvalidArgument);
205 }
206 let groups = group_by_ignoring_digits(
207 outcome.start,
208 outcome.start + outcome.count - 1,
209 self.digit_trie.base,
210 min_nb_digits,
211 );
212 for group in groups {
213 let mut get_value = |_: Option<Vec<RangeInfo>>| -> Result<Vec<RangeInfo>, Error> {
214 let combination_iterator = CombinationIterator::new(nb_oracles, threshold);
215 let mut range_infos: Vec<RangeInfo> = Vec::new();
216 for selector in combination_iterator {
217 let range_info = RangeInfo {
218 cet_index,
219 adaptor_index,
220 };
221 adaptor_index += 1;
222 let paths = oracle_numeric_infos
223 .nb_digits
224 .iter()
225 .enumerate()
226 .filter_map(|(i, nb_digits)| {
227 if !selector.contains(&i) {
228 return None;
229 }
230 let expected_len = group.len() + nb_digits - min_nb_digits;
231 Some(pre_pad_vec(group.clone(), expected_len))
232 })
233 .collect();
234 let trie_info = TrieIterInfo {
235 paths,
236 indexes: selector,
237 value: range_info.clone(),
238 };
239 trie_infos.push(trie_info);
240 range_infos.push(range_info);
241 }
242 Ok(range_infos)
243 };
244 self.digit_trie.insert(&group, &mut get_value)?;
245 }
246 }
247
248 if let Some(extra_cover_trie) = &mut self.extra_cover_trie {
249 let mut get_value =
250 |paths: &[Vec<usize>], oracle_indexes: &[usize]| -> Result<RangeInfo, Error> {
251 get_value_callback(
252 paths,
253 oracle_indexes,
254 outcomes.len() - 1,
255 &mut adaptor_index,
256 &mut trie_infos,
257 )
258 };
259 extra_cover_trie.insert_max_paths(&mut get_value)?;
260 }
261
262 Ok(trie_infos)
263 }
264
265 fn iter(&'a self) -> MultiOracleTrieIter<'a> {
266 let digit_trie_iterator = DigitTrieIter::new(&self.digit_trie);
267 let extra_cover_trie_iterator = self.extra_cover_trie.as_ref().map(MultiTrieIterator::new);
268 MultiOracleTrieIter {
269 digit_trie_iterator,
270 extra_cover_trie_iterator,
271 cur_res: None,
272 cur_index: 0,
273 combination_iter: CombinationIterator::new(
274 self.oracle_numeric_infos.nb_digits.len(),
275 self.threshold,
276 ),
277 oracle_numeric_infos: self.oracle_numeric_infos.clone(),
278 }
279 }
280}
281
282pub struct MultiOracleTrieIter<'a> {
284 digit_trie_iterator: DigitTrieIter<'a, Vec<RangeInfo>>,
285 extra_cover_trie_iterator: Option<MultiTrieIterator<'a, RangeInfo>>,
286 cur_res: Option<LookupResult<'a, Vec<RangeInfo>, usize>>,
287 cur_index: usize,
288 combination_iter: CombinationIterator,
289 oracle_numeric_infos: OracleNumericInfo,
290}
291
292impl Iterator for MultiOracleTrieIter<'_> {
293 type Item = TrieIterInfo;
294
295 fn next(&mut self) -> Option<Self::Item> {
296 if self.cur_res.is_none() {
297 self.cur_res = self.digit_trie_iterator.next();
298 }
299 let res = match &self.cur_res {
300 None => {
301 if let Some(extra_cover_trie_iterator) = &mut self.extra_cover_trie_iterator {
302 let res = extra_cover_trie_iterator.next()?;
303 let (indexes, paths) = res.path.iter().fold(
304 (Vec::new(), Vec::new()),
305 |(mut indexes, mut paths), x| {
306 indexes.push(x.0);
307 paths.push(x.1.clone());
308 (indexes, paths)
309 },
310 );
311 return Some(TrieIterInfo {
312 indexes,
313 paths,
314 value: res.value.clone(),
315 });
316 } else {
317 return None;
318 }
319 }
320 Some(res) => res,
321 };
322
323 let indexes = match self.combination_iter.next() {
324 Some(selector) => selector,
325 None => {
326 self.cur_res = None;
327 self.cur_index = 0;
328 self.combination_iter = CombinationIterator::new(
329 self.combination_iter.nb_elements,
330 self.combination_iter.nb_selected,
331 );
332 return self.next();
333 }
334 };
335 let min_nb_digits = self.oracle_numeric_infos.get_min_nb_digits();
336 let paths = &std::iter::repeat_n(res.path.clone(), indexes.len())
337 .take(indexes.len())
338 .zip(indexes.iter())
339 .map(|(x, i)| {
340 let extra_len = self.oracle_numeric_infos.nb_digits[*i] - min_nb_digits;
341 if extra_len == 0 {
342 x
343 } else {
344 let expected_size = extra_len + x.len();
345 pre_pad_vec(x, expected_size)
346 }
347 })
348 .collect::<Vec<Vec<_>>>();
349 let value = res.value[self.cur_index].clone();
350 self.cur_index += 1;
351 Some(TrieIterInfo {
352 indexes,
353 paths: paths.clone(),
354 value,
355 })
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use bitcoin::Amount;
362 use ddk_dlc::{Payout, RangePayout};
363
364 use crate::{test_utils::get_variable_oracle_numeric_infos, DlcTrie};
365
366 use super::MultiOracleTrie;
367 #[test]
368 fn test_longer_outcome_len() {
369 let range_payouts = vec![RangePayout {
370 start: 0,
371 count: 1023,
372 payout: Payout {
373 offer: Amount::from_sat(200000000),
374 accept: Amount::ZERO,
375 },
376 }];
377 let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[10, 15, 15, 15, 12], 2);
378 let mut multi_oracle_trie = MultiOracleTrie::new(&oracle_numeric_infos, 2).unwrap();
379 multi_oracle_trie.generate(0, &range_payouts).unwrap();
380 multi_oracle_trie
381 .look_up(&[
382 (1, vec![0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
383 (4, vec![0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
384 ])
385 .expect("Could not retrieve path with extra len.");
386 }
387
388 #[test]
389 fn test_over_bound_outcome() {
390 let range_payouts = vec![RangePayout {
391 start: 0,
392 count: 1023,
393 payout: Payout {
394 offer: Amount::from_sat(200000000),
395 accept: Amount::ZERO,
396 },
397 }];
398 let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[10, 15, 15, 15, 12], 2);
399 let mut multi_oracle_trie = MultiOracleTrie::new(&oracle_numeric_infos, 2).unwrap();
400 multi_oracle_trie.generate(0, &range_payouts).unwrap();
401 multi_oracle_trie
402 .look_up(&[
403 (1, vec![1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
404 (4, vec![0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0]),
405 ])
406 .expect("Could not retrieve path with extra len.");
407 }
408
409 #[test]
410 fn test_invalid_range_payout() {
411 let range_payouts = vec![RangePayout {
412 start: 0,
413 count: 0,
414 payout: Payout {
415 offer: Amount::ZERO,
416 accept: Amount::from_sat(200000000),
417 },
418 }];
419
420 let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[13, 12], 2);
421 let mut multi_oracle_trie = MultiOracleTrie::new(&oracle_numeric_infos, 2).unwrap();
422 multi_oracle_trie
423 .generate(0, &range_payouts)
424 .expect_err("Should fail when given a range payout with a count of 0");
425 }
426}