use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
use base64::Engine;
use ed25519_dalek::VerifyingKey;
use crate::error::{HyperError, HyperResult};
struct CacheEntry {
key: VerifyingKey,
fetched_at: Instant,
}
pub struct MfrCertCache {
ais_endpoint: String,
http: reqwest::Client,
ttl: Duration,
cache: RwLock<HashMap<String, CacheEntry>>,
}
impl std::fmt::Debug for MfrCertCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MfrCertCache")
.field("ais_endpoint", &self.ais_endpoint)
.field("ttl", &self.ttl)
.finish_non_exhaustive()
}
}
impl MfrCertCache {
pub fn new(ais_endpoint: impl Into<String>) -> Arc<Self> {
Arc::new(Self {
ais_endpoint: ais_endpoint.into(),
http: reqwest::Client::new(),
ttl: Duration::from_secs(3600),
cache: RwLock::new(HashMap::new()),
})
}
pub fn get_from_cache(&self, manufacturer: &str, key_id: Option<&str>) -> Option<VerifyingKey> {
let cache_key = match key_id {
Some(id) => format!("{}:{}", manufacturer, id),
None => manufacturer.to_string(),
};
let cache = self.cache.read().expect("cert_cache read lock poisoned");
cache.get(&cache_key).and_then(|entry| {
if entry.fetched_at.elapsed() < self.ttl {
Some(entry.key)
} else {
None
}
})
}
pub async fn get_or_fetch(
&self,
manufacturer: &str,
key_id: Option<&str>,
) -> HyperResult<VerifyingKey> {
if let Some(key) = self.get_from_cache(manufacturer, key_id) {
tracing::debug!(manufacturer, ?key_id, "MFR pubkey cache hit");
return Ok(key);
}
tracing::debug!(
manufacturer,
?key_id,
"MFR pubkey cache miss, fetching from AIS"
);
let key = self.fetch_from_ais(manufacturer, key_id).await?;
let cache_key = match key_id {
Some(id) => format!("{}:{}", manufacturer, id),
None => manufacturer.to_string(),
};
{
let mut cache = self.cache.write().expect("cert_cache write lock poisoned");
cache.insert(
cache_key,
CacheEntry {
key,
fetched_at: Instant::now(),
},
);
}
tracing::info!(
manufacturer,
?key_id,
"MFR pubkey fetched from AIS and cached"
);
Ok(key)
}
async fn fetch_from_ais(
&self,
manufacturer: &str,
key_id: Option<&str>,
) -> HyperResult<VerifyingKey> {
let url = if let Some(id) = key_id {
format!(
"{}/mfr/{}/verifying_key?key_id={}",
self.ais_endpoint, manufacturer, id
)
} else {
format!("{}/mfr/{}/verifying_key", self.ais_endpoint, manufacturer)
};
tracing::debug!(url, "fetching MFR pubkey from AIS");
let resp = self.http.get(&url).send().await.map_err(|e| {
HyperError::UntrustedManufacturer(format!(
"failed to fetch MFR pubkey ({manufacturer}): {e}"
))
})?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
tracing::warn!(
manufacturer,
status = status.as_u16(),
body,
"AIS returned non-2xx, MFR pubkey fetch failed"
);
return Err(HyperError::UntrustedManufacturer(format!(
"AIS refused to provide MFR pubkey ({manufacturer}), status={status}"
)));
}
#[derive(serde::Deserialize)]
struct VerifyingKeyResp {
public_key: String,
}
let body: VerifyingKeyResp = resp.json().await.map_err(|e| {
HyperError::UntrustedManufacturer(format!(
"failed to parse MFR pubkey response ({manufacturer}): {e}"
))
})?;
let key_bytes = base64::engine::general_purpose::STANDARD
.decode(&body.public_key)
.map_err(|e| {
HyperError::UntrustedManufacturer(format!(
"MFR pubkey base64 decode failed ({manufacturer}): {e}"
))
})?;
let key_arr: [u8; 32] = key_bytes.try_into().map_err(|v: Vec<u8>| {
HyperError::UntrustedManufacturer(format!(
"MFR pubkey length incorrect ({manufacturer}), expected 32 bytes, got {} bytes",
v.len()
))
})?;
VerifyingKey::from_bytes(&key_arr).map_err(|e| {
HyperError::UntrustedManufacturer(format!(
"MFR pubkey format invalid ({manufacturer}): {e}"
))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn cache_returns_cached_key_without_http() {
use ed25519_dalek::SigningKey;
use rand::rngs::OsRng;
let signing_key = SigningKey::generate(&mut OsRng);
let verifying_key = signing_key.verifying_key();
let key_b64 = base64::engine::general_purpose::STANDARD.encode(verifying_key.to_bytes());
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("GET", "/mfr/test-mfr/verifying_key")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(format!(r#"{{"public_key":"{key_b64}"}}"#))
.expect(1) .create_async()
.await;
let cache = MfrCertCache::new(server.url());
let k1 = cache.get_or_fetch("test-mfr", None).await.unwrap();
let k2 = cache.get_or_fetch("test-mfr", None).await.unwrap();
mock.assert_async().await;
assert_eq!(k1.to_bytes(), k2.to_bytes());
assert_eq!(k1.to_bytes(), verifying_key.to_bytes());
}
#[tokio::test]
async fn fetch_fails_on_404() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("GET", "/mfr/unknown-mfr/verifying_key")
.with_status(404)
.create_async()
.await;
let cache = MfrCertCache::new(server.url());
let result = cache.get_or_fetch("unknown-mfr", None).await;
assert!(
matches!(result, Err(HyperError::UntrustedManufacturer(_))),
"404 should return UntrustedManufacturer, actual: {result:?}"
);
}
}