1use std::{
25 collections::{HashMap, HashSet},
26 num::NonZeroU32,
27 pin::Pin,
28 sync::{Arc, Mutex},
29 time::{Duration, SystemTime, UNIX_EPOCH},
30};
31
32use arc_swap::ArcSwap;
33use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
34use rustls::{
35 DigitallySignedStruct, DistinguishedName, Error as TlsError, RootCertStore, SignatureScheme,
36 client::danger::HandshakeSignatureValid,
37 pki_types::{CertificateDer, CertificateRevocationListDer, UnixTime},
38 server::{
39 WebPkiClientVerifier,
40 danger::{ClientCertVerified, ClientCertVerifier},
41 },
42};
43use tokio::{
44 net::lookup_host,
45 sync::{RwLock, Semaphore, mpsc},
46 task::JoinSet,
47 time::{Instant, Sleep},
48};
49use tokio_util::sync::CancellationToken;
50use url::Url;
51use x509_parser::{
52 extensions::{DistributionPointName, GeneralName, ParsedExtension},
53 prelude::{FromDer, X509Certificate},
54 revocation_list::CertificateRevocationList,
55};
56
57use crate::{
58 auth::MtlsConfig,
59 error::McpxError,
60 ssrf::{check_scheme, ip_block_reason},
61};
62
63const BOOTSTRAP_TIMEOUT: Duration = Duration::from_secs(10);
64const MIN_AUTO_REFRESH: Duration = Duration::from_mins(10);
65const MAX_AUTO_REFRESH: Duration = Duration::from_hours(24);
66const CRL_CONNECT_TIMEOUT: Duration = Duration::from_secs(3);
69
70#[derive(Clone, Debug)]
72#[non_exhaustive]
73pub struct CachedCrl {
74 pub der: CertificateRevocationListDer<'static>,
76 pub this_update: SystemTime,
78 pub next_update: Option<SystemTime>,
80 pub fetched_at: SystemTime,
82 pub source_url: String,
84}
85
86pub(crate) struct VerifierHandle(pub Arc<dyn ClientCertVerifier>);
87
88impl std::fmt::Debug for VerifierHandle {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.debug_struct("VerifierHandle").finish_non_exhaustive()
91 }
92}
93
94#[allow(
96 missing_debug_implementations,
97 reason = "contains ArcSwap and dyn verifier internals"
98)]
99#[non_exhaustive]
100pub struct CrlSet {
101 inner_verifier: ArcSwap<VerifierHandle>,
102 pub cache: RwLock<HashMap<String, CachedCrl>>,
104 pub roots: Arc<RootCertStore>,
106 pub config: MtlsConfig,
108 pub discover_tx: mpsc::UnboundedSender<String>,
110 client: reqwest::Client,
111 seen_urls: Mutex<HashSet<String>>,
112 cached_urls: Mutex<HashSet<String>>,
113 global_fetch_sem: Arc<Semaphore>,
115 host_semaphores: Arc<tokio::sync::Mutex<HashMap<String, Arc<Semaphore>>>>,
117 discovery_limiter: Arc<DefaultDirectRateLimiter>,
125 max_response_bytes: u64,
128 last_cap_warn: Mutex<HashMap<&'static str, Instant>>,
129}
130
131impl CrlSet {
132 fn new(
133 roots: Arc<RootCertStore>,
134 config: MtlsConfig,
135 discover_tx: mpsc::UnboundedSender<String>,
136 initial_cache: HashMap<String, CachedCrl>,
137 ) -> Result<Arc<Self>, McpxError> {
138 let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
146 let resolver: Arc<dyn reqwest::dns::Resolve> =
147 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
148 Arc::clone(&allowlist),
149 #[cfg(any(test, feature = "test-helpers"))]
150 Arc::new(std::sync::atomic::AtomicBool::new(false)),
151 #[cfg(not(any(test, feature = "test-helpers")))]
152 (),
153 ));
154
155 let client = reqwest::Client::builder()
156 .no_proxy()
158 .dns_resolver(Arc::clone(&resolver))
159 .timeout(config.crl_fetch_timeout)
160 .connect_timeout(CRL_CONNECT_TIMEOUT)
161 .tcp_keepalive(None)
162 .redirect(reqwest::redirect::Policy::none())
163 .user_agent(format!("rmcp-server-kit/{}", env!("CARGO_PKG_VERSION")))
164 .build()
165 .map_err(|error| McpxError::Startup(format!("CRL HTTP client init: {error}")))?;
166
167 let initial_verifier = rebuild_verifier(&roots, &config, &initial_cache)?;
168 let seen_urls = initial_cache.keys().cloned().collect::<HashSet<_>>();
169 let cached_urls = seen_urls.clone();
170
171 let concurrency = config.crl_max_concurrent_fetches.max(1);
172 let global_fetch_sem = Arc::new(Semaphore::new(concurrency));
173 let host_semaphores = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
174
175 let rate =
176 NonZeroU32::new(config.crl_discovery_rate_per_min.max(1)).unwrap_or(NonZeroU32::MIN);
177 let discovery_limiter = Arc::new(RateLimiter::direct(Quota::per_minute(rate)));
178
179 let max_response_bytes = config.crl_max_response_bytes;
180
181 Ok(Arc::new(Self {
182 inner_verifier: ArcSwap::from_pointee(VerifierHandle(initial_verifier)),
183 cache: RwLock::new(initial_cache),
184 roots,
185 config,
186 discover_tx,
187 client,
188 seen_urls: Mutex::new(seen_urls),
189 cached_urls: Mutex::new(cached_urls),
190 global_fetch_sem,
191 host_semaphores,
192 discovery_limiter,
193 max_response_bytes,
194 last_cap_warn: Mutex::new(HashMap::new()),
195 }))
196 }
197
198 fn warn_cap_exceeded_throttled(&self, which: &'static str) {
199 let now = Instant::now();
200 let cooldown = Duration::from_mins(1);
201 let should_warn = match self.last_cap_warn.lock() {
202 Ok(mut guard) => {
203 let should_emit = guard
204 .get(which)
205 .is_none_or(|last| now.saturating_duration_since(*last) >= cooldown);
206 if should_emit {
207 guard.insert(which, now);
208 }
209 should_emit
210 }
211 Err(poisoned) => {
212 let mut guard = poisoned.into_inner();
213 let should_emit = guard
214 .get(which)
215 .is_none_or(|last| now.saturating_duration_since(*last) >= cooldown);
216 if should_emit {
217 guard.insert(which, now);
218 }
219 should_emit
220 }
221 };
222
223 if should_warn {
224 tracing::warn!(which = which, "CRL map cap exceeded; dropping newest entry");
225 }
226 }
227
228 async fn insert_cache_entry(&self, url: String, cached: CachedCrl) -> bool {
229 let inserted = {
230 let mut guard = self.cache.write().await;
231 if guard.len() >= self.config.crl_max_cache_entries && !guard.contains_key(&url) {
232 false
233 } else {
234 guard.insert(url.clone(), cached);
235 true
236 }
237 };
238
239 if inserted {
240 match self.cached_urls.lock() {
241 Ok(mut cached_urls) => {
242 cached_urls.insert(url);
243 }
244 Err(poisoned) => {
245 poisoned.into_inner().insert(url);
246 }
247 }
248 } else {
249 self.warn_cap_exceeded_throttled("cache");
250 }
251
252 inserted
253 }
254
255 pub async fn force_refresh(&self) -> Result<(), McpxError> {
261 let urls = {
262 let cache = self.cache.read().await;
263 cache.keys().cloned().collect::<Vec<_>>()
264 };
265 self.refresh_urls(urls).await
266 }
267
268 async fn refresh_due_urls(&self) -> Result<(), McpxError> {
269 let now = SystemTime::now();
270 let urls = {
271 let cache = self.cache.read().await;
272 cache
273 .iter()
274 .filter(|(_, cached)| {
275 should_refresh_cached(cached, now, self.config.crl_refresh_interval)
276 })
277 .map(|(url, _)| url.clone())
278 .collect::<Vec<_>>()
279 };
280
281 if urls.is_empty() {
282 return Ok(());
283 }
284
285 self.refresh_urls(urls).await
286 }
287
288 async fn refresh_urls(&self, urls: Vec<String>) -> Result<(), McpxError> {
289 let results = self.fetch_url_results(urls).await;
290 let now = SystemTime::now();
291 let mut cache = self.cache.write().await;
292 let mut changed = false;
293
294 for (url, result) in results {
295 match result {
296 Ok(cached) => {
297 if cache.len() >= self.config.crl_max_cache_entries && !cache.contains_key(&url)
298 {
299 drop(cache);
300 self.warn_cap_exceeded_throttled("cache");
301 cache = self.cache.write().await;
302 continue;
303 }
304 cache.insert(url.clone(), cached);
305 changed = true;
306 match self.cached_urls.lock() {
307 Ok(mut cached_urls) => {
308 cached_urls.insert(url);
309 }
310 Err(poisoned) => {
311 poisoned.into_inner().insert(url);
312 }
313 }
314 }
315 Err(error) => {
316 let remove_entry = cache.get(&url).is_some_and(|existing| {
317 existing
318 .next_update
319 .and_then(|next| next.checked_add(self.config.crl_stale_grace))
320 .is_some_and(|deadline| now > deadline)
321 });
322 tracing::warn!(url = %url, error = %error, "CRL refresh failed");
323 if remove_entry {
324 cache.remove(&url);
325 changed = true;
326 match self.cached_urls.lock() {
327 Ok(mut cached_urls) => {
328 cached_urls.remove(&url);
329 }
330 Err(poisoned) => {
331 poisoned.into_inner().remove(&url);
332 }
333 }
334 match self.seen_urls.lock() {
335 Ok(mut seen_urls) => {
336 seen_urls.remove(&url);
337 }
338 Err(poisoned) => {
339 poisoned.into_inner().remove(&url);
340 }
341 }
342 }
343 }
344 }
345 }
346
347 if changed {
348 self.swap_verifier_from_cache(&cache)?;
349 }
350
351 Ok(())
352 }
353
354 async fn fetch_and_store_url(&self, url: String) -> Result<(), McpxError> {
355 let cached = gated_fetch(
356 &self.client,
357 &self.global_fetch_sem,
358 &self.host_semaphores,
359 &url,
360 self.config.crl_allow_http,
361 self.max_response_bytes,
362 self.config.crl_max_host_semaphores,
363 )
364 .await?;
365 if !self.insert_cache_entry(url, cached).await {
366 return Ok(());
367 }
368 let cache = self.cache.read().await;
369 self.swap_verifier_from_cache(&cache)?;
370 Ok(())
371 }
372
373 fn note_discovered_urls(&self, urls: &[String]) -> bool {
374 let mut missing_cached = false;
383
384 let candidates: Vec<String> = match self.seen_urls.lock() {
395 Ok(seen) => urls
396 .iter()
397 .filter(|url| !seen.contains(*url))
398 .cloned()
399 .collect(),
400 Err(_) => Vec::new(),
401 };
402
403 for url in candidates {
410 if self.discovery_limiter.check().is_err() {
411 tracing::warn!(
412 url = %url,
413 "discovery_rate_limited: dropped CDP URL beyond per-minute cap (will be retried on next handshake observing this URL)"
414 );
415 continue;
416 }
417 if self.discover_tx.send(url.clone()).is_err() {
418 tracing::debug!(
421 url = %url,
422 "discover channel closed; dropping CDP URL without marking seen"
423 );
424 continue;
425 }
426 let mut guard = self
428 .seen_urls
429 .lock()
430 .unwrap_or_else(std::sync::PoisonError::into_inner);
431 if guard.len() >= self.config.crl_max_seen_urls {
432 self.warn_cap_exceeded_throttled("seen_urls");
433 break;
434 }
435 guard.insert(url);
436 }
437
438 if self.config.crl_deny_on_unavailable {
439 let cached = self
440 .cached_urls
441 .lock()
442 .ok()
443 .map(|guard| guard.clone())
444 .unwrap_or_default();
445 missing_cached = urls.iter().any(|url| !cached.contains(url));
446 }
447
448 missing_cached
449 }
450
451 #[doc(hidden)]
457 pub fn __test_with_prepopulated_crls(
458 roots: Arc<RootCertStore>,
459 config: MtlsConfig,
460 prefilled_crls: Vec<CertificateRevocationListDer<'static>>,
461 ) -> Result<Arc<Self>, McpxError> {
462 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
463 drop(discover_rx);
464
465 let mut initial_cache = HashMap::new();
466 for (index, der) in prefilled_crls.into_iter().enumerate() {
467 let source_url = format!("memory://crl/{index}");
468 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
469 initial_cache.insert(
470 source_url.clone(),
471 CachedCrl {
472 der,
473 this_update,
474 next_update,
475 fetched_at: SystemTime::now(),
476 source_url,
477 },
478 );
479 }
480
481 Self::new(roots, config, discover_tx, initial_cache)
482 }
483
484 #[doc(hidden)]
496 pub fn __test_with_kept_receiver(
497 roots: Arc<RootCertStore>,
498 config: MtlsConfig,
499 prefilled_crls: Vec<CertificateRevocationListDer<'static>>,
500 ) -> Result<(Arc<Self>, mpsc::UnboundedReceiver<String>), McpxError> {
501 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
502
503 let mut initial_cache = HashMap::new();
504 for (index, der) in prefilled_crls.into_iter().enumerate() {
505 let source_url = format!("memory://crl/{index}");
506 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
507 initial_cache.insert(
508 source_url.clone(),
509 CachedCrl {
510 der,
511 this_update,
512 next_update,
513 fetched_at: SystemTime::now(),
514 source_url,
515 },
516 );
517 }
518
519 let crl_set = Self::new(roots, config, discover_tx, initial_cache)?;
520 Ok((crl_set, discover_rx))
521 }
522
523 #[doc(hidden)]
528 pub fn __test_check_discovery_rate(&self, urls: &[String]) -> (usize, usize) {
529 let mut accepted = 0usize;
530 let mut dropped = 0usize;
531 for url in urls {
532 if self.discovery_limiter.check().is_ok() {
533 let _ = self.discover_tx.send(url.clone());
534 accepted += 1;
535 } else {
536 dropped += 1;
537 }
538 }
539 (accepted, dropped)
540 }
541
542 #[doc(hidden)]
546 pub fn __test_note_discovered_urls(&self, urls: &[String]) -> bool {
547 let missing_cached = self.note_discovered_urls(urls);
548 if self.discover_tx.is_closed() {
549 match self.seen_urls.lock() {
550 Ok(mut guard) => {
551 for url in urls {
552 if guard.contains(url) {
553 continue;
554 }
555 if guard.len() >= self.config.crl_max_seen_urls {
556 self.warn_cap_exceeded_throttled("seen_urls");
557 break;
558 }
559 guard.insert(url.clone());
560 }
561 }
562 Err(poisoned) => {
563 let mut guard = poisoned.into_inner();
564 for url in urls {
565 if guard.contains(url) {
566 continue;
567 }
568 if guard.len() >= self.config.crl_max_seen_urls {
569 self.warn_cap_exceeded_throttled("seen_urls");
570 break;
571 }
572 guard.insert(url.clone());
573 }
574 }
575 }
576 }
577 missing_cached
578 }
579
580 #[doc(hidden)]
585 pub fn __test_is_seen(&self, url: &str) -> bool {
586 match self.seen_urls.lock() {
587 Ok(seen) => seen.contains(url),
588 Err(_) => false,
589 }
590 }
591
592 #[cfg(any(test, feature = "test-helpers"))]
595 #[doc(hidden)]
596 pub fn __test_host_semaphore_count(&self) -> usize {
597 self.host_semaphores
598 .try_lock()
599 .map_or(0, |guard| guard.len())
600 }
601
602 #[cfg(any(test, feature = "test-helpers"))]
604 #[doc(hidden)]
605 pub fn __test_cache_len(&self) -> usize {
606 self.cache.try_read().map_or(0, |guard| guard.len())
607 }
608
609 #[cfg(any(test, feature = "test-helpers"))]
611 #[doc(hidden)]
612 pub fn __test_cache_contains(&self, url: &str) -> bool {
613 self.cache
614 .try_read()
615 .is_ok_and(|guard| guard.contains_key(url))
616 }
617
618 #[cfg(any(test, feature = "test-helpers"))]
625 #[doc(hidden)]
626 pub async fn __test_trigger_fetch(&self, url: &str) -> Result<(), McpxError> {
627 if let Err(error) = gated_fetch(
628 &self.client,
629 &self.global_fetch_sem,
630 &self.host_semaphores,
631 url,
632 self.config.crl_allow_http,
633 self.max_response_bytes,
634 self.config.crl_max_host_semaphores,
635 )
636 .await
637 {
638 if error
639 .to_string()
640 .contains("crl_host_semaphore_cap_exceeded")
641 {
642 Err(error)
643 } else {
644 Ok(())
645 }
646 } else {
647 Ok(())
648 }
649 }
650
651 #[cfg(any(test, feature = "test-helpers"))]
663 #[doc(hidden)]
664 pub async fn __test_insert_cache(&self, url: &str, cached: CachedCrl) {
665 let _ = self.insert_cache_entry(url.to_owned(), cached).await;
666 }
667
668 #[cfg(any(test, feature = "test-helpers"))]
673 #[doc(hidden)]
674 pub async fn __test_trigger_refresh_url(&self, url: &str) -> Result<(), McpxError> {
675 self.refresh_urls(vec![url.to_owned()]).await
676 }
677
678 async fn fetch_url_results(
679 &self,
680 urls: Vec<String>,
681 ) -> Vec<(String, Result<CachedCrl, McpxError>)> {
682 let mut tasks = JoinSet::new();
683 for url in urls {
684 let client = self.client.clone();
685 let global_sem = Arc::clone(&self.global_fetch_sem);
686 let host_map = Arc::clone(&self.host_semaphores);
687 let allow_http = self.config.crl_allow_http;
688 let max_bytes = self.max_response_bytes;
689 let max_host_semaphores = self.config.crl_max_host_semaphores;
690 tasks.spawn(async move {
691 let result = gated_fetch(
692 &client,
693 &global_sem,
694 &host_map,
695 &url,
696 allow_http,
697 max_bytes,
698 max_host_semaphores,
699 )
700 .await;
701 (url, result)
702 });
703 }
704
705 let mut results = Vec::new();
706 while let Some(joined) = tasks.join_next().await {
707 match joined {
708 Ok(result) => results.push(result),
709 Err(error) => {
710 tracing::warn!(error = %error, "CRL refresh task join failed");
711 }
712 }
713 }
714
715 results
716 }
717
718 fn swap_verifier_from_cache(
719 &self,
720 cache: &impl std::ops::Deref<Target = HashMap<String, CachedCrl>>,
721 ) -> Result<(), McpxError> {
722 let verifier = rebuild_verifier(&self.roots, &self.config, cache)?;
723 self.inner_verifier
724 .store(Arc::new(VerifierHandle(verifier)));
725 Ok(())
726 }
727}
728
729impl CachedCrl {
730 #[cfg(any(test, feature = "test-helpers"))]
734 #[doc(hidden)]
735 #[must_use]
736 pub fn __test_synthetic(now: SystemTime) -> Self {
737 Self {
738 der: CertificateRevocationListDer::from(vec![0x30, 0x00]),
739 this_update: now,
740 next_update: now.checked_add(Duration::from_hours(24)),
741 fetched_at: now,
742 source_url: "test://synthetic".to_owned(),
743 }
744 }
745
746 #[cfg(any(test, feature = "test-helpers"))]
750 #[doc(hidden)]
751 #[must_use]
752 pub fn __test_stale(reference_past: SystemTime) -> Self {
753 Self {
754 der: CertificateRevocationListDer::from(vec![0x30, 0x00]),
755 this_update: reference_past,
756 next_update: Some(reference_past),
757 fetched_at: reference_past,
758 source_url: "test://stale".to_owned(),
759 }
760 }
761}
762
763pub struct DynamicClientCertVerifier {
766 inner: Arc<CrlSet>,
767 dn_subjects: Vec<DistinguishedName>,
768}
769
770impl DynamicClientCertVerifier {
771 #[must_use]
773 pub fn new(inner: Arc<CrlSet>) -> Self {
774 Self {
775 dn_subjects: inner.roots.subjects(),
776 inner,
777 }
778 }
779}
780
781impl std::fmt::Debug for DynamicClientCertVerifier {
782 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
783 f.debug_struct("DynamicClientCertVerifier")
784 .field("dn_subjects_len", &self.dn_subjects.len())
785 .finish_non_exhaustive()
786 }
787}
788
789impl ClientCertVerifier for DynamicClientCertVerifier {
790 fn offer_client_auth(&self) -> bool {
791 let verifier = self.inner.inner_verifier.load();
792 verifier.0.offer_client_auth()
793 }
794
795 fn client_auth_mandatory(&self) -> bool {
796 let verifier = self.inner.inner_verifier.load();
797 verifier.0.client_auth_mandatory()
798 }
799
800 fn root_hint_subjects(&self) -> &[DistinguishedName] {
801 &self.dn_subjects
802 }
803
804 fn verify_client_cert(
805 &self,
806 end_entity: &CertificateDer<'_>,
807 intermediates: &[CertificateDer<'_>],
808 now: UnixTime,
809 ) -> Result<ClientCertVerified, TlsError> {
810 let mut discovered =
822 extract_cdp_urls(end_entity.as_ref(), self.inner.config.crl_allow_http);
823 for intermediate in intermediates {
824 discovered.extend(extract_cdp_urls(
825 intermediate.as_ref(),
826 self.inner.config.crl_allow_http,
827 ));
828 }
829 discovered.sort();
830 discovered.dedup();
831
832 if self.inner.note_discovered_urls(&discovered) {
833 return Err(TlsError::General(
834 "client certificate revocation status unavailable".to_owned(),
835 ));
836 }
837
838 let verifier = self.inner.inner_verifier.load();
839 verifier
840 .0
841 .verify_client_cert(end_entity, intermediates, now)
842 }
843
844 fn verify_tls12_signature(
845 &self,
846 message: &[u8],
847 cert: &CertificateDer<'_>,
848 dss: &DigitallySignedStruct,
849 ) -> Result<HandshakeSignatureValid, TlsError> {
850 let verifier = self.inner.inner_verifier.load();
851 verifier.0.verify_tls12_signature(message, cert, dss)
852 }
853
854 fn verify_tls13_signature(
855 &self,
856 message: &[u8],
857 cert: &CertificateDer<'_>,
858 dss: &DigitallySignedStruct,
859 ) -> Result<HandshakeSignatureValid, TlsError> {
860 let verifier = self.inner.inner_verifier.load();
861 verifier.0.verify_tls13_signature(message, cert, dss)
862 }
863
864 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
865 let verifier = self.inner.inner_verifier.load();
866 verifier.0.supported_verify_schemes()
867 }
868
869 fn requires_raw_public_keys(&self) -> bool {
870 let verifier = self.inner.inner_verifier.load();
871 verifier.0.requires_raw_public_keys()
872 }
873}
874
875#[must_use]
883pub fn extract_cdp_urls(cert_der: &[u8], allow_http: bool) -> Vec<String> {
884 let Ok((_, cert)) = X509Certificate::from_der(cert_der) else {
885 return Vec::new();
886 };
887
888 let mut urls = Vec::new();
889 for ext in cert.extensions() {
890 if let ParsedExtension::CRLDistributionPoints(cdps) = ext.parsed_extension() {
891 for point in cdps.iter() {
892 if let Some(DistributionPointName::FullName(names)) = &point.distribution_point {
893 for name in names {
894 if let GeneralName::URI(uri) = name {
895 let raw = *uri;
896 let Ok(parsed) = Url::parse(raw) else {
897 tracing::debug!(url = %raw, "CDP URL parse failed; dropped");
898 continue;
899 };
900 if let Err(reason) = check_scheme(&parsed, allow_http) {
901 tracing::debug!(
902 url = %raw,
903 reason,
904 "CDP URL rejected by scheme guard; dropped"
905 );
906 continue;
907 }
908 urls.push(parsed.into());
909 }
910 }
911 }
912 }
913 }
914 }
915
916 urls
917}
918
919#[allow(
926 clippy::cognitive_complexity,
927 reason = "bootstrap coordinates timeout, parallel fetches, and partial-cache recovery"
928)]
929pub async fn bootstrap_fetch(
930 roots: Arc<RootCertStore>,
931 ca_certs: &[CertificateDer<'static>],
932 config: MtlsConfig,
933) -> Result<(Arc<CrlSet>, mpsc::UnboundedReceiver<String>), McpxError> {
934 let (discover_tx, discover_rx) = mpsc::unbounded_channel();
935
936 let mut urls = ca_certs
937 .iter()
938 .flat_map(|cert| extract_cdp_urls(cert.as_ref(), config.crl_allow_http))
939 .collect::<Vec<_>>();
940 urls.sort();
941 urls.dedup();
942
943 let bootstrap_allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
947 let bootstrap_resolver: Arc<dyn reqwest::dns::Resolve> =
948 Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
949 Arc::clone(&bootstrap_allowlist),
950 #[cfg(any(test, feature = "test-helpers"))]
951 Arc::new(std::sync::atomic::AtomicBool::new(false)),
952 #[cfg(not(any(test, feature = "test-helpers")))]
953 (),
954 ));
955
956 let client = reqwest::Client::builder()
957 .no_proxy()
959 .dns_resolver(Arc::clone(&bootstrap_resolver))
960 .timeout(config.crl_fetch_timeout)
961 .connect_timeout(CRL_CONNECT_TIMEOUT)
962 .tcp_keepalive(None)
963 .redirect(reqwest::redirect::Policy::none())
964 .user_agent(format!("rmcp-server-kit/{}", env!("CARGO_PKG_VERSION")))
965 .build()
966 .map_err(|error| McpxError::Startup(format!("CRL HTTP client init: {error}")))?;
967
968 let bootstrap_concurrency = config.crl_max_concurrent_fetches.max(1);
972 let global_sem = Arc::new(Semaphore::new(bootstrap_concurrency));
973 let host_semaphores = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
974 let allow_http = config.crl_allow_http;
975 let max_bytes = config.crl_max_response_bytes;
976 let max_host_semaphores = config.crl_max_host_semaphores;
977
978 let mut initial_cache = HashMap::new();
979 let mut tasks = JoinSet::new();
980 for url in &urls {
981 let client = client.clone();
982 let url = url.clone();
983 let global_sem = Arc::clone(&global_sem);
984 let host_semaphores = Arc::clone(&host_semaphores);
985 tasks.spawn(async move {
986 let result = gated_fetch(
987 &client,
988 &global_sem,
989 &host_semaphores,
990 &url,
991 allow_http,
992 max_bytes,
993 max_host_semaphores,
994 )
995 .await;
996 (url, result)
997 });
998 }
999
1000 let timeout: Sleep = tokio::time::sleep(BOOTSTRAP_TIMEOUT);
1001 tokio::pin!(timeout);
1002
1003 while !tasks.is_empty() {
1004 tokio::select! {
1005 () = &mut timeout => {
1006 tracing::warn!("CRL bootstrap timed out after {:?}", BOOTSTRAP_TIMEOUT);
1007 break;
1008 }
1009 maybe_joined = tasks.join_next() => {
1010 let Some(joined) = maybe_joined else {
1011 break;
1012 };
1013 match joined {
1014 Ok((url, Ok(cached))) => {
1015 initial_cache.insert(url, cached);
1016 }
1017 Ok((url, Err(error))) => {
1018 tracing::warn!(url = %url, error = %error, "CRL bootstrap fetch failed");
1019 }
1020 Err(error) => {
1021 tracing::warn!(error = %error, "CRL bootstrap task join failed");
1022 }
1023 }
1024 }
1025 }
1026 }
1027
1028 let set = CrlSet::new(roots, config, discover_tx, initial_cache)?;
1029 Ok((set, discover_rx))
1030}
1031
1032#[allow(
1034 clippy::cognitive_complexity,
1035 reason = "refresher loop intentionally handles shutdown, timer, and discovery in one select"
1036)]
1037pub async fn run_crl_refresher(
1038 set: Arc<CrlSet>,
1039 mut discover_rx: mpsc::UnboundedReceiver<String>,
1040 shutdown: CancellationToken,
1041) {
1042 let mut refresh_sleep = schedule_next_refresh(&set).await;
1043
1044 loop {
1045 tokio::select! {
1046 () = shutdown.cancelled() => {
1047 break;
1048 }
1049 () = &mut refresh_sleep => {
1050 if let Err(error) = set.refresh_due_urls().await {
1051 tracing::warn!(error = %error, "CRL periodic refresh failed");
1052 }
1053 refresh_sleep = schedule_next_refresh(&set).await;
1054 }
1055 maybe_url = discover_rx.recv() => {
1056 let Some(url) = maybe_url else {
1057 break;
1058 };
1059 if let Err(error) = set.fetch_and_store_url(url.clone()).await {
1060 tracing::warn!(url = %url, error = %error, "CRL discovery fetch failed");
1061 }
1062 refresh_sleep = schedule_next_refresh(&set).await;
1063 }
1064 }
1065 }
1066}
1067
1068pub fn rebuild_verifier<S: std::hash::BuildHasher>(
1074 roots: &Arc<RootCertStore>,
1075 config: &MtlsConfig,
1076 cache: &HashMap<String, CachedCrl, S>,
1077) -> Result<Arc<dyn ClientCertVerifier>, McpxError> {
1078 let mut builder = WebPkiClientVerifier::builder(Arc::clone(roots));
1079
1080 if !cache.is_empty() {
1081 let crls = cache
1082 .values()
1083 .map(|cached| cached.der.clone())
1084 .collect::<Vec<_>>();
1085 builder = builder.with_crls(crls);
1086 }
1087 if config.crl_end_entity_only {
1088 builder = builder.only_check_end_entity_revocation();
1089 }
1090 if !config.crl_deny_on_unavailable {
1091 builder = builder.allow_unknown_revocation_status();
1092 }
1093 if config.crl_enforce_expiration {
1094 builder = builder.enforce_revocation_expiration();
1095 }
1096 if !config.required {
1097 builder = builder.allow_unauthenticated();
1098 }
1099
1100 builder
1101 .build()
1102 .map_err(|error| McpxError::Tls(format!("mTLS verifier error: {error}")))
1103}
1104
1105pub fn parse_crl_metadata(der: &[u8]) -> Result<(SystemTime, Option<SystemTime>), McpxError> {
1111 let (_, crl) = CertificateRevocationList::from_der(der)
1112 .map_err(|error| McpxError::Tls(format!("invalid CRL DER: {error:?}")))?;
1113
1114 Ok((
1115 asn1_time_to_system_time(crl.last_update()),
1116 crl.next_update().map(asn1_time_to_system_time),
1117 ))
1118}
1119
1120async fn schedule_next_refresh(set: &CrlSet) -> Pin<Box<Sleep>> {
1121 let duration = next_refresh_delay(set).await;
1122 boxed_sleep(duration)
1123}
1124
1125fn boxed_sleep(duration: Duration) -> Pin<Box<Sleep>> {
1126 Box::pin(tokio::time::sleep_until(Instant::now() + duration))
1127}
1128
1129async fn next_refresh_delay(set: &CrlSet) -> Duration {
1130 if let Some(interval) = set.config.crl_refresh_interval {
1131 return clamp_refresh(interval);
1132 }
1133
1134 let now = SystemTime::now();
1135 let cache = set.cache.read().await;
1136 let mut next = MAX_AUTO_REFRESH;
1137
1138 for cached in cache.values() {
1139 if let Some(next_update) = cached.next_update {
1140 let duration = next_update.duration_since(now).unwrap_or(Duration::ZERO);
1141 next = next.min(clamp_refresh(duration));
1142 }
1143 }
1144
1145 next
1146}
1147
1148async fn gated_fetch(
1155 client: &reqwest::Client,
1156 global_sem: &Arc<Semaphore>,
1157 host_semaphores: &Arc<tokio::sync::Mutex<HashMap<String, Arc<Semaphore>>>>,
1158 url: &str,
1159 allow_http: bool,
1160 max_bytes: u64,
1161 max_host_semaphores: usize,
1162) -> Result<CachedCrl, McpxError> {
1163 let host_key = Url::parse(url)
1164 .ok()
1165 .and_then(|u| u.host_str().map(str::to_owned))
1166 .unwrap_or_else(|| url.to_owned());
1167
1168 let host_sem = {
1169 let mut map = host_semaphores.lock().await;
1170 if !map.contains_key(&host_key) {
1171 if map.len() >= max_host_semaphores {
1172 return Err(McpxError::Config(
1173 "crl_host_semaphore_cap_exceeded: too many distinct CRL hosts in flight"
1174 .to_owned(),
1175 ));
1176 }
1177 map.insert(host_key.clone(), Arc::new(Semaphore::new(1)));
1178 }
1179 match map.get(&host_key) {
1180 Some(semaphore) => Arc::clone(semaphore),
1181 None => {
1182 return Err(McpxError::Tls(
1183 "CRL host semaphore missing after insertion".to_owned(),
1184 ));
1185 }
1186 }
1187 };
1188
1189 let _global_permit = Arc::clone(global_sem)
1190 .acquire_owned()
1191 .await
1192 .map_err(|error| McpxError::Tls(format!("CRL global semaphore closed: {error}")))?;
1193 let _host_permit = host_sem
1194 .acquire_owned()
1195 .await
1196 .map_err(|error| McpxError::Tls(format!("CRL host semaphore closed: {error}")))?;
1197
1198 fetch_crl(client, url, allow_http, max_bytes).await
1199}
1200
1201async fn fetch_crl(
1202 client: &reqwest::Client,
1203 url: &str,
1204 allow_http: bool,
1205 max_bytes: u64,
1206) -> Result<CachedCrl, McpxError> {
1207 let parsed =
1208 Url::parse(url).map_err(|error| McpxError::Tls(format!("CRL URL parse {url}: {error}")))?;
1209
1210 if let Err(reason) = check_scheme(&parsed, allow_http) {
1211 tracing::warn!(url = %url, reason, "CRL fetch denied: scheme");
1212 return Err(McpxError::Tls(format!(
1213 "CRL scheme rejected ({reason}): {url}"
1214 )));
1215 }
1216
1217 let host = parsed
1218 .host_str()
1219 .ok_or_else(|| McpxError::Tls(format!("CRL URL has no host: {url}")))?;
1220 let port = parsed
1221 .port_or_known_default()
1222 .ok_or_else(|| McpxError::Tls(format!("CRL URL has no known port: {url}")))?;
1223
1224 let addrs = lookup_host((host, port))
1225 .await
1226 .map_err(|error| McpxError::Tls(format!("CRL DNS resolution {url}: {error}")))?;
1227
1228 let mut any_addr = false;
1229 for addr in addrs {
1230 any_addr = true;
1231 if let Some(reason) = ip_block_reason(addr.ip()) {
1232 tracing::warn!(
1233 url = %url,
1234 resolved_ip = %addr.ip(),
1235 reason,
1236 "CRL fetch denied: blocked IP"
1237 );
1238 return Err(McpxError::Tls(format!(
1239 "CRL host resolved to blocked IP ({reason}): {url}"
1240 )));
1241 }
1242 }
1243 if !any_addr {
1244 return Err(McpxError::Tls(format!(
1245 "CRL DNS resolution returned no addresses: {url}"
1246 )));
1247 }
1248
1249 let mut response = client
1250 .get(url)
1251 .send()
1252 .await
1253 .map_err(|error| McpxError::Tls(format!("CRL fetch {url}: {error}")))?
1254 .error_for_status()
1255 .map_err(|error| McpxError::Tls(format!("CRL fetch {url}: {error}")))?;
1256
1257 let initial_capacity = usize::try_from(max_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
1260 let mut body: Vec<u8> = Vec::with_capacity(initial_capacity);
1261 while let Some(chunk) = response
1262 .chunk()
1263 .await
1264 .map_err(|error| McpxError::Tls(format!("CRL read {url}: {error}")))?
1265 {
1266 let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
1267 let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
1268 if body_len.saturating_add(chunk_len) > max_bytes {
1269 return Err(McpxError::Tls(format!(
1270 "CRL body exceeded cap of {max_bytes} bytes: {url}"
1271 )));
1272 }
1273 body.extend_from_slice(&chunk);
1274 }
1275
1276 let der = CertificateRevocationListDer::from(body);
1277 let (this_update, next_update) = parse_crl_metadata(der.as_ref())?;
1278
1279 Ok(CachedCrl {
1280 der,
1281 this_update,
1282 next_update,
1283 fetched_at: SystemTime::now(),
1284 source_url: url.to_owned(),
1285 })
1286}
1287
1288fn should_refresh_cached(
1289 cached: &CachedCrl,
1290 now: SystemTime,
1291 fixed_interval: Option<Duration>,
1292) -> bool {
1293 if let Some(interval) = fixed_interval {
1294 return cached
1295 .fetched_at
1296 .checked_add(clamp_refresh(interval))
1297 .is_none_or(|deadline| now >= deadline);
1298 }
1299
1300 cached
1301 .next_update
1302 .is_none_or(|next_update| now >= next_update)
1303}
1304
1305fn clamp_refresh(duration: Duration) -> Duration {
1306 duration.clamp(MIN_AUTO_REFRESH, MAX_AUTO_REFRESH)
1307}
1308
1309fn asn1_time_to_system_time(time: x509_parser::time::ASN1Time) -> SystemTime {
1310 let timestamp = time.timestamp();
1311 if timestamp >= 0 {
1312 let seconds = u64::try_from(timestamp).unwrap_or(0);
1313 UNIX_EPOCH + Duration::from_secs(seconds)
1314 } else {
1315 UNIX_EPOCH - Duration::from_secs(timestamp.unsigned_abs())
1316 }
1317}