use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::cmp::Ordering;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use zeroize::Zeroizing;
#[derive(Clone)]
pub struct Credential {
inner: Arc<Zeroizing<Box<[u8]>>>,
}
impl Credential {
#[must_use]
pub fn from_bytes(bytes: &[u8]) -> Self {
Self {
inner: Arc::new(Zeroizing::new(bytes.to_vec().into_boxed_slice())),
}
}
#[must_use]
pub fn from_text(s: &str) -> Self {
Self::from_bytes(s.as_bytes())
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
#[must_use]
pub fn expose_secret(&self) -> &[u8] {
&self.inner
}
#[must_use]
pub fn expose_str(&self) -> Option<&str> {
std::str::from_utf8(&self.inner).ok()
}
}
impl From<&str> for Credential {
fn from(s: &str) -> Self {
Self::from_text(s)
}
}
impl From<String> for Credential {
fn from(s: String) -> Self {
Self::from_bytes(s.as_bytes())
}
}
impl From<&[u8]> for Credential {
fn from(b: &[u8]) -> Self {
Self::from_bytes(b)
}
}
impl From<Vec<u8>> for Credential {
fn from(v: Vec<u8>) -> Self {
Self::from_bytes(&v)
}
}
impl PartialEq for Credential {
fn eq(&self, other: &Self) -> bool {
if self.inner.len() != other.inner.len() {
return false;
}
let mut diff: u8 = 0;
for (a, b) in self.inner.iter().zip(other.inner.iter()) {
diff |= a ^ b;
}
diff == 0
}
}
impl Eq for Credential {}
impl PartialOrd for Credential {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Credential {
fn cmp(&self, other: &Self) -> Ordering {
self.inner
.as_ref()
.as_ref()
.cmp(other.inner.as_ref().as_ref())
}
}
impl Hash for Credential {
fn hash<H: Hasher>(&self, state: &mut H) {
self.inner.as_ref().as_ref().hash(state);
}
}
impl std::fmt::Debug for Credential {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Credential(<redacted {} bytes>)", self.inner.len())
}
}
impl std::fmt::Display for Credential {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "<redacted {} bytes>", self.inner.len())
}
}
impl Serialize for Credential {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
use serde::ser::SerializeMap;
let mut m = serializer.serialize_map(Some(1))?;
match self.expose_str() {
Some(s) => m.serialize_entry("text", s)?,
None => m.serialize_entry("b64", &base64_encode(&self.inner))?,
}
m.end()
}
}
impl<'de> Deserialize<'de> for Credential {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
#[derive(Deserialize)]
#[serde(untagged)]
enum Wire {
Tagged {
#[serde(default)]
text: Option<String>,
#[serde(default)]
b64: Option<String>,
},
Legacy(String),
}
match Wire::deserialize(deserializer)? {
Wire::Tagged {
text: Some(t),
b64: None,
} => Ok(Credential::from_text(&t)),
Wire::Tagged {
text: None,
b64: Some(b),
} => {
let bytes = crate::encoding::decode_standard_base64(&b)
.map_err(serde::de::Error::custom)?;
Ok(Credential::from_bytes(&bytes))
}
Wire::Tagged { .. } => Err(serde::de::Error::custom(
"Credential must specify exactly one of `text` or `b64`",
)),
Wire::Legacy(s) => {
if let Some(rest) = s.strip_prefix("b64:") {
let bytes = crate::encoding::decode_standard_base64(rest)
.map_err(serde::de::Error::custom)?;
Ok(Credential::from_bytes(&bytes))
} else {
Ok(Credential::from_text(&s))
}
}
}
}
}
fn base64_encode(input: &[u8]) -> String {
const TABLE: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
for chunk in input.chunks(3) {
let b0 = chunk[0];
let b1 = chunk.get(1).copied().unwrap_or(0);
let b2 = chunk.get(2).copied().unwrap_or(0);
out.push(TABLE[(b0 >> 2) as usize] as char);
out.push(TABLE[(((b0 & 0x03) << 4) | (b1 >> 4)) as usize] as char);
if chunk.len() > 1 {
out.push(TABLE[(((b1 & 0x0F) << 2) | (b2 >> 6)) as usize] as char);
} else {
out.push('=');
}
if chunk.len() > 2 {
out.push(TABLE[(b2 & 0x3F) as usize] as char);
} else {
out.push('=');
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn debug_redacts_bytes() {
let c = Credential::from_text("AKIAIOSFODNN7EXAMPLE");
let s = format!("{c:?}");
assert!(s.contains("redacted"));
assert!(!s.contains("AKIA"));
}
#[test]
fn display_redacts_bytes() {
let c = Credential::from_text("ghp_abcdef1234567890");
let s = format!("{c}");
assert!(s.contains("redacted"));
assert!(!s.contains("ghp_"));
}
#[test]
fn expose_secret_returns_bytes() {
let c = Credential::from_text("hello");
assert_eq!(c.expose_secret(), b"hello");
assert_eq!(c.expose_str(), Some("hello"));
}
#[test]
fn equality_constant_time() {
let a = Credential::from_text("aaa");
let b = Credential::from_text("aaa");
let c = Credential::from_text("aab");
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn serialize_utf8_credential_as_tagged_text() {
let c = Credential::from_text("AKIA1234");
let json = serde_json::to_string(&c).unwrap();
assert_eq!(json, "{\"text\":\"AKIA1234\"}");
}
#[test]
fn serialize_binary_credential_as_tagged_b64() {
let c = Credential::from_bytes(&[0xFF, 0xFE, 0x00, 0x42]);
let json = serde_json::to_string(&c).unwrap();
assert!(
json.starts_with("{\"b64\":\""),
"expected tagged b64 envelope, got {json}"
);
}
#[test]
fn legacy_b64_prefix_still_deserializes() {
let bytes = [0xFF, 0xFE, 0x00, 0x42];
let legacy = format!("\"b64:{}\"", super::base64_encode(&bytes));
let back: Credential = serde_json::from_str(&legacy).unwrap();
assert_eq!(back.expose_secret(), &bytes);
}
#[test]
fn legacy_plain_string_still_deserializes() {
let back: Credential = serde_json::from_str("\"AKIA1234\"").unwrap();
assert_eq!(back.expose_str(), Some("AKIA1234"));
}
#[test]
fn round_trip_serde() {
let c = Credential::from_text("xoxb-1234-5678-abc");
let json = serde_json::to_string(&c).unwrap();
let back: Credential = serde_json::from_str(&json).unwrap();
assert_eq!(c, back);
}
#[test]
fn round_trip_binary_serde() {
let c = Credential::from_bytes(&[0x00, 0x01, 0xFF, 0xFE]);
let json = serde_json::to_string(&c).unwrap();
let back: Credential = serde_json::from_str(&json).unwrap();
assert_eq!(c, back);
}
#[test]
fn cloning_does_not_duplicate_buffer() {
let a = Credential::from_text("shared");
let b = a.clone();
assert!(std::ptr::eq(
a.expose_secret().as_ptr(),
b.expose_secret().as_ptr()
));
}
}
#[derive(Clone, Default)]
pub struct SensitiveString {
inner: Arc<Zeroizing<String>>,
}
impl SensitiveString {
pub fn new(s: String) -> Self {
Self {
inner: Arc::new(Zeroizing::new(s)),
}
}
pub fn join(parts: &[SensitiveString], sep: &str) -> Self {
let mut s = String::new();
for (i, p) in parts.iter().enumerate() {
if i > 0 {
s.push_str(sep);
}
s.push_str(p.as_str());
}
Self::new(s)
}
pub fn as_str(&self) -> &str {
self.inner.as_str()
}
pub fn as_bytes(&self) -> &[u8] {
self.inner.as_bytes()
}
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
}
impl std::ops::Deref for SensitiveString {
type Target = str;
fn deref(&self) -> &Self::Target {
self.as_str()
}
}
impl AsRef<str> for SensitiveString {
fn as_ref(&self) -> &str {
self.as_str()
}
}
impl From<String> for SensitiveString {
fn from(s: String) -> Self {
Self::new(s)
}
}
impl From<&str> for SensitiveString {
fn from(s: &str) -> Self {
Self::new(s.to_string())
}
}
impl From<&String> for SensitiveString {
fn from(s: &String) -> Self {
Self::new(s.clone())
}
}
impl std::fmt::Display for SensitiveString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
impl std::fmt::Debug for SensitiveString {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "SensitiveString({:?})", self.as_str())
}
}
impl Serialize for SensitiveString {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
self.as_str().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for SensitiveString {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
String::deserialize(deserializer).map(Self::new)
}
}