1use crate::primitives::big_number::{BigNumber, Endian};
9use crate::primitives::curve::Curve;
10use crate::primitives::error::PrimitivesError;
11use crate::primitives::jacobian_point::JacobianPoint;
12
13#[derive(Clone, Debug)]
17pub struct Point {
18 pub x: BigNumber,
20 pub y: BigNumber,
22 pub inf: bool,
24}
25
26impl Point {
27 pub fn new(x: BigNumber, y: BigNumber) -> Self {
29 Point { x, y, inf: false }
30 }
31
32 pub fn infinity() -> Self {
34 Point {
35 x: BigNumber::zero(),
36 y: BigNumber::zero(),
37 inf: true,
38 }
39 }
40
41 pub fn is_infinity(&self) -> bool {
43 self.inf
44 }
45
46 pub fn validate(&self) -> bool {
49 if self.inf {
50 return false;
51 }
52
53 let curve = Curve::secp256k1();
54 let red = &curve.red;
55
56 let x_red = self.x.to_red(red.clone());
57 let y_red = self.y.to_red(red.clone());
58
59 let y2 = red.sqr(&y_red);
61
62 let x2 = red.sqr(&x_red);
64 let x3 = red.mul(&x_red, &x2);
65 let seven = BigNumber::from_number(7).to_red(red.clone());
66 let rhs = red.add(&x3, &seven);
67
68 y2.from_red().cmp(&rhs.from_red()) == 0
69 }
70
71 pub fn from_x(x: &BigNumber, odd: bool) -> Result<Self, PrimitivesError> {
74 let curve = Curve::secp256k1();
75 let red = &curve.red;
76
77 let x_red = x.to_red(red.clone());
78
79 let x2 = red.sqr(&x_red);
81 let x3 = red.mul(&x_red, &x2);
82 let seven = BigNumber::from_number(7).to_red(red.clone());
83 let y2 = red.add(&x3, &seven);
84
85 let y_red = red.sqrt(&y2);
88
89 let y_check = red.sqr(&y_red);
91 if y_check.from_red().cmp(&y2.from_red()) != 0 {
92 return Err(PrimitivesError::PointNotOnCurve);
93 }
94
95 let mut y_val = y_red.from_red();
96
97 if y_val.is_odd() != odd {
99 y_val = curve.p.sub(&y_val);
100 }
101
102 let point = Point::new(x.clone(), y_val);
103 if !point.validate() {
104 return Err(PrimitivesError::PointNotOnCurve);
105 }
106 Ok(point)
107 }
108
109 pub fn from_der(bytes: &[u8]) -> Result<Self, PrimitivesError> {
114 if bytes.is_empty() {
115 return Err(PrimitivesError::InvalidDer("empty input".to_string()));
116 }
117
118 let prefix = bytes[0];
119
120 match prefix {
121 0x04 | 0x06 | 0x07 => {
122 if bytes.len() != 65 {
124 return Err(PrimitivesError::InvalidDer(format!(
125 "uncompressed point must be 65 bytes, got {}",
126 bytes.len()
127 )));
128 }
129
130 if prefix == 0x06 {
132 if bytes[64] & 1 != 0 {
133 return Err(PrimitivesError::InvalidDer(
134 "hybrid point parity mismatch (expected even y)".to_string(),
135 ));
136 }
137 } else if prefix == 0x07 && bytes[64] & 1 == 0 {
138 return Err(PrimitivesError::InvalidDer(
139 "hybrid point parity mismatch (expected odd y)".to_string(),
140 ));
141 }
142
143 let x = BigNumber::from_bytes(&bytes[1..33], Endian::Big);
144 let y = BigNumber::from_bytes(&bytes[33..65], Endian::Big);
145
146 let point = Point::new(x, y);
147 if !point.validate() {
148 return Err(PrimitivesError::PointNotOnCurve);
149 }
150 Ok(point)
151 }
152 0x02 | 0x03 => {
153 if bytes.len() != 33 {
155 return Err(PrimitivesError::InvalidDer(format!(
156 "compressed point must be 33 bytes, got {}",
157 bytes.len()
158 )));
159 }
160
161 let x = BigNumber::from_bytes(&bytes[1..33], Endian::Big);
162 let odd = prefix == 0x03;
163 Point::from_x(&x, odd)
164 }
165 _ => Err(PrimitivesError::InvalidDer(format!(
166 "unknown point format prefix: 0x{:02x}",
167 prefix
168 ))),
169 }
170 }
171
172 pub fn from_string(hex: &str) -> Result<Self, PrimitivesError> {
174 let bytes = hex_to_bytes(hex)?;
175 Self::from_der(&bytes)
176 }
177
178 pub fn to_der(&self, compressed: bool) -> Vec<u8> {
183 if self.inf {
184 return vec![0x00];
185 }
186
187 let x_bytes = self.x.to_array(Endian::Big, Some(32));
188
189 if compressed {
190 let prefix = if self.y.is_even() { 0x02 } else { 0x03 };
191 let mut result = Vec::with_capacity(33);
192 result.push(prefix);
193 result.extend_from_slice(&x_bytes);
194 result
195 } else {
196 let y_bytes = self.y.to_array(Endian::Big, Some(32));
197 let mut result = Vec::with_capacity(65);
198 result.push(0x04);
199 result.extend_from_slice(&x_bytes);
200 result.extend_from_slice(&y_bytes);
201 result
202 }
203 }
204
205 pub fn to_hex(&self) -> String {
207 bytes_to_hex(&self.to_der(true))
208 }
209
210 pub fn add(&self, other: &Point) -> Point {
212 if self.inf {
213 return other.clone();
214 }
215 if other.inf {
216 return self.clone();
217 }
218
219 let jp1 = JacobianPoint::from_affine(&self.x, &self.y);
221 let jp2 = JacobianPoint::from_affine(&other.x, &other.y);
222 let result = jp1.add(&jp2);
223
224 if result.is_infinity() {
225 return Point::infinity();
226 }
227
228 let (x, y) = result.to_affine();
229 Point::new(x, y)
230 }
231
232 pub fn mul(&self, k: &BigNumber) -> Point {
234 if k.is_zero() || self.inf {
235 return Point::infinity();
236 }
237
238 let is_neg = k.is_neg();
239 let k_abs = if is_neg { k.neg() } else { k.clone() };
240
241 let curve = Curve::secp256k1();
243 let k_mod = k_abs.umod(&curve.n).unwrap_or(k_abs);
244
245 if k_mod.is_zero() {
246 return Point::infinity();
247 }
248
249 let jp = JacobianPoint::from_affine(&self.x, &self.y);
250 let result = jp.mul_wnaf(&k_mod);
251
252 if result.is_infinity() {
253 return Point::infinity();
254 }
255
256 let (x, y) = result.to_affine();
257 let point = Point::new(x, y);
258
259 if is_neg {
260 point.negate()
261 } else {
262 point
263 }
264 }
265
266 pub fn negate(&self) -> Point {
268 if self.inf {
269 return self.clone();
270 }
271 let curve = Curve::secp256k1();
272 let neg_y = curve.p.sub(&self.y);
273 Point::new(self.x.clone(), neg_y)
274 }
275
276 #[allow(clippy::should_implement_trait)]
278 pub fn eq(&self, other: &Point) -> bool {
279 if self.inf && other.inf {
280 return true;
281 }
282 if self.inf != other.inf {
283 return false;
284 }
285 self.x.cmp(&other.x) == 0 && self.y.cmp(&other.y) == 0
286 }
287
288 pub fn dbl(&self) -> Point {
290 if self.inf {
291 return self.clone();
292 }
293 let jp = JacobianPoint::from_affine(&self.x, &self.y);
294 let result = jp.dbl();
295 if result.is_infinity() {
296 return Point::infinity();
297 }
298 let (x, y) = result.to_affine();
299 Point::new(x, y)
300 }
301
302 pub fn get_x(&self) -> BigNumber {
304 self.x.clone()
305 }
306
307 pub fn get_y(&self) -> BigNumber {
309 self.y.clone()
310 }
311}
312
313fn hex_to_bytes(hex: &str) -> Result<Vec<u8>, PrimitivesError> {
318 if hex.len() & 1 != 0 {
319 return Err(PrimitivesError::InvalidHex(
320 "odd-length hex string".to_string(),
321 ));
322 }
323 let mut bytes = Vec::with_capacity(hex.len() / 2);
324 for i in (0..hex.len()).step_by(2) {
325 let byte = u8::from_str_radix(&hex[i..i + 2], 16)
326 .map_err(|e| PrimitivesError::InvalidHex(e.to_string()))?;
327 bytes.push(byte);
328 }
329 Ok(bytes)
330}
331
332fn bytes_to_hex(bytes: &[u8]) -> String {
333 bytes.iter().map(|b| format!("{:02x}", b)).collect()
334}
335
336#[cfg(test)]
341mod tests {
342 use super::*;
343
344 fn g() -> Point {
345 let curve = Curve::secp256k1();
346 curve.generator()
347 }
348
349 #[test]
350 fn test_point_infinity() {
351 let inf = Point::infinity();
352 assert!(inf.is_infinity());
353 }
354
355 #[test]
356 fn test_point_g_on_curve() {
357 let g = g();
358 assert!(g.validate());
359 }
360
361 #[test]
362 fn test_point_infinity_not_on_curve() {
363 let inf = Point::infinity();
364 assert!(!inf.validate());
365 }
366
367 #[test]
368 fn test_point_add_g_plus_g() {
369 let g = g();
370 let two_g = g.add(&g);
371 assert_eq!(
372 two_g.x.to_hex(),
373 "c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5"
374 );
375 assert_eq!(
376 two_g.y.to_hex(),
377 "1ae168fea63dc339a3c58419466ceaeef7f632653266d0e1236431a950cfe52a"
378 );
379 }
380
381 #[test]
382 fn test_point_add_identity() {
383 let g = g();
384 let inf = Point::infinity();
385
386 let r1 = g.add(&inf);
387 assert!(r1.eq(&g));
388
389 let r2 = inf.add(&g);
390 assert!(r2.eq(&g));
391 }
392
393 #[test]
394 fn test_point_mul_1() {
395 let g = g();
396 let k = BigNumber::one();
397 let result = g.mul(&k);
398 assert!(result.eq(&g));
399 }
400
401 #[test]
402 fn test_point_mul_2_equals_add() {
403 let g = g();
404 let k = BigNumber::from_number(2);
405 let mul_result = g.mul(&k);
406 let add_result = g.add(&g);
407 assert!(mul_result.eq(&add_result));
408 }
409
410 #[test]
411 fn test_point_mul_n_is_infinity() {
412 let g = g();
413 let curve = Curve::secp256k1();
414 let result = g.mul(&curve.n);
415 assert!(result.is_infinity());
416 }
417
418 #[test]
419 fn test_point_mul_n_minus_1() {
420 let g = g();
421 let curve = Curve::secp256k1();
422 let n_minus_1 = curve.n.subn(1);
423 let result = g.mul(&n_minus_1);
424 assert_eq!(result.x.cmp(&g.x), 0);
426 let neg_y = curve.p.sub(&g.y);
427 assert_eq!(result.y.cmp(&neg_y), 0);
428 }
429
430 #[test]
431 fn test_point_negate() {
432 let g = g();
433 let neg_g = g.negate();
434 assert_eq!(neg_g.x.cmp(&g.x), 0);
435 let curve = Curve::secp256k1();
436 let expected_y = curve.p.sub(&g.y);
437 assert_eq!(neg_g.y.cmp(&expected_y), 0);
438 }
439
440 #[test]
441 fn test_point_negate_add_is_infinity() {
442 let g = g();
443 let neg_g = g.negate();
444 let result = g.add(&neg_g);
445 assert!(result.is_infinity());
446 }
447
448 #[test]
449 fn test_point_compressed_even_y() {
450 let g = g();
451 let der = g.to_der(true);
452 assert_eq!(der.len(), 33);
453 assert_eq!(der[0], 0x02);
455 }
456
457 #[test]
458 fn test_point_uncompressed() {
459 let g = g();
460 let der = g.to_der(false);
461 assert_eq!(der.len(), 65);
462 assert_eq!(der[0], 0x04);
463 }
464
465 #[test]
466 fn test_point_from_der_compressed() {
467 let g = g();
468 let der = g.to_der(true);
469 let recovered = Point::from_der(&der).unwrap();
470 assert!(recovered.eq(&g));
471 }
472
473 #[test]
474 fn test_point_from_der_uncompressed() {
475 let g = g();
476 let der = g.to_der(false);
477 let recovered = Point::from_der(&der).unwrap();
478 assert!(recovered.eq(&g));
479 }
480
481 #[test]
482 fn test_point_from_der_round_trip_compressed() {
483 let g = g();
484 for k in 1..=10 {
485 let p = g.mul(&BigNumber::from_number(k));
486 if p.is_infinity() {
487 continue;
488 }
489 let der = p.to_der(true);
490 let recovered = Point::from_der(&der).unwrap();
491 assert!(recovered.eq(&p), "round-trip failed for k={}", k);
492 }
493 }
494
495 #[test]
496 fn test_point_from_der_round_trip_uncompressed() {
497 let g = g();
498 for k in 1..=10 {
499 let p = g.mul(&BigNumber::from_number(k));
500 if p.is_infinity() {
501 continue;
502 }
503 let der = p.to_der(false);
504 let recovered = Point::from_der(&der).unwrap();
505 assert!(recovered.eq(&p), "round-trip failed for k={}", k);
506 }
507 }
508
509 #[test]
510 fn test_point_invalid_not_on_curve() {
511 let mut bytes = vec![0x04];
513 bytes.extend_from_slice(&[0x01; 32]); bytes.extend_from_slice(&[0x01; 32]); let result = Point::from_der(&bytes);
516 assert!(result.is_err());
517 }
518
519 #[test]
520 fn test_point_from_string() {
521 let g = g();
522 let hex = g.to_hex();
523 let recovered = Point::from_string(&hex).unwrap();
524 assert!(recovered.eq(&g));
525 }
526
527 #[test]
528 fn test_point_mul_known_multiples() {
529 let g = g();
530 let expected = vec![
531 (
532 2,
533 "c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5",
534 "1ae168fea63dc339a3c58419466ceaeef7f632653266d0e1236431a950cfe52a",
535 ),
536 (
537 3,
538 "f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9",
539 "388f7b0f632de8140fe337e62a37f3566500a99934c2231b6cb9fd7584b8e672",
540 ),
541 (
542 5,
543 "2f8bde4d1a07209355b4a7250a5c5128e88b84bddc619ab7cba8d569b240efe4",
544 "d8ac222636e5e3d6d4dba9dda6c9c426f788271bab0d6840dca87d3aa6ac62d6",
545 ),
546 (
547 10,
548 "a0434d9e47f3c86235477c7b1ae6ae5d3442d49b1943c2b752a68e2a47e247c7",
549 "893aba425419bc27a3b6c7e693a24c696f794c2ed877a1593cbee53b037368d7",
550 ),
551 ];
552
553 for (k, ex, ey) in expected {
554 let result = g.mul(&BigNumber::from_number(k));
555 assert_eq!(result.x.to_hex(), ex, "x mismatch for k={}", k);
556 assert_eq!(result.y.to_hex(), ey, "y mismatch for k={}", k);
557 }
558 }
559
560 #[test]
561 fn test_point_dbl() {
562 let g = g();
563 let dbl = g.dbl();
564 let add = g.add(&g);
565 assert!(dbl.eq(&add));
566 }
567
568 #[test]
569 fn test_point_from_x() {
570 let curve = Curve::secp256k1();
571 let p = Point::from_x(&curve.g_x, false).unwrap();
573 assert_eq!(p.x.cmp(&curve.g_x), 0);
574 assert_eq!(p.y.cmp(&curve.g_y), 0);
575 }
576
577 #[test]
578 fn test_point_from_x_odd() {
579 let curve = Curve::secp256k1();
580 let p = Point::from_x(&curve.g_x, true).unwrap();
582 let neg_y = curve.p.sub(&curve.g_y);
583 assert_eq!(p.y.cmp(&neg_y), 0);
584 }
585}