window-sand-box 0.1.1

Windows 沙盒终端执行工具 — 使用受限令牌、ACL 和私有桌面隔离进程权限,提供安全的命令执行环境
//! Windows 工具函数

use anyhow::{Result, anyhow};
use std::ffi::c_void;
use windows_sys::Win32::Foundation::{CloseHandle, GetLastError, HANDLE, HLOCAL, LocalFree};
use windows_sys::Win32::Security::{
    GetTokenInformation, SID_AND_ATTRIBUTES, TOKEN_QUERY,
};
use windows_sys::Win32::Security::Authorization::ConvertSidToStringSidW;
use windows_sys::Win32::System::Threading::{GetCurrentProcess, OpenProcessToken};

/// 将字符串转为 UTF-16 宽字符数组(以 null 结尾)
pub fn to_wide(s: impl AsRef<std::ffi::OsStr>) -> Vec<u16> {
    use std::os::windows::ffi::OsStrExt;
    s.as_ref()
        .encode_wide()
        .chain(std::iter::once(0))
        .collect()
}

/// TOKEN_USER 结构(手动定义,避免依赖 windows-sys 中可能缺失的类型)
#[repr(C)]
struct TokenUserStruct {
    user: SID_AND_ATTRIBUTES,
}

/// TokenUser 的 TOKEN_INFORMATION_CLASS 值 = 1
const TOKEN_USER_CLASS: i32 = 1;

/// 获取当前 Windows 用户的 SID 字符串
///
/// 例如:`S-1-5-21-123456789-1234567890-123456789-1001`
/// 同一 Windows 用户在不同进程中返回相同的 SID。
pub fn get_current_user_sid() -> Result<String> {
    unsafe {
        let mut token_handle: HANDLE = 0;
        if OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &mut token_handle) == 0 {
            return Err(anyhow!("OpenProcessToken 失败: {}", GetLastError()));
        }

        // 先查询所需缓冲区大小
        let mut buf_size: u32 = 0;
        GetTokenInformation(
            token_handle,
            TOKEN_USER_CLASS,
            std::ptr::null_mut(),
            0,
            &mut buf_size,
        );
        if buf_size == 0 {
            CloseHandle(token_handle);
            return Err(anyhow!("GetTokenInformation 查询大小失败: {}", GetLastError()));
        }

        let mut buf = vec![0u8; buf_size as usize];
        if GetTokenInformation(
            token_handle,
            TOKEN_USER_CLASS,
            buf.as_mut_ptr() as *mut c_void,
            buf_size,
            &mut buf_size,
        ) == 0
        {
            CloseHandle(token_handle);
            return Err(anyhow!("GetTokenInformation 失败: {}", GetLastError()));
        }

        // 解析 TOKEN_USER 结构获取用户 SID 指针
        let token_user = &*(buf.as_ptr() as *const TokenUserStruct);
        let user_sid = token_user.user.Sid;

        // 将 SID 转为字符串
        let mut sid_str_ptr: *mut u16 = std::ptr::null_mut();
        if ConvertSidToStringSidW(user_sid, &mut sid_str_ptr) == 0 {
            CloseHandle(token_handle);
            return Err(anyhow!("ConvertSidToStringSidW 失败: {}", GetLastError()));
        }

        let len = (0..).take_while(|&i| *sid_str_ptr.add(i) != 0).count();
        let slice = std::slice::from_raw_parts(sid_str_ptr, len);
        let result = String::from_utf16_lossy(slice);

        LocalFree(sid_str_ptr as HLOCAL);
        CloseHandle(token_handle);

        Ok(result)
    }
}

/// 使用系统 OEM 代码页解码字节(cmd.exe 默认输出编码)
///
/// cmd.exe 使用系统的 OEM 代码页(如中文 Windows 上为 GBK/936,
/// 英文 Windows 上为 CP437/437)输出文本。此函数通过 Windows API
/// `MultiByteToWideChar` 用系统 OEM 代码页解码,比硬编码 GBK 更准确。
///
/// 返回 `None` 如果 OEM 代码页不可用或解码失败。
#[cfg(windows)]
pub fn decode_oem_string(bytes: &[u8]) -> Option<String> {
    use windows_sys::Win32::Globalization::{GetOEMCP, MultiByteToWideChar};

    unsafe {
        let codepage = GetOEMCP();
        if codepage == 0 {
            return None;
        }

        // 如果 OEM 代码页恰好是 UTF-8(65001),跳过(调用者已尝试 UTF-8)
        if codepage == 65001 {
            return None;
        }

        // 查询所需缓冲区大小
        let needed = MultiByteToWideChar(
            codepage,
            0,
            bytes.as_ptr(),
            bytes.len() as i32,
            std::ptr::null_mut(),
            0,
        );
        if needed == 0 {
            return None;
        }

        let mut buf = vec![0u16; needed as usize];
        let written = MultiByteToWideChar(
            codepage,
            0,
            bytes.as_ptr(),
            bytes.len() as i32,
            buf.as_mut_ptr(),
            needed,
        );
        if written == 0 {
            return None;
        }

        Some(String::from_utf16_lossy(&buf))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_to_wide_ascii() {
        let wide = to_wide("hello");
        assert_eq!(wide, &[104, 101, 108, 108, 111, 0]);
    }

    #[test]
    fn test_to_wide_ends_with_null() {
        let wide = to_wide("test");
        assert!(wide.ends_with(&[0]));
        assert_eq!(wide.len(), 5); // t e s t \0
    }

    #[test]
    fn test_to_wide_empty_string() {
        let wide = to_wide("");
        assert_eq!(wide, &[0]);
    }

    #[test]
    fn test_to_wide_unicode() {
        let wide = to_wide("中文");
        // 中 = U+4E2D, 文 = U+6587
        assert_eq!(wide, &[0x4E2D, 0x6587, 0]);
    }

    #[test]
    fn test_to_wide_with_emoji() {
        let wide = to_wide("a🌍b");
        // 🌍 = U+1F30D, 编码为 UTF-16 surrogate pair: 0xD83C 0xDF0D
        assert_eq!(wide[0], 0x0061); // 'a'
        assert_eq!(wide[1], 0xD83C); // high surrogate
        assert_eq!(wide[2], 0xDF0D); // low surrogate
        assert_eq!(wide[3], 0x0062); // 'b'
        assert_eq!(wide[4], 0);       // null terminator
    }

    #[test]
    fn test_to_wide_path_with_spaces() {
        let wide = to_wide("C:\\Program Files\\test");
        assert!(wide.contains(&0x0020)); // 包含空格字符
        assert!(wide.ends_with(&[0]));
    }

    #[test]
    fn test_to_wide_preserves_case() {
        let lower = to_wide("abc");
        let upper = to_wide("ABC");
        assert_ne!(lower, upper, "大小写应被保留");
    }

    #[test]
    fn test_get_current_user_sid_works() -> Result<()> {
        let sid = get_current_user_sid()?;
        assert!(sid.starts_with("S-1-5-21-") || sid.starts_with("S-1-5-"), "用户 SID 格式无效: {sid}");
        assert!(!sid.is_empty(), "SID 不应为空");
        Ok(())
    }
}