use crate::{wrappers, LocalBox, Sid};
use std::ffi::{OsStr, OsString};
use std::io;
use std::os::windows::ffi::{OsStrExt, OsStringExt};
use std::ptr::{null, null_mut};
use winapi::um::winnt::{HANDLE, TOKEN_USER};
pub fn os_from_buf(buf: &[u16]) -> OsString {
OsString::from_wide(
&buf.iter()
.cloned()
.take_while(|&n| n != 0)
.collect::<Vec<u16>>(),
)
}
pub fn buf_from_os<S: AsRef<OsStr> + ?Sized>(os: &S) -> Vec<u16> {
let mut buf: Vec<u16> = os.as_ref().encode_wide().collect();
buf.push(0);
buf
}
pub unsafe fn search_buffer<T: PartialEq>(needle: &T, haystack: *const T) -> usize {
let mut position = 0usize;
while *haystack.add(position) != *needle {
position += 1;
}
position
}
pub fn has_bit(field: u32, bit: u32) -> bool {
field & bit != 0
}
pub fn ptr_from_opt<T>(opt: Option<&T>) -> *const T {
match opt {
Some(inner) => inner,
None => null(),
}
}
pub fn current_process_sid() -> io::Result<LocalBox<Sid>> {
let mut process_token: HANDLE = null_mut();
let result = unsafe {
winapi::um::processthreadsapi::OpenProcessToken(
winapi::um::processthreadsapi::GetCurrentProcess(),
winapi::um::winnt::TOKEN_QUERY,
&mut process_token,
)
};
if result == 0 {
return Err(io::Error::last_os_error());
}
let mut len = 44u32;
let mut token_info;
loop {
token_info = vec![0u8; len as usize];
let result = unsafe {
winapi::um::securitybaseapi::GetTokenInformation(
process_token,
winapi::um::winnt::TokenUser,
token_info.as_mut_ptr() as *mut _,
len,
&mut len,
)
};
if result != 0 {
break;
} else {
let error_code = io::Error::last_os_error();
if error_code.raw_os_error()
== Some(winapi::shared::winerror::ERROR_INSUFFICIENT_BUFFER as i32)
{
continue;
}
unsafe { winapi::um::handleapi::CloseHandle(process_token) };
return Err(error_code);
}
}
let sid_ref = unsafe { &*((*(token_info.as_ptr() as *const TOKEN_USER)).User.Sid as *mut Sid) };
let sid_copy = wrappers::CopySid(sid_ref);
unsafe {
winapi::um::handleapi::CloseHandle(process_token);
}
sid_copy
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn round_trip() {
let basic_os = OsString::from("TeSt");
let basic_buf = vec![0x54, 0x65, 0x53, 0x74, 0x00];
let basic_buf_nuls = vec![0x54, 0x65, 0x53, 0x74, 0x00, 0x00, 0x00, 0x00];
assert_eq!(os_from_buf(&basic_buf), basic_os);
assert_eq!(buf_from_os(&basic_os), basic_buf);
assert_eq!(os_from_buf(&basic_buf_nuls), basic_os);
let unicode_os = OsString::from("💩");
let unicode_buf = vec![0xd83d, 0xdca9, 0x0];
let unicode_buf_nuls = vec![0xd83d, 0xdca9, 0x0, 0x0, 0x0, 0x0, 0x0];
assert_eq!(os_from_buf(&unicode_buf), unicode_os);
assert_eq!(buf_from_os(&unicode_os), unicode_buf);
assert_eq!(os_from_buf(&unicode_buf_nuls), unicode_os);
}
#[test]
fn got_a_sid_for_the_current_process() {
assert!(current_process_sid().is_ok());
}
}