use std::fmt;
use std::ops::Not;
use crate::utf16string::ZeroizedUtf16String;
use crate::{Error, Secret, Utf16String, Utf16StringExt};
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum UsernameError {
MixedFormat,
}
impl std::error::Error for UsernameError {}
impl fmt::Display for UsernameError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UsernameError::MixedFormat => write!(f, "mixed username format"),
}
}
}
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum UserNameFormat {
UserPrincipalName,
DownLevelLogonName,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct Username {
value: String,
format: UserNameFormat,
sep_idx: Option<usize>,
}
impl Username {
pub fn new_upn(account_name: &str, upn_suffix: &str) -> Result<Self, UsernameError> {
if account_name.contains(['\\']) {
return Err(UsernameError::MixedFormat);
}
if upn_suffix.contains(['\\', '@']) {
return Err(UsernameError::MixedFormat);
}
Ok(Self {
value: format!("{account_name}@{upn_suffix}"),
format: UserNameFormat::UserPrincipalName,
sep_idx: Some(account_name.len()),
})
}
pub fn new_down_level_logon_name(account_name: &str, netbios_domain_name: &str) -> Result<Self, UsernameError> {
if account_name.contains(['\\', '@']) {
return Err(UsernameError::MixedFormat);
}
if netbios_domain_name.contains(['\\', '@']) {
return Err(UsernameError::MixedFormat);
}
Ok(Self {
value: format!("{netbios_domain_name}\\{account_name}"),
format: UserNameFormat::DownLevelLogonName,
sep_idx: Some(netbios_domain_name.len()),
})
}
pub fn new(account_name: &str, netbios_domain_name: Option<&str>) -> Result<Self, UsernameError> {
match netbios_domain_name {
Some(netbios_domain_name) if !netbios_domain_name.is_empty() => {
Self::new_down_level_logon_name(account_name, netbios_domain_name)
}
_ => Self::parse(account_name),
}
}
pub fn parse(value: &str) -> Result<Self, UsernameError> {
match (value.split_once('\\'), value.rsplit_once('@')) {
(None, None) => Ok(Self {
value: value.to_owned(),
format: UserNameFormat::DownLevelLogonName,
sep_idx: None,
}),
(Some((netbios_domain_name, account_name)), _) => {
Self::new_down_level_logon_name(account_name, netbios_domain_name)
}
(_, Some((account_name, upn_suffix))) => Self::new_upn(account_name, upn_suffix),
}
}
pub fn inner(&self) -> &str {
&self.value
}
pub fn format(&self) -> UserNameFormat {
self.format
}
pub fn domain_name(&self) -> Option<&str> {
self.sep_idx.map(|idx| match self.format {
UserNameFormat::UserPrincipalName => &self.value[idx + 1..],
UserNameFormat::DownLevelLogonName => &self.value[..idx],
})
}
pub fn account_name(&self) -> &str {
if let Some(idx) = self.sep_idx {
match self.format {
UserNameFormat::UserPrincipalName => &self.value[..idx],
UserNameFormat::DownLevelLogonName => &self.value[idx + 1..],
}
} else {
&self.value
}
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct AuthIdentity {
pub username: Username,
pub password: Secret<String>,
}
#[derive(Clone, Eq, PartialEq, Default)]
pub struct AuthIdentityBuffers {
pub user: Utf16String,
pub domain: Utf16String,
pub password: Secret<ZeroizedUtf16String>,
}
impl AuthIdentityBuffers {
pub fn new(user: Utf16String, domain: Utf16String, password: Utf16String) -> Self {
Self {
user,
domain,
password: ZeroizedUtf16String(password).into(),
}
}
pub fn is_empty(&self) -> bool {
self.user.is_empty()
}
pub fn from_utf8(user: &str, domain: &str, password: &str) -> Self {
Self {
user: user.into(),
domain: domain.into(),
password: ZeroizedUtf16String(Utf16String::from(password)).into(),
}
}
pub fn from_utf8_with_hash(user: &str, domain: &str, nt_hash: &crate::NtlmHash) -> Self {
Self {
user: user.into(),
domain: domain.into(),
password: ZeroizedUtf16String(Utf16String::from(nt_hash.to_sspi_password())).into(),
}
}
}
impl fmt::Debug for AuthIdentityBuffers {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "AuthIdentityBuffers {{ user: 0x")?;
self.user
.as_bytes_le()
.iter()
.try_for_each(|byte| write!(f, "{byte:02X}"))?;
write!(f, ", domain: 0x")?;
self.domain
.as_bytes_le()
.iter()
.try_for_each(|byte| write!(f, "{byte:02X}"))?;
write!(f, ", password: {:?} }}", self.password)?;
Ok(())
}
}
impl From<AuthIdentity> for AuthIdentityBuffers {
fn from(credentials: AuthIdentity) -> Self {
let password: &str = credentials.password.as_ref().as_ref();
Self {
user: credentials.username.account_name().into(),
domain: credentials.username.domain_name().unwrap_or_default().into(),
password: ZeroizedUtf16String(password.into()).into(),
}
}
}
impl TryFrom<&AuthIdentityBuffers> for AuthIdentity {
type Error = UsernameError;
fn try_from(credentials_buffers: &AuthIdentityBuffers) -> Result<Self, Self::Error> {
let account_name = credentials_buffers.user.to_string();
let domain_name = credentials_buffers
.domain
.is_empty()
.not()
.then(|| credentials_buffers.domain.to_string());
let username = Username::new(&account_name, domain_name.as_deref())?;
let password = credentials_buffers.password.as_ref().as_ref().to_string().into();
Ok(Self { username, password })
}
}
impl TryFrom<AuthIdentityBuffers> for AuthIdentity {
type Error = UsernameError;
fn try_from(credentials_buffers: AuthIdentityBuffers) -> Result<Self, Self::Error> {
AuthIdentity::try_from(&credentials_buffers)
}
}
#[cfg(feature = "scard")]
mod scard_credentials {
#[cfg(not(target_arch = "wasm32"))]
use std::path::PathBuf;
use picky::key::PrivateKey;
use picky_asn1_der::Asn1DerError;
use picky_asn1_x509::Certificate;
use crate::secret::SecretPrivateKey;
use crate::utf16string::ZeroizedUtf16String;
use crate::{Error, ErrorKind, NonEmpty, Secret, Utf16String};
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct CertificateRaw(Vec<u8>);
impl AsRef<[u8]> for CertificateRaw {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl TryFrom<Vec<u8>> for CertificateRaw {
type Error = Asn1DerError;
fn try_from(value: Vec<u8>) -> Result<Self, Self::Error> {
let _: Certificate = picky_asn1_der::from_bytes(value.as_ref())?;
Ok(Self(value))
}
}
impl From<CertificateRaw> for Vec<u8> {
fn from(value: CertificateRaw) -> Self {
value.0
}
}
impl TryFrom<&Certificate> for CertificateRaw {
type Error = Asn1DerError;
fn try_from(value: &Certificate) -> Result<Self, Self::Error> {
picky_asn1_der::to_vec(value).map(Self)
}
}
impl TryFrom<Certificate> for CertificateRaw {
type Error = Asn1DerError;
fn try_from(value: Certificate) -> Result<Self, Self::Error> {
Self::try_from(&value)
}
}
impl From<&CertificateRaw> for Certificate {
fn from(value: &CertificateRaw) -> Self {
picky_asn1_der::from_bytes(&value.0).expect("value.0 is convertible to Certificate (checked on creation)")
}
}
impl From<CertificateRaw> for Certificate {
fn from(value: CertificateRaw) -> Self {
Self::from(&value)
}
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum SmartCardType {
Emulated {
scard_pin: Secret<Vec<u8>>,
},
#[cfg(not(target_arch = "wasm32"))]
SystemProvided {
pkcs11_module_path: PathBuf,
},
#[cfg(target_os = "windows")]
WindowsNative,
}
#[derive(Clone, Eq, PartialEq, Debug)]
pub struct SmartCardIdentityBuffers {
pub username: Utf16String,
pub certificate: CertificateRaw,
pub card_name: Option<NonEmpty<Utf16String>>,
pub reader_name: Utf16String,
pub container_name: Option<NonEmpty<Utf16String>>,
pub csp_name: Utf16String,
pub pin: Secret<ZeroizedUtf16String>,
pub private_key_pem: Option<NonEmpty<Utf16String>>,
pub scard_type: SmartCardType,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SmartCardIdentity {
pub username: String,
pub certificate: Certificate,
pub reader_name: String,
pub card_name: Option<String>,
pub container_name: Option<String>,
pub csp_name: String,
pub pin: Secret<Vec<u8>>,
pub private_key: Option<SecretPrivateKey>,
pub scard_type: SmartCardType,
}
impl TryFrom<SmartCardIdentity> for SmartCardIdentityBuffers {
type Error = Error;
fn try_from(value: SmartCardIdentity) -> Result<Self, Self::Error> {
let private_key = if let Some(key) = value.private_key {
NonEmpty::new(Utf16String::from(key.as_ref().to_pem_str().map_err(|e| {
Error::new(
ErrorKind::InternalError,
format!("Unable to serialize a smart card private key: {e}"),
)
})?))
} else {
None
};
Ok(Self {
certificate: value.certificate.try_into()?,
reader_name: value.reader_name.into(),
pin: ZeroizedUtf16String(String::from_utf8_lossy(value.pin.as_ref()).as_ref().into()).into(),
username: value.username.into(),
card_name: value.card_name.and_then(|value| NonEmpty::new(value.into())),
container_name: value.container_name.and_then(|value| NonEmpty::new(value.into())),
csp_name: value.csp_name.into(),
private_key_pem: private_key,
scard_type: value.scard_type,
})
}
}
impl TryFrom<&SmartCardIdentityBuffers> for SmartCardIdentity {
type Error = Error;
fn try_from(value: &SmartCardIdentityBuffers) -> Result<Self, Self::Error> {
let private_key = if let Some(key) = &value.private_key_pem {
let pem_string = key.as_ref().to_string();
Some(SecretPrivateKey::new(PrivateKey::from_pem_str(&pem_string).map_err(
|e| {
Error::new(
ErrorKind::InternalError,
format!("Unable to create a PrivateKey from a PEM string: {e}"),
)
},
)?))
} else {
None
};
Ok(Self {
certificate: Certificate::from(&value.certificate),
reader_name: value.reader_name.to_string(),
pin: value.pin.as_ref().0.to_string().into_bytes().into(),
username: value.username.to_string(),
card_name: value.card_name.as_ref().map(NonEmpty::as_ref).map(ToString::to_string),
container_name: value
.container_name
.as_ref()
.map(NonEmpty::as_ref)
.map(ToString::to_string),
csp_name: value.csp_name.to_string(),
private_key,
scard_type: value.scard_type.clone(),
})
}
}
}
#[cfg(feature = "scard")]
pub use self::scard_credentials::{CertificateRaw, SmartCardIdentity, SmartCardIdentityBuffers, SmartCardType};
#[derive(Clone, Eq, PartialEq, Debug)]
pub enum CredentialsBuffers {
AuthIdentity(AuthIdentityBuffers),
#[cfg(feature = "scard")]
SmartCard(SmartCardIdentityBuffers),
}
impl CredentialsBuffers {
pub fn into_auth_identity(self) -> Option<AuthIdentityBuffers> {
match self {
CredentialsBuffers::AuthIdentity(identity) => Some(identity),
#[cfg(feature = "scard")]
_ => None,
}
}
pub fn to_auth_identity(&self) -> Option<AuthIdentityBuffers> {
match self {
CredentialsBuffers::AuthIdentity(identity) => Some(identity.clone()),
#[cfg(feature = "scard")]
_ => None,
}
}
pub fn as_auth_identity(&self) -> Option<&AuthIdentityBuffers> {
match self {
CredentialsBuffers::AuthIdentity(identity) => Some(identity),
#[cfg(feature = "scard")]
_ => None,
}
}
pub fn as_mut_auth_identity(&mut self) -> Option<&mut AuthIdentityBuffers> {
match self {
CredentialsBuffers::AuthIdentity(identity) => Some(identity),
#[cfg(feature = "scard")]
_ => None,
}
}
}
#[derive(Clone, PartialEq, Debug)]
pub enum Credentials {
AuthIdentity(AuthIdentity),
#[cfg(feature = "scard")]
SmartCard(Box<SmartCardIdentity>),
}
impl Credentials {
pub fn to_auth_identity(&self) -> Option<AuthIdentity> {
match self {
Credentials::AuthIdentity(identity) => Some(identity.clone()),
#[cfg(feature = "scard")]
_ => None,
}
}
pub fn auth_identity(self) -> Option<AuthIdentity> {
match self {
Credentials::AuthIdentity(identity) => Some(identity),
#[cfg(feature = "scard")]
_ => None,
}
}
}
#[cfg(feature = "scard")]
impl From<SmartCardIdentity> for Credentials {
fn from(value: SmartCardIdentity) -> Self {
Self::SmartCard(Box::new(value))
}
}
impl From<AuthIdentity> for Credentials {
fn from(value: AuthIdentity) -> Self {
Self::AuthIdentity(value)
}
}
impl TryFrom<Credentials> for CredentialsBuffers {
type Error = Error;
fn try_from(value: Credentials) -> Result<Self, Self::Error> {
Ok(match value {
Credentials::AuthIdentity(identity) => Self::AuthIdentity(identity.into()),
#[cfg(feature = "scard")]
Credentials::SmartCard(identity) => Self::SmartCard((*identity).try_into()?),
})
}
}
#[cfg(test)]
mod tests {
use proptest::prelude::*;
use super::*;
#[test]
fn username_format_conversion() {
proptest!(|(value in "[a-zA-Z0-9.]{1,3}@?\\\\?[a-zA-Z0-9.]{1,3}@?\\\\?[a-zA-Z0-9.]{1,3}")| {
let res = Username::parse(&value);
prop_assume!(res.is_ok());
let initial_username = res.unwrap();
assert_eq!(initial_username.inner(), value);
if let Some(domain_name) = initial_username.domain_name() {
let upn = Username::new_upn(initial_username.account_name(), domain_name).expect("UPN");
assert_eq!(upn.account_name(), initial_username.account_name());
assert_eq!(upn.domain_name(), initial_username.domain_name());
}
if !initial_username.account_name().contains('@') {
let netbios_name = Username::new(initial_username.account_name(), initial_username.domain_name()).expect("NetBIOS");
assert_eq!(netbios_name.format(), UserNameFormat::DownLevelLogonName);
assert_eq!(netbios_name.account_name(), initial_username.account_name());
assert_eq!(netbios_name.domain_name(), initial_username.domain_name());
}
})
}
fn check_round_trip_property(username: &Username) {
let round_trip = Username::parse(username.inner()).expect("round-trip parse");
assert_eq!(*username, round_trip);
}
#[test]
fn upn_round_trip() {
proptest!(|(account_name in "[a-zA-Z0-9@.]{1,3}", domain_name in "[a-z0-9.]{1,3}")| {
let username = Username::new_upn(&account_name, &domain_name).expect("UPN");
assert_eq!(username.account_name(), account_name);
assert_eq!(username.domain_name(), Some(domain_name.as_str()));
assert_eq!(username.format(), UserNameFormat::UserPrincipalName);
check_round_trip_property(&username);
})
}
#[test]
fn down_level_logon_name_round_trip() {
proptest!(|(account_name in "[a-zA-Z0-9.]{1,3}", domain_name in "[A-Z0-9.]{1,3}")| {
let username = Username::new_down_level_logon_name(&account_name, &domain_name).expect("down-level logon name");
assert_eq!(username.account_name(), account_name);
assert_eq!(username.domain_name(), Some(domain_name.as_str()));
assert_eq!(username.format(), UserNameFormat::DownLevelLogonName);
check_round_trip_property(&username);
})
}
}