Skip to main content

reinhardt_testkit/
factory.rs

1//! Request factory for creating test requests
2//!
3//! Similar to DRF's APIRequestFactory
4
5use bytes::Bytes;
6use http::{HeaderMap, HeaderValue, Method, Request};
7use http_body_util::Full;
8use serde::Serialize;
9use serde_json::Value;
10use std::collections::HashMap;
11
12use crate::client::ClientError;
13
14/// Factory for creating test requests
15pub struct APIRequestFactory {
16	default_format: String,
17	default_headers: HeaderMap,
18}
19
20impl APIRequestFactory {
21	/// Create a new request factory
22	///
23	/// # Examples
24	///
25	/// ```
26	/// use reinhardt_testkit::factory::APIRequestFactory;
27	///
28	/// let factory = APIRequestFactory::new();
29	/// let request = factory.get("/api/users/").build();
30	/// ```
31	pub fn new() -> Self {
32		Self {
33			default_format: "json".to_string(),
34			default_headers: HeaderMap::new(),
35		}
36	}
37	/// Set the default content format (e.g., `"json"`, `"xml"`).
38	pub fn with_format(mut self, format: impl Into<String>) -> Self {
39		self.default_format = format.into();
40		self
41	}
42	/// Add a default header to all requests created by this factory.
43	pub fn with_header(
44		mut self,
45		name: impl AsRef<str>,
46		value: impl AsRef<str>,
47	) -> Result<Self, ClientError> {
48		let header_name: http::header::HeaderName = name.as_ref().parse().map_err(|_| {
49			ClientError::RequestFailed(format!("Invalid header name: {}", name.as_ref()))
50		})?;
51		self.default_headers
52			.insert(header_name, HeaderValue::from_str(value.as_ref())?);
53		Ok(self)
54	}
55	/// Create a GET request
56	///
57	/// # Examples
58	///
59	/// ```
60	/// use reinhardt_testkit::factory::APIRequestFactory;
61	///
62	/// let factory = APIRequestFactory::new();
63	/// let request = factory.get("/api/users/").build().unwrap();
64	/// assert_eq!(request.method(), "GET");
65	/// ```
66	pub fn get(&self, path: &str) -> RequestBuilder {
67		RequestBuilder::new(Method::GET, path, &self.default_headers)
68	}
69	/// Create a POST request
70	///
71	/// # Examples
72	///
73	/// ```
74	/// use reinhardt_testkit::factory::APIRequestFactory;
75	/// use serde_json::json;
76	///
77	/// let factory = APIRequestFactory::new();
78	/// let data = json!({"name": "test"});
79	/// let request = factory.post("/api/users/").json(&data).unwrap().build().unwrap();
80	/// assert_eq!(request.method(), "POST");
81	/// ```
82	pub fn post(&self, path: &str) -> RequestBuilder {
83		RequestBuilder::new(Method::POST, path, &self.default_headers)
84			.with_format(&self.default_format)
85	}
86	/// Create a PUT request
87	///
88	/// # Examples
89	///
90	/// ```
91	/// use reinhardt_testkit::factory::APIRequestFactory;
92	/// use serde_json::json;
93	///
94	/// let factory = APIRequestFactory::new();
95	/// let data = json!({"name": "updated"});
96	/// let request = factory.put("/api/users/1/").json(&data).unwrap().build().unwrap();
97	/// assert_eq!(request.method(), "PUT");
98	/// ```
99	pub fn put(&self, path: &str) -> RequestBuilder {
100		RequestBuilder::new(Method::PUT, path, &self.default_headers)
101			.with_format(&self.default_format)
102	}
103	/// Create a PATCH request
104	///
105	/// # Examples
106	///
107	/// ```
108	/// use reinhardt_testkit::factory::APIRequestFactory;
109	/// use serde_json::json;
110	///
111	/// let factory = APIRequestFactory::new();
112	/// let data = json!({"name": "partial_update"});
113	/// let request = factory.patch("/api/users/1/").json(&data).unwrap().build().unwrap();
114	/// assert_eq!(request.method(), "PATCH");
115	/// ```
116	pub fn patch(&self, path: &str) -> RequestBuilder {
117		RequestBuilder::new(Method::PATCH, path, &self.default_headers)
118			.with_format(&self.default_format)
119	}
120	/// Create a DELETE request
121	///
122	/// # Examples
123	///
124	/// ```
125	/// use reinhardt_testkit::factory::APIRequestFactory;
126	///
127	/// let factory = APIRequestFactory::new();
128	/// let request = factory.delete("/api/users/1/").build().unwrap();
129	/// assert_eq!(request.method(), "DELETE");
130	/// ```
131	pub fn delete(&self, path: &str) -> RequestBuilder {
132		RequestBuilder::new(Method::DELETE, path, &self.default_headers)
133	}
134	/// Create a HEAD request
135	///
136	/// # Examples
137	///
138	/// ```
139	/// use reinhardt_testkit::factory::APIRequestFactory;
140	///
141	/// let factory = APIRequestFactory::new();
142	/// let request = factory.head("/api/users/").build().unwrap();
143	/// assert_eq!(request.method(), "HEAD");
144	/// ```
145	pub fn head(&self, path: &str) -> RequestBuilder {
146		RequestBuilder::new(Method::HEAD, path, &self.default_headers)
147	}
148	/// Create an OPTIONS request
149	///
150	/// # Examples
151	///
152	/// ```
153	/// use reinhardt_testkit::factory::APIRequestFactory;
154	///
155	/// let factory = APIRequestFactory::new();
156	/// let request = factory.options("/api/users/").build().unwrap();
157	/// assert_eq!(request.method(), "OPTIONS");
158	/// ```
159	pub fn options(&self, path: &str) -> RequestBuilder {
160		RequestBuilder::new(Method::OPTIONS, path, &self.default_headers)
161	}
162	/// Create a generic request with custom method
163	///
164	/// # Examples
165	///
166	/// ```
167	/// use reinhardt_testkit::factory::APIRequestFactory;
168	/// use http::Method;
169	///
170	/// let factory = APIRequestFactory::new();
171	/// let request = factory.request(Method::TRACE, "/api/trace/").build().unwrap();
172	/// assert_eq!(request.method(), "TRACE");
173	/// ```
174	pub fn request(&self, method: Method, path: &str) -> RequestBuilder {
175		RequestBuilder::new(method, path, &self.default_headers)
176	}
177}
178
179impl Default for APIRequestFactory {
180	fn default() -> Self {
181		Self::new()
182	}
183}
184
185/// Builder for constructing test requests
186pub struct RequestBuilder {
187	method: Method,
188	path: String,
189	headers: HeaderMap,
190	query_params: HashMap<String, String>,
191	body: Option<Bytes>,
192	format: String,
193	user: Option<Value>,
194}
195
196impl RequestBuilder {
197	/// Create a new request builder with the given HTTP method, path, and default headers.
198	pub fn new(method: Method, path: &str, default_headers: &HeaderMap) -> Self {
199		Self {
200			method,
201			path: path.to_string(),
202			headers: default_headers.clone(),
203			query_params: HashMap::new(),
204			body: None,
205			format: "json".to_string(),
206			user: None,
207		}
208	}
209	/// Get the HTTP method of this request.
210	pub fn method(&self) -> Method {
211		self.method.clone()
212	}
213	/// Get the path of this request.
214	pub fn path(&self) -> &str {
215		&self.path
216	}
217	/// Set the content format for this request.
218	pub fn with_format(mut self, format: &str) -> Self {
219		self.format = format.to_string();
220		self
221	}
222	// Fixes #865
223	/// Add a custom HTTP header to this request builder.
224	pub fn header(mut self, name: &str, value: &str) -> Result<Self, ClientError> {
225		let header_name: http::header::HeaderName = name
226			.parse()
227			.map_err(|_| ClientError::RequestFailed(format!("Invalid header name: {}", name)))?;
228		self.headers
229			.insert(header_name, HeaderValue::from_str(value)?);
230		Ok(self)
231	}
232	/// Add a query parameter to this request.
233	pub fn query(mut self, key: &str, value: &str) -> Self {
234		self.query_params.insert(key.to_string(), value.to_string());
235		self
236	}
237	/// Add a query parameter (alias for `query`).
238	pub fn query_param(self, key: &str, value: &str) -> Self {
239		self.query(key, value)
240	}
241	/// Set request body as JSON
242	///
243	/// # Examples
244	///
245	/// ```
246	/// use reinhardt_testkit::factory::APIRequestFactory;
247	/// use serde_json::json;
248	///
249	/// let factory = APIRequestFactory::new();
250	/// let data = json!({"name": "test"});
251	/// let request = factory.post("/api/users/").json(&data).unwrap().build();
252	/// ```
253	pub fn json<T: Serialize>(mut self, data: &T) -> Result<Self, ClientError> {
254		let json = serde_json::to_vec(data)?;
255		self.body = Some(Bytes::from(json));
256		self.format = "json".to_string();
257		Ok(self)
258	}
259	/// Set request body as form data
260	///
261	/// # Examples
262	///
263	/// ```
264	/// use reinhardt_testkit::factory::APIRequestFactory;
265	/// use serde_json::json;
266	///
267	/// let factory = APIRequestFactory::new();
268	/// let data = json!({"name": "test", "age": 30});
269	/// let request = factory.post("/api/users/").form(&data).unwrap().build();
270	/// ```
271	pub fn form<T: Serialize>(mut self, data: &T) -> Result<Self, ClientError> {
272		let json_value = serde_json::to_value(data)?;
273		if let Value::Object(map) = json_value {
274			let form_data = map
275				.iter()
276				.map(|(k, v)| {
277					let value_str = match v {
278						Value::String(s) => s.clone(),
279						_ => v.to_string(),
280					};
281					format!(
282						"{}={}",
283						url::form_urlencoded::byte_serialize(k.as_bytes()).collect::<String>(),
284						url::form_urlencoded::byte_serialize(value_str.as_bytes())
285							.collect::<String>()
286					)
287				})
288				.collect::<Vec<_>>()
289				.join("&");
290			self.body = Some(Bytes::from(form_data));
291			self.format = "form".to_string();
292			Ok(self)
293		} else {
294			Err(ClientError::RequestFailed(
295				"Expected object for form data".to_string(),
296			))
297		}
298	}
299	/// Set raw body
300	///
301	/// # Examples
302	///
303	/// ```
304	/// use reinhardt_testkit::factory::APIRequestFactory;
305	///
306	/// let factory = APIRequestFactory::new();
307	/// let request = factory.post("/api/upload/").body("raw data").build().unwrap();
308	/// ```
309	pub fn body(mut self, body: impl Into<Bytes>) -> Self {
310		self.body = Some(body.into());
311		self
312	}
313	/// Force authenticate as user (for testing)
314	///
315	/// # Examples
316	///
317	/// ```
318	/// use reinhardt_testkit::factory::APIRequestFactory;
319	/// use serde_json::json;
320	///
321	/// let factory = APIRequestFactory::new();
322	/// let user = json!({"id": 1, "username": "testuser"});
323	/// let request = factory.get("/api/profile/").force_authenticate(user).build().unwrap();
324	/// ```
325	pub fn force_authenticate(mut self, user: Value) -> Self {
326		self.user = Some(user);
327		self
328	}
329	/// Build the request
330	///
331	/// # Examples
332	///
333	/// ```
334	/// use reinhardt_testkit::factory::APIRequestFactory;
335	///
336	/// let factory = APIRequestFactory::new();
337	/// let request = factory.get("/api/users/").build().unwrap();
338	/// assert_eq!(request.method(), "GET");
339	/// ```
340	pub fn build(self) -> Result<Request<Full<Bytes>>, ClientError> {
341		let mut url = self.path.clone();
342
343		// Add query parameters
344		if !self.query_params.is_empty() {
345			let query_string = self
346				.query_params
347				.iter()
348				.map(|(k, v)| {
349					format!(
350						"{}={}",
351						url::form_urlencoded::byte_serialize(k.as_bytes()).collect::<String>(),
352						url::form_urlencoded::byte_serialize(v.as_bytes()).collect::<String>()
353					)
354				})
355				.collect::<Vec<_>>()
356				.join("&");
357			url = format!("{}?{}", url, query_string);
358		}
359
360		let mut request = Request::builder().method(self.method).uri(url);
361
362		// Add headers
363		for (name, value) in self.headers.iter() {
364			request = request.header(name, value);
365		}
366
367		// Add content type based on format
368		if self.body.is_some() {
369			let content_type = match self.format.as_str() {
370				"json" => "application/json",
371				"form" => "application/x-www-form-urlencoded",
372				_ => "application/octet-stream",
373			};
374			request = request.header("Content-Type", content_type);
375		}
376
377		// Add authentication marker if user is set
378		if self.user.is_some() {
379			request = request.header("X-Test-User", "authenticated");
380		}
381
382		// Build request with body
383		let body = self.body.unwrap_or_default();
384		let req = request.body(Full::new(body))?;
385
386		Ok(req)
387	}
388}
389
390#[cfg(test)]
391mod tests {
392	use super::*;
393	use rstest::rstest;
394	use serde_json::json;
395
396	// ========================================================================
397	// Normal: APIRequestFactory
398	// ========================================================================
399
400	#[rstest]
401	fn test_factory_new() {
402		// Arrange
403		let factory = APIRequestFactory::new();
404
405		// Act
406		let request = factory.get("/api/users/").build().unwrap();
407
408		// Assert
409		assert_eq!(request.method(), Method::GET);
410	}
411
412	#[rstest]
413	fn test_factory_default() {
414		// Arrange
415		let factory_new = APIRequestFactory::new();
416		let factory_default = APIRequestFactory::default();
417
418		// Act
419		let req_new = factory_new.get("/test").build().unwrap();
420		let req_default = factory_default.get("/test").build().unwrap();
421
422		// Assert
423		assert_eq!(req_new.method(), req_default.method());
424		assert_eq!(req_new.uri(), req_default.uri());
425	}
426
427	#[rstest]
428	fn test_factory_with_format() {
429		// Arrange
430		let factory = APIRequestFactory::new().with_format("xml");
431
432		// Act
433		let request = factory.post("/api/data/").body("payload").build().unwrap();
434
435		// Assert
436		assert_eq!(
437			request.headers().get("Content-Type").unwrap(),
438			"application/octet-stream"
439		);
440	}
441
442	#[rstest]
443	fn test_factory_with_header() {
444		// Arrange
445		let factory = APIRequestFactory::new()
446			.with_header("X-Custom", "value123")
447			.unwrap();
448
449		// Act
450		let request = factory.get("/api/items/").build().unwrap();
451
452		// Assert
453		assert_eq!(request.headers().get("x-custom").unwrap(), "value123");
454	}
455
456	#[rstest]
457	fn test_factory_get() {
458		// Arrange
459		let factory = APIRequestFactory::new();
460
461		// Act
462		let request = factory.get("/api/users/").build().unwrap();
463
464		// Assert
465		assert_eq!(request.method(), Method::GET);
466	}
467
468	#[rstest]
469	fn test_factory_post() {
470		// Arrange
471		let factory = APIRequestFactory::new();
472
473		// Act
474		let request = factory.post("/api/users/").build().unwrap();
475
476		// Assert
477		assert_eq!(request.method(), Method::POST);
478	}
479
480	#[rstest]
481	fn test_factory_put() {
482		// Arrange
483		let factory = APIRequestFactory::new();
484
485		// Act
486		let request = factory.put("/api/users/1/").build().unwrap();
487
488		// Assert
489		assert_eq!(request.method(), Method::PUT);
490	}
491
492	#[rstest]
493	fn test_factory_patch() {
494		// Arrange
495		let factory = APIRequestFactory::new();
496
497		// Act
498		let request = factory.patch("/api/users/1/").build().unwrap();
499
500		// Assert
501		assert_eq!(request.method(), Method::PATCH);
502	}
503
504	#[rstest]
505	fn test_factory_delete() {
506		// Arrange
507		let factory = APIRequestFactory::new();
508
509		// Act
510		let request = factory.delete("/api/users/1/").build().unwrap();
511
512		// Assert
513		assert_eq!(request.method(), Method::DELETE);
514	}
515
516	#[rstest]
517	fn test_factory_head() {
518		// Arrange
519		let factory = APIRequestFactory::new();
520
521		// Act
522		let request = factory.head("/api/users/").build().unwrap();
523
524		// Assert
525		assert_eq!(request.method(), Method::HEAD);
526	}
527
528	#[rstest]
529	fn test_factory_options() {
530		// Arrange
531		let factory = APIRequestFactory::new();
532
533		// Act
534		let request = factory.options("/api/users/").build().unwrap();
535
536		// Assert
537		assert_eq!(request.method(), Method::OPTIONS);
538	}
539
540	#[rstest]
541	fn test_factory_request_custom() {
542		// Arrange
543		let factory = APIRequestFactory::new();
544
545		// Act
546		let request = factory
547			.request(Method::TRACE, "/api/trace/")
548			.build()
549			.unwrap();
550
551		// Assert
552		assert_eq!(request.method(), Method::TRACE);
553	}
554
555	// ========================================================================
556	// Normal: RequestBuilder
557	// ========================================================================
558
559	#[rstest]
560	fn test_builder_json() {
561		// Arrange
562		let factory = APIRequestFactory::new();
563		let data = json!({"name": "test"});
564
565		// Act
566		let request = factory
567			.post("/api/users/")
568			.json(&data)
569			.unwrap()
570			.build()
571			.unwrap();
572
573		// Assert
574		assert_eq!(
575			request.headers().get("Content-Type").unwrap(),
576			"application/json"
577		);
578		assert_eq!(request.method(), Method::POST);
579	}
580
581	#[rstest]
582	fn test_builder_form() {
583		// Arrange
584		let factory = APIRequestFactory::new();
585		let data = json!({"name": "test", "age": 30});
586
587		// Act
588		let request = factory
589			.post("/api/users/")
590			.form(&data)
591			.unwrap()
592			.build()
593			.unwrap();
594
595		// Assert
596		assert_eq!(
597			request.headers().get("Content-Type").unwrap(),
598			"application/x-www-form-urlencoded"
599		);
600	}
601
602	#[rstest]
603	fn test_builder_raw_body() {
604		// Arrange
605		let factory = APIRequestFactory::new();
606
607		// Act
608		let request = factory
609			.post("/api/upload/")
610			.body("raw data")
611			.build()
612			.unwrap();
613
614		// Assert
615		assert_eq!(request.method(), Method::POST);
616		assert_eq!(
617			request.headers().get("Content-Type").unwrap(),
618			"application/json"
619		);
620	}
621
622	#[rstest]
623	fn test_builder_query_single() {
624		// Arrange
625		let factory = APIRequestFactory::new();
626
627		// Act
628		let request = factory
629			.get("/api/users/")
630			.query("page", "1")
631			.build()
632			.unwrap();
633
634		// Assert
635		assert_eq!(request.uri().to_string(), "/api/users/?page=1");
636	}
637
638	#[rstest]
639	fn test_builder_query_multiple() {
640		// Arrange
641		let factory = APIRequestFactory::new();
642
643		// Act
644		let request = factory
645			.get("/api/users/")
646			.query("page", "1")
647			.query_param("limit", "10")
648			.build()
649			.unwrap();
650
651		// Assert
652		let uri = request.uri().to_string();
653		assert!(uri.contains("page=1"));
654		assert!(uri.contains("limit=10"));
655		assert!(uri.contains('&'));
656	}
657
658	#[rstest]
659	fn test_builder_force_authenticate() {
660		// Arrange
661		let factory = APIRequestFactory::new();
662		let user = json!({"id": 1, "username": "testuser"});
663
664		// Act
665		let request = factory
666			.get("/api/profile/")
667			.force_authenticate(user)
668			.build()
669			.unwrap();
670
671		// Assert
672		assert_eq!(
673			request.headers().get("X-Test-User").unwrap(),
674			"authenticated"
675		);
676	}
677
678	#[rstest]
679	fn test_builder_method_getter() {
680		// Arrange
681		let factory = APIRequestFactory::new();
682
683		// Act
684		let builder = factory.get("/test");
685
686		// Assert
687		assert_eq!(builder.method(), Method::GET);
688	}
689
690	#[rstest]
691	fn test_builder_path_getter() {
692		// Arrange
693		let factory = APIRequestFactory::new();
694
695		// Act
696		let builder = factory.get("/api/items/");
697
698		// Assert
699		assert_eq!(builder.path(), "/api/items/");
700	}
701
702	#[rstest]
703	fn test_builder_with_format() {
704		// Arrange
705		let factory = APIRequestFactory::new();
706
707		// Act
708		let request = factory
709			.post("/api/data/")
710			.with_format("form")
711			.body("key=val")
712			.build()
713			.unwrap();
714
715		// Assert
716		assert_eq!(
717			request.headers().get("Content-Type").unwrap(),
718			"application/x-www-form-urlencoded"
719		);
720	}
721
722	// ========================================================================
723	// Error cases
724	// ========================================================================
725
726	#[rstest]
727	fn test_factory_with_header_invalid_name() {
728		// Arrange / Act
729		let result = APIRequestFactory::new().with_header("invalid header!", "value");
730
731		// Assert
732		assert!(result.is_err());
733	}
734
735	#[rstest]
736	fn test_builder_form_non_object() {
737		// Arrange
738		let factory = APIRequestFactory::new();
739		let data = json!([1, 2, 3]);
740
741		// Act
742		let result = factory.post("/api/users/").form(&data);
743
744		// Assert
745		assert!(result.is_err());
746	}
747
748	#[rstest]
749	fn test_builder_header_invalid_name() {
750		// Arrange
751		let factory = APIRequestFactory::new();
752
753		// Act
754		let result = factory.get("/test").header("bad header!", "value");
755
756		// Assert
757		assert!(result.is_err());
758	}
759
760	// ========================================================================
761	// Edge cases
762	// ========================================================================
763
764	#[rstest]
765	fn test_builder_no_body_no_content_type() {
766		// Arrange
767		let factory = APIRequestFactory::new();
768
769		// Act
770		let request = factory.get("/api/users/").build().unwrap();
771
772		// Assert
773		assert!(request.headers().get("Content-Type").is_none());
774	}
775
776	#[rstest]
777	fn test_builder_json_empty_object() {
778		// Arrange
779		let factory = APIRequestFactory::new();
780		let data = json!({});
781
782		// Act
783		let request = factory
784			.post("/api/data/")
785			.json(&data)
786			.unwrap()
787			.build()
788			.unwrap();
789
790		// Assert
791		assert_eq!(
792			request.headers().get("Content-Type").unwrap(),
793			"application/json"
794		);
795	}
796
797	#[rstest]
798	fn test_builder_query_special_chars() {
799		// Arrange
800		let factory = APIRequestFactory::new();
801
802		// Act
803		let request = factory
804			.get("/api/search/")
805			.query("q", "hello world&foo=bar")
806			.build()
807			.unwrap();
808
809		// Assert
810		let uri = request.uri().to_string();
811		assert!(uri.contains("hello+world"));
812		assert!(!uri.contains("hello world&foo=bar"));
813	}
814
815	#[rstest]
816	fn test_builder_unknown_format() {
817		// Arrange
818		let factory = APIRequestFactory::new().with_format("xml");
819
820		// Act
821		let request = factory.post("/api/data/").body("<xml/>").build().unwrap();
822
823		// Assert
824		assert_eq!(
825			request.headers().get("Content-Type").unwrap(),
826			"application/octet-stream"
827		);
828	}
829}