Skip to main content

bsv/primitives/
point.rs

1//! Affine point representation on the secp256k1 curve.
2//!
3//! Point provides the public-facing API for elliptic curve point operations
4//! including addition, scalar multiplication, compression/decompression, and
5//! DER encoding. Internally delegates heavy arithmetic to JacobianPoint for
6//! efficiency.
7
8use crate::primitives::big_number::{BigNumber, Endian};
9use crate::primitives::curve::Curve;
10use crate::primitives::error::PrimitivesError;
11use crate::primitives::jacobian_point::JacobianPoint;
12
13/// A point on the secp256k1 curve in affine coordinates (x, y).
14///
15/// The point at infinity is represented by `inf == true` (x and y are zero).
16#[derive(Clone, Debug)]
17pub struct Point {
18    /// The x-coordinate.
19    pub x: BigNumber,
20    /// The y-coordinate.
21    pub y: BigNumber,
22    /// Whether this is the point at infinity.
23    pub inf: bool,
24}
25
26impl Point {
27    /// Create a new point from x, y coordinates.
28    pub fn new(x: BigNumber, y: BigNumber) -> Self {
29        Point { x, y, inf: false }
30    }
31
32    /// Create the point at infinity (identity element).
33    pub fn infinity() -> Self {
34        Point {
35            x: BigNumber::zero(),
36            y: BigNumber::zero(),
37            inf: true,
38        }
39    }
40
41    /// Check if this is the point at infinity.
42    pub fn is_infinity(&self) -> bool {
43        self.inf
44    }
45
46    /// Validate that this point lies on the secp256k1 curve.
47    /// Returns true if y^2 = x^3 + 7 (mod p).
48    pub fn validate(&self) -> bool {
49        if self.inf {
50            return false;
51        }
52
53        let curve = Curve::secp256k1();
54        let red = &curve.red;
55
56        let x_red = self.x.to_red(red.clone());
57        let y_red = self.y.to_red(red.clone());
58
59        // lhs = y^2 mod p
60        let y2 = red.sqr(&y_red);
61
62        // rhs = x^3 + 7 mod p
63        let x2 = red.sqr(&x_red);
64        let x3 = red.mul(&x_red, &x2);
65        let seven = BigNumber::from_number(7).to_red(red.clone());
66        let rhs = red.add(&x3, &seven);
67
68        y2.from_red().cmp(&rhs.from_red()) == 0
69    }
70
71    /// Recover a point from its x coordinate and y-parity.
72    /// `odd` = true means y should be odd.
73    pub fn from_x(x: &BigNumber, odd: bool) -> Result<Self, PrimitivesError> {
74        let curve = Curve::secp256k1();
75        let red = &curve.red;
76
77        let x_red = x.to_red(red.clone());
78
79        // y^2 = x^3 + 7 mod p
80        let x2 = red.sqr(&x_red);
81        let x3 = red.mul(&x_red, &x2);
82        let seven = BigNumber::from_number(7).to_red(red.clone());
83        let y2 = red.add(&x3, &seven);
84
85        // sqrt(y^2) mod p
86        // For secp256k1, p % 4 == 3, so sqrt(a) = a^((p+1)/4)
87        let y_red = red.sqrt(&y2);
88
89        // Verify the square root is valid
90        let y_check = red.sqr(&y_red);
91        if y_check.from_red().cmp(&y2.from_red()) != 0 {
92            return Err(PrimitivesError::PointNotOnCurve);
93        }
94
95        let mut y_val = y_red.from_red();
96
97        // Adjust parity
98        if y_val.is_odd() != odd {
99            y_val = curve.p.sub(&y_val);
100        }
101
102        let point = Point::new(x.clone(), y_val);
103        if !point.validate() {
104            return Err(PrimitivesError::PointNotOnCurve);
105        }
106        Ok(point)
107    }
108
109    /// Parse a point from DER-encoded bytes (compressed or uncompressed).
110    ///
111    /// Compressed format: 0x02/0x03 || x (33 bytes total)
112    /// Uncompressed format: 0x04 || x || y (65 bytes total)
113    pub fn from_der(bytes: &[u8]) -> Result<Self, PrimitivesError> {
114        if bytes.is_empty() {
115            return Err(PrimitivesError::InvalidDer("empty input".to_string()));
116        }
117
118        let prefix = bytes[0];
119
120        match prefix {
121            0x04 | 0x06 | 0x07 => {
122                // Uncompressed or hybrid format
123                if bytes.len() != 65 {
124                    return Err(PrimitivesError::InvalidDer(format!(
125                        "uncompressed point must be 65 bytes, got {}",
126                        bytes.len()
127                    )));
128                }
129
130                // Validate hybrid format parity
131                if prefix == 0x06 {
132                    if bytes[64] & 1 != 0 {
133                        return Err(PrimitivesError::InvalidDer(
134                            "hybrid point parity mismatch (expected even y)".to_string(),
135                        ));
136                    }
137                } else if prefix == 0x07 && bytes[64] & 1 == 0 {
138                    return Err(PrimitivesError::InvalidDer(
139                        "hybrid point parity mismatch (expected odd y)".to_string(),
140                    ));
141                }
142
143                let x = BigNumber::from_bytes(&bytes[1..33], Endian::Big);
144                let y = BigNumber::from_bytes(&bytes[33..65], Endian::Big);
145
146                let point = Point::new(x, y);
147                if !point.validate() {
148                    return Err(PrimitivesError::PointNotOnCurve);
149                }
150                Ok(point)
151            }
152            0x02 | 0x03 => {
153                // Compressed format
154                if bytes.len() != 33 {
155                    return Err(PrimitivesError::InvalidDer(format!(
156                        "compressed point must be 33 bytes, got {}",
157                        bytes.len()
158                    )));
159                }
160
161                let x = BigNumber::from_bytes(&bytes[1..33], Endian::Big);
162                let odd = prefix == 0x03;
163                Point::from_x(&x, odd)
164            }
165            _ => Err(PrimitivesError::InvalidDer(format!(
166                "unknown point format prefix: 0x{:02x}",
167                prefix
168            ))),
169        }
170    }
171
172    /// Parse a point from a hex string (DER encoded).
173    pub fn from_string(hex: &str) -> Result<Self, PrimitivesError> {
174        let bytes = hex_to_bytes(hex)?;
175        Self::from_der(&bytes)
176    }
177
178    /// Encode this point to DER format.
179    ///
180    /// Compressed (33 bytes): 0x02/0x03 || x
181    /// Uncompressed (65 bytes): 0x04 || x || y
182    pub fn to_der(&self, compressed: bool) -> Vec<u8> {
183        if self.inf {
184            return vec![0x00];
185        }
186
187        let x_bytes = self.x.to_array(Endian::Big, Some(32));
188
189        if compressed {
190            let prefix = if self.y.is_even() { 0x02 } else { 0x03 };
191            let mut result = Vec::with_capacity(33);
192            result.push(prefix);
193            result.extend_from_slice(&x_bytes);
194            result
195        } else {
196            let y_bytes = self.y.to_array(Endian::Big, Some(32));
197            let mut result = Vec::with_capacity(65);
198            result.push(0x04);
199            result.extend_from_slice(&x_bytes);
200            result.extend_from_slice(&y_bytes);
201            result
202        }
203    }
204
205    /// Encode to hex string (compressed DER).
206    pub fn to_hex(&self) -> String {
207        bytes_to_hex(&self.to_der(true))
208    }
209
210    /// Add two points.
211    pub fn add(&self, other: &Point) -> Point {
212        if self.inf {
213            return other.clone();
214        }
215        if other.inf {
216            return self.clone();
217        }
218
219        // Use Jacobian arithmetic for efficiency
220        let jp1 = JacobianPoint::from_affine(&self.x, &self.y);
221        let jp2 = JacobianPoint::from_affine(&other.x, &other.y);
222        let result = jp1.add(&jp2);
223
224        if result.is_infinity() {
225            return Point::infinity();
226        }
227
228        let (x, y) = result.to_affine();
229        Point::new(x, y)
230    }
231
232    /// Scalar multiplication: self * k.
233    pub fn mul(&self, k: &BigNumber) -> Point {
234        if k.is_zero() || self.inf {
235            return Point::infinity();
236        }
237
238        let is_neg = k.is_neg();
239        let k_abs = if is_neg { k.neg() } else { k.clone() };
240
241        // Reduce k mod n
242        let curve = Curve::secp256k1();
243        let k_mod = k_abs.umod(&curve.n).unwrap_or(k_abs);
244
245        if k_mod.is_zero() {
246            return Point::infinity();
247        }
248
249        let jp = JacobianPoint::from_affine(&self.x, &self.y);
250        let result = jp.mul_wnaf(&k_mod);
251
252        if result.is_infinity() {
253            return Point::infinity();
254        }
255
256        let (x, y) = result.to_affine();
257        let point = Point::new(x, y);
258
259        if is_neg {
260            point.negate()
261        } else {
262            point
263        }
264    }
265
266    /// Negate a point (same x, y = p - y).
267    pub fn negate(&self) -> Point {
268        if self.inf {
269            return self.clone();
270        }
271        let curve = Curve::secp256k1();
272        let neg_y = curve.p.sub(&self.y);
273        Point::new(self.x.clone(), neg_y)
274    }
275
276    /// Check equality of two points.
277    #[allow(clippy::should_implement_trait)]
278    pub fn eq(&self, other: &Point) -> bool {
279        if self.inf && other.inf {
280            return true;
281        }
282        if self.inf != other.inf {
283            return false;
284        }
285        self.x.cmp(&other.x) == 0 && self.y.cmp(&other.y) == 0
286    }
287
288    /// Double this point (P + P = 2P).
289    pub fn dbl(&self) -> Point {
290        if self.inf {
291            return self.clone();
292        }
293        let jp = JacobianPoint::from_affine(&self.x, &self.y);
294        let result = jp.dbl();
295        if result.is_infinity() {
296            return Point::infinity();
297        }
298        let (x, y) = result.to_affine();
299        Point::new(x, y)
300    }
301
302    /// Get x coordinate (clone).
303    pub fn get_x(&self) -> BigNumber {
304        self.x.clone()
305    }
306
307    /// Get y coordinate (clone).
308    pub fn get_y(&self) -> BigNumber {
309        self.y.clone()
310    }
311}
312
313// ---------------------------------------------------------------------------
314// Hex helpers
315// ---------------------------------------------------------------------------
316
317fn hex_to_bytes(hex: &str) -> Result<Vec<u8>, PrimitivesError> {
318    if hex.len() & 1 != 0 {
319        return Err(PrimitivesError::InvalidHex(
320            "odd-length hex string".to_string(),
321        ));
322    }
323    let mut bytes = Vec::with_capacity(hex.len() / 2);
324    for i in (0..hex.len()).step_by(2) {
325        let byte = u8::from_str_radix(&hex[i..i + 2], 16)
326            .map_err(|e| PrimitivesError::InvalidHex(e.to_string()))?;
327        bytes.push(byte);
328    }
329    Ok(bytes)
330}
331
332fn bytes_to_hex(bytes: &[u8]) -> String {
333    bytes.iter().map(|b| format!("{:02x}", b)).collect()
334}
335
336// ---------------------------------------------------------------------------
337// Tests
338// ---------------------------------------------------------------------------
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    fn g() -> Point {
345        let curve = Curve::secp256k1();
346        curve.generator()
347    }
348
349    #[test]
350    fn test_point_infinity() {
351        let inf = Point::infinity();
352        assert!(inf.is_infinity());
353    }
354
355    #[test]
356    fn test_point_g_on_curve() {
357        let g = g();
358        assert!(g.validate());
359    }
360
361    #[test]
362    fn test_point_infinity_not_on_curve() {
363        let inf = Point::infinity();
364        assert!(!inf.validate());
365    }
366
367    #[test]
368    fn test_point_add_g_plus_g() {
369        let g = g();
370        let two_g = g.add(&g);
371        assert_eq!(
372            two_g.x.to_hex(),
373            "c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5"
374        );
375        assert_eq!(
376            two_g.y.to_hex(),
377            "1ae168fea63dc339a3c58419466ceaeef7f632653266d0e1236431a950cfe52a"
378        );
379    }
380
381    #[test]
382    fn test_point_add_identity() {
383        let g = g();
384        let inf = Point::infinity();
385
386        let r1 = g.add(&inf);
387        assert!(r1.eq(&g));
388
389        let r2 = inf.add(&g);
390        assert!(r2.eq(&g));
391    }
392
393    #[test]
394    fn test_point_mul_1() {
395        let g = g();
396        let k = BigNumber::one();
397        let result = g.mul(&k);
398        assert!(result.eq(&g));
399    }
400
401    #[test]
402    fn test_point_mul_2_equals_add() {
403        let g = g();
404        let k = BigNumber::from_number(2);
405        let mul_result = g.mul(&k);
406        let add_result = g.add(&g);
407        assert!(mul_result.eq(&add_result));
408    }
409
410    #[test]
411    fn test_point_mul_n_is_infinity() {
412        let g = g();
413        let curve = Curve::secp256k1();
414        let result = g.mul(&curve.n);
415        assert!(result.is_infinity());
416    }
417
418    #[test]
419    fn test_point_mul_n_minus_1() {
420        let g = g();
421        let curve = Curve::secp256k1();
422        let n_minus_1 = curve.n.subn(1);
423        let result = g.mul(&n_minus_1);
424        // (n-1)*G should have same x as G but negated y (= p - G.y)
425        assert_eq!(result.x.cmp(&g.x), 0);
426        let neg_y = curve.p.sub(&g.y);
427        assert_eq!(result.y.cmp(&neg_y), 0);
428    }
429
430    #[test]
431    fn test_point_negate() {
432        let g = g();
433        let neg_g = g.negate();
434        assert_eq!(neg_g.x.cmp(&g.x), 0);
435        let curve = Curve::secp256k1();
436        let expected_y = curve.p.sub(&g.y);
437        assert_eq!(neg_g.y.cmp(&expected_y), 0);
438    }
439
440    #[test]
441    fn test_point_negate_add_is_infinity() {
442        let g = g();
443        let neg_g = g.negate();
444        let result = g.add(&neg_g);
445        assert!(result.is_infinity());
446    }
447
448    #[test]
449    fn test_point_compressed_even_y() {
450        let g = g();
451        let der = g.to_der(true);
452        assert_eq!(der.len(), 33);
453        // G has even y, so prefix should be 0x02
454        assert_eq!(der[0], 0x02);
455    }
456
457    #[test]
458    fn test_point_uncompressed() {
459        let g = g();
460        let der = g.to_der(false);
461        assert_eq!(der.len(), 65);
462        assert_eq!(der[0], 0x04);
463    }
464
465    #[test]
466    fn test_point_from_der_compressed() {
467        let g = g();
468        let der = g.to_der(true);
469        let recovered = Point::from_der(&der).unwrap();
470        assert!(recovered.eq(&g));
471    }
472
473    #[test]
474    fn test_point_from_der_uncompressed() {
475        let g = g();
476        let der = g.to_der(false);
477        let recovered = Point::from_der(&der).unwrap();
478        assert!(recovered.eq(&g));
479    }
480
481    #[test]
482    fn test_point_from_der_round_trip_compressed() {
483        let g = g();
484        for k in 1..=10 {
485            let p = g.mul(&BigNumber::from_number(k));
486            if p.is_infinity() {
487                continue;
488            }
489            let der = p.to_der(true);
490            let recovered = Point::from_der(&der).unwrap();
491            assert!(recovered.eq(&p), "round-trip failed for k={}", k);
492        }
493    }
494
495    #[test]
496    fn test_point_from_der_round_trip_uncompressed() {
497        let g = g();
498        for k in 1..=10 {
499            let p = g.mul(&BigNumber::from_number(k));
500            if p.is_infinity() {
501                continue;
502            }
503            let der = p.to_der(false);
504            let recovered = Point::from_der(&der).unwrap();
505            assert!(recovered.eq(&p), "round-trip failed for k={}", k);
506        }
507    }
508
509    #[test]
510    fn test_point_invalid_not_on_curve() {
511        // Random bytes that are not on the curve
512        let mut bytes = vec![0x04];
513        bytes.extend_from_slice(&[0x01; 32]); // x = 1
514        bytes.extend_from_slice(&[0x01; 32]); // y = 1
515        let result = Point::from_der(&bytes);
516        assert!(result.is_err());
517    }
518
519    #[test]
520    fn test_point_from_string() {
521        let g = g();
522        let hex = g.to_hex();
523        let recovered = Point::from_string(&hex).unwrap();
524        assert!(recovered.eq(&g));
525    }
526
527    #[test]
528    fn test_point_mul_known_multiples() {
529        let g = g();
530        let expected = vec![
531            (
532                2,
533                "c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5",
534                "1ae168fea63dc339a3c58419466ceaeef7f632653266d0e1236431a950cfe52a",
535            ),
536            (
537                3,
538                "f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9",
539                "388f7b0f632de8140fe337e62a37f3566500a99934c2231b6cb9fd7584b8e672",
540            ),
541            (
542                5,
543                "2f8bde4d1a07209355b4a7250a5c5128e88b84bddc619ab7cba8d569b240efe4",
544                "d8ac222636e5e3d6d4dba9dda6c9c426f788271bab0d6840dca87d3aa6ac62d6",
545            ),
546            (
547                10,
548                "a0434d9e47f3c86235477c7b1ae6ae5d3442d49b1943c2b752a68e2a47e247c7",
549                "893aba425419bc27a3b6c7e693a24c696f794c2ed877a1593cbee53b037368d7",
550            ),
551        ];
552
553        for (k, ex, ey) in expected {
554            let result = g.mul(&BigNumber::from_number(k));
555            assert_eq!(result.x.to_hex(), ex, "x mismatch for k={}", k);
556            assert_eq!(result.y.to_hex(), ey, "y mismatch for k={}", k);
557        }
558    }
559
560    #[test]
561    fn test_point_dbl() {
562        let g = g();
563        let dbl = g.dbl();
564        let add = g.add(&g);
565        assert!(dbl.eq(&add));
566    }
567
568    #[test]
569    fn test_point_from_x() {
570        let curve = Curve::secp256k1();
571        // Recover G from its x coordinate
572        let p = Point::from_x(&curve.g_x, false).unwrap();
573        assert_eq!(p.x.cmp(&curve.g_x), 0);
574        assert_eq!(p.y.cmp(&curve.g_y), 0);
575    }
576
577    #[test]
578    fn test_point_from_x_odd() {
579        let curve = Curve::secp256k1();
580        // G.y is even, so asking for odd should give p - G.y
581        let p = Point::from_x(&curve.g_x, true).unwrap();
582        let neg_y = curve.p.sub(&curve.g_y);
583        assert_eq!(p.y.cmp(&neg_y), 0);
584    }
585}