1use std::collections::HashMap;
2use std::convert::TryInto;
3use std::error::Error;
4use std::fmt::{self, Display};
5use std::io::{BufRead, Write};
6
7use anyhow::Context;
8use http::{header::HeaderName, HeaderValue, Method};
9use url::Url;
10
11use crate::{ClientRequestLike, Header, HttpDigest, PseudoHeader, RequestLike, ServerRequestLike};
12
13#[derive(Debug)]
16pub struct ParseError;
17
18impl Error for ParseError {}
19impl Display for ParseError {
20 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
21 f.write_str("Malformed HTTP request")
22 }
23}
24
25#[derive(Debug, Clone, PartialEq)]
27pub struct MockRequest {
28 method: Method,
29 path: String,
30 headers: HashMap<HeaderName, HeaderValue>,
31 body: Option<Vec<u8>>,
32}
33
34impl MockRequest {
35 pub fn method(&self) -> Method {
37 self.method.clone()
38 }
39 pub fn path(&self) -> &str {
41 &self.path
42 }
43 pub fn headers(&self) -> impl IntoIterator<Item = (&HeaderName, &HeaderValue)> {
45 &self.headers
46 }
47 pub fn body(&self) -> Option<&[u8]> {
49 self.body.as_deref()
50 }
51
52 pub fn new(method: Method, url: &str) -> Self {
54 let url: Url = url.parse().unwrap();
55
56 let path = if let Some(query) = url.query() {
57 format!("{}?{}", url.path(), query)
58 } else {
59 url.path().into()
60 };
61 let mut res = Self {
62 method,
63 path,
64 headers: Default::default(),
65 body: None,
66 };
67 if let Some(host) = url.host_str().map(ToOwned::to_owned) {
68 res = res.with_header("Host", &host)
69 }
70 res
71 }
72 pub fn with_header(mut self, name: &str, value: &str) -> Self {
74 self.set_header(
75 HeaderName::from_bytes(name.as_bytes()).unwrap(),
76 HeaderValue::from_bytes(value.as_bytes()).unwrap(),
77 );
78 self
79 }
80 pub fn with_body(mut self, body: Vec<u8>) -> Self {
82 let l = body.len();
83 self.body = Some(body);
84 self.with_header("Content-Length", &l.to_string())
85 }
86
87 pub fn from_reader<R: BufRead>(reader: &mut R) -> Result<Self, Box<dyn Error>> {
89 let mut line = String::new();
90
91 reader.read_line(&mut line)?;
93 let mut parts = line.split_ascii_whitespace();
94
95 let method: Method = parts.next().ok_or(ParseError)?.parse()?;
97
98 let path: String = parts.next().ok_or(ParseError)?.parse()?;
100
101 #[allow(clippy::mutable_key_type)]
103 let mut headers = HashMap::new();
104 let has_body = loop {
105 line.truncate(0);
106 if reader.read_line(&mut line)? == 0 {
107 break false;
108 }
109 if line.trim().is_empty() {
110 break true;
111 }
112
113 let mut parts = line.splitn(2, ':');
114
115 let name_str = parts.next().ok_or(ParseError)?.trim();
116 let header_name: HeaderName = name_str
117 .parse()
118 .with_context(|| format!("{:?}", name_str))?;
119 let value_str = parts.next().ok_or(ParseError)?.trim();
120 let header_value: HeaderValue = value_str
121 .parse()
122 .with_context(|| format!("{:?}", value_str))?;
123 headers.insert(header_name, header_value);
124 };
125
126 let body = if has_body {
127 let mut body = Vec::new();
128 reader.read_to_end(&mut body)?;
129 Some(body)
130 } else {
131 None
132 };
133
134 Ok(Self {
135 method,
136 path,
137 headers,
138 body,
139 })
140 }
141
142 pub fn write<W: Write>(&self, writer: &mut W) -> Result<(), Box<dyn Error>> {
144 writeln!(writer, "{} {} HTTP/1.1", self.method.as_str(), self.path)?;
145 for (header_name, header_value) in &self.headers {
146 writeln!(
147 writer,
148 "{}: {}",
149 header_name.as_str(),
150 header_value.to_str()?
151 )?;
152 }
153
154 if let Some(body) = &self.body {
155 writeln!(writer)?;
156 writer.write_all(body)?;
157 }
158
159 Ok(())
160 }
161}
162
163impl RequestLike for MockRequest {
164 fn header(&self, header: &Header) -> Option<HeaderValue> {
165 match header {
166 Header::Normal(header_name) => self.headers.get(header_name).cloned(),
167 Header::Pseudo(PseudoHeader::RequestTarget) => {
168 let method = self.method.as_str().to_ascii_lowercase();
169 format!("{} {}", method, self.path).try_into().ok()
170 }
171 _ => None,
172 }
173 }
174}
175
176impl ClientRequestLike for MockRequest {
177 fn compute_digest(&mut self, digest: &dyn HttpDigest) -> Option<String> {
178 self.body.as_ref().map(|b| digest.http_digest(b))
179 }
180 fn set_header(&mut self, header: HeaderName, value: HeaderValue) {
181 self.headers.insert(header, value);
182 }
183}
184
185impl ServerRequestLike for &MockRequest {
186 type Remnant = ();
187
188 fn complete_with_digest(self, digest: &dyn HttpDigest) -> (Option<String>, Self::Remnant) {
189 if let Some(body) = self.body.as_ref() {
190 let computed_digest = digest.http_digest(body);
191 (Some(computed_digest), ())
192 } else {
193 (None, ())
194 }
195 }
196 fn complete(self) -> Self::Remnant {}
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 use std::sync::Arc;
204
205 use http::header::DATE;
206
207 use crate::{
208 HttpSignatureVerify, RsaSha256Verify, SimpleKeyProvider, VerifyingConfig, VerifyingExt,
209 };
210
211 fn test_request() -> MockRequest {
225 MockRequest::new(Method::POST, "http://example.com/foo?param=value&pet=dog")
226 .with_header("Date", "Sun, 05 Jan 2014 21:31:40 GMT")
227 .with_header("Content-Type", "application/json")
228 .with_header(
229 "Digest",
230 "SHA-256=X48E9qOokqqrvdts8nOJRJN3OWDUoyWxBf7kbu9DBPE=",
231 )
232 .with_body(r#"{"hello": "world"}"#.as_bytes().into())
233 }
234
235 fn test_key_provider() -> SimpleKeyProvider {
238 SimpleKeyProvider::new(vec![(
239 "Test",
240 Arc::new(
241 RsaSha256Verify::new_pem(
242 include_bytes!("../test_data/public.pem"),
244 )
245 .unwrap(),
246 ) as Arc<dyn HttpSignatureVerify>,
247 )])
248 }
249
250 #[test]
253 fn default_test() {
254 let mut req = test_request().with_header(
256 "Authorization",
257 "\
258 Signature \
259 keyId=\"Test\", \
260 algorithm=\"rsa-sha256\", \
261 headers=\"date\", \
262 signature=\"SjWJWbWN7i0wzBvtPl8rbASWz5xQW6mcJmn+ibttBqtifLN7Sazz\
263 6m79cNfwwb8DMJ5cou1s7uEGKKCs+FLEEaDV5lp7q25WqS+lavg7T8hc0GppauB\
264 6hbgEKTwblDHYGEtbGmtdHgVCk9SuS13F0hZ8FD0k/5OxEPXe5WozsbM=\"\
265 ",
266 );
267
268 let mut config = VerifyingConfig::new(test_key_provider());
269 config.set_validate_date(false);
270 config.set_require_digest(false);
271 config.set_required_headers(&[Header::Normal(DATE)]);
272
273 req.verify(&config)
274 .expect("Signature to be verified correctly");
275
276 req = req.with_header("Date", "Sun, 05 Jan 2014 21:31:41 GMT");
278
279 req.verify(&config)
280 .expect_err("Signature verification to fail");
281 }
282
283 #[test]
285 fn basic_test() {
286 let req = test_request().with_header(
288 "Authorization",
289 "\
290 Signature \
291 keyId=\"Test\", \
292 algorithm=\"rsa-sha256\", \
293 headers=\"(request-target) host date\", \
294 signature=\"qdx+H7PHHDZgy4y/Ahn9Tny9V3GP6YgBPyUXMmoxWtLbHpUnXS\
295 2mg2+SbrQDMCJypxBLSPQR2aAjn7ndmw2iicw3HMbe8VfEdKFYRqzic+efkb3\
296 nndiv/x1xSHDJWeSWkx3ButlYSuBskLu6kd9Fswtemr3lgdDEmn04swr2Os0=\"\
297 ",
298 );
299
300 let mut config = VerifyingConfig::new(test_key_provider());
301 config.set_validate_date(false);
302 config.set_require_digest(false);
303
304 dbg!(&req);
305
306 req.verify(&config)
307 .expect("Signature to be verified correctly");
308 }
309
310 #[test]
313 fn all_headers_test() {
314 let req = test_request().with_header(
316 "Authorization",
317 "\
318 Signature \
319 keyId=\"Test\", \
320 algorithm=\"rsa-sha256\", \
321 created=1402170695, \
322 expires=1402170699, \
323 headers=\"(request-target) host date content-type digest content-length\", \
324 signature=\"vSdrb+dS3EceC9bcwHSo4MlyKS59iFIrhgYkz8+oVLEEzmYZZvRs\
325 8rgOp+63LEM3v+MFHB32NfpB2bEKBIvB1q52LaEUHFv120V01IL+TAD48XaERZF\
326 ukWgHoBTLMhYS2Gb51gWxpeIq8knRmPnYePbF5MOkR0Zkly4zKH7s1dE=\"\
327 ",
328 );
329
330 let mut config = VerifyingConfig::new(test_key_provider());
331 config.set_validate_date(false);
332 config.set_require_digest(false);
333
334 dbg!(&req);
335
336 req.verify(&config)
337 .expect("Signature to be verified correctly");
338 }
339}