alexa_verifier/
async.rs

1use crate::{
2    constants::*,
3    error::{log_error, VerificationError},
4    normalize,
5};
6use failure::{bail, Error, ResultExt};
7use futures_util::lock::Mutex;
8use std::{collections::HashMap, path::Path};
9use time::Duration;
10use url::{Host, Url};
11use x509_parser::objects::Nid;
12
13/// Exposes verify method and caches new certificates asynchronously on the first request
14pub struct RequestVerifierAsync {
15    cert_cache: Mutex<HashMap<String, Vec<u8>>>,
16}
17
18/// ```rust
19/// impl Default for RequestVerifierAsync {
20///     fn default() -> Self {
21///         RequestVerifierAsync {
22///             cert_cache: Mutex::new(HashMap::new()),
23///         }
24///     }
25/// }
26/// ```
27impl Default for RequestVerifierAsync {
28    fn default() -> Self {
29        RequestVerifierAsync {
30            cert_cache: Mutex::new(HashMap::new()),
31        }
32    }
33}
34
35impl RequestVerifierAsync {
36    /// Create default instance with an empty cache
37    pub fn new() -> Self {
38        RequestVerifierAsync::default()
39    }
40
41    /// Asynchronously verify that the request came from Alexa. Returns a `std::future::Future`
42    /// that can `.await`'d.
43    ///
44    /// - `SignatureCertChainUrl` and `Signature` are headers of the request
45    ///
46    /// - Pass the entire body of the request for signature verification
47    ///
48    /// - Timestamp comes from the body, `{ "request" : { "timestamp": "" } }`. If deserialized using [alexa_sdk](https://github.com/tarkah/alexa_rust) then timestamp can be taken from `alexa_sdk::Request.body.timestamp`
49    ///
50    /// - A tolerance value in milliseconds can be passed to verify the request was received within that tolerance (default is `150_000`)
51    pub async fn verify(
52        &self,
53        signature_cert_chain_url: &str,
54        signature: &str,
55        body: &[u8],
56        timestamp: &str,
57        timestamp_tolerance_millis: Option<u64>,
58    ) -> Result<(), Error> {
59        if let Err(e) = self
60            .retrieve_and_validate_cert(signature_cert_chain_url, signature, body)
61            .await
62        {
63            log_error(e)?;
64        };
65
66        if let Err(e) = self.validate_timestamp(timestamp, timestamp_tolerance_millis) {
67            log_error(e)?;
68        };
69
70        Ok(())
71    }
72
73    async fn retrieve_and_validate_cert(
74        &self,
75        signature_cert_chain_url: &str,
76        signature: &str,
77        body: &[u8],
78    ) -> Result<(), Error> {
79        // First, validate cert url
80        self.validate_cert_url(&signature_cert_chain_url)?;
81
82        // Look for certificate in cache, if not, download using validated url
83        let mut not_exists = false;
84        if !self
85            .cert_cache
86            .lock()
87            .await
88            .contains_key(&signature_cert_chain_url.to_string())
89        {
90            not_exists = true;
91            self.retrieve_cert(&signature_cert_chain_url)
92                .await
93                .context(VerificationError::RetrieveCert)?;
94        }
95
96        // Get certificate from cache (shouldn't fail), convert from pem to der,
97        // then parse as x509
98        let cert_cache = self.cert_cache.lock().await;
99        let pem_bytes = cert_cache
100            .get(&signature_cert_chain_url.to_string())
101            .ok_or(VerificationError::MissingCertCache)?;
102        let (_, pem) =
103            x509_parser::pem::pem_to_der(pem_bytes).map_err(|_| VerificationError::PemParse)?;
104        drop(cert_cache);
105        let certificate = pem.parse_x509().map_err(|_| VerificationError::CertParse)?;
106
107        // Make sure cert is not expired
108        let not_before = certificate.tbs_certificate.validity.not_before;
109        let not_after = certificate.tbs_certificate.validity.not_after;
110        let now_utc = time::now_utc();
111        if now_utc < not_before || now_utc > not_after {
112            bail!(VerificationError::ExpiredCert)
113        }
114
115        // Make sure domain is in SAN extension
116        // Only need to validate first time cert is downloaded
117        if not_exists {
118            let mut sans: Vec<&str> = Vec::new();
119            for ext in &certificate.tbs_certificate.extensions {
120                if ext.oid == x509_parser::objects::nid2obj(&Nid::SubjectAltName).unwrap() {
121                    let (_, ber) = der_parser::parse_der(&ext.value)
122                        .map_err(|_| VerificationError::CertExtParse)?;
123                    for b in ber.into_iter() {
124                        if let der_parser::ber::BerObjectContent::Unknown(_, i) = b.content {
125                            sans.push(
126                                std::str::from_utf8(i).context(VerificationError::SanExtension)?,
127                            )
128                        } else {
129                            bail!(VerificationError::SanExtension)
130                        }
131                    }
132                }
133            }
134            if !sans.contains(&CERT_CHAIN_DOMAIN) {
135                bail!(VerificationError::DomainNotInSan)
136            }
137        }
138
139        // Get primary key for signature verification
140        let pkey = certificate
141            .tbs_certificate
142            .subject_pki
143            .subject_public_key
144            .data;
145
146        // Parses the public key and verifies signature is a valid signature of message using it.
147        self.validate_request_body(signature, body, pkey)?;
148
149        Ok(())
150    }
151
152    async fn retrieve_cert(&self, signature_cert_chain_url: &str) -> Result<(), Error> {
153        // Get cert using validated SignatureCertChainUrl
154        let resp = reqwest::get(signature_cert_chain_url).await?;
155        let bytes = resp.bytes().await?;
156
157        // Add to cert cache
158        let _ = self
159            .cert_cache
160            .lock()
161            .await
162            .insert(signature_cert_chain_url.to_string(), bytes.to_vec());
163
164        Ok(())
165    }
166
167    fn validate_cert_url(&self, signature_cert_chain_url: &str) -> Result<(), Error> {
168        let parsed_url = Url::parse(signature_cert_chain_url)?;
169
170        let scheme = parsed_url.scheme();
171        if scheme != CERT_CHAIN_URL_SCHEME {
172            bail!(VerificationError::UrlScheme {
173                scheme: scheme.to_string()
174            })
175        }
176
177        if let Some(hostname) = parsed_url.host() {
178            match hostname {
179                Host::Domain(hostname) => {
180                    if hostname.to_lowercase() != CERT_CHAIN_URL_HOSTNAME {
181                        bail!(VerificationError::UrlHostname {
182                            hostname: hostname.to_string()
183                        });
184                    }
185                }
186                Host::Ipv4(ip) => bail!(VerificationError::UrlHostname {
187                    hostname: format!("{}", ip)
188                }),
189                Host::Ipv6(ip) => bail!(VerificationError::UrlHostname {
190                    hostname: format!("{}", ip)
191                }),
192            }
193        } else {
194            bail!(VerificationError::UrlHostname {
195                hostname: "".to_string()
196            })
197        }
198
199        let path = Path::new(parsed_url.path());
200        let normalized_path = normalize::normalize_path(&path);
201        if !normalized_path.starts_with(CERT_CHAIN_URL_STARTPATH) {
202            bail!(VerificationError::UrlPath {
203                path: format!("{}", normalized_path.display())
204            })
205        }
206
207        if let Some(port) = parsed_url.port() {
208            if port != CERT_CHAIN_URL_PORT {
209                bail!(VerificationError::UrlPort { port })
210            }
211        }
212
213        Ok(())
214    }
215
216    fn validate_request_body(
217        &self,
218        signature: &str,
219        body: &[u8],
220        pkey_bytes: &[u8],
221    ) -> Result<(), Error> {
222        let decoded_signature = base64::decode(&signature)?;
223
224        let pkey = ring::signature::UnparsedPublicKey::new(
225            &ring::signature::RSA_PKCS1_2048_8192_SHA1_FOR_LEGACY_USE_ONLY,
226            pkey_bytes,
227        );
228
229        pkey.verify(body, &decoded_signature)?;
230
231        Ok(())
232    }
233
234    fn validate_timestamp(
235        &self,
236        timestamp: &str,
237        timestamp_tolerance_millis: Option<u64>,
238    ) -> Result<(), Error> {
239        // If no tolerance is provided, use DEFAULT
240        let tolerance_millis = {
241            if let Some(t) = timestamp_tolerance_millis {
242                Duration::milliseconds(t as i64)
243            } else {
244                Duration::milliseconds(DEFAULT_TIMESTAMP_TOLERANCE_IN_MILLIS)
245            }
246        };
247
248        // Make sure tolerance is not higher than max allowed by Alexa
249        if tolerance_millis > Duration::milliseconds(MAX_TIMESTAMP_TOLERANCE_IN_MILLIS) {
250            bail!(VerificationError::TimestampMax {
251                millis: tolerance_millis.num_milliseconds()
252            });
253        }
254
255        // Timestamp is in ISO 8601 format
256        let timestamp =
257            time::strptime(timestamp, "%FT%TZ").context(VerificationError::TimestampParse {
258                timestamp: timestamp.to_owned(),
259            })?;
260        let utc_now = time::now_utc();
261
262        // Ensure request received within tolerance milliseconds
263        let duration_between = utc_now - timestamp;
264        if duration_between > tolerance_millis {
265            bail!(VerificationError::Timestamp);
266        };
267
268        Ok(())
269    }
270}