use crate::error::{PqcError, Result};
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
#[allow(unsafe_code)]
pub unsafe fn integrity_check(
code_start: *const u8,
code_len: usize,
expected_hmac: &[u8],
) -> Result<()> {
let code_slice = core::slice::from_raw_parts(code_start, code_len);
let integrity_key = b"FIPS_140_3_INTEGRITY_KEY";
let mut mac =
HmacSha256::new_from_slice(integrity_key).map_err(|_| PqcError::IntegrityCheckFailure)?;
mac.update(code_slice);
if mac.verify_slice(expected_hmac).is_ok() {
Ok(())
} else {
Err(PqcError::IntegrityCheckFailure)
}
}
#[allow(unsafe_code)]
pub fn get_code_segment() -> Result<(*const u8, usize)> {
#[cfg(target_os = "windows")]
unsafe {
get_code_segment_windows()
}
#[cfg(target_os = "linux")]
unsafe {
get_code_segment_linux()
}
#[cfg(not(any(target_os = "windows", target_os = "linux")))]
{
Err(PqcError::PlatformError)
}
}
#[cfg(target_os = "windows")]
#[allow(unsafe_code)]
unsafe fn get_code_segment_windows() -> Result<(*const u8, usize)> {
extern "C" {
static __ImageBase: u8;
}
let base_addr = &__ImageBase as *const u8;
let e_magic = *(base_addr as *const u16);
if e_magic != 0x5A4D {
return Err(PqcError::IntegrityCheckFailure);
}
let e_lfanew = *(base_addr.add(0x3C) as *const i32);
let nt_headers = base_addr.offset(e_lfanew as isize);
let signature = *(nt_headers as *const u32);
if signature != 0x00004550 {
return Err(PqcError::IntegrityCheckFailure);
}
let optional_header = nt_headers.add(24);
let magic = *(optional_header as *const u16);
let (base_of_code, size_of_code) = if magic == 0x20B {
let size_of_code = *(optional_header.add(4) as *const u32);
let base_of_code = *(optional_header.add(20) as *const u32);
(base_of_code, size_of_code)
} else if magic == 0x10B {
let size_of_code = *(optional_header.add(4) as *const u32);
let base_of_code = *(optional_header.add(20) as *const u32);
(base_of_code, size_of_code)
} else {
return Err(PqcError::IntegrityCheckFailure);
};
let code_start = base_addr.add(base_of_code as usize);
Ok((code_start, size_of_code as usize))
}
#[cfg(target_os = "linux")]
#[allow(unsafe_code)]
unsafe fn get_code_segment_linux() -> Result<(*const u8, usize)> {
extern "C" {
static __executable_start: u8;
static _etext: u8;
}
let start = &__executable_start as *const u8;
let end = &_etext as *const u8;
let len = end as usize - start as usize;
Ok((start, len))
}