use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct CallRecord {
pub args: Vec<serde_json::Value>,
pub timestamp: std::time::Instant,
}
pub struct MockFunction<T> {
calls: Arc<Mutex<Vec<CallRecord>>>,
return_values: Arc<Mutex<VecDeque<T>>>,
default_return: Option<T>,
}
impl<T: Clone> MockFunction<T> {
pub fn new() -> Self {
Self {
calls: Arc::new(Mutex::new(Vec::new())),
return_values: Arc::new(Mutex::new(VecDeque::new())),
default_return: None,
}
}
pub fn with_default(default_return: T) -> Self {
Self {
calls: Arc::new(Mutex::new(Vec::new())),
return_values: Arc::new(Mutex::new(VecDeque::new())),
default_return: Some(default_return),
}
}
pub async fn returns(&self, value: T) {
self.return_values.lock().await.push_back(value);
}
pub async fn returns_many(&self, values: Vec<T>) {
let mut queue = self.return_values.lock().await;
for value in values {
queue.push_back(value);
}
}
pub async fn call(&self, args: Vec<serde_json::Value>) -> Option<T> {
let record = CallRecord {
args,
timestamp: std::time::Instant::now(),
};
self.calls.lock().await.push(record);
let mut queue = self.return_values.lock().await;
queue.pop_front().or_else(|| self.default_return.clone())
}
pub async fn call_count(&self) -> usize {
self.calls.lock().await.len()
}
pub async fn was_called(&self) -> bool {
self.call_count().await > 0
}
pub async fn was_called_with(&self, args: Vec<serde_json::Value>) -> bool {
let calls = self.calls.lock().await;
calls.iter().any(|record| record.args == args)
}
pub async fn get_calls(&self) -> Vec<CallRecord> {
self.calls.lock().await.clone()
}
pub async fn reset(&self) {
self.calls.lock().await.clear();
self.return_values.lock().await.clear();
}
pub async fn last_call_args(&self) -> Option<Vec<serde_json::Value>> {
self.calls.lock().await.last().map(|r| r.args.clone())
}
}
impl<T: Clone> Default for MockFunction<T> {
fn default() -> Self {
Self::new()
}
}
pub struct Spy<T> {
inner: Option<T>,
calls: Arc<Mutex<Vec<CallRecord>>>,
}
impl<T> Spy<T> {
pub fn new() -> Self {
Self {
inner: None,
calls: Arc::new(Mutex::new(Vec::new())),
}
}
pub fn wrap(inner: T) -> Self {
Self {
inner: Some(inner),
calls: Arc::new(Mutex::new(Vec::new())),
}
}
pub async fn record_call(&self, args: Vec<serde_json::Value>) {
let record = CallRecord {
args,
timestamp: std::time::Instant::now(),
};
self.calls.lock().await.push(record);
}
pub async fn call_count(&self) -> usize {
self.calls.lock().await.len()
}
pub async fn was_called(&self) -> bool {
self.call_count().await > 0
}
pub async fn was_called_with(&self, args: Vec<serde_json::Value>) -> bool {
let calls = self.calls.lock().await;
calls.iter().any(|record| record.args == args)
}
pub async fn get_calls(&self) -> Vec<CallRecord> {
self.calls.lock().await.clone()
}
pub async fn last_call_args(&self) -> Option<Vec<serde_json::Value>> {
self.calls.lock().await.last().map(|r| r.args.clone())
}
pub async fn reset(&self) {
self.calls.lock().await.clear();
}
pub fn inner(&self) -> Option<&T> {
self.inner.as_ref()
}
pub fn into_inner(self) -> Option<T> {
self.inner
}
}
impl<T> Default for Spy<T> {
fn default() -> Self {
Self::new()
}
}
pub struct SimpleHandler<F>
where
F: Fn(reinhardt_http::Request) -> reinhardt_http::Result<reinhardt_http::Response>
+ Send
+ Sync
+ 'static,
{
handler_fn: F,
}
impl<F> SimpleHandler<F>
where
F: Fn(reinhardt_http::Request) -> reinhardt_http::Result<reinhardt_http::Response>
+ Send
+ Sync
+ 'static,
{
pub fn new(handler_fn: F) -> Self {
Self { handler_fn }
}
}
#[async_trait::async_trait]
impl<F> reinhardt_http::Handler for SimpleHandler<F>
where
F: Fn(reinhardt_http::Request) -> reinhardt_http::Result<reinhardt_http::Response>
+ Send
+ Sync
+ 'static,
{
async fn handle(
&self,
request: reinhardt_http::Request,
) -> reinhardt_http::Result<reinhardt_http::Response> {
(self.handler_fn)(request)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_function() {
let mock = MockFunction::<i32>::new();
mock.returns(42).await;
mock.returns(100).await;
let result1 = mock.call(vec![serde_json::json!(1)]).await;
assert_eq!(result1, Some(42));
let result2 = mock.call(vec![serde_json::json!(2)]).await;
assert_eq!(result2, Some(100));
assert_eq!(mock.call_count().await, 2);
assert!(mock.was_called().await);
}
#[tokio::test]
async fn test_mock_default() {
let mock = MockFunction::with_default(99);
let result = mock.call(vec![]).await;
assert_eq!(result, Some(99));
}
#[tokio::test]
async fn test_spy() {
use serde_json::json;
let spy: Spy<String> = Spy::new();
spy.record_call(vec![json!("arg1")]).await;
spy.record_call(vec![json!("arg2")]).await;
assert_eq!(spy.call_count().await, 2);
assert!(spy.was_called().await);
assert!(spy.was_called_with(vec![json!("arg1")]).await);
}
#[tokio::test]
async fn test_mock_reset() {
let mock = MockFunction::<i32>::new();
mock.call(vec![]).await;
assert_eq!(mock.call_count().await, 1);
mock.reset().await;
assert_eq!(mock.call_count().await, 0);
}
}