Skip to main content

reinhardt_testkit/
mock.rs

1use std::collections::VecDeque;
2use std::sync::Arc;
3use tokio::sync::Mutex;
4
5/// Call record for tracking function calls
6#[derive(Debug, Clone)]
7pub struct CallRecord {
8	/// Arguments passed to the function call (serialized as JSON values).
9	pub args: Vec<serde_json::Value>,
10	/// Timestamp when the call was recorded.
11	pub timestamp: std::time::Instant,
12}
13
14/// Mock function tracker
15pub struct MockFunction<T> {
16	calls: Arc<Mutex<Vec<CallRecord>>>,
17	return_values: Arc<Mutex<VecDeque<T>>>,
18	default_return: Option<T>,
19}
20
21impl<T: Clone> MockFunction<T> {
22	/// Create a new mock function
23	///
24	/// # Examples
25	///
26	/// ```
27	/// use reinhardt_testkit::mock::MockFunction;
28	///
29	/// # tokio_test::block_on(async {
30	/// let mock = MockFunction::<i32>::new();
31	/// assert_eq!(mock.call_count().await, 0);
32	/// # });
33	/// ```
34	pub fn new() -> Self {
35		Self {
36			calls: Arc::new(Mutex::new(Vec::new())),
37			return_values: Arc::new(Mutex::new(VecDeque::new())),
38			default_return: None,
39		}
40	}
41	/// Create a mock function with a default return value
42	///
43	/// # Examples
44	///
45	/// ```
46	/// use reinhardt_testkit::mock::MockFunction;
47	///
48	/// # tokio_test::block_on(async {
49	/// let mock = MockFunction::with_default(42);
50	/// let result = mock.call(vec![]).await;
51	/// assert_eq!(result, Some(42));
52	/// # });
53	/// ```
54	pub fn with_default(default_return: T) -> Self {
55		Self {
56			calls: Arc::new(Mutex::new(Vec::new())),
57			return_values: Arc::new(Mutex::new(VecDeque::new())),
58			default_return: Some(default_return),
59		}
60	}
61	/// Queue a return value for the next call
62	///
63	/// # Examples
64	///
65	/// ```
66	/// use reinhardt_testkit::mock::MockFunction;
67	///
68	/// # tokio_test::block_on(async {
69	/// let mock = MockFunction::<i32>::new();
70	/// mock.returns(42).await;
71	///
72	/// let result = mock.call(vec![]).await;
73	/// assert_eq!(result, Some(42));
74	/// # });
75	/// ```
76	pub async fn returns(&self, value: T) {
77		self.return_values.lock().await.push_back(value);
78	}
79	/// Queue multiple return values for sequential calls
80	///
81	/// # Examples
82	///
83	/// ```
84	/// use reinhardt_testkit::mock::MockFunction;
85	///
86	/// # tokio_test::block_on(async {
87	/// let mock = MockFunction::<i32>::new();
88	/// mock.returns_many(vec![1, 2, 3]).await;
89	///
90	/// assert_eq!(mock.call(vec![]).await, Some(1));
91	/// assert_eq!(mock.call(vec![]).await, Some(2));
92	/// assert_eq!(mock.call(vec![]).await, Some(3));
93	/// # });
94	/// ```
95	pub async fn returns_many(&self, values: Vec<T>) {
96		let mut queue = self.return_values.lock().await;
97		for value in values {
98			queue.push_back(value);
99		}
100	}
101	/// Record a call and return the next queued value
102	///
103	/// # Examples
104	///
105	/// ```
106	/// use reinhardt_testkit::mock::MockFunction;
107	/// use serde_json::json;
108	///
109	/// # tokio_test::block_on(async {
110	/// let mock = MockFunction::<i32>::new();
111	/// mock.returns(42).await;
112	///
113	/// let result = mock.call(vec![json!("arg1"), json!(123)]).await;
114	/// assert_eq!(result, Some(42));
115	/// assert_eq!(mock.call_count().await, 1);
116	/// # });
117	/// ```
118	pub async fn call(&self, args: Vec<serde_json::Value>) -> Option<T> {
119		let record = CallRecord {
120			args,
121			timestamp: std::time::Instant::now(),
122		};
123		self.calls.lock().await.push(record);
124
125		let mut queue = self.return_values.lock().await;
126		queue.pop_front().or_else(|| self.default_return.clone())
127	}
128	/// Get the number of times the function has been called
129	///
130	/// # Examples
131	///
132	/// ```
133	/// use reinhardt_testkit::mock::MockFunction;
134	///
135	/// # tokio_test::block_on(async {
136	/// let mock = MockFunction::<i32>::new();
137	/// assert_eq!(mock.call_count().await, 0);
138	///
139	/// mock.call(vec![]).await;
140	/// assert_eq!(mock.call_count().await, 1);
141	/// # });
142	/// ```
143	pub async fn call_count(&self) -> usize {
144		self.calls.lock().await.len()
145	}
146	/// Check if the function has been called at least once
147	///
148	/// # Examples
149	///
150	/// ```
151	/// use reinhardt_testkit::mock::MockFunction;
152	///
153	/// # tokio_test::block_on(async {
154	/// let mock = MockFunction::<i32>::new();
155	/// assert!(!mock.was_called().await);
156	///
157	/// mock.call(vec![]).await;
158	/// assert!(mock.was_called().await);
159	/// # });
160	/// ```
161	pub async fn was_called(&self) -> bool {
162		self.call_count().await > 0
163	}
164	/// Check if the function was called with specific arguments
165	///
166	/// # Examples
167	///
168	/// ```
169	/// use reinhardt_testkit::mock::MockFunction;
170	/// use serde_json::json;
171	///
172	/// # tokio_test::block_on(async {
173	/// let mock = MockFunction::<i32>::new();
174	/// mock.call(vec![json!("test"), json!(42)]).await;
175	///
176	/// assert!(mock.was_called_with(vec![json!("test"), json!(42)]).await);
177	/// assert!(!mock.was_called_with(vec![json!("other")]).await);
178	/// # });
179	/// ```
180	pub async fn was_called_with(&self, args: Vec<serde_json::Value>) -> bool {
181		let calls = self.calls.lock().await;
182		calls.iter().any(|record| record.args == args)
183	}
184	/// Get all call records for inspection
185	///
186	/// # Examples
187	///
188	/// ```
189	/// use reinhardt_testkit::mock::MockFunction;
190	/// use serde_json::json;
191	///
192	/// # tokio_test::block_on(async {
193	/// let mock = MockFunction::<i32>::new();
194	/// mock.call(vec![json!("arg1")]).await;
195	/// mock.call(vec![json!("arg2")]).await;
196	///
197	/// let calls = mock.get_calls().await;
198	/// assert_eq!(calls.len(), 2);
199	/// assert_eq!(calls[0].args, vec![json!("arg1")]);
200	/// # });
201	/// ```
202	pub async fn get_calls(&self) -> Vec<CallRecord> {
203		self.calls.lock().await.clone()
204	}
205	/// Reset the mock to its initial state
206	///
207	/// # Examples
208	///
209	/// ```
210	/// use reinhardt_testkit::mock::MockFunction;
211	///
212	/// # tokio_test::block_on(async {
213	/// let mock = MockFunction::<i32>::new();
214	/// mock.call(vec![]).await;
215	/// assert_eq!(mock.call_count().await, 1);
216	///
217	/// mock.reset().await;
218	/// assert_eq!(mock.call_count().await, 0);
219	/// # });
220	/// ```
221	pub async fn reset(&self) {
222		self.calls.lock().await.clear();
223		self.return_values.lock().await.clear();
224	}
225	/// Get the arguments from the last function call
226	///
227	/// # Examples
228	///
229	/// ```
230	/// use reinhardt_testkit::mock::MockFunction;
231	/// use serde_json::json;
232	///
233	/// # tokio_test::block_on(async {
234	/// let mock = MockFunction::<i32>::new();
235	/// mock.call(vec![json!("first")]).await;
236	/// mock.call(vec![json!("last")]).await;
237	///
238	/// let last_args = mock.last_call_args().await;
239	/// assert_eq!(last_args, Some(vec![json!("last")]));
240	/// # });
241	/// ```
242	pub async fn last_call_args(&self) -> Option<Vec<serde_json::Value>> {
243		self.calls.lock().await.last().map(|r| r.args.clone())
244	}
245}
246
247impl<T: Clone> Default for MockFunction<T> {
248	fn default() -> Self {
249		Self::new()
250	}
251}
252
253/// Spy for tracking method calls with arguments
254pub struct Spy<T> {
255	inner: Option<T>,
256	calls: Arc<Mutex<Vec<CallRecord>>>,
257}
258
259impl<T> Spy<T> {
260	/// Create a new spy without wrapping any object
261	///
262	/// # Examples
263	///
264	/// ```
265	/// use reinhardt_testkit::mock::Spy;
266	///
267	/// let spy = Spy::<String>::new();
268	/// assert!(spy.inner().is_none());
269	/// ```
270	pub fn new() -> Self {
271		Self {
272			inner: None,
273			calls: Arc::new(Mutex::new(Vec::new())),
274		}
275	}
276	/// Create a spy that wraps an existing object
277	///
278	/// # Examples
279	///
280	/// ```
281	/// use reinhardt_testkit::mock::Spy;
282	///
283	/// let value = "test".to_string();
284	/// let spy = Spy::wrap(value);
285	/// assert!(spy.inner().is_some());
286	/// ```
287	pub fn wrap(inner: T) -> Self {
288		Self {
289			inner: Some(inner),
290			calls: Arc::new(Mutex::new(Vec::new())),
291		}
292	}
293	/// Record a method call with arguments
294	///
295	/// # Examples
296	///
297	/// ```
298	/// use reinhardt_testkit::mock::Spy;
299	/// use serde_json::json;
300	///
301	/// # tokio_test::block_on(async {
302	/// let spy = Spy::<String>::new();
303	/// spy.record_call(vec![json!("arg1"), json!(42)]).await;
304	/// assert_eq!(spy.call_count().await, 1);
305	/// # });
306	/// ```
307	pub async fn record_call(&self, args: Vec<serde_json::Value>) {
308		let record = CallRecord {
309			args,
310			timestamp: std::time::Instant::now(),
311		};
312		self.calls.lock().await.push(record);
313	}
314	/// Get the total number of recorded calls.
315	pub async fn call_count(&self) -> usize {
316		self.calls.lock().await.len()
317	}
318	/// Check if the spy was called at least once.
319	pub async fn was_called(&self) -> bool {
320		self.call_count().await > 0
321	}
322	/// Check if the spy was called with specific arguments
323	///
324	/// # Examples
325	///
326	/// ```
327	/// use reinhardt_testkit::mock::Spy;
328	/// use serde_json::json;
329	///
330	/// # tokio_test::block_on(async {
331	/// let spy = Spy::<String>::new();
332	/// spy.record_call(vec![json!("test")]).await;
333	///
334	/// assert!(spy.was_called_with(vec![json!("test")]).await);
335	/// assert!(!spy.was_called_with(vec![json!("other")]).await);
336	/// # });
337	/// ```
338	pub async fn was_called_with(&self, args: Vec<serde_json::Value>) -> bool {
339		let calls = self.calls.lock().await;
340		calls.iter().any(|record| record.args == args)
341	}
342	/// Get all recorded call records.
343	pub async fn get_calls(&self) -> Vec<CallRecord> {
344		self.calls.lock().await.clone()
345	}
346	/// Get the arguments from the last recorded call, if any.
347	pub async fn last_call_args(&self) -> Option<Vec<serde_json::Value>> {
348		self.calls.lock().await.last().map(|r| r.args.clone())
349	}
350	/// Reset the spy by clearing all call records
351	///
352	/// # Examples
353	///
354	/// ```
355	/// use reinhardt_testkit::mock::Spy;
356	/// use serde_json::json;
357	///
358	/// # tokio_test::block_on(async {
359	/// let spy = Spy::<String>::new();
360	/// spy.record_call(vec![json!("test")]).await;
361	/// assert_eq!(spy.call_count().await, 1);
362	///
363	/// spy.reset().await;
364	/// assert_eq!(spy.call_count().await, 0);
365	/// # });
366	/// ```
367	pub async fn reset(&self) {
368		self.calls.lock().await.clear();
369	}
370	/// Get a reference to the wrapped inner value, if any.
371	pub fn inner(&self) -> Option<&T> {
372		self.inner.as_ref()
373	}
374	/// Consume the spy and return the wrapped inner value, if any.
375	pub fn into_inner(self) -> Option<T> {
376		self.inner
377	}
378}
379
380impl<T> Default for Spy<T> {
381	fn default() -> Self {
382		Self::new()
383	}
384}
385
386// ============================================================================
387// Handler Mocks
388// ============================================================================
389
390/// Simple handler wrapper for testing
391///
392/// Provides a convenient way to create handlers from closures for testing purposes.
393/// The handler function can be any closure that takes a `Request` and returns a
394/// [`Result<Response>`].
395///
396/// # Examples
397///
398/// ## Basic usage
399///
400/// ```no_run
401/// use reinhardt_testkit::mock::SimpleHandler;
402/// use reinhardt_http::{Request, Response};
403/// use reinhardt_http::Handler;
404///
405/// let handler = SimpleHandler::new(|req: Request| {
406///     Ok(Response::ok().with_body("Hello, World!"))
407/// });
408///
409/// // Use handler in tests
410/// ```
411///
412/// ## With path-based routing
413///
414/// ```no_run
415/// use reinhardt_testkit::mock::SimpleHandler;
416/// use reinhardt_http::{Request, Response};
417///
418/// let handler = SimpleHandler::new(|req: Request| {
419///     match req.path() {
420///         "/" => Ok(Response::ok().with_body("Home")),
421///         "/api" => Ok(Response::ok().with_body(r#"{"status": "ok"}"#)),
422///         _ => Ok(Response::not_found().with_body("Not Found")),
423///     }
424/// });
425/// ```
426///
427/// ## With custom logic
428///
429/// ```no_run
430/// use reinhardt_testkit::mock::SimpleHandler;
431/// use reinhardt_http::{Request, Response};
432/// use std::sync::{Arc, Mutex};
433///
434/// let call_count = Arc::new(Mutex::new(0));
435/// let call_count_clone = call_count.clone();
436///
437/// let handler = SimpleHandler::new(move |req: Request| {
438///     let mut count = call_count_clone.lock().unwrap();
439///     *count += 1;
440///     Ok(Response::ok().with_body(format!("Call count: {}", *count)))
441/// });
442/// ```
443pub struct SimpleHandler<F>
444where
445	F: Fn(reinhardt_http::Request) -> reinhardt_http::Result<reinhardt_http::Response>
446		+ Send
447		+ Sync
448		+ 'static,
449{
450	handler_fn: F,
451}
452
453impl<F> SimpleHandler<F>
454where
455	F: Fn(reinhardt_http::Request) -> reinhardt_http::Result<reinhardt_http::Response>
456		+ Send
457		+ Sync
458		+ 'static,
459{
460	/// Create a new SimpleHandler with the given handler function
461	///
462	/// # Arguments
463	///
464	/// * `handler_fn` - A closure that processes requests and returns responses
465	///
466	/// # Examples
467	///
468	/// ```no_run
469	/// use reinhardt_testkit::mock::SimpleHandler;
470	/// use reinhardt_http::{Request, Response};
471	///
472	/// let handler = SimpleHandler::new(|req| {
473	///     Ok(Response::ok().with_body("Success"))
474	/// });
475	/// ```
476	pub fn new(handler_fn: F) -> Self {
477		Self { handler_fn }
478	}
479}
480
481#[async_trait::async_trait]
482impl<F> reinhardt_http::Handler for SimpleHandler<F>
483where
484	F: Fn(reinhardt_http::Request) -> reinhardt_http::Result<reinhardt_http::Response>
485		+ Send
486		+ Sync
487		+ 'static,
488{
489	async fn handle(
490		&self,
491		request: reinhardt_http::Request,
492	) -> reinhardt_http::Result<reinhardt_http::Response> {
493		(self.handler_fn)(request)
494	}
495}
496
497#[cfg(test)]
498mod tests {
499	use super::*;
500
501	#[tokio::test]
502	async fn test_mock_function() {
503		let mock = MockFunction::<i32>::new();
504
505		mock.returns(42).await;
506		mock.returns(100).await;
507
508		let result1 = mock.call(vec![serde_json::json!(1)]).await;
509		assert_eq!(result1, Some(42));
510
511		let result2 = mock.call(vec![serde_json::json!(2)]).await;
512		assert_eq!(result2, Some(100));
513
514		assert_eq!(mock.call_count().await, 2);
515		assert!(mock.was_called().await);
516	}
517
518	#[tokio::test]
519	async fn test_mock_default() {
520		let mock = MockFunction::with_default(99);
521
522		let result = mock.call(vec![]).await;
523		assert_eq!(result, Some(99));
524	}
525
526	#[tokio::test]
527	async fn test_spy() {
528		use serde_json::json;
529
530		let spy: Spy<String> = Spy::new();
531
532		spy.record_call(vec![json!("arg1")]).await;
533		spy.record_call(vec![json!("arg2")]).await;
534
535		assert_eq!(spy.call_count().await, 2);
536		assert!(spy.was_called().await);
537		assert!(spy.was_called_with(vec![json!("arg1")]).await);
538	}
539
540	#[tokio::test]
541	async fn test_mock_reset() {
542		let mock = MockFunction::<i32>::new();
543		mock.call(vec![]).await;
544		assert_eq!(mock.call_count().await, 1);
545
546		mock.reset().await;
547		assert_eq!(mock.call_count().await, 0);
548	}
549}