#![allow(
dead_code,
unused_imports,
unused_qualifications,
unreachable_patterns,
unsafe_code
)]
use std::iter::once;
use std::mem::size_of;
use std::ptr::{null, null_mut};
use windows::core::{PCWSTR, PWSTR};
use windows::Win32::Foundation::{CloseHandle, HANDLE, HWND};
use windows::Win32::Graphics::Gdi::HBITMAP;
use windows::Win32::Security::Credentials::{
CredUIPromptForWindowsCredentialsW, CredUnPackAuthenticationBufferW,
CREDUIWIN_ENUMERATE_CURRENT_USER, CREDUI_INFOW, CRED_PACK_FLAGS,
};
use windows::Win32::Security::{LogonUserW, LOGON32_LOGON_NETWORK, LOGON32_PROVIDER_DEFAULT};
use windows::Win32::System::Com::CoTaskMemFree;
use zeroize::Zeroize;
#[derive(Debug)]
pub enum PresenceOutcome {
Verified,
Denied(String),
Unavailable(String),
}
const ERROR_SUCCESS_CODE: u32 = 0;
const ERROR_CANCELLED_CODE: u32 = 1223; const ERROR_LOGON_FAILURE_CODE: u32 = 1326;
const MAX_ATTEMPTS: u32 = 3;
pub fn verify_current_user(reason: &str) -> PresenceOutcome {
unsafe { verify_current_user_inner(reason) }
}
unsafe fn verify_current_user_inner(reason: &str) -> PresenceOutcome {
let message: Vec<u16> = reason.encode_utf16().chain(once(0)).collect();
let caption: Vec<u16> = "gocode-dev".encode_utf16().chain(once(0)).collect();
let ui_info = CREDUI_INFOW {
cbSize: size_of::<CREDUI_INFOW>() as u32,
hwndParent: HWND::default(),
pszMessageText: PCWSTR(message.as_ptr()),
pszCaptionText: PCWSTR(caption.as_ptr()),
hbmBanner: HBITMAP::default(),
};
let mut auth_error: u32 = 0;
let mut attempts: u32 = 0;
loop {
attempts += 1;
let mut auth_package: u32 = 0;
let mut out_buf: *mut core::ffi::c_void = null_mut();
let mut out_size: u32 = 0;
let rc = CredUIPromptForWindowsCredentialsW(
Some(&ui_info),
auth_error,
&mut auth_package,
None,
0,
&mut out_buf,
&mut out_size,
None,
CREDUIWIN_ENUMERATE_CURRENT_USER,
);
match rc {
ERROR_SUCCESS_CODE => {}
ERROR_CANCELLED_CODE => {
return PresenceOutcome::Denied("user cancelled the password prompt".into());
}
other => {
return PresenceOutcome::Unavailable(format!(
"CredUIPromptForWindowsCredentialsW failed (0x{other:08X})"
));
}
}
let outcome = verify_auth_buffer(out_buf, out_size);
if !out_buf.is_null() {
std::slice::from_raw_parts_mut(out_buf.cast::<u8>(), out_size as usize).zeroize();
CoTaskMemFree(Some(out_buf.cast_const()));
}
match outcome {
AuthCheck::Verified => return PresenceOutcome::Verified,
AuthCheck::WrongPassword => {
auth_error = ERROR_LOGON_FAILURE_CODE;
if attempts >= MAX_ATTEMPTS {
return PresenceOutcome::Denied(
"Windows password could not be verified".into(),
);
}
}
AuthCheck::Unavailable(detail) => return PresenceOutcome::Unavailable(detail),
}
}
}
enum AuthCheck {
Verified,
WrongPassword,
Unavailable(String),
}
fn wlen(buf: &[u16]) -> usize {
buf.iter().position(|&c| c == 0).unwrap_or(buf.len())
}
unsafe fn verify_auth_buffer(buf: *mut core::ffi::c_void, size: u32) -> AuthCheck {
if buf.is_null() || size == 0 {
return AuthCheck::Unavailable("empty credential buffer".into());
}
let mut user_len: u32 = 0;
let mut domain_len: u32 = 0;
let mut pass_len: u32 = 0;
drop(CredUnPackAuthenticationBufferW(
CRED_PACK_FLAGS(0),
buf,
size,
PWSTR(null_mut()),
&mut user_len,
PWSTR(null_mut()),
Some(&mut domain_len),
PWSTR(null_mut()),
&mut pass_len,
));
if user_len == 0 || pass_len == 0 {
return AuthCheck::Unavailable("could not size unpacked credentials".into());
}
let mut user = vec![0_u16; user_len as usize];
let mut domain = vec![0_u16; domain_len.max(1) as usize];
let mut password = vec![0_u16; pass_len as usize];
let unpacked = CredUnPackAuthenticationBufferW(
CRED_PACK_FLAGS(0),
buf,
size,
PWSTR(user.as_mut_ptr()),
&mut user_len,
PWSTR(domain.as_mut_ptr()),
Some(&mut domain_len),
PWSTR(password.as_mut_ptr()),
&mut pass_len,
);
if unpacked.is_err() {
user.zeroize();
domain.zeroize();
password.zeroize();
return AuthCheck::Unavailable("could not unpack credentials".into());
}
let user_str = String::from_utf16_lossy(&user[..wlen(&user)]);
let domain_str = String::from_utf16_lossy(&domain[..wlen(&domain)]);
let (eff_user, eff_domain) = if !domain_str.is_empty() {
(user_str, domain_str)
} else if let Some((dom, usr)) = user_str.split_once('\\') {
(usr.to_string(), dom.to_string())
} else {
(user_str, String::new())
};
let user_w: Vec<u16> = eff_user.encode_utf16().chain(once(0)).collect();
let domain_w: Vec<u16> = eff_domain.encode_utf16().chain(once(0)).collect();
let domain_ptr = if eff_domain.is_empty() {
PCWSTR(null())
} else {
PCWSTR(domain_w.as_ptr())
};
let mut token = HANDLE::default();
let logon = LogonUserW(
PCWSTR(user_w.as_ptr()),
domain_ptr,
PCWSTR(password.as_ptr()),
LOGON32_LOGON_NETWORK,
LOGON32_PROVIDER_DEFAULT,
&mut token,
);
password.zeroize();
user.zeroize();
domain.zeroize();
match logon {
Ok(()) => {
if !token.is_invalid() {
drop(CloseHandle(token));
}
AuthCheck::Verified
}
Err(err) => {
let win32 = (err.code().0 as u32) & 0xFFFF;
if win32 == ERROR_LOGON_FAILURE_CODE {
AuthCheck::WrongPassword
} else {
AuthCheck::Unavailable(format!("LogonUserW could not validate the account: {err}"))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn win32_codes_match_platform_definitions() {
assert_eq!(ERROR_SUCCESS_CODE, 0);
assert_eq!(ERROR_CANCELLED_CODE, 1223);
assert_eq!(ERROR_LOGON_FAILURE_CODE, 1326);
}
#[test]
fn outcomes_carry_the_expected_shape() {
let denied = PresenceOutcome::Denied("cancelled".into());
let unavailable = PresenceOutcome::Unavailable("headless".into());
assert!(matches!(denied, PresenceOutcome::Denied(_)));
assert!(matches!(unavailable, PresenceOutcome::Unavailable(_)));
assert!(matches!(
PresenceOutcome::Verified,
PresenceOutcome::Verified
));
}
#[test]
fn wlen_stops_at_nul() {
assert_eq!(wlen(&[b'a' as u16, b'b' as u16, 0, b'c' as u16]), 2);
assert_eq!(wlen(&[b'x' as u16, b'y' as u16]), 2);
assert_eq!(wlen(&[0]), 0);
}
}