1use std::{
25 collections::{HashMap, HashSet},
26 num::NonZeroU32,
27 pin::Pin,
28 sync::{Arc, Mutex},
29 time::{Duration, SystemTime, UNIX_EPOCH},
30};
31
32use arc_swap::ArcSwap;
33use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
34use rustls::{
35 DigitallySignedStruct, DistinguishedName, Error as TlsError, RootCertStore, SignatureScheme,
36 client::danger::HandshakeSignatureValid,
37 pki_types::{CertificateDer, CertificateRevocationListDer, UnixTime},
38 server::{
39 WebPkiClientVerifier,
40 danger::{ClientCertVerified, ClientCertVerifier},
41 },
42};
43use tokio::{
44 net::lookup_host,
45 sync::{RwLock, Semaphore, mpsc},
46 task::JoinSet,
47 time::{Instant, Sleep},
48};
49use tokio_util::sync::CancellationToken;
50use url::Url;
51use x509_parser::{
52 extensions::{DistributionPointName, GeneralName, ParsedExtension},
53 prelude::{FromDer, X509Certificate},
54 revocation_list::CertificateRevocationList,
55};
56
57use crate::{
58 auth::MtlsConfig,
59 error::McpxError,
60 ssrf::{check_scheme, ip_block_reason},
61};
62
63const BOOTSTRAP_TIMEOUT: Duration = Duration::from_secs(10);
64const MIN_AUTO_REFRESH: Duration = Duration::from_mins(10);
65const MAX_AUTO_REFRESH: Duration = Duration::from_hours(24);
66const CRL_CONNECT_TIMEOUT: Duration = Duration::from_secs(3);
69
70#[derive(Clone, Debug)]
72#[non_exhaustive]
73pub struct CachedCrl {
74 pub der: CertificateRevocationListDer<'static>,
76 pub this_update: SystemTime,
78 pub next_update: Option<SystemTime>,
80 pub fetched_at: SystemTime,
82 pub source_url: String,
84}
85
86pub(crate) struct VerifierHandle(pub Arc<dyn ClientCertVerifier>);
87
88impl std::fmt::Debug for VerifierHandle {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("VerifierHandle").finish_non_exhaustive()
91 }
92}
93
94#[allow(
96 missing_debug_implementations,
97 reason = "contains ArcSwap and dyn verifier internals"
98)]
99#[non_exhaustive]
100pub struct CrlSet {
101 inner_verifier: ArcSwap<VerifierHandle>,
102 pub cache: RwLock<HashMap<String, CachedCrl>>,
104 pub roots: Arc<RootCertStore>,
106 pub config: MtlsConfig,
108 pub discover_tx: mpsc::UnboundedSender<String>,
110 client: reqwest::Client,
111 seen_urls: Mutex<HashSet<String>>,
112 cached_urls: Mutex<HashSet<String>>,
113 global_fetch_sem: Arc<Semaphore>,
115 host_semaphores: Arc<tokio::sync::Mutex<HashMap<String, Arc<Semaphore>>>>,
120 discovery_limiter: Arc<DefaultDirectRateLimiter>,
131 max_response_bytes: u64,
134 last_cap_warn: Mutex<HashMap<&'static str, Instant>>,
135}
136
137impl CrlSet {
138 fn new(
139 roots: Arc<RootCertStore>,
140 config: MtlsConfig,
141 discover_tx: mpsc::UnboundedSender<String>,
142 initial_cache: HashMap<String, CachedCrl>,
143 ) -> Result<Arc<Self>, McpxError> {
144 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
152 let resolver: Arc<dyn reqwest::dns::Resolve> =
153 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
154 Arc::clone(&allowlist),
155 #[cfg(any(test, feature = "test-helpers"))]
156 Arc::new(std::sync::atomic::AtomicBool::new(false)),
157 #[cfg(not(any(test, feature = "test-helpers")))]
158 (),
159 ));
160
161 let client = reqwest::Client::builder()
162 .no_proxy()
164 .dns_resolver(Arc::clone(&resolver))
165 .timeout(config.crl_fetch_timeout)
166 .connect_timeout(CRL_CONNECT_TIMEOUT)
167 .tcp_keepalive(None)
168 .redirect(reqwest::redirect::Policy::none())
169 .user_agent(format!("rmcp-server-kit/{}", env!("CARGO_PKG_VERSION")))
170 .build()
171 .map_err(|error| McpxError::Startup(format!("CRL HTTP client init: {error}")))?;
172
173 let initial_verifier = rebuild_verifier(&roots, &config, &initial_cache)?;
174 let seen_urls = initial_cache.keys().cloned().collect::<HashSet<_>>();
175 let cached_urls = seen_urls.clone();
176
177 let concurrency = config.crl_max_concurrent_fetches.max(1);
178 let global_fetch_sem = Arc::new(Semaphore::new(concurrency));
179 let host_semaphores = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
180
181 let rate =
182 NonZeroU32::new(config.crl_discovery_rate_per_min.max(1)).unwrap_or(NonZeroU32::MIN);
183 let discovery_limiter = Arc::new(RateLimiter::direct(Quota::per_minute(rate)));
184
185 let max_response_bytes = config.crl_max_response_bytes;
186
187 Ok(Arc::new(Self {
188 inner_verifier: ArcSwap::from_pointee(VerifierHandle(initial_verifier)),
189 cache: RwLock::new(initial_cache),
190 roots,
191 config,
192 discover_tx,
193 client,
194 seen_urls: Mutex::new(seen_urls),
195 cached_urls: Mutex::new(cached_urls),
196 global_fetch_sem,
197 host_semaphores,
198 discovery_limiter,
199 max_response_bytes,
200 last_cap_warn: Mutex::new(HashMap::new()),
201 }))
202 }
203
204 fn warn_cap_exceeded_throttled(&self, which: &'static str) {
205 let now = Instant::now();
206 let cooldown = Duration::from_mins(1);
207 let should_warn = match self.last_cap_warn.lock() {
208 Ok(mut guard) => {
209 let should_emit = guard
210 .get(which)
211 .is_none_or(|last| now.saturating_duration_since(*last) >= cooldown);
212 if should_emit {
213 guard.insert(which, now);
214 }
215 should_emit
216 }
217 Err(poisoned) => {
218 let mut guard = poisoned.into_inner();
219 let should_emit = guard
220 .get(which)
221 .is_none_or(|last| now.saturating_duration_since(*last) >= cooldown);
222 if should_emit {
223 guard.insert(which, now);
224 }
225 should_emit
226 }
227 };
228
229 if should_warn {
230 tracing::warn!(which = which, "CRL map cap exceeded; dropping newest entry");
231 }
232 }
233
234 async fn insert_cache_entry(&self, url: String, cached: CachedCrl) -> bool {
235 let inserted = {
236 let mut guard = self.cache.write().await;
237 if guard.len() >= self.config.crl_max_cache_entries && !guard.contains_key(&url) {
238 false
239 } else {
240 guard.insert(url.clone(), cached);
241 true
242 }
243 };
244
245 if inserted {
246 match self.cached_urls.lock() {
247 Ok(mut cached_urls) => {
248 cached_urls.insert(url);
249 }
250 Err(poisoned) => {
251 poisoned.into_inner().insert(url);
252 }
253 }
254 } else {
255 self.warn_cap_exceeded_throttled("cache");
256 }
257
258 inserted
259 }
260
261 pub async fn force_refresh(&self) -> Result<(), McpxError> {
267 let urls = {
268 let cache = self.cache.read().await;
269 cache.keys().cloned().collect::<Vec<_>>()
270 };
271 self.refresh_urls(urls).await
272 }
273
274 async fn refresh_due_urls(&self) -> Result<(), McpxError> {
275 let now = SystemTime::now();
276 let urls = {
277 let cache = self.cache.read().await;
278 cache
279 .iter()
280 .filter(|(_, cached)| {
281 should_refresh_cached(cached, now, self.config.crl_refresh_interval)
282 })
283 .map(|(url, _)| url.clone())
284 .collect::<Vec<_>>()
285 };
286
287 if urls.is_empty() {
288 return Ok(());
289 }
290
291 self.refresh_urls(urls).await
292 }
293
294 async fn refresh_urls(&self, urls: Vec<String>) -> Result<(), McpxError> {
295 let results = self.fetch_url_results(urls).await;
296 let now = SystemTime::now();
297 let mut cache = self.cache.write().await;
298 let mut changed = false;
299
300 for (url, result) in results {
301 match result {
302 Ok(cached) => {
303 if cache.len() >= self.config.crl_max_cache_entries && !cache.contains_key(&url)
304 {
305 drop(cache);
306 self.warn_cap_exceeded_throttled("cache");
307 cache = self.cache.write().await;
308 continue;
309 }
310 cache.insert(url.clone(), cached);
311 changed = true;
312 match self.cached_urls.lock() {
313 Ok(mut cached_urls) => {
314 cached_urls.insert(url);
315 }
316 Err(poisoned) => {
317 poisoned.into_inner().insert(url);
318 }
319 }
320 }
321 Err(error) => {
322 let remove_entry = cache.get(&url).is_some_and(|existing| {
323 existing
324 .next_update
325 .and_then(|next| next.checked_add(self.config.crl_stale_grace))
326 .is_some_and(|deadline| now > deadline)
327 });
328 tracing::warn!(url = %url, error = %error, "CRL refresh failed");
329 if remove_entry {
330 cache.remove(&url);
331 changed = true;
332 match self.cached_urls.lock() {
333 Ok(mut cached_urls) => {
334 cached_urls.remove(&url);
335 }
336 Err(poisoned) => {
337 poisoned.into_inner().remove(&url);
338 }
339 }
340 match self.seen_urls.lock() {
341 Ok(mut seen_urls) => {
342 seen_urls.remove(&url);
343 }
344 Err(poisoned) => {
345 poisoned.into_inner().remove(&url);
346 }
347 }
348 }
349 }
350 }
351 }
352
353 if changed {
354 self.swap_verifier_from_cache(&cache)?;
355 }
356
357 Ok(())
358 }
359
360 async fn fetch_and_store_url(&self, url: String) -> Result<(), McpxError> {
361 let cached = gated_fetch(
362 &self.client,
363 &self.global_fetch_sem,
364 &self.host_semaphores,
365 &url,
366 self.config.crl_allow_http,
367 self.max_response_bytes,
368 self.config.crl_max_host_semaphores,
369 )
370 .await?;
371 if !self.insert_cache_entry(url, cached).await {
372 return Ok(());
373 }
374 let cache = self.cache.read().await;
375 self.swap_verifier_from_cache(&cache)?;
376 Ok(())
377 }
378
379 fn note_discovered_urls(&self, urls: &[String]) -> bool {
380 let mut missing_cached = false;
389
390 let candidates: Vec<String> = match self.seen_urls.lock() {
401 Ok(seen) => urls
402 .iter()
403 .filter(|url| !seen.contains(*url))
404 .cloned()
405 .collect(),
406 Err(_) => Vec::new(),
407 };
408
409 for url in candidates {
416 if self.discovery_limiter.check().is_err() {
417 tracing::warn!(
418 url = %url,
419 "discovery_rate_limited: dropped CDP URL beyond per-minute cap (will be retried on next handshake observing this URL)"
420 );
421 continue;
422 }
423 if self.discover_tx.send(url.clone()).is_err() {
424 tracing::debug!(
427 url = %url,
428 "discover channel closed; dropping CDP URL without marking seen"
429 );
430 continue;
431 }
432 let mut guard = self
434 .seen_urls
435 .lock()
436 .unwrap_or_else(std::sync::PoisonError::into_inner);
437 if guard.len() >= self.config.crl_max_seen_urls {
438 self.warn_cap_exceeded_throttled("seen_urls");
439 break;
440 }
441 guard.insert(url);
442 }
443
444 if self.config.crl_deny_on_unavailable {
445 let cached = self
446 .cached_urls
447 .lock()
448 .ok()
449 .map(|guard| guard.clone())
450 .unwrap_or_default();
451 missing_cached = urls.iter().any(|url| !cached.contains(url));
452 }
453
454 missing_cached
455 }
456
457 #[doc(hidden)]
463 pub fn __test_with_prepopulated_crls(
464 roots: Arc<RootCertStore>,
465 config: MtlsConfig,
466 prefilled_crls: Vec<CertificateRevocationListDer<'static>>,
467 ) -> Result<Arc<Self>, McpxError> {
468 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
469 drop(discover_rx);
470
471 let mut initial_cache = HashMap::new();
472 for (index, der) in prefilled_crls.into_iter().enumerate() {
473 let source_url = format!("memory://crl/{index}");
474 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
475 initial_cache.insert(
476 source_url.clone(),
477 CachedCrl {
478 der,
479 this_update,
480 next_update,
481 fetched_at: SystemTime::now(),
482 source_url,
483 },
484 );
485 }
486
487 Self::new(roots, config, discover_tx, initial_cache)
488 }
489
490 #[doc(hidden)]
502 pub fn __test_with_kept_receiver(
503 roots: Arc<RootCertStore>,
504 config: MtlsConfig,
505 prefilled_crls: Vec<CertificateRevocationListDer<'static>>,
506 ) -> Result<(Arc<Self>, mpsc::UnboundedReceiver<String>), McpxError> {
507 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
508
509 let mut initial_cache = HashMap::new();
510 for (index, der) in prefilled_crls.into_iter().enumerate() {
511 let source_url = format!("memory://crl/{index}");
512 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
513 initial_cache.insert(
514 source_url.clone(),
515 CachedCrl {
516 der,
517 this_update,
518 next_update,
519 fetched_at: SystemTime::now(),
520 source_url,
521 },
522 );
523 }
524
525 let crl_set = Self::new(roots, config, discover_tx, initial_cache)?;
526 Ok((crl_set, discover_rx))
527 }
528
529 #[doc(hidden)]
534 pub fn __test_check_discovery_rate(&self, urls: &[String]) -> (usize, usize) {
535 let mut accepted = 0usize;
536 let mut dropped = 0usize;
537 for url in urls {
538 if self.discovery_limiter.check().is_ok() {
539 let _ = self.discover_tx.send(url.clone());
540 accepted += 1;
541 } else {
542 dropped += 1;
543 }
544 }
545 (accepted, dropped)
546 }
547
548 #[doc(hidden)]
552 pub fn __test_note_discovered_urls(&self, urls: &[String]) -> bool {
553 let missing_cached = self.note_discovered_urls(urls);
554 if self.discover_tx.is_closed() {
555 match self.seen_urls.lock() {
556 Ok(mut guard) => {
557 for url in urls {
558 if guard.contains(url) {
559 continue;
560 }
561 if guard.len() >= self.config.crl_max_seen_urls {
562 self.warn_cap_exceeded_throttled("seen_urls");
563 break;
564 }
565 guard.insert(url.clone());
566 }
567 }
568 Err(poisoned) => {
569 let mut guard = poisoned.into_inner();
570 for url in urls {
571 if guard.contains(url) {
572 continue;
573 }
574 if guard.len() >= self.config.crl_max_seen_urls {
575 self.warn_cap_exceeded_throttled("seen_urls");
576 break;
577 }
578 guard.insert(url.clone());
579 }
580 }
581 }
582 }
583 missing_cached
584 }
585
586 #[doc(hidden)]
591 pub fn __test_is_seen(&self, url: &str) -> bool {
592 match self.seen_urls.lock() {
593 Ok(seen) => seen.contains(url),
594 Err(_) => false,
595 }
596 }
597
598 #[cfg(any(test, feature = "test-helpers"))]
601 #[doc(hidden)]
602 pub fn __test_host_semaphore_count(&self) -> usize {
603 self.host_semaphores
604 .try_lock()
605 .map_or(0, |guard| guard.len())
606 }
607
608 #[cfg(any(test, feature = "test-helpers"))]
610 #[doc(hidden)]
611 pub fn __test_cache_len(&self) -> usize {
612 self.cache.try_read().map_or(0, |guard| guard.len())
613 }
614
615 #[cfg(any(test, feature = "test-helpers"))]
617 #[doc(hidden)]
618 pub fn __test_cache_contains(&self, url: &str) -> bool {
619 self.cache
620 .try_read()
621 .is_ok_and(|guard| guard.contains_key(url))
622 }
623
624 #[cfg(any(test, feature = "test-helpers"))]
631 #[doc(hidden)]
632 pub async fn __test_trigger_fetch(&self, url: &str) -> Result<(), McpxError> {
633 if let Err(error) = gated_fetch(
634 &self.client,
635 &self.global_fetch_sem,
636 &self.host_semaphores,
637 url,
638 self.config.crl_allow_http,
639 self.max_response_bytes,
640 self.config.crl_max_host_semaphores,
641 )
642 .await
643 {
644 if error
645 .to_string()
646 .contains("crl_host_semaphore_cap_exceeded")
647 {
648 Err(error)
649 } else {
650 Ok(())
651 }
652 } else {
653 Ok(())
654 }
655 }
656
657 #[cfg(any(test, feature = "test-helpers"))]
669 #[doc(hidden)]
670 pub async fn __test_insert_cache(&self, url: &str, cached: CachedCrl) {
671 let _ = self.insert_cache_entry(url.to_owned(), cached).await;
672 }
673
674 #[cfg(any(test, feature = "test-helpers"))]
679 #[doc(hidden)]
680 pub async fn __test_trigger_refresh_url(&self, url: &str) -> Result<(), McpxError> {
681 self.refresh_urls(vec![url.to_owned()]).await
682 }
683
684 async fn fetch_url_results(
685 &self,
686 urls: Vec<String>,
687 ) -> Vec<(String, Result<CachedCrl, McpxError>)> {
688 let mut tasks = JoinSet::new();
689 for url in urls {
690 let client = self.client.clone();
691 let global_sem = Arc::clone(&self.global_fetch_sem);
692 let host_map = Arc::clone(&self.host_semaphores);
693 let allow_http = self.config.crl_allow_http;
694 let max_bytes = self.max_response_bytes;
695 let max_host_semaphores = self.config.crl_max_host_semaphores;
696 tasks.spawn(async move {
697 let result = gated_fetch(
698 &client,
699 &global_sem,
700 &host_map,
701 &url,
702 allow_http,
703 max_bytes,
704 max_host_semaphores,
705 )
706 .await;
707 (url, result)
708 });
709 }
710
711 let mut results = Vec::new();
712 while let Some(joined) = tasks.join_next().await {
713 match joined {
714 Ok(result) => results.push(result),
715 Err(error) => {
716 tracing::warn!(error = %error, "CRL refresh task join failed");
717 }
718 }
719 }
720
721 results
722 }
723
724 fn swap_verifier_from_cache(
725 &self,
726 cache: &impl std::ops::Deref<Target = HashMap<String, CachedCrl>>,
727 ) -> Result<(), McpxError> {
728 let verifier = rebuild_verifier(&self.roots, &self.config, cache)?;
729 self.inner_verifier
730 .store(Arc::new(VerifierHandle(verifier)));
731 Ok(())
732 }
733}
734
735impl CachedCrl {
736 #[cfg(any(test, feature = "test-helpers"))]
740 #[doc(hidden)]
741 #[must_use]
742 pub fn __test_synthetic(now: SystemTime) -> Self {
743 Self {
744 der: CertificateRevocationListDer::from(vec![0x30, 0x00]),
745 this_update: now,
746 next_update: now.checked_add(Duration::from_hours(24)),
747 fetched_at: now,
748 source_url: "test://synthetic".to_owned(),
749 }
750 }
751
752 #[cfg(any(test, feature = "test-helpers"))]
756 #[doc(hidden)]
757 #[must_use]
758 pub fn __test_stale(reference_past: SystemTime) -> Self {
759 Self {
760 der: CertificateRevocationListDer::from(vec![0x30, 0x00]),
761 this_update: reference_past,
762 next_update: Some(reference_past),
763 fetched_at: reference_past,
764 source_url: "test://stale".to_owned(),
765 }
766 }
767}
768
769pub struct DynamicClientCertVerifier {
772 inner: Arc<CrlSet>,
773 dn_subjects: Vec<DistinguishedName>,
774}
775
776impl DynamicClientCertVerifier {
777 #[must_use]
779 pub fn new(inner: Arc<CrlSet>) -> Self {
780 Self {
781 dn_subjects: inner.roots.subjects(),
782 inner,
783 }
784 }
785}
786
787impl std::fmt::Debug for DynamicClientCertVerifier {
788 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
789 f.debug_struct("DynamicClientCertVerifier")
790 .field("dn_subjects_len", &self.dn_subjects.len())
791 .finish_non_exhaustive()
792 }
793}
794
795impl ClientCertVerifier for DynamicClientCertVerifier {
796 fn offer_client_auth(&self) -> bool {
797 let verifier = self.inner.inner_verifier.load();
798 verifier.0.offer_client_auth()
799 }
800
801 fn client_auth_mandatory(&self) -> bool {
802 let verifier = self.inner.inner_verifier.load();
803 verifier.0.client_auth_mandatory()
804 }
805
806 fn root_hint_subjects(&self) -> &[DistinguishedName] {
807 &self.dn_subjects
808 }
809
810 fn verify_client_cert(
811 &self,
812 end_entity: &CertificateDer<'_>,
813 intermediates: &[CertificateDer<'_>],
814 now: UnixTime,
815 ) -> Result<ClientCertVerified, TlsError> {
816 let mut discovered =
828 extract_cdp_urls(end_entity.as_ref(), self.inner.config.crl_allow_http);
829 for intermediate in intermediates {
830 discovered.extend(extract_cdp_urls(
831 intermediate.as_ref(),
832 self.inner.config.crl_allow_http,
833 ));
834 }
835 discovered.sort();
836 discovered.dedup();
837
838 if self.inner.note_discovered_urls(&discovered) {
839 return Err(TlsError::General(
840 "client certificate revocation status unavailable".to_owned(),
841 ));
842 }
843
844 let verifier = self.inner.inner_verifier.load();
845 verifier
846 .0
847 .verify_client_cert(end_entity, intermediates, now)
848 }
849
850 fn verify_tls12_signature(
851 &self,
852 message: &[u8],
853 cert: &CertificateDer<'_>,
854 dss: &DigitallySignedStruct,
855 ) -> Result<HandshakeSignatureValid, TlsError> {
856 let verifier = self.inner.inner_verifier.load();
857 verifier.0.verify_tls12_signature(message, cert, dss)
858 }
859
860 fn verify_tls13_signature(
861 &self,
862 message: &[u8],
863 cert: &CertificateDer<'_>,
864 dss: &DigitallySignedStruct,
865 ) -> Result<HandshakeSignatureValid, TlsError> {
866 let verifier = self.inner.inner_verifier.load();
867 verifier.0.verify_tls13_signature(message, cert, dss)
868 }
869
870 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
871 let verifier = self.inner.inner_verifier.load();
872 verifier.0.supported_verify_schemes()
873 }
874
875 fn requires_raw_public_keys(&self) -> bool {
876 let verifier = self.inner.inner_verifier.load();
877 verifier.0.requires_raw_public_keys()
878 }
879}
880
881#[must_use]
889pub fn extract_cdp_urls(cert_der: &[u8], allow_http: bool) -> Vec<String> {
890 let Ok((_, cert)) = X509Certificate::from_der(cert_der) else {
891 return Vec::new();
892 };
893
894 let mut urls = Vec::new();
895 for ext in cert.extensions() {
896 if let ParsedExtension::CRLDistributionPoints(cdps) = ext.parsed_extension() {
897 for point in cdps.iter() {
898 if let Some(DistributionPointName::FullName(names)) = &point.distribution_point {
899 for name in names {
900 if let GeneralName::URI(uri) = name {
901 let raw = *uri;
902 let Ok(parsed) = Url::parse(raw) else {
903 tracing::debug!(url = %raw, "CDP URL parse failed; dropped");
904 continue;
905 };
906 if let Err(reason) = check_scheme(&parsed, allow_http) {
907 tracing::debug!(
908 url = %raw,
909 reason,
910 "CDP URL rejected by scheme guard; dropped"
911 );
912 continue;
913 }
914 urls.push(parsed.into());
915 }
916 }
917 }
918 }
919 }
920 }
921
922 urls
923}
924
925#[allow(
932 clippy::cognitive_complexity,
933 reason = "bootstrap coordinates timeout, parallel fetches, and partial-cache recovery"
934)]
935pub async fn bootstrap_fetch(
936 roots: Arc<RootCertStore>,
937 ca_certs: &[CertificateDer<'static>],
938 config: MtlsConfig,
939) -> Result<(Arc<CrlSet>, mpsc::UnboundedReceiver<String>), McpxError> {
940 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
941
942 let mut urls = ca_certs
943 .iter()
944 .flat_map(|cert| extract_cdp_urls(cert.as_ref(), config.crl_allow_http))
945 .collect::<Vec<_>>();
946 urls.sort();
947 urls.dedup();
948
949 let bootstrap_allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
953 let bootstrap_resolver: Arc<dyn reqwest::dns::Resolve> =
954 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
955 Arc::clone(&bootstrap_allowlist),
956 #[cfg(any(test, feature = "test-helpers"))]
957 Arc::new(std::sync::atomic::AtomicBool::new(false)),
958 #[cfg(not(any(test, feature = "test-helpers")))]
959 (),
960 ));
961
962 let client = reqwest::Client::builder()
963 .no_proxy()
965 .dns_resolver(Arc::clone(&bootstrap_resolver))
966 .timeout(config.crl_fetch_timeout)
967 .connect_timeout(CRL_CONNECT_TIMEOUT)
968 .tcp_keepalive(None)
969 .redirect(reqwest::redirect::Policy::none())
970 .user_agent(format!("rmcp-server-kit/{}", env!("CARGO_PKG_VERSION")))
971 .build()
972 .map_err(|error| McpxError::Startup(format!("CRL HTTP client init: {error}")))?;
973
974 let bootstrap_concurrency = config.crl_max_concurrent_fetches.max(1);
978 let global_sem = Arc::new(Semaphore::new(bootstrap_concurrency));
979 let host_semaphores = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
980 let allow_http = config.crl_allow_http;
981 let max_bytes = config.crl_max_response_bytes;
982 let max_host_semaphores = config.crl_max_host_semaphores;
983
984 let mut initial_cache = HashMap::new();
985 let mut tasks = JoinSet::new();
986 for url in &urls {
987 let client = client.clone();
988 let url = url.clone();
989 let global_sem = Arc::clone(&global_sem);
990 let host_semaphores = Arc::clone(&host_semaphores);
991 tasks.spawn(async move {
992 let result = gated_fetch(
993 &client,
994 &global_sem,
995 &host_semaphores,
996 &url,
997 allow_http,
998 max_bytes,
999 max_host_semaphores,
1000 )
1001 .await;
1002 (url, result)
1003 });
1004 }
1005
1006 let timeout: Sleep = tokio::time::sleep(BOOTSTRAP_TIMEOUT);
1007 tokio::pin!(timeout);
1008
1009 while !tasks.is_empty() {
1010 tokio::select! {
1011 () = &mut timeout => {
1012 tracing::warn!("CRL bootstrap timed out after {:?}", BOOTSTRAP_TIMEOUT);
1013 break;
1014 }
1015 maybe_joined = tasks.join_next() => {
1016 let Some(joined) = maybe_joined else {
1017 break;
1018 };
1019 match joined {
1020 Ok((url, Ok(cached))) => {
1021 initial_cache.insert(url, cached);
1022 }
1023 Ok((url, Err(error))) => {
1024 tracing::warn!(url = %url, error = %error, "CRL bootstrap fetch failed");
1025 }
1026 Err(error) => {
1027 tracing::warn!(error = %error, "CRL bootstrap task join failed");
1028 }
1029 }
1030 }
1031 }
1032 }
1033
1034 let set = CrlSet::new(roots, config, discover_tx, initial_cache)?;
1035 Ok((set, discover_rx))
1036}
1037
1038#[allow(
1040 clippy::cognitive_complexity,
1041 reason = "refresher loop intentionally handles shutdown, timer, and discovery in one select"
1042)]
1043pub async fn run_crl_refresher(
1044 set: Arc<CrlSet>,
1045 mut discover_rx: mpsc::UnboundedReceiver<String>,
1046 shutdown: CancellationToken,
1047) {
1048 let mut refresh_sleep = schedule_next_refresh(&set).await;
1049
1050 loop {
1051 tokio::select! {
1052 () = shutdown.cancelled() => {
1053 break;
1054 }
1055 () = &mut refresh_sleep => {
1056 if let Err(error) = set.refresh_due_urls().await {
1057 tracing::warn!(error = %error, "CRL periodic refresh failed");
1058 }
1059 refresh_sleep = schedule_next_refresh(&set).await;
1060 }
1061 maybe_url = discover_rx.recv() => {
1062 let Some(url) = maybe_url else {
1063 break;
1064 };
1065 if let Err(error) = set.fetch_and_store_url(url.clone()).await {
1066 tracing::warn!(url = %url, error = %error, "CRL discovery fetch failed");
1067 }
1068 refresh_sleep = schedule_next_refresh(&set).await;
1069 }
1070 }
1071 }
1072}
1073
1074pub fn rebuild_verifier<S: std::hash::BuildHasher>(
1080 roots: &Arc<RootCertStore>,
1081 config: &MtlsConfig,
1082 cache: &HashMap<String, CachedCrl, S>,
1083) -> Result<Arc<dyn ClientCertVerifier>, McpxError> {
1084 let mut builder = WebPkiClientVerifier::builder(Arc::clone(roots));
1085
1086 if !cache.is_empty() {
1087 let crls = cache
1088 .values()
1089 .map(|cached| cached.der.clone())
1090 .collect::<Vec<_>>();
1091 builder = builder.with_crls(crls);
1092 }
1093 if config.crl_end_entity_only {
1094 builder = builder.only_check_end_entity_revocation();
1095 }
1096 if !config.crl_deny_on_unavailable {
1097 builder = builder.allow_unknown_revocation_status();
1098 }
1099 if config.crl_enforce_expiration {
1100 builder = builder.enforce_revocation_expiration();
1101 }
1102 if !config.required {
1103 builder = builder.allow_unauthenticated();
1104 }
1105
1106 builder
1107 .build()
1108 .map_err(|error| McpxError::Tls(format!("mTLS verifier error: {error}")))
1109}
1110
1111pub fn parse_crl_metadata(der: &[u8]) -> Result<(SystemTime, Option<SystemTime>), McpxError> {
1117 let (_, crl) = CertificateRevocationList::from_der(der)
1118 .map_err(|error| McpxError::Tls(format!("invalid CRL DER: {error:?}")))?;
1119
1120 Ok((
1121 asn1_time_to_system_time(crl.last_update()),
1122 crl.next_update().map(asn1_time_to_system_time),
1123 ))
1124}
1125
1126async fn schedule_next_refresh(set: &CrlSet) -> Pin<Box<Sleep>> {
1127 let duration = next_refresh_delay(set).await;
1128 boxed_sleep(duration)
1129}
1130
1131fn boxed_sleep(duration: Duration) -> Pin<Box<Sleep>> {
1132 Box::pin(tokio::time::sleep_until(Instant::now() + duration))
1133}
1134
1135async fn next_refresh_delay(set: &CrlSet) -> Duration {
1136 if let Some(interval) = set.config.crl_refresh_interval {
1137 return clamp_refresh(interval);
1138 }
1139
1140 let now = SystemTime::now();
1141 let cache = set.cache.read().await;
1142 let mut next = MAX_AUTO_REFRESH;
1143
1144 for cached in cache.values() {
1145 if let Some(next_update) = cached.next_update {
1146 let duration = next_update.duration_since(now).unwrap_or(Duration::ZERO);
1147 next = next.min(clamp_refresh(duration));
1148 }
1149 }
1150
1151 next
1152}
1153
1154fn acquire_host_semaphore(
1164 map: &mut HashMap<String, Arc<Semaphore>>,
1165 host_key: &str,
1166 max_host_semaphores: usize,
1167) -> Result<Arc<Semaphore>, McpxError> {
1168 if !map.contains_key(host_key) {
1169 if map.len() >= max_host_semaphores {
1170 map.retain(|_, semaphore| Arc::strong_count(semaphore) > 1);
1172 }
1173 if map.len() >= max_host_semaphores {
1174 return Err(McpxError::Config(
1175 "crl_host_semaphore_cap_exceeded: too many distinct CRL hosts in flight".to_owned(),
1176 ));
1177 }
1178 map.insert(host_key.to_owned(), Arc::new(Semaphore::new(1)));
1179 }
1180 match map.get(host_key) {
1181 Some(semaphore) => Ok(Arc::clone(semaphore)),
1182 None => Err(McpxError::Tls(
1183 "CRL host semaphore missing after insertion".to_owned(),
1184 )),
1185 }
1186}
1187
1188async fn gated_fetch(
1196 client: &reqwest::Client,
1197 global_sem: &Arc<Semaphore>,
1198 host_semaphores: &Arc<tokio::sync::Mutex<HashMap<String, Arc<Semaphore>>>>,
1199 url: &str,
1200 allow_http: bool,
1201 max_bytes: u64,
1202 max_host_semaphores: usize,
1203) -> Result<CachedCrl, McpxError> {
1204 let host_key = Url::parse(url)
1205 .ok()
1206 .and_then(|u| u.host_str().map(str::to_owned))
1207 .unwrap_or_else(|| url.to_owned());
1208
1209 let host_sem = {
1210 let mut map = host_semaphores.lock().await;
1211 acquire_host_semaphore(&mut map, &host_key, max_host_semaphores)?
1212 };
1213
1214 let _global_permit = Arc::clone(global_sem)
1215 .acquire_owned()
1216 .await
1217 .map_err(|error| McpxError::Tls(format!("CRL global semaphore closed: {error}")))?;
1218 let _host_permit = host_sem
1219 .acquire_owned()
1220 .await
1221 .map_err(|error| McpxError::Tls(format!("CRL host semaphore closed: {error}")))?;
1222
1223 fetch_crl(client, url, allow_http, max_bytes).await
1224}
1225
1226async fn fetch_crl(
1227 client: &reqwest::Client,
1228 url: &str,
1229 allow_http: bool,
1230 max_bytes: u64,
1231) -> Result<CachedCrl, McpxError> {
1232 let parsed =
1233 Url::parse(url).map_err(|error| McpxError::Tls(format!("CRL URL parse {url}: {error}")))?;
1234
1235 if let Err(reason) = check_scheme(&parsed, allow_http) {
1236 tracing::warn!(url = %url, reason, "CRL fetch denied: scheme");
1237 return Err(McpxError::Tls(format!(
1238 "CRL scheme rejected ({reason}): {url}"
1239 )));
1240 }
1241
1242 let host = parsed
1243 .host_str()
1244 .ok_or_else(|| McpxError::Tls(format!("CRL URL has no host: {url}")))?;
1245 let port = parsed
1246 .port_or_known_default()
1247 .ok_or_else(|| McpxError::Tls(format!("CRL URL has no known port: {url}")))?;
1248
1249 let addrs = lookup_host((host, port))
1250 .await
1251 .map_err(|error| McpxError::Tls(format!("CRL DNS resolution {url}: {error}")))?;
1252
1253 let mut any_addr = false;
1254 for addr in addrs {
1255 any_addr = true;
1256 if let Some(reason) = ip_block_reason(addr.ip()) {
1257 tracing::warn!(
1258 url = %url,
1259 resolved_ip = %addr.ip(),
1260 reason,
1261 "CRL fetch denied: blocked IP"
1262 );
1263 return Err(McpxError::Tls(format!(
1264 "CRL host resolved to blocked IP ({reason}): {url}"
1265 )));
1266 }
1267 }
1268 if !any_addr {
1269 return Err(McpxError::Tls(format!(
1270 "CRL DNS resolution returned no addresses: {url}"
1271 )));
1272 }
1273
1274 let mut response = client
1275 .get(url)
1276 .send()
1277 .await
1278 .map_err(|error| McpxError::Tls(format!("CRL fetch {url}: {error}")))?
1279 .error_for_status()
1280 .map_err(|error| McpxError::Tls(format!("CRL fetch {url}: {error}")))?;
1281
1282 let initial_capacity = usize::try_from(max_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
1285 let mut body: Vec<u8> = Vec::with_capacity(initial_capacity);
1286 while let Some(chunk) = response
1287 .chunk()
1288 .await
1289 .map_err(|error| McpxError::Tls(format!("CRL read {url}: {error}")))?
1290 {
1291 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
1292 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
1293 if body_len.saturating_add(chunk_len) > max_bytes {
1294 return Err(McpxError::Tls(format!(
1295 "CRL body exceeded cap of {max_bytes} bytes: {url}"
1296 )));
1297 }
1298 body.extend_from_slice(&chunk);
1299 }
1300
1301 let der = CertificateRevocationListDer::from(body);
1302 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
1303
1304 Ok(CachedCrl {
1305 der,
1306 this_update,
1307 next_update,
1308 fetched_at: SystemTime::now(),
1309 source_url: url.to_owned(),
1310 })
1311}
1312
1313fn should_refresh_cached(
1314 cached: &CachedCrl,
1315 now: SystemTime,
1316 fixed_interval: Option<Duration>,
1317) -> bool {
1318 if let Some(interval) = fixed_interval {
1319 return cached
1320 .fetched_at
1321 .checked_add(clamp_refresh(interval))
1322 .is_none_or(|deadline| now >= deadline);
1323 }
1324
1325 cached
1326 .next_update
1327 .is_none_or(|next_update| now >= next_update)
1328}
1329
1330fn clamp_refresh(duration: Duration) -> Duration {
1331 duration.clamp(MIN_AUTO_REFRESH, MAX_AUTO_REFRESH)
1332}
1333
1334const MAX_ASN1_TIMESTAMP_SECS: u64 = 253_402_300_799;
1338
1339fn asn1_time_to_system_time(time: x509_parser::time::ASN1Time) -> SystemTime {
1348 let timestamp = time.timestamp();
1349 if timestamp >= 0 {
1350 let seconds = u64::try_from(timestamp)
1351 .unwrap_or(0)
1352 .min(MAX_ASN1_TIMESTAMP_SECS);
1353 UNIX_EPOCH
1354 .checked_add(Duration::from_secs(seconds))
1355 .unwrap_or(UNIX_EPOCH)
1356 } else {
1357 UNIX_EPOCH
1358 .checked_sub(Duration::from_secs(timestamp.unsigned_abs()))
1359 .unwrap_or(UNIX_EPOCH)
1360 }
1361}
1362
1363#[cfg(test)]
1364mod tests {
1365 use super::*;
1366
1367 fn asn1(timestamp: i64) -> x509_parser::time::ASN1Time {
1368 x509_parser::time::ASN1Time::from_timestamp(timestamp).expect("valid ASN.1 timestamp")
1369 }
1370
1371 #[test]
1372 fn asn1_time_clamps_unrepresentable_timestamps() {
1373 let year_1500 = asn1_time_to_system_time(asn1(-14_831_769_600));
1378 assert!(year_1500 <= UNIX_EPOCH);
1379 #[cfg(windows)]
1380 assert_eq!(year_1500, UNIX_EPOCH);
1381
1382 let year_1601 = asn1_time_to_system_time(asn1(-11_644_473_600));
1385 assert!(year_1601 <= UNIX_EPOCH);
1386
1387 assert!(asn1_time_to_system_time(asn1(-2)) <= UNIX_EPOCH);
1389
1390 assert_eq!(
1392 asn1_time_to_system_time(asn1(1_700_000_000)),
1393 UNIX_EPOCH + Duration::from_secs(1_700_000_000)
1394 );
1395
1396 let max = i64::try_from(MAX_ASN1_TIMESTAMP_SECS).expect("fits in i64");
1398 assert_eq!(
1399 asn1_time_to_system_time(asn1(max)),
1400 UNIX_EPOCH + Duration::from_secs(MAX_ASN1_TIMESTAMP_SECS)
1401 );
1402 }
1403
1404 #[test]
1405 fn host_semaphore_evicts_idle_at_cap() {
1406 let mut map = HashMap::new();
1407 for i in 0..4 {
1408 drop(
1410 acquire_host_semaphore(&mut map, &format!("idle-{i}.example"), 4)
1411 .expect("under cap"),
1412 );
1413 }
1414 assert_eq!(map.len(), 4);
1415
1416 let sem = acquire_host_semaphore(&mut map, "new-host.example", 4)
1419 .expect("idle eviction frees space for a new host");
1420 assert!(map.contains_key("new-host.example"));
1421 drop(sem);
1422 }
1423
1424 #[test]
1425 fn host_semaphore_keeps_inflight_at_cap() {
1426 let mut map = HashMap::new();
1427 let inflight = acquire_host_semaphore(&mut map, "busy.example", 3).expect("under cap");
1429 for i in 0..2 {
1430 drop(
1431 acquire_host_semaphore(&mut map, &format!("idle-{i}.example"), 3)
1432 .expect("under cap"),
1433 );
1434 }
1435 assert_eq!(map.len(), 3);
1436
1437 drop(
1438 acquire_host_semaphore(&mut map, "new-host.example", 3)
1439 .expect("idle entries evicted while in-flight survives"),
1440 );
1441 assert!(
1442 map.contains_key("busy.example"),
1443 "in-flight host must survive eviction"
1444 );
1445 assert!(map.contains_key("new-host.example"));
1446 drop(inflight);
1447 }
1448
1449 #[test]
1450 fn host_semaphore_cap_error_when_all_inflight() {
1451 let mut map = HashMap::new();
1452 let held: Vec<_> = (0..2)
1453 .map(|i| {
1454 acquire_host_semaphore(&mut map, &format!("busy-{i}.example"), 2)
1455 .expect("under cap")
1456 })
1457 .collect();
1458
1459 let result = acquire_host_semaphore(&mut map, "new-host.example", 2);
1460 assert!(
1461 result.is_err(),
1462 "cap must still reject when every entry has an in-flight fetch"
1463 );
1464 drop(held);
1465 }
1466}