use crate::wide::{to_mont_128,mont_prod_128,mont_sqr_128, nqr_128, mont_sub_128};
const fn double(x: u128, n: u128) -> u128{
let sum = x.wrapping_shl(1);
if sum >= n || sum < x{
return sum.wrapping_sub(n)
}
sum
}
const fn mont_add(x: u128, y: u128, n: u128) -> u128{
let sum = x.wrapping_add(y);
if sum >= n || sum < x{
return sum.wrapping_sub(n)
}
sum
}
const fn sqr_coef(x: u128, y: u128, inv: u128, n: u128) -> u128{
double(mont_prod_128(x,y,inv,n),n)
}
const fn prod_coef(x: u128, y: u128, a: u128, b: u128, inv: u128, n: u128) -> u128{
mont_add(mont_prod_128(x,b,inv,n),mont_prod_128(y,a,inv,n),n)
}
const fn gaussian_sqr_real(x: u128, y: u128, inv: u128, n: u128) -> u128{
mont_sub_128(mont_sqr_128(x,inv,n),mont_sqr_128(y,inv,n),n)
}
const fn gaussian_prod_real(x: u128, y: u128, a: u128, b: u128,inv: u128, n: u128) -> u128{
mont_sub_128(mont_prod_128(x,a,inv,n),mont_prod_128(y,b,inv,n),n)
}
const fn gaussian_sqr(x: (u128,u128), inv: u128, n: u128) -> (u128,u128){
(gaussian_sqr_real(x.0,x.1,inv,n),sqr_coef(x.0,x.1,inv,n))
}
const fn gaussian_prod(x: (u128,u128), a:(u128,u128),inv: u128, n: u128) -> (u128,u128){
(gaussian_prod_real(x.0,x.1,a.0,a.1,inv,n),prod_coef(x.0,x.1,a.0,a.1,inv,n))
}
const fn gaussian_pow(mut base: (u128,u128),mut one: (u128,u128),mut pow: u128, inv: u128, n: u128) -> (u128,u128){
while pow > 1 {
if pow&1 == 0{
base = gaussian_sqr(base,inv,n);
pow>>=1;
}
else{
one = gaussian_prod(one,base,inv,n);
base = gaussian_sqr(base,inv,n);
pow>>=1;
}
}
gaussian_prod(one,base,inv,n)
}
const fn two_sqr_real(x: u128, y: u128, inv: u128, n: u128) -> u128{
mont_add(mont_sqr_128(x,inv,n),double(mont_sqr_128(y,inv,n),n),n)
}
const fn two_prod_real(x: u128, y: u128, a: u128, b: u128,inv: u128, n: u128) -> u128{
mont_add(mont_prod_128(x,a,inv,n),double(mont_prod_128(y,b,inv,n),n),n)
}
const fn two_sqr(x: (u128,u128), inv: u128, n: u128) -> (u128,u128){
(two_sqr_real(x.0,x.1,inv,n),sqr_coef(x.0,x.1,inv,n))
}
const fn two_prod(x: (u128,u128), a: (u128,u128), inv: u128, n: u128) -> (u128,u128){
(two_prod_real(x.0,x.1,a.0,a.1,inv,n),prod_coef(x.0,x.1,a.0,a.1,inv,n))
}
const fn two_pow(mut base: (u128,u128),mut one: (u128,u128),mut pow: u128, inv: u128, n: u128) -> (u128,u128){
while pow > 1 {
if pow&1 == 0{
base = two_sqr(base,inv,n);
pow>>=1;
}
else{
one = two_prod(one,base,inv,n);
base = two_sqr(base,inv,n);
pow>>=1;
}
}
two_prod(one,base,inv,n)
}
const fn general_sqr_real(x: u128, y: u128, c: u128,inv: u128, n: u128) -> u128{
mont_add(mont_sqr_128(x,inv,n),mont_prod_128( mont_sqr_128(y,inv,n),c,inv,n),n)
}
const fn general_prod_real(x: u128, y: u128, a: u128, b: u128,c: u128,inv: u128, n: u128) -> u128{
mont_add(mont_prod_128(x,a,inv,n),mont_prod_128(mont_prod_128(y,b,inv,n),c,inv,n),n)
}
const fn general_sqr(x: (u128,u128),c: u128, inv: u128, n: u128) -> (u128,u128){
(general_sqr_real(x.0,x.1,c,inv,n),sqr_coef(x.0,x.1,inv,n))
}
const fn general_prod(x: (u128,u128),a: (u128,u128), c: u128, inv: u128, n: u128) -> (u128,u128){
(general_prod_real(x.0,x.1,a.0,a.1,c,inv,n),prod_coef(x.0,x.1,a.0,a.1,inv,n))
}
const fn general_pow(mut base: (u128,u128),mut one: (u128,u128),c: u128, mut pow: u128, inv: u128, n: u128) -> (u128,u128){
while pow > 1 {
if pow&1 == 0{
base = general_sqr(base,c,inv,n);
pow>>=1;
}
else{
one = general_prod(one,base,c,inv,n);
base = general_sqr(base,c,inv,n);
pow>>=1;
}
}
general_prod(one,base,c,inv,n)
}
const fn frobenius_idx(n: u128) -> i32{
if n&3 == 3{
return -1;
}
if n&7 == 5{
return 2;
}
if n%12 == 5 || n%12 == 7{
return 3;
}
if n%5 == 2 || n%5 == 3{
return 5;
}
let mut idx = 7;
while !nqr_128(idx,n){
idx+=2;
}
idx as i32
}
pub const fn qft(n: u128, one: u128,two: u128, oneinv: u128,inv: u128) -> bool{
let idx = frobenius_idx(n);
let mul_ident = (one,0);
match idx{
-1 => {
let base = (two,one);
let residue = gaussian_pow(base,mul_ident,n,inv,n);
if residue.0==two && residue.1==oneinv{
return true;
}
false
}
2 => {
let base = (two,one);
let residue = two_pow(base, mul_ident,n,inv,n);
if residue.0==two && residue.1==oneinv{
return true;
}
false
}
_=> {
let base = (one,one);
let c = to_mont_128(idx as u128,n);
let residue = general_pow(base,mul_ident, c,n,inv,n);
if residue.0==one && residue.1==oneinv{
return true;
}
false
}
}
}