use std::fmt::{self, Debug};
use crate::modhex::{self, ModHex};
use aes::{
Aes128,
cipher::{self, BlockDecrypt, KeyInit},
};
const MAX_USAGE_COUNTER: u16 = 0x7fff;
#[derive(Debug, PartialEq)]
pub enum Error {
InvalidOtpFormat,
Modhex(modhex::Error),
Decryption,
InvalidKey,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::InvalidOtpFormat => write!(f, "Invalid OTP format"),
Error::Modhex(err) => write!(f, "Modhex error: {err:?}"),
Error::Decryption => write!(f, "Decryption failed"),
Error::InvalidKey => write!(f, "Invalid decryption key"),
}
}
}
impl std::error::Error for Error {}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct PublicId {
pub raw_bytes: [u8; 6],
}
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
#[repr(transparent)]
pub struct PrivateId {
pub raw_bytes: [u8; 6],
}
impl fmt::Display for PublicId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", ModHex::from(&self.raw_bytes[..]))
}
}
impl Debug for PublicId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PublicId (")?;
fmt::Display::fmt(&self, f)?;
write!(f, ")")
}
}
impl fmt::Display for PrivateId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
format_hex(f, &self.raw_bytes)
}
}
impl Debug for PrivateId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "PrivateId (")?;
fmt::Display::fmt(&self, f)?;
write!(f, ")")
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct Otp {
pub id: PublicId,
private: [u8; 16],
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecryptedOtp {
pub id: PublicId,
pub private: DecryptedPrivateData,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DecryptedPrivateData {
pub id: PrivateId,
pub usage_counter: u16,
pub session_counter: u8,
pub timestamp: u32,
pub random: [u8; 2],
}
#[derive(Debug, PartialEq, Eq)]
pub enum ValidationError {
PublicIdMismatch,
Decryption,
PrivateIdMismatch,
UsageCounterDecreased,
UsageCounterAtMaximum,
InvalidSessionCounter,
}
impl fmt::Display for ValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ValidationError::PublicIdMismatch => write!(f, "Public ID does not match"),
ValidationError::PrivateIdMismatch => {
write!(f, "Private ID does not match previous OTP")
}
ValidationError::UsageCounterDecreased => {
write!(f, "Usage counter decreased - possible replay attack")
}
ValidationError::UsageCounterAtMaximum => {
write!(f, "Usage counter is at maximum value")
}
ValidationError::InvalidSessionCounter => write!(f, "Session counter is not valid"),
ValidationError::Decryption => write!(f, "Decryption failed"),
}
}
}
impl DecryptedPrivateData {
fn validate_with_previous(&self, previous: &Self) -> Result<(), ValidationError> {
self.validate_basic(previous.id)?;
if self.usage_counter < previous.usage_counter || self.usage_counter == 0 {
return Err(ValidationError::UsageCounterDecreased);
} else if self.usage_counter == previous.usage_counter {
if self.session_counter <= previous.session_counter {
return Err(ValidationError::InvalidSessionCounter);
}
return Ok(());
} else {
}
Ok(())
}
fn validate_basic(&self, private_id: PrivateId) -> Result<(), ValidationError> {
if self.id != private_id {
return Err(ValidationError::PrivateIdMismatch);
}
if self.usage_counter == MAX_USAGE_COUNTER {
return Err(ValidationError::UsageCounterAtMaximum);
}
Ok(())
}
}
impl Otp {
pub fn from_modhex(otp: &str) -> Result<Self, Error> {
if otp.len() != 44 {
return Err(Error::InvalidOtpFormat);
}
let otp_raw = ModHex::try_from(otp).map_err(Error::Modhex)?;
unsafe {
Ok(Otp {
id: PublicId {
raw_bytes: otp_raw.raw_bytes()[..6].try_into().unwrap_unchecked(),
},
private: otp_raw.raw_bytes()[6..].try_into().unwrap_unchecked(),
})
}
}
pub fn decrypt(&self, key: &[u8]) -> Result<DecryptedOtp, Error> {
Ok(DecryptedOtp {
id: self.id,
private: DecryptedPrivateData::decrypt(
&self.private,
&key.try_into().map_err(|_| Error::InvalidKey)?,
)
.or(Err(Error::Decryption))?,
})
}
pub fn validate(&self, previous: &Self, key: &[u8]) -> Result<(), ValidationError> {
if self.id != previous.id {
return Err(ValidationError::PublicIdMismatch);
}
let decrypted_self = self.decrypt(key).or(Err(ValidationError::Decryption))?;
let decrypted_previous = previous.decrypt(key).or(Err(ValidationError::Decryption))?;
decrypted_self.validate(&decrypted_previous)
}
}
impl DecryptedOtp {
pub fn validate(&self, previous: &Self) -> Result<(), ValidationError> {
if self.id != previous.id {
return Err(ValidationError::PublicIdMismatch);
}
self.private.validate_with_previous(&previous.private)
}
}
impl DecryptedPrivateData {
fn decrypt(private_data: &[u8; 16], secret_key: &[u8; 16]) -> Result<Self, ()> {
let decrypted = {
use cipher::generic_array::GenericArray;
let cipher = Aes128::new(GenericArray::from_slice(secret_key));
let mut block = GenericArray::clone_from_slice(private_data);
cipher.decrypt_block(&mut block);
block.into_iter().collect::<Vec<u8>>()
};
if Self::calculate_crc16(&decrypted) != 0xf0b8 {
return Err(());
}
Ok(DecryptedPrivateData {
id: PrivateId {
raw_bytes: unsafe { decrypted[0..6].try_into().unwrap_unchecked() },
},
usage_counter: u16::from_le_bytes([decrypted[6], decrypted[7]]),
timestamp: u32::from_le_bytes([decrypted[8], decrypted[9], decrypted[10], 0]),
session_counter: decrypted[11],
random: [decrypted[12], decrypted[13]],
})
}
fn calculate_crc16(data: &[u8]) -> u16 {
let mut crc: u16 = 0xffff;
for &byte in data {
crc ^= u16::from(byte);
for _ in 0..8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ 0x8408;
} else {
crc >>= 1;
}
}
}
crc
}
}
pub(crate) fn format_hex(f: &mut std::fmt::Formatter, value: &[u8]) -> std::fmt::Result {
for byte in value {
write!(f, "{byte:02x}")?;
}
Ok(())
}