use byteorder::{ByteOrder, LittleEndian};
use std::collections::HashMap;
use std::iter::once;
use std::mem::MaybeUninit;
use std::str;
use windows_sys::Win32::Foundation::{
GetLastError, ERROR_BAD_USERNAME, ERROR_INVALID_FLAGS, ERROR_INVALID_PARAMETER,
ERROR_NOT_FOUND, ERROR_NO_SUCH_LOGON_SESSION, FILETIME,
};
use windows_sys::Win32::Security::Credentials::{
CredDeleteW, CredFree, CredReadW, CredWriteW, CREDENTIALW, CREDENTIAL_ATTRIBUTEW, CRED_FLAGS,
CRED_MAX_CREDENTIAL_BLOB_SIZE, CRED_MAX_GENERIC_TARGET_NAME_LENGTH, CRED_MAX_STRING_LENGTH,
CRED_MAX_USERNAME_LENGTH, CRED_PERSIST_ENTERPRISE, CRED_TYPE_GENERIC,
};
use super::credential::{Credential, CredentialApi, CredentialBuilder, CredentialBuilderApi};
use super::error::{Error as ErrorCode, Result};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct WinCredential {
pub username: String,
pub target_name: String,
pub target_alias: String,
pub comment: String,
}
impl CredentialApi for WinCredential {
fn set_password(&self, password: &str) -> Result<()> {
self.validate_attributes(None, Some(password))?;
let blob_u16 = to_wstr_no_null(password);
let mut blob = vec![0; blob_u16.len() * 2];
LittleEndian::write_u16_into(&blob_u16, &mut blob);
self.set_secret(&blob)
}
fn set_secret(&self, secret: &[u8]) -> Result<()> {
self.validate_attributes(Some(secret), None)?;
self.save_credential(secret)
}
fn get_password(&self) -> Result<String> {
self.extract_from_platform(extract_password)
}
fn get_secret(&self) -> Result<Vec<u8>> {
self.extract_from_platform(extract_secret)
}
fn get_attributes(&self) -> Result<HashMap<String, String>> {
let cred = self.extract_from_platform(Self::extract_credential)?;
let mut attributes: HashMap<String, String> = HashMap::new();
attributes.insert("comment".to_string(), cred.comment.clone());
attributes.insert("target_alias".to_string(), cred.target_alias.clone());
attributes.insert("username".to_string(), cred.username.clone());
Ok(attributes)
}
fn update_attributes(&self, attributes: &HashMap<&str, &str>) -> Result<()> {
let secret = self.extract_from_platform(extract_secret)?;
let mut cred = self.extract_from_platform(Self::extract_credential)?;
if let Some(comment) = attributes.get(&"comment") {
cred.comment = comment.to_string();
}
if let Some(target_alias) = attributes.get(&"target_alias") {
cred.target_alias = target_alias.to_string();
}
if let Some(username) = attributes.get(&"username") {
cred.username = username.to_string();
}
cred.validate_attributes(Some(&secret), None)?;
cred.save_credential(&secret)
}
fn delete_credential(&self) -> Result<()> {
self.validate_attributes(None, None)?;
let target_name = to_wstr(&self.target_name);
let cred_type = CRED_TYPE_GENERIC;
match unsafe { CredDeleteW(target_name.as_ptr(), cred_type, 0) } {
0 => Err(decode_error()),
_ => Ok(()),
}
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn debug_fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Debug::fmt(self, f)
}
}
impl WinCredential {
fn validate_attributes(&self, secret: Option<&[u8]>, password: Option<&str>) -> Result<()> {
if self.username.len() > CRED_MAX_USERNAME_LENGTH as usize {
return Err(ErrorCode::TooLong(
String::from("user"),
CRED_MAX_USERNAME_LENGTH,
));
}
if self.target_name.is_empty() {
return Err(ErrorCode::Invalid(
"target".to_string(),
"cannot be empty".to_string(),
));
}
if self.target_name.len() > CRED_MAX_GENERIC_TARGET_NAME_LENGTH as usize {
return Err(ErrorCode::TooLong(
String::from("target"),
CRED_MAX_GENERIC_TARGET_NAME_LENGTH,
));
}
if self.target_alias.len() > CRED_MAX_STRING_LENGTH as usize {
return Err(ErrorCode::TooLong(
String::from("target alias"),
CRED_MAX_STRING_LENGTH,
));
}
if self.comment.len() > CRED_MAX_STRING_LENGTH as usize {
return Err(ErrorCode::TooLong(
String::from("comment"),
CRED_MAX_STRING_LENGTH,
));
}
if let Some(secret) = secret {
if secret.len() > CRED_MAX_CREDENTIAL_BLOB_SIZE as usize {
return Err(ErrorCode::TooLong(
String::from("secret"),
CRED_MAX_CREDENTIAL_BLOB_SIZE,
));
}
}
if let Some(password) = password {
if password.encode_utf16().count() * 2 > CRED_MAX_CREDENTIAL_BLOB_SIZE as usize {
return Err(ErrorCode::TooLong(
String::from("password encoded as UTF-16"),
CRED_MAX_CREDENTIAL_BLOB_SIZE,
));
}
}
Ok(())
}
fn save_credential(&self, secret: &[u8]) -> Result<()> {
let mut username = to_wstr(&self.username);
let mut target_name = to_wstr(&self.target_name);
let mut target_alias = to_wstr(&self.target_alias);
let mut comment = to_wstr(&self.comment);
let mut blob = secret.to_vec();
let blob_len = blob.len() as u32;
let flags = CRED_FLAGS::default();
let cred_type = CRED_TYPE_GENERIC;
let persist = CRED_PERSIST_ENTERPRISE;
let last_written = FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let attribute_count = 0;
let attributes: *mut CREDENTIAL_ATTRIBUTEW = std::ptr::null_mut();
let mut credential = CREDENTIALW {
Flags: flags,
Type: cred_type,
TargetName: target_name.as_mut_ptr(),
Comment: comment.as_mut_ptr(),
LastWritten: last_written,
CredentialBlobSize: blob_len,
CredentialBlob: blob.as_mut_ptr(),
Persist: persist,
AttributeCount: attribute_count,
Attributes: attributes,
TargetAlias: target_alias.as_mut_ptr(),
UserName: username.as_mut_ptr(),
};
let p_credential: *const CREDENTIALW = &mut credential;
match unsafe { CredWriteW(p_credential, 0) } {
0 => Err(decode_error()),
_ => Ok(()),
}
}
pub fn get_credential(&self) -> Result<Self> {
self.extract_from_platform(Self::extract_credential)
}
fn extract_from_platform<F, T>(&self, f: F) -> Result<T>
where
F: FnOnce(&CREDENTIALW) -> Result<T>,
{
self.validate_attributes(None, None)?;
let mut p_credential = MaybeUninit::uninit();
let result = {
let cred_type = CRED_TYPE_GENERIC;
let target_name = to_wstr(&self.target_name);
unsafe {
CredReadW(
target_name.as_ptr(),
cred_type,
0,
p_credential.as_mut_ptr(),
)
}
};
match result {
0 => {
Err(decode_error())
}
_ => {
let p_credential = unsafe { p_credential.assume_init() };
let w_credential: CREDENTIALW = unsafe { *p_credential };
let result = f(&w_credential);
unsafe { CredFree(p_credential as *mut _) };
result
}
}
}
fn extract_credential(w_credential: &CREDENTIALW) -> Result<Self> {
Ok(Self {
username: unsafe { from_wstr(w_credential.UserName) },
target_name: unsafe { from_wstr(w_credential.TargetName) },
target_alias: unsafe { from_wstr(w_credential.TargetAlias) },
comment: unsafe { from_wstr(w_credential.Comment) },
})
}
pub fn new_with_target(
target: Option<&str>,
service: &str,
user: &str,
) -> Result<WinCredential> {
const VERSION: &str = env!("CARGO_PKG_VERSION");
let credential = if let Some(target) = target {
Self {
username: user.to_string(),
target_name: target.to_string(),
target_alias: String::new(),
comment: format!("{user}@{service}:{target} (keyring v{VERSION})"),
}
} else {
Self {
username: user.to_string(),
target_name: format!("{user}.{service}"),
target_alias: String::new(),
comment: format!("{user}@{service}:{user}.{service} (keyring v{VERSION})"),
}
};
credential.validate_attributes(None, None)?;
Ok(credential)
}
}
pub struct WinCredentialBuilder {}
pub fn default_credential_builder() -> Box<CredentialBuilder> {
Box::new(WinCredentialBuilder {})
}
impl CredentialBuilderApi for WinCredentialBuilder {
fn build(&self, target: Option<&str>, service: &str, user: &str) -> Result<Box<Credential>> {
Ok(Box::new(WinCredential::new_with_target(
target, service, user,
)?))
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
fn extract_password(credential: &CREDENTIALW) -> Result<String> {
let blob = extract_secret(credential)?;
if blob.len() % 2 != 0 {
return Err(ErrorCode::BadEncoding(blob));
}
let mut blob_u16 = vec![0; blob.len() / 2];
LittleEndian::read_u16_into(&blob, &mut blob_u16);
String::from_utf16(&blob_u16).map_err(|_| ErrorCode::BadEncoding(blob))
}
fn extract_secret(credential: &CREDENTIALW) -> Result<Vec<u8>> {
let blob_pointer: *const u8 = credential.CredentialBlob;
let blob_len: usize = credential.CredentialBlobSize as usize;
if blob_len == 0 {
return Ok(Vec::new());
}
let blob = unsafe { std::slice::from_raw_parts(blob_pointer, blob_len) };
Ok(blob.to_vec())
}
fn to_wstr(s: &str) -> Vec<u16> {
s.encode_utf16().chain(once(0)).collect()
}
fn to_wstr_no_null(s: &str) -> Vec<u16> {
s.encode_utf16().collect()
}
unsafe fn from_wstr(ws: *const u16) -> String {
if ws.is_null() {
return String::new();
}
let len = (0..).take_while(|&i| *ws.offset(i) != 0).count();
if len == 0 {
return String::new();
}
let slice = std::slice::from_raw_parts(ws, len);
String::from_utf16_lossy(slice)
}
#[derive(Debug)]
pub struct Error(pub u32);
impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self.0 {
ERROR_NO_SUCH_LOGON_SESSION => write!(f, "Windows ERROR_NO_SUCH_LOGON_SESSION"),
ERROR_NOT_FOUND => write!(f, "Windows ERROR_NOT_FOUND"),
ERROR_BAD_USERNAME => write!(f, "Windows ERROR_BAD_USERNAME"),
ERROR_INVALID_FLAGS => write!(f, "Windows ERROR_INVALID_FLAGS"),
ERROR_INVALID_PARAMETER => write!(f, "Windows ERROR_INVALID_PARAMETER"),
err => write!(f, "Windows error code {err}"),
}
}
}
impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
None
}
}
pub fn decode_error() -> ErrorCode {
match unsafe { GetLastError() } {
ERROR_NOT_FOUND => ErrorCode::NoEntry,
ERROR_NO_SUCH_LOGON_SESSION => {
ErrorCode::NoStorageAccess(wrap(ERROR_NO_SUCH_LOGON_SESSION))
}
err => ErrorCode::PlatformFailure(wrap(err)),
}
}
fn wrap(code: u32) -> Box<dyn std::error::Error + Send + Sync> {
Box::new(Error(code))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::credential::CredentialPersistence;
use crate::tests::{generate_random_string, generate_random_string_of_len};
use crate::Entry;
#[test]
fn test_persistence() {
assert!(matches!(
default_credential_builder().persistence(),
CredentialPersistence::UntilDelete
))
}
fn entry_new(service: &str, user: &str) -> Entry {
crate::tests::entry_from_constructor(WinCredential::new_with_target, service, user)
}
#[test]
fn test_bad_password() {
fn make_platform_credential(password: &mut Vec<u8>) -> CREDENTIALW {
let last_written = FILETIME {
dwLowDateTime: 0,
dwHighDateTime: 0,
};
let attribute_count = 0;
let attributes: *mut CREDENTIAL_ATTRIBUTEW = std::ptr::null_mut();
CREDENTIALW {
Flags: 0,
Type: CRED_TYPE_GENERIC,
TargetName: std::ptr::null_mut(),
Comment: std::ptr::null_mut(),
LastWritten: last_written,
CredentialBlobSize: password.len() as u32,
CredentialBlob: password.as_mut_ptr(),
Persist: CRED_PERSIST_ENTERPRISE,
AttributeCount: attribute_count,
Attributes: attributes,
TargetAlias: std::ptr::null_mut(),
UserName: std::ptr::null_mut(),
}
}
let mut odd_bytes = b"1".to_vec();
let malformed_utf16 = [0xD834, 0xDD1E, 0x006d, 0x0075, 0xD800, 0x0069, 0x0063];
let mut malformed_bytes: Vec<u8> = vec![0; malformed_utf16.len() * 2];
LittleEndian::write_u16_into(&malformed_utf16, &mut malformed_bytes);
for bytes in [&mut odd_bytes, &mut malformed_bytes] {
let credential = make_platform_credential(bytes);
match extract_password(&credential) {
Err(ErrorCode::BadEncoding(str)) => assert_eq!(&str, bytes),
Err(other) => panic!("Bad password ({bytes:?}) decode gave wrong error: {other}"),
Ok(s) => panic!("Bad password ({bytes:?}) decode gave results: {s:?}"),
}
}
}
#[test]
fn test_validate_attributes() {
fn validate_attribute_too_long(result: Result<()>, attr: &str, len: u32) {
match result {
Err(ErrorCode::TooLong(arg, val)) => {
if attr == "password" {
assert_eq!(
&arg, "password encoded as UTF-16",
"Error names wrong attribute"
);
} else {
assert_eq!(&arg, attr, "Error names wrong attribute");
}
assert_eq!(val, len, "Error names wrong limit");
}
Err(other) => panic!("Error is not '{attr} too long': {other}"),
Ok(_) => panic!("No error when {attr} too long"),
}
}
let cred = WinCredential {
username: "username".to_string(),
target_name: "target_name".to_string(),
target_alias: "target_alias".to_string(),
comment: "comment".to_string(),
};
for (attr, len) in [
("user", CRED_MAX_USERNAME_LENGTH),
("target", CRED_MAX_GENERIC_TARGET_NAME_LENGTH),
("target alias", CRED_MAX_STRING_LENGTH),
("comment", CRED_MAX_STRING_LENGTH),
("password", CRED_MAX_CREDENTIAL_BLOB_SIZE),
("secret", CRED_MAX_CREDENTIAL_BLOB_SIZE),
] {
let long_string = generate_random_string_of_len(1 + len as usize);
let mut bad_cred = cred.clone();
match attr {
"user" => bad_cred.username = long_string.clone(),
"target" => bad_cred.target_name = long_string.clone(),
"target alias" => bad_cred.target_alias = long_string.clone(),
"comment" => bad_cred.comment = long_string.clone(),
_ => (),
}
let validate = |r| validate_attribute_too_long(r, attr, len);
match attr {
"password" => {
let password = generate_random_string_of_len((len / 2) as usize + 1);
validate(bad_cred.validate_attributes(None, Some(&password)))
}
"secret" => {
let secret: Vec<u8> = vec![255u8; len as usize + 1];
validate(bad_cred.validate_attributes(Some(&secret), None))
}
_ => validate(bad_cred.validate_attributes(None, None)),
}
}
}
#[test]
fn test_password_valid_only_after_conversion_to_utf16() {
let cred = WinCredential {
username: "username".to_string(),
target_name: "target_name".to_string(),
target_alias: "target_alias".to_string(),
comment: "comment".to_string(),
};
let len = CRED_MAX_CREDENTIAL_BLOB_SIZE / 2;
let password: String = (0..len).map(|_| "笑").collect();
assert!(password.len() > CRED_MAX_CREDENTIAL_BLOB_SIZE as usize);
cred.validate_attributes(None, Some(&password))
.expect("Password of appropriate length in UTF16 was invalid");
}
#[test]
fn test_invalid_parameter() {
let credential = WinCredential::new_with_target(Some(""), "service", "user");
assert!(
matches!(credential, Err(ErrorCode::Invalid(_, _))),
"Created entry with empty target"
);
}
#[test]
fn test_empty_service_and_user() {
crate::tests::test_empty_service_and_user(entry_new);
}
#[test]
fn test_missing_entry() {
crate::tests::test_missing_entry(entry_new);
}
#[test]
fn test_empty_password() {
crate::tests::test_empty_password(entry_new);
}
#[test]
fn test_round_trip_ascii_password() {
crate::tests::test_round_trip_ascii_password(entry_new);
}
#[test]
fn test_round_trip_non_ascii_password() {
crate::tests::test_round_trip_non_ascii_password(entry_new);
}
#[test]
fn test_round_trip_random_secret() {
crate::tests::test_round_trip_random_secret(entry_new);
}
#[test]
fn test_update() {
crate::tests::test_update(entry_new);
}
#[test]
fn test_get_update_attributes() {
let name = generate_random_string();
let cred = WinCredential::new_with_target(None, &name, &name)
.expect("Can't create credential for attribute test");
let entry = Entry::new_with_credential(Box::new(cred.clone()));
assert!(
matches!(entry.get_attributes(), Err(ErrorCode::NoEntry)),
"Read missing credential in attribute test",
);
let mut in_map: HashMap<&str, &str> = HashMap::new();
in_map.insert("label", "ignored label value");
in_map.insert("attribute name", "ignored attribute value");
in_map.insert("target_alias", "target alias value");
in_map.insert("comment", "comment value");
in_map.insert("username", "username value");
assert!(
matches!(entry.update_attributes(&in_map), Err(ErrorCode::NoEntry)),
"Updated missing credential in attribute test",
);
entry
.set_password("test password for attributes")
.unwrap_or_else(|err| panic!("Can't set password for attribute test: {err:?}"));
let out_map = entry
.get_attributes()
.expect("Can't get attributes after create");
assert_eq!(out_map["target_alias"], cred.target_alias);
assert_eq!(out_map["comment"], cred.comment);
assert_eq!(out_map["username"], cred.username);
assert!(
matches!(entry.update_attributes(&in_map), Ok(())),
"Couldn't update attributes in attribute test",
);
let after_map = entry
.get_attributes()
.expect("Can't get attributes after update");
assert_eq!(after_map["target_alias"], in_map["target_alias"]);
assert_eq!(after_map["comment"], in_map["comment"]);
assert_eq!(after_map["username"], in_map["username"]);
assert!(!after_map.contains_key("label"));
assert!(!after_map.contains_key("attribute name"));
entry
.delete_credential()
.unwrap_or_else(|err| panic!("Can't delete credential for attribute test: {err:?}"));
assert!(
matches!(entry.get_attributes(), Err(ErrorCode::NoEntry)),
"Read deleted credential in attribute test",
);
}
#[test]
fn test_get_credential() {
let name = generate_random_string();
let entry = entry_new(&name, &name);
let password = "test get password";
entry
.set_password(password)
.expect("Can't set test get password");
let credential: &WinCredential = entry
.get_credential()
.downcast_ref()
.expect("Not a windows credential");
let actual = credential.get_credential().expect("Can't read credential");
assert_eq!(
actual.username, credential.username,
"Usernames don't match"
);
assert_eq!(
actual.target_name, credential.target_name,
"Target names don't match"
);
assert_eq!(
actual.target_alias, credential.target_alias,
"Target aliases don't match"
);
assert_eq!(actual.comment, credential.comment, "Comments don't match");
entry
.delete_credential()
.expect("Couldn't delete get-credential");
assert!(matches!(entry.get_password(), Err(ErrorCode::NoEntry)));
}
}