1#![cfg_attr(docsrs, feature(doc_auto_cfg))]
2#![doc = include_str!("../README.md")]
3#![deny(missing_docs)]
4#![cfg_attr(not(feature = "std"), no_std)]
5
6use core::fmt::Debug;
7use std_shims::{
8 vec,
9 vec::Vec,
10 io::{self, Read, Write},
11};
12
13use curve25519_dalek::{
14 scalar::Scalar,
15 edwards::{EdwardsPoint, CompressedEdwardsY},
16};
17
18const VARINT_CONTINUATION_MASK: u8 = 0b1000_0000;
19
20mod sealed {
21 pub trait VarInt: TryInto<u64> + TryFrom<u64> + Copy {
25 const BITS: usize;
26 }
27
28 impl VarInt for u8 {
29 const BITS: usize = 8;
30 }
31 impl VarInt for u32 {
32 const BITS: usize = 32;
33 }
34 impl VarInt for u64 {
35 const BITS: usize = 64;
36 }
37 impl VarInt for usize {
38 const BITS: usize = core::mem::size_of::<usize>() * 8;
39 }
40}
41
42pub fn varint_len<V: sealed::VarInt>(varint: V) -> usize {
46 let varint_u64: u64 = varint.try_into().map_err(|_| "varint exceeded u64").unwrap();
47 ((usize::try_from(u64::BITS - varint_u64.leading_zeros()).unwrap().saturating_sub(1)) / 7) + 1
48}
49
50pub fn write_byte<W: Write>(byte: &u8, w: &mut W) -> io::Result<()> {
54 w.write_all(&[*byte])
55}
56
57pub fn write_varint<W: Write, U: sealed::VarInt>(varint: &U, w: &mut W) -> io::Result<()> {
61 let mut varint: u64 = (*varint).try_into().map_err(|_| "varint exceeded u64").unwrap();
62 while {
63 let mut b = u8::try_from(varint & u64::from(!VARINT_CONTINUATION_MASK)).unwrap();
64 varint >>= 7;
65 if varint != 0 {
66 b |= VARINT_CONTINUATION_MASK;
67 }
68 write_byte(&b, w)?;
69 varint != 0
70 } {}
71 Ok(())
72}
73
74pub fn write_scalar<W: Write>(scalar: &Scalar, w: &mut W) -> io::Result<()> {
76 w.write_all(&scalar.to_bytes())
77}
78
79pub fn write_point<W: Write>(point: &EdwardsPoint, w: &mut W) -> io::Result<()> {
81 w.write_all(&point.compress().to_bytes())
82}
83
84pub fn write_raw_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
86 f: F,
87 values: &[T],
88 w: &mut W,
89) -> io::Result<()> {
90 for value in values {
91 f(value, w)?;
92 }
93 Ok(())
94}
95
96pub fn write_vec<T, W: Write, F: Fn(&T, &mut W) -> io::Result<()>>(
98 f: F,
99 values: &[T],
100 w: &mut W,
101) -> io::Result<()> {
102 write_varint(&values.len(), w)?;
103 write_raw_vec(f, values, w)
104}
105
106pub fn read_bytes<R: Read, const N: usize>(r: &mut R) -> io::Result<[u8; N]> {
108 let mut res = [0; N];
109 r.read_exact(&mut res)?;
110 Ok(res)
111}
112
113pub fn read_byte<R: Read>(r: &mut R) -> io::Result<u8> {
115 Ok(read_bytes::<_, 1>(r)?[0])
116}
117
118pub fn read_u16<R: Read>(r: &mut R) -> io::Result<u16> {
120 read_bytes(r).map(u16::from_le_bytes)
121}
122
123pub fn read_u32<R: Read>(r: &mut R) -> io::Result<u32> {
125 read_bytes(r).map(u32::from_le_bytes)
126}
127
128pub fn read_u64<R: Read>(r: &mut R) -> io::Result<u64> {
130 read_bytes(r).map(u64::from_le_bytes)
131}
132
133pub fn read_varint<R: Read, U: sealed::VarInt>(r: &mut R) -> io::Result<U> {
135 let mut bits = 0;
136 let mut res = 0;
137 while {
138 let b = read_byte(r)?;
139 if (bits != 0) && (b == 0) {
140 Err(io::Error::other("non-canonical varint"))?;
141 }
142 if ((bits + 7) >= U::BITS) && (b >= (1 << (U::BITS - bits))) {
143 Err(io::Error::other("varint overflow"))?;
144 }
145
146 res += u64::from(b & (!VARINT_CONTINUATION_MASK)) << bits;
147 bits += 7;
148 b & VARINT_CONTINUATION_MASK == VARINT_CONTINUATION_MASK
149 } {}
150 res.try_into().map_err(|_| io::Error::other("VarInt does not fit into integer type"))
151}
152
153pub fn read_scalar<R: Read>(r: &mut R) -> io::Result<Scalar> {
158 Option::from(Scalar::from_canonical_bytes(read_bytes(r)?))
159 .ok_or_else(|| io::Error::other("unreduced scalar"))
160}
161
162pub fn decompress_point(bytes: [u8; 32]) -> Option<EdwardsPoint> {
172 CompressedEdwardsY(bytes)
173 .decompress()
174 .filter(|point| point.compress().to_bytes() == bytes)
176}
177
178pub fn read_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
183 let bytes = read_bytes(r)?;
184 decompress_point(bytes).ok_or_else(|| io::Error::other("invalid point"))
185}
186
187pub fn read_torsion_free_point<R: Read>(r: &mut R) -> io::Result<EdwardsPoint> {
189 read_point(r)
190 .ok()
191 .filter(EdwardsPoint::is_torsion_free)
192 .ok_or_else(|| io::Error::other("invalid point"))
193}
194
195pub fn read_raw_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(
197 f: F,
198 len: usize,
199 r: &mut R,
200) -> io::Result<Vec<T>> {
201 let mut res = vec![];
202 for _ in 0 .. len {
203 res.push(f(r)?);
204 }
205 Ok(res)
206}
207
208pub fn read_array<R: Read, T: Debug, F: Fn(&mut R) -> io::Result<T>, const N: usize>(
210 f: F,
211 r: &mut R,
212) -> io::Result<[T; N]> {
213 read_raw_vec(f, N, r).map(|vec| vec.try_into().unwrap())
214}
215
216pub fn read_vec<R: Read, T, F: Fn(&mut R) -> io::Result<T>>(f: F, r: &mut R) -> io::Result<Vec<T>> {
218 read_raw_vec(f, read_varint(r)?, r)
219}