1use std::convert::TryInto;
2
3use http::header::{HeaderName, HeaderValue};
4
5use super::*;
6
7fn host_from_url(url: &url::Url) -> Option<String> {
9 url.host_str().map(|host| match url.port() {
10 Some(port) => format!("{}:{}", host, port),
11 None => host.into(),
12 })
13}
14
15impl RequestLike for reqwest::Request {
16 fn header(&self, header: &Header) -> Option<HeaderValue> {
17 match header {
18 Header::Normal(header_name) => self.headers().get(header_name).cloned(),
19 Header::Pseudo(PseudoHeader::RequestTarget) => {
20 let method = self.method().as_str().to_ascii_lowercase();
21 let path = self.url().path();
22 format!("{} {}", method, path).try_into().ok()
23 }
24 _ => None,
25 }
26 }
27}
28
29impl ClientRequestLike for reqwest::Request {
30 fn host(&self) -> Option<String> {
31 host_from_url(self.url())
32 }
33 fn compute_digest(&mut self, digest: &dyn HttpDigest) -> Option<String> {
34 self.body()?.as_bytes().map(|b| digest.http_digest(b))
35 }
36 fn set_header(&mut self, header: HeaderName, value: HeaderValue) {
37 self.headers_mut().insert(header, value);
38 }
39}
40
41impl RequestLike for reqwest::blocking::Request {
42 fn header(&self, header: &Header) -> Option<HeaderValue> {
43 match header {
44 Header::Normal(header_name) => self.headers().get(header_name).cloned(),
45 Header::Pseudo(PseudoHeader::RequestTarget) => {
46 let method = self.method().as_str().to_ascii_lowercase();
47 let path = self.url().path();
48 if let Some(query) = self.url().query() {
49 format!("{} {}?{}", method, path, query)
50 } else {
51 format!("{} {}", method, path)
52 }
53 .try_into()
54 .ok()
55 }
56 _ => None,
57 }
58 }
59}
60
61impl ClientRequestLike for reqwest::blocking::Request {
62 fn host(&self) -> Option<String> {
63 host_from_url(self.url())
64 }
65 fn compute_digest(&mut self, digest: &dyn HttpDigest) -> Option<String> {
66 let bytes_to_digest = self.body_mut().as_mut()?.buffer().ok()?;
67 Some(digest.http_digest(bytes_to_digest))
68 }
69 fn set_header(&mut self, header: HeaderName, value: HeaderValue) {
70 self.headers_mut().insert(header, value);
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use chrono::{offset::TimeZone, Utc};
77 use http::header::{AUTHORIZATION, CONTENT_TYPE, DATE, HOST};
78
79 use super::*;
80
81 #[test]
82 fn it_works() {
83 let config = SigningConfig::new_default("test_key", "abcdefgh".as_bytes());
84
85 let client = reqwest::Client::new();
86
87 let without_sig = client
88 .post("http://test.com/foo/bar")
89 .header(CONTENT_TYPE, "application/json")
90 .header(
91 DATE,
92 Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11)
93 .single()
94 .expect("valid date")
95 .format("%a, %d %b %Y %T GMT")
96 .to_string(),
97 )
98 .body(&br#"{ "x": 1, "y": 2}"#[..])
99 .build()
100 .unwrap();
101
102 let with_sig = without_sig.signed(&config).unwrap();
103
104 assert_eq!(with_sig.headers().get(AUTHORIZATION).unwrap(), "Signature keyId=\"test_key\",algorithm=\"hs2019\",signature=\"F8gZiriO7dtKFiP5eSZ+Oh1h61JIrAR6D5Mdh98DjqA=\",headers=\"(request-target) host date digest\"");
105 assert_eq!(
106 with_sig
107 .headers()
108 .get(HeaderName::from_static("digest"))
109 .unwrap(),
110 "SHA-256=2vgEVkfe4d6VW+tSWAziO7BUx7uT/rA9hn1EoxUJi2o="
111 );
112 assert_eq!(with_sig.headers().get(HOST).unwrap(), "test.com");
113 }
114
115 #[test]
116 fn it_works_blocking() {
117 let config = SigningConfig::new_default("test_key", "abcdefgh".as_bytes());
118
119 let client = reqwest::blocking::Client::new();
120
121 let without_sig = client
122 .post("http://test.com/foo/bar")
123 .header(CONTENT_TYPE, "application/json")
124 .header(
125 DATE,
126 Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11)
127 .single()
128 .expect("valid date")
129 .format("%a, %d %b %Y %T GMT")
130 .to_string(),
131 )
132 .body(&br#"{ "x": 1, "y": 2}"#[..])
133 .build()
134 .unwrap();
135
136 let with_sig = without_sig.signed(&config).unwrap();
137
138 assert_eq!(with_sig.headers().get(AUTHORIZATION).unwrap(), "Signature keyId=\"test_key\",algorithm=\"hs2019\",signature=\"F8gZiriO7dtKFiP5eSZ+Oh1h61JIrAR6D5Mdh98DjqA=\",headers=\"(request-target) host date digest\"");
139 assert_eq!(
140 with_sig
141 .headers()
142 .get(HeaderName::from_static("digest"))
143 .unwrap(),
144 "SHA-256=2vgEVkfe4d6VW+tSWAziO7BUx7uT/rA9hn1EoxUJi2o="
145 );
146 assert_eq!(with_sig.headers().get(HOST).unwrap(), "test.com");
147 }
148
149 #[test]
150 fn sets_host_header_with_port_correctly() {
151 let config = SigningConfig::new_default("test_key", "abcdefgh".as_bytes());
152 let client = reqwest::Client::new();
153
154 let without_sig = client
155 .post("http://localhost:8080/foo/bar")
156 .header(CONTENT_TYPE, "application/json")
157 .header(
158 DATE,
159 Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11)
160 .single()
161 .expect("valid date")
162 .format("%a, %d %b %Y %T GMT")
163 .to_string(),
164 )
165 .body(&br#"{ "x": 1, "y": 2}"#[..])
166 .build()
167 .unwrap();
168
169 let with_sig = without_sig.signed(&config).unwrap();
170 assert_eq!(with_sig.headers().get(HOST).unwrap(), "localhost:8080");
171 }
172
173 #[test]
174 fn sets_host_header_with_port_correctly_blocking() {
175 let config = SigningConfig::new_default("test_key", "abcdefgh".as_bytes());
176 let client = reqwest::blocking::Client::new();
177
178 let without_sig = client
179 .post("http://localhost:8080/foo/bar")
180 .header(CONTENT_TYPE, "application/json")
181 .header(
182 DATE,
183 Utc.with_ymd_and_hms(2014, 7, 8, 9, 10, 11)
184 .single()
185 .expect("valid date")
186 .format("%a, %d %b %Y %T GMT")
187 .to_string(),
188 )
189 .body(&br#"{ "x": 1, "y": 2}"#[..])
190 .build()
191 .unwrap();
192
193 let with_sig = without_sig.signed(&config).unwrap();
194 assert_eq!(with_sig.headers().get(HOST).unwrap(), "localhost:8080");
195 }
196 #[test]
197 #[ignore]
198 fn it_can_talk_to_reference_integration() {
199 let config = SigningConfig::new_default("dummykey", &base64::decode("dummykey").unwrap());
200
201 let client = reqwest::blocking::Client::new();
202
203 let req = client
204 .get("http://localhost:8080/config")
205 .build()
206 .unwrap()
207 .signed(&config)
208 .unwrap();
209
210 let result = client.execute(req).unwrap();
211 println!("{:?}", result.text().unwrap());
212 }
213}