keri_core/event/sections/
threshold.rs

1use std::{fmt, str::FromStr};
2
3use fraction::{Fraction, One, Zero};
4use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
5use serde_hex::{Compact, SerHex};
6
7use super::key_config::SignatureError;
8
9#[derive(Debug, thiserror::Error, Serialize, Deserialize)]
10pub enum ThresholdError {
11    #[error("Error parsing numerical value")]
12    ParseIntError,
13    #[error("Wrong threshold value. Should be fraction")]
14    FractionExpected,
15}
16
17impl From<core::num::ParseIntError> for ThresholdError {
18    fn from(_: core::num::ParseIntError) -> Self {
19        ThresholdError::ParseIntError
20    }
21}
22
23#[derive(Debug, Clone, PartialEq, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
24#[rkyv(derive(Debug))]
25pub struct ThresholdFraction {
26    #[rkyv(with = rkyv_serialization::FractionDef)]
27    fraction: Fraction,
28}
29
30impl ThresholdFraction {
31    pub fn new(n: u64, d: u64) -> Self {
32        Self {
33            fraction: Fraction::new(n, d),
34        }
35    }
36}
37
38impl fmt::Display for ThresholdFraction {
39    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
40        write!(f, "{}", self.fraction)
41    }
42}
43
44impl FromStr for ThresholdFraction {
45    type Err = ThresholdError;
46
47    fn from_str(s: &str) -> Result<Self, Self::Err> {
48        let f: Vec<_> = s.split('/').collect();
49        if f.len() > 2 {
50            Err(ThresholdError::FractionExpected)
51        } else if f.len() == 1 {
52            let a = f[0].parse::<u64>()?;
53            Ok(ThresholdFraction {
54                fraction: Fraction::new(a, 1u64),
55            })
56        } else {
57            let a = f[0].parse::<u64>()?;
58            let b = f[1].parse::<u64>()?;
59            Ok(ThresholdFraction {
60                fraction: Fraction::new(a, b),
61            })
62        }
63    }
64}
65impl<'de> Deserialize<'de> for ThresholdFraction {
66    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
67    where
68        D: Deserializer<'de>,
69    {
70        let s = String::deserialize(deserializer)?;
71        FromStr::from_str(&s).map_err(de::Error::custom)
72    }
73}
74
75impl Serialize for ThresholdFraction {
76    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
77    where
78        S: Serializer,
79    {
80        serializer.serialize_str(&self.to_string())
81    }
82}
83
84#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
85#[serde(untagged)]
86#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
87#[rkyv(derive(Debug))]
88pub enum SignatureThreshold {
89    #[serde(with = "SerHex::<Compact>")]
90    Simple(u64),
91    Weighted(WeightedThreshold),
92}
93
94#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
95#[serde(untagged)]
96#[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
97#[rkyv(derive(Debug))]
98pub enum WeightedThreshold {
99    Single(ThresholdClause),
100    Multi(MultiClauses),
101}
102
103impl WeightedThreshold {
104    pub fn enough_signatures(&self, sigs_indexes: &[usize]) -> Result<(), SignatureError> {
105        match self {
106            WeightedThreshold::Single(clause) => clause.enough_signatures(0, sigs_indexes),
107            WeightedThreshold::Multi(clauses) => clauses.enough_signatures(sigs_indexes),
108        }
109    }
110}
111
112impl SignatureThreshold {
113    pub fn simple(t: u64) -> Self {
114        Self::Simple(t)
115    }
116
117    pub fn single_weighted(fracs: Vec<(u64, u64)>) -> Self {
118        Self::Weighted(WeightedThreshold::Single(ThresholdClause::new_from_tuples(
119            fracs,
120        )))
121    }
122
123    pub fn multi_weighted(fracs: Vec<Vec<(u64, u64)>>) -> Self {
124        Self::Weighted(WeightedThreshold::Multi(MultiClauses::new_from_tuples(
125            fracs,
126        )))
127    }
128
129    pub fn enough_signatures(&self, sigs_indexes: &[usize]) -> Result<(), SignatureError> {
130        match self {
131            SignatureThreshold::Simple(ref t) => {
132                if (sigs_indexes.len() as u64) >= *t {
133                    Ok(())
134                } else {
135                    Err(SignatureError::NotEnoughSigsError)
136                }
137            }
138            SignatureThreshold::Weighted(ref thresh) => thresh.enough_signatures(sigs_indexes),
139        }
140    }
141}
142
143impl Default for SignatureThreshold {
144    fn default() -> Self {
145        Self::Simple(1)
146    }
147}
148#[derive(
149    Serialize,
150    Deserialize,
151    Debug,
152    Clone,
153    PartialEq,
154    rkyv::Archive,
155    rkyv::Serialize,
156    rkyv::Deserialize,
157)]
158#[rkyv(derive(Debug))]
159pub struct ThresholdClause(Vec<ThresholdFraction>);
160
161impl ThresholdClause {
162    pub fn new(fracs: &[ThresholdFraction]) -> Self {
163        Self(fracs.to_owned())
164    }
165
166    pub fn new_from_tuples(tuples: Vec<(u64, u64)>) -> Self {
167        let clause = tuples
168            .into_iter()
169            .map(|(n, d)| ThresholdFraction::new(n, d))
170            .collect();
171        Self(clause)
172    }
173
174    pub fn length(&self) -> usize {
175        self.0.len()
176    }
177
178    pub fn enough_signatures(
179        &self,
180        start_index: usize,
181        sigs_indexes: &[usize],
182    ) -> Result<(), SignatureError> {
183        (sigs_indexes
184            .iter()
185            .fold(Some(Zero::zero()), |acc: Option<Fraction>, sig_index| {
186                if let (Some(element), Some(sum)) = (self.0.get(sig_index - start_index), acc) {
187                    Some(sum + element.fraction)
188                } else {
189                    None
190                }
191            })
192            .ok_or_else(|| SignatureError::MissingIndex)?
193            >= One::one())
194        .then(|| ())
195        .ok_or(SignatureError::NotEnoughSigsError)
196    }
197}
198
199#[derive(
200    Deserialize,
201    Serialize,
202    Debug,
203    Clone,
204    PartialEq,
205    rkyv::Archive,
206    rkyv::Serialize,
207    rkyv::Deserialize,
208)]
209#[rkyv(derive(Debug))]
210
211pub struct MultiClauses(Vec<ThresholdClause>);
212
213impl MultiClauses {
214    pub fn new(fracs: Vec<Vec<ThresholdFraction>>) -> Self {
215        let clauses = fracs
216            .iter()
217            .map(|clause| ThresholdClause::new(clause))
218            .collect();
219
220        Self(clauses)
221    }
222
223    pub fn new_from_tuples(fracs: Vec<Vec<(u64, u64)>>) -> Self {
224        let wt = fracs
225            .into_iter()
226            .map(ThresholdClause::new_from_tuples)
227            .collect();
228        MultiClauses(wt)
229    }
230
231    pub fn length(&self) -> usize {
232        self.0.iter().map(|l| l.length()).sum()
233    }
234
235    pub fn enough_signatures(&self, sigs_indexes: &[usize]) -> Result<(), SignatureError> {
236        self.0
237            .iter()
238            .fold(Ok((0, true)), |acc, clause| -> Result<_, SignatureError> {
239                let (start, enough) = acc?;
240                let sigs: Vec<usize> = sigs_indexes
241                    .iter()
242                    .cloned()
243                    .filter(|sig_index| {
244                        sig_index >= &start && sig_index < &(start + clause.0.len())
245                    })
246                    .collect();
247                Ok((
248                    start + clause.0.len(),
249                    enough && clause.enough_signatures(start, &sigs).is_ok(),
250                ))
251            })?
252            .1
253            .then(|| ())
254            .ok_or(SignatureError::NotEnoughSigsError)
255    }
256}
257
258mod rkyv_serialization {
259    use fraction::Fraction;
260
261    #[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
262    #[rkyv(remote = fraction::Sign)]
263    #[rkyv(derive(Debug))]
264    pub enum SignDef {
265        Plus,
266        Minus,
267    }
268
269    impl From<SignDef> for fraction::Sign {
270        fn from(value: SignDef) -> Self {
271            match value {
272                SignDef::Plus => Self::Plus,
273                SignDef::Minus => Self::Minus,
274            }
275        }
276    }
277
278    #[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
279    #[rkyv(remote = Fraction)]
280    #[rkyv(derive(Debug))]
281    pub enum FractionDef {
282        Rational(
283            #[rkyv(with = SignDef)] fraction::Sign,
284            #[rkyv(with = RatioDef)] fraction::Ratio<u64>,
285        ),
286        Infinity(#[rkyv(with = SignDef)] fraction::Sign),
287        NaN,
288    }
289
290    impl From<FractionDef> for Fraction {
291        fn from(value: FractionDef) -> Self {
292            match value {
293                FractionDef::Rational(sign, ratio) => Fraction::Rational(sign, ratio),
294                FractionDef::Infinity(sign) => Fraction::Infinity(sign),
295                FractionDef::NaN => Fraction::NaN,
296            }
297        }
298    }
299
300    #[derive(rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
301    #[rkyv(remote = fraction::Ratio<u64>)]
302    #[rkyv(derive(Debug))]
303    pub struct RatioDef {
304        /// Numerator.
305        #[rkyv(getter = fraction::Ratio::numer)]
306        numer: u64,
307        /// Denominator.
308        #[rkyv(getter = fraction::Ratio::denom)]
309        denom: u64,
310    }
311
312    impl From<RatioDef> for fraction::Ratio<u64> {
313        fn from(value: RatioDef) -> Self {
314            Self::new(value.numer, value.denom)
315        }
316    }
317}
318
319#[test]
320fn test_enough_sigs() -> Result<(), SignatureError> {
321    // Threshold: [[1/1], [1/2, 1/2, 1/2], [1/2,1/2]]
322    let wt = MultiClauses::new_from_tuples(vec![vec![(1, 1)], vec![(1, 2), (1, 2), (1, 2)]]);
323    let sigs_indexes: Vec<_> = vec![0, 1, 2, 3];
324
325    // All signatures.
326    assert!(wt.enough_signatures(&sigs_indexes.clone()).is_ok());
327
328    // Enough signatures.
329    let enough = vec![
330        sigs_indexes[0].clone(),
331        sigs_indexes[1].clone(),
332        sigs_indexes[3].clone(),
333    ];
334    assert!(wt.enough_signatures(&enough.clone()).is_ok());
335
336    let not_enough = vec![sigs_indexes[0].clone()];
337    assert!(!wt.enough_signatures(&not_enough.clone()).is_ok());
338
339    Ok(())
340}
341
342#[test]
343pub fn test_weighted_treshold_serialization() -> Result<(), SignatureError> {
344    let multi_threshold = r#"[["1"],["1/2","1/2","1/2"]]"#.to_string();
345    let wt: WeightedThreshold = serde_json::from_str(&multi_threshold).unwrap();
346    assert!(matches!(wt, WeightedThreshold::Multi(_)));
347    // assert_eq!(serde_json::to_string(&wt).unwrap(), multi_threshold);
348    assert_eq!(
349        serde_json::to_string(&wt).unwrap(),
350        r#"[["1"],["1/2","1/2","1/2"]]"#.to_string()
351    );
352
353    let single_threshold = r#"["1/2","1/2","1/2"]"#.to_string();
354    let wt: WeightedThreshold = serde_json::from_str(&single_threshold).unwrap();
355    assert!(matches!(wt, WeightedThreshold::Single(_)));
356    assert_eq!(serde_json::to_string(&wt).unwrap(), single_threshold);
357    Ok(())
358}