Skip to main content

reinhardt_test/
views.rs

1//! View test utilities for Reinhardt framework
2//!
3//! Provides test models, request builders, and test views for view testing.
4
5use bytes::Bytes;
6use hyper::{HeaderMap, Method, Uri, Version};
7use reinhardt_http::{Error, Request, Response, Result};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11// ============================================================================
12// Test Models
13// ============================================================================
14
15/// Test model for view tests
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17pub struct TestModel {
18	pub id: Option<i64>,
19	pub name: String,
20	pub slug: String,
21	pub created_at: String,
22}
23
24crate::impl_test_model!(TestModel, i64, "test_models");
25
26/// Test model for API view tests
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
28pub struct ApiTestModel {
29	pub id: Option<i64>,
30	pub title: String,
31	pub content: String,
32}
33
34crate::impl_test_model!(ApiTestModel, i64, "api_test_models");
35
36// ============================================================================
37// Request Creation Functions
38// ============================================================================
39
40/// Create a test request with the given parameters
41pub fn create_request(
42	method: Method,
43	path: &str,
44	query_params: Option<HashMap<String, String>>,
45	headers: Option<HeaderMap>,
46	body: Option<Bytes>,
47) -> Request {
48	// Fixes #880: URL-encode query parameter keys and values to prevent injection
49	let uri_str = if let Some(ref params) = query_params {
50		let query = params
51			.iter()
52			.map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
53			.collect::<Vec<_>>()
54			.join("&");
55		format!("{}?{}", path, query)
56	} else {
57		path.to_string()
58	};
59
60	let uri = uri_str.parse::<Uri>().unwrap();
61	Request::builder()
62		.method(method)
63		.uri(uri)
64		.version(Version::HTTP_11)
65		.headers(headers.unwrap_or_default())
66		.body(body.unwrap_or_default())
67		.build()
68		.expect("Failed to build request")
69}
70
71/// Create a test request with path parameters
72pub fn create_request_with_path_params(
73	method: Method,
74	path: &str,
75	path_params: HashMap<String, String>,
76	query_params: Option<HashMap<String, String>>,
77	headers: Option<HeaderMap>,
78	body: Option<Bytes>,
79) -> Request {
80	let mut request = create_request(method, path, query_params, headers, body);
81	request.path_params = path_params;
82	request
83}
84
85/// Create a test request with headers
86pub fn create_request_with_headers(
87	method: Method,
88	path: &str,
89	headers: HashMap<String, String>,
90	body: Option<Bytes>,
91) -> Request {
92	let mut header_map = HeaderMap::new();
93	for (key, value) in headers {
94		if let (Ok(header_name), Ok(header_value)) = (
95			hyper::header::HeaderName::from_bytes(key.as_bytes()),
96			hyper::header::HeaderValue::from_str(&value),
97		) {
98			header_map.insert(header_name, header_value);
99		}
100	}
101
102	create_request(method, path, None, Some(header_map), body)
103}
104
105/// Create a test request with JSON body
106pub fn create_json_request(method: Method, path: &str, json_data: &serde_json::Value) -> Request {
107	let body = Bytes::from(serde_json::to_vec(json_data).unwrap());
108	let mut headers = HeaderMap::new();
109	headers.insert(
110		hyper::header::CONTENT_TYPE,
111		hyper::header::HeaderValue::from_static("application/json"),
112	);
113
114	create_request(method, path, None, Some(headers), Some(body))
115}
116
117// ============================================================================
118// Test Data Generation
119// ============================================================================
120
121/// Create test objects for list views
122pub fn create_test_objects() -> Vec<TestModel> {
123	vec![
124		TestModel {
125			id: Some(1),
126			name: "First Object".to_string(),
127			slug: "first-object".to_string(),
128			created_at: "2023-01-01T00:00:00Z".to_string(),
129		},
130		TestModel {
131			id: Some(2),
132			name: "Second Object".to_string(),
133			slug: "second-object".to_string(),
134			created_at: "2023-01-02T00:00:00Z".to_string(),
135		},
136		TestModel {
137			id: Some(3),
138			name: "Third Object".to_string(),
139			slug: "third-object".to_string(),
140			created_at: "2023-01-03T00:00:00Z".to_string(),
141		},
142	]
143}
144
145/// Create test objects for API views
146pub fn create_api_test_objects() -> Vec<ApiTestModel> {
147	vec![
148		ApiTestModel {
149			id: Some(1),
150			title: "First Post".to_string(),
151			content: "This is the first post content".to_string(),
152		},
153		ApiTestModel {
154			id: Some(2),
155			title: "Second Post".to_string(),
156			content: "This is the second post content".to_string(),
157		},
158		ApiTestModel {
159			id: Some(3),
160			title: "Third Post".to_string(),
161			content: "This is the third post content".to_string(),
162		},
163	]
164}
165
166/// Create a large set of test objects for pagination testing
167pub fn create_large_test_objects(count: usize) -> Vec<TestModel> {
168	(0..count)
169		.map(|i| TestModel {
170			id: Some(i as i64),
171			name: format!("Object {}", i),
172			slug: format!("object-{}", i),
173			created_at: format!("2023-01-{:02}T00:00:00Z", (i % 30) + 1),
174		})
175		.collect()
176}
177
178// ============================================================================
179// Test Views
180// ============================================================================
181
182/// Create a simple view for testing basic functionality
183pub struct SimpleTestView {
184	pub content: String,
185	pub allowed_methods: Vec<Method>,
186}
187
188impl SimpleTestView {
189	pub fn new(content: &str) -> Self {
190		Self {
191			content: content.to_string(),
192			allowed_methods: vec![Method::GET],
193		}
194	}
195
196	pub fn with_methods(mut self, methods: Vec<Method>) -> Self {
197		self.allowed_methods = methods;
198		self
199	}
200}
201
202#[async_trait::async_trait]
203impl reinhardt_views::View for SimpleTestView {
204	async fn dispatch(&self, request: Request) -> Result<Response> {
205		if !self.allowed_methods.contains(&request.method) {
206			return Err(Error::Validation(format!(
207				"Method {} not allowed",
208				request.method
209			)));
210		}
211
212		Ok(Response::ok().with_body(self.content.clone().into_bytes()))
213	}
214}
215
216/// Create a view that always returns an error for testing error handling
217pub struct ErrorTestView {
218	pub error_message: String,
219	pub error_kind: ErrorKind,
220}
221
222pub enum ErrorKind {
223	NotFound,
224	Validation,
225	Internal,
226	Authentication,
227	Authorization,
228}
229
230impl ErrorTestView {
231	pub fn new(error_message: String, error_kind: ErrorKind) -> Self {
232		Self {
233			error_message,
234			error_kind,
235		}
236	}
237
238	pub fn not_found(message: impl Into<String>) -> Self {
239		Self::new(message.into(), ErrorKind::NotFound)
240	}
241
242	pub fn validation(message: impl Into<String>) -> Self {
243		Self::new(message.into(), ErrorKind::Validation)
244	}
245}
246
247#[async_trait::async_trait]
248impl reinhardt_views::View for ErrorTestView {
249	async fn dispatch(&self, _request: Request) -> Result<Response> {
250		match self.error_kind {
251			ErrorKind::NotFound => Err(Error::NotFound(self.error_message.clone())),
252			ErrorKind::Validation => Err(Error::Validation(self.error_message.clone())),
253			ErrorKind::Internal => Err(Error::Internal(self.error_message.clone())),
254			ErrorKind::Authentication => Err(Error::Authentication(self.error_message.clone())),
255			ErrorKind::Authorization => Err(Error::Authorization(self.error_message.clone())),
256		}
257	}
258}