kyber_rs/group/edwards25519/
point.rs

1use core::fmt::{Debug, Display, Formatter, LowerHex, UpperHex};
2
3use serde::{Deserialize, Serialize};
4
5use crate::{
6    cipher::Stream,
7    encoding::{BinaryMarshaler, BinaryUnmarshaler, Marshaling, MarshallingError},
8    group::{self, internal::marshalling, PointCanCheckCanonicalAndSmallOrder, PointError},
9};
10
11use super::{
12    constants::{BASEEXT, COFACTOR_SCALAR, NULL_POINT, PRIME_ORDER_SCALAR, WEAK_KEYS},
13    ge::{
14        ge_scalar_mult, ge_scalar_mult_base, CachedGroupElement, CompletedGroupElement,
15        ExtendedGroupElement,
16    },
17    ge_mult_vartime::ge_scalar_mult_vartime,
18    Scalar,
19};
20
21const MARSHAL_POINT_ID: [u8; 8] = [b'e', b'd', b'.', b'p', b'o', b'i', b'n', b't'];
22
23#[derive(Copy, Clone, Eq, Ord, PartialOrd, Debug, Default, Serialize, Deserialize)]
24pub struct Point {
25    ge: ExtendedGroupElement,
26    var_time: bool,
27}
28
29impl Point {
30    pub fn new() -> Self {
31        Self::default()
32    }
33}
34
35impl BinaryMarshaler for Point {
36    fn marshal_binary(&self) -> Result<Vec<u8>, MarshallingError> {
37        let mut b = [0_u8; 32];
38        self.ge.write_bytes(&mut b);
39        Ok(b.to_vec())
40    }
41}
42impl BinaryUnmarshaler for Point {
43    fn unmarshal_binary(&mut self, data: &[u8]) -> Result<(), MarshallingError> {
44        if !self.ge.set_bytes(data) {
45            return Err(MarshallingError::InvalidInput(
46                "invalid Ed25519 curve point".to_owned(),
47            ));
48        }
49        Ok(())
50    }
51}
52
53impl Marshaling for Point {
54    fn marshal_to(&self, w: &mut impl std::io::Write) -> Result<(), MarshallingError> {
55        marshalling::point_marshal_to(self, w)
56    }
57
58    fn marshal_size(&self) -> usize {
59        32
60    }
61
62    fn unmarshal_from(&mut self, r: &mut impl std::io::Read) -> Result<(), MarshallingError> {
63        marshalling::point_unmarshal_from(self, r)
64    }
65
66    fn unmarshal_from_random(&mut self, r: &mut (impl std::io::Read + Stream)) {
67        marshalling::point_unmarshal_from_random(self, r);
68    }
69
70    fn marshal_id(&self) -> [u8; 8] {
71        MARSHAL_POINT_ID
72    }
73}
74
75impl group::Point for Point {
76    type SCALAR = Scalar;
77
78    /// [`null()`] sets [`self`] to the neutral element, which is `(0,1)` for twisted Edwards curves.
79    fn null(mut self) -> Self {
80        self.ge.zero();
81        self
82    }
83
84    /// [`base()`] sets [`self`] to the standard base point for this curve
85    fn base(mut self) -> Self {
86        self.ge = BASEEXT;
87        self
88    }
89
90    fn pick<S: crate::cipher::Stream>(self, rand: &mut S) -> Self {
91        self.embed(None, rand)
92    }
93
94    fn set(&mut self, p: &Self) -> Self {
95        self.ge = p.ge;
96        *self
97    }
98
99    fn embed_len(&self) -> usize {
100        // Reserve the most-significant 8 bits for pseudo-randomness.
101        // Reserve the least-significant 8 bits for embedded data length.
102        // (Hopefully it's unlikely we'll need >=2048-bit curves soon.)
103        (255 - 8 - 8) / 8
104    }
105
106    fn embed<S: Stream>(mut self, data: Option<&[u8]>, rand: &mut S) -> Self {
107        // How many bytes to embed?
108        let mut dl = self.embed_len();
109        let data_len = match data {
110            Some(d) => d.len(),
111            None => 0,
112        };
113        if dl > data_len {
114            dl = data_len;
115        }
116
117        loop {
118            // Pick a random point, with optional embedded data
119            let mut b = [0_u8; 32];
120            rand.xor_key_stream(&mut b, &[0_u8; 32]).unwrap();
121            if let Some(d) = data {
122                // Encode length in low 8 bits
123                b[0] = dl as u8;
124                // Copy in data to embed
125                b[1..1 + dl].copy_from_slice(&d[0..dl]);
126            }
127            // Try to decode
128            if !self.ge.set_bytes(&b) {
129                // invalid point, retry
130                continue;
131            }
132
133            // TODO: manage this case
134            // If we're using the full group,
135            // we just need any point on the curve, so we're done.
136            //		if c.full {
137            //			return P,data[dl:]
138            //		}
139
140            // We're using the prime-order subgroup,
141            // so we need to make sure the point is in that subencoding.
142            // If we're not trying to embed data,
143            // we can convert our point into one in the subgroup
144            // simply by multiplying it by the cofactor.
145            if data.is_none() {
146                // multiply by cofactor
147                let old_self = &self.clone();
148                self = self.mul(&COFACTOR_SCALAR, Some(old_self));
149                if self.eq(&NULL_POINT) {
150                    // unlucky; try again
151                    continue;
152                }
153                // success
154                return self;
155            }
156
157            // Since we need the point's y-coordinate to hold our data,
158            // we must simply check if the point is in the subgroup
159            // and retry point generation until it is.
160            let mut q = Point::default();
161            q = q.mul(&PRIME_ORDER_SCALAR, Some(&self));
162            if q.eq(&NULL_POINT) {
163                return self; // success
164            }
165            // Keep trying...
166        }
167    }
168
169    fn data(&self) -> Result<Vec<u8>, PointError> {
170        let mut b = [0u8; 32];
171        self.ge.write_bytes(&mut b);
172        let dl = b[0] as usize; // extract length byte
173        if dl > self.embed_len() {
174            return Err(PointError::EmbedDataLength);
175        }
176        Ok(b[1..1 + dl].to_vec())
177    }
178
179    fn add(mut self, p1: &Self, p2: &Self) -> Self {
180        let mut t2 = CachedGroupElement::default();
181        let mut r = CompletedGroupElement::default();
182
183        p2.ge.write_cached(&mut t2);
184        r.add(&p1.ge, &t2);
185        r.to_extended(&mut self.ge);
186
187        self
188    }
189
190    fn sub(mut self, p1: &Self, p2: &Self) -> Self {
191        let mut t2 = CachedGroupElement::default();
192        let mut r = CompletedGroupElement::default();
193
194        p2.ge.write_cached(&mut t2);
195        r.sub(&p1.ge, &t2);
196        r.to_extended(&mut self.ge);
197
198        self
199    }
200
201    fn neg(&mut self, a: &Self) -> Self {
202        self.ge.neg(&a.ge);
203        *self
204    }
205
206    /// [`mul()`] multiplies [`Point`] `p` by [`Scalar`] `s` using the repeated doubling method.
207    fn mul(mut self, s: &Scalar, p: Option<&Self>) -> Self {
208        let mut a = s.v;
209
210        match p {
211            None => {
212                ge_scalar_mult_base(&mut self.ge, &mut a);
213            }
214            Some(a_p) => {
215                if self.var_time {
216                    ge_scalar_mult_vartime(&mut self.ge, &mut a, &mut a_p.ge.clone());
217                } else {
218                    ge_scalar_mult(&mut self.ge, &mut a, &mut a_p.ge.clone());
219                }
220            }
221        }
222
223        self
224    }
225}
226
227impl PartialEq for Point {
228    /// [`eq()`] is an equality test for two [`points`](Point) on the same curve
229    fn eq(&self, p2: &Self) -> bool {
230        let mut b1 = [0_u8; 32];
231        let mut b2 = [0_u8; 32];
232        self.ge.write_bytes(&mut b1);
233        p2.ge.write_bytes(&mut b2);
234        for i in 0..b1.len() {
235            if b1[i] != b2[i] {
236                return false;
237            }
238        }
239        true
240    }
241}
242
243impl core::hash::Hash for Point {
244    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
245        let mut b = [0_u8; 32];
246        self.ge.write_bytes(&mut b);
247        b.hash(state);
248    }
249}
250
251impl Display for Point {
252    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
253        write!(f, "Ed25519Point({self:#x})")
254    }
255}
256
257impl LowerHex for Point {
258    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
259        let prefix = if f.alternate() { "0x" } else { "" };
260        let mut b = [0u8; 32];
261        self.ge.write_bytes(&mut b);
262        let encoded = hex::encode(b);
263        write!(f, "{prefix}{encoded}")
264    }
265}
266
267impl UpperHex for Point {
268    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
269        let prefix = if f.alternate() { "0X" } else { "" };
270        let mut b = [0u8; 32];
271        self.ge.write_bytes(&mut b);
272        let encoded = hex::encode_upper(b);
273        write!(f, "{prefix}{encoded}")
274    }
275}
276
277impl PointCanCheckCanonicalAndSmallOrder for Point {
278    /// [`has_small_order()`] determines whether the group element has small order
279    ///
280    /// Provides resilience against malicious key substitution attacks (M-S-UEO)
281    /// and message bound security (MSB) even for malicious keys
282    /// See paper <https://eprint.iacr.org/2020/823.pdf> for definitions and theorems
283    ///
284    /// This is the same code as in
285    /// <https://github.com/jedisct1/libsodium/blob/4744636721d2e420f8bbe2d563f31b1f5e682229/src/libsodium/crypto_core/ed25519/ref10/ed25519_ref10.c#L1170>
286    fn has_small_order(&self) -> bool {
287        let s = match self.marshal_binary() {
288            Ok(v) => v,
289            Err(_) => return false,
290        };
291
292        let mut c = [0u8; 5];
293
294        (0..31).for_each(|j| {
295            for i in 0..5 {
296                c[i] |= s[j] ^ WEAK_KEYS[i][j];
297            }
298        });
299        for i in 0..5 {
300            c[i] |= (s[31] & 0x7f) ^ WEAK_KEYS[i][31];
301        }
302
303        // Constant time verification if one or more of the c's are zero
304        let mut k = 0;
305        (0..5).for_each(|i| {
306            k |= (c[i] as u16) - 1;
307        });
308
309        (k >> 8) & 1 > 0
310    }
311
312    /// [`is_canonical()`] determines whether the group element is canonical
313    ///
314    /// Checks whether group element s is less than p, according to RFC8032ยง5.1.3.1
315    /// <https://tools.ietf.org/html/rfc8032#section-5.1.3>
316    ///
317    /// Taken from
318    /// <https://github.com/jedisct1/libsodium/blob/4744636721d2e420f8bbe2d563f31b1f5e682229/src/libsodium/crypto_core/ed25519/ref10/ed25519_ref10.c#L1113>
319    ///
320    /// The method accepts a buffer instead of calling `marshal_binary()` on the receiver
321    /// because that always returns a value modulo `prime`.
322    fn is_canonical(&self, b: &[u8]) -> bool {
323        if b.len() != 32 {
324            return false;
325        }
326
327        let mut c = (b[31] & 0x7f) ^ 0x7f;
328        for i in (1..=30).into_iter().rev() {
329            c |= b[i] ^ 0xff;
330        }
331
332        // subtraction might underflow
333        c = (((c as u16) - 1) >> 8) as u8;
334        let d = ((0xEDu16.wrapping_sub(1u16.wrapping_sub(b[0] as u16))) >> 8) as u8;
335
336        1 - (c & d & 1) == 1
337    }
338}