curv/elliptic/curves/wrappers/
arithmetic.rs

1use std::ops;
2
3use crate::elliptic::curves::traits::*;
4
5use super::*;
6
7macro_rules! matrix {
8    (
9        trait = $trait:ident,
10        trait_fn = $trait_fn:ident,
11        output = $output:ty,
12        output_new = $output_new:expr,
13        point_fn = $point_fn:ident,
14        point_assign_fn = $point_assign_fn:ident,
15        pairs = {(r_<$($l:lifetime),*> $lhs_ref:ty, $rhs:ty), $($rest:tt)*}
16    ) => {
17        impl<$($l,)* E: Curve> ops::$trait<$rhs> for $lhs_ref {
18            type Output = $output;
19            fn $trait_fn(self, rhs: $rhs) -> Self::Output {
20                let p = self.as_raw().$point_fn(rhs.as_raw());
21                $output_new(p)
22            }
23        }
24        matrix!{
25            trait = $trait,
26            trait_fn = $trait_fn,
27            output = $output,
28            output_new = $output_new,
29            point_fn = $point_fn,
30            point_assign_fn = $point_assign_fn,
31            pairs = {$($rest)*}
32        }
33    };
34
35    (
36        trait = $trait:ident,
37        trait_fn = $trait_fn:ident,
38        output = $output:ty,
39        output_new = $output_new:expr,
40        point_fn = $point_fn:ident,
41        point_assign_fn = $point_assign_fn:ident,
42        pairs = {(_r<$($l:lifetime),*> $lhs:ty, $rhs_ref:ty), $($rest:tt)*}
43    ) => {
44        impl<$($l,)* E: Curve> ops::$trait<$rhs_ref> for $lhs {
45            type Output = $output;
46            fn $trait_fn(self, rhs: $rhs_ref) -> Self::Output {
47                let p = rhs.as_raw().$point_fn(self.as_raw());
48                $output_new(p)
49            }
50        }
51        matrix!{
52            trait = $trait,
53            trait_fn = $trait_fn,
54            output = $output,
55            output_new = $output_new,
56            point_fn = $point_fn,
57            point_assign_fn = $point_assign_fn,
58            pairs = {$($rest)*}
59        }
60    };
61
62    (
63        trait = $trait:ident,
64        trait_fn = $trait_fn:ident,
65        output = $output:ty,
66        output_new = $output_new:expr,
67        point_fn = $point_fn:ident,
68        point_assign_fn = $point_assign_fn:ident,
69        pairs = {(o_<$($l:lifetime),*> $lhs_owned:ty, $rhs:ty), $($rest:tt)*}
70    ) => {
71        impl<$($l,)* E: Curve> ops::$trait<$rhs> for $lhs_owned {
72            type Output = $output;
73            fn $trait_fn(self, rhs: $rhs) -> Self::Output {
74                let mut raw = self.into_raw();
75                raw.$point_assign_fn(rhs.as_raw());
76                $output_new(raw)
77            }
78        }
79        matrix!{
80            trait = $trait,
81            trait_fn = $trait_fn,
82            output = $output,
83            output_new = $output_new,
84            point_fn = $point_fn,
85            point_assign_fn = $point_assign_fn,
86            pairs = {$($rest)*}
87        }
88    };
89
90    (
91        trait = $trait:ident,
92        trait_fn = $trait_fn:ident,
93        output = $output:ty,
94        output_new = $output_new:expr,
95        point_fn = $point_fn:ident,
96        point_assign_fn = $point_assign_fn:ident,
97        pairs = {(_o<$($l:lifetime),*> $lhs:ty, $rhs_owned:ty), $($rest:tt)*}
98    ) => {
99        impl<$($l,)* E: Curve> ops::$trait<$rhs_owned> for $lhs {
100            type Output = $output;
101            fn $trait_fn(self, rhs: $rhs_owned) -> Self::Output {
102                let mut raw = rhs.into_raw();
103                raw.$point_assign_fn(self.as_raw());
104                $output_new(raw)
105            }
106        }
107        matrix!{
108            trait = $trait,
109            trait_fn = $trait_fn,
110            output = $output,
111            output_new = $output_new,
112            point_fn = $point_fn,
113            point_assign_fn = $point_assign_fn,
114            pairs = {$($rest)*}
115        }
116    };
117
118    (
119        trait = $trait:ident,
120        trait_fn = $trait_fn:ident,
121        output = $output:ty,
122        output_new = $output_new:expr,
123        point_fn = $point_fn:ident,
124        point_assign_fn = $point_assign_fn:ident,
125        pairs = {}
126    ) => {
127        // happy termination
128    };
129}
130
131fn addition_of_two_points<E: Curve>(result: E::Point) -> Point<E> {
132    // Safety: addition of two points of group order is always either a zero point or point of group
133    // order: `A + B = aG + bG = (a + b)G`
134    unsafe { Point::from_raw_unchecked(result) }
135}
136
137matrix! {
138    trait = Add,
139    trait_fn = add,
140    output = Point<E>,
141    output_new = addition_of_two_points,
142    point_fn = add_point,
143    point_assign_fn = add_point_assign,
144    pairs = {
145        (o_<> Point<E>, Point<E>), (o_<> Point<E>, &Point<E>),
146        (o_<> Point<E>, Generator<E>),
147
148        (_o<> &Point<E>, Point<E>), (r_<> &Point<E>, &Point<E>),
149        (r_<> &Point<E>, Generator<E>),
150
151        (_o<> Generator<E>, Point<E>), (r_<> Generator<E>, &Point<E>),
152        (r_<> Generator<E>, Generator<E>),
153    }
154}
155
156fn subtraction_of_two_point<E: Curve>(result: E::Point) -> Point<E> {
157    // Safety: subtraction of two points of group order is always either a zero point or point of group
158    // order: `A - B = aG - bG = (a - b)G`
159    unsafe { Point::from_raw_unchecked(result) }
160}
161
162matrix! {
163    trait = Sub,
164    trait_fn = sub,
165    output = Point<E>,
166    output_new = subtraction_of_two_point,
167    point_fn = sub_point,
168    point_assign_fn = sub_point_assign,
169    pairs = {
170        (o_<> Point<E>, Point<E>), (o_<> Point<E>, &Point<E>),
171        (o_<> Point<E>, Generator<E>),
172
173        (r_<> &Point<E>, Point<E>), (r_<> &Point<E>, &Point<E>),
174        (r_<> &Point<E>, Generator<E>),
175
176        (r_<> Generator<E>, Point<E>), (r_<> Generator<E>, &Point<E>),
177        (r_<> Generator<E>, Generator<E>),
178    }
179}
180
181fn multiplication_of_point_at_scalar<E: Curve>(result: E::Point) -> Point<E> {
182    // Safety: multiplication of point of group order at a scalar is always either a zero point or
183    // point of group order: `kA = kaG`
184    unsafe { Point::from_raw_unchecked(result) }
185}
186
187matrix! {
188    trait = Mul,
189    trait_fn = mul,
190    output = Point<E>,
191    output_new = multiplication_of_point_at_scalar,
192    point_fn = scalar_mul,
193    point_assign_fn = scalar_mul_assign,
194    pairs = {
195        (o_<> Point<E>, Scalar<E>), (o_<> Point<E>, &Scalar<E>),
196        (r_<> &Point<E>, Scalar<E>), (r_<> &Point<E>, &Scalar<E>),
197
198        (_o<> Scalar<E>, Point<E>), (_o<> &Scalar<E>, Point<E>),
199        (_r<> Scalar<E>, &Point<E>), (_r<> &Scalar<E>, &Point<E>),
200    }
201}
202
203matrix! {
204    trait = Add,
205    trait_fn = add,
206    output = Scalar<E>,
207    output_new = Scalar::from_raw,
208    point_fn = add,
209    point_assign_fn = add_assign,
210    pairs = {
211        (o_<> Scalar<E>, Scalar<E>), (o_<> Scalar<E>, &Scalar<E>),
212        (_o<> &Scalar<E>, Scalar<E>), (r_<> &Scalar<E>, &Scalar<E>),
213    }
214}
215
216matrix! {
217    trait = Sub,
218    trait_fn = sub,
219    output = Scalar<E>,
220    output_new = Scalar::from_raw,
221    point_fn = sub,
222    point_assign_fn = sub_assign,
223    pairs = {
224        (o_<> Scalar<E>, Scalar<E>), (o_<> Scalar<E>, &Scalar<E>),
225        (r_<> &Scalar<E>, Scalar<E>), (r_<> &Scalar<E>, &Scalar<E>),
226    }
227}
228
229matrix! {
230    trait = Mul,
231    trait_fn = mul,
232    output = Scalar<E>,
233    output_new = Scalar::from_raw,
234    point_fn = mul,
235    point_assign_fn = mul_assign,
236    pairs = {
237        (o_<> Scalar<E>, Scalar<E>), (o_<> Scalar<E>, &Scalar<E>),
238        (_o<> &Scalar<E>, Scalar<E>), (r_<> &Scalar<E>, &Scalar<E>),
239    }
240}
241
242impl<E: Curve> ops::Mul<&Scalar<E>> for Generator<E> {
243    type Output = Point<E>;
244    fn mul(self, rhs: &Scalar<E>) -> Self::Output {
245        Point::from_raw(E::Point::generator_mul(rhs.as_raw())).expect(
246            "generator multiplied by scalar is always a point of group order or a zero point",
247        )
248    }
249}
250
251impl<E: Curve> ops::Mul<Scalar<E>> for Generator<E> {
252    type Output = Point<E>;
253    fn mul(self, rhs: Scalar<E>) -> Self::Output {
254        self.mul(&rhs)
255    }
256}
257
258impl<E: Curve> ops::Mul<Generator<E>> for &Scalar<E> {
259    type Output = Point<E>;
260    fn mul(self, rhs: Generator<E>) -> Self::Output {
261        rhs.mul(self)
262    }
263}
264
265impl<E: Curve> ops::Mul<Generator<E>> for Scalar<E> {
266    type Output = Point<E>;
267    fn mul(self, rhs: Generator<E>) -> Self::Output {
268        rhs.mul(self)
269    }
270}
271
272impl<E: Curve> ops::Neg for Scalar<E> {
273    type Output = Scalar<E>;
274
275    fn neg(self) -> Self::Output {
276        Scalar::from_raw(self.as_raw().neg())
277    }
278}
279
280impl<E: Curve> ops::Neg for &Scalar<E> {
281    type Output = Scalar<E>;
282
283    fn neg(self) -> Self::Output {
284        Scalar::from_raw(self.as_raw().neg())
285    }
286}
287
288impl<E: Curve> ops::Neg for Point<E> {
289    type Output = Point<E>;
290
291    fn neg(self) -> Self::Output {
292        Point::from_raw(self.as_raw().neg_point())
293            .expect("neg must not produce point of different order")
294    }
295}
296
297impl<E: Curve> ops::Neg for &Point<E> {
298    type Output = Point<E>;
299
300    fn neg(self) -> Self::Output {
301        Point::from_raw(self.as_raw().neg_point())
302            .expect("neg must not produce point of different order")
303    }
304}
305
306impl<E: Curve> ops::Neg for Generator<E> {
307    type Output = Point<E>;
308
309    fn neg(self) -> Self::Output {
310        Point::from_raw(self.as_raw().neg_point())
311            .expect("neg must not produce point of different order")
312    }
313}
314
315#[cfg(test)]
316mod test {
317    use super::*;
318
319    macro_rules! assert_operator_defined_for {
320        (
321            assert_fn = $assert_fn:ident,
322            lhs = {},
323            rhs = {$($rhs:ty),*},
324        ) => {
325            // Corner case
326        };
327        (
328            assert_fn = $assert_fn:ident,
329            lhs = {$lhs:ty $(, $lhs_tail:ty)*},
330            rhs = {$($rhs:ty),*},
331        ) => {
332            assert_operator_defined_for! {
333                assert_fn = $assert_fn,
334                lhs = $lhs,
335                rhs = {$($rhs),*},
336            }
337            assert_operator_defined_for! {
338                assert_fn = $assert_fn,
339                lhs = {$($lhs_tail),*},
340                rhs = {$($rhs),*},
341            }
342        };
343        (
344            assert_fn = $assert_fn:ident,
345            lhs = $lhs:ty,
346            rhs = {$($rhs:ty),*},
347        ) => {
348            $($assert_fn::<E, $lhs, $rhs>());*
349        };
350    }
351
352    /// Function asserts that P2 can be added to P1 (ie. P1 + P2) and result is Point.
353    /// If any condition doesn't meet, function won't compile.
354    #[allow(dead_code)]
355    fn assert_point_addition_defined<E, P1, P2>()
356    where
357        P1: ops::Add<P2, Output = Point<E>>,
358        E: Curve,
359    {
360        // no-op
361    }
362
363    #[test]
364    fn test_point_addition_defined() {
365        fn _curve<E: Curve>() {
366            assert_operator_defined_for! {
367                assert_fn = assert_point_addition_defined,
368                lhs = {Point<E>, &Point<E>, Generator<E>},
369                rhs = {Point<E>, &Point<E>, Generator<E>},
370            }
371        }
372    }
373
374    /// Function asserts that P2 can be subtracted from P1 (ie. P1 - P2) and result is Point.
375    /// If any condition doesn't meet, function won't compile.
376    #[allow(dead_code)]
377    fn assert_point_subtraction_defined<E, P1, P2>()
378    where
379        P1: ops::Sub<P2, Output = Point<E>>,
380        E: Curve,
381    {
382        // no-op
383    }
384
385    #[test]
386    fn test_point_subtraction_defined() {
387        fn _curve<E: Curve>() {
388            assert_operator_defined_for! {
389                assert_fn = assert_point_subtraction_defined,
390                lhs = {Point<E>, &Point<E>, Generator<E>},
391                rhs = {Point<E>, &Point<E>, Generator<E>},
392            }
393        }
394    }
395
396    /// Function asserts that M can be multiplied by N (ie. M * N) and result is Point.
397    /// If any condition doesn't meet, function won't compile.
398    #[allow(dead_code)]
399    fn assert_point_multiplication_defined<E, M, N>()
400    where
401        M: ops::Mul<N, Output = Point<E>>,
402        E: Curve,
403    {
404        // no-op
405    }
406
407    #[test]
408    fn test_point_multiplication_defined() {
409        fn _curve<E: Curve>() {
410            assert_operator_defined_for! {
411                assert_fn = assert_point_multiplication_defined,
412                lhs = {Point<E>, &Point<E>, Generator<E>},
413                rhs = {Scalar<E>, &Scalar<E>},
414            }
415
416            // and vice-versa
417
418            assert_operator_defined_for! {
419                assert_fn = assert_point_multiplication_defined,
420                lhs = {Scalar<E>, &Scalar<E>},
421                rhs = {Point<E>, &Point<E>, Generator<E>},
422            }
423        }
424    }
425
426    /// Function asserts that S2 can be added to S1 (ie. S1 + S2) and result is Scalar.
427    /// If any condition doesn't meet, function won't compile.
428    #[allow(dead_code)]
429    fn assert_scalars_addition_defined<E, S1, S2>()
430    where
431        S1: ops::Add<S2, Output = Scalar<E>>,
432        E: Curve,
433    {
434        // no-op
435    }
436
437    #[test]
438    fn test_scalars_addition_defined() {
439        fn _curve<E: Curve>() {
440            assert_operator_defined_for! {
441                assert_fn = assert_scalars_addition_defined,
442                lhs = {Scalar<E>, Scalar<E>},
443                rhs = {Scalar<E>, Scalar<E>},
444            }
445        }
446    }
447
448    /// Function asserts that S2 can be subtracted from S1 (ie. S1 - S2) and result is Scalar.
449    /// If any condition doesn't meet, function won't compile.
450    #[allow(dead_code)]
451    fn assert_scalars_subtraction_defined<E, S1, S2>()
452    where
453        S1: ops::Sub<S2, Output = Scalar<E>>,
454        E: Curve,
455    {
456        // no-op
457    }
458
459    #[test]
460    fn test_scalars_subtraction_defined() {
461        fn _curve<E: Curve>() {
462            assert_operator_defined_for! {
463                assert_fn = assert_scalars_subtraction_defined,
464                lhs = {Scalar<E>, Scalar<E>},
465                rhs = {Scalar<E>, Scalar<E>},
466            }
467        }
468    }
469
470    /// Function asserts that S1 can be multiplied by S2 (ie. S1 * S2) and result is Scalar.
471    /// If any condition doesn't meet, function won't compile.
472    #[allow(dead_code)]
473    fn assert_scalars_multiplication_defined<E, S1, S2>()
474    where
475        S1: ops::Mul<S2, Output = Scalar<E>>,
476        E: Curve,
477    {
478        // no-op
479    }
480
481    #[test]
482    fn test_scalars_multiplication_defined() {
483        fn _curve<E: Curve>() {
484            assert_operator_defined_for! {
485                assert_fn = assert_scalars_multiplication_defined,
486                lhs = {Scalar<E>, Scalar<E>},
487                rhs = {Scalar<E>, Scalar<E>},
488            }
489        }
490    }
491}