use std::io::BufRead;
use std::path::{Path, PathBuf};
pub const PEM_WALLET_FILE_NAME: &str = "ewallet.pem";
pub const SSO_WALLET_FILE_NAME: &str = "cwallet.sso";
#[derive(thiserror::Error)]
#[non_exhaustive]
pub enum WalletError {
#[error("wallet file is missing")]
FileMissing(String),
#[error("failed to read wallet file: {source}")]
Io {
path: String,
#[source]
source: std::io::Error,
},
#[error("failed to parse wallet PEM: {0}")]
Pem(String),
#[error("wallet contained no certificates")]
NoCertificates,
#[error("cwallet.sso parse error: {0}")]
Sso(String),
#[error(
"cwallet.sso support is experimental and not enabled; rebuild with \
--features experimental, or convert the wallet to ewallet.pem"
)]
SsoNotEnabled,
#[error("wallet format {format} is not supported by this thin build")]
UnsupportedFormat { format: &'static str },
}
impl std::fmt::Debug for WalletError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
const REDACTED_PATH: &str = "***redacted***";
let redacted = |_: &String| REDACTED_PATH;
match self {
Self::FileMissing(path) => f.debug_tuple("FileMissing").field(&redacted(path)).finish(),
Self::Io { path, source } => f
.debug_struct("Io")
.field("path", &redacted(path))
.field("source", source)
.finish(),
Self::Pem(message) => f.debug_tuple("Pem").field(message).finish(),
Self::NoCertificates => f.write_str("NoCertificates"),
Self::Sso(message) => f.debug_tuple("Sso").field(message).finish(),
Self::SsoNotEnabled => f.write_str("SsoNotEnabled"),
Self::UnsupportedFormat { format } => f
.debug_struct("UnsupportedFormat")
.field("format", format)
.finish(),
}
}
}
#[derive(Debug, Clone, Default)]
pub struct WalletContents {
pub ca_certificates: Vec<Vec<u8>>,
pub client_cert_chain: Vec<Vec<u8>>,
pub client_private_key: Option<Vec<u8>>,
}
impl WalletContents {
#[must_use]
pub fn has_client_identity(&self) -> bool {
!self.client_cert_chain.is_empty() && self.client_private_key.is_some()
}
}
#[must_use]
pub fn resolve_wallet_dir(
wallet_location: Option<&str>,
tns_admin: Option<&str>,
) -> Option<PathBuf> {
if let Some(loc) = wallet_location {
if !loc.is_empty() && !loc.eq_ignore_ascii_case("SYSTEM") {
return Some(PathBuf::from(loc));
}
if loc.eq_ignore_ascii_case("SYSTEM") {
return None;
}
}
tns_admin.filter(|s| !s.is_empty()).map(PathBuf::from)
}
#[must_use]
pub fn pem_wallet_path(dir: &Path) -> PathBuf {
dir.join(PEM_WALLET_FILE_NAME)
}
#[must_use]
pub fn sso_wallet_path(dir: &Path) -> PathBuf {
dir.join(SSO_WALLET_FILE_NAME)
}
pub fn parse_ewallet_pem(
pem: &[u8],
_wallet_password: Option<&str>,
) -> Result<WalletContents, WalletError> {
let mut reader = std::io::BufReader::new(pem);
let mut contents = WalletContents::default();
let mut all_certs: Vec<Vec<u8>> = Vec::new();
let mut keys: Vec<Vec<u8>> = Vec::new();
let mut saw_encrypted_key = false;
loop {
match rustls_pemfile::read_one(&mut reader) {
Ok(Some(item)) => match item {
rustls_pemfile::Item::X509Certificate(der) => {
all_certs.push(der.as_ref().to_vec());
}
rustls_pemfile::Item::Pkcs8Key(der) => {
keys.push(der.secret_pkcs8_der().to_vec());
}
rustls_pemfile::Item::Pkcs1Key(der) => {
keys.push(der.secret_pkcs1_der().to_vec());
}
rustls_pemfile::Item::Sec1Key(der) => {
keys.push(der.secret_sec1_der().to_vec());
}
_ => {}
},
Ok(None) => break,
Err(e) => return Err(WalletError::Pem(e.to_string())),
}
}
if keys.is_empty() && pem_contains_encrypted_key(pem) {
saw_encrypted_key = true;
}
if all_certs.is_empty() {
return Err(WalletError::NoCertificates);
}
contents.ca_certificates = all_certs.clone();
if let Some(key) = keys.into_iter().next() {
contents.client_cert_chain = all_certs;
contents.client_private_key = Some(key);
} else if saw_encrypted_key {
return Err(WalletError::Pem(
"wallet private key is encrypted; supply a wallet with an \
unencrypted ewallet.pem (orapki ... -auto_login) or use cwallet.sso"
.to_string(),
));
}
Ok(contents)
}
pub fn parse_pem_certificates(reader: &mut dyn BufRead) -> Vec<Vec<u8>> {
rustls_pemfile::certs(reader)
.filter_map(Result::ok)
.map(|der| der.as_ref().to_vec())
.collect()
}
fn pem_contains_encrypted_key(pem: &[u8]) -> bool {
let mut reader = std::io::BufReader::new(pem);
let mut line = String::new();
while let Ok(n) = reader.read_line(&mut line) {
if n == 0 {
break;
}
if line.contains("ENCRYPTED PRIVATE KEY") || line.contains("Proc-Type: 4,ENCRYPTED") {
return true;
}
line.clear();
}
false
}
pub fn read_ewallet_pem(
dir: &Path,
wallet_password: Option<&str>,
) -> Result<WalletContents, WalletError> {
let path = pem_wallet_path(dir);
if !path.exists() {
return Err(WalletError::FileMissing(path.display().to_string()));
}
let bytes = std::fs::read(&path).map_err(|source| WalletError::Io {
path: path.display().to_string(),
source,
})?;
parse_ewallet_pem(&bytes, wallet_password)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_prefers_explicit_location() {
let dir = resolve_wallet_dir(Some("/wallets/db1"), Some("/etc/tns"));
assert_eq!(dir, Some(PathBuf::from("/wallets/db1")));
}
#[test]
fn resolve_system_means_no_wallet() {
assert_eq!(resolve_wallet_dir(Some("SYSTEM"), Some("/etc/tns")), None);
assert_eq!(resolve_wallet_dir(Some("system"), None), None);
}
#[test]
fn resolve_falls_back_to_tns_admin() {
assert_eq!(
resolve_wallet_dir(None, Some("/etc/tns")),
Some(PathBuf::from("/etc/tns"))
);
}
#[test]
fn resolve_none_when_nothing_set() {
assert_eq!(resolve_wallet_dir(None, None), None);
assert_eq!(resolve_wallet_dir(Some(""), None), None);
}
#[test]
fn parse_rejects_empty_pem() {
let err = parse_ewallet_pem(b"", None).unwrap_err();
assert!(matches!(err, WalletError::NoCertificates));
}
#[test]
fn wallet_errors_redact_paths_in_display_and_debug() {
let sensitive_path = "/private/wallet/ewallet.pem";
let err = WalletError::FileMissing(sensitive_path.to_string());
assert!(!format!("{err}").contains(sensitive_path));
assert!(!format!("{err:?}").contains(sensitive_path));
let err = WalletError::Io {
path: sensitive_path.to_string(),
source: std::io::Error::new(std::io::ErrorKind::NotFound, "missing"),
};
assert!(!format!("{err}").contains(sensitive_path));
assert!(!format!("{err:?}").contains(sensitive_path));
}
}