use crate::{Encoding, Error, Result};
use core::{cmp::PartialEq, fmt, str::FromStr};
use subtle::{Choice, ConstantTimeEq};
#[derive(Copy, Clone, Eq)]
pub struct Output {
bytes: [u8; Self::MAX_LENGTH],
length: u8,
encoding: Encoding,
}
#[allow(clippy::len_without_is_empty)]
impl Output {
pub const MIN_LENGTH: usize = 10;
pub const MAX_LENGTH: usize = 64;
pub const B64_MAX_LENGTH: usize = ((Self::MAX_LENGTH * 4) / 3) + 1;
pub fn new(input: &[u8]) -> Result<Self> {
Self::init_with(input.len(), |bytes| {
bytes.copy_from_slice(input);
Ok(())
})
}
pub fn new_with_encoding(input: &[u8], encoding: Encoding) -> Result<Self> {
let mut result = Self::new(input)?;
result.encoding = encoding;
Ok(result)
}
pub fn init_with<F>(output_size: usize, f: F) -> Result<Self>
where
F: FnOnce(&mut [u8]) -> Result<()>,
{
if output_size < Self::MIN_LENGTH {
return Err(Error::OutputTooShort);
}
if output_size > Self::MAX_LENGTH {
return Err(Error::OutputTooLong);
}
let mut bytes = [0u8; Self::MAX_LENGTH];
f(&mut bytes[..output_size])?;
Ok(Self {
bytes,
length: output_size as u8,
encoding: Encoding::default(),
})
}
pub fn as_bytes(&self) -> &[u8] {
&self.bytes[..self.len()]
}
pub fn encoding(&self) -> Encoding {
self.encoding
}
pub fn len(&self) -> usize {
usize::from(self.length)
}
pub fn b64_decode(input: &str) -> Result<Self> {
Self::decode(input, Encoding::B64)
}
pub fn b64_encode<'a>(&self, out: &'a mut [u8]) -> Result<&'a str> {
self.encode(out, Encoding::B64)
}
pub fn decode(input: &str, encoding: Encoding) -> Result<Self> {
let mut bytes = [0u8; Self::MAX_LENGTH];
encoding
.decode(input, &mut bytes)
.map_err(Into::into)
.and_then(|decoded| Self::new_with_encoding(decoded, encoding))
}
pub fn encode<'a>(&self, out: &'a mut [u8], encoding: Encoding) -> Result<&'a str> {
Ok(encoding.encode(self.as_ref(), out)?)
}
pub fn b64_len(&self) -> usize {
Encoding::B64.encoded_len(self.as_ref())
}
}
impl AsRef<[u8]> for Output {
fn as_ref(&self) -> &[u8] {
self.as_bytes()
}
}
impl ConstantTimeEq for Output {
fn ct_eq(&self, other: &Self) -> Choice {
self.as_ref().ct_eq(other.as_ref())
}
}
impl FromStr for Output {
type Err = Error;
fn from_str(s: &str) -> Result<Self> {
Self::b64_decode(s)
}
}
impl PartialEq for Output {
fn eq(&self, other: &Self) -> bool {
self.ct_eq(other).into()
}
}
impl TryFrom<&[u8]> for Output {
type Error = Error;
fn try_from(input: &[u8]) -> Result<Output> {
Self::new(input)
}
}
impl fmt::Display for Output {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut buffer = [0u8; Self::B64_MAX_LENGTH];
self.encode(&mut buffer, self.encoding)
.map_err(|_| fmt::Error)
.and_then(|encoded| f.write_str(encoded))
}
}
impl fmt::Debug for Output {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Output(\"{}\")", self)
}
}
#[cfg(test)]
mod tests {
use super::{Error, Output};
#[test]
fn new_with_valid_min_length_input() {
let bytes = [10u8; 10];
let output = Output::new(&bytes).unwrap();
assert_eq!(output.as_ref(), &bytes);
}
#[test]
fn new_with_valid_max_length_input() {
let bytes = [64u8; 64];
let output = Output::new(&bytes).unwrap();
assert_eq!(output.as_ref(), &bytes);
}
#[test]
fn reject_new_too_short() {
let bytes = [9u8; 9];
let err = Output::new(&bytes).err().unwrap();
assert_eq!(err, Error::OutputTooShort);
}
#[test]
fn reject_new_too_long() {
let bytes = [65u8; 65];
let err = Output::new(&bytes).err().unwrap();
assert_eq!(err, Error::OutputTooLong);
}
#[test]
fn partialeq_true() {
let a = Output::new(&[1u8; 32]).unwrap();
let b = Output::new(&[1u8; 32]).unwrap();
assert_eq!(a, b);
}
#[test]
fn partialeq_false() {
let a = Output::new(&[1u8; 32]).unwrap();
let b = Output::new(&[2u8; 32]).unwrap();
assert_ne!(a, b);
}
}