macro_galois_field/
lib.rs

1//! Derive macro for Galois Field
2//!
3//! # Quick Start
4//!
5//! ```
6//! use macro_galois_field::Field;
7//!
8//! #[derive(Field, Debug, Default, Copy, Clone)]
9//! #[prime = 2]
10//! struct Fp2(u64);
11//!
12//! let a = Fp2(3);
13//! let b = Fp2(3);
14//! assert_eq!(a + b, Fp2(0), "{} + {}", a.0, b.0);
15//! assert_eq!(a - b, Fp2(0), "{} - {}", a.0, b.0);
16//! assert_eq!(a - b, Fp2(2), "{} - {}", a.0, b.0);
17//! assert_eq!(a * b, Fp2(1), "{} * {}", a.0, b.0);
18//! assert_eq!(a * b, Fp2(3), "{} * {}", a.0, b.0);
19//! assert_eq!(a / b, Fp2(1), "{} / {}", a.0, b.0);
20//!
21//! let a = Fp2(3);
22//! let b = Fp2(100);
23//! assert_eq!(a + b, Fp2(1), "{} + {}", a.0, b.0);
24//! assert_eq!(a - b, Fp2(1), "{} - {}", a.0, b.0);
25//! ```
26
27#![recursion_limit="1024"]
28use proc_macro::TokenStream;
29use quote::{quote, ToTokens};
30use syn::{
31    parse_macro_input, DeriveInput,
32};
33
34// TODO: not only prime field, non-prime field
35// TODO: not only struct(u64), but also struct(usize), struct {num: u32}
36// TODO: check if prime is really a prime number
37#[proc_macro_derive(Field, attributes(prime))]
38pub fn derive(input: TokenStream) -> TokenStream {
39    // Parse the input tokens into a syntax tree.
40    let ast = parse_macro_input!(input as DeriveInput);
41    let name = &ast.ident;
42    let prime_attr = ast.attrs.iter().find(|atter| {
43        atter.path.segments[0].ident.to_string() == "prime"
44    }).expect("prime attr should exist.");
45    let tt= prime_attr.tts.clone().into_iter().nth(0).expect(&format!("expect: #[prime = <num>].\nreal: {:?}", prime_attr.into_token_stream().to_string()));
46    assert_eq!(tt.to_string(), "=");
47    let prime_tt= prime_attr.tts.clone().into_iter().nth(1).expect(&format!("expect: #[prime = <num>].\nreal: {:?}", prime_attr.into_token_stream().to_string()));
48    let prime: u64 = match prime_tt {
49        proc_macro2::TokenTree::Literal(l) => {
50            l.to_string().parse().unwrap()
51        }
52        _ => {
53            panic!("{:?}", prime_tt)
54        }
55    };
56    (quote!{
57        impl #name {
58            const prime: u64 = #prime;
59
60            pub fn n(num: u64) -> Self {
61                #name(num % #prime)
62            }
63
64            fn modinv(&self) -> Self {
65                let mut x0: i64 = 1;
66                let mut y0: i64 = 0;
67                let mut x1: i64 = 0;
68                let mut y1:i64 = 1;
69                let mut a: i64 = self.0 as i64;
70                let mut b: i64 = #prime as i64;
71                while b != 0 {
72                    let q = a / b;
73                    let pre_b = b;
74                    let pre_a = a;
75                    a = pre_b;
76                    b = pre_a % pre_b;
77
78                    let pre_x0 = x0;
79                    let pre_x1 = x1;
80                    x0 = pre_x1;
81                    x1 = pre_x0 - q * pre_x1;
82
83                    let pre_y0 = y0;
84                    let pre_y1 = y1;
85                    y0 = pre_y1;
86                    y1 = pre_y0 - q * pre_y1;
87                }
88                if a != 1 {
89                    dbg!(a, b, x0, x1, y0, y1);
90                    panic!("modular inverse does not exist for num: {}, moduler: {}", self.0, #prime);
91                }
92                if x0 < 0 {
93                    let q = x0 / #prime as i64;
94                    x0 -= (q - 1) * #prime as i64;
95                }
96                x0 = x0 % #prime as i64;
97                #name(x0 as u64)
98            }
99        }
100
101        impl std::ops::Add for #name {
102            type Output = Self;
103            fn add(self, rhs: Self) -> Self::Output {
104                #name((self.0 + rhs.0) % #prime)
105            }
106        }
107
108        impl std::ops::AddAssign for #name {
109            fn add_assign(&mut self, rhs: Self) {
110                self.0 = (self.0 + rhs.0) % #prime;
111            }
112        }
113
114        impl std::ops::Sub for #name {
115            type Output = Self;
116            fn sub(self, rhs: Self) -> Self::Output {
117                let mut n = self.0;
118                while n < rhs.0 {
119                    n += #prime
120                }
121                #name((n - rhs.0) % #prime)
122            }
123        }
124
125        impl std::ops::SubAssign for #name {
126            fn sub_assign(&mut self, rhs: Self) {
127                while self.0 < rhs.0 {
128                    self.0 += #prime
129                }
130                self.0 = (self.0 - rhs.0) % #prime;
131            }
132        }
133
134        impl std::ops::Mul for #name {
135            type Output = Self;
136            fn mul(self, rhs: Self) -> Self::Output {
137                #name((self.0 * rhs.0) % #prime)
138            }
139        }
140
141        impl std::ops::MulAssign for #name {
142            fn mul_assign(&mut self, rhs: Self) {
143                self.0 = (self.0 * rhs.0) % #prime;
144            }
145        }
146
147        impl std::ops::Div for #name {
148            type Output = Self;
149            fn div(self, rhs: Self) -> Self::Output {
150                self * rhs.modinv()
151            }
152        }
153
154        impl std::ops::DivAssign for #name {
155            fn div_assign(&mut self, rhs: Self) {
156                let a = #name(self.0) * rhs.modinv();
157                self.0 = a.0;
158            }
159        }
160
161        impl std::ops::Neg for #name {
162            type Output = Self;
163            fn neg(self) -> Self::Output {
164                let mut num = -(self.0 as i64);
165                if num < 0 {
166                    let q = num / #prime as i64;
167                    num -= (q - 1) * #prime as i64;
168                }
169
170                num = num % #prime as i64;
171                #name(num as u64)
172            }
173        }
174
175        impl PartialEq for #name {
176            fn eq(&self, other: &Self) -> bool {
177                self.0 % #prime == other.0 % #prime
178            }
179        }
180        impl Eq for #name {}
181
182    }).into()
183}