1use conduit::{header, Body, HeaderMap, Method, RequestExt, Response, StatusCode};
2use conduit_middleware::{AfterResult, Middleware};
3use std::borrow::Cow;
4use time::{OffsetDateTime, ParseError, PrimitiveDateTime};
5
6#[allow(missing_copy_implementations)]
7pub struct ConditionalGet;
8
9impl Middleware for ConditionalGet {
10 fn after(&self, req: &mut dyn RequestExt, res: AfterResult) -> AfterResult {
11 let res = res?;
12
13 match *req.method() {
14 Method::GET | Method::HEAD => {
15 if is_ok(&res) && is_fresh(req, &res) {
16 let (mut parts, _) = res.into_parts();
17 parts.status = StatusCode::NOT_MODIFIED;
18 parts.headers.remove(header::CONTENT_TYPE);
19 parts.headers.remove(header::CONTENT_LENGTH);
20 return Ok(Response::from_parts(parts, Body::empty()));
21 }
22 }
23 _ => (),
24 }
25
26 Ok(res)
27 }
28}
29
30fn is_ok(response: &Response<Body>) -> bool {
31 response.status() == 200
32}
33
34fn is_fresh(req: &dyn RequestExt, res: &Response<Body>) -> bool {
35 let modified_since = get_and_concat_header(req.headers(), header::IF_MODIFIED_SINCE);
36 let none_match = get_and_concat_header(req.headers(), header::IF_NONE_MATCH);
37
38 if modified_since.is_empty() && none_match.is_empty() {
39 return false;
40 }
41
42 let is_modified_since = match std::str::from_utf8(&modified_since) {
43 Err(_) => true,
44 Ok(string) if string.is_empty() => true,
45 Ok(modified_since) => {
46 let modified_since = parse_http_date(modified_since);
47 match modified_since {
48 Err(_) => return false, Ok(parsed) => is_modified_since(parsed, res),
50 }
51 }
52 };
53
54 is_modified_since && etag_matches(&none_match, res)
55}
56
57fn etag_matches(none_match: &[u8], res: &Response<Body>) -> bool {
58 let value = get_and_concat_header(res.headers(), header::ETAG);
59 value == none_match
60}
61
62fn is_modified_since(modified_since: OffsetDateTime, res: &Response<Body>) -> bool {
63 let last_modified = get_and_concat_header(res.headers(), header::LAST_MODIFIED);
64
65 match std::str::from_utf8(&last_modified) {
66 Err(_) => false,
67 Ok(last_modified) => match parse_http_date(last_modified) {
68 Err(_) => false,
69 Ok(last_modified) => modified_since.unix_timestamp() >= last_modified.unix_timestamp(),
70 },
71 }
72}
73
74fn get_and_concat_header(headers: &HeaderMap, name: header::HeaderName) -> Cow<'_, [u8]> {
75 let mut values = headers.get_all(name).iter();
76 if values.size_hint() == (1, Some(1)) {
77 Cow::Borrowed(values.next().unwrap().as_bytes())
80 } else {
81 let values: Vec<_> = values.map(|val| val.as_bytes()).collect();
82 Cow::Owned(values.concat())
83 }
84}
85
86fn parse_http_date(string: &str) -> Result<OffsetDateTime, ()> {
87 parse_rfc1123(string)
88 .or_else(|_| parse_rfc850(string))
89 .or_else(|_| parse_asctime(string))
90 .map_err(|_| ())
91}
92
93fn parse_rfc1123(string: &str) -> Result<OffsetDateTime, ParseError> {
94 Ok(PrimitiveDateTime::parse(string, "%a, %d %b %Y %T GMT")?.assume_utc())
95}
96
97fn parse_rfc850(string: &str) -> Result<OffsetDateTime, ParseError> {
98 Ok(PrimitiveDateTime::parse(string, "%a, %d-%m-%y %T GMT")?.assume_utc())
99}
100
101fn parse_asctime(string: &str) -> Result<OffsetDateTime, ParseError> {
102 Ok(PrimitiveDateTime::parse(string, "%a %m\t%d %T %Y")?.assume_utc())
104}
105
106#[cfg(test)]
107mod tests {
108 use conduit::{
109 box_error, header, Body, Handler, HandlerResult, HeaderMap, Method, RequestExt, Response,
110 StatusCode,
111 };
112 use conduit_middleware::MiddlewareBuilder;
113 use conduit_test::{MockRequest, ResponseExt};
114 use time::{Duration, OffsetDateTime};
115
116 use super::ConditionalGet;
117
118 macro_rules! returning {
119 ($status:expr, $($header:expr => $value:expr),+) => ({
120 use std::convert::TryInto;
121 let mut headers = HeaderMap::new();
122 $(headers.append($header, $value.try_into().unwrap());)+
123 let handler = SimpleHandler::new(headers, $status, "hello");
124 let mut stack = MiddlewareBuilder::new(handler);
125 stack.add(ConditionalGet);
126 stack
127 });
128 ($($header:expr => $value:expr),+) => ({
129 returning!(StatusCode::OK, $($header => $value),+)
130 })
131 }
132
133 macro_rules! request {
134 ($($header:expr => $value:expr),+) => ({
135 let mut req = MockRequest::new(Method::GET, "/");
136 $(req.header($header, &$value.to_string());)+
137 req
138 })
139 }
140
141 #[test]
142 fn test_sends_304() {
143 let handler = returning!(header::LAST_MODIFIED => httpdate(OffsetDateTime::now_utc()));
144 expect_304(handler.call(&mut request!(
145 header::IF_MODIFIED_SINCE => httpdate(OffsetDateTime::now_utc())
146 )));
147 }
148
149 #[test]
150 fn test_sends_304_if_older_than_now() {
151 let handler = returning!(header::LAST_MODIFIED => before_now());
152 expect_304(handler.call(&mut request!(
153 header::IF_MODIFIED_SINCE => httpdate(OffsetDateTime::now_utc())
154 )));
155 }
156
157 #[test]
158 fn test_sends_304_with_etag() {
159 let handler = returning!(header::ETAG => "1234");
160 expect_304(handler.call(&mut request!(
161 header::IF_NONE_MATCH => "1234"
162 )));
163 }
164
165 #[test]
166 fn test_sends_200_with_fresh_time_but_not_etag() {
167 let handler = returning!(header::LAST_MODIFIED => before_now(), header::ETAG => "1234");
168 expect_200(handler.call(&mut request!(
169 header::IF_MODIFIED_SINCE => now(),
170 header::IF_NONE_MATCH => "4321"
171 )));
172 }
173
174 #[test]
175 fn test_sends_200_with_fresh_etag_but_not_time() {
176 let handler = returning!(header::LAST_MODIFIED => now(), header::ETAG => "1234");
177 expect_200(handler.call(&mut request!(
178 header::IF_MODIFIED_SINCE => before_now(),
179 header::IF_NONE_MATCH => "1234"
180 )));
181 }
182
183 #[test]
184 fn test_sends_200_with_fresh_etag() {
185 let handler = returning!(header::ETAG => "1234");
186 expect_200(handler.call(&mut request!(
187 header::IF_NONE_MATCH => "4321"
188 )));
189 }
190
191 #[test]
192 fn test_sends_200_with_fresh_time() {
193 let handler = returning!(header::LAST_MODIFIED => now());
194 expect_200(handler.call(&mut request!(
195 header::IF_MODIFIED_SINCE => before_now()
196 )));
197 }
198
199 #[test]
200 fn test_sends_304_with_fresh_time_and_etag() {
201 let handler = returning!(header::LAST_MODIFIED => before_now(), header::ETAG => "1234");
202 expect_304(handler.call(&mut request!(
203 header::IF_MODIFIED_SINCE => now(),
204 header::IF_NONE_MATCH => "1234"
205 )));
206 }
207
208 #[test]
209 fn test_does_not_affect_non_200() {
210 let handler = returning!(StatusCode::FOUND, header::LAST_MODIFIED => before_now(), header::ETAG => "1234");
211 expect(
212 StatusCode::FOUND,
213 handler.call(&mut request!(
214 header::IF_MODIFIED_SINCE => now(),
215 header::IF_NONE_MATCH => "1234"
216 )),
217 );
218 }
219
220 #[test]
221 fn test_does_not_affect_malformed_timestamp() {
222 let bad_stamp = OffsetDateTime::now_utc().format("%Y-%m-%d %H:%M:%S %z");
223 let handler = returning!(header::LAST_MODIFIED => before_now());
224 expect_200(handler.call(&mut request!(
225 header::IF_MODIFIED_SINCE => bad_stamp
226 )));
227 }
228
229 fn expect_304(response: HandlerResult) {
230 let response = response.expect("No response");
231 assert_eq!(response.status(), StatusCode::NOT_MODIFIED);
232 assert_eq!(*response.into_cow(), b""[..]);
233 }
234
235 fn expect_200(response: HandlerResult) {
236 expect(StatusCode::OK, response);
237 }
238
239 fn expect(status: StatusCode, response: HandlerResult) {
240 let response = response.expect("No response");
241 assert_eq!(response.status(), status);
242 assert_eq!(*response.into_cow(), b"hello"[..]);
243 }
244
245 struct SimpleHandler {
246 headers: HeaderMap,
247 status: StatusCode,
248 body: &'static str,
249 }
250
251 impl SimpleHandler {
252 fn new(headers: HeaderMap, status: StatusCode, body: &'static str) -> SimpleHandler {
253 SimpleHandler {
254 headers,
255 status,
256 body,
257 }
258 }
259 }
260
261 impl Handler for SimpleHandler {
262 fn call(&self, _: &mut dyn RequestExt) -> HandlerResult {
263 let mut builder = Response::builder().status(self.status);
264 builder.headers_mut().unwrap().extend(self.headers.clone());
265 builder
266 .body(Body::from_static(self.body.as_bytes()))
267 .map_err(box_error)
268 }
269 }
270
271 fn before_now() -> String {
272 let now = OffsetDateTime::now_utc();
273 httpdate(now - Duration::weeks(52))
274 }
275
276 fn now() -> String {
277 httpdate(OffsetDateTime::now_utc())
278 }
279
280 fn httpdate(time: OffsetDateTime) -> String {
281 time.format("%a, %d-%m-%y %T GMT")
282 }
283}