use std::collections::BTreeMap;
use crate::crypto::hasher_cert_ext;
use crate::error::KyaError;
pub const MAX_EXTENSION_BYTES: usize = 16384;
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(untagged))]
pub enum ExtValue {
Str(String),
U64(u64),
Strings(Vec<String>),
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct CertExtensions {
fields: BTreeMap<String, ExtValue>,
#[cfg_attr(feature = "serde", serde(skip))]
byte_count: usize,
}
impl CertExtensions {
pub fn new() -> Self {
Self::default()
}
pub fn set(self, key: impl Into<String>, value: impl Into<ExtValue>) -> Self {
self.set_checked(key, value)
.expect("extension limit exceeded")
}
pub fn set_checked(
mut self,
key: impl Into<String>,
value: impl Into<ExtValue>,
) -> Result<Self, KyaError> {
let key_str = key.into();
if key_str.is_empty() {
return Err(KyaError::WireFormatError(
"extension key must not be empty".into(),
));
}
let val = value.into();
let mut add_bytes = key_str.len();
match &val {
ExtValue::Str(s) => add_bytes += s.len(),
ExtValue::U64(_) => add_bytes += 8,
ExtValue::Strings(v) => {
for s in v {
add_bytes += s.len();
}
}
}
if self.byte_count + add_bytes > MAX_EXTENSION_BYTES {
return Err(KyaError::WireFormatError(
"maximum extension byte limit exceeded".into(),
));
}
self.byte_count += add_bytes;
self.fields.insert(key_str, val);
Ok(self)
}
pub fn get(&self, key: &str) -> Option<&ExtValue> {
self.fields.get(key)
}
pub fn is_empty(&self) -> bool {
self.fields.is_empty()
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &ExtValue)> {
self.fields.iter()
}
pub fn commitment(&self) -> [u8; 32] {
let mut h = hasher_cert_ext(crate::cert::CERT_VERSION);
if self.fields.is_empty() {
h.update(&0u64.to_le_bytes());
return h.finalize().into();
}
h.update(&(self.fields.len() as u64).to_le_bytes());
for (k, v) in &self.fields {
let k_bytes = k.as_bytes();
h.update(&(k_bytes.len() as u64).to_le_bytes());
h.update(k_bytes);
match v {
ExtValue::Str(s) => {
h.update(&[0u8]); h.update(&(s.len() as u64).to_le_bytes());
h.update(s.as_bytes());
}
ExtValue::U64(n) => {
h.update(&[1u8]); h.update(&n.to_le_bytes());
}
ExtValue::Strings(vec) => {
h.update(&[2u8]); h.update(&(vec.len() as u64).to_le_bytes());
for s in vec {
h.update(&(s.len() as u64).to_le_bytes());
h.update(s.as_bytes());
}
}
}
}
h.finalize().into()
}
}
impl std::fmt::Display for ExtValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExtValue::Str(s) => write!(f, "{s}"),
ExtValue::U64(n) => write!(f, "{n}"),
ExtValue::Strings(v) => write!(f, "[{}]", v.join(", ")),
}
}
}
impl From<serde_json::Value> for ExtValue {
fn from(v: serde_json::Value) -> Self {
match v {
serde_json::Value::Number(n) if n.is_u64() => ExtValue::U64(n.as_u64().unwrap()),
serde_json::Value::String(s) => ExtValue::Str(s),
serde_json::Value::Array(arr) => ExtValue::Strings(
arr.into_iter()
.map(|x| match x {
serde_json::Value::String(s) => s,
other => other.to_string(),
})
.collect(),
),
other => ExtValue::Str(other.to_string()),
}
}
}