1use crate::config::Config;
10use crate::handler::Handler;
11use hyper::{Response, Method};
12use hyper::body::Bytes;
13use http_body_util::Full;
14use hyper::header::{HeaderMap, CONTENT_LENGTH, CONTENT_ENCODING, CONTENT_TYPE};
15use std::sync::Arc;
16use std::io::Write;
17
18pub struct Http2Server;
20
21impl Http2Server {
22 pub fn new(_config: Config, _handler: Arc<Handler>) -> Self {
24 Self
25 }
26
27 pub async fn handle_push(
29 &self,
30 body: Bytes,
31 headers: &HeaderMap,
32 _stream: &mut (dyn tokio::io::AsyncWrite + std::marker::Unpin),
33 ) -> Result<Response<Full<Bytes>>, crate::error::Error> {
34 use hyper::header::CONTENT_LENGTH;
35
36 let _content_length = headers.get(CONTENT_LENGTH)
39 .and_then(|len| {
40 let len = len.to_str().unwrap_or("0");
41 len.parse::<usize>().ok()
42 })
43 .unwrap_or(0);
44
45 if body.is_empty() {
46 let response = Response::builder()
48 .status(200)
49 .header(CONTENT_LENGTH, "0")
50 .header(CONTENT_TYPE, "application/http2")
51 .body(Full::new(Bytes::new()))
52 .unwrap();
53 return Ok(response);
54 }
55
56 if body.len() > 65536 {
58 return Err(crate::error::Error::Http(
59 "Body too large for HTTP/2 push".to_string(),
60 ));
61 }
62
63 let response_body = body;
65 let response_length = response_body.len();
66
67 let response = Response::builder()
68 .status(200)
69 .header(CONTENT_LENGTH, response_length.to_string())
70 .header(CONTENT_ENCODING, "gzip") .header(CONTENT_TYPE, "application/http2")
72 .body(Full::new(response_body))
73 .unwrap();
74
75 Ok(response)
76 }
77
78 pub async fn handle_client_push(
80 &self,
81 body: Bytes,
82 headers: &HeaderMap,
83 _stream: &mut (dyn tokio::io::AsyncWrite + std::marker::Unpin),
84 ) -> Result<Response<Full<Bytes>>, crate::error::Error> {
85 use hyper::header::{CONTENT_LENGTH, CONTENT_ENCODING, ACCEPT_ENCODING};
86
87 let accept_encoding = headers.get(ACCEPT_ENCODING)
89 .and_then(|enc| enc.to_str().ok())
90 .unwrap_or("");
91
92 if !accept_encoding.contains("http/2") {
93 return Err(crate::error::Error::Http(
94 "Client-initiated push requires Accept-Encoding: http/2".to_string(),
95 ));
96 }
97
98 if body.is_empty() {
100 let response = Response::builder()
101 .status(200)
102 .header(CONTENT_LENGTH, "0")
103 .header(CONTENT_TYPE, "application/http2")
104 .body(Full::new(Bytes::new()))
105 .unwrap();
106 return Ok(response);
107 }
108
109 let is_compressed = headers.get(CONTENT_ENCODING)
111 .and_then(|enc| enc.to_str().ok())
112 .unwrap_or("");
113
114 let final_body = if is_compressed.contains("gzip") {
115 use flate2::write::GzEncoder;
116 use flate2::Compression;
117
118 let mut encoder = GzEncoder::new(Vec::new(), Compression::default());
119 encoder.write_all(&body).map_err(|e| {
120 crate::error::Error::Internal(format!("Gzip compression failed: {}", e))
121 })?;
122 encoder.finish().map_err(|e| {
123 crate::error::Error::Internal(format!("Gzip finalization failed: {}", e))
124 })?
125 } else {
126 body.to_vec()
127 };
128
129 let response_length = final_body.len();
130
131 let response = Response::builder()
132 .status(200)
133 .header(CONTENT_LENGTH, response_length.to_string())
134 .header(CONTENT_ENCODING, "gzip")
135 .header(CONTENT_TYPE, "application/http2")
136 .body(Full::new(Bytes::from(final_body)))
137 .unwrap();
138
139 Ok(response)
140 }
141
142 pub fn is_http2_push(method: &Method, headers: &HeaderMap) -> bool {
144 method == &Method::POST
145 && headers.get("content-type")
146 .and_then(|ct| ct.to_str().ok())
147 .unwrap_or("")
148 .starts_with("application/http2+push")
149 }
150
151 pub fn create_http2_response(
153 status: hyper::StatusCode,
154 body: Bytes,
155 headers: &HeaderMap,
156 content_encoding: Option<&str>,
157 ) -> Response<Full<Bytes>> {
158 let mut builder = Response::builder().status(status);
159
160 if let Some(encoding) = content_encoding {
162 builder = builder.header(CONTENT_ENCODING, encoding);
163 }
164 if let Some(content_type) = headers.get(CONTENT_TYPE) {
165 builder = builder.header(CONTENT_TYPE, content_type);
166 }
167
168 let content_length = body.len();
169 builder = builder.header(CONTENT_LENGTH, content_length.to_string());
170
171 builder.body(Full::new(body)).unwrap()
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178
179 #[tokio::test]
180 async fn test_handle_empty_push() {
181 let config = Arc::new(Config::default());
182 let handler = Arc::new(Handler::new(config.clone()));
183 let server = Http2Server::new((*config).clone(), handler);
184 let body = Bytes::new();
185
186 let mut headers = HeaderMap::new();
187 headers.insert("content-type", "application/http2+push".parse().unwrap());
188 headers.insert("content-length", "0".parse().unwrap());
189
190 let response = server.handle_push(body, &headers, &mut Vec::new()).await;
191 assert!(response.is_ok());
192 let response = response.unwrap();
193 assert_eq!(response.status(), 200);
194 assert_eq!(response.headers().get("content-length").unwrap().to_str().unwrap(), "0");
195 assert_eq!(response.headers().get("content-type").unwrap().to_str().unwrap(), "application/http2");
196 }
197
198 #[tokio::test]
199 async fn test_push_with_body() {
200 let config = Arc::new(Config::default());
201 let handler = Arc::new(Handler::new(config.clone()));
202 let server = Http2Server::new((*config).clone(), handler);
203
204 let test_body = b"Hello from HTTP/2 server!";
205 let mut headers = HeaderMap::new();
206 headers.insert("content-type", "application/http2+push".parse().unwrap());
207 headers.insert("content-length", test_body.len().to_string().parse().unwrap());
208
209 let result = server.handle_push(Bytes::copy_from_slice(test_body), &headers, &mut Vec::new()).await;
210 assert!(result.is_ok());
211 let response = result.unwrap();
212 assert_eq!(response.status(), 200);
213 assert_eq!(response.headers().get("content-length").unwrap().to_str().unwrap(), test_body.len().to_string());
214 }
215
216 #[tokio::test]
217 async fn test_push_too_large() {
218 let config = Arc::new(Config::default());
219 let handler = Arc::new(Handler::new(config.clone()));
220 let server = Http2Server::new((*config).clone(), handler);
221
222 let mut headers = HeaderMap::new();
223 headers.insert("content-type", "application/http2+push".parse().unwrap());
224
225 let result = server.handle_push(Bytes::from(vec![b'X'; 100000]), &headers, &mut Vec::new()).await;
226 assert!(result.is_err());
227 assert!(result.unwrap_err().to_string().contains("too large"));
228 }
229
230 #[tokio::test]
231 async fn test_client_initiated_push() {
232 let config = Arc::new(Config::default());
233 let handler = Arc::new(Handler::new(config.clone()));
234 let server = Http2Server::new((*config).clone(), handler);
235
236 let mut headers = HeaderMap::new();
237 headers.insert("content-type", "application/http2+push".parse().unwrap());
238 headers.insert("accept-encoding", "http/2".parse().unwrap());
239
240 let result = server.handle_client_push(Bytes::copy_from_slice(b"Client push"), &headers, &mut Vec::new()).await;
241 assert!(result.is_ok());
242 }
243
244 #[test]
245 fn test_http2_detection() {
246 let mut headers = HeaderMap::new();
247 let content_type = hyper::header::HeaderValue::from_static("application/http2+push");
248 headers.insert("content-type", content_type);
249
250 assert!(Http2Server::is_http2_push(&Method::POST, &headers));
251 assert!(!Http2Server::is_http2_push(&Method::GET, &headers));
252
253 let mut headers_post = HeaderMap::new();
254 let content_type_post = hyper::header::HeaderValue::from_static("application/http2+push");
255 headers_post.insert("content-type", content_type_post);
256
257 assert!(Http2Server::is_http2_push(&Method::POST, &headers_post));
258 }
259
260 #[tokio::test]
261 async fn test_content_encoding_gzip() {
262 let config = Arc::new(Config::default());
263 let handler = Arc::new(Handler::new(config.clone()));
264 let server = Http2Server::new((*config).clone(), handler);
265
266 let mut headers = HeaderMap::new();
267 headers.insert("content-encoding", "gzip".parse().unwrap());
268 headers.insert("content-type", "application/http2+push".parse().unwrap());
269
270 let result = server.handle_push(Bytes::copy_from_slice(b"Compressed"), &headers, &mut Vec::new()).await;
271 assert!(result.is_ok());
272 let response = result.unwrap();
273 assert_eq!(response.status(), 200);
274 assert_eq!(response.headers().get("content-encoding").unwrap().to_str().unwrap(), "gzip");
275 }
276
277 #[tokio::test]
278 async fn test_priority_ordering() {
279 let config = Arc::new(Config::default());
280 let handler = Arc::new(Handler::new(config.clone()));
281 let server = Http2Server::new((*config).clone(), handler);
282
283 let mut push_order = Vec::new();
285
286 let result1 = server.handle_push(Bytes::copy_from_slice(b"Push 1"), &HeaderMap::new(), &mut push_order).await;
288 assert!(result1.is_ok());
289 let response1 = result1.unwrap();
290 assert_eq!(response1.status(), 200);
291
292 let result2 = server.handle_push(Bytes::copy_from_slice(b"Push 2 with body"), &HeaderMap::new(), &mut push_order).await;
294 assert!(result2.is_ok());
295 let response2 = result2.unwrap();
296 assert_eq!(response2.status(), 200);
297 }
298
299 #[tokio::test]
300 async fn test_client_push_missing_accept_encoding() {
301 let config = Arc::new(Config::default());
302 let handler = Arc::new(Handler::new(config.clone()));
303 let server = Http2Server::new((*config).clone(), handler);
304
305 let mut headers = HeaderMap::new();
306 headers.insert("content-type", "application/http2+push".parse().unwrap());
307 let result = server.handle_client_push(Bytes::copy_from_slice(b"Client push"), &headers, &mut Vec::new()).await;
310 assert!(result.is_err());
311 assert!(result.unwrap_err().to_string().contains("Accept-Encoding"));
312 }
313
314 #[tokio::test]
315 async fn test_client_push_empty_body() {
316 let config = Arc::new(Config::default());
317 let handler = Arc::new(Handler::new(config.clone()));
318 let server = Http2Server::new((*config).clone(), handler);
319
320 let mut headers = HeaderMap::new();
321 headers.insert("content-type", "application/http2+push".parse().unwrap());
322 headers.insert("accept-encoding", "http/2".parse().unwrap());
323
324 let result = server.handle_client_push(Bytes::new(), &headers, &mut Vec::new()).await;
325 assert!(result.is_ok());
326 let response = result.unwrap();
327 assert_eq!(response.status(), 200);
328 assert_eq!(response.headers().get("content-length").unwrap().to_str().unwrap(), "0");
329 }
330
331 #[tokio::test]
332 async fn test_client_push_with_gzip_compression() {
333 let config = Arc::new(Config::default());
334 let handler = Arc::new(Handler::new(config.clone()));
335 let server = Http2Server::new((*config).clone(), handler);
336
337 let mut headers = HeaderMap::new();
338 headers.insert("content-type", "application/http2+push".parse().unwrap());
339 headers.insert("accept-encoding", "http/2".parse().unwrap());
340 headers.insert("content-encoding", "gzip".parse().unwrap());
341
342 let body = b"Test body for compression";
343 let result = server.handle_client_push(Bytes::copy_from_slice(body), &headers, &mut Vec::new()).await;
344 assert!(result.is_ok());
345 let response = result.unwrap();
346 assert_eq!(response.status(), 200);
347 assert_eq!(response.headers().get("content-encoding").unwrap().to_str().unwrap(), "gzip");
348 }
349
350 #[tokio::test]
351 async fn test_client_push_without_compression() {
352 let config = Arc::new(Config::default());
353 let handler = Arc::new(Handler::new(config.clone()));
354 let server = Http2Server::new((*config).clone(), handler);
355
356 let mut headers = HeaderMap::new();
357 headers.insert("content-type", "application/http2+push".parse().unwrap());
358 headers.insert("accept-encoding", "http/2".parse().unwrap());
359 let body = b"Test body without compression";
362 let result = server.handle_client_push(Bytes::copy_from_slice(body), &headers, &mut Vec::new()).await;
363 assert!(result.is_ok());
364 let response = result.unwrap();
365 assert_eq!(response.status(), 200);
366 }
367
368 #[test]
369 fn test_create_http2_response() {
370 let body = Bytes::from_static(b"Test response body");
371 let mut headers = HeaderMap::new();
372 headers.insert("content-type", "text/plain".parse().unwrap());
373
374 let response = Http2Server::create_http2_response(
375 hyper::StatusCode::OK,
376 body.clone(),
377 &headers,
378 Some("gzip"),
379 );
380
381 assert_eq!(response.status(), 200);
382 assert_eq!(response.headers().get("content-encoding").unwrap().to_str().unwrap(), "gzip");
383 assert_eq!(response.headers().get("content-type").unwrap().to_str().unwrap(), "text/plain");
384 assert_eq!(response.headers().get("content-length").unwrap().to_str().unwrap(), body.len().to_string());
385 }
386
387 #[test]
388 fn test_create_http2_response_without_encoding() {
389 let body = Bytes::from_static(b"Test response body");
390 let mut headers = HeaderMap::new();
391 headers.insert("content-type", "application/json".parse().unwrap());
392
393 let response = Http2Server::create_http2_response(
394 hyper::StatusCode::NOT_FOUND,
395 body.clone(),
396 &headers,
397 None,
398 );
399
400 assert_eq!(response.status(), 404);
401 assert!(response.headers().get("content-encoding").is_none());
402 assert_eq!(response.headers().get("content-type").unwrap().to_str().unwrap(), "application/json");
403 }
404
405 #[test]
406 fn test_is_http2_push_wrong_content_type() {
407 let mut headers = HeaderMap::new();
408 headers.insert("content-type", "application/json".parse().unwrap());
409
410 assert!(!Http2Server::is_http2_push(&Method::POST, &headers));
411 }
412
413 #[test]
414 fn test_is_http2_push_missing_content_type() {
415 let headers = HeaderMap::new();
416 assert!(!Http2Server::is_http2_push(&Method::POST, &headers));
417 }
418
419 #[test]
420 fn test_is_http2_push_put_method() {
421 let mut headers = HeaderMap::new();
422 headers.insert("content-type", "application/http2+push".parse().unwrap());
423
424 assert!(!Http2Server::is_http2_push(&Method::PUT, &headers));
425 }
426
427 #[tokio::test]
428 async fn test_push_with_invalid_content_length() {
429 let config = Arc::new(Config::default());
430 let handler = Arc::new(Handler::new(config.clone()));
431 let server = Http2Server::new((*config).clone(), handler);
432
433 let mut headers = HeaderMap::new();
434 headers.insert("content-type", "application/http2+push".parse().unwrap());
435 headers.insert("content-length", "invalid".parse().unwrap());
436
437 let result = server.handle_push(Bytes::from_static(b"test"), &headers, &mut Vec::new()).await;
438 assert!(result.is_ok());
439 }
440
441 #[tokio::test]
442 async fn test_push_at_boundary_size() {
443 let config = Arc::new(Config::default());
444 let handler = Arc::new(Handler::new(config.clone()));
445 let server = Http2Server::new((*config).clone(), handler);
446
447 let mut headers = HeaderMap::new();
448 headers.insert("content-type", "application/http2+push".parse().unwrap());
449
450 let result = server.handle_push(Bytes::from(vec![b'X'; 65536]), &headers, &mut Vec::new()).await;
452 assert!(result.is_ok());
453
454 let result = server.handle_push(Bytes::from(vec![b'X'; 65537]), &headers, &mut Vec::new()).await;
456 assert!(result.is_err());
457 }
458}