1use crate::App;
7use crate::app::{BuiltApp, Policy};
8use crate::clock::Clock;
9use crate::error::Error;
10use crate::response::{IntoResponse, Response};
11use bytes::Bytes;
12use http::{Method, StatusCode, header};
13use http_body_util::BodyExt;
14use serde::Serialize;
15use serde::de::DeserializeOwned;
16use std::any::TypeId;
17use std::sync::Arc;
18
19pub struct TestPart {
21 name: String,
22 filename: Option<String>,
23 content_type: Option<String>,
24 data: Bytes,
25}
26
27impl TestPart {
28 pub fn text(name: &str, value: &str) -> Self {
30 Self {
31 name: name.to_string(),
32 filename: None,
33 content_type: None,
34 data: Bytes::copy_from_slice(value.as_bytes()),
35 }
36 }
37 pub fn file(name: &str, filename: &str, content_type: &str, data: &[u8]) -> Self {
39 Self {
40 name: name.to_string(),
41 filename: Some(filename.to_string()),
42 content_type: Some(content_type.to_string()),
43 data: Bytes::copy_from_slice(data),
44 }
45 }
46}
47
48const TEST_BOUNDARY: &str = "jerrycan-test-boundary-7f3a";
53
54impl App {
55 pub fn into_test(self) -> TestApp {
62 let mut built = self.build().expect("app failed to build");
63 let clock = Clock::test();
64 let mut overrides = (*built.overrides).clone();
65 overrides.insert(
66 TypeId::of::<Clock>(),
67 Arc::new(clock.clone()) as crate::dep::AnyArc,
68 );
69 built.overrides = Arc::new(overrides);
70 TestApp { built, clock }
71 }
72}
73
74pub struct TestApp {
75 built: BuiltApp,
76 clock: Clock,
80}
81
82impl TestApp {
83 pub fn override_dep<T: Send + Sync + 'static>(mut self, value: T) -> Self {
86 let mut map = (*self.built.overrides).clone();
87 map.insert(TypeId::of::<T>(), Arc::new(value) as crate::dep::AnyArc);
88 self.built.overrides = Arc::new(map);
89 self
90 }
91
92 pub fn clock(&self) -> Clock {
96 self.clock.clone()
97 }
98
99 pub fn task_context(&self) -> crate::TaskContext {
106 self.built.task_context()
107 }
108
109 pub async fn get(&self, path: &str) -> TestResponse {
110 self.request_json(Method::GET, path, None).await
111 }
112 pub async fn delete(&self, path: &str) -> TestResponse {
113 self.request_json(Method::DELETE, path, None).await
114 }
115 pub async fn post_json<B: Serialize>(&self, path: &str, body: &B) -> TestResponse {
116 self.request_json(
117 Method::POST,
118 path,
119 Some(serde_json::to_vec(body).expect("serialize")),
120 )
121 .await
122 }
123 pub async fn put_json<B: Serialize>(&self, path: &str, body: &B) -> TestResponse {
124 self.request_json(
125 Method::PUT,
126 path,
127 Some(serde_json::to_vec(body).expect("serialize")),
128 )
129 .await
130 }
131 pub async fn patch_json<B: Serialize>(&self, path: &str, body: &B) -> TestResponse {
132 self.request_json(
133 Method::PATCH,
134 path,
135 Some(serde_json::to_vec(body).expect("serialize")),
136 )
137 .await
138 }
139
140 pub async fn options_with(&self, path: &str, headers: &[(&str, &str)]) -> TestResponse {
142 self.request(http::Method::OPTIONS, path, headers, None)
143 .await
144 }
145
146 pub async fn request(
150 &self,
151 method: http::Method,
152 path: &str,
153 headers: &[(&str, &str)],
154 body: Option<&[u8]>,
155 ) -> TestResponse {
156 self.send(
157 method,
158 path,
159 body.map(Bytes::copy_from_slice),
160 None,
161 headers,
162 None,
163 )
164 .await
165 }
166
167 pub async fn post_bytes(&self, path: &str, bytes: &[u8]) -> TestResponse {
170 self.post_bytes_with(path, bytes, &[]).await
171 }
172
173 pub async fn post_bytes_with(
175 &self,
176 path: &str,
177 bytes: &[u8],
178 headers: &[(&str, &str)],
179 ) -> TestResponse {
180 self.send(
181 Method::POST,
182 path,
183 Some(Bytes::copy_from_slice(bytes)),
184 Some("application/octet-stream"),
185 headers,
186 None,
187 )
188 .await
189 }
190
191 pub async fn post_multipart(&self, path: &str, parts: &[TestPart]) -> TestResponse {
193 self.post_multipart_with(path, parts, &[]).await
194 }
195
196 pub async fn post_multipart_with(
198 &self,
199 path: &str,
200 parts: &[TestPart],
201 headers: &[(&str, &str)],
202 ) -> TestResponse {
203 let mut body = Vec::new();
204 for part in parts {
205 debug_assert!(
206 !part
207 .data
208 .windows(TEST_BOUNDARY.len() + 4)
209 .any(|w| w[..4] == *b"\r\n--" && &w[4..] == TEST_BOUNDARY.as_bytes()),
210 "TestPart data contains the multipart delimiter — the assembled request would corrupt"
211 );
212 body.extend_from_slice(format!("--{TEST_BOUNDARY}\r\n").as_bytes());
213 match &part.filename {
214 Some(filename) => body.extend_from_slice(
215 format!(
216 "content-disposition: form-data; name=\"{}\"; filename=\"{}\"\r\n",
217 part.name, filename
218 )
219 .as_bytes(),
220 ),
221 None => body.extend_from_slice(
222 format!("content-disposition: form-data; name=\"{}\"\r\n", part.name)
223 .as_bytes(),
224 ),
225 }
226 if let Some(content_type) = &part.content_type {
227 body.extend_from_slice(format!("content-type: {content_type}\r\n").as_bytes());
228 }
229 body.extend_from_slice(b"\r\n");
230 body.extend_from_slice(&part.data);
231 body.extend_from_slice(b"\r\n");
232 }
233 body.extend_from_slice(format!("--{TEST_BOUNDARY}--\r\n").as_bytes());
234 let content_type = format!("multipart/form-data; boundary={TEST_BOUNDARY}");
235 self.send(
236 Method::POST,
237 path,
238 Some(Bytes::from(body)),
239 Some(&content_type),
240 headers,
241 None,
242 )
243 .await
244 }
245
246 pub async fn get_with(&self, path: &str, headers: &[(&str, &str)]) -> TestResponse {
248 self.request_with(Method::GET, path, None, headers).await
249 }
250
251 pub async fn get_from(&self, path: &str, peer: std::net::SocketAddr) -> TestResponse {
254 self.send(Method::GET, path, None, None, &[], Some(peer))
255 .await
256 }
257
258 pub async fn request_from(
260 &self,
261 method: http::Method,
262 path: &str,
263 headers: &[(&str, &str)],
264 peer: std::net::SocketAddr,
265 ) -> TestResponse {
266 self.send(method, path, None, None, headers, Some(peer))
267 .await
268 }
269
270 pub async fn post_json_with<B: Serialize>(
272 &self,
273 path: &str,
274 body: &B,
275 headers: &[(&str, &str)],
276 ) -> TestResponse {
277 self.request_with(
278 Method::POST,
279 path,
280 Some(serde_json::to_vec(body).expect("serialize")),
281 headers,
282 )
283 .await
284 }
285
286 pub async fn delete_with(&self, path: &str, headers: &[(&str, &str)]) -> TestResponse {
288 self.request_with(Method::DELETE, path, None, headers).await
289 }
290
291 pub async fn put_json_with<B: Serialize>(
293 &self,
294 path: &str,
295 body: &B,
296 headers: &[(&str, &str)],
297 ) -> TestResponse {
298 self.request_with(
299 Method::PUT,
300 path,
301 Some(serde_json::to_vec(body).expect("serialize")),
302 headers,
303 )
304 .await
305 }
306
307 pub async fn patch_json_with<B: Serialize>(
309 &self,
310 path: &str,
311 body: &B,
312 headers: &[(&str, &str)],
313 ) -> TestResponse {
314 self.request_with(
315 Method::PATCH,
316 path,
317 Some(serde_json::to_vec(body).expect("serialize")),
318 headers,
319 )
320 .await
321 }
322
323 async fn request_json(
324 &self,
325 method: Method,
326 path: &str,
327 json: Option<Vec<u8>>,
328 ) -> TestResponse {
329 self.request_with(method, path, json, &[]).await
330 }
331
332 async fn request_with(
333 &self,
334 method: Method,
335 path: &str,
336 json: Option<Vec<u8>>,
337 headers: &[(&str, &str)],
338 ) -> TestResponse {
339 let content_type = json.as_ref().map(|_| "application/json");
340 self.send(
341 method,
342 path,
343 json.map(Bytes::from),
344 content_type,
345 headers,
346 None,
347 )
348 .await
349 }
350
351 async fn send(
358 &self,
359 method: Method,
360 path: &str,
361 body: Option<Bytes>,
362 content_type: Option<&str>,
363 headers: &[(&str, &str)],
364 peer: Option<std::net::SocketAddr>,
365 ) -> TestResponse {
366 let mut builder = http::Request::builder().method(method).uri(path);
367 let explicit_ct = headers
372 .iter()
373 .any(|(name, _)| name.eq_ignore_ascii_case("content-type"));
374 if let Some(ct) = content_type
375 && !explicit_ct
376 {
377 builder = builder.header(header::CONTENT_TYPE, ct);
378 }
379 for (name, value) in headers {
380 builder = builder.header(*name, *value);
381 }
382 let req = builder.body(()).expect("test request build");
383 let (mut parts, ()) = req.into_parts();
384 if let Some(peer) = peer {
387 parts.extensions.insert(crate::extract::ClientAddr(peer));
388 }
389 let body = body.unwrap_or_default();
390 let cors_origin = parts.headers.get(http::header::ORIGIN).cloned();
394
395 let (limit, stream) = match self.built.route_policy(&parts) {
397 Policy::Reject(response) => return TestResponse::collect(response).await,
398 Policy::Route { limit, stream } => (limit, stream),
399 };
400 let lane = if stream {
405 crate::extract::BodyLane::Stream(Some(test_stream_lane(body, limit)))
406 } else {
407 if body.len() > limit {
408 let mut response = Error::payload_too_large().into_response();
409 if self.built.security_headers {
410 crate::app::apply_security_headers(&mut response);
411 }
412 if let Some(config) = &self.built.cors {
413 crate::cors::apply_cors(&mut response, cors_origin.as_ref(), config);
414 }
415 return TestResponse::collect(response).await;
416 }
417 crate::extract::BodyLane::Buffered(body)
418 };
419 TestResponse::collect(self.built.dispatch(parts, lane).await).await
420 }
421}
422
423fn test_stream_lane(body: Bytes, limit: usize) -> crate::extract::StreamLane {
428 struct Frames(std::collections::VecDeque<Bytes>);
429 impl http_body::Body for Frames {
430 type Data = Bytes;
431 type Error = Box<dyn std::error::Error + Send + Sync>;
432 fn poll_frame(
433 mut self: std::pin::Pin<&mut Self>,
434 _cx: &mut std::task::Context<'_>,
435 ) -> std::task::Poll<Option<Result<http_body::Frame<Bytes>, Self::Error>>> {
436 std::task::Poll::Ready(self.0.pop_front().map(|b| Ok(http_body::Frame::data(b))))
437 }
438 }
439 let frames = body.chunks(13).map(Bytes::copy_from_slice).collect();
440 let limited = http_body_util::Limited::new(Frames(frames), limit);
443 http_body_util::combinators::UnsyncBoxBody::new(limited)
444}
445
446pub struct TestResponse {
447 status: StatusCode,
448 headers: http::HeaderMap,
449 body: Bytes,
450}
451
452impl TestResponse {
453 async fn collect(res: Response) -> Self {
454 let (parts, body) = res.into_parts();
455 let body = body
456 .collect()
457 .await
459 .unwrap_or_else(|e| panic!("response body failed mid-stream: {e}"))
460 .to_bytes();
461 Self {
462 status: parts.status,
463 headers: parts.headers,
464 body,
465 }
466 }
467
468 pub fn status(&self) -> StatusCode {
469 self.status
470 }
471 pub fn headers(&self) -> &http::HeaderMap {
472 &self.headers
473 }
474 pub fn text(&self) -> String {
475 String::from_utf8_lossy(&self.body).into_owned()
476 }
477
478 pub fn bytes(&self) -> &[u8] {
480 &self.body
481 }
482
483 pub fn json<T: DeserializeOwned>(&self) -> T {
485 serde_json::from_slice(&self.body).unwrap_or_else(|e| {
486 panic!(
487 "response body is not the expected JSON shape: {e}\nbody: {}",
488 self.text()
489 )
490 })
491 }
492}
493
494#[cfg(test)]
495mod tests {
496 use crate::prelude::*;
497
498 #[tokio::test]
499 async fn post_multipart_builds_a_parseable_request() {
500 use crate::multipart::Multipart;
501 async fn upload(mut mp: Multipart) -> Result<Json<Vec<String>>> {
502 let mut seen = Vec::new();
503 while let Some(part) = mp.next_part().await? {
504 let label = match part.filename() {
505 Some(f) => format!("{}:{}", part.name(), f),
506 None => part.name().to_string(),
507 };
508 let len = part.bytes().await?.len();
509 seen.push(format!("{label}({len})"));
510 }
511 Ok(Json(seen))
512 }
513 let t = App::new()
514 .route("/upload", post(upload).stream_body())
515 .into_test();
516 let res = t
517 .post_multipart(
518 "/upload",
519 &[
520 TestPart::text("title", "Q3 leads"),
521 TestPart::file("csv", "leads.csv", "text/csv", b"a,b\n1,2\n"),
522 ],
523 )
524 .await;
525 assert_eq!(res.status().as_u16(), 200, "body: {}", res.text());
526 assert_eq!(
527 res.json::<Vec<String>>(),
528 vec!["title(8)".to_string(), "csv:leads.csv(8)".to_string()]
529 );
530 }
531
532 #[tokio::test]
533 async fn test_response_exposes_raw_bytes() {
534 async fn download() -> StreamBody {
535 let (body, tx) = StreamBody::channel();
536 tokio::spawn(async move {
537 let _ = tx.send(&b"\x00\x01binary"[..]).await;
538 });
539 body.content_type("application/octet-stream")
540 }
541 let t = App::new().route("/dl", get(download)).into_test();
542 let res = t.get("/dl").await;
543 assert_eq!(res.bytes(), b"\x00\x01binary");
544 }
545
546 #[tokio::test]
549 async fn post_multipart_with_carries_extra_headers() {
550 use crate::multipart::Multipart;
551 async fn upload(headers: Headers, mut mp: Multipart) -> Result<Json<(bool, usize)>> {
552 let authed = headers.get("x-auth").is_some();
553 let mut parts = 0;
554 while let Some(part) = mp.next_part().await? {
555 let _ = part.bytes().await?;
556 parts += 1;
557 }
558 Ok(Json((authed, parts)))
559 }
560 let t = App::new()
561 .route("/upload", post(upload).stream_body())
562 .into_test();
563 let res = t
564 .post_multipart_with(
565 "/upload",
566 &[TestPart::text("field", "value")],
567 &[("x-auth", "token")],
568 )
569 .await;
570 assert_eq!(res.status().as_u16(), 200, "body: {}", res.text());
571 assert_eq!(res.json::<(bool, usize)>(), (true, 1));
572 }
573
574 #[tokio::test]
578 async fn test_clock_handle_drives_resolved_clock_in_task_context() {
579 let t = App::new().into_test();
580 let mut ctx = t.task_context();
581 let resolved = ctx.resolve::<Clock>().await.unwrap();
582 let before = resolved.now();
583 t.clock().advance(std::time::Duration::from_secs(60));
584 assert_eq!(
585 resolved.now().duration_since(before).unwrap(),
586 std::time::Duration::from_secs(60),
587 "TestApp::clock() and the resolved Clock share one offset",
588 );
589 }
590}