use std::{ffi, mem::MaybeUninit, path::Path, ptr};
use widestring::{U16CString, U16Str, U16String};
use windows::{
core::{self, HSTRING, PWSTR},
Storage::Provider::StorageProviderSyncRootManager,
Win32::{
Foundation::{self, GetLastError, LocalFree, HANDLE, HLOCAL, SUCCESS},
Security::{self, Authorization::ConvertSidToStringSidW, GetTokenInformation, TOKEN_USER},
Storage::CloudFilters,
},
};
use crate::ext::PathExt;
pub fn active_roots() {
todo!()
}
pub fn is_supported() -> core::Result<bool> {
StorageProviderSyncRootManager::IsSupported()
}
#[derive(Debug, Clone)]
pub struct SyncRootIdBuilder {
provider_name: U16String,
user_security_id: SecurityId,
account_name: U16String,
}
impl SyncRootIdBuilder {
pub fn new(provider_name: U16String) -> Self {
assert!(
provider_name.len() <= CloudFilters::CF_MAX_PROVIDER_NAME_LENGTH as usize,
"provider name must not exceed {} characters, got {} characters",
CloudFilters::CF_MAX_PROVIDER_NAME_LENGTH,
provider_name.len()
);
Self {
provider_name,
user_security_id: SecurityId(U16String::new()),
account_name: U16String::new(),
}
}
pub fn user_security_id(mut self, security_id: SecurityId) -> Self {
self.user_security_id = security_id;
self
}
pub fn account_name(mut self, account_name: U16String) -> Self {
self.account_name = account_name;
self
}
pub fn build(self) -> core::Result<SyncRootId> {
Ok(SyncRootId(HSTRING::from_wide(
&[
self.provider_name.as_slice(),
self.user_security_id.0.as_slice(),
self.account_name.as_slice(),
]
.join(&SyncRootId::SEPARATOR),
)))
}
}
#[derive(Debug, Clone)]
pub struct SyncRootId(HSTRING);
impl SyncRootId {
const SEPARATOR: u16 = 0x21;
pub fn from_path<P: AsRef<Path>>(path: P) -> core::Result<Self> {
Ok(Self(path.as_ref().sync_root_info()?.Id()?))
}
pub fn is_registered(&self) -> core::Result<bool> {
Ok(
match StorageProviderSyncRootManager::GetSyncRootInformationForId(&self.0) {
Ok(_) => true,
Err(err) => err.code() != Foundation::ERROR_NOT_FOUND.to_hresult(),
},
)
}
pub fn unregister(&self) -> core::Result<()> {
StorageProviderSyncRootManager::Unregister(&self.0)
}
pub fn as_u16str(&self) -> &U16Str {
U16Str::from_slice(&self.0)
}
pub fn as_hstring(&self) -> &HSTRING {
&self.0
}
pub fn to_components(&self) -> core::Result<(&U16Str, &U16Str, &U16Str)> {
let mut parts = self.0.splitn(3, |&byte| byte == Self::SEPARATOR);
if let (Some(first), Some(second), Some(third)) = (parts.next(), parts.next(), parts.next())
{
Ok((
U16Str::from_slice(first),
U16Str::from_slice(second),
U16Str::from_slice(third),
))
} else {
Err(Foundation::ERROR_INVALID_DATA.into())
}
}
}
#[derive(Debug, Clone)]
pub struct SecurityId(U16String);
impl SecurityId {
const CURRENT_THREAD_EFFECTIVE_TOKEN: HANDLE = HANDLE(-6i32 as *mut ffi::c_void);
pub fn new_unchecked(id: U16String) -> Self {
Self(id)
}
pub fn current_user() -> core::Result<Self> {
unsafe {
let mut token_size = 0;
let mut token = MaybeUninit::<TOKEN_USER>::uninit();
if GetTokenInformation(
Self::CURRENT_THREAD_EFFECTIVE_TOKEN,
Security::TokenUser,
None,
0,
&mut token_size,
)
.is_err()
&& GetLastError() == Foundation::ERROR_INSUFFICIENT_BUFFER
{
GetTokenInformation(
Self::CURRENT_THREAD_EFFECTIVE_TOKEN,
Security::TokenUser,
Some(&mut token as *mut _ as *mut _),
token_size,
&mut token_size,
)?;
}
let token = token.assume_init();
let mut sid = PWSTR(ptr::null_mut());
ConvertSidToStringSidW(token.User.Sid, &mut sid as *mut _)?;
let string_sid = U16CString::from_ptr_str(sid.0).into_ustring();
if !LocalFree(Some(HLOCAL(sid.0 as *mut _))).0.is_null() {
let last_error = GetLastError();
if last_error.0 != SUCCESS {
return Err(last_error.into());
}
}
Ok(SecurityId::new_unchecked(string_sid))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_syncroot_id_parse() {
let id = SyncRootId(HSTRING::from("provider-id!security-id!account-name"));
let components = id.to_components();
assert!(components.is_ok());
let (provider, security, account) = id.to_components().unwrap();
assert_eq!(provider, U16String::from("provider-id"));
assert_eq!(security, U16String::from("security-id"));
assert_eq!(account, U16String::from("account-name"));
}
#[test]
fn test_invalid_syncroot_id_parse() {
let id = SyncRootId(HSTRING::from("provider-id!security-id0000"));
let components = id.to_components();
assert!(components.is_err());
}
#[test]
fn test_empty_syncroot_id_parse() {
let id = SyncRootId(HSTRING::from(""));
let components = id.to_components();
assert!(components.is_err());
}
}