use crate::{
pos_finder::PosFinder, LocalizedKey, Md5, SecurityError, SecurityParams, SecurityResult, Sha1,
AUTH_PARAMS_LEN, AUTH_PARAMS_PLACEHOLDER,
};
use hmac::{Hmac, Mac, NewMac};
use md5::digest::{BlockInput, FixedOutput, Reset, Update};
use std::ops::Range;
const TIME_WINDOW: u32 = 150;
pub trait Digest: Update + BlockInput + FixedOutput + Reset + Default + Clone {}
impl Digest for Md5 {}
impl Digest for Sha1 {}
#[derive(Clone, Debug, Eq, PartialEq, Hash)]
pub struct AuthKey<'a, D> {
localized_key: LocalizedKey<'a, D>,
}
impl<'a, D: 'a> AuthKey<'a, D> {
pub fn new(localized_key: LocalizedKey<'a, D>) -> Self {
Self { localized_key }
}
fn params_ranges(msg: &[u8]) -> SecurityResult<(Range<usize>, Range<usize>)> {
let mut pos_finder = PosFinder::new(msg);
pos_finder.step_into_seq()?; pos_finder.skip_int()?; pos_finder.skip_seq()?; let security_params_range = pos_finder.step_into_octet_str()?;
let auth_params_range = Self::find_auth_params_range(&mut pos_finder)
.map_err(|_| SecurityError::MalformedSecurityParams)?;
let auth_params_len = auth_params_range.end - auth_params_range.start;
if auth_params_len != AUTH_PARAMS_LEN {
return Err(SecurityError::WrongAuthParams);
}
Ok((security_params_range, auth_params_range))
}
fn find_auth_params_range(pos_finder: &mut PosFinder) -> SecurityResult<Range<usize>> {
pos_finder.step_into_seq()?; pos_finder.skip_octet_str()?; pos_finder.skip_int()?; pos_finder.skip_int()?; pos_finder.skip_octet_str()?;
pos_finder.step_into_octet_str() }
fn validate_timeliness(
security_params: &SecurityParams,
local_engine_id: &[u8],
local_engine_boots: u32,
local_engine_time: u32,
) -> SecurityResult<()> {
if local_engine_boots >= SecurityParams::ENGINE_BOOTS_MAX {
return Err(SecurityError::NotInTimeWindow);
}
let is_authoritative_engine = security_params.engine_id() == local_engine_id;
if is_authoritative_engine {
Self::validate_timeliness_for_authoritative(
security_params,
local_engine_boots,
local_engine_time,
)?;
} else {
Self::validate_timeliness_for_non_authoritative(
security_params,
local_engine_boots,
local_engine_time,
)?;
}
Ok(())
}
fn validate_timeliness_for_authoritative(
security_params: &SecurityParams,
local_engine_boots: u32,
local_engine_time: u32,
) -> SecurityResult<()> {
if security_params.engine_boots() != local_engine_boots {
return Err(SecurityError::NotInTimeWindow);
}
let time_diff = Self::diff(security_params.engine_time(), local_engine_time);
if time_diff > TIME_WINDOW {
return Err(SecurityError::NotInTimeWindow);
}
Ok(())
}
fn validate_timeliness_for_non_authoritative(
security_params: &SecurityParams,
local_engine_boots: u32,
local_engine_time: u32,
) -> SecurityResult<()> {
if security_params.engine_boots() < local_engine_boots {
return Err(SecurityError::NotInTimeWindow);
}
if security_params.engine_boots() == local_engine_boots
&& Self::less_than_over(
TIME_WINDOW,
security_params.engine_time(),
local_engine_time,
)
{
return Err(SecurityError::NotInTimeWindow);
}
Ok(())
}
fn diff(lhs: u32, rhs: u32) -> u32 {
if lhs > rhs {
lhs - rhs
} else {
rhs - lhs
}
}
fn less_than_over(amount: u32, lhs: u32, rhs: u32) -> bool {
if lhs > rhs {
return false;
}
rhs - lhs > amount
}
}
impl<'a, D: 'a> AuthKey<'a, D>
where
D: Digest,
{
pub fn auth_in_msg(
&self,
msg: &mut [u8],
local_engine_id: &[u8],
local_engine_boots: u32,
local_engine_time: u32,
) -> SecurityResult<()> {
let (security_params_range, auth_params_range) = Self::params_ranges(msg)?;
let mut saved_auth_params: [u8; AUTH_PARAMS_LEN] = [0x0; AUTH_PARAMS_LEN];
saved_auth_params.copy_from_slice(&msg[auth_params_range.start..auth_params_range.end]);
msg[auth_params_range.start..auth_params_range.end]
.copy_from_slice(&AUTH_PARAMS_PLACEHOLDER);
let auth_params = self.hmac(msg);
if saved_auth_params != auth_params[..] {
return Err(SecurityError::WrongAuthParams);
}
msg[auth_params_range].copy_from_slice(&saved_auth_params);
let security_params = SecurityParams::decode(&msg[security_params_range])?;
Self::validate_timeliness(
&security_params,
local_engine_id,
local_engine_boots,
local_engine_time,
)?;
Ok(())
}
pub fn auth_out_msg(&self, msg: &mut [u8]) -> SecurityResult<()> {
let (_, auth_params_range) = Self::params_ranges(msg)?;
let auth_params = self.hmac(msg);
msg[auth_params_range].copy_from_slice(&auth_params);
Ok(())
}
fn hmac(&self, msg: &[u8]) -> Vec<u8> {
let mut mac = Hmac::<D>::new_varkey(self.localized_key.bytes()).unwrap();
mac.update(msg);
let result = mac.finalize();
let bytes = result.into_bytes();
bytes[0..AUTH_PARAMS_LEN].to_vec()
}
}