Skip to main content

actr_hyper/verify/
cert_cache.rs

1//! Production mode MFR public key cache
2//!
3//! `MfrCertCache` fetches manufacturer Ed25519 public keys on demand from
4//! AIS `GET /mfr/{name}/verifying_key`, caching locally (TTL 1 hour).
5//!
6//! Uses `std::sync::RwLock` (not tokio) internally because:
7//! - Cache reads/writes are extremely short memory operations that won't block the tokio executor
8//! - Provides a synchronous read path for `RegistryTrust::verify_package` to call directly
9
10use std::collections::HashMap;
11use std::sync::{Arc, RwLock};
12use std::time::{Duration, Instant};
13
14use base64::Engine;
15use ed25519_dalek::VerifyingKey;
16
17use crate::error::{HyperError, HyperResult};
18
19/// MFR public key cache entry
20struct CacheEntry {
21    key: VerifyingKey,
22    fetched_at: Instant,
23}
24
25/// Production mode MFR Ed25519 public key cache
26///
27/// Fetches manufacturer public keys on demand from the AIS endpoint, cache TTL defaults to 1 hour.
28/// Shared across tasks via `Arc<MfrCertCache>`.
29pub struct MfrCertCache {
30    ais_endpoint: String,
31    http: reqwest::Client,
32    ttl: Duration,
33    cache: RwLock<HashMap<String, CacheEntry>>,
34}
35
36impl std::fmt::Debug for MfrCertCache {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        f.debug_struct("MfrCertCache")
39            .field("ais_endpoint", &self.ais_endpoint)
40            .field("ttl", &self.ttl)
41            .finish_non_exhaustive()
42    }
43}
44
45impl MfrCertCache {
46    pub fn new(ais_endpoint: impl Into<String>) -> Arc<Self> {
47        Arc::new(Self {
48            ais_endpoint: ais_endpoint.into(),
49            http: reqwest::Client::new(),
50            ttl: Duration::from_secs(3600),
51            cache: RwLock::new(HashMap::new()),
52        })
53    }
54
55    /// Used in `RegistryTrust::verify_package` synchronous path;
56    /// caller must ensure the cache has been warmed via `get_or_fetch` beforehand.
57    pub fn get_from_cache(&self, manufacturer: &str, key_id: Option<&str>) -> Option<VerifyingKey> {
58        let cache_key = match key_id {
59            Some(id) => format!("{}:{}", manufacturer, id),
60            None => manufacturer.to_string(),
61        };
62        let cache = self.cache.read().expect("cert_cache read lock poisoned");
63        cache.get(&cache_key).and_then(|entry| {
64            if entry.fetched_at.elapsed() < self.ttl {
65                Some(entry.key)
66            } else {
67                None
68            }
69        })
70    }
71
72    /// Get the Ed25519 verifying key for the specified manufacturer
73    ///
74    /// Reads from cache first (if not expired); on miss, fetches from AIS and updates cache.
75    pub async fn get_or_fetch(
76        &self,
77        manufacturer: &str,
78        key_id: Option<&str>,
79    ) -> HyperResult<VerifyingKey> {
80        // fast path: read cache
81        if let Some(key) = self.get_from_cache(manufacturer, key_id) {
82            tracing::debug!(manufacturer, ?key_id, "MFR pubkey cache hit");
83            return Ok(key);
84        }
85
86        tracing::debug!(
87            manufacturer,
88            ?key_id,
89            "MFR pubkey cache miss, fetching from AIS"
90        );
91
92        // slow path: HTTP fetch
93        let key = self.fetch_from_ais(manufacturer, key_id).await?;
94
95        // write to cache (brief blocking lock, just a HashMap insert)
96        let cache_key = match key_id {
97            Some(id) => format!("{}:{}", manufacturer, id),
98            None => manufacturer.to_string(),
99        };
100        {
101            let mut cache = self.cache.write().expect("cert_cache write lock poisoned");
102            cache.insert(
103                cache_key,
104                CacheEntry {
105                    key,
106                    fetched_at: Instant::now(),
107                },
108            );
109        }
110
111        tracing::info!(
112            manufacturer,
113            ?key_id,
114            "MFR pubkey fetched from AIS and cached"
115        );
116        Ok(key)
117    }
118
119    /// Fetch public key from AIS `GET /mfr/{manufacturer}/verifying_key`
120    async fn fetch_from_ais(
121        &self,
122        manufacturer: &str,
123        key_id: Option<&str>,
124    ) -> HyperResult<VerifyingKey> {
125        let url = if let Some(id) = key_id {
126            format!(
127                "{}/mfr/{}/verifying_key?key_id={}",
128                self.ais_endpoint, manufacturer, id
129            )
130        } else {
131            format!("{}/mfr/{}/verifying_key", self.ais_endpoint, manufacturer)
132        };
133        tracing::debug!(url, "fetching MFR pubkey from AIS");
134
135        let resp = self.http.get(&url).send().await.map_err(|e| {
136            HyperError::UntrustedManufacturer(format!(
137                "failed to fetch MFR pubkey ({manufacturer}): {e}"
138            ))
139        })?;
140
141        if !resp.status().is_success() {
142            let status = resp.status();
143            let body = resp.text().await.unwrap_or_default();
144            tracing::warn!(
145                manufacturer,
146                status = status.as_u16(),
147                body,
148                "AIS returned non-2xx, MFR pubkey fetch failed"
149            );
150            return Err(HyperError::UntrustedManufacturer(format!(
151                "AIS refused to provide MFR pubkey ({manufacturer}), status={status}"
152            )));
153        }
154
155        #[derive(serde::Deserialize)]
156        struct VerifyingKeyResp {
157            /// Base64-encoded Ed25519 verifying key (32 bytes)
158            public_key: String,
159        }
160
161        let body: VerifyingKeyResp = resp.json().await.map_err(|e| {
162            HyperError::UntrustedManufacturer(format!(
163                "failed to parse MFR pubkey response ({manufacturer}): {e}"
164            ))
165        })?;
166
167        let key_bytes = base64::engine::general_purpose::STANDARD
168            .decode(&body.public_key)
169            .map_err(|e| {
170                HyperError::UntrustedManufacturer(format!(
171                    "MFR pubkey base64 decode failed ({manufacturer}): {e}"
172                ))
173            })?;
174
175        let key_arr: [u8; 32] = key_bytes.try_into().map_err(|v: Vec<u8>| {
176            HyperError::UntrustedManufacturer(format!(
177                "MFR pubkey length incorrect ({manufacturer}), expected 32 bytes, got {} bytes",
178                v.len()
179            ))
180        })?;
181
182        VerifyingKey::from_bytes(&key_arr).map_err(|e| {
183            HyperError::UntrustedManufacturer(format!(
184                "MFR pubkey format invalid ({manufacturer}): {e}"
185            ))
186        })
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193
194    #[tokio::test]
195    async fn cache_returns_cached_key_without_http() {
196        use ed25519_dalek::SigningKey;
197        use rand::rngs::OsRng;
198
199        let signing_key = SigningKey::generate(&mut OsRng);
200        let verifying_key = signing_key.verifying_key();
201        let key_b64 = base64::engine::general_purpose::STANDARD.encode(verifying_key.to_bytes());
202
203        let mut server = mockito::Server::new_async().await;
204        let mock = server
205            .mock("GET", "/mfr/test-mfr/verifying_key")
206            .with_status(200)
207            .with_header("content-type", "application/json")
208            .with_body(format!(r#"{{"public_key":"{key_b64}"}}"#))
209            .expect(1) // only called once, second time hits cache
210            .create_async()
211            .await;
212
213        let cache = MfrCertCache::new(server.url());
214
215        // first miss -> calls HTTP
216        let k1 = cache.get_or_fetch("test-mfr", None).await.unwrap();
217        // second hit -> no HTTP call
218        let k2 = cache.get_or_fetch("test-mfr", None).await.unwrap();
219
220        mock.assert_async().await;
221        assert_eq!(k1.to_bytes(), k2.to_bytes());
222        assert_eq!(k1.to_bytes(), verifying_key.to_bytes());
223    }
224
225    #[tokio::test]
226    async fn fetch_fails_on_404() {
227        let mut server = mockito::Server::new_async().await;
228        let _mock = server
229            .mock("GET", "/mfr/unknown-mfr/verifying_key")
230            .with_status(404)
231            .create_async()
232            .await;
233
234        let cache = MfrCertCache::new(server.url());
235        let result = cache.get_or_fetch("unknown-mfr", None).await;
236
237        assert!(
238            matches!(result, Err(HyperError::UntrustedManufacturer(_))),
239            "404 should return UntrustedManufacturer, actual: {result:?}"
240        );
241    }
242}