use std::collections::HashMap;
use std::ops::Deref;
use crate::auth::AuthError;
use crate::primitives::public_key::PublicKey;
use crate::primitives::symmetric_key::SymmetricKey;
use crate::wallet::interfaces::{
Certificate, CreateSignatureArgs, DecryptArgs, EncryptArgs, GetPublicKeyArgs,
VerifySignatureArgs, WalletInterface,
};
use crate::wallet::types::{Counterparty, CounterpartyType, Protocol};
pub const CERTIFICATE_SIGNATURE_PROTOCOL: &str = "certificate signature";
pub const CERTIFICATE_FIELD_ENCRYPTION_PROTOCOL: &str = "certificate field encryption";
pub const SECURITY_LEVEL: u8 = 2;
pub(crate) fn base64_encode(data: &[u8]) -> String {
const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut result = String::new();
let chunks = data.chunks(3);
for chunk in chunks {
let b0 = chunk[0] as u32;
let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
let triple = (b0 << 16) | (b1 << 8) | b2;
result.push(CHARS[((triple >> 18) & 0x3F) as usize] as char);
result.push(CHARS[((triple >> 12) & 0x3F) as usize] as char);
if chunk.len() > 1 {
result.push(CHARS[((triple >> 6) & 0x3F) as usize] as char);
} else {
result.push('=');
}
if chunk.len() > 2 {
result.push(CHARS[(triple & 0x3F) as usize] as char);
} else {
result.push('=');
}
}
result
}
pub(crate) fn base64_decode(s: &str) -> Result<Vec<u8>, AuthError> {
fn char_to_val(c: u8) -> Result<u8, AuthError> {
match c {
b'A'..=b'Z' => Ok(c - b'A'),
b'a'..=b'z' => Ok(c - b'a' + 26),
b'0'..=b'9' => Ok(c - b'0' + 52),
b'+' => Ok(62),
b'/' => Ok(63),
_ => Err(AuthError::SerializationError(format!(
"invalid base64 character: {}",
c as char
))),
}
}
let bytes = s.as_bytes();
let mut result = Vec::new();
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'=' {
break;
}
let a = char_to_val(bytes[i])?;
let b = if i + 1 < bytes.len() && bytes[i + 1] != b'=' {
char_to_val(bytes[i + 1])?
} else {
0
};
let c = if i + 2 < bytes.len() && bytes[i + 2] != b'=' {
char_to_val(bytes[i + 2])?
} else {
0
};
let d = if i + 3 < bytes.len() && bytes[i + 3] != b'=' {
char_to_val(bytes[i + 3])?
} else {
0
};
let triple = ((a as u32) << 18) | ((b as u32) << 12) | ((c as u32) << 6) | (d as u32);
result.push(((triple >> 16) & 0xFF) as u8);
if i + 2 < bytes.len() && bytes[i + 2] != b'=' {
result.push(((triple >> 8) & 0xFF) as u8);
}
if i + 3 < bytes.len() && bytes[i + 3] != b'=' {
result.push((triple & 0xFF) as u8);
}
i += 4;
}
Ok(result)
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "network", derive(serde::Serialize, serde::Deserialize))]
pub struct AuthCertificate {
#[cfg_attr(feature = "network", serde(flatten))]
pub inner: Certificate,
}
impl Deref for AuthCertificate {
type Target = Certificate;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl AuthCertificate {
pub fn new(inner: Certificate) -> Self {
AuthCertificate { inner }
}
fn to_binary_for_signing(cert: &Certificate) -> Vec<u8> {
let mut data = Vec::new();
data.extend_from_slice(&cert.cert_type.0);
data.extend_from_slice(&cert.serial_number.0);
let subject_bytes = cert.subject.to_der();
data.extend_from_slice(&subject_bytes);
let certifier_bytes = cert.certifier.to_der();
data.extend_from_slice(&certifier_bytes);
if let Some(ref outpoint) = cert.revocation_outpoint {
if let Some(dot_idx) = outpoint.find('.') {
let txid_hex = &outpoint[..dot_idx];
let output_index_str = &outpoint[dot_idx + 1..];
let txid_bytes = hex_decode(txid_hex);
data.extend_from_slice(&txid_bytes);
let output_index: u64 = output_index_str.parse().unwrap_or(0);
write_varint(&mut data, output_index);
}
}
if let Some(ref fields) = cert.fields {
let mut field_names: Vec<&String> = fields.keys().collect();
field_names.sort();
write_varint(&mut data, field_names.len() as u64);
for name in field_names {
let name_bytes = name.as_bytes();
write_varint(&mut data, name_bytes.len() as u64);
data.extend_from_slice(name_bytes);
let value = &fields[name];
let value_bytes = value.as_bytes();
write_varint(&mut data, value_bytes.len() as u64);
data.extend_from_slice(value_bytes);
}
} else {
write_varint(&mut data, 0);
}
data
}
pub async fn sign<W: WalletInterface + ?Sized>(
cert: &mut Certificate,
wallet: &W,
) -> Result<(), AuthError> {
if cert.signature.is_some() {
return Err(AuthError::CertificateValidation(
"certificate has already been signed".to_string(),
));
}
let identity_result = wallet
.get_public_key(
GetPublicKeyArgs {
identity_key: true,
protocol_id: None,
key_id: None,
counterparty: None,
privileged: false,
privileged_reason: None,
for_self: None,
seek_permission: None,
},
None,
)
.await?;
cert.certifier = identity_result.public_key;
let preimage = Self::to_binary_for_signing(cert);
let key_id = format!(
"{} {}",
base64_encode(&cert.cert_type.0),
base64_encode(&cert.serial_number.0)
);
let result = wallet
.create_signature(
CreateSignatureArgs {
data: Some(preimage),
hash_to_directly_sign: None,
protocol_id: Protocol {
security_level: SECURITY_LEVEL,
protocol: CERTIFICATE_SIGNATURE_PROTOCOL.to_string(),
},
key_id,
counterparty: Counterparty {
counterparty_type: CounterpartyType::Uninitialized,
public_key: None,
},
privileged: false,
privileged_reason: None,
seek_permission: None,
},
None,
)
.await?;
cert.signature = Some(result.signature);
Ok(())
}
pub async fn verify<W: WalletInterface + ?Sized>(
cert: &Certificate,
wallet: &W,
) -> Result<bool, AuthError> {
let preimage = Self::to_binary_for_signing(cert);
let signature = cert.signature.clone().unwrap_or_default();
let key_id = format!(
"{} {}",
base64_encode(&cert.cert_type.0),
base64_encode(&cert.serial_number.0)
);
let result = wallet
.verify_signature(
VerifySignatureArgs {
data: Some(preimage),
hash_to_directly_verify: None,
signature,
protocol_id: Protocol {
security_level: SECURITY_LEVEL,
protocol: CERTIFICATE_SIGNATURE_PROTOCOL.to_string(),
},
key_id,
counterparty: Counterparty {
counterparty_type: CounterpartyType::Other,
public_key: Some(cert.certifier.clone()),
},
for_self: None,
privileged: false,
privileged_reason: None,
seek_permission: None,
},
None,
)
.await?;
Ok(result.valid)
}
pub fn get_certificate_field_encryption_details(
field_name: &str,
serial_number: Option<&str>,
) -> (Protocol, String) {
let key_id = match serial_number {
Some(sn) => format!("{} {}", sn, field_name),
None => field_name.to_string(),
};
(
Protocol {
security_level: SECURITY_LEVEL,
protocol: CERTIFICATE_FIELD_ENCRYPTION_PROTOCOL.to_string(),
},
key_id,
)
}
pub async fn encrypt_fields<W: WalletInterface + ?Sized>(
fields: &HashMap<String, String>,
serial_number: Option<&str>,
counterparty: &PublicKey,
wallet: &W,
) -> Result<(HashMap<String, String>, HashMap<String, String>), AuthError> {
let mut encrypted_fields = HashMap::new();
let mut keyring = HashMap::new();
for (field_name, field_value) in fields {
let sym_key = SymmetricKey::from_random();
let encrypted_value = sym_key.encrypt(field_value.as_bytes())?;
encrypted_fields.insert(field_name.clone(), base64_encode(&encrypted_value));
let (protocol, key_id) =
Self::get_certificate_field_encryption_details(field_name, serial_number);
let encrypt_result = wallet
.encrypt(
EncryptArgs {
plaintext: sym_key.to_bytes(),
protocol_id: protocol,
key_id,
counterparty: Counterparty {
counterparty_type: CounterpartyType::Other,
public_key: Some(counterparty.clone()),
},
privileged: false,
privileged_reason: None,
seek_permission: None,
},
None,
)
.await?;
keyring.insert(
field_name.clone(),
base64_encode(&encrypt_result.ciphertext),
);
}
Ok((encrypted_fields, keyring))
}
pub async fn decrypt_fields<W: WalletInterface + ?Sized>(
encrypted_fields: &HashMap<String, String>,
keyring: &HashMap<String, String>,
serial_number: &str,
counterparty: &PublicKey,
wallet: &W,
) -> Result<HashMap<String, String>, AuthError> {
if keyring.is_empty() {
return Err(AuthError::CertificateValidation(
"a keyring is required to decrypt certificate fields".to_string(),
));
}
let mut decrypted = HashMap::new();
for (field_name, encrypted_key_b64) in keyring {
let encrypted_key = base64_decode(encrypted_key_b64)?;
let (protocol, key_id) =
Self::get_certificate_field_encryption_details(field_name, Some(serial_number));
let decrypt_result = wallet
.decrypt(
DecryptArgs {
ciphertext: encrypted_key,
protocol_id: protocol,
key_id,
counterparty: Counterparty {
counterparty_type: CounterpartyType::Other,
public_key: Some(counterparty.clone()),
},
privileged: false,
privileged_reason: None,
seek_permission: None,
},
None,
)
.await?;
let sym_key = SymmetricKey::from_bytes(&decrypt_result.plaintext)?;
let encrypted_field_value = match encrypted_fields.get(field_name) {
Some(v) => base64_decode(v)?,
None => {
return Err(AuthError::CertificateValidation(format!(
"field '{}' not found in encrypted fields",
field_name
)));
}
};
let plaintext_bytes = sym_key.decrypt(&encrypted_field_value)?;
let plaintext = String::from_utf8(plaintext_bytes).map_err(|e| {
AuthError::CertificateValidation(format!(
"decrypted field '{}' is not valid UTF-8: {}",
field_name, e
))
})?;
decrypted.insert(field_name.clone(), plaintext);
}
Ok(decrypted)
}
}
fn hex_decode(hex: &str) -> Vec<u8> {
let mut bytes = Vec::with_capacity(hex.len() / 2);
let mut i = 0;
let hex_bytes = hex.as_bytes();
while i + 1 < hex_bytes.len() {
let hi = hex_nibble(hex_bytes[i]);
let lo = hex_nibble(hex_bytes[i + 1]);
bytes.push((hi << 4) | lo);
i += 2;
}
bytes
}
fn hex_nibble(c: u8) -> u8 {
match c {
b'0'..=b'9' => c - b'0',
b'a'..=b'f' => c - b'a' + 10,
b'A'..=b'F' => c - b'A' + 10,
_ => 0,
}
}
fn write_varint(buf: &mut Vec<u8>, val: u64) {
if val < 0xfd {
buf.push(val as u8);
} else if val <= 0xffff {
buf.push(0xfd);
buf.extend_from_slice(&(val as u16).to_le_bytes());
} else if val <= 0xffff_ffff {
buf.push(0xfe);
buf.extend_from_slice(&(val as u32).to_le_bytes());
} else {
buf.push(0xff);
buf.extend_from_slice(&val.to_le_bytes());
}
}