1use crate::{
2 constants::*,
3 error::{log_error, VerificationError},
4 normalize,
5};
6use failure::{bail, Error, ResultExt};
7use futures_util::lock::Mutex;
8use std::{collections::HashMap, path::Path};
9use time::Duration;
10use url::{Host, Url};
11use x509_parser::objects::Nid;
12
13pub struct RequestVerifierAsync {
15 cert_cache: Mutex<HashMap<String, Vec<u8>>>,
16}
17
18impl Default for RequestVerifierAsync {
28 fn default() -> Self {
29 RequestVerifierAsync {
30 cert_cache: Mutex::new(HashMap::new()),
31 }
32 }
33}
34
35impl RequestVerifierAsync {
36 pub fn new() -> Self {
38 RequestVerifierAsync::default()
39 }
40
41 pub async fn verify(
52 &self,
53 signature_cert_chain_url: &str,
54 signature: &str,
55 body: &[u8],
56 timestamp: &str,
57 timestamp_tolerance_millis: Option<u64>,
58 ) -> Result<(), Error> {
59 if let Err(e) = self
60 .retrieve_and_validate_cert(signature_cert_chain_url, signature, body)
61 .await
62 {
63 log_error(e)?;
64 };
65
66 if let Err(e) = self.validate_timestamp(timestamp, timestamp_tolerance_millis) {
67 log_error(e)?;
68 };
69
70 Ok(())
71 }
72
73 async fn retrieve_and_validate_cert(
74 &self,
75 signature_cert_chain_url: &str,
76 signature: &str,
77 body: &[u8],
78 ) -> Result<(), Error> {
79 self.validate_cert_url(&signature_cert_chain_url)?;
81
82 let mut not_exists = false;
84 if !self
85 .cert_cache
86 .lock()
87 .await
88 .contains_key(&signature_cert_chain_url.to_string())
89 {
90 not_exists = true;
91 self.retrieve_cert(&signature_cert_chain_url)
92 .await
93 .context(VerificationError::RetrieveCert)?;
94 }
95
96 let cert_cache = self.cert_cache.lock().await;
99 let pem_bytes = cert_cache
100 .get(&signature_cert_chain_url.to_string())
101 .ok_or(VerificationError::MissingCertCache)?;
102 let (_, pem) =
103 x509_parser::pem::pem_to_der(pem_bytes).map_err(|_| VerificationError::PemParse)?;
104 drop(cert_cache);
105 let certificate = pem.parse_x509().map_err(|_| VerificationError::CertParse)?;
106
107 let not_before = certificate.tbs_certificate.validity.not_before;
109 let not_after = certificate.tbs_certificate.validity.not_after;
110 let now_utc = time::now_utc();
111 if now_utc < not_before || now_utc > not_after {
112 bail!(VerificationError::ExpiredCert)
113 }
114
115 if not_exists {
118 let mut sans: Vec<&str> = Vec::new();
119 for ext in &certificate.tbs_certificate.extensions {
120 if ext.oid == x509_parser::objects::nid2obj(&Nid::SubjectAltName).unwrap() {
121 let (_, ber) = der_parser::parse_der(&ext.value)
122 .map_err(|_| VerificationError::CertExtParse)?;
123 for b in ber.into_iter() {
124 if let der_parser::ber::BerObjectContent::Unknown(_, i) = b.content {
125 sans.push(
126 std::str::from_utf8(i).context(VerificationError::SanExtension)?,
127 )
128 } else {
129 bail!(VerificationError::SanExtension)
130 }
131 }
132 }
133 }
134 if !sans.contains(&CERT_CHAIN_DOMAIN) {
135 bail!(VerificationError::DomainNotInSan)
136 }
137 }
138
139 let pkey = certificate
141 .tbs_certificate
142 .subject_pki
143 .subject_public_key
144 .data;
145
146 self.validate_request_body(signature, body, pkey)?;
148
149 Ok(())
150 }
151
152 async fn retrieve_cert(&self, signature_cert_chain_url: &str) -> Result<(), Error> {
153 let resp = reqwest::get(signature_cert_chain_url).await?;
155 let bytes = resp.bytes().await?;
156
157 let _ = self
159 .cert_cache
160 .lock()
161 .await
162 .insert(signature_cert_chain_url.to_string(), bytes.to_vec());
163
164 Ok(())
165 }
166
167 fn validate_cert_url(&self, signature_cert_chain_url: &str) -> Result<(), Error> {
168 let parsed_url = Url::parse(signature_cert_chain_url)?;
169
170 let scheme = parsed_url.scheme();
171 if scheme != CERT_CHAIN_URL_SCHEME {
172 bail!(VerificationError::UrlScheme {
173 scheme: scheme.to_string()
174 })
175 }
176
177 if let Some(hostname) = parsed_url.host() {
178 match hostname {
179 Host::Domain(hostname) => {
180 if hostname.to_lowercase() != CERT_CHAIN_URL_HOSTNAME {
181 bail!(VerificationError::UrlHostname {
182 hostname: hostname.to_string()
183 });
184 }
185 }
186 Host::Ipv4(ip) => bail!(VerificationError::UrlHostname {
187 hostname: format!("{}", ip)
188 }),
189 Host::Ipv6(ip) => bail!(VerificationError::UrlHostname {
190 hostname: format!("{}", ip)
191 }),
192 }
193 } else {
194 bail!(VerificationError::UrlHostname {
195 hostname: "".to_string()
196 })
197 }
198
199 let path = Path::new(parsed_url.path());
200 let normalized_path = normalize::normalize_path(&path);
201 if !normalized_path.starts_with(CERT_CHAIN_URL_STARTPATH) {
202 bail!(VerificationError::UrlPath {
203 path: format!("{}", normalized_path.display())
204 })
205 }
206
207 if let Some(port) = parsed_url.port() {
208 if port != CERT_CHAIN_URL_PORT {
209 bail!(VerificationError::UrlPort { port })
210 }
211 }
212
213 Ok(())
214 }
215
216 fn validate_request_body(
217 &self,
218 signature: &str,
219 body: &[u8],
220 pkey_bytes: &[u8],
221 ) -> Result<(), Error> {
222 let decoded_signature = base64::decode(&signature)?;
223
224 let pkey = ring::signature::UnparsedPublicKey::new(
225 &ring::signature::RSA_PKCS1_2048_8192_SHA1_FOR_LEGACY_USE_ONLY,
226 pkey_bytes,
227 );
228
229 pkey.verify(body, &decoded_signature)?;
230
231 Ok(())
232 }
233
234 fn validate_timestamp(
235 &self,
236 timestamp: &str,
237 timestamp_tolerance_millis: Option<u64>,
238 ) -> Result<(), Error> {
239 let tolerance_millis = {
241 if let Some(t) = timestamp_tolerance_millis {
242 Duration::milliseconds(t as i64)
243 } else {
244 Duration::milliseconds(DEFAULT_TIMESTAMP_TOLERANCE_IN_MILLIS)
245 }
246 };
247
248 if tolerance_millis > Duration::milliseconds(MAX_TIMESTAMP_TOLERANCE_IN_MILLIS) {
250 bail!(VerificationError::TimestampMax {
251 millis: tolerance_millis.num_milliseconds()
252 });
253 }
254
255 let timestamp =
257 time::strptime(timestamp, "%FT%TZ").context(VerificationError::TimestampParse {
258 timestamp: timestamp.to_owned(),
259 })?;
260 let utc_now = time::now_utc();
261
262 let duration_between = utc_now - timestamp;
264 if duration_between > tolerance_millis {
265 bail!(VerificationError::Timestamp);
266 };
267
268 Ok(())
269 }
270}