use crate::{B64, Error, Result};
use base64ct::Encoding;
use core::{cmp::Ordering, fmt, str::FromStr};
use ctutils::{Choice, CtEq};
#[derive(Copy, Clone, Eq)]
pub struct Output {
bytes: [u8; Self::MAX_LENGTH],
length: u8,
}
#[allow(clippy::len_without_is_empty)]
impl Output {
pub const MIN_LENGTH: usize = 10;
pub const MAX_LENGTH: usize = 64;
pub const B64_MAX_LENGTH: usize = (Self::MAX_LENGTH * 4).div_ceil(3);
pub fn new(input: &[u8]) -> Result<Self> {
Self::init_with(input.len(), |bytes| {
bytes.copy_from_slice(input);
Ok(())
})
}
pub fn init_with<F>(output_size: usize, f: F) -> Result<Self>
where
F: FnOnce(&mut [u8]) -> Result<()>,
{
if output_size < Self::MIN_LENGTH {
return Err(Error::OutputSize {
provided: Ordering::Less,
expected: Self::MIN_LENGTH,
});
}
if output_size > Self::MAX_LENGTH {
return Err(Error::OutputSize {
provided: Ordering::Greater,
expected: Self::MAX_LENGTH,
});
}
let mut bytes = [0u8; Self::MAX_LENGTH];
f(&mut bytes[..output_size])?;
Ok(Self {
bytes,
length: output_size as u8,
})
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len()]
}
pub fn len(&self) -> usize {
usize::from(self.length)
}
pub fn decode(input: &str) -> Result<Self> {
let mut bytes = [0u8; Self::MAX_LENGTH];
B64::decode(input, &mut bytes)
.map_err(Into::into)
.and_then(Self::new)
}
pub fn encode<'a>(&self, out: &'a mut [u8]) -> Result<&'a str> {
Ok(B64::encode(self.as_ref(), out)?)
}
pub fn encoded_len(&self) -> usize {
B64::encoded_len(self.as_ref())
}
#[deprecated(since = "0.3.0", note = "Use `Output::decode` instead")]
pub fn b64_decode(input: &str) -> Result<Self> {
Self::decode(input)
}
#[deprecated(since = "0.3.0", note = "Use `Output::encode` instead")]
pub fn b64_encode<'a>(&self, out: &'a mut [u8]) -> Result<&'a str> {
self.encode(out)
}
#[deprecated(since = "0.3.0", note = "Use `Output::encoded_len` instead")]
pub fn b64_len(&self) -> usize {
self.encoded_len()
}
}
impl AsRef<[u8]> for Output {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl CtEq for Output {
fn ct_eq(&self, other: &Self) -> Choice {
self.as_ref().ct_eq(other.as_ref())
}
}
impl FromStr for Output {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::decode(s)
}
}
impl PartialEq for Output {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl TryFrom<&[u8]> for Output {
type Error = Error;
fn try_from(input: &[u8]) -> Result<Output> {
Self::new(input)
}
}
impl fmt::Display for Output {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut buffer = [0u8; Self::B64_MAX_LENGTH];
self.encode(&mut buffer)
.map_err(|_| fmt::Error)
.and_then(|encoded| f.write_str(encoded))
}
}
impl fmt::Debug for Output {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Output(\"{self}\")")
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::{Error, Ordering, Output};
#[test]
fn new_with_valid_min_length_input() {
let bytes = [10u8; 10];
let output = Output::new(&bytes).unwrap();
assert_eq!(output.as_ref(), &bytes);
}
#[test]
fn new_with_valid_max_length_input() {
let bytes = [64u8; 64];
let output = Output::new(&bytes).unwrap();
assert_eq!(output.as_ref(), &bytes);
}
#[test]
fn reject_new_too_short() {
let bytes = [9u8; 9];
let err = Output::new(&bytes).err().unwrap();
assert_eq!(
err,
Error::OutputSize {
provided: Ordering::Less,
expected: Output::MIN_LENGTH
}
);
}
#[test]
fn reject_new_too_long() {
let bytes = [65u8; 65];
let err = Output::new(&bytes).err().unwrap();
assert_eq!(
err,
Error::OutputSize {
provided: Ordering::Greater,
expected: Output::MAX_LENGTH
}
);
}
#[test]
fn partialeq_true() {
let a = Output::new(&[1u8; 32]).unwrap();
let b = Output::new(&[1u8; 32]).unwrap();
assert_eq!(a, b);
}
#[test]
fn partialeq_false() {
let a = Output::new(&[1u8; 32]).unwrap();
let b = Output::new(&[2u8; 32]).unwrap();
assert_ne!(a, b);
}
}