cesride/core/counter/
mod.rs

1pub mod tables;
2
3use crate::core::util;
4use crate::error::{err, Error, Result};
5
6#[derive(Debug, Clone, PartialEq)]
7pub struct Counter {
8    pub(crate) code: String,
9    pub(crate) count: u32,
10}
11
12impl Counter {
13    pub fn new(
14        count: Option<u32>,
15        count_b64: Option<&str>,
16        code: Option<&str>,
17        qb64b: Option<&[u8]>,
18        qb64: Option<&str>,
19        qb2: Option<&[u8]>,
20    ) -> Result<Self> {
21        if let Some(code) = code {
22            let count = if let Some(count) = count {
23                count
24            } else if let Some(count_b64) = count_b64 {
25                util::b64_to_u32(count_b64)?
26            } else {
27                1
28            };
29
30            Self::new_with_code_and_count(code, count)
31        } else if let Some(qb64b) = qb64b {
32            Self::new_with_qb64b(qb64b)
33        } else if let Some(qb64) = qb64 {
34            Self::new_with_qb64(qb64)
35        } else if let Some(qb2) = qb2 {
36            Self::new_with_qb2(qb2)
37        } else {
38            err!(Error::Validation("need either code and count, qb64b, qb64 or qb2".to_string()))
39        }
40    }
41
42    pub fn code(&self) -> String {
43        self.code.clone()
44    }
45
46    pub fn count(&self) -> u32 {
47        self.count
48    }
49
50    pub fn count_as_b64(&self, length: usize) -> Result<String> {
51        let length = if length == 0 { tables::sizage(&self.code())?.ss as usize } else { length };
52        util::u32_to_b64(self.count(), length)
53    }
54
55    pub fn qb64(&self) -> Result<String> {
56        self.infil()
57    }
58
59    pub fn qb64b(&self) -> Result<Vec<u8>> {
60        Ok(self.qb64()?.as_bytes().to_vec())
61    }
62
63    pub fn qb2(&self) -> Result<Vec<u8>> {
64        self.binfil()
65    }
66
67    pub fn sem_ver_str_to_b64(version: &str) -> Result<String> {
68        let strings = version.split('.').collect::<Vec<_>>();
69        let mut parts = vec![0u8; 3];
70
71        if strings.len() > 3 {
72            return err!(Error::Conversion(format!(
73                "invalid semantic version: version = '{version}'"
74            )));
75        }
76
77        for i in 0..strings.len() {
78            let n = match strings[i].parse::<i8>() {
79                Ok(n) => {
80                    if n < 0 {
81                        return err!(Error::Conversion(format!(
82                            "invalid semantic version: version = '{version}'"
83                        )));
84                    } else {
85                        n as u8
86                    }
87                }
88                Err(_) => {
89                    if strings[i].is_empty() {
90                        0
91                    } else {
92                        return err!(Error::Conversion(format!(
93                            "invalid semantic version: version = '{version}'"
94                        )));
95                    }
96                }
97            };
98            parts[i] = n;
99        }
100
101        Counter::sem_ver_parts_to_b64(&parts)
102    }
103
104    pub fn sem_ver_to_b64(major: u8, minor: u8, patch: u8) -> Result<String> {
105        let parts = &[major, minor, patch];
106        Counter::sem_ver_parts_to_b64(parts)
107    }
108
109    pub fn new_with_code_and_count(code: &str, count: u32) -> Result<Self> {
110        if code.is_empty() {
111            return err!(Error::EmptyMaterial("empty code".to_string()));
112        }
113
114        let szg = tables::sizage(code)?;
115        let cs = szg.hs + szg.ss;
116        if szg.fs != cs || cs % 4 != 0 {
117            // unreachable
118            // code validated and unless sizages are broken this cannot be reached
119            return err!(Error::InvalidCodeSize(format!(
120                "whole code size not a multiple of 4: cs = {cs}, fs = {}",
121                szg.fs
122            )));
123        }
124
125        if count > 64_u32.pow(szg.ss) - 1 {
126            return err!(Error::InvalidVarIndex(format!(
127                "invalid count for code: count = {count}, code = '{code}'"
128            )));
129        }
130
131        Ok(Counter { code: code.to_string(), count })
132    }
133
134    pub fn new_with_qb64(qb64: &str) -> Result<Self> {
135        let mut counter: Counter = Default::default();
136        counter.exfil(qb64)?;
137        Ok(counter)
138    }
139
140    pub fn new_with_qb64b(qb64b: &[u8]) -> Result<Self> {
141        let qb64 = String::from_utf8(qb64b.to_vec())?;
142
143        let mut counter: Counter = Default::default();
144        counter.exfil(&qb64)?;
145        Ok(counter)
146    }
147
148    pub fn new_with_qb2(qb2: &[u8]) -> Result<Self> {
149        let mut counter: Counter = Default::default();
150        counter.bexfil(qb2)?;
151        Ok(counter)
152    }
153
154    fn sem_ver_parts_to_b64(parts: &[u8]) -> Result<String> {
155        for p in parts.iter().copied() {
156            if p > 63 {
157                return err!(Error::Parsing(format!(
158                    "semantic version out of bounds: parts = {parts:?}"
159                )));
160            }
161        }
162
163        Ok(parts
164            .iter()
165            .map(|p| {
166                match util::u32_to_b64(*p as u32, 1) {
167                    Ok(s) => s,
168                    Err(_) => unreachable!(), // this is programmer error, since *p < 64
169                }
170            })
171            .collect::<Vec<String>>()
172            .join(""))
173    }
174
175    fn infil(&self) -> Result<String> {
176        let code = &self.code();
177        let count = self.count();
178
179        let szg = tables::sizage(code)?;
180        let cs = szg.hs + szg.ss;
181
182        if szg.fs != cs || cs % 4 != 0 {
183            // unreachable
184            // unless sizages are broken this cannot happen
185            return err!(Error::InvalidCodeSize(format!(
186                "whole code size not complete or not a multiple of 4: cs = {cs}, fs = {}",
187                szg.fs
188            )));
189        }
190
191        if count > 64_u32.pow(szg.ss) - 1 {
192            return err!(Error::InvalidVarIndex(format!(
193                "invalid count for code: count = {count}, code = '{code}'"
194            )));
195        }
196
197        let both = format!("{code}{}", util::u32_to_b64(count, szg.ss as usize)?);
198        if both.len() != cs as usize {
199            // unreachable
200            // unless sizages are broken, we constructed both to be of length cs
201            return err!(Error::InvalidCodeSize(format!(
202                "mismatched code size: size = {}, code = '{both}'",
203                both.len()
204            )));
205        }
206
207        Ok(both)
208    }
209
210    fn binfil(&self) -> Result<Vec<u8>> {
211        let both = self.infil()?;
212        util::code_b64_to_b2(&both)
213    }
214
215    fn exfil(&mut self, qb64: &str) -> Result<()> {
216        if qb64.is_empty() {
217            return err!(Error::EmptyMaterial("empty qb64".to_string()));
218        }
219
220        // we validated there will be a char here, above.
221        let first = &qb64[..2];
222
223        let hs = tables::hardage(first)? as usize;
224        if qb64.len() < hs {
225            return err!(Error::Shortage(format!(
226                "insufficient material for hard part of code: qb64 size = {}, hs = {hs}",
227                qb64.len()
228            )));
229        }
230
231        // bounds already checked
232        let hard = &qb64[..hs];
233        let szg = tables::sizage(hard)?;
234        let cs = szg.hs + szg.ss;
235
236        if qb64.len() < cs as usize {
237            return err!(Error::Shortage(format!(
238                "insufficient material for code: qb64 size = {}, cs = {cs}",
239                qb64.len()
240            )));
241        }
242
243        let count_b64 = &qb64[szg.hs as usize..cs as usize];
244        let count = util::b64_to_u64(count_b64)? as u32;
245
246        self.code = hard.to_string();
247        self.count = count;
248
249        Ok(())
250    }
251
252    fn bexfil(&mut self, qb2: &[u8]) -> Result<()> {
253        if qb2.is_empty() {
254            return err!(Error::EmptyMaterial("empty qualified base2".to_string()));
255        }
256
257        let first = util::nab_sextets(qb2, 2)?;
258        if first[0] > 0x3e {
259            if first[0] == 0x3f {
260                return err!(Error::UnexpectedOpCode(
261                    "unexpected start during extraction".to_string(),
262                ));
263            } else {
264                // unreachable
265                // programmer error - nab_sextets ensures values fall below 0x40. the only possible
266                // value is 0x3f, and we handle it
267                return err!(Error::UnexpectedCode(format!(
268                    "unexpected code start: sextets = {first:?}"
269                )));
270            }
271        }
272
273        let hs = tables::bardage(&first)?;
274        let bhs = ((hs + 1) * 3) / 4;
275        if qb2.len() < bhs as usize {
276            return err!(Error::Shortage(format!(
277                "need more bytes: qb2 size = {}, bhs = {bhs}",
278                qb2.len()
279            )));
280        }
281
282        let hard = util::code_b2_to_b64(qb2, hs as usize)?;
283        let szg = tables::sizage(&hard)?;
284        let cs = szg.hs + szg.ss;
285        let bcs = ((cs + 1) * 3) / 4;
286        if qb2.len() < bcs as usize {
287            return err!(Error::Shortage(format!(
288                "need more bytes: qb2 size = {}, bcs = {bcs}",
289                qb2.len()
290            )));
291        }
292
293        let both = util::code_b2_to_b64(qb2, cs as usize)?;
294        let mut count = 0;
295        for c in both[hs as usize..cs as usize].chars() {
296            count <<= 6;
297            count += util::b64_char_to_index(c)? as u32;
298        }
299
300        self.code = hard;
301        self.count = count;
302
303        Ok(())
304    }
305
306    pub fn full_size(&self) -> Result<usize> {
307        Ok(tables::sizage(&self.code())?.fs as usize)
308    }
309}
310
311impl Default for Counter {
312    fn default() -> Self {
313        Counter { code: "".to_string(), count: 0 }
314    }
315}
316
317#[cfg(test)]
318mod test {
319    use crate::core::counter::{tables as counter, Counter};
320    use base64::{engine::general_purpose as b64_engine, Engine};
321    use rstest::rstest;
322
323    #[rstest]
324    #[case("-AAB", 1, "B", counter::Codex::ControllerIdxSigs)]
325    #[case("-AAF", 5, "F", counter::Codex::ControllerIdxSigs)]
326    #[case("-0VAAAQA", 1024, "QA", counter::Codex::BigAttachedMaterialQuadlets)]
327    fn new(#[case] qsc: &str, #[case] count: u32, #[case] count_b64: &str, #[case] code: &str) {
328        assert!(Counter::new(None, None, None, None, None, None).is_err());
329        let counter = Counter::new(None, None, Some(code), None, None, None).unwrap();
330        assert_eq!(counter.count(), 1);
331
332        let counter1 = Counter::new(Some(count), None, Some(code), None, None, None).unwrap();
333        let counter2 = Counter::new(None, Some(count_b64), Some(code), None, None, None).unwrap();
334        let counter3 = Counter::new(None, None, None, None, Some(qsc), None).unwrap();
335
336        assert_eq!(counter1.code(), code);
337        assert_eq!(counter2.code(), code);
338        assert_eq!(counter3.code(), code);
339        assert_eq!(counter1.count(), count);
340        assert_eq!(counter2.count(), count);
341        assert_eq!(counter3.count(), count);
342
343        let qb64b = counter1.qb64b().unwrap();
344        let qb2 = counter1.qb2().unwrap();
345
346        assert!(Counter::new(None, None, None, Some(&qb64b), None, None).is_ok());
347        assert!(Counter::new(None, None, None, None, None, Some(&qb2)).is_ok());
348    }
349
350    #[rstest]
351    #[case("-AAB", 1, "B", counter::Codex::ControllerIdxSigs)]
352    #[case("-AAF", 5, "F", counter::Codex::ControllerIdxSigs)]
353    #[case("-0VAAAQA", 1024, "QA", counter::Codex::BigAttachedMaterialQuadlets)]
354    fn creation(
355        #[case] qsc: &str,
356        #[case] count: u32,
357        #[case] count_b64: &str,
358        #[case] code: &str,
359    ) {
360        let qscb = qsc.as_bytes();
361        let qscb2 = b64_engine::URL_SAFE.decode(qsc).unwrap();
362
363        let counter1 = Counter::new(Some(count), None, Some(code), None, None, None).unwrap();
364        let counter2 = Counter::new(None, Some(count_b64), Some(code), None, None, None).unwrap();
365        let counter3 = Counter::new(None, None, None, None, Some(qsc), None).unwrap();
366        let counter4 = Counter::new(None, None, None, Some(qscb), None, None).unwrap();
367        let counter5 = Counter::new(None, None, None, None, None, Some(&qscb2)).unwrap();
368
369        assert_eq!(counter1.code(), counter2.code());
370        assert_eq!(counter1.count(), counter2.count());
371        assert_eq!(counter1.code(), counter3.code());
372        assert_eq!(counter1.count(), counter3.count());
373        assert_eq!(counter1.code(), counter4.code());
374        assert_eq!(counter1.count(), counter4.count());
375        assert_eq!(counter1.code(), counter5.code());
376        assert_eq!(counter1.count(), counter5.count());
377    }
378
379    #[rstest]
380    #[case(0, "AAA", 0, "AAA", counter::Codex::KERIProtocolStack)]
381    fn versioned_creation(
382        #[case] verint: u32,
383        #[case] version: &str,
384        #[case] count: u32,
385        #[case] count_b64: &str,
386        #[case] code: &str,
387    ) {
388        let qsc = &format!("{code}{version}");
389        let qscb = qsc.as_bytes();
390        let qscb2 = b64_engine::URL_SAFE.decode(qsc).unwrap();
391
392        let counter1 = Counter::new(Some(count), None, Some(code), None, None, None).unwrap();
393        let counter2 = Counter::new(None, Some(count_b64), Some(code), None, None, None).unwrap();
394        let counter3 = Counter::new(None, None, None, None, Some(qsc), None).unwrap();
395        let counter4 = Counter::new(None, None, None, Some(qscb), None, None).unwrap();
396        let counter5 = Counter::new(None, None, None, None, None, Some(&qscb2)).unwrap();
397
398        assert_eq!(counter1.code(), code);
399        assert_eq!(counter1.count(), verint);
400        assert_eq!(counter1.code(), counter2.code());
401        assert_eq!(counter1.count(), counter2.count());
402        assert_eq!(counter1.code(), counter3.code());
403        assert_eq!(counter1.count(), counter3.count());
404        assert_eq!(counter1.code(), counter4.code());
405        assert_eq!(counter1.count(), counter4.count());
406        assert_eq!(counter1.code(), counter5.code());
407        assert_eq!(counter1.count(), counter5.count());
408
409        assert_eq!(counter1.count_as_b64(3).unwrap(), version);
410
411        // when 0 is an argument, we use a default
412        assert_eq!(counter1.count_as_b64(0).unwrap(), version);
413    }
414
415    #[rstest]
416    fn b64_overflow_and_underflow(#[values("-AAB")] qsc: &str) {
417        // add some chars
418        let longqsc64 = &format!("{qsc}ABCD");
419        let counter = Counter::new(None, None, None, None, Some(longqsc64), None).unwrap();
420        assert_eq!(
421            counter.qb64().unwrap().len() as u32,
422            counter::sizage(&counter.code()).unwrap().fs
423        );
424
425        // remove a char
426        let shortqsc64 = &qsc[..qsc.len() - 1];
427        assert!(Counter::new_with_qb64(shortqsc64).is_err());
428    }
429
430    #[rstest]
431    fn binary_overflow_and_underflow(#[values(vec![248, 0, 1])] qscb2: Vec<u8>) {
432        // add some bytes
433        let mut longqscb2 = qscb2.clone();
434        longqscb2.resize(longqscb2.len() + 5, 1);
435        let counter = Counter::new(None, None, None, None, None, Some(&longqscb2)).unwrap();
436        assert_eq!(counter.qb2().unwrap(), *qscb2);
437        assert_eq!(
438            counter.qb64().unwrap().len() as u32,
439            counter::sizage(&counter.code()).unwrap().fs
440        );
441
442        // remove a bytes
443        let shortqscb2 = &qscb2[..qscb2.len() - 1];
444        assert!(Counter::new(None, None, None, None, None, Some(shortqscb2)).is_err());
445    }
446
447    #[rstest]
448    fn exfil_infil_bexfil_binfil(#[values("-0VAAAQA")] qsc: &str) {
449        let counter1 = Counter::new(None, None, None, None, Some(qsc), None).unwrap();
450        let qb2 = counter1.qb2().unwrap();
451        let counter2 = Counter::new(None, None, None, None, None, Some(&qb2)).unwrap();
452        assert_eq!(counter1.code(), counter2.code());
453        assert_eq!(counter1.count(), counter2.count());
454        assert_eq!(counter1.qb2().unwrap(), counter2.qb2().unwrap());
455        assert_eq!(qsc, counter2.qb64().unwrap());
456    }
457
458    #[rstest]
459    #[case("1.2.3", "BCD")]
460    #[case("1.1", "BBA")]
461    #[case("1.", "BAA")]
462    #[case("1", "BAA")]
463    #[case("1.2.", "BCA")]
464    #[case("..", "AAA")]
465    #[case("1..3", "BAD")]
466    fn semantic_versioning_strings(#[case] version: &str, #[case] b64: &str) {
467        assert_eq!(Counter::sem_ver_str_to_b64(version).unwrap(), b64);
468    }
469
470    #[rstest]
471    #[case(1, 0, 0, "BAA")]
472    #[case(0, 1, 0, "ABA")]
473    #[case(0, 0, 1, "AAB")]
474    #[case(3, 4, 5, "DEF")]
475    fn semantic_versioning_u8s(
476        #[case] major: u8,
477        #[case] minor: u8,
478        #[case] patch: u8,
479        #[case] b64: &str,
480    ) {
481        assert_eq!(Counter::sem_ver_to_b64(major, minor, patch).unwrap(), b64);
482    }
483
484    #[rstest]
485    fn semantic_versioning_unhappy_strings(#[values("64.0.1", "-1.0.1", "0.0.64")] version: &str) {
486        assert!(Counter::sem_ver_str_to_b64(version).is_err());
487    }
488
489    #[rstest]
490    #[case(64, 0, 0)]
491    fn semantic_versioning_unhappy_u32s(#[case] major: u8, #[case] minor: u8, #[case] patch: u8) {
492        assert!(Counter::sem_ver_to_b64(major, minor, patch).is_err());
493    }
494
495    #[test]
496    fn unhappy_paths() {
497        assert!(Counter::new_with_code_and_count("", 1).is_err());
498        assert!(
499            Counter::new_with_code_and_count(counter::Codex::ControllerIdxSigs, 64 * 64).is_err()
500        );
501        assert!(Counter::sem_ver_str_to_b64("1.2.3.4").is_err());
502        assert!(Counter::sem_ver_str_to_b64("bad.semantic.version").is_err());
503        assert!((Counter { code: counter::Codex::ControllerIdxSigs.to_string(), count: 64 * 64 })
504            .qb64()
505            .is_err());
506
507        assert!(Counter::new(None, None, None, None, Some(""), None).is_err());
508        assert!(Counter::new(None, None, None, None, Some("--"), None).is_err());
509        assert!(Counter::new(None, None, None, None, Some("__"), None).is_err());
510        assert!(Counter::new(
511            None,
512            None,
513            None,
514            None,
515            Some(counter::Codex::ControllerIdxSigs),
516            None
517        )
518        .is_err());
519
520        assert!(Counter::new(None, None, None, Some(&[]), None, None).is_err());
521
522        assert!(Counter::new(None, None, None, None, None, Some(&[])).is_err());
523        assert!(Counter::new(None, None, None, None, None, Some(&[0xf8, 0])).is_err());
524        assert!(Counter::new(None, None, None, None, None, Some(&[0xfc, 0])).is_err());
525        assert!(Counter::new(None, None, None, None, None, Some(&[0xfb, 0xe0])).is_err());
526    }
527
528    #[rstest]
529    #[case(counter::Codex::ControllerIdxSigs, 1)]
530    fn qb64b(#[case] code: &str, #[case] count: u32) {
531        let c = Counter { code: code.to_string(), count };
532        let qb64b = c.qb64b().unwrap();
533        assert!(Counter::new(None, None, None, Some(&qb64b), None, None).is_ok());
534    }
535
536    #[rstest]
537    #[case("-AAB", counter::Codex::ControllerIdxSigs, 4)]
538    #[case("-0VAAAQA", counter::Codex::BigAttachedMaterialQuadlets, 8)]
539    #[case("--AAAAAA", counter::Codex::KERIProtocolStack, 8)]
540    fn qb_size(#[case] qsc: &str, #[case] code: &str, #[case] full_size: usize) {
541        let counter = Counter::new(None, None, None, None, Some(qsc), None).unwrap();
542        assert_eq!(counter.code(), code); // Just a self-check of the input data
543        assert_eq!(counter.full_size().unwrap(), full_size);
544    }
545}