1use std::{
25 collections::{HashMap, HashSet},
26 num::NonZeroU32,
27 pin::Pin,
28 sync::{Arc, Mutex},
29 time::{Duration, SystemTime, UNIX_EPOCH},
30};
31
32use arc_swap::ArcSwap;
33use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
34use rustls::{
35 DigitallySignedStruct, DistinguishedName, Error as TlsError, RootCertStore, SignatureScheme,
36 client::danger::HandshakeSignatureValid,
37 pki_types::{CertificateDer, CertificateRevocationListDer, UnixTime},
38 server::{
39 WebPkiClientVerifier,
40 danger::{ClientCertVerified, ClientCertVerifier},
41 },
42};
43use tokio::{
44 net::lookup_host,
45 sync::{RwLock, Semaphore, mpsc},
46 task::JoinSet,
47 time::{Instant, Sleep},
48};
49use tokio_util::sync::CancellationToken;
50use url::Url;
51use x509_parser::{
52 extensions::{DistributionPointName, GeneralName, ParsedExtension},
53 prelude::{FromDer, X509Certificate},
54 revocation_list::CertificateRevocationList,
55};
56
57use crate::{
58 auth::MtlsConfig,
59 error::McpxError,
60 ssrf::{check_scheme, ip_block_reason},
61};
62
63const BOOTSTRAP_TIMEOUT: Duration = Duration::from_secs(10);
64const MIN_AUTO_REFRESH: Duration = Duration::from_mins(10);
65const MAX_AUTO_REFRESH: Duration = Duration::from_hours(24);
66const CRL_CONNECT_TIMEOUT: Duration = Duration::from_secs(3);
69
70#[derive(Clone, Debug)]
72#[non_exhaustive]
73pub struct CachedCrl {
74 pub der: CertificateRevocationListDer<'static>,
76 pub this_update: SystemTime,
78 pub next_update: Option<SystemTime>,
80 pub fetched_at: SystemTime,
82 pub source_url: String,
84}
85
86pub(crate) struct VerifierHandle(pub Arc<dyn ClientCertVerifier>);
87
88impl std::fmt::Debug for VerifierHandle {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("VerifierHandle").finish_non_exhaustive()
91 }
92}
93
94#[allow(
96 missing_debug_implementations,
97 reason = "contains ArcSwap and dyn verifier internals"
98)]
99#[non_exhaustive]
100pub struct CrlSet {
101 inner_verifier: ArcSwap<VerifierHandle>,
102 pub cache: RwLock<HashMap<String, CachedCrl>>,
104 pub roots: Arc<RootCertStore>,
106 pub config: MtlsConfig,
108 pub discover_tx: mpsc::UnboundedSender<String>,
110 client: reqwest::Client,
111 seen_urls: Mutex<HashSet<String>>,
112 cached_urls: Mutex<HashSet<String>>,
113 global_fetch_sem: Arc<Semaphore>,
115 host_semaphores: Arc<tokio::sync::Mutex<HashMap<String, Arc<Semaphore>>>>,
117 discovery_limiter: Arc<DefaultDirectRateLimiter>,
128 max_response_bytes: u64,
131 last_cap_warn: Mutex<HashMap<&'static str, Instant>>,
132}
133
134impl CrlSet {
135 fn new(
136 roots: Arc<RootCertStore>,
137 config: MtlsConfig,
138 discover_tx: mpsc::UnboundedSender<String>,
139 initial_cache: HashMap<String, CachedCrl>,
140 ) -> Result<Arc<Self>, McpxError> {
141 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
149 let resolver: Arc<dyn reqwest::dns::Resolve> =
150 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
151 Arc::clone(&allowlist),
152 #[cfg(any(test, feature = "test-helpers"))]
153 Arc::new(std::sync::atomic::AtomicBool::new(false)),
154 #[cfg(not(any(test, feature = "test-helpers")))]
155 (),
156 ));
157
158 let client = reqwest::Client::builder()
159 .no_proxy()
161 .dns_resolver(Arc::clone(&resolver))
162 .timeout(config.crl_fetch_timeout)
163 .connect_timeout(CRL_CONNECT_TIMEOUT)
164 .tcp_keepalive(None)
165 .redirect(reqwest::redirect::Policy::none())
166 .user_agent(format!("rmcp-server-kit/{}", env!("CARGO_PKG_VERSION")))
167 .build()
168 .map_err(|error| McpxError::Startup(format!("CRL HTTP client init: {error}")))?;
169
170 let initial_verifier = rebuild_verifier(&roots, &config, &initial_cache)?;
171 let seen_urls = initial_cache.keys().cloned().collect::<HashSet<_>>();
172 let cached_urls = seen_urls.clone();
173
174 let concurrency = config.crl_max_concurrent_fetches.max(1);
175 let global_fetch_sem = Arc::new(Semaphore::new(concurrency));
176 let host_semaphores = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
177
178 let rate =
179 NonZeroU32::new(config.crl_discovery_rate_per_min.max(1)).unwrap_or(NonZeroU32::MIN);
180 let discovery_limiter = Arc::new(RateLimiter::direct(Quota::per_minute(rate)));
181
182 let max_response_bytes = config.crl_max_response_bytes;
183
184 Ok(Arc::new(Self {
185 inner_verifier: ArcSwap::from_pointee(VerifierHandle(initial_verifier)),
186 cache: RwLock::new(initial_cache),
187 roots,
188 config,
189 discover_tx,
190 client,
191 seen_urls: Mutex::new(seen_urls),
192 cached_urls: Mutex::new(cached_urls),
193 global_fetch_sem,
194 host_semaphores,
195 discovery_limiter,
196 max_response_bytes,
197 last_cap_warn: Mutex::new(HashMap::new()),
198 }))
199 }
200
201 fn warn_cap_exceeded_throttled(&self, which: &'static str) {
202 let now = Instant::now();
203 let cooldown = Duration::from_mins(1);
204 let should_warn = match self.last_cap_warn.lock() {
205 Ok(mut guard) => {
206 let should_emit = guard
207 .get(which)
208 .is_none_or(|last| now.saturating_duration_since(*last) >= cooldown);
209 if should_emit {
210 guard.insert(which, now);
211 }
212 should_emit
213 }
214 Err(poisoned) => {
215 let mut guard = poisoned.into_inner();
216 let should_emit = guard
217 .get(which)
218 .is_none_or(|last| now.saturating_duration_since(*last) >= cooldown);
219 if should_emit {
220 guard.insert(which, now);
221 }
222 should_emit
223 }
224 };
225
226 if should_warn {
227 tracing::warn!(which = which, "CRL map cap exceeded; dropping newest entry");
228 }
229 }
230
231 async fn insert_cache_entry(&self, url: String, cached: CachedCrl) -> bool {
232 let inserted = {
233 let mut guard = self.cache.write().await;
234 if guard.len() >= self.config.crl_max_cache_entries && !guard.contains_key(&url) {
235 false
236 } else {
237 guard.insert(url.clone(), cached);
238 true
239 }
240 };
241
242 if inserted {
243 match self.cached_urls.lock() {
244 Ok(mut cached_urls) => {
245 cached_urls.insert(url);
246 }
247 Err(poisoned) => {
248 poisoned.into_inner().insert(url);
249 }
250 }
251 } else {
252 self.warn_cap_exceeded_throttled("cache");
253 }
254
255 inserted
256 }
257
258 pub async fn force_refresh(&self) -> Result<(), McpxError> {
264 let urls = {
265 let cache = self.cache.read().await;
266 cache.keys().cloned().collect::<Vec<_>>()
267 };
268 self.refresh_urls(urls).await
269 }
270
271 async fn refresh_due_urls(&self) -> Result<(), McpxError> {
272 let now = SystemTime::now();
273 let urls = {
274 let cache = self.cache.read().await;
275 cache
276 .iter()
277 .filter(|(_, cached)| {
278 should_refresh_cached(cached, now, self.config.crl_refresh_interval)
279 })
280 .map(|(url, _)| url.clone())
281 .collect::<Vec<_>>()
282 };
283
284 if urls.is_empty() {
285 return Ok(());
286 }
287
288 self.refresh_urls(urls).await
289 }
290
291 async fn refresh_urls(&self, urls: Vec<String>) -> Result<(), McpxError> {
292 let results = self.fetch_url_results(urls).await;
293 let now = SystemTime::now();
294 let mut cache = self.cache.write().await;
295 let mut changed = false;
296
297 for (url, result) in results {
298 match result {
299 Ok(cached) => {
300 if cache.len() >= self.config.crl_max_cache_entries && !cache.contains_key(&url)
301 {
302 drop(cache);
303 self.warn_cap_exceeded_throttled("cache");
304 cache = self.cache.write().await;
305 continue;
306 }
307 cache.insert(url.clone(), cached);
308 changed = true;
309 match self.cached_urls.lock() {
310 Ok(mut cached_urls) => {
311 cached_urls.insert(url);
312 }
313 Err(poisoned) => {
314 poisoned.into_inner().insert(url);
315 }
316 }
317 }
318 Err(error) => {
319 let remove_entry = cache.get(&url).is_some_and(|existing| {
320 existing
321 .next_update
322 .and_then(|next| next.checked_add(self.config.crl_stale_grace))
323 .is_some_and(|deadline| now > deadline)
324 });
325 tracing::warn!(url = %url, error = %error, "CRL refresh failed");
326 if remove_entry {
327 cache.remove(&url);
328 changed = true;
329 match self.cached_urls.lock() {
330 Ok(mut cached_urls) => {
331 cached_urls.remove(&url);
332 }
333 Err(poisoned) => {
334 poisoned.into_inner().remove(&url);
335 }
336 }
337 match self.seen_urls.lock() {
338 Ok(mut seen_urls) => {
339 seen_urls.remove(&url);
340 }
341 Err(poisoned) => {
342 poisoned.into_inner().remove(&url);
343 }
344 }
345 }
346 }
347 }
348 }
349
350 if changed {
351 self.swap_verifier_from_cache(&cache)?;
352 }
353
354 Ok(())
355 }
356
357 async fn fetch_and_store_url(&self, url: String) -> Result<(), McpxError> {
358 let cached = gated_fetch(
359 &self.client,
360 &self.global_fetch_sem,
361 &self.host_semaphores,
362 &url,
363 self.config.crl_allow_http,
364 self.max_response_bytes,
365 self.config.crl_max_host_semaphores,
366 )
367 .await?;
368 if !self.insert_cache_entry(url, cached).await {
369 return Ok(());
370 }
371 let cache = self.cache.read().await;
372 self.swap_verifier_from_cache(&cache)?;
373 Ok(())
374 }
375
376 fn note_discovered_urls(&self, urls: &[String]) -> bool {
377 let mut missing_cached = false;
386
387 let candidates: Vec<String> = match self.seen_urls.lock() {
398 Ok(seen) => urls
399 .iter()
400 .filter(|url| !seen.contains(*url))
401 .cloned()
402 .collect(),
403 Err(_) => Vec::new(),
404 };
405
406 for url in candidates {
413 if self.discovery_limiter.check().is_err() {
414 tracing::warn!(
415 url = %url,
416 "discovery_rate_limited: dropped CDP URL beyond per-minute cap (will be retried on next handshake observing this URL)"
417 );
418 continue;
419 }
420 if self.discover_tx.send(url.clone()).is_err() {
421 tracing::debug!(
424 url = %url,
425 "discover channel closed; dropping CDP URL without marking seen"
426 );
427 continue;
428 }
429 let mut guard = self
431 .seen_urls
432 .lock()
433 .unwrap_or_else(std::sync::PoisonError::into_inner);
434 if guard.len() >= self.config.crl_max_seen_urls {
435 self.warn_cap_exceeded_throttled("seen_urls");
436 break;
437 }
438 guard.insert(url);
439 }
440
441 if self.config.crl_deny_on_unavailable {
442 let cached = self
443 .cached_urls
444 .lock()
445 .ok()
446 .map(|guard| guard.clone())
447 .unwrap_or_default();
448 missing_cached = urls.iter().any(|url| !cached.contains(url));
449 }
450
451 missing_cached
452 }
453
454 #[doc(hidden)]
460 pub fn __test_with_prepopulated_crls(
461 roots: Arc<RootCertStore>,
462 config: MtlsConfig,
463 prefilled_crls: Vec<CertificateRevocationListDer<'static>>,
464 ) -> Result<Arc<Self>, McpxError> {
465 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
466 drop(discover_rx);
467
468 let mut initial_cache = HashMap::new();
469 for (index, der) in prefilled_crls.into_iter().enumerate() {
470 let source_url = format!("memory://crl/{index}");
471 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
472 initial_cache.insert(
473 source_url.clone(),
474 CachedCrl {
475 der,
476 this_update,
477 next_update,
478 fetched_at: SystemTime::now(),
479 source_url,
480 },
481 );
482 }
483
484 Self::new(roots, config, discover_tx, initial_cache)
485 }
486
487 #[doc(hidden)]
499 pub fn __test_with_kept_receiver(
500 roots: Arc<RootCertStore>,
501 config: MtlsConfig,
502 prefilled_crls: Vec<CertificateRevocationListDer<'static>>,
503 ) -> Result<(Arc<Self>, mpsc::UnboundedReceiver<String>), McpxError> {
504 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
505
506 let mut initial_cache = HashMap::new();
507 for (index, der) in prefilled_crls.into_iter().enumerate() {
508 let source_url = format!("memory://crl/{index}");
509 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
510 initial_cache.insert(
511 source_url.clone(),
512 CachedCrl {
513 der,
514 this_update,
515 next_update,
516 fetched_at: SystemTime::now(),
517 source_url,
518 },
519 );
520 }
521
522 let crl_set = Self::new(roots, config, discover_tx, initial_cache)?;
523 Ok((crl_set, discover_rx))
524 }
525
526 #[doc(hidden)]
531 pub fn __test_check_discovery_rate(&self, urls: &[String]) -> (usize, usize) {
532 let mut accepted = 0usize;
533 let mut dropped = 0usize;
534 for url in urls {
535 if self.discovery_limiter.check().is_ok() {
536 let _ = self.discover_tx.send(url.clone());
537 accepted += 1;
538 } else {
539 dropped += 1;
540 }
541 }
542 (accepted, dropped)
543 }
544
545 #[doc(hidden)]
549 pub fn __test_note_discovered_urls(&self, urls: &[String]) -> bool {
550 let missing_cached = self.note_discovered_urls(urls);
551 if self.discover_tx.is_closed() {
552 match self.seen_urls.lock() {
553 Ok(mut guard) => {
554 for url in urls {
555 if guard.contains(url) {
556 continue;
557 }
558 if guard.len() >= self.config.crl_max_seen_urls {
559 self.warn_cap_exceeded_throttled("seen_urls");
560 break;
561 }
562 guard.insert(url.clone());
563 }
564 }
565 Err(poisoned) => {
566 let mut guard = poisoned.into_inner();
567 for url in urls {
568 if guard.contains(url) {
569 continue;
570 }
571 if guard.len() >= self.config.crl_max_seen_urls {
572 self.warn_cap_exceeded_throttled("seen_urls");
573 break;
574 }
575 guard.insert(url.clone());
576 }
577 }
578 }
579 }
580 missing_cached
581 }
582
583 #[doc(hidden)]
588 pub fn __test_is_seen(&self, url: &str) -> bool {
589 match self.seen_urls.lock() {
590 Ok(seen) => seen.contains(url),
591 Err(_) => false,
592 }
593 }
594
595 #[cfg(any(test, feature = "test-helpers"))]
598 #[doc(hidden)]
599 pub fn __test_host_semaphore_count(&self) -> usize {
600 self.host_semaphores
601 .try_lock()
602 .map_or(0, |guard| guard.len())
603 }
604
605 #[cfg(any(test, feature = "test-helpers"))]
607 #[doc(hidden)]
608 pub fn __test_cache_len(&self) -> usize {
609 self.cache.try_read().map_or(0, |guard| guard.len())
610 }
611
612 #[cfg(any(test, feature = "test-helpers"))]
614 #[doc(hidden)]
615 pub fn __test_cache_contains(&self, url: &str) -> bool {
616 self.cache
617 .try_read()
618 .is_ok_and(|guard| guard.contains_key(url))
619 }
620
621 #[cfg(any(test, feature = "test-helpers"))]
628 #[doc(hidden)]
629 pub async fn __test_trigger_fetch(&self, url: &str) -> Result<(), McpxError> {
630 if let Err(error) = gated_fetch(
631 &self.client,
632 &self.global_fetch_sem,
633 &self.host_semaphores,
634 url,
635 self.config.crl_allow_http,
636 self.max_response_bytes,
637 self.config.crl_max_host_semaphores,
638 )
639 .await
640 {
641 if error
642 .to_string()
643 .contains("crl_host_semaphore_cap_exceeded")
644 {
645 Err(error)
646 } else {
647 Ok(())
648 }
649 } else {
650 Ok(())
651 }
652 }
653
654 #[cfg(any(test, feature = "test-helpers"))]
666 #[doc(hidden)]
667 pub async fn __test_insert_cache(&self, url: &str, cached: CachedCrl) {
668 let _ = self.insert_cache_entry(url.to_owned(), cached).await;
669 }
670
671 #[cfg(any(test, feature = "test-helpers"))]
676 #[doc(hidden)]
677 pub async fn __test_trigger_refresh_url(&self, url: &str) -> Result<(), McpxError> {
678 self.refresh_urls(vec![url.to_owned()]).await
679 }
680
681 async fn fetch_url_results(
682 &self,
683 urls: Vec<String>,
684 ) -> Vec<(String, Result<CachedCrl, McpxError>)> {
685 let mut tasks = JoinSet::new();
686 for url in urls {
687 let client = self.client.clone();
688 let global_sem = Arc::clone(&self.global_fetch_sem);
689 let host_map = Arc::clone(&self.host_semaphores);
690 let allow_http = self.config.crl_allow_http;
691 let max_bytes = self.max_response_bytes;
692 let max_host_semaphores = self.config.crl_max_host_semaphores;
693 tasks.spawn(async move {
694 let result = gated_fetch(
695 &client,
696 &global_sem,
697 &host_map,
698 &url,
699 allow_http,
700 max_bytes,
701 max_host_semaphores,
702 )
703 .await;
704 (url, result)
705 });
706 }
707
708 let mut results = Vec::new();
709 while let Some(joined) = tasks.join_next().await {
710 match joined {
711 Ok(result) => results.push(result),
712 Err(error) => {
713 tracing::warn!(error = %error, "CRL refresh task join failed");
714 }
715 }
716 }
717
718 results
719 }
720
721 fn swap_verifier_from_cache(
722 &self,
723 cache: &impl std::ops::Deref<Target = HashMap<String, CachedCrl>>,
724 ) -> Result<(), McpxError> {
725 let verifier = rebuild_verifier(&self.roots, &self.config, cache)?;
726 self.inner_verifier
727 .store(Arc::new(VerifierHandle(verifier)));
728 Ok(())
729 }
730}
731
732impl CachedCrl {
733 #[cfg(any(test, feature = "test-helpers"))]
737 #[doc(hidden)]
738 #[must_use]
739 pub fn __test_synthetic(now: SystemTime) -> Self {
740 Self {
741 der: CertificateRevocationListDer::from(vec![0x30, 0x00]),
742 this_update: now,
743 next_update: now.checked_add(Duration::from_hours(24)),
744 fetched_at: now,
745 source_url: "test://synthetic".to_owned(),
746 }
747 }
748
749 #[cfg(any(test, feature = "test-helpers"))]
753 #[doc(hidden)]
754 #[must_use]
755 pub fn __test_stale(reference_past: SystemTime) -> Self {
756 Self {
757 der: CertificateRevocationListDer::from(vec![0x30, 0x00]),
758 this_update: reference_past,
759 next_update: Some(reference_past),
760 fetched_at: reference_past,
761 source_url: "test://stale".to_owned(),
762 }
763 }
764}
765
766pub struct DynamicClientCertVerifier {
769 inner: Arc<CrlSet>,
770 dn_subjects: Vec<DistinguishedName>,
771}
772
773impl DynamicClientCertVerifier {
774 #[must_use]
776 pub fn new(inner: Arc<CrlSet>) -> Self {
777 Self {
778 dn_subjects: inner.roots.subjects(),
779 inner,
780 }
781 }
782}
783
784impl std::fmt::Debug for DynamicClientCertVerifier {
785 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
786 f.debug_struct("DynamicClientCertVerifier")
787 .field("dn_subjects_len", &self.dn_subjects.len())
788 .finish_non_exhaustive()
789 }
790}
791
792impl ClientCertVerifier for DynamicClientCertVerifier {
793 fn offer_client_auth(&self) -> bool {
794 let verifier = self.inner.inner_verifier.load();
795 verifier.0.offer_client_auth()
796 }
797
798 fn client_auth_mandatory(&self) -> bool {
799 let verifier = self.inner.inner_verifier.load();
800 verifier.0.client_auth_mandatory()
801 }
802
803 fn root_hint_subjects(&self) -> &[DistinguishedName] {
804 &self.dn_subjects
805 }
806
807 fn verify_client_cert(
808 &self,
809 end_entity: &CertificateDer<'_>,
810 intermediates: &[CertificateDer<'_>],
811 now: UnixTime,
812 ) -> Result<ClientCertVerified, TlsError> {
813 let mut discovered =
825 extract_cdp_urls(end_entity.as_ref(), self.inner.config.crl_allow_http);
826 for intermediate in intermediates {
827 discovered.extend(extract_cdp_urls(
828 intermediate.as_ref(),
829 self.inner.config.crl_allow_http,
830 ));
831 }
832 discovered.sort();
833 discovered.dedup();
834
835 if self.inner.note_discovered_urls(&discovered) {
836 return Err(TlsError::General(
837 "client certificate revocation status unavailable".to_owned(),
838 ));
839 }
840
841 let verifier = self.inner.inner_verifier.load();
842 verifier
843 .0
844 .verify_client_cert(end_entity, intermediates, now)
845 }
846
847 fn verify_tls12_signature(
848 &self,
849 message: &[u8],
850 cert: &CertificateDer<'_>,
851 dss: &DigitallySignedStruct,
852 ) -> Result<HandshakeSignatureValid, TlsError> {
853 let verifier = self.inner.inner_verifier.load();
854 verifier.0.verify_tls12_signature(message, cert, dss)
855 }
856
857 fn verify_tls13_signature(
858 &self,
859 message: &[u8],
860 cert: &CertificateDer<'_>,
861 dss: &DigitallySignedStruct,
862 ) -> Result<HandshakeSignatureValid, TlsError> {
863 let verifier = self.inner.inner_verifier.load();
864 verifier.0.verify_tls13_signature(message, cert, dss)
865 }
866
867 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
868 let verifier = self.inner.inner_verifier.load();
869 verifier.0.supported_verify_schemes()
870 }
871
872 fn requires_raw_public_keys(&self) -> bool {
873 let verifier = self.inner.inner_verifier.load();
874 verifier.0.requires_raw_public_keys()
875 }
876}
877
878#[must_use]
886pub fn extract_cdp_urls(cert_der: &[u8], allow_http: bool) -> Vec<String> {
887 let Ok((_, cert)) = X509Certificate::from_der(cert_der) else {
888 return Vec::new();
889 };
890
891 let mut urls = Vec::new();
892 for ext in cert.extensions() {
893 if let ParsedExtension::CRLDistributionPoints(cdps) = ext.parsed_extension() {
894 for point in cdps.iter() {
895 if let Some(DistributionPointName::FullName(names)) = &point.distribution_point {
896 for name in names {
897 if let GeneralName::URI(uri) = name {
898 let raw = *uri;
899 let Ok(parsed) = Url::parse(raw) else {
900 tracing::debug!(url = %raw, "CDP URL parse failed; dropped");
901 continue;
902 };
903 if let Err(reason) = check_scheme(&parsed, allow_http) {
904 tracing::debug!(
905 url = %raw,
906 reason,
907 "CDP URL rejected by scheme guard; dropped"
908 );
909 continue;
910 }
911 urls.push(parsed.into());
912 }
913 }
914 }
915 }
916 }
917 }
918
919 urls
920}
921
922#[allow(
929 clippy::cognitive_complexity,
930 reason = "bootstrap coordinates timeout, parallel fetches, and partial-cache recovery"
931)]
932pub async fn bootstrap_fetch(
933 roots: Arc<RootCertStore>,
934 ca_certs: &[CertificateDer<'static>],
935 config: MtlsConfig,
936) -> Result<(Arc<CrlSet>, mpsc::UnboundedReceiver<String>), McpxError> {
937 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
938
939 let mut urls = ca_certs
940 .iter()
941 .flat_map(|cert| extract_cdp_urls(cert.as_ref(), config.crl_allow_http))
942 .collect::<Vec<_>>();
943 urls.sort();
944 urls.dedup();
945
946 let bootstrap_allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
950 let bootstrap_resolver: Arc<dyn reqwest::dns::Resolve> =
951 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
952 Arc::clone(&bootstrap_allowlist),
953 #[cfg(any(test, feature = "test-helpers"))]
954 Arc::new(std::sync::atomic::AtomicBool::new(false)),
955 #[cfg(not(any(test, feature = "test-helpers")))]
956 (),
957 ));
958
959 let client = reqwest::Client::builder()
960 .no_proxy()
962 .dns_resolver(Arc::clone(&bootstrap_resolver))
963 .timeout(config.crl_fetch_timeout)
964 .connect_timeout(CRL_CONNECT_TIMEOUT)
965 .tcp_keepalive(None)
966 .redirect(reqwest::redirect::Policy::none())
967 .user_agent(format!("rmcp-server-kit/{}", env!("CARGO_PKG_VERSION")))
968 .build()
969 .map_err(|error| McpxError::Startup(format!("CRL HTTP client init: {error}")))?;
970
971 let bootstrap_concurrency = config.crl_max_concurrent_fetches.max(1);
975 let global_sem = Arc::new(Semaphore::new(bootstrap_concurrency));
976 let host_semaphores = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
977 let allow_http = config.crl_allow_http;
978 let max_bytes = config.crl_max_response_bytes;
979 let max_host_semaphores = config.crl_max_host_semaphores;
980
981 let mut initial_cache = HashMap::new();
982 let mut tasks = JoinSet::new();
983 for url in &urls {
984 let client = client.clone();
985 let url = url.clone();
986 let global_sem = Arc::clone(&global_sem);
987 let host_semaphores = Arc::clone(&host_semaphores);
988 tasks.spawn(async move {
989 let result = gated_fetch(
990 &client,
991 &global_sem,
992 &host_semaphores,
993 &url,
994 allow_http,
995 max_bytes,
996 max_host_semaphores,
997 )
998 .await;
999 (url, result)
1000 });
1001 }
1002
1003 let timeout: Sleep = tokio::time::sleep(BOOTSTRAP_TIMEOUT);
1004 tokio::pin!(timeout);
1005
1006 while !tasks.is_empty() {
1007 tokio::select! {
1008 () = &mut timeout => {
1009 tracing::warn!("CRL bootstrap timed out after {:?}", BOOTSTRAP_TIMEOUT);
1010 break;
1011 }
1012 maybe_joined = tasks.join_next() => {
1013 let Some(joined) = maybe_joined else {
1014 break;
1015 };
1016 match joined {
1017 Ok((url, Ok(cached))) => {
1018 initial_cache.insert(url, cached);
1019 }
1020 Ok((url, Err(error))) => {
1021 tracing::warn!(url = %url, error = %error, "CRL bootstrap fetch failed");
1022 }
1023 Err(error) => {
1024 tracing::warn!(error = %error, "CRL bootstrap task join failed");
1025 }
1026 }
1027 }
1028 }
1029 }
1030
1031 let set = CrlSet::new(roots, config, discover_tx, initial_cache)?;
1032 Ok((set, discover_rx))
1033}
1034
1035#[allow(
1037 clippy::cognitive_complexity,
1038 reason = "refresher loop intentionally handles shutdown, timer, and discovery in one select"
1039)]
1040pub async fn run_crl_refresher(
1041 set: Arc<CrlSet>,
1042 mut discover_rx: mpsc::UnboundedReceiver<String>,
1043 shutdown: CancellationToken,
1044) {
1045 let mut refresh_sleep = schedule_next_refresh(&set).await;
1046
1047 loop {
1048 tokio::select! {
1049 () = shutdown.cancelled() => {
1050 break;
1051 }
1052 () = &mut refresh_sleep => {
1053 if let Err(error) = set.refresh_due_urls().await {
1054 tracing::warn!(error = %error, "CRL periodic refresh failed");
1055 }
1056 refresh_sleep = schedule_next_refresh(&set).await;
1057 }
1058 maybe_url = discover_rx.recv() => {
1059 let Some(url) = maybe_url else {
1060 break;
1061 };
1062 if let Err(error) = set.fetch_and_store_url(url.clone()).await {
1063 tracing::warn!(url = %url, error = %error, "CRL discovery fetch failed");
1064 }
1065 refresh_sleep = schedule_next_refresh(&set).await;
1066 }
1067 }
1068 }
1069}
1070
1071pub fn rebuild_verifier<S: std::hash::BuildHasher>(
1077 roots: &Arc<RootCertStore>,
1078 config: &MtlsConfig,
1079 cache: &HashMap<String, CachedCrl, S>,
1080) -> Result<Arc<dyn ClientCertVerifier>, McpxError> {
1081 let mut builder = WebPkiClientVerifier::builder(Arc::clone(roots));
1082
1083 if !cache.is_empty() {
1084 let crls = cache
1085 .values()
1086 .map(|cached| cached.der.clone())
1087 .collect::<Vec<_>>();
1088 builder = builder.with_crls(crls);
1089 }
1090 if config.crl_end_entity_only {
1091 builder = builder.only_check_end_entity_revocation();
1092 }
1093 if !config.crl_deny_on_unavailable {
1094 builder = builder.allow_unknown_revocation_status();
1095 }
1096 if config.crl_enforce_expiration {
1097 builder = builder.enforce_revocation_expiration();
1098 }
1099 if !config.required {
1100 builder = builder.allow_unauthenticated();
1101 }
1102
1103 builder
1104 .build()
1105 .map_err(|error| McpxError::Tls(format!("mTLS verifier error: {error}")))
1106}
1107
1108pub fn parse_crl_metadata(der: &[u8]) -> Result<(SystemTime, Option<SystemTime>), McpxError> {
1114 let (_, crl) = CertificateRevocationList::from_der(der)
1115 .map_err(|error| McpxError::Tls(format!("invalid CRL DER: {error:?}")))?;
1116
1117 Ok((
1118 asn1_time_to_system_time(crl.last_update()),
1119 crl.next_update().map(asn1_time_to_system_time),
1120 ))
1121}
1122
1123async fn schedule_next_refresh(set: &CrlSet) -> Pin<Box<Sleep>> {
1124 let duration = next_refresh_delay(set).await;
1125 boxed_sleep(duration)
1126}
1127
1128fn boxed_sleep(duration: Duration) -> Pin<Box<Sleep>> {
1129 Box::pin(tokio::time::sleep_until(Instant::now() + duration))
1130}
1131
1132async fn next_refresh_delay(set: &CrlSet) -> Duration {
1133 if let Some(interval) = set.config.crl_refresh_interval {
1134 return clamp_refresh(interval);
1135 }
1136
1137 let now = SystemTime::now();
1138 let cache = set.cache.read().await;
1139 let mut next = MAX_AUTO_REFRESH;
1140
1141 for cached in cache.values() {
1142 if let Some(next_update) = cached.next_update {
1143 let duration = next_update.duration_since(now).unwrap_or(Duration::ZERO);
1144 next = next.min(clamp_refresh(duration));
1145 }
1146 }
1147
1148 next
1149}
1150
1151async fn gated_fetch(
1158 client: &reqwest::Client,
1159 global_sem: &Arc<Semaphore>,
1160 host_semaphores: &Arc<tokio::sync::Mutex<HashMap<String, Arc<Semaphore>>>>,
1161 url: &str,
1162 allow_http: bool,
1163 max_bytes: u64,
1164 max_host_semaphores: usize,
1165) -> Result<CachedCrl, McpxError> {
1166 let host_key = Url::parse(url)
1167 .ok()
1168 .and_then(|u| u.host_str().map(str::to_owned))
1169 .unwrap_or_else(|| url.to_owned());
1170
1171 let host_sem = {
1172 let mut map = host_semaphores.lock().await;
1173 if !map.contains_key(&host_key) {
1174 if map.len() >= max_host_semaphores {
1175 return Err(McpxError::Config(
1176 "crl_host_semaphore_cap_exceeded: too many distinct CRL hosts in flight"
1177 .to_owned(),
1178 ));
1179 }
1180 map.insert(host_key.clone(), Arc::new(Semaphore::new(1)));
1181 }
1182 match map.get(&host_key) {
1183 Some(semaphore) => Arc::clone(semaphore),
1184 None => {
1185 return Err(McpxError::Tls(
1186 "CRL host semaphore missing after insertion".to_owned(),
1187 ));
1188 }
1189 }
1190 };
1191
1192 let _global_permit = Arc::clone(global_sem)
1193 .acquire_owned()
1194 .await
1195 .map_err(|error| McpxError::Tls(format!("CRL global semaphore closed: {error}")))?;
1196 let _host_permit = host_sem
1197 .acquire_owned()
1198 .await
1199 .map_err(|error| McpxError::Tls(format!("CRL host semaphore closed: {error}")))?;
1200
1201 fetch_crl(client, url, allow_http, max_bytes).await
1202}
1203
1204async fn fetch_crl(
1205 client: &reqwest::Client,
1206 url: &str,
1207 allow_http: bool,
1208 max_bytes: u64,
1209) -> Result<CachedCrl, McpxError> {
1210 let parsed =
1211 Url::parse(url).map_err(|error| McpxError::Tls(format!("CRL URL parse {url}: {error}")))?;
1212
1213 if let Err(reason) = check_scheme(&parsed, allow_http) {
1214 tracing::warn!(url = %url, reason, "CRL fetch denied: scheme");
1215 return Err(McpxError::Tls(format!(
1216 "CRL scheme rejected ({reason}): {url}"
1217 )));
1218 }
1219
1220 let host = parsed
1221 .host_str()
1222 .ok_or_else(|| McpxError::Tls(format!("CRL URL has no host: {url}")))?;
1223 let port = parsed
1224 .port_or_known_default()
1225 .ok_or_else(|| McpxError::Tls(format!("CRL URL has no known port: {url}")))?;
1226
1227 let addrs = lookup_host((host, port))
1228 .await
1229 .map_err(|error| McpxError::Tls(format!("CRL DNS resolution {url}: {error}")))?;
1230
1231 let mut any_addr = false;
1232 for addr in addrs {
1233 any_addr = true;
1234 if let Some(reason) = ip_block_reason(addr.ip()) {
1235 tracing::warn!(
1236 url = %url,
1237 resolved_ip = %addr.ip(),
1238 reason,
1239 "CRL fetch denied: blocked IP"
1240 );
1241 return Err(McpxError::Tls(format!(
1242 "CRL host resolved to blocked IP ({reason}): {url}"
1243 )));
1244 }
1245 }
1246 if !any_addr {
1247 return Err(McpxError::Tls(format!(
1248 "CRL DNS resolution returned no addresses: {url}"
1249 )));
1250 }
1251
1252 let mut response = client
1253 .get(url)
1254 .send()
1255 .await
1256 .map_err(|error| McpxError::Tls(format!("CRL fetch {url}: {error}")))?
1257 .error_for_status()
1258 .map_err(|error| McpxError::Tls(format!("CRL fetch {url}: {error}")))?;
1259
1260 let initial_capacity = usize::try_from(max_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
1263 let mut body: Vec<u8> = Vec::with_capacity(initial_capacity);
1264 while let Some(chunk) = response
1265 .chunk()
1266 .await
1267 .map_err(|error| McpxError::Tls(format!("CRL read {url}: {error}")))?
1268 {
1269 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
1270 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
1271 if body_len.saturating_add(chunk_len) > max_bytes {
1272 return Err(McpxError::Tls(format!(
1273 "CRL body exceeded cap of {max_bytes} bytes: {url}"
1274 )));
1275 }
1276 body.extend_from_slice(&chunk);
1277 }
1278
1279 let der = CertificateRevocationListDer::from(body);
1280 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
1281
1282 Ok(CachedCrl {
1283 der,
1284 this_update,
1285 next_update,
1286 fetched_at: SystemTime::now(),
1287 source_url: url.to_owned(),
1288 })
1289}
1290
1291fn should_refresh_cached(
1292 cached: &CachedCrl,
1293 now: SystemTime,
1294 fixed_interval: Option<Duration>,
1295) -> bool {
1296 if let Some(interval) = fixed_interval {
1297 return cached
1298 .fetched_at
1299 .checked_add(clamp_refresh(interval))
1300 .is_none_or(|deadline| now >= deadline);
1301 }
1302
1303 cached
1304 .next_update
1305 .is_none_or(|next_update| now >= next_update)
1306}
1307
1308fn clamp_refresh(duration: Duration) -> Duration {
1309 duration.clamp(MIN_AUTO_REFRESH, MAX_AUTO_REFRESH)
1310}
1311
1312fn asn1_time_to_system_time(time: x509_parser::time::ASN1Time) -> SystemTime {
1313 let timestamp = time.timestamp();
1314 if timestamp >= 0 {
1315 let seconds = u64::try_from(timestamp).unwrap_or(0);
1316 UNIX_EPOCH + Duration::from_secs(seconds)
1317 } else {
1318 UNIX_EPOCH - Duration::from_secs(timestamp.unsigned_abs())
1319 }
1320}