1use crate::digit_decomposition::group_by_ignoring_digits;
7use crate::multi_trie::{MultiTrie, MultiTrieDump, MultiTrieIterator};
8use crate::utils::get_value_callback;
9
10use crate::{DlcTrie, OracleNumericInfo, RangeInfo, TrieIterInfo};
11use ddk_dlc::{Error, RangePayout};
12
13#[derive(Clone)]
17pub struct MultiOracleTrieWithDiff {
18 pub multi_trie: MultiTrie<RangeInfo>,
20 pub oracle_numeric_infos: OracleNumericInfo,
22}
23
24impl MultiOracleTrieWithDiff {
25 pub fn new(
27 oracle_numeric_infos: &OracleNumericInfo,
28 threshold: usize,
29 min_support_exp: usize,
30 max_error_exp: usize,
31 ) -> Result<Self, Error> {
32 let nb_oracles = oracle_numeric_infos.nb_digits.len();
33 let is_valid =
34 nb_oracles >= 1 && threshold <= nb_oracles && min_support_exp < max_error_exp;
35 if !is_valid {
36 return Err(Error::InvalidArgument);
37 }
38 let multi_trie = MultiTrie::new(
39 oracle_numeric_infos,
40 threshold,
41 min_support_exp,
42 max_error_exp,
43 true,
44 );
45 Ok(MultiOracleTrieWithDiff {
46 multi_trie,
47 oracle_numeric_infos: oracle_numeric_infos.clone(),
48 })
49 }
50}
51
52impl<'a> DlcTrie<'a, MultiOracleTrieWithDiffIter<'a>> for MultiOracleTrieWithDiff {
53 fn generate(
54 &mut self,
55 adaptor_index_start: usize,
56 outcomes: &[RangePayout],
57 ) -> Result<Vec<TrieIterInfo>, Error> {
58 let mut adaptor_index = adaptor_index_start;
59 let mut trie_infos = Vec::new();
60
61 for (cet_index, outcome) in outcomes.iter().enumerate() {
62 if outcome.count == 0 {
63 return Err(Error::InvalidArgument);
64 }
65 let groups = group_by_ignoring_digits(
66 outcome.start,
67 outcome.start + outcome.count - 1,
68 self.oracle_numeric_infos.base,
69 self.oracle_numeric_infos.get_min_nb_digits(),
70 );
71 for group in groups {
72 let mut get_value =
73 |paths: &[Vec<usize>], oracle_indexes: &[usize]| -> Result<RangeInfo, Error> {
74 get_value_callback(
75 paths,
76 oracle_indexes,
77 cet_index,
78 &mut adaptor_index,
79 &mut trie_infos,
80 )
81 };
82 self.multi_trie.insert(&group, &mut get_value)?;
83 }
84 }
85
86 if self.oracle_numeric_infos.has_diff_nb_digits() {
87 let mut get_value =
88 |paths: &[Vec<usize>], oracle_indexes: &[usize]| -> Result<RangeInfo, Error> {
89 get_value_callback(
90 paths,
91 oracle_indexes,
92 outcomes.len() - 1,
93 &mut adaptor_index,
94 &mut trie_infos,
95 )
96 };
97 self.multi_trie.insert_max_paths(&mut get_value)?;
98 }
99
100 Ok(trie_infos)
101 }
102
103 fn iter(&'a self) -> MultiOracleTrieWithDiffIter<'a> {
104 let multi_trie_iterator = MultiTrieIterator::new(&self.multi_trie);
105 MultiOracleTrieWithDiffIter {
106 multi_trie_iterator,
107 }
108 }
109}
110
111pub struct MultiOracleTrieWithDiffDump {
113 pub multi_trie_dump: MultiTrieDump<RangeInfo>,
115 pub oracle_numeric_infos: OracleNumericInfo,
117}
118
119impl MultiOracleTrieWithDiff {
120 pub fn dump(&self) -> MultiOracleTrieWithDiffDump {
122 let multi_trie_dump = self.multi_trie.dump();
123 MultiOracleTrieWithDiffDump {
124 multi_trie_dump,
125 oracle_numeric_infos: self.oracle_numeric_infos.clone(),
126 }
127 }
128
129 pub fn from_dump(dump: MultiOracleTrieWithDiffDump) -> MultiOracleTrieWithDiff {
131 let MultiOracleTrieWithDiffDump {
132 multi_trie_dump,
133 oracle_numeric_infos,
134 } = dump;
135 MultiOracleTrieWithDiff {
136 multi_trie: MultiTrie::from_dump(multi_trie_dump),
137 oracle_numeric_infos,
138 }
139 }
140}
141
142pub struct MultiOracleTrieWithDiffIter<'a> {
144 multi_trie_iterator: MultiTrieIterator<'a, RangeInfo>,
145}
146
147impl Iterator for MultiOracleTrieWithDiffIter<'_> {
148 type Item = TrieIterInfo;
149
150 fn next(&mut self) -> Option<Self::Item> {
151 let res = self.multi_trie_iterator.next()?;
152 let (indexes, paths) =
153 res.path
154 .iter()
155 .fold((Vec::new(), Vec::new()), |(mut indexes, mut paths), x| {
156 indexes.push(x.0);
157 paths.push(x.1.clone());
158 (indexes, paths)
159 });
160 Some(TrieIterInfo {
161 indexes,
162 paths,
163 value: res.value.clone(),
164 })
165 }
166}
167
168#[cfg(test)]
169mod tests {
170 use bitcoin::Amount;
171 use ddk_dlc::{Payout, RangePayout};
172
173 use crate::{test_utils::get_variable_oracle_numeric_infos, DlcTrie};
174
175 use super::MultiOracleTrieWithDiff;
176 #[test]
177 fn test_is_ordered() {
178 let range_payouts = vec![
179 RangePayout {
180 start: 0,
181 count: 1,
182 payout: Payout {
183 offer: Amount::ZERO,
184 accept: Amount::from_sat(200000000),
185 },
186 },
187 RangePayout {
188 start: 1,
189 count: 1,
190 payout: Payout {
191 offer: Amount::from_sat(40000000),
192 accept: Amount::from_sat(160000000),
193 },
194 },
195 RangePayout {
196 start: 2,
197 count: 1,
198 payout: Payout {
199 offer: Amount::from_sat(80000000),
200 accept: Amount::from_sat(120000000),
201 },
202 },
203 RangePayout {
204 start: 3,
205 count: 1,
206 payout: Payout {
207 offer: Amount::from_sat(120000000),
208 accept: Amount::from_sat(80000000),
209 },
210 },
211 RangePayout {
212 start: 4,
213 count: 1,
214 payout: Payout {
215 offer: Amount::from_sat(160000000),
216 accept: Amount::from_sat(40000000),
217 },
218 },
219 RangePayout {
220 start: 5,
221 count: 1019,
222 payout: Payout {
223 offer: Amount::from_sat(200000000),
224 accept: Amount::ZERO,
225 },
226 },
227 ];
228
229 let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[13, 12], 2);
230 let mut multi_oracle_trie =
231 MultiOracleTrieWithDiff::new(&oracle_numeric_infos, 2, 1, 2).unwrap();
232 let info = multi_oracle_trie.generate(0, &range_payouts).unwrap();
233 let mut indexes: Vec<_> = info
234 .into_iter()
235 .map(|info| info.value.adaptor_index)
236 .collect();
237
238 let lookup_res = multi_oracle_trie
239 .multi_trie
240 .look_up(&[
241 (0, vec![0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0]),
242 (1, vec![0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0]),
243 ])
244 .expect("Could not find");
245
246 indexes.sort();
247
248 for (prev_index, i) in indexes.iter().skip(1).enumerate() {
249 assert_eq!(*i, prev_index + 1);
250 }
251
252 let mut indexes: Vec<_> = multi_oracle_trie
253 .iter()
254 .map(|info| info.value.adaptor_index)
255 .collect();
256
257 indexes.sort();
258
259 for (prev_index, i) in indexes.iter().skip(1).enumerate() {
260 assert_eq!(*i, prev_index + 1);
261 }
262
263 let iter_res = multi_oracle_trie
264 .iter()
265 .find(|x| x.value.adaptor_index == 22)
266 .unwrap();
267 assert_eq!(
268 &lookup_res
269 .1
270 .iter()
271 .map(|(_, x)| x.clone())
272 .collect::<Vec<_>>(),
273 &iter_res.paths
274 );
275 }
276
277 #[test]
278 fn test_invalid_range_payout() {
279 let range_payouts = vec![RangePayout {
280 start: 0,
281 count: 0,
282 payout: Payout {
283 offer: Amount::ZERO,
284 accept: Amount::from_sat(200000000),
285 },
286 }];
287
288 let oracle_numeric_infos = get_variable_oracle_numeric_infos(&[13, 12], 2);
289 let mut multi_oracle_trie =
290 MultiOracleTrieWithDiff::new(&oracle_numeric_infos, 2, 1, 2).unwrap();
291 multi_oracle_trie
292 .generate(0, &range_payouts)
293 .expect_err("Should fail when given a range payout with a count of 0");
294 }
295}