use crate::shake::{i_shake256_extract, InnerShake256Context};
pub fn hash_to_point_vartime(sc: &mut InnerShake256Context, x: &mut [u16], logn: u32) {
let n: usize = 1 << logn;
let mut pos = 0usize;
while pos < n {
let mut buf = [0u8; 2];
i_shake256_extract(sc, &mut buf);
let w: u32 = ((buf[0] as u32) << 8) | (buf[1] as u32);
if w < 61445 {
let mut val = w;
while val >= 12289 {
val -= 12289;
}
x[pos] = val as u16;
pos += 1;
}
}
}
pub fn hash_to_point_ct(sc: &mut InnerShake256Context, x: &mut [u16], logn: u32, tmp: &mut [u8]) {
static OVERTAB: [u16; 11] = [
0, 65, 67, 71, 77, 86, 100, 122, 154, 205, 287,
];
let n = 1usize << logn;
let n2 = n << 1;
let over = OVERTAB[logn as usize] as usize;
let m = n + over;
let tt1: &mut [u16] =
unsafe { core::slice::from_raw_parts_mut(tmp.as_mut_ptr() as *mut u16, n) };
let mut tt2 = [0u16; 63];
for u in 0..m {
let mut buf = [0u8; 2];
i_shake256_extract(sc, &mut buf);
let w: u32 = ((buf[0] as u32) << 8) | (buf[1] as u32);
let mut wr = w;
wr = wr.wrapping_sub(24578 & ((wr.wrapping_sub(24578) >> 31).wrapping_sub(1)));
wr = wr.wrapping_sub(24578 & ((wr.wrapping_sub(24578) >> 31).wrapping_sub(1)));
wr = wr.wrapping_sub(12289 & ((wr.wrapping_sub(12289) >> 31).wrapping_sub(1)));
wr |= (w.wrapping_sub(61445) >> 31).wrapping_sub(1);
if u < n {
x[u] = wr as u16;
} else if u < n2 {
tt1[u - n] = wr as u16;
} else {
tt2[u - n2] = wr as u16;
}
}
let mut p: usize = 1;
while p <= over {
let mut v: usize = 0;
for u in 0..m {
let sv = if u < n {
x[u]
} else if u < n2 {
tt1[u - n]
} else {
tt2[u - n2]
};
let j = u - v;
let mk = (sv >> 15).wrapping_sub(1u16);
v = v.wrapping_add(mk as usize & 1);
if u < p {
continue;
}
let dv = if (u - p) < n {
x[u - p]
} else if (u - p) < n2 {
tt1[(u - p) - n]
} else {
tt2[(u - p) - n2]
};
let mk2 = mk & ((((j & p) as u32 + 0x1FF) >> 9) as u16).wrapping_neg();
let new_s = sv ^ (mk2 & (sv ^ dv));
let new_d = dv ^ (mk2 & (sv ^ dv));
if u < n {
x[u] = new_s;
} else if u < n2 {
tt1[u - n] = new_s;
} else {
tt2[u - n2] = new_s;
}
if (u - p) < n {
x[u - p] = new_d;
} else if (u - p) < n2 {
tt1[(u - p) - n] = new_d;
} else {
tt2[(u - p) - n2] = new_d;
}
}
p <<= 1;
}
}
static L2BOUND: [u32; 11] = [
0, 101498, 208714, 428865, 892039, 1852696, 3842630, 7959734, 16468416, 34034726, 70265242,
];
pub fn is_short(s1: &[i16], s2: &[i16], logn: u32) -> bool {
let n: usize = 1 << logn;
debug_assert!(s1.len() >= n, "is_short: s1 too short");
debug_assert!(s2.len() >= n, "is_short: s2 too short");
let mut s: u32 = 0;
let mut ng: u32 = 0;
for u in 0..n {
unsafe {
let z = *s1.get_unchecked(u) as i32;
s = s.wrapping_add((z * z) as u32);
ng |= s;
let z = *s2.get_unchecked(u) as i32;
s = s.wrapping_add((z * z) as u32);
ng |= s;
}
}
s |= (ng >> 31).wrapping_neg();
s <= L2BOUND[logn as usize]
}
pub fn is_short_half(sqn: u32, s2: &[i16], logn: u32) -> bool {
let n: usize = 1 << logn;
debug_assert!(s2.len() >= n, "is_short_half: s2 too short");
let mut sqn = sqn;
let mut ng: u32 = (sqn >> 31).wrapping_neg();
for u in 0..n {
unsafe {
let z = *s2.get_unchecked(u) as i32;
sqn = sqn.wrapping_add((z * z) as u32);
ng |= sqn;
}
}
sqn |= (ng >> 31).wrapping_neg();
sqn <= L2BOUND[logn as usize]
}