1use std::collections::HashMap;
19use std::fmt::Write;
20use std::sync::Arc;
21use std::time::{Duration, SystemTime};
22
23use parking_lot::RwLock;
24use tracing::warn;
25use x509_parser::prelude::*;
26
27use crate::error::{NetError, NetResult};
28use crate::mtls::RevocationStatus;
29
30const TAG_SEQUENCE: u8 = 0x30;
34const TAG_OCTET_STRING: u8 = 0x04;
36const TAG_INTEGER: u8 = 0x02;
38const TAG_OID: u8 = 0x06;
40const TAG_ENUMERATED: u8 = 0x0A;
42const TAG_CONTEXT_0: u8 = 0xA0;
44const TAG_CONTEXT_1: u8 = 0xA1;
46const TAG_CONTEXT_PRIM_1: u8 = 0x81;
48
49const SHA256_OID_BYTES: &[u8] = &[0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01];
51
52const OCSP_BASIC_OID_BYTES: &[u8] = &[0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01, 0x01];
54
55const AIA_OCSP_OID_BYTES: &[u8] = &[0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x30, 0x01];
57
58const AIA_EXT_OID_BYTES: &[u8] = &[0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x01, 0x01];
60
61const DEFAULT_OCSP_TIMEOUT: Duration = Duration::from_secs(5);
63
64const DEFAULT_CACHE_TTL: Duration = Duration::from_secs(3600);
66
67fn der_encode_length(len: usize) -> Vec<u8> {
71 if len < 0x80 {
72 vec![len as u8]
73 } else if len < 0x100 {
74 vec![0x81, len as u8]
75 } else if len < 0x10000 {
76 vec![0x82, (len >> 8) as u8, len as u8]
77 } else if len < 0x100_0000 {
78 vec![0x83, (len >> 16) as u8, (len >> 8) as u8, len as u8]
79 } else {
80 vec![
81 0x84,
82 (len >> 24) as u8,
83 (len >> 16) as u8,
84 (len >> 8) as u8,
85 len as u8,
86 ]
87 }
88}
89
90fn der_tlv(tag: u8, content: &[u8]) -> Vec<u8> {
92 let mut out = vec![tag];
93 out.extend(der_encode_length(content.len()));
94 out.extend(content);
95 out
96}
97
98fn der_oid(oid_bytes: &[u8]) -> Vec<u8> {
100 der_tlv(TAG_OID, oid_bytes)
101}
102
103fn der_algorithm_identifier(oid_bytes: &[u8]) -> Vec<u8> {
105 let mut content = der_oid(oid_bytes);
106 content.extend(&[0x05, 0x00]);
108 der_tlv(TAG_SEQUENCE, &content)
109}
110
111fn der_octet_string(data: &[u8]) -> Vec<u8> {
113 der_tlv(TAG_OCTET_STRING, data)
114}
115
116fn der_integer_from_bytes(data: &[u8]) -> Vec<u8> {
118 let stripped = data
120 .iter()
121 .position(|&b| b != 0)
122 .map_or(&data[data.len().saturating_sub(1)..], |pos| &data[pos..]);
123
124 if stripped.first().is_some_and(|&b| b & 0x80 != 0) {
126 let mut content = vec![0x00];
127 content.extend(stripped);
128 der_tlv(TAG_INTEGER, &content)
129 } else {
130 der_tlv(TAG_INTEGER, stripped)
131 }
132}
133
134fn der_read_length(data: &[u8]) -> NetResult<(usize, usize)> {
138 if data.is_empty() {
139 return Err(NetError::InvalidCertificate(
140 "OCSP: unexpected end of DER data reading length".to_string(),
141 ));
142 }
143 let first = data[0];
144 if first < 0x80 {
145 Ok((first as usize, 1))
146 } else {
147 let num_bytes = (first & 0x7F) as usize;
148 if num_bytes == 0 || num_bytes > 4 {
149 return Err(NetError::InvalidCertificate(format!(
150 "OCSP: unsupported DER length encoding ({num_bytes} bytes)"
151 )));
152 }
153 if data.len() < 1 + num_bytes {
154 return Err(NetError::InvalidCertificate(
155 "OCSP: truncated DER length".to_string(),
156 ));
157 }
158 let mut val: usize = 0;
159 for i in 0..num_bytes {
160 val = (val << 8) | (data[1 + i] as usize);
161 }
162 Ok((val, 1 + num_bytes))
163 }
164}
165
166fn der_read_tlv(data: &[u8]) -> NetResult<(u8, &[u8], usize)> {
168 if data.is_empty() {
169 return Err(NetError::InvalidCertificate(
170 "OCSP: unexpected end of DER data reading TLV".to_string(),
171 ));
172 }
173 let tag = data[0];
174 let (len, len_bytes) = der_read_length(&data[1..])?;
175 let header_len = 1 + len_bytes;
176 let total = header_len + len;
177 if data.len() < total {
178 return Err(NetError::InvalidCertificate(format!(
179 "OCSP: DER content truncated (need {total}, have {})",
180 data.len()
181 )));
182 }
183 Ok((tag, &data[header_len..total], total))
184}
185
186fn der_children(data: &[u8]) -> NetResult<Vec<(u8, Vec<u8>)>> {
188 let mut children = Vec::new();
189 let mut pos = 0;
190 while pos < data.len() {
191 let (tag, content, consumed) = der_read_tlv(&data[pos..])?;
192 children.push((tag, content.to_vec()));
193 pos += consumed;
194 }
195 Ok(children)
196}
197
198fn cert_fingerprint(cert_der: &[u8]) -> String {
202 cert_der.iter().take(32).fold(String::new(), |mut s, b| {
203 let _ = write!(&mut s, "{b:02x}");
204 s
205 })
206}
207
208pub fn extract_ocsp_url(cert_der: &[u8]) -> NetResult<Option<String>> {
212 let (_, parsed) = X509Certificate::from_der(cert_der).map_err(|e| {
213 NetError::InvalidCertificate(format!("OCSP: failed to parse certificate: {e}"))
214 })?;
215
216 let aia_oid = asn1_rs::Oid::new(std::borrow::Cow::Borrowed(AIA_EXT_OID_BYTES));
218 let aia = parsed.extensions().iter().find(|ext| ext.oid == aia_oid);
219
220 let aia_ext = match aia {
221 Some(ext) => ext,
222 None => return Ok(None),
223 };
224
225 let children = der_children(aia_ext.value)?;
228
229 for (tag, child_data) in &children {
230 if *tag != TAG_SEQUENCE {
232 continue;
233 }
234 let inner = der_children(child_data)?;
235 if inner.len() < 2 {
236 continue;
237 }
238
239 let (oid_tag, oid_data) = &inner[0];
241 if *oid_tag != TAG_OID {
242 continue;
243 }
244
245 if oid_data.as_slice() != AIA_OCSP_OID_BYTES {
247 continue;
248 }
249
250 let (name_tag, name_data) = &inner[1];
252 if *name_tag == 0x86 {
254 let url = String::from_utf8(name_data.clone()).map_err(|e| {
255 NetError::InvalidCertificate(format!("OCSP: invalid URL encoding: {e}"))
256 })?;
257 return Ok(Some(url));
258 }
259 }
260
261 Ok(None)
262}
263
264pub fn build_ocsp_request(cert_der: &[u8]) -> NetResult<Vec<u8>> {
288 let (_, parsed) = X509Certificate::from_der(cert_der).map_err(|e| {
289 NetError::InvalidCertificate(format!("OCSP: failed to parse certificate: {e}"))
290 })?;
291
292 let issuer_name_der = parsed.issuer().as_raw();
294 let issuer_name_hash = blake3::hash(issuer_name_der);
295
296 let spki_der = parsed.public_key().raw;
301 let issuer_key_hash = blake3::hash(spki_der);
302
303 let serial_bytes = parsed.serial.to_bytes_be();
305
306 let algo_id = der_algorithm_identifier(SHA256_OID_BYTES);
308 let name_hash = der_octet_string(issuer_name_hash.as_bytes());
309 let key_hash = der_octet_string(issuer_key_hash.as_bytes());
310 let serial_int = der_integer_from_bytes(&serial_bytes);
311
312 let mut cert_id_content = Vec::new();
313 cert_id_content.extend(&algo_id);
314 cert_id_content.extend(&name_hash);
315 cert_id_content.extend(&key_hash);
316 cert_id_content.extend(&serial_int);
317 let cert_id = der_tlv(TAG_SEQUENCE, &cert_id_content);
318
319 let request = der_tlv(TAG_SEQUENCE, &cert_id);
321
322 let request_list = der_tlv(TAG_SEQUENCE, &request);
324
325 let tbs_request = der_tlv(TAG_SEQUENCE, &request_list);
327
328 let ocsp_request = der_tlv(TAG_SEQUENCE, &tbs_request);
330
331 Ok(ocsp_request)
332}
333
334fn parse_url(url: &str) -> NetResult<(String, u16, String)> {
338 let without_scheme = if let Some(rest) = url.strip_prefix("http://") {
340 rest
341 } else if let Some(rest) = url.strip_prefix("https://") {
342 rest
344 } else {
345 url
346 };
347
348 let (host_port, path) = match without_scheme.find('/') {
350 Some(idx) => (&without_scheme[..idx], &without_scheme[idx..]),
351 None => (without_scheme, "/"),
352 };
353
354 let (host, port) = match host_port.rfind(':') {
356 Some(idx) => {
357 let port_str = &host_port[idx + 1..];
358 let port: u16 = port_str.parse().map_err(|e| {
359 NetError::InvalidCertificate(format!("OCSP: invalid port in URL: {e}"))
360 })?;
361 (host_port[..idx].to_string(), port)
362 }
363 None => (host_port.to_string(), 80),
364 };
365
366 Ok((host, port, path.to_string()))
367}
368
369pub async fn send_ocsp_request(
373 url: &str,
374 request_der: &[u8],
375 timeout: Duration,
376) -> NetResult<Vec<u8>> {
377 use tokio::io::{AsyncReadExt, AsyncWriteExt};
378 use tokio::net::TcpStream;
379
380 let (host, port, path) = parse_url(url)?;
381
382 let http_request = format!(
384 "POST {path} HTTP/1.1\r\n\
385 Host: {host}\r\n\
386 Content-Type: application/ocsp-request\r\n\
387 Content-Length: {}\r\n\
388 Connection: close\r\n\
389 \r\n",
390 request_der.len()
391 );
392
393 let addr = format!("{host}:{port}");
395 let stream = tokio::time::timeout(timeout, TcpStream::connect(&addr))
396 .await
397 .map_err(|_| NetError::Timeout(format!("OCSP: connection to {addr} timed out")))?
398 .map_err(|e| {
399 NetError::ConnectionRefused(format!("OCSP: failed to connect to {addr}: {e}"))
400 })?;
401
402 let mut stream = stream;
403
404 tokio::time::timeout(timeout, async {
406 stream
407 .write_all(http_request.as_bytes())
408 .await
409 .map_err(|e| NetError::ConnectionReset(format!("OCSP: failed to send request: {e}")))?;
410 stream.write_all(request_der).await.map_err(|e| {
411 NetError::ConnectionReset(format!("OCSP: failed to send request body: {e}"))
412 })?;
413 stream
414 .flush()
415 .await
416 .map_err(|e| NetError::ConnectionReset(format!("OCSP: failed to flush: {e}")))?;
417 Ok::<(), NetError>(())
418 })
419 .await
420 .map_err(|_| NetError::Timeout("OCSP: send timed out".to_string()))??;
421
422 let response_bytes = tokio::time::timeout(timeout, async {
424 let mut buf = Vec::with_capacity(8192);
425 stream.read_to_end(&mut buf).await.map_err(|e| {
426 NetError::ConnectionReset(format!("OCSP: failed to read response: {e}"))
427 })?;
428 Ok::<Vec<u8>, NetError>(buf)
429 })
430 .await
431 .map_err(|_| NetError::Timeout("OCSP: read timed out".to_string()))??;
432
433 let header_end = response_bytes
435 .windows(4)
436 .position(|w| w == b"\r\n\r\n")
437 .ok_or_else(|| {
438 NetError::InvalidCertificate(
439 "OCSP: malformed HTTP response (no header end)".to_string(),
440 )
441 })?;
442
443 let header_str = String::from_utf8_lossy(&response_bytes[..header_end]);
444
445 let status_line = header_str
447 .lines()
448 .next()
449 .ok_or_else(|| NetError::InvalidCertificate("OCSP: empty HTTP response".to_string()))?;
450
451 let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
453 if parts.len() < 2 {
454 return Err(NetError::InvalidCertificate(format!(
455 "OCSP: malformed HTTP status line: {status_line}"
456 )));
457 }
458 let status_code: u16 = parts[1].parse().map_err(|e| {
459 NetError::InvalidCertificate(format!("OCSP: invalid HTTP status code: {e}"))
460 })?;
461
462 if status_code != 200 {
463 return Err(NetError::InvalidCertificate(format!(
464 "OCSP: HTTP error {status_code}"
465 )));
466 }
467
468 let body_start = header_end + 4;
469 if body_start >= response_bytes.len() {
470 return Err(NetError::InvalidCertificate(
471 "OCSP: empty HTTP response body".to_string(),
472 ));
473 }
474
475 Ok(response_bytes[body_start..].to_vec())
476}
477
478pub fn parse_ocsp_response(response_der: &[u8]) -> NetResult<RevocationStatus> {
520 let (tag, ocsp_resp_content, _) = der_read_tlv(response_der)?;
522 if tag != TAG_SEQUENCE {
523 return Err(NetError::InvalidCertificate(format!(
524 "OCSP: expected SEQUENCE, got 0x{tag:02x}"
525 )));
526 }
527
528 let children = der_children(ocsp_resp_content)?;
529 if children.is_empty() {
530 return Err(NetError::InvalidCertificate(
531 "OCSP: empty OCSPResponse".to_string(),
532 ));
533 }
534
535 let (status_tag, status_data) = &children[0];
537 if *status_tag != TAG_ENUMERATED {
538 return Err(NetError::InvalidCertificate(format!(
539 "OCSP: expected ENUMERATED for responseStatus, got 0x{status_tag:02x}"
540 )));
541 }
542 let response_status = status_data
543 .first()
544 .copied()
545 .ok_or_else(|| NetError::InvalidCertificate("OCSP: empty responseStatus".to_string()))?;
546
547 if response_status != 0 {
550 return Err(NetError::InvalidCertificate(format!(
551 "OCSP: non-successful responseStatus: {response_status}"
552 )));
553 }
554
555 if children.len() < 2 {
557 return Err(NetError::InvalidCertificate(
558 "OCSP: missing responseBytes".to_string(),
559 ));
560 }
561
562 let (rb_tag, rb_data) = &children[1];
563 if *rb_tag != TAG_CONTEXT_0 {
564 return Err(NetError::InvalidCertificate(format!(
565 "OCSP: expected [0] for responseBytes, got 0x{rb_tag:02x}"
566 )));
567 }
568
569 let (inner_tag, inner_content, _) = der_read_tlv(rb_data)?;
571 if inner_tag != TAG_SEQUENCE {
572 return Err(NetError::InvalidCertificate(
573 "OCSP: responseBytes inner not SEQUENCE".to_string(),
574 ));
575 }
576
577 let rb_children = der_children(inner_content)?;
578 if rb_children.len() < 2 {
579 return Err(NetError::InvalidCertificate(
580 "OCSP: responseBytes SEQUENCE too short".to_string(),
581 ));
582 }
583
584 let (oid_tag, oid_data) = &rb_children[0];
586 if *oid_tag != TAG_OID {
587 return Err(NetError::InvalidCertificate(
588 "OCSP: responseType not OID".to_string(),
589 ));
590 }
591 if oid_data.as_slice() != OCSP_BASIC_OID_BYTES {
592 return Err(NetError::InvalidCertificate(
593 "OCSP: responseType is not id-pkix-ocsp-basic".to_string(),
594 ));
595 }
596
597 let (oct_tag, oct_data) = &rb_children[1];
599 if *oct_tag != TAG_OCTET_STRING {
600 return Err(NetError::InvalidCertificate(
601 "OCSP: response not OCTET STRING".to_string(),
602 ));
603 }
604
605 parse_basic_ocsp_response(oct_data)
607}
608
609fn parse_basic_ocsp_response(data: &[u8]) -> NetResult<RevocationStatus> {
611 let (tag, content, _) = der_read_tlv(data)?;
613 if tag != TAG_SEQUENCE {
614 return Err(NetError::InvalidCertificate(
615 "OCSP: BasicOCSPResponse not SEQUENCE".to_string(),
616 ));
617 }
618
619 let children = der_children(content)?;
620 if children.is_empty() {
621 return Err(NetError::InvalidCertificate(
622 "OCSP: empty BasicOCSPResponse".to_string(),
623 ));
624 }
625
626 let (tbs_tag, tbs_data) = &children[0];
628 if *tbs_tag != TAG_SEQUENCE {
629 return Err(NetError::InvalidCertificate(
630 "OCSP: tbsResponseData not SEQUENCE".to_string(),
631 ));
632 }
633
634 parse_tbs_response_data(tbs_data)
635}
636
637fn parse_tbs_response_data(data: &[u8]) -> NetResult<RevocationStatus> {
639 let children = der_children(data)?;
640
641 let mut response_seq: Option<&Vec<u8>> = None;
650 let mut found_time = false;
651
652 for (tag, child_data) in &children {
653 if *tag == TAG_CONTEXT_0 {
655 continue;
656 }
657 if *tag == TAG_CONTEXT_1 || *tag == 0xA2 {
659 continue;
660 }
661 if *tag == 0x18 {
663 found_time = true;
664 continue;
665 }
666 if *tag == TAG_SEQUENCE && found_time {
668 response_seq = Some(child_data);
669 break;
670 }
671 if *tag == TAG_SEQUENCE && !found_time {
674 if let Ok(inner) = der_children(child_data) {
676 if !inner.is_empty() && inner[0].0 == TAG_SEQUENCE {
677 response_seq = Some(child_data);
678 break;
679 }
680 }
681 }
682 }
683
684 let responses_data = response_seq.ok_or_else(|| {
685 NetError::InvalidCertificate(
686 "OCSP: could not find responses SEQUENCE in ResponseData".to_string(),
687 )
688 })?;
689
690 let single_responses = der_children(responses_data)?;
692 if single_responses.is_empty() {
693 return Err(NetError::InvalidCertificate(
694 "OCSP: no SingleResponse found".to_string(),
695 ));
696 }
697
698 let (sr_tag, sr_data) = &single_responses[0];
699 if *sr_tag != TAG_SEQUENCE {
700 return Err(NetError::InvalidCertificate(
701 "OCSP: SingleResponse not SEQUENCE".to_string(),
702 ));
703 }
704
705 parse_single_response(sr_data)
706}
707
708fn parse_single_response(data: &[u8]) -> NetResult<RevocationStatus> {
710 let children = der_children(data)?;
711
712 if children.len() < 2 {
719 return Err(NetError::InvalidCertificate(
720 "OCSP: SingleResponse too short".to_string(),
721 ));
722 }
723
724 let (status_tag, _status_data) = &children[1];
726
727 match *status_tag {
733 0x80 => Ok(RevocationStatus::Good), 0xA1 | TAG_CONTEXT_PRIM_1 => Ok(RevocationStatus::Revoked), 0x82 => Ok(RevocationStatus::Unknown), other => {
737 warn!("OCSP: unexpected certStatus tag 0x{other:02x}");
738 Ok(RevocationStatus::Unknown)
739 }
740 }
741}
742
743#[derive(Debug)]
754pub struct OcspRevocationChecker {
755 responder_url: Option<String>,
757 response_cache: Arc<RwLock<HashMap<String, (RevocationStatus, SystemTime)>>>,
759 cache_ttl: Duration,
761 timeout: Duration,
763}
764
765impl Default for OcspRevocationChecker {
766 fn default() -> Self {
767 Self::new()
768 }
769}
770
771impl OcspRevocationChecker {
772 pub fn new() -> Self {
774 Self {
775 responder_url: None,
776 response_cache: Arc::new(RwLock::new(HashMap::new())),
777 cache_ttl: DEFAULT_CACHE_TTL,
778 timeout: DEFAULT_OCSP_TIMEOUT,
779 }
780 }
781
782 pub fn with_responder_url(mut self, url: impl Into<String>) -> Self {
784 self.responder_url = Some(url.into());
785 self
786 }
787
788 pub fn with_cache_ttl(mut self, ttl: Duration) -> Self {
790 self.cache_ttl = ttl;
791 self
792 }
793
794 pub fn with_timeout(mut self, timeout: Duration) -> Self {
796 self.timeout = timeout;
797 self
798 }
799
800 pub fn get_cached(&self, fingerprint: &str) -> Option<RevocationStatus> {
802 let cache = self.response_cache.read();
803 if let Some((status, timestamp)) = cache.get(fingerprint) {
804 if timestamp.elapsed().unwrap_or(Duration::MAX) < self.cache_ttl {
805 return Some(*status);
806 }
807 }
808 None
809 }
810
811 pub fn cache_status(&self, fingerprint: String, status: RevocationStatus) {
813 let mut cache = self.response_cache.write();
814 cache.insert(fingerprint, (status, SystemTime::now()));
815 }
816
817 pub fn clear_cache(&self) {
819 self.response_cache.write().clear();
820 }
821
822 pub fn cache_size(&self) -> usize {
824 self.response_cache.read().len()
825 }
826
827 fn resolve_responder_url(&self, cert_der: &[u8]) -> NetResult<Option<String>> {
829 if let Some(ref url) = self.responder_url {
831 return Ok(Some(url.clone()));
832 }
833 extract_ocsp_url(cert_der)
835 }
836
837 pub fn check_revocation(
839 &self,
840 cert: &rustls::pki_types::CertificateDer<'_>,
841 ) -> NetResult<RevocationStatus> {
842 let fingerprint = cert_fingerprint(cert.as_ref());
843
844 if let Some(status) = self.get_cached(&fingerprint) {
846 return Ok(status);
847 }
848
849 Ok(RevocationStatus::Unknown)
851 }
852
853 pub fn check_revocation_async<'a>(
855 &'a self,
856 cert: &'a rustls::pki_types::CertificateDer<'_>,
857 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + 'a>>
858 {
859 let cert_bytes = cert.as_ref().to_vec();
860 let fingerprint = cert_fingerprint(&cert_bytes);
861
862 if let Some(status) = self.get_cached(&fingerprint) {
864 return Box::pin(async move { Ok(status) });
865 }
866
867 let timeout = self.timeout;
868
869 Box::pin(async move {
870 let url = match self.resolve_responder_url(&cert_bytes) {
872 Ok(Some(url)) => url,
873 Ok(None) => {
874 warn!("OCSP: no responder URL available for certificate");
875 return Ok(RevocationStatus::Unknown);
876 }
877 Err(e) => {
878 warn!("OCSP: failed to resolve responder URL: {e}");
879 return Ok(RevocationStatus::Unknown);
880 }
881 };
882
883 let request_der = match build_ocsp_request(&cert_bytes) {
885 Ok(req) => req,
886 Err(e) => {
887 warn!("OCSP: failed to build request: {e}");
888 return Ok(RevocationStatus::Unknown);
889 }
890 };
891
892 let response_der = match send_ocsp_request(&url, &request_der, timeout).await {
894 Ok(resp) => resp,
895 Err(e) => {
896 warn!("OCSP: network error: {e}");
897 return Ok(RevocationStatus::Unknown);
898 }
899 };
900
901 let status = match parse_ocsp_response(&response_der) {
903 Ok(s) => s,
904 Err(e) => {
905 warn!("OCSP: failed to parse response: {e}");
906 RevocationStatus::Unknown
907 }
908 };
909
910 self.cache_status(fingerprint, status);
912
913 Ok(status)
914 })
915 }
916}
917
918impl crate::mtls::RevocationChecker for OcspRevocationChecker {
921 fn check_revocation(
922 &self,
923 cert: &rustls::pki_types::CertificateDer<'_>,
924 ) -> NetResult<RevocationStatus> {
925 OcspRevocationChecker::check_revocation(self, cert)
926 }
927
928 fn check_revocation_async(
929 &self,
930 cert: &rustls::pki_types::CertificateDer<'_>,
931 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = NetResult<RevocationStatus>> + Send + '_>>
932 {
933 let cert_bytes = cert.as_ref().to_vec();
934 let fingerprint = cert_fingerprint(&cert_bytes);
935
936 if let Some(status) = self.get_cached(&fingerprint) {
938 return Box::pin(async move { Ok(status) });
939 }
940
941 let timeout = self.timeout;
942
943 Box::pin(async move {
944 let url = match self.resolve_responder_url(&cert_bytes) {
946 Ok(Some(url)) => url,
947 Ok(None) => {
948 warn!("OCSP: no responder URL available for certificate");
949 return Ok(RevocationStatus::Unknown);
950 }
951 Err(e) => {
952 warn!("OCSP: failed to resolve responder URL: {e}");
953 return Ok(RevocationStatus::Unknown);
954 }
955 };
956
957 let request_der = match build_ocsp_request(&cert_bytes) {
959 Ok(req) => req,
960 Err(e) => {
961 warn!("OCSP: failed to build request: {e}");
962 return Ok(RevocationStatus::Unknown);
963 }
964 };
965
966 let response_der = match send_ocsp_request(&url, &request_der, timeout).await {
968 Ok(resp) => resp,
969 Err(e) => {
970 warn!("OCSP: network error: {e}");
971 return Ok(RevocationStatus::Unknown);
972 }
973 };
974
975 let status = match parse_ocsp_response(&response_der) {
977 Ok(s) => s,
978 Err(e) => {
979 warn!("OCSP: failed to parse response: {e}");
980 RevocationStatus::Unknown
981 }
982 };
983
984 self.cache_status(fingerprint, status);
986
987 Ok(status)
988 })
989 }
990}
991
992#[cfg(test)]
995mod tests {
996 use super::*;
997 use crate::tls::SelfSignedGenerator;
998
999 fn gen_test_cert() -> (rustls::pki_types::CertificateDer<'static>, Vec<u8>) {
1001 let generator = SelfSignedGenerator::new("test-ocsp");
1002 let (cert, _key) = generator.generate().expect("should generate cert");
1003 let der = cert.as_ref().to_vec();
1004 (cert, der)
1005 }
1006
1007 #[test]
1010 fn test_build_ocsp_request_structure() {
1011 let (_cert, der) = gen_test_cert();
1012 let req = build_ocsp_request(&der).expect("should build OCSP request");
1013
1014 assert_eq!(
1016 req[0], TAG_SEQUENCE,
1017 "OCSP request must start with SEQUENCE tag"
1018 );
1019
1020 let (tag, content, total) = der_read_tlv(&req).expect("should parse outer SEQUENCE");
1022 assert_eq!(tag, TAG_SEQUENCE);
1023 assert_eq!(total, req.len(), "entire request should be consumed");
1024
1025 let (tbs_tag, tbs_content, _) = der_read_tlv(content).expect("should parse TBSRequest");
1027 assert_eq!(tbs_tag, TAG_SEQUENCE);
1028
1029 let (rl_tag, rl_content, _) = der_read_tlv(tbs_content).expect("should parse requestList");
1031 assert_eq!(rl_tag, TAG_SEQUENCE);
1032
1033 let (r_tag, r_content, _) = der_read_tlv(rl_content).expect("should parse Request");
1035 assert_eq!(r_tag, TAG_SEQUENCE);
1036
1037 let (cid_tag, cid_content, _) = der_read_tlv(r_content).expect("should parse CertID");
1039 assert_eq!(cid_tag, TAG_SEQUENCE);
1040
1041 let cert_id_children = der_children(cid_content).expect("should parse CertID fields");
1043 assert_eq!(
1044 cert_id_children.len(),
1045 4,
1046 "CertID must have 4 fields (algo, nameHash, keyHash, serial)"
1047 );
1048
1049 assert_eq!(cert_id_children[0].0, TAG_SEQUENCE, "algo must be SEQUENCE");
1051 assert_eq!(cert_id_children[1].0, TAG_OCTET_STRING);
1053 assert_eq!(cert_id_children[2].0, TAG_OCTET_STRING);
1055 assert_eq!(cert_id_children[3].0, TAG_INTEGER);
1057 }
1058
1059 #[test]
1062 fn test_extract_ocsp_url_from_aia() {
1063 let (_cert, der) = gen_test_cert();
1065 let url = extract_ocsp_url(&der).expect("should not error");
1066 assert_eq!(url, None);
1068 }
1069
1070 #[test]
1073 fn test_parse_ocsp_response_good() {
1074 let response = build_test_ocsp_response(0x80, &[0x00]); let status = parse_ocsp_response(&response).expect("should parse good response");
1077 assert_eq!(status, RevocationStatus::Good);
1078 }
1079
1080 #[test]
1083 fn test_parse_ocsp_response_revoked() {
1084 let revoked_info = der_tlv(0x18, b"20250101000000Z"); let response = build_test_ocsp_response(0xA1, &revoked_info);
1088 let status = parse_ocsp_response(&response).expect("should parse revoked response");
1089 assert_eq!(status, RevocationStatus::Revoked);
1090 }
1091
1092 #[test]
1095 fn test_parse_ocsp_response_unknown() {
1096 let response = build_test_ocsp_response(0x82, &[0x00]); let status = parse_ocsp_response(&response).expect("should parse unknown response");
1099 assert_eq!(status, RevocationStatus::Unknown);
1100 }
1101
1102 #[test]
1105 fn test_parse_ocsp_response_malformed() {
1106 let garbage = vec![0xFF, 0x01, 0x02, 0x03, 0xDE, 0xAD, 0xBE, 0xEF];
1107 let result = parse_ocsp_response(&garbage);
1108 assert!(result.is_err(), "garbage bytes should return error");
1109 }
1110
1111 #[test]
1114 fn test_ocsp_cache_hit() {
1115 let checker = OcspRevocationChecker::new().with_cache_ttl(Duration::from_secs(3600));
1116 let (cert, _der) = gen_test_cert();
1117
1118 let fingerprint = cert_fingerprint(cert.as_ref());
1119 checker.cache_status(fingerprint, RevocationStatus::Good);
1120
1121 let status = checker
1122 .check_revocation(&cert)
1123 .expect("should check revocation");
1124 assert_eq!(status, RevocationStatus::Good);
1125 }
1126
1127 #[tokio::test]
1130 async fn test_ocsp_cache_miss_and_populate() {
1131 let checker = OcspRevocationChecker::new().with_cache_ttl(Duration::from_secs(3600));
1132 let (cert, _der) = gen_test_cert();
1133
1134 let fingerprint = cert_fingerprint(cert.as_ref());
1135
1136 assert!(checker.get_cached(&fingerprint).is_none());
1138
1139 let status = checker
1140 .check_revocation(&cert)
1141 .expect("should check revocation");
1142 assert_eq!(status, RevocationStatus::Unknown);
1143
1144 checker.cache_status(fingerprint.clone(), RevocationStatus::Good);
1146 assert_eq!(
1147 checker.get_cached(&fingerprint),
1148 Some(RevocationStatus::Good)
1149 );
1150 assert_eq!(checker.cache_size(), 1);
1151 }
1152
1153 #[test]
1156 fn test_ocsp_cache_expiry() {
1157 let checker = OcspRevocationChecker::new().with_cache_ttl(Duration::from_millis(1));
1159 let (cert, _der) = gen_test_cert();
1160
1161 let fingerprint = cert_fingerprint(cert.as_ref());
1162 checker.cache_status(fingerprint.clone(), RevocationStatus::Good);
1163
1164 std::thread::sleep(Duration::from_millis(10));
1166
1167 assert!(
1169 checker.get_cached(&fingerprint).is_none(),
1170 "expired cache entry should not be returned"
1171 );
1172
1173 let status = checker
1175 .check_revocation(&cert)
1176 .expect("should check revocation");
1177 assert_eq!(status, RevocationStatus::Unknown);
1178 }
1179
1180 #[test]
1183 fn test_ocsp_sync_cache_only() {
1184 let checker = OcspRevocationChecker::new()
1186 .with_responder_url("http://ocsp.example.com")
1187 .with_cache_ttl(Duration::from_secs(3600));
1188
1189 let (cert, _der) = gen_test_cert();
1190
1191 let status = checker
1193 .check_revocation(&cert)
1194 .expect("should check revocation");
1195 assert_eq!(status, RevocationStatus::Unknown);
1196 }
1197
1198 #[tokio::test]
1201 async fn test_ocsp_fallback_on_error() {
1202 let checker = OcspRevocationChecker::new()
1204 .with_responder_url("http://127.0.0.1:1")
1205 .with_timeout(Duration::from_millis(100));
1206
1207 let (cert, _der) = gen_test_cert();
1208
1209 let status = checker
1210 .check_revocation_async(&cert)
1211 .await
1212 .expect("should not error even on network failure");
1213 assert_eq!(status, RevocationStatus::Unknown);
1214 }
1215
1216 #[test]
1219 fn test_ocsp_with_custom_responder() {
1220 let checker =
1221 OcspRevocationChecker::new().with_responder_url("http://custom-ocsp.example.com/ocsp");
1222
1223 let (_cert, der) = gen_test_cert();
1224
1225 let url = checker
1227 .resolve_responder_url(&der)
1228 .expect("should resolve URL");
1229 assert_eq!(url, Some("http://custom-ocsp.example.com/ocsp".to_string()));
1230 }
1231
1232 #[test]
1235 fn test_parse_url_variants() {
1236 let (host, port, path) =
1237 parse_url("http://ocsp.example.com:8080/ocsp").expect("should parse");
1238 assert_eq!(host, "ocsp.example.com");
1239 assert_eq!(port, 8080);
1240 assert_eq!(path, "/ocsp");
1241
1242 let (host, port, path) = parse_url("http://ocsp.example.com/check").expect("should parse");
1243 assert_eq!(host, "ocsp.example.com");
1244 assert_eq!(port, 80);
1245 assert_eq!(path, "/check");
1246
1247 let (host, port, path) = parse_url("http://ocsp.example.com").expect("should parse");
1248 assert_eq!(host, "ocsp.example.com");
1249 assert_eq!(port, 80);
1250 assert_eq!(path, "/");
1251 }
1252
1253 #[test]
1256 fn test_der_integer_from_bytes() {
1257 let encoded = der_integer_from_bytes(&[0x05]);
1259 assert_eq!(encoded, vec![TAG_INTEGER, 0x01, 0x05]);
1260
1261 let encoded = der_integer_from_bytes(&[0x80]);
1263 assert_eq!(encoded, vec![TAG_INTEGER, 0x02, 0x00, 0x80]);
1264
1265 let encoded = der_integer_from_bytes(&[0x01, 0x00]);
1267 assert_eq!(encoded, vec![TAG_INTEGER, 0x02, 0x01, 0x00]);
1268
1269 let encoded = der_integer_from_bytes(&[0x00, 0x00, 0x42]);
1271 assert_eq!(encoded, vec![TAG_INTEGER, 0x01, 0x42]);
1272 }
1273
1274 #[test]
1275 fn test_der_encode_length() {
1276 assert_eq!(der_encode_length(0), vec![0x00]);
1277 assert_eq!(der_encode_length(127), vec![0x7F]);
1278 assert_eq!(der_encode_length(128), vec![0x81, 0x80]);
1279 assert_eq!(der_encode_length(256), vec![0x82, 0x01, 0x00]);
1280 }
1281
1282 fn build_test_ocsp_response(cert_status_tag: u8, cert_status_data: &[u8]) -> Vec<u8> {
1286 let algo = der_algorithm_identifier(SHA256_OID_BYTES);
1288 let name_hash = der_octet_string(&[0u8; 32]);
1289 let key_hash = der_octet_string(&[0u8; 32]);
1290 let serial = der_integer_from_bytes(&[0x01]);
1291 let mut cert_id_content = Vec::new();
1292 cert_id_content.extend(&algo);
1293 cert_id_content.extend(&name_hash);
1294 cert_id_content.extend(&key_hash);
1295 cert_id_content.extend(&serial);
1296 let cert_id = der_tlv(TAG_SEQUENCE, &cert_id_content);
1297
1298 let cert_status = der_tlv(cert_status_tag, cert_status_data);
1300
1301 let this_update = der_tlv(0x18, b"20250101000000Z");
1303
1304 let mut sr_content = Vec::new();
1306 sr_content.extend(&cert_id);
1307 sr_content.extend(&cert_status);
1308 sr_content.extend(&this_update);
1309 let single_response = der_tlv(TAG_SEQUENCE, &sr_content);
1310
1311 let responses = der_tlv(TAG_SEQUENCE, &single_response);
1313
1314 let responder_id = der_tlv(0xA1, &der_octet_string(&[0u8; 20]));
1316
1317 let produced_at = der_tlv(0x18, b"20250101000000Z");
1319
1320 let mut tbs_content = Vec::new();
1322 tbs_content.extend(&responder_id);
1323 tbs_content.extend(&produced_at);
1324 tbs_content.extend(&responses);
1325 let tbs_response_data = der_tlv(TAG_SEQUENCE, &tbs_content);
1326
1327 let sig_algo = der_algorithm_identifier(SHA256_OID_BYTES);
1329
1330 let signature = der_tlv(0x03, &[0x00, 0x00]); let mut basic_content = Vec::new();
1335 basic_content.extend(&tbs_response_data);
1336 basic_content.extend(&sig_algo);
1337 basic_content.extend(&signature);
1338 let basic_ocsp_response = der_tlv(TAG_SEQUENCE, &basic_content);
1339
1340 let response_type = der_oid(OCSP_BASIC_OID_BYTES);
1342 let response_octet = der_octet_string(&basic_ocsp_response);
1343 let mut rb_content = Vec::new();
1344 rb_content.extend(&response_type);
1345 rb_content.extend(&response_octet);
1346 let response_bytes_seq = der_tlv(TAG_SEQUENCE, &rb_content);
1347 let response_bytes = der_tlv(TAG_CONTEXT_0, &response_bytes_seq);
1348
1349 let response_status = der_tlv(TAG_ENUMERATED, &[0x00]);
1351
1352 let mut ocsp_resp_content = Vec::new();
1354 ocsp_resp_content.extend(&response_status);
1355 ocsp_resp_content.extend(&response_bytes);
1356 der_tlv(TAG_SEQUENCE, &ocsp_resp_content)
1357 }
1358}