1use std::collections::BTreeMap;
2
3use crate::crypto::hasher_cert_ext;
4use crate::error::A1Error;
5
6pub const MAX_EXTENSION_BYTES: usize = 16384;
8
9#[derive(Clone, Debug, PartialEq, Eq)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[cfg_attr(feature = "serde", serde(untagged))]
12pub enum ExtValue {
13 Str(String),
14 U64(u64),
15 Strings(Vec<String>),
16}
17
18#[derive(Clone, Debug, Default)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct CertExtensions {
44 fields: BTreeMap<String, ExtValue>,
45 #[cfg_attr(feature = "serde", serde(skip))]
46 byte_count: usize,
47}
48
49impl CertExtensions {
50 pub fn new() -> Self {
52 Self::default()
53 }
54
55 pub fn set(self, key: impl Into<String>, value: impl Into<ExtValue>) -> Self {
58 self.set_checked(key, value)
59 .expect("extension limit exceeded")
60 }
61
62 pub fn set_checked(
69 mut self,
70 key: impl Into<String>,
71 value: impl Into<ExtValue>,
72 ) -> Result<Self, A1Error> {
73 let key_str = key.into();
74
75 if key_str.is_empty() {
76 return Err(A1Error::WireFormatError(
77 "extension key must not be empty".into(),
78 ));
79 }
80
81 let val = value.into();
82
83 let mut add_bytes = key_str.len();
84 match &val {
85 ExtValue::Str(s) => add_bytes += s.len(),
86 ExtValue::U64(_) => add_bytes += 8,
87 ExtValue::Strings(v) => {
88 for s in v {
89 add_bytes += s.len();
90 }
91 }
92 }
93
94 if self.byte_count + add_bytes > MAX_EXTENSION_BYTES {
95 return Err(A1Error::WireFormatError(
96 "maximum extension byte limit exceeded".into(),
97 ));
98 }
99
100 self.byte_count += add_bytes;
101 self.fields.insert(key_str, val);
102 Ok(self)
103 }
104
105 pub fn get(&self, key: &str) -> Option<&ExtValue> {
107 self.fields.get(key)
108 }
109
110 pub fn is_empty(&self) -> bool {
112 self.fields.is_empty()
113 }
114
115 pub fn iter(&self) -> impl Iterator<Item = (&String, &ExtValue)> {
116 self.fields.iter()
117 }
118
119 pub fn commitment(&self) -> [u8; 32] {
121 let mut h = hasher_cert_ext(crate::cert::CERT_VERSION);
122 h.update(b"a1::dyolo::cert::ext::v2.8.0");
123 if self.fields.is_empty() {
124 h.update(&0u64.to_le_bytes());
126 return h.finalize().into();
127 }
128
129 h.update(&(self.fields.len() as u64).to_le_bytes());
130 for (k, v) in &self.fields {
131 let k_bytes = k.as_bytes();
132 h.update(&(k_bytes.len() as u64).to_le_bytes());
133 h.update(k_bytes);
134
135 match v {
137 ExtValue::Str(s) => {
138 h.update(&[0u8]); h.update(&(s.len() as u64).to_le_bytes());
140 h.update(s.as_bytes());
141 }
142 ExtValue::U64(n) => {
143 h.update(&[1u8]); h.update(&n.to_le_bytes());
145 }
146 ExtValue::Strings(vec) => {
147 h.update(&[2u8]); h.update(&(vec.len() as u64).to_le_bytes());
149 for s in vec {
150 h.update(&(s.len() as u64).to_le_bytes());
151 h.update(s.as_bytes());
152 }
153 }
154 }
155 }
156 h.finalize().into()
157 }
158}
159
160impl std::fmt::Display for ExtValue {
161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162 match self {
163 ExtValue::Str(s) => write!(f, "{s}"),
164 ExtValue::U64(n) => write!(f, "{n}"),
165 ExtValue::Strings(v) => write!(f, "[{}]", v.join(", ")),
166 }
167 }
168}
169
170impl From<serde_json::Value> for ExtValue {
171 fn from(v: serde_json::Value) -> Self {
172 match v {
173 serde_json::Value::Number(n) if n.is_u64() => ExtValue::U64(n.as_u64().unwrap()),
174 serde_json::Value::String(s) => ExtValue::Str(s),
175 serde_json::Value::Array(arr) => ExtValue::Strings(
176 arr.into_iter()
177 .map(|x| match x {
178 serde_json::Value::String(s) => s,
179 other => other.to_string(),
180 })
181 .collect(),
182 ),
183 other => ExtValue::Str(other.to_string()),
184 }
185 }
186}