use std::fmt;
use std::hash::{Hash, Hasher};
use std::ops::Deref;
use std::sync::atomic::{AtomicUsize, Ordering};
use zeroize::{Zeroize, ZeroizeOnDrop};
pub trait SensitiveData {
fn display_name(&self) -> &str;
fn is_highly_sensitive(&self) -> bool {
false
}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub enum SensitivityLevel {
#[default]
Low,
Medium,
High,
Critical,
}
impl SensitivityLevel {
pub fn is_critical_or_high(&self) -> bool {
matches!(self, SensitivityLevel::Critical | SensitivityLevel::High)
}
}
static ALLOCATED_SECURE_STRINGS: AtomicUsize = AtomicUsize::new(0);
static DEALLOCATED_SECURE_STRINGS: AtomicUsize = AtomicUsize::new(0);
pub fn allocated_secure_strings() -> usize {
ALLOCATED_SECURE_STRINGS.load(Ordering::SeqCst)
}
pub fn deallocated_secure_strings() -> usize {
DEALLOCATED_SECURE_STRINGS.load(Ordering::SeqCst)
}
#[cfg(test)]
pub fn reset_secure_string_counters() {
ALLOCATED_SECURE_STRINGS.store(0, Ordering::SeqCst);
DEALLOCATED_SECURE_STRINGS.store(0, Ordering::SeqCst);
}
#[derive(Eq)]
pub struct SecureString {
data: Vec<u8>,
sensitivity: SensitivityLevel,
display_name: String,
}
impl SecureString {
pub fn new(s: impl Into<String>, sensitivity: SensitivityLevel) -> Self {
let string = s.into();
ALLOCATED_SECURE_STRINGS.fetch_add(1, Ordering::SeqCst);
let display_name = string.clone();
let data = string.into_bytes();
Self {
data,
sensitivity,
display_name,
}
}
pub fn from(s: impl Into<String>) -> Self {
Self::new(s, SensitivityLevel::Critical)
}
pub fn from_bytes(data: Vec<u8>, sensitivity: SensitivityLevel) -> Self {
ALLOCATED_SECURE_STRINGS.fetch_add(1, Ordering::SeqCst);
Self {
data,
sensitivity,
display_name: "[binary data]".to_string(),
}
}
pub fn as_str(&self) -> &str {
std::str::from_utf8(&self.data).unwrap_or("[invalid utf-8]")
}
pub fn as_bytes(&self) -> &[u8] {
&self.data
}
pub fn to_plain_string(self) -> String {
String::from_utf8(self.data.clone()).unwrap_or_default()
}
#[allow(clippy::result_unit_err)]
pub fn compare(&self, other: &str) -> Result<(), ()> {
let mut result: u8 = 0;
for (a, b) in self.data.iter().zip(other.bytes()) {
result |= a ^ b;
}
if self.data.len() != other.len() {
result |= 1;
}
if result == 0 {
Ok(())
} else {
Err(())
}
}
pub fn sensitivity(&self) -> SensitivityLevel {
self.sensitivity.clone()
}
pub fn is_highly_sensitive(&self) -> bool {
self.sensitivity.is_critical_or_high()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn zeroize(&mut self) {
self.data.zeroize();
}
pub fn fingerprint(&self, max_len: usize) -> String {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(&self.data);
let result = hasher.finalize();
let hex = hex::encode(result);
if hex.len() > max_len {
hex[..max_len].to_string()
} else {
hex
}
}
pub fn masked(&self) -> String {
if self.data.is_empty() {
return "[empty]".to_string();
}
let s = self.as_str();
let len = s.len();
match len {
0 => "[empty]".to_string(),
1..=2 => "*".repeat(len),
3..=4 => {
let visible = if len == 3 { 1 } else { 2 };
format!("{}{}", &s[..visible], "*".repeat(len - visible))
}
_ => {
let visible = std::cmp::min(2, len / 4);
let masked_chars = std::cmp::min(6, len - visible);
format!("{}{}", &s[..visible], "*".repeat(masked_chars))
}
}
}
}
impl SensitiveData for SecureString {
fn display_name(&self) -> &str {
&self.display_name
}
fn is_highly_sensitive(&self) -> bool {
self.is_highly_sensitive()
}
}
impl Drop for SecureString {
fn drop(&mut self) {
self.data.zeroize();
DEALLOCATED_SECURE_STRINGS.fetch_add(1, Ordering::SeqCst);
}
}
impl ZeroizeOnDrop for SecureString {}
impl Clone for SecureString {
fn clone(&self) -> Self {
#[cfg(feature = "tracing")]
tracing::warn!("Cloning SecureString - this may leak sensitive data");
Self {
data: self.data.clone(),
sensitivity: self.sensitivity.clone(),
display_name: self.display_name.clone(),
}
}
}
impl fmt::Debug for SecureString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SecureString({})", self.masked())
}
}
impl fmt::Display for SecureString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.masked())
}
}
impl PartialEq for SecureString {
fn eq(&self, other: &Self) -> bool {
self.compare(other.as_str()).is_ok()
}
}
impl Hash for SecureString {
fn hash<H: Hasher>(&self, state: &mut H) {
self.data.hash(state);
}
}
impl Deref for SecureString {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
#[derive(Default)]
pub struct SecureStringBuilder {
data: Vec<u8>,
sensitivity: SensitivityLevel,
display_name: Option<String>,
}
impl SecureStringBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn with_sensitivity(mut self, sensitivity: SensitivityLevel) -> Self {
self.sensitivity = sensitivity;
self
}
pub fn with_display_name(mut self, name: impl Into<String>) -> Self {
self.display_name = Some(name.into());
self
}
pub fn push(mut self, c: char) -> Self {
let mut buf = [0u8; 4];
let encoded = c.encode_utf8(&mut buf);
self.data.extend_from_slice(encoded.as_bytes());
self
}
pub fn push_str(mut self, s: &str) -> Self {
self.data.extend_from_slice(s.as_bytes());
self
}
pub fn push_u8(mut self, b: u8) -> Self {
self.data.push(b);
self
}
pub fn build(self) -> SecureString {
let display_name = self
.display_name
.unwrap_or_else(|| String::from_utf8_lossy(&self.data).into_owned());
ALLOCATED_SECURE_STRINGS.fetch_add(1, Ordering::SeqCst);
SecureString {
data: self.data,
sensitivity: self.sensitivity,
display_name,
}
}
}
impl From<&str> for SecureString {
fn from(s: &str) -> Self {
Self::from(s.to_string())
}
}
impl From<String> for SecureString {
fn from(s: String) -> Self {
Self::from(s)
}
}
impl From<&String> for SecureString {
fn from(s: &String) -> Self {
Self::from(s.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_secure_string_creation() {
let secret = SecureString::from("password123");
assert_eq!(secret.len(), 11);
assert!(!secret.is_empty());
}
#[test]
fn test_secure_string_compare() {
let secret = SecureString::from("password123");
assert!(secret.compare("password123").is_ok());
assert!(secret.compare("wrongpassword").is_err());
}
#[test]
fn test_secure_string_masked() {
let secret = SecureString::from("password123");
let masked = secret.masked();
assert!(masked.contains('*'));
assert!(masked.len() < 12);
}
#[test]
fn test_secure_string_display() {
let secret = SecureString::from("password123");
let display = format!("{}", secret);
assert!(display.contains('*'));
}
#[test]
fn test_secure_string_debug() {
let secret = SecureString::from("password123");
let debug = format!("{:?}", secret);
assert!(debug.contains("SecureString"));
}
#[test]
fn test_sensitivity_levels() {
let critical = SecureString::new("secret", SensitivityLevel::Critical);
let high = SecureString::new("token", SensitivityLevel::High);
let medium = SecureString::new("user", SensitivityLevel::Medium);
let low = SecureString::new("config", SensitivityLevel::Low);
assert!(critical.is_highly_sensitive());
assert!(high.is_highly_sensitive());
assert!(!medium.is_highly_sensitive());
assert!(!low.is_highly_sensitive());
}
#[test]
fn test_fingerprint() {
let secret = SecureString::from("password123");
let fp = secret.fingerprint(16);
assert_eq!(fp.len(), 16);
assert!(fp.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_secure_string_builder() {
let secret = SecureStringBuilder::new()
.push_str("pass")
.push('w')
.push_str("ord")
.build();
assert_eq!(secret.as_str(), "password");
}
#[test]
fn test_from_bytes() {
let data = vec![0x01, 0x02, 0x03, 0x04];
let secret = SecureString::from_bytes(data.clone(), SensitivityLevel::High);
assert_eq!(secret.as_bytes(), data.as_slice());
assert_eq!(secret.display_name(), "[binary data]");
}
#[test]
#[ignore = "计数器是全局的,会受其他测试影响,仅用于手动验证"]
fn test_allocation_counters() {
let initial_allocated = allocated_secure_strings();
let initial_deallocated = deallocated_secure_strings();
let _secret1 = SecureString::from("test1");
let _secret2 = SecureString::from("test2");
assert_eq!(
allocated_secure_strings(),
initial_allocated + 2,
"Should allocate 2 new SecureStrings"
);
assert_eq!(
deallocated_secure_strings(),
initial_deallocated,
"Should not deallocate any SecureStrings"
);
}
#[test]
fn test_partial_eq() {
let secret1 = SecureString::from("password");
let secret2 = SecureString::from("password");
let secret3 = SecureString::from("different");
assert_eq!(secret1, secret2);
assert_ne!(secret1, secret3);
}
#[test]
fn test_hash() {
use std::collections::HashSet;
let secret1 = SecureString::from("password");
let secret2 = SecureString::from("password");
let secret3 = SecureString::from("different");
let mut set = HashSet::new();
set.insert(secret1.clone());
set.insert(secret2.clone());
set.insert(secret3.clone());
assert_eq!(set.len(), 2);
assert!(set.contains(&secret1));
assert!(set.contains(&secret2));
assert!(set.contains(&secret3));
}
}