actr_hyper/verify/
cert_cache.rs1use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use std::time::{Duration, Instant};
13
14use base64::Engine;
15use ed25519_dalek::VerifyingKey;
16
17use crate::error::{HyperError, HyperResult};
18
19struct CacheEntry {
21 key: VerifyingKey,
22 fetched_at: Instant,
23}
24
25pub struct MfrCertCache {
30 ais_endpoint: String,
31 http: reqwest::Client,
32 ttl: Duration,
33 cache: RwLock<HashMap<String, CacheEntry>>,
34}
35
36impl std::fmt::Debug for MfrCertCache {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("MfrCertCache")
39 .field("ais_endpoint", &self.ais_endpoint)
40 .field("ttl", &self.ttl)
41 .finish_non_exhaustive()
42 }
43}
44
45impl MfrCertCache {
46 pub fn new(ais_endpoint: impl Into<String>) -> Arc<Self> {
47 Arc::new(Self {
48 ais_endpoint: ais_endpoint.into(),
49 http: reqwest::Client::new(),
50 ttl: Duration::from_secs(3600),
51 cache: RwLock::new(HashMap::new()),
52 })
53 }
54
55 pub fn get_from_cache(&self, manufacturer: &str, key_id: Option<&str>) -> Option<VerifyingKey> {
58 let cache_key = match key_id {
59 Some(id) => format!("{}:{}", manufacturer, id),
60 None => manufacturer.to_string(),
61 };
62 let cache = self.cache.read().expect("cert_cache read lock poisoned");
63 cache.get(&cache_key).and_then(|entry| {
64 if entry.fetched_at.elapsed() < self.ttl {
65 Some(entry.key)
66 } else {
67 None
68 }
69 })
70 }
71
72 pub async fn get_or_fetch(
76 &self,
77 manufacturer: &str,
78 key_id: Option<&str>,
79 ) -> HyperResult<VerifyingKey> {
80 if let Some(key) = self.get_from_cache(manufacturer, key_id) {
82 tracing::debug!(manufacturer, ?key_id, "MFR pubkey cache hit");
83 return Ok(key);
84 }
85
86 tracing::debug!(
87 manufacturer,
88 ?key_id,
89 "MFR pubkey cache miss, fetching from AIS"
90 );
91
92 let key = self.fetch_from_ais(manufacturer, key_id).await?;
94
95 let cache_key = match key_id {
97 Some(id) => format!("{}:{}", manufacturer, id),
98 None => manufacturer.to_string(),
99 };
100 {
101 let mut cache = self.cache.write().expect("cert_cache write lock poisoned");
102 cache.insert(
103 cache_key,
104 CacheEntry {
105 key,
106 fetched_at: Instant::now(),
107 },
108 );
109 }
110
111 tracing::info!(
112 manufacturer,
113 ?key_id,
114 "MFR pubkey fetched from AIS and cached"
115 );
116 Ok(key)
117 }
118
119 async fn fetch_from_ais(
121 &self,
122 manufacturer: &str,
123 key_id: Option<&str>,
124 ) -> HyperResult<VerifyingKey> {
125 let url = if let Some(id) = key_id {
126 format!(
127 "{}/mfr/{}/verifying_key?key_id={}",
128 self.ais_endpoint, manufacturer, id
129 )
130 } else {
131 format!("{}/mfr/{}/verifying_key", self.ais_endpoint, manufacturer)
132 };
133 tracing::debug!(url, "fetching MFR pubkey from AIS");
134
135 let resp = self.http.get(&url).send().await.map_err(|e| {
136 HyperError::UntrustedManufacturer(format!(
137 "failed to fetch MFR pubkey ({manufacturer}): {e}"
138 ))
139 })?;
140
141 if !resp.status().is_success() {
142 let status = resp.status();
143 let body = resp.text().await.unwrap_or_default();
144 tracing::warn!(
145 manufacturer,
146 status = status.as_u16(),
147 body,
148 "AIS returned non-2xx, MFR pubkey fetch failed"
149 );
150 return Err(HyperError::UntrustedManufacturer(format!(
151 "AIS refused to provide MFR pubkey ({manufacturer}), status={status}"
152 )));
153 }
154
155 #[derive(serde::Deserialize)]
156 struct VerifyingKeyResp {
157 public_key: String,
159 }
160
161 let body: VerifyingKeyResp = resp.json().await.map_err(|e| {
162 HyperError::UntrustedManufacturer(format!(
163 "failed to parse MFR pubkey response ({manufacturer}): {e}"
164 ))
165 })?;
166
167 let key_bytes = base64::engine::general_purpose::STANDARD
168 .decode(&body.public_key)
169 .map_err(|e| {
170 HyperError::UntrustedManufacturer(format!(
171 "MFR pubkey base64 decode failed ({manufacturer}): {e}"
172 ))
173 })?;
174
175 let key_arr: [u8; 32] = key_bytes.try_into().map_err(|v: Vec<u8>| {
176 HyperError::UntrustedManufacturer(format!(
177 "MFR pubkey length incorrect ({manufacturer}), expected 32 bytes, got {} bytes",
178 v.len()
179 ))
180 })?;
181
182 VerifyingKey::from_bytes(&key_arr).map_err(|e| {
183 HyperError::UntrustedManufacturer(format!(
184 "MFR pubkey format invalid ({manufacturer}): {e}"
185 ))
186 })
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[tokio::test]
195 async fn cache_returns_cached_key_without_http() {
196 use ed25519_dalek::SigningKey;
197 use rand::rngs::OsRng;
198
199 let signing_key = SigningKey::generate(&mut OsRng);
200 let verifying_key = signing_key.verifying_key();
201 let key_b64 = base64::engine::general_purpose::STANDARD.encode(verifying_key.to_bytes());
202
203 let mut server = mockito::Server::new_async().await;
204 let mock = server
205 .mock("GET", "/mfr/test-mfr/verifying_key")
206 .with_status(200)
207 .with_header("content-type", "application/json")
208 .with_body(format!(r#"{{"public_key":"{key_b64}"}}"#))
209 .expect(1) .create_async()
211 .await;
212
213 let cache = MfrCertCache::new(server.url());
214
215 let k1 = cache.get_or_fetch("test-mfr", None).await.unwrap();
217 let k2 = cache.get_or_fetch("test-mfr", None).await.unwrap();
219
220 mock.assert_async().await;
221 assert_eq!(k1.to_bytes(), k2.to_bytes());
222 assert_eq!(k1.to_bytes(), verifying_key.to_bytes());
223 }
224
225 #[tokio::test]
226 async fn fetch_fails_on_404() {
227 let mut server = mockito::Server::new_async().await;
228 let _mock = server
229 .mock("GET", "/mfr/unknown-mfr/verifying_key")
230 .with_status(404)
231 .create_async()
232 .await;
233
234 let cache = MfrCertCache::new(server.url());
235 let result = cache.get_or_fetch("unknown-mfr", None).await;
236
237 assert!(
238 matches!(result, Err(HyperError::UntrustedManufacturer(_))),
239 "404 should return UntrustedManufacturer, actual: {result:?}"
240 );
241 }
242}