Skip to main content

a1/
cert_extensions.rs

1use std::collections::BTreeMap;
2
3use crate::crypto::hasher_cert_ext;
4use crate::error::A1Error;
5
6/// Maximum allowed length in bytes for all encoded extension keys and values combined.
7pub const MAX_EXTENSION_BYTES: usize = 16384;
8
9#[derive(Clone, Debug, PartialEq, Eq)]
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[cfg_attr(feature = "serde", serde(untagged))]
12pub enum ExtValue {
13    Str(String),
14    U64(u64),
15    Strings(Vec<String>),
16}
17
18/// Typed extension fields committed into a [`DelegationCert`] signature.
19///
20/// Extensions augment the minimal cert payload with business-level metadata
21/// without breaking the cryptographic invariants of the core fields.
22/// The extension map is canonically serialized and hashed before inclusion in
23/// `signable_bytes`, so a tampered extension causes signature verification to fail.
24///
25/// # Reserved namespaces
26///
27/// `dyolo.*` is reserved for protocol extensions. Applications should use
28/// their own reverse-DNS prefix, e.g. `acme.cost_center`.
29///
30/// # Well-known keys
31///
32/// | Key                     | Type       | Meaning                                      |
33/// |-------------------------|------------|----------------------------------------------|
34/// | `dyolo.rate_limit_rpm`  | `u64`      | Max requests per minute for the delegate     |
35/// | `dyolo.quota_tokens`    | `u64`      | Max LLM tokens the delegate may consume      |
36/// | `dyolo.geo_allow`       | `[String]` | ISO-3166-1 alpha-2 country codes allowed     |
37/// | `dyolo.cost_center`     | `String`   | Billing cost-center tag for audit            |
38/// | `dyolo.ttl_warn_sec`    | `u64`      | Alert threshold before expiry                |
39///
40/// [`DelegationCert`]: crate::DelegationCert
41#[derive(Clone, Debug, Default)]
42#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
43pub struct CertExtensions {
44    fields: BTreeMap<String, ExtValue>,
45    #[cfg_attr(feature = "serde", serde(skip))]
46    byte_count: usize,
47}
48
49impl CertExtensions {
50    /// Create an empty extension map.
51    pub fn new() -> Self {
52        Self::default()
53    }
54
55    /// Set an extension field. Returns `self` for chaining.
56    /// Panics if the size limit is exceeded. Prefer `set_checked` for safety.
57    pub fn set(self, key: impl Into<String>, value: impl Into<ExtValue>) -> Self {
58        self.set_checked(key, value)
59            .expect("extension limit exceeded")
60    }
61
62    /// Set an extension field with size validation.
63    ///
64    /// Keys must be non-empty. Use a reverse-DNS prefix for application-specific
65    /// keys (e.g. `acme.cost_center`). The `dyolo.*` namespace is reserved for
66    /// well-known protocol keys (`dyolo.rate_limit_rpm`, `dyolo.quota_tokens`,
67    /// `dyolo.geo_allow`, `dyolo.cost_center`, `dyolo.ttl_warn_sec`).
68    pub fn set_checked(
69        mut self,
70        key: impl Into<String>,
71        value: impl Into<ExtValue>,
72    ) -> Result<Self, A1Error> {
73        let key_str = key.into();
74
75        if key_str.is_empty() {
76            return Err(A1Error::WireFormatError(
77                "extension key must not be empty".into(),
78            ));
79        }
80
81        let val = value.into();
82
83        let mut add_bytes = key_str.len();
84        match &val {
85            ExtValue::Str(s) => add_bytes += s.len(),
86            ExtValue::U64(_) => add_bytes += 8,
87            ExtValue::Strings(v) => {
88                for s in v {
89                    add_bytes += s.len();
90                }
91            }
92        }
93
94        if self.byte_count + add_bytes > MAX_EXTENSION_BYTES {
95            return Err(A1Error::WireFormatError(
96                "maximum extension byte limit exceeded".into(),
97            ));
98        }
99
100        self.byte_count += add_bytes;
101        self.fields.insert(key_str, val);
102        Ok(self)
103    }
104
105    /// Get an extension field by key.
106    pub fn get(&self, key: &str) -> Option<&ExtValue> {
107        self.fields.get(key)
108    }
109
110    /// Returns `true` if no extension fields are set.
111    pub fn is_empty(&self) -> bool {
112        self.fields.is_empty()
113    }
114
115    pub fn iter(&self) -> impl Iterator<Item = (&String, &ExtValue)> {
116        self.fields.iter()
117    }
118
119    /// A deterministic 32-byte commitment over all extension fields.
120    pub fn commitment(&self) -> [u8; 32] {
121        let mut h = hasher_cert_ext(crate::cert::CERT_VERSION);
122        h.update(b"a1::dyolo::cert::ext::v2.8.0");
123        if self.fields.is_empty() {
124            // Explicit empty sentinel (hashes the domain + version + length 0)
125            h.update(&0u64.to_le_bytes());
126            return h.finalize().into();
127        }
128
129        h.update(&(self.fields.len() as u64).to_le_bytes());
130        for (k, v) in &self.fields {
131            let k_bytes = k.as_bytes();
132            h.update(&(k_bytes.len() as u64).to_le_bytes());
133            h.update(k_bytes);
134
135            // Deterministic typed encoding
136            match v {
137                ExtValue::Str(s) => {
138                    h.update(&[0u8]); // type tag
139                    h.update(&(s.len() as u64).to_le_bytes());
140                    h.update(s.as_bytes());
141                }
142                ExtValue::U64(n) => {
143                    h.update(&[1u8]); // type tag
144                    h.update(&n.to_le_bytes());
145                }
146                ExtValue::Strings(vec) => {
147                    h.update(&[2u8]); // type tag
148                    h.update(&(vec.len() as u64).to_le_bytes());
149                    for s in vec {
150                        h.update(&(s.len() as u64).to_le_bytes());
151                        h.update(s.as_bytes());
152                    }
153                }
154            }
155        }
156        h.finalize().into()
157    }
158}
159
160impl std::fmt::Display for ExtValue {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            ExtValue::Str(s) => write!(f, "{s}"),
164            ExtValue::U64(n) => write!(f, "{n}"),
165            ExtValue::Strings(v) => write!(f, "[{}]", v.join(", ")),
166        }
167    }
168}
169
170impl From<serde_json::Value> for ExtValue {
171    fn from(v: serde_json::Value) -> Self {
172        match v {
173            serde_json::Value::Number(n) if n.is_u64() => ExtValue::U64(n.as_u64().unwrap()),
174            serde_json::Value::String(s) => ExtValue::Str(s),
175            serde_json::Value::Array(arr) => ExtValue::Strings(
176                arr.into_iter()
177                    .map(|x| match x {
178                        serde_json::Value::String(s) => s,
179                        other => other.to_string(),
180                    })
181                    .collect(),
182            ),
183            other => ExtValue::Str(other.to_string()),
184        }
185    }
186}