1use std::ops;
2
3use crate::elliptic::curves::traits::*;
4
5use super::*;
6
7macro_rules! matrix {
8 (
9 trait = $trait:ident,
10 trait_fn = $trait_fn:ident,
11 output = $output:ty,
12 output_new = $output_new:expr,
13 point_fn = $point_fn:ident,
14 point_assign_fn = $point_assign_fn:ident,
15 pairs = {(r_<$($l:lifetime),*> $lhs_ref:ty, $rhs:ty), $($rest:tt)*}
16 ) => {
17 impl<$($l,)* E: Curve> ops::$trait<$rhs> for $lhs_ref {
18 type Output = $output;
19 fn $trait_fn(self, rhs: $rhs) -> Self::Output {
20 let p = self.as_raw().$point_fn(rhs.as_raw());
21 $output_new(p)
22 }
23 }
24 matrix!{
25 trait = $trait,
26 trait_fn = $trait_fn,
27 output = $output,
28 output_new = $output_new,
29 point_fn = $point_fn,
30 point_assign_fn = $point_assign_fn,
31 pairs = {$($rest)*}
32 }
33 };
34
35 (
36 trait = $trait:ident,
37 trait_fn = $trait_fn:ident,
38 output = $output:ty,
39 output_new = $output_new:expr,
40 point_fn = $point_fn:ident,
41 point_assign_fn = $point_assign_fn:ident,
42 pairs = {(_r<$($l:lifetime),*> $lhs:ty, $rhs_ref:ty), $($rest:tt)*}
43 ) => {
44 impl<$($l,)* E: Curve> ops::$trait<$rhs_ref> for $lhs {
45 type Output = $output;
46 fn $trait_fn(self, rhs: $rhs_ref) -> Self::Output {
47 let p = rhs.as_raw().$point_fn(self.as_raw());
48 $output_new(p)
49 }
50 }
51 matrix!{
52 trait = $trait,
53 trait_fn = $trait_fn,
54 output = $output,
55 output_new = $output_new,
56 point_fn = $point_fn,
57 point_assign_fn = $point_assign_fn,
58 pairs = {$($rest)*}
59 }
60 };
61
62 (
63 trait = $trait:ident,
64 trait_fn = $trait_fn:ident,
65 output = $output:ty,
66 output_new = $output_new:expr,
67 point_fn = $point_fn:ident,
68 point_assign_fn = $point_assign_fn:ident,
69 pairs = {(o_<$($l:lifetime),*> $lhs_owned:ty, $rhs:ty), $($rest:tt)*}
70 ) => {
71 impl<$($l,)* E: Curve> ops::$trait<$rhs> for $lhs_owned {
72 type Output = $output;
73 fn $trait_fn(self, rhs: $rhs) -> Self::Output {
74 let mut raw = self.into_raw();
75 raw.$point_assign_fn(rhs.as_raw());
76 $output_new(raw)
77 }
78 }
79 matrix!{
80 trait = $trait,
81 trait_fn = $trait_fn,
82 output = $output,
83 output_new = $output_new,
84 point_fn = $point_fn,
85 point_assign_fn = $point_assign_fn,
86 pairs = {$($rest)*}
87 }
88 };
89
90 (
91 trait = $trait:ident,
92 trait_fn = $trait_fn:ident,
93 output = $output:ty,
94 output_new = $output_new:expr,
95 point_fn = $point_fn:ident,
96 point_assign_fn = $point_assign_fn:ident,
97 pairs = {(_o<$($l:lifetime),*> $lhs:ty, $rhs_owned:ty), $($rest:tt)*}
98 ) => {
99 impl<$($l,)* E: Curve> ops::$trait<$rhs_owned> for $lhs {
100 type Output = $output;
101 fn $trait_fn(self, rhs: $rhs_owned) -> Self::Output {
102 let mut raw = rhs.into_raw();
103 raw.$point_assign_fn(self.as_raw());
104 $output_new(raw)
105 }
106 }
107 matrix!{
108 trait = $trait,
109 trait_fn = $trait_fn,
110 output = $output,
111 output_new = $output_new,
112 point_fn = $point_fn,
113 point_assign_fn = $point_assign_fn,
114 pairs = {$($rest)*}
115 }
116 };
117
118 (
119 trait = $trait:ident,
120 trait_fn = $trait_fn:ident,
121 output = $output:ty,
122 output_new = $output_new:expr,
123 point_fn = $point_fn:ident,
124 point_assign_fn = $point_assign_fn:ident,
125 pairs = {}
126 ) => {
127 };
129}
130
131fn addition_of_two_points<E: Curve>(result: E::Point) -> Point<E> {
132 unsafe { Point::from_raw_unchecked(result) }
135}
136
137matrix! {
138 trait = Add,
139 trait_fn = add,
140 output = Point<E>,
141 output_new = addition_of_two_points,
142 point_fn = add_point,
143 point_assign_fn = add_point_assign,
144 pairs = {
145 (o_<> Point<E>, Point<E>), (o_<> Point<E>, &Point<E>),
146 (o_<> Point<E>, Generator<E>),
147
148 (_o<> &Point<E>, Point<E>), (r_<> &Point<E>, &Point<E>),
149 (r_<> &Point<E>, Generator<E>),
150
151 (_o<> Generator<E>, Point<E>), (r_<> Generator<E>, &Point<E>),
152 (r_<> Generator<E>, Generator<E>),
153 }
154}
155
156fn subtraction_of_two_point<E: Curve>(result: E::Point) -> Point<E> {
157 unsafe { Point::from_raw_unchecked(result) }
160}
161
162matrix! {
163 trait = Sub,
164 trait_fn = sub,
165 output = Point<E>,
166 output_new = subtraction_of_two_point,
167 point_fn = sub_point,
168 point_assign_fn = sub_point_assign,
169 pairs = {
170 (o_<> Point<E>, Point<E>), (o_<> Point<E>, &Point<E>),
171 (o_<> Point<E>, Generator<E>),
172
173 (r_<> &Point<E>, Point<E>), (r_<> &Point<E>, &Point<E>),
174 (r_<> &Point<E>, Generator<E>),
175
176 (r_<> Generator<E>, Point<E>), (r_<> Generator<E>, &Point<E>),
177 (r_<> Generator<E>, Generator<E>),
178 }
179}
180
181fn multiplication_of_point_at_scalar<E: Curve>(result: E::Point) -> Point<E> {
182 unsafe { Point::from_raw_unchecked(result) }
185}
186
187matrix! {
188 trait = Mul,
189 trait_fn = mul,
190 output = Point<E>,
191 output_new = multiplication_of_point_at_scalar,
192 point_fn = scalar_mul,
193 point_assign_fn = scalar_mul_assign,
194 pairs = {
195 (o_<> Point<E>, Scalar<E>), (o_<> Point<E>, &Scalar<E>),
196 (r_<> &Point<E>, Scalar<E>), (r_<> &Point<E>, &Scalar<E>),
197
198 (_o<> Scalar<E>, Point<E>), (_o<> &Scalar<E>, Point<E>),
199 (_r<> Scalar<E>, &Point<E>), (_r<> &Scalar<E>, &Point<E>),
200 }
201}
202
203matrix! {
204 trait = Add,
205 trait_fn = add,
206 output = Scalar<E>,
207 output_new = Scalar::from_raw,
208 point_fn = add,
209 point_assign_fn = add_assign,
210 pairs = {
211 (o_<> Scalar<E>, Scalar<E>), (o_<> Scalar<E>, &Scalar<E>),
212 (_o<> &Scalar<E>, Scalar<E>), (r_<> &Scalar<E>, &Scalar<E>),
213 }
214}
215
216matrix! {
217 trait = Sub,
218 trait_fn = sub,
219 output = Scalar<E>,
220 output_new = Scalar::from_raw,
221 point_fn = sub,
222 point_assign_fn = sub_assign,
223 pairs = {
224 (o_<> Scalar<E>, Scalar<E>), (o_<> Scalar<E>, &Scalar<E>),
225 (r_<> &Scalar<E>, Scalar<E>), (r_<> &Scalar<E>, &Scalar<E>),
226 }
227}
228
229matrix! {
230 trait = Mul,
231 trait_fn = mul,
232 output = Scalar<E>,
233 output_new = Scalar::from_raw,
234 point_fn = mul,
235 point_assign_fn = mul_assign,
236 pairs = {
237 (o_<> Scalar<E>, Scalar<E>), (o_<> Scalar<E>, &Scalar<E>),
238 (_o<> &Scalar<E>, Scalar<E>), (r_<> &Scalar<E>, &Scalar<E>),
239 }
240}
241
242impl<E: Curve> ops::Mul<&Scalar<E>> for Generator<E> {
243 type Output = Point<E>;
244 fn mul(self, rhs: &Scalar<E>) -> Self::Output {
245 Point::from_raw(E::Point::generator_mul(rhs.as_raw())).expect(
246 "generator multiplied by scalar is always a point of group order or a zero point",
247 )
248 }
249}
250
251impl<E: Curve> ops::Mul<Scalar<E>> for Generator<E> {
252 type Output = Point<E>;
253 fn mul(self, rhs: Scalar<E>) -> Self::Output {
254 self.mul(&rhs)
255 }
256}
257
258impl<E: Curve> ops::Mul<Generator<E>> for &Scalar<E> {
259 type Output = Point<E>;
260 fn mul(self, rhs: Generator<E>) -> Self::Output {
261 rhs.mul(self)
262 }
263}
264
265impl<E: Curve> ops::Mul<Generator<E>> for Scalar<E> {
266 type Output = Point<E>;
267 fn mul(self, rhs: Generator<E>) -> Self::Output {
268 rhs.mul(self)
269 }
270}
271
272impl<E: Curve> ops::Neg for Scalar<E> {
273 type Output = Scalar<E>;
274
275 fn neg(self) -> Self::Output {
276 Scalar::from_raw(self.as_raw().neg())
277 }
278}
279
280impl<E: Curve> ops::Neg for &Scalar<E> {
281 type Output = Scalar<E>;
282
283 fn neg(self) -> Self::Output {
284 Scalar::from_raw(self.as_raw().neg())
285 }
286}
287
288impl<E: Curve> ops::Neg for Point<E> {
289 type Output = Point<E>;
290
291 fn neg(self) -> Self::Output {
292 Point::from_raw(self.as_raw().neg_point())
293 .expect("neg must not produce point of different order")
294 }
295}
296
297impl<E: Curve> ops::Neg for &Point<E> {
298 type Output = Point<E>;
299
300 fn neg(self) -> Self::Output {
301 Point::from_raw(self.as_raw().neg_point())
302 .expect("neg must not produce point of different order")
303 }
304}
305
306impl<E: Curve> ops::Neg for Generator<E> {
307 type Output = Point<E>;
308
309 fn neg(self) -> Self::Output {
310 Point::from_raw(self.as_raw().neg_point())
311 .expect("neg must not produce point of different order")
312 }
313}
314
315#[cfg(test)]
316mod test {
317 use super::*;
318
319 macro_rules! assert_operator_defined_for {
320 (
321 assert_fn = $assert_fn:ident,
322 lhs = {},
323 rhs = {$($rhs:ty),*},
324 ) => {
325 };
327 (
328 assert_fn = $assert_fn:ident,
329 lhs = {$lhs:ty $(, $lhs_tail:ty)*},
330 rhs = {$($rhs:ty),*},
331 ) => {
332 assert_operator_defined_for! {
333 assert_fn = $assert_fn,
334 lhs = $lhs,
335 rhs = {$($rhs),*},
336 }
337 assert_operator_defined_for! {
338 assert_fn = $assert_fn,
339 lhs = {$($lhs_tail),*},
340 rhs = {$($rhs),*},
341 }
342 };
343 (
344 assert_fn = $assert_fn:ident,
345 lhs = $lhs:ty,
346 rhs = {$($rhs:ty),*},
347 ) => {
348 $($assert_fn::<E, $lhs, $rhs>());*
349 };
350 }
351
352 #[allow(dead_code)]
355 fn assert_point_addition_defined<E, P1, P2>()
356 where
357 P1: ops::Add<P2, Output = Point<E>>,
358 E: Curve,
359 {
360 }
362
363 #[test]
364 fn test_point_addition_defined() {
365 fn _curve<E: Curve>() {
366 assert_operator_defined_for! {
367 assert_fn = assert_point_addition_defined,
368 lhs = {Point<E>, &Point<E>, Generator<E>},
369 rhs = {Point<E>, &Point<E>, Generator<E>},
370 }
371 }
372 }
373
374 #[allow(dead_code)]
377 fn assert_point_subtraction_defined<E, P1, P2>()
378 where
379 P1: ops::Sub<P2, Output = Point<E>>,
380 E: Curve,
381 {
382 }
384
385 #[test]
386 fn test_point_subtraction_defined() {
387 fn _curve<E: Curve>() {
388 assert_operator_defined_for! {
389 assert_fn = assert_point_subtraction_defined,
390 lhs = {Point<E>, &Point<E>, Generator<E>},
391 rhs = {Point<E>, &Point<E>, Generator<E>},
392 }
393 }
394 }
395
396 #[allow(dead_code)]
399 fn assert_point_multiplication_defined<E, M, N>()
400 where
401 M: ops::Mul<N, Output = Point<E>>,
402 E: Curve,
403 {
404 }
406
407 #[test]
408 fn test_point_multiplication_defined() {
409 fn _curve<E: Curve>() {
410 assert_operator_defined_for! {
411 assert_fn = assert_point_multiplication_defined,
412 lhs = {Point<E>, &Point<E>, Generator<E>},
413 rhs = {Scalar<E>, &Scalar<E>},
414 }
415
416 assert_operator_defined_for! {
419 assert_fn = assert_point_multiplication_defined,
420 lhs = {Scalar<E>, &Scalar<E>},
421 rhs = {Point<E>, &Point<E>, Generator<E>},
422 }
423 }
424 }
425
426 #[allow(dead_code)]
429 fn assert_scalars_addition_defined<E, S1, S2>()
430 where
431 S1: ops::Add<S2, Output = Scalar<E>>,
432 E: Curve,
433 {
434 }
436
437 #[test]
438 fn test_scalars_addition_defined() {
439 fn _curve<E: Curve>() {
440 assert_operator_defined_for! {
441 assert_fn = assert_scalars_addition_defined,
442 lhs = {Scalar<E>, Scalar<E>},
443 rhs = {Scalar<E>, Scalar<E>},
444 }
445 }
446 }
447
448 #[allow(dead_code)]
451 fn assert_scalars_subtraction_defined<E, S1, S2>()
452 where
453 S1: ops::Sub<S2, Output = Scalar<E>>,
454 E: Curve,
455 {
456 }
458
459 #[test]
460 fn test_scalars_subtraction_defined() {
461 fn _curve<E: Curve>() {
462 assert_operator_defined_for! {
463 assert_fn = assert_scalars_subtraction_defined,
464 lhs = {Scalar<E>, Scalar<E>},
465 rhs = {Scalar<E>, Scalar<E>},
466 }
467 }
468 }
469
470 #[allow(dead_code)]
473 fn assert_scalars_multiplication_defined<E, S1, S2>()
474 where
475 S1: ops::Mul<S2, Output = Scalar<E>>,
476 E: Curve,
477 {
478 }
480
481 #[test]
482 fn test_scalars_multiplication_defined() {
483 fn _curve<E: Curve>() {
484 assert_operator_defined_for! {
485 assert_fn = assert_scalars_multiplication_defined,
486 lhs = {Scalar<E>, Scalar<E>},
487 rhs = {Scalar<E>, Scalar<E>},
488 }
489 }
490 }
491}