extern crate alloc;
use alloc::{vec, vec::Vec};
use aes::{
Aes128,
cipher::{BlockDecrypt, KeyInit, generic_array::GenericArray},
};
use hmac::{Hmac, Mac};
use sha1::Sha1;
const EAPOL_VERSION: u8 = 0x01;
const EAPOL_TYPE_KEY: u8 = 0x03;
const KEY_DESC_TYPE_RSN: u8 = 0x02;
const KEY_INFO_TYPE_HMAC_SHA1_AES: u16 = 0x0002; const KEY_INFO_PAIRWISE: u16 = 0x0008;
const KEY_INFO_INSTALL: u16 = 0x0040;
const KEY_INFO_ACK: u16 = 0x0080;
const KEY_INFO_MIC: u16 = 0x0100;
const KEY_INFO_SECURE: u16 = 0x0200;
const KEY_INFO_ENC_KEY_DATA: u16 = 0x1000;
const EAPOL_HDR_LEN: usize = 4;
const EAPOL_KEY_HDR_LEN: usize = 95;
const MIC_OFFSET: usize = EAPOL_HDR_LEN + 77;
const PMK_LEN: usize = 32;
const PTK_LEN: usize = 48;
const KCK_LEN: usize = 16;
const KEK_LEN: usize = 16;
const TK_LEN: usize = 16;
const NONCE_LEN: usize = 32;
const MIC_LEN: usize = 16;
const REPLAY_COUNTER_LEN: usize = 8;
const SHA1_DIGEST_SIZE: usize = 20;
type HmacSha1 = Hmac<Sha1>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HandshakeState {
Idle,
M2Sent,
Completed,
}
pub enum HandshakeAction {
SendM2(Vec<u8>),
Completed(HandshakeResult),
}
#[derive(Debug, Clone)]
pub struct HandshakeResult {
pub m4_frame: Vec<u8>,
pub tk: [u8; TK_LEN],
pub gtk: Vec<u8>,
pub gtk_key_idx: u8,
}
#[derive(Debug)]
pub enum WpaError {
FrameTooShort,
InvalidEapolType,
InvalidDescriptorType,
UnexpectedMessage,
InvalidState,
ReplayCounterMismatch,
MicMismatch,
InvalidKeyData,
AesUnwrapFailed,
GtkNotFound,
}
impl core::fmt::Display for WpaError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
WpaError::FrameTooShort => write!(f, "frame too short"),
WpaError::InvalidEapolType => write!(f, "not an EAPOL-Key frame"),
WpaError::InvalidDescriptorType => write!(f, "invalid key descriptor type"),
WpaError::UnexpectedMessage => write!(f, "unexpected message"),
WpaError::InvalidState => write!(f, "invalid handshake state"),
WpaError::ReplayCounterMismatch => write!(f, "replay counter mismatch"),
WpaError::MicMismatch => write!(f, "MIC verification failed"),
WpaError::InvalidKeyData => write!(f, "invalid key data"),
WpaError::AesUnwrapFailed => write!(f, "AES key unwrap failed"),
WpaError::GtkNotFound => write!(f, "GTK not found in key data"),
}
}
}
#[derive(Debug)]
struct EapolKeyHeader {
key_info: u16,
key_length: u16,
replay_counter: [u8; REPLAY_COUNTER_LEN],
key_nonce: [u8; NONCE_LEN],
key_iv: [u8; 16],
key_rsc: [u8; 8],
key_mic: [u8; MIC_LEN],
key_data_len: u16,
key_data: Vec<u8>,
}
fn parse_eapol_key_header(eapol: &[u8]) -> Result<EapolKeyHeader, WpaError> {
if eapol.len() < EAPOL_HDR_LEN + EAPOL_KEY_HDR_LEN {
return Err(WpaError::FrameTooShort);
}
if eapol[1] != EAPOL_TYPE_KEY {
return Err(WpaError::InvalidEapolType);
}
let off = EAPOL_HDR_LEN;
if eapol[off] != KEY_DESC_TYPE_RSN {
return Err(WpaError::InvalidDescriptorType);
}
let key_info = u16::from_be_bytes([eapol[off + 1], eapol[off + 2]]);
let key_length = u16::from_be_bytes([eapol[off + 3], eapol[off + 4]]);
let mut replay_counter = [0u8; REPLAY_COUNTER_LEN];
replay_counter.copy_from_slice(&eapol[off + 5..off + 13]);
let mut key_nonce = [0u8; NONCE_LEN];
key_nonce.copy_from_slice(&eapol[off + 13..off + 45]);
let mut key_iv = [0u8; 16];
key_iv.copy_from_slice(&eapol[off + 45..off + 61]);
let mut key_rsc = [0u8; 8];
key_rsc.copy_from_slice(&eapol[off + 61..off + 69]);
let mut key_mic = [0u8; MIC_LEN];
key_mic.copy_from_slice(&eapol[off + 77..off + 93]);
let key_data_len = u16::from_be_bytes([eapol[off + 93], eapol[off + 94]]);
let key_data_start = EAPOL_HDR_LEN + EAPOL_KEY_HDR_LEN; let key_data_end = key_data_start + key_data_len as usize;
if eapol.len() < key_data_end {
return Err(WpaError::FrameTooShort);
}
let key_data = eapol[key_data_start..key_data_end].to_vec();
Ok(EapolKeyHeader {
key_info,
key_length,
replay_counter,
key_nonce,
key_iv,
key_rsc,
key_mic,
key_data_len,
key_data,
})
}
#[derive(Clone)]
struct Ptk {
kck: [u8; KCK_LEN],
kek: [u8; KEK_LEN],
tk: [u8; TK_LEN],
}
impl Ptk {
fn from_bytes(ptk_bytes: &[u8; PTK_LEN]) -> Self {
let mut kck = [0u8; KCK_LEN];
let mut kek = [0u8; KEK_LEN];
let mut tk = [0u8; TK_LEN];
kck.copy_from_slice(&ptk_bytes[0..16]);
kek.copy_from_slice(&ptk_bytes[16..32]);
tk.copy_from_slice(&ptk_bytes[32..48]);
Self { kck, kek, tk }
}
}
fn hmac_sha1(key: &[u8], data: &[u8]) -> [u8; SHA1_DIGEST_SIZE] {
let mut mac = <HmacSha1 as Mac>::new_from_slice(key).expect("HMAC key length");
mac.update(data);
let result = mac.finalize();
let mut out = [0u8; SHA1_DIGEST_SIZE];
out.copy_from_slice(&result.into_bytes());
out
}
fn pbkdf2_sha1(passphrase: &[u8], salt: &[u8], iterations: u32, dk_len: usize) -> Vec<u8> {
let mut result = Vec::with_capacity(dk_len);
let blocks_needed = dk_len.div_ceil(SHA1_DIGEST_SIZE);
for block_index in 1..=blocks_needed {
let mut salt_block = Vec::with_capacity(salt.len() + 4);
salt_block.extend_from_slice(salt);
salt_block.extend_from_slice(&(block_index as u32).to_be_bytes());
let mut u = hmac_sha1(passphrase, &salt_block);
let mut t = u;
for _ in 1..iterations {
u = hmac_sha1(passphrase, &u);
for i in 0..SHA1_DIGEST_SIZE {
t[i] ^= u[i];
}
}
result.extend_from_slice(&t);
}
result.truncate(dk_len);
result
}
fn prf_sha1(key: &[u8], label: &[u8], data: &[u8], output_len: usize) -> Vec<u8> {
let iterations = output_len.div_ceil(SHA1_DIGEST_SIZE);
let mut result = Vec::with_capacity(iterations * SHA1_DIGEST_SIZE);
for i in 0..iterations {
let mut input = Vec::with_capacity(label.len() + 1 + data.len() + 1);
input.extend_from_slice(label);
input.push(0x00); input.extend_from_slice(data);
input.push(i as u8);
let hash = hmac_sha1(key, &input);
result.extend_from_slice(&hash);
}
result.truncate(output_len);
result
}
fn derive_ptk(
pmk: &[u8; PMK_LEN],
aa: &[u8; 6],
spa: &[u8; 6],
anonce: &[u8; NONCE_LEN],
snonce: &[u8; NONCE_LEN],
) -> Ptk {
let mut data = [0u8; 6 + 6 + NONCE_LEN + NONCE_LEN];
let (min_addr, max_addr) = if aa[..] < spa[..] {
(aa.as_slice(), spa.as_slice())
} else {
(spa.as_slice(), aa.as_slice())
};
data[0..6].copy_from_slice(min_addr);
data[6..12].copy_from_slice(max_addr);
let (min_nonce, max_nonce) = if anonce[..] < snonce[..] {
(anonce.as_slice(), snonce.as_slice())
} else {
(snonce.as_slice(), anonce.as_slice())
};
data[12..44].copy_from_slice(min_nonce);
data[44..76].copy_from_slice(max_nonce);
let ptk_bytes = prf_sha1(pmk, b"Pairwise key expansion", &data, PTK_LEN);
let mut ptk_arr = [0u8; PTK_LEN];
ptk_arr.copy_from_slice(&ptk_bytes);
Ptk::from_bytes(&ptk_arr)
}
fn compute_mic(kck: &[u8], eapol_frame: &[u8]) -> [u8; MIC_LEN] {
let hash = hmac_sha1(kck, eapol_frame);
let mut mic = [0u8; MIC_LEN];
mic.copy_from_slice(&hash[..MIC_LEN]);
mic
}
fn aes_key_unwrap(kek: &[u8], wrapped: &[u8]) -> Result<Vec<u8>, WpaError> {
if wrapped.len() < 16 || !wrapped.len().is_multiple_of(8) {
return Err(WpaError::AesUnwrapFailed);
}
let n = (wrapped.len() / 8) - 1; let cipher = Aes128::new(GenericArray::from_slice(kek));
let mut a = [0u8; 8];
a.copy_from_slice(&wrapped[0..8]);
let mut r = Vec::with_capacity(n * 8);
for i in 0..n {
r.extend_from_slice(&wrapped[(i + 1) * 8..(i + 2) * 8]);
}
for j in (0..6u64).rev() {
for i in (0..n).rev() {
let t = (n as u64) * j + (i as u64) + 1;
let t_bytes = t.to_be_bytes();
for k in 0..8 {
a[k] ^= t_bytes[k];
}
let mut block = [0u8; 16];
block[0..8].copy_from_slice(&a);
block[8..16].copy_from_slice(&r[i * 8..(i + 1) * 8]);
let ga = GenericArray::from_mut_slice(&mut block);
cipher.decrypt_block(ga);
a.copy_from_slice(&block[0..8]);
r[i * 8..(i + 1) * 8].copy_from_slice(&block[8..16]);
}
}
const DEFAULT_IV: [u8; 8] = [0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6, 0xA6];
if a != DEFAULT_IV {
log::error!(
"[wpa2] AES Key Unwrap IV check failed: {:02x?} != {:02x?}",
a,
DEFAULT_IV
);
return Err(WpaError::AesUnwrapFailed);
}
Ok(r)
}
fn generate_snonce() -> [u8; NONCE_LEN] {
let mut snonce = [0u8; NONCE_LEN];
let ts = crate::runtime::runtime().now_nanos();
let ts_bytes = ts.to_le_bytes();
let hash1 = hmac_sha1(&ts_bytes, b"snonce-gen-1");
let hash2 = hmac_sha1(&ts_bytes, b"snonce-gen-2");
snonce[..20].copy_from_slice(&hash1);
snonce[20..32].copy_from_slice(&hash2[..12]);
snonce
}
fn build_eapol_key_frame(
key_info: u16,
key_length: u16,
replay_counter: &[u8; REPLAY_COUNTER_LEN],
key_nonce: &[u8; NONCE_LEN],
key_data: &[u8],
) -> Vec<u8> {
let key_data_len = key_data.len() as u16;
let body_len = (EAPOL_KEY_HDR_LEN + key_data.len()) as u16;
let total_len = EAPOL_HDR_LEN + EAPOL_KEY_HDR_LEN + key_data.len();
let mut frame = vec![0u8; total_len];
frame[0] = EAPOL_VERSION;
frame[1] = EAPOL_TYPE_KEY;
frame[2..4].copy_from_slice(&body_len.to_be_bytes());
let off = EAPOL_HDR_LEN;
frame[off] = KEY_DESC_TYPE_RSN;
frame[off + 1..off + 3].copy_from_slice(&key_info.to_be_bytes());
frame[off + 3..off + 5].copy_from_slice(&key_length.to_be_bytes());
frame[off + 5..off + 13].copy_from_slice(replay_counter);
frame[off + 13..off + 45].copy_from_slice(key_nonce);
frame[off + 93..off + 95].copy_from_slice(&key_data_len.to_be_bytes());
if !key_data.is_empty() {
frame[EAPOL_HDR_LEN + EAPOL_KEY_HDR_LEN..].copy_from_slice(key_data);
}
frame
}
fn parse_gtk_kde(data: &[u8]) -> Result<(Vec<u8>, u8), WpaError> {
let mut offset = 0;
while offset + 2 <= data.len() {
let element_type = data[offset];
let element_len = data[offset + 1] as usize;
if offset + 2 + element_len > data.len() {
return Err(WpaError::InvalidKeyData);
}
if element_type == 0xDD
&& element_len >= 6
&& data[offset + 2..offset + 6] == [0x00, 0x0F, 0xAC, 0x01]
{
let key_id = data[offset + 6] & 0x03; let gtk = data[offset + 8..offset + 2 + element_len].to_vec();
log::debug!(
"[wpa2] Found GTK KDE: key_id={}, gtk_len={}",
key_id,
gtk.len(),
);
return Ok((gtk, key_id));
}
if element_type == 0x00 {
offset += 1;
continue;
}
offset += 2 + element_len;
}
log::error!(
"[wpa2] GTK KDE not found in key data ({} bytes)",
data.len()
);
Err(WpaError::GtkNotFound)
}
pub struct Wpa2Handshake {
pub state: HandshakeState,
pmk: [u8; PMK_LEN],
ptk: Option<Ptk>,
anonce: [u8; NONCE_LEN],
snonce: [u8; NONCE_LEN],
aa: [u8; 6],
spa: [u8; 6],
rsn_ie: Vec<u8>,
replay_counter: [u8; REPLAY_COUNTER_LEN],
gtk: Vec<u8>,
gtk_key_idx: u8,
}
impl Wpa2Handshake {
pub fn new(passphrase: &[u8], ssid: &[u8], aa: &[u8; 6], spa: &[u8; 6], rsn_ie: &[u8]) -> Self {
let pmk_vec = pbkdf2_sha1(passphrase, ssid, 4096, PMK_LEN);
let mut pmk = [0u8; PMK_LEN];
pmk.copy_from_slice(&pmk_vec);
let snonce = generate_snonce();
Self {
state: HandshakeState::Idle,
pmk,
ptk: None,
snonce,
anonce: [0u8; NONCE_LEN],
aa: *aa,
spa: *spa,
rsn_ie: rsn_ie.to_vec(),
replay_counter: [0u8; REPLAY_COUNTER_LEN],
gtk: Vec::new(),
gtk_key_idx: 0,
}
}
pub fn update_rsn_ie(&mut self, new_rsn_ie: &[u8]) {
log::debug!(
"[wpa2] Updating RSN IE: old={:02x?}, new={:02x?}",
self.rsn_ie,
new_rsn_ie
);
self.rsn_ie = new_rsn_ie.to_vec();
}
pub fn process_eapol(&mut self, eapol: &[u8]) -> Result<HandshakeAction, WpaError> {
let hdr = parse_eapol_key_header(eapol)?;
let has_ack = (hdr.key_info & KEY_INFO_ACK) != 0;
let has_mic = (hdr.key_info & KEY_INFO_MIC) != 0;
let has_install = (hdr.key_info & KEY_INFO_INSTALL) != 0;
let has_enc = (hdr.key_info & KEY_INFO_ENC_KEY_DATA) != 0;
if has_ack && !has_mic {
log::debug!(
"[wpa2] === M1 === key_info=0x{:04x} replay={:02x?}",
hdr.key_info,
hdr.replay_counter
);
self.process_m1(&hdr, eapol)
} else if has_ack && has_mic && has_install && has_enc {
log::debug!("[wpa2] === M3 === key_info=0x{:04x}", hdr.key_info);
self.process_m3(&hdr, eapol)
} else {
log::warn!(
"[wpa2] Unexpected EAPOL key_info=0x{:04x}, ignoring",
hdr.key_info
);
Err(WpaError::UnexpectedMessage)
}
}
fn process_m1(
&mut self,
hdr: &EapolKeyHeader,
_eapol: &[u8],
) -> Result<HandshakeAction, WpaError> {
if self.state != HandshakeState::Idle && self.state != HandshakeState::M2Sent {
log::warn!("[wpa2] M1 received in unexpected state: {:?}", self.state);
}
self.anonce.copy_from_slice(&hdr.key_nonce);
self.replay_counter.copy_from_slice(&hdr.replay_counter);
let ptk = derive_ptk(&self.pmk, &self.aa, &self.spa, &self.anonce, &self.snonce);
self.ptk = Some(ptk);
let key_info: u16 = KEY_INFO_TYPE_HMAC_SHA1_AES | KEY_INFO_PAIRWISE | KEY_INFO_MIC;
let mut m2 = build_eapol_key_frame(
key_info,
0, &self.replay_counter,
&self.snonce, &self.rsn_ie, );
let mic = compute_mic(&self.ptk.as_ref().unwrap().kck, &m2);
m2[MIC_OFFSET..MIC_OFFSET + MIC_LEN].copy_from_slice(&mic);
self.state = HandshakeState::M2Sent;
log::debug!(
"[wpa2] M2 built ({} bytes), snonce={:02x?}.. anonce={:02x?}.. MIC={:02x?}",
m2.len(),
&self.snonce[..4],
&self.anonce[..4],
&mic[..4]
);
Ok(HandshakeAction::SendM2(m2))
}
fn process_m3(
&mut self,
hdr: &EapolKeyHeader,
eapol: &[u8],
) -> Result<HandshakeAction, WpaError> {
if self.state != HandshakeState::M2Sent {
log::warn!("[wpa2] M3 received in unexpected state: {:?}", self.state);
return Err(WpaError::InvalidState);
}
let ptk = self.ptk.as_ref().ok_or(WpaError::InvalidState)?;
if hdr.replay_counter[..] < self.replay_counter[..] {
log::error!("[wpa2] M3 replay counter too old");
return Err(WpaError::ReplayCounterMismatch);
}
self.replay_counter.copy_from_slice(&hdr.replay_counter);
let mut eapol_copy = eapol.to_vec();
for i in 0..MIC_LEN {
eapol_copy[MIC_OFFSET + i] = 0;
}
let computed_mic = compute_mic(&ptk.kck, &eapol_copy);
if computed_mic != hdr.key_mic {
log::error!(
"[wpa2] M3 MIC mismatch! expected={:02x?}, got={:02x?}",
&computed_mic[..4],
&hdr.key_mic[..4],
);
return Err(WpaError::MicMismatch);
}
log::debug!("[wpa2] M3 MIC verified OK");
if hdr.key_nonce != self.anonce {
log::warn!("[wpa2] M3 ANonce differs from M1, updating");
self.anonce.copy_from_slice(&hdr.key_nonce);
}
let key_data = &hdr.key_data;
if key_data.is_empty() {
log::error!("[wpa2] M3 has no key data");
return Err(WpaError::InvalidKeyData);
}
let decrypted = aes_key_unwrap(&ptk.kek, key_data)?;
log::debug!("[wpa2] M3 key data decrypted: {} bytes", decrypted.len());
let (gtk, gtk_key_idx) = parse_gtk_kde(&decrypted)?;
self.gtk = gtk;
self.gtk_key_idx = gtk_key_idx;
log::debug!(
"[wpa2] GTK extracted: key_idx={}, len={}",
self.gtk_key_idx,
self.gtk.len(),
);
let key_info: u16 =
KEY_INFO_TYPE_HMAC_SHA1_AES | KEY_INFO_PAIRWISE | KEY_INFO_MIC | KEY_INFO_SECURE;
let mut m4 = build_eapol_key_frame(
key_info,
0, &self.replay_counter,
&[0u8; NONCE_LEN], &[], );
let mic = compute_mic(&ptk.kck, &m4);
m4[MIC_OFFSET..MIC_OFFSET + MIC_LEN].copy_from_slice(&mic);
self.state = HandshakeState::Completed;
log::debug!("[wpa2] M4 built ({} bytes), handshake complete!", m4.len());
let mut tk = [0u8; TK_LEN];
tk.copy_from_slice(&ptk.tk);
Ok(HandshakeAction::Completed(HandshakeResult {
m4_frame: m4,
tk,
gtk: self.gtk.clone(),
gtk_key_idx: self.gtk_key_idx,
}))
}
}