use crate::{
Error, Result,
mode::Mode,
pwxform::{PwxformCtx, RMIN},
};
use core::{
fmt::{self, Display},
str::{self, FromStr},
};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Params {
pub(crate) mode: Mode,
pub(crate) n: u64,
pub(crate) r: u32,
pub(crate) p: u32,
pub(crate) t: u32,
pub(crate) g: u32,
pub(crate) nrom: u64,
}
impl Params {
pub(crate) const MAX_ENCODED_LEN: usize = 8 * 6;
pub fn new(mode: Mode, n: u64, r: u32, p: u32) -> Result<Params> {
Self::new_with_all_params(mode, n, r, p, 0, 0)
}
pub fn new_with_all_params(
mode: Mode,
n: u64,
r: u32,
p: u32,
t: u32,
g: u32,
) -> Result<Params> {
if g != 0 {
return Err(Error::Params);
}
if mode.is_rw()
&& (n / u64::from(p) <= 1
|| r < RMIN
|| u64::from(p) > u64::MAX / (3 * (1 << 8) * 2 * 8)
|| u64::from(p) > u64::MAX / (size_of::<PwxformCtx<'_>>() as u64))
{
return Err(Error::Params);
}
Ok(Params {
mode,
n,
r,
p,
t,
g,
nrom: 0,
})
}
#[must_use]
pub const fn n(&self) -> u64 {
self.n
}
#[must_use]
pub const fn r(&self) -> u32 {
self.r
}
#[must_use]
pub const fn p(&self) -> u32 {
self.p
}
#[allow(non_snake_case)]
pub(crate) fn encode<'o>(&self, out: &'o mut [u8]) -> Result<&'o str> {
let flavor = u32::from(self.mode);
let N_log2 = N2log2(self.n);
if N_log2 == 0 {
return Err(Error::Params);
}
let NROM_log2 = N2log2(self.nrom);
if self.nrom != 0 && NROM_log2 == 0 {
return Err(Error::Params);
}
if u64::from(self.r) * u64::from(self.p) >= (1 << 30) {
return Err(Error::Params);
}
let mut pos = 0;
let written = encode64_uint32(&mut out[pos..], flavor, 0)?;
pos += written;
let written = encode64_uint32(&mut out[pos..], N_log2, 1)?;
pos += written;
let written = encode64_uint32(&mut out[pos..], self.r, 1)?;
pos += written;
let mut have = 0;
if self.p != 1 {
have |= 1;
}
if self.t != 0 {
have |= 2;
}
if self.g != 0 {
have |= 4;
}
if NROM_log2 != 0 {
have |= 8;
}
if have != 0 {
let written = encode64_uint32(&mut out[pos..], have, 1)?;
pos += written;
}
if self.p != 1 {
let written = encode64_uint32(&mut out[pos..], self.p, 2)?;
pos += written;
}
if self.t != 0 {
let written = encode64_uint32(&mut out[pos..], self.t, 1)?;
pos += written;
}
if self.g != 0 {
let written = encode64_uint32(&mut out[pos..], self.g, 1)?;
pos += written;
}
if NROM_log2 != 0 {
let written = encode64_uint32(&mut out[pos..], NROM_log2, 1)?;
pos += written;
}
str::from_utf8(&out[..pos]).map_err(|_| Error::Encoding)
}
}
impl Default for Params {
fn default() -> Self {
Params {
mode: Mode::default(),
n: 4096,
r: 32,
p: 1,
t: 0,
g: 0,
nrom: 0,
}
}
}
impl FromStr for Params {
type Err = Error;
#[allow(non_snake_case)]
fn from_str(s: &str) -> Result<Params> {
let bytes = s.as_bytes();
let mut pos = 0usize;
let (flavor, new_pos) = decode64_uint32(bytes, pos, 0)?;
pos = new_pos;
let mode = Mode::try_from(flavor)?;
let (nlog2, new_pos) = decode64_uint32(bytes, pos, 1)?;
pos = new_pos;
if nlog2 > 63 {
return Err(Error::Encoding);
}
let n = 1 << nlog2;
let (r, new_pos) = decode64_uint32(bytes, pos, 1)?;
pos = new_pos;
let mut p = 1;
let mut t = 0;
let mut g = 0;
if pos < bytes.len() {
let (have, new_pos) = decode64_uint32(bytes, pos, 1)?;
pos = new_pos;
if (have & 0x01) != 0 {
let (_p, new_pos) = decode64_uint32(bytes, pos, 2)?;
pos = new_pos;
p = _p;
}
if (have & 0x02) != 0 {
let (_t, new_pos) = decode64_uint32(bytes, pos, 1)?;
pos = new_pos;
t = _t;
}
if (have & 0x04) != 0 {
let (_g, new_pos) = decode64_uint32(bytes, pos, 1)?;
pos = new_pos;
g = _g;
}
if (have & 0x08) != 0 {
let (nrom, _) = decode64_uint32(bytes, pos, 1)?;
if nrom != 0 {
return Err(Error::Params);
}
}
}
Self::new_with_all_params(mode, n, r, p, t, g)
}
}
impl Display for Params {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut buf = [0u8; Self::MAX_ENCODED_LEN];
f.write_str(self.encode(&mut buf).expect("params encode failed"))
}
}
#[allow(non_snake_case)]
fn N2log2(N: u64) -> u32 {
if N < 2 {
return 0;
}
let mut N_log2 = 2u32;
while (N >> N_log2) != 0 {
N_log2 += 1;
}
N_log2 -= 1;
if (N >> N_log2) != 1 {
return 0;
}
N_log2
}
static ITOA64: &[u8] = b"./0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
static ATOI64: [u8; 128] = {
let mut tbl = [0xFFu8; 128]; let mut i = 0u8;
while i < 64 {
tbl[ITOA64[i as usize] as usize] = i;
i += 1;
}
tbl
};
fn decode64_uint32(src: &[u8], mut pos: usize, min: u32) -> Result<(u32, usize)> {
let mut start = 0u32;
let mut end = 47u32;
let mut chars = 1u32;
let mut bits = 0u32;
if pos >= src.len() {
return Err(Error::Encoding);
}
let n = *ATOI64
.get(usize::from(src[pos]))
.filter(|&&n| n <= 63)
.ok_or(Error::Encoding)?;
pos += 1;
let mut dst = min;
while u32::from(n) > end {
dst += (end + 1 - start) << bits;
start = end + 1;
end = start + (62 - end) / 2;
chars += 1;
bits += 6;
}
dst += (u32::from(n) - start) << bits;
while chars > 1 {
chars -= 1;
if bits < 6 || pos >= src.len() {
return Err(Error::Encoding);
}
let c = match ATOI64.get(src[pos] as usize) {
Some(&c) if c <= 63 => c,
_ => return Err(Error::Encoding),
};
pos += 1;
bits -= 6;
dst += u32::from(c) << bits;
}
Ok((dst, pos))
}
fn encode64_uint32(dst: &mut [u8], mut src: u32, min: u32) -> Result<usize> {
let mut start = 0u32;
let mut end = 47u32;
let mut chars = 1u32;
let mut bits = 0u32;
if src < min {
return Err(Error::Params);
}
src -= min;
loop {
let count = (end + 1 - start) << bits;
if src < count {
break;
}
if start >= 63 {
return Err(Error::Encoding);
}
start = end + 1;
end = start + (62 - end) / 2;
src -= count;
chars += 1;
bits += 6;
}
if dst.len() < (chars as usize) {
return Err(Error::Encoding);
}
let mut pos: usize = 0;
dst[pos] = ITOA64[(start + (src >> bits)) as usize];
pos += 1;
while chars > 1 {
chars -= 1;
bits = bits.wrapping_sub(6);
dst[pos] = ITOA64[((src >> bits) & 0x3f) as usize];
pos += 1;
}
Ok(pos)
}
#[cfg(test)]
mod tests {
use crate::{Mode, Params};
use alloc::string::ToString;
#[test]
fn encoder() {
let p1 = Params {
mode: Mode::default(),
n: 4096,
r: 32,
p: 1,
t: 0,
g: 0,
nrom: 0,
};
assert_eq!(p1.to_string(), "j9T");
let p2 = Params {
mode: Mode::default(),
n: 4096,
r: 8,
p: 4,
t: 0,
g: 0,
nrom: 0,
};
assert_eq!(p2.to_string(), "j95.0");
let p3 = Params {
mode: Mode::default(),
n: 4096,
r: 8,
p: 1,
t: 2,
g: 5,
nrom: 0,
};
assert_eq!(p3.to_string(), "j953/2");
let p4 = Params {
mode: Mode::default(),
n: 32768,
r: 8,
p: 1,
t: 0,
g: 0,
nrom: 4096,
};
assert_eq!(p4.to_string(), "jC559");
}
#[test]
#[allow(clippy::unwrap_used)]
fn decoder() {
let p1: Params = "j9T".parse().unwrap();
assert_eq!(
p1,
Params {
mode: Mode::default(),
n: 4096,
r: 32,
p: 1,
t: 0,
g: 0,
nrom: 0,
}
);
let p2: Params = "j95.0".parse().unwrap();
assert_eq!(
p2,
Params {
mode: Mode::default(),
n: 4096,
r: 8,
p: 4,
t: 0,
g: 0,
nrom: 0,
}
);
assert!("j953/2".parse::<Params>().is_err());
assert!("jC559".parse::<Params>().is_err());
}
}