keri_core/event/sections/
threshold.rs1use std::{fmt, str::FromStr};
2
3use fraction::{Fraction, One, Zero};
4use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
5use serde_hex::{Compact, SerHex};
6
7use super::key_config::SignatureError;
8
9#[derive(Debug, thiserror::Error, Serialize, Deserialize)]
10pub enum ThresholdError {
11 #[error("Error parsing numerical value")]
12 ParseIntError,
13 #[error("Wrong threshold value. Should be fraction")]
14 FractionExpected,
15}
16
17impl From<core::num::ParseIntError> for ThresholdError {
18 fn from(_: core::num::ParseIntError) -> Self {
19 ThresholdError::ParseIntError
20 }
21}
22
23#[derive(Debug, Clone, PartialEq, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
24#[rkyv(derive(Debug))]
25pub struct ThresholdFraction {
26 #[rkyv(with = rkyv_serialization::FractionDef)]
27 fraction: Fraction,
28}
29
30impl ThresholdFraction {
31 pub fn new(n: u64, d: u64) -> Self {
32 Self {
33 fraction: Fraction::new(n, d),
34 }
35 }
36}
37
38impl fmt::Display for ThresholdFraction {
39 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40 write!(f, "{}", self.fraction)
41 }
42}
43
44impl FromStr for ThresholdFraction {
45 type Err = ThresholdError;
46
47 fn from_str(s: &str) -> Result<Self, Self::Err> {
48 let f: Vec<_> = s.split('/').collect();
49 if f.len() > 2 {
50 Err(ThresholdError::FractionExpected)
51 } else if f.len() == 1 {
52 let a = f[0].parse::<u64>()?;
53 Ok(ThresholdFraction {
54 fraction: Fraction::new(a, 1u64),
55 })
56 } else {
57 let a = f[0].parse::<u64>()?;
58 let b = f[1].parse::<u64>()?;
59 Ok(ThresholdFraction {
60 fraction: Fraction::new(a, b),
61 })
62 }
63 }
64}
65impl<'de> Deserialize<'de> for ThresholdFraction {
66 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
67 where
68 D: Deserializer<'de>,
69 {
70 let s = String::deserialize(deserializer)?;
71 FromStr::from_str(&s).map_err(de::Error::custom)
72 }
73}
74
75impl Serialize for ThresholdFraction {
76 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77 where
78 S: Serializer,
79 {
80 serializer.serialize_str(&self.to_string())
81 }
82}
83
84#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
85#[serde(untagged)]
86#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
87#[rkyv(derive(Debug))]
88pub enum SignatureThreshold {
89 #[serde(with = "SerHex::<Compact>")]
90 Simple(u64),
91 Weighted(WeightedThreshold),
92}
93
94#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
95#[serde(untagged)]
96#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
97#[rkyv(derive(Debug))]
98pub enum WeightedThreshold {
99 Single(ThresholdClause),
100 Multi(MultiClauses),
101}
102
103impl WeightedThreshold {
104 pub fn enough_signatures(&self, sigs_indexes: &[usize]) -> Result<(), SignatureError> {
105 match self {
106 WeightedThreshold::Single(clause) => clause.enough_signatures(0, sigs_indexes),
107 WeightedThreshold::Multi(clauses) => clauses.enough_signatures(sigs_indexes),
108 }
109 }
110}
111
112impl SignatureThreshold {
113 pub fn simple(t: u64) -> Self {
114 Self::Simple(t)
115 }
116
117 pub fn single_weighted(fracs: Vec<(u64, u64)>) -> Self {
118 Self::Weighted(WeightedThreshold::Single(ThresholdClause::new_from_tuples(
119 fracs,
120 )))
121 }
122
123 pub fn multi_weighted(fracs: Vec<Vec<(u64, u64)>>) -> Self {
124 Self::Weighted(WeightedThreshold::Multi(MultiClauses::new_from_tuples(
125 fracs,
126 )))
127 }
128
129 pub fn enough_signatures(&self, sigs_indexes: &[usize]) -> Result<(), SignatureError> {
130 match self {
131 SignatureThreshold::Simple(ref t) => {
132 if (sigs_indexes.len() as u64) >= *t {
133 Ok(())
134 } else {
135 Err(SignatureError::NotEnoughSigsError)
136 }
137 }
138 SignatureThreshold::Weighted(ref thresh) => thresh.enough_signatures(sigs_indexes),
139 }
140 }
141}
142
143impl Default for SignatureThreshold {
144 fn default() -> Self {
145 Self::Simple(1)
146 }
147}
148#[derive(
149 Serialize,
150 Deserialize,
151 Debug,
152 Clone,
153 PartialEq,
154 rkyv::Archive,
155 rkyv::Serialize,
156 rkyv::Deserialize,
157)]
158#[rkyv(derive(Debug))]
159pub struct ThresholdClause(Vec<ThresholdFraction>);
160
161impl ThresholdClause {
162 pub fn new(fracs: &[ThresholdFraction]) -> Self {
163 Self(fracs.to_owned())
164 }
165
166 pub fn new_from_tuples(tuples: Vec<(u64, u64)>) -> Self {
167 let clause = tuples
168 .into_iter()
169 .map(|(n, d)| ThresholdFraction::new(n, d))
170 .collect();
171 Self(clause)
172 }
173
174 pub fn length(&self) -> usize {
175 self.0.len()
176 }
177
178 pub fn enough_signatures(
179 &self,
180 start_index: usize,
181 sigs_indexes: &[usize],
182 ) -> Result<(), SignatureError> {
183 (sigs_indexes
184 .iter()
185 .fold(Some(Zero::zero()), |acc: Option<Fraction>, sig_index| {
186 if let (Some(element), Some(sum)) = (self.0.get(sig_index - start_index), acc) {
187 Some(sum + element.fraction)
188 } else {
189 None
190 }
191 })
192 .ok_or_else(|| SignatureError::MissingIndex)?
193 >= One::one())
194 .then(|| ())
195 .ok_or(SignatureError::NotEnoughSigsError)
196 }
197}
198
199#[derive(
200 Deserialize,
201 Serialize,
202 Debug,
203 Clone,
204 PartialEq,
205 rkyv::Archive,
206 rkyv::Serialize,
207 rkyv::Deserialize,
208)]
209#[rkyv(derive(Debug))]
210
211pub struct MultiClauses(Vec<ThresholdClause>);
212
213impl MultiClauses {
214 pub fn new(fracs: Vec<Vec<ThresholdFraction>>) -> Self {
215 let clauses = fracs
216 .iter()
217 .map(|clause| ThresholdClause::new(clause))
218 .collect();
219
220 Self(clauses)
221 }
222
223 pub fn new_from_tuples(fracs: Vec<Vec<(u64, u64)>>) -> Self {
224 let wt = fracs
225 .into_iter()
226 .map(ThresholdClause::new_from_tuples)
227 .collect();
228 MultiClauses(wt)
229 }
230
231 pub fn length(&self) -> usize {
232 self.0.iter().map(|l| l.length()).sum()
233 }
234
235 pub fn enough_signatures(&self, sigs_indexes: &[usize]) -> Result<(), SignatureError> {
236 self.0
237 .iter()
238 .fold(Ok((0, true)), |acc, clause| -> Result<_, SignatureError> {
239 let (start, enough) = acc?;
240 let sigs: Vec<usize> = sigs_indexes
241 .iter()
242 .cloned()
243 .filter(|sig_index| {
244 sig_index >= &start && sig_index < &(start + clause.0.len())
245 })
246 .collect();
247 Ok((
248 start + clause.0.len(),
249 enough && clause.enough_signatures(start, &sigs).is_ok(),
250 ))
251 })?
252 .1
253 .then(|| ())
254 .ok_or(SignatureError::NotEnoughSigsError)
255 }
256}
257
258mod rkyv_serialization {
259 use fraction::Fraction;
260
261 #[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
262 #[rkyv(remote = fraction::Sign)]
263 #[rkyv(derive(Debug))]
264 pub enum SignDef {
265 Plus,
266 Minus,
267 }
268
269 impl From<SignDef> for fraction::Sign {
270 fn from(value: SignDef) -> Self {
271 match value {
272 SignDef::Plus => Self::Plus,
273 SignDef::Minus => Self::Minus,
274 }
275 }
276 }
277
278 #[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
279 #[rkyv(remote = Fraction)]
280 #[rkyv(derive(Debug))]
281 pub enum FractionDef {
282 Rational(
283 #[rkyv(with = SignDef)] fraction::Sign,
284 #[rkyv(with = RatioDef)] fraction::Ratio<u64>,
285 ),
286 Infinity(#[rkyv(with = SignDef)] fraction::Sign),
287 NaN,
288 }
289
290 impl From<FractionDef> for Fraction {
291 fn from(value: FractionDef) -> Self {
292 match value {
293 FractionDef::Rational(sign, ratio) => Fraction::Rational(sign, ratio),
294 FractionDef::Infinity(sign) => Fraction::Infinity(sign),
295 FractionDef::NaN => Fraction::NaN,
296 }
297 }
298 }
299
300 #[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
301 #[rkyv(remote = fraction::Ratio<u64>)]
302 #[rkyv(derive(Debug))]
303 pub struct RatioDef {
304 #[rkyv(getter = fraction::Ratio::numer)]
306 numer: u64,
307 #[rkyv(getter = fraction::Ratio::denom)]
309 denom: u64,
310 }
311
312 impl From<RatioDef> for fraction::Ratio<u64> {
313 fn from(value: RatioDef) -> Self {
314 Self::new(value.numer, value.denom)
315 }
316 }
317}
318
319#[test]
320fn test_enough_sigs() -> Result<(), SignatureError> {
321 let wt = MultiClauses::new_from_tuples(vec![vec![(1, 1)], vec![(1, 2), (1, 2), (1, 2)]]);
323 let sigs_indexes: Vec<_> = vec![0, 1, 2, 3];
324
325 assert!(wt.enough_signatures(&sigs_indexes.clone()).is_ok());
327
328 let enough = vec![
330 sigs_indexes[0].clone(),
331 sigs_indexes[1].clone(),
332 sigs_indexes[3].clone(),
333 ];
334 assert!(wt.enough_signatures(&enough.clone()).is_ok());
335
336 let not_enough = vec![sigs_indexes[0].clone()];
337 assert!(!wt.enough_signatures(¬_enough.clone()).is_ok());
338
339 Ok(())
340}
341
342#[test]
343pub fn test_weighted_treshold_serialization() -> Result<(), SignatureError> {
344 let multi_threshold = r#"[["1"],["1/2","1/2","1/2"]]"#.to_string();
345 let wt: WeightedThreshold = serde_json::from_str(&multi_threshold).unwrap();
346 assert!(matches!(wt, WeightedThreshold::Multi(_)));
347 assert_eq!(
349 serde_json::to_string(&wt).unwrap(),
350 r#"[["1"],["1/2","1/2","1/2"]]"#.to_string()
351 );
352
353 let single_threshold = r#"["1/2","1/2","1/2"]"#.to_string();
354 let wt: WeightedThreshold = serde_json::from_str(&single_threshold).unwrap();
355 assert!(matches!(wt, WeightedThreshold::Single(_)));
356 assert_eq!(serde_json::to_string(&wt).unwrap(), single_threshold);
357 Ok(())
358}