1use parking_lot::RwLock;
7use std::sync::Arc;
8use wae_types::{WaeError, WaeErrorKind, WaeResult as TestingResult};
9
10#[cfg(feature = "mockall")]
11pub use mockall::{self, __mock_MockObject, __mock_MockStatic, automock, mock, predicate, sequence};
12
13#[derive(Debug, Clone)]
15pub struct MockCall {
16 pub args: Vec<String>,
18 pub timestamp: std::time::Instant,
20}
21
22#[derive(Debug)]
24pub enum MockResult<T> {
25 Return(T),
27 Error(String),
29 Sequence(Vec<T>),
31}
32
33impl<T: Clone> MockResult<T> {
34 pub fn return_value(value: T) -> Self {
36 MockResult::Return(value)
37 }
38
39 pub fn error(msg: impl Into<String>) -> Self {
41 MockResult::Error(msg.into())
42 }
43
44 pub fn sequence(values: Vec<T>) -> Self {
46 MockResult::Sequence(values)
47 }
48}
49
50impl<T: Clone> Clone for MockResult<T> {
51 fn clone(&self) -> Self {
52 match self {
53 MockResult::Return(v) => MockResult::Return(v.clone()),
54 MockResult::Error(e) => MockResult::Error(e.clone()),
55 MockResult::Sequence(v) => MockResult::Sequence(v.clone()),
56 }
57 }
58}
59
60#[derive(Debug, Default)]
62pub struct MockExpectation {
63 pub expected_calls: Option<usize>,
65 pub description: Option<String>,
67}
68
69impl MockExpectation {
70 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub fn times(mut self, count: usize) -> Self {
77 self.expected_calls = Some(count);
78 self
79 }
80
81 pub fn description(mut self, desc: impl Into<String>) -> Self {
83 self.description = Some(desc.into());
84 self
85 }
86}
87
88pub trait Mock: Send + Sync {
90 fn calls(&self) -> Vec<MockCall>;
92
93 fn call_count(&self) -> usize;
95
96 fn verify(&self) -> TestingResult<()>;
98
99 fn reset(&self);
101}
102
103#[allow(async_fn_in_trait)]
105pub trait AsyncMock: Mock {
106 async fn verify_async(&self) -> TestingResult<()>;
108}
109
110pub struct MockBuilder<T> {
112 result: Option<MockResult<T>>,
113 expectation: MockExpectation,
114 calls: Arc<RwLock<Vec<MockCall>>>,
115}
116
117impl<T: Clone + Send + Sync + 'static> MockBuilder<T> {
118 pub fn new() -> Self {
120 Self { result: None, expectation: MockExpectation::default(), calls: Arc::new(RwLock::new(Vec::new())) }
121 }
122
123 pub fn return_value(mut self, value: T) -> Self {
125 self.result = Some(MockResult::return_value(value));
126 self
127 }
128
129 pub fn error(mut self, msg: impl Into<String>) -> Self {
131 self.result = Some(MockResult::error(msg));
132 self
133 }
134
135 pub fn sequence(mut self, values: Vec<T>) -> Self {
137 self.result = Some(MockResult::sequence(values));
138 self
139 }
140
141 pub fn expect(mut self, expectation: MockExpectation) -> Self {
143 self.expectation = expectation;
144 self
145 }
146
147 pub fn build(self) -> MockFn<T> {
149 MockFn {
150 result: self.result,
151 expectation: self.expectation,
152 calls: self.calls,
153 sequence_index: Arc::new(RwLock::new(0)),
154 }
155 }
156}
157
158impl<T: Clone + Send + Sync + 'static> Default for MockBuilder<T> {
159 fn default() -> Self {
160 Self::new()
161 }
162}
163
164pub struct MockFn<T> {
166 result: Option<MockResult<T>>,
167 expectation: MockExpectation,
168 calls: Arc<RwLock<Vec<MockCall>>>,
169 sequence_index: Arc<RwLock<usize>>,
170}
171
172impl<T: Clone + Send + Sync + 'static> MockFn<T> {
173 pub fn call(&self, args: Vec<String>) -> TestingResult<T> {
175 {
176 let mut calls = self.calls.write();
177 calls.push(MockCall { args, timestamp: std::time::Instant::now() });
178 }
179
180 match &self.result {
181 Some(MockResult::Return(v)) => Ok(v.clone()),
182 Some(MockResult::Error(e)) => Err(WaeError::new(WaeErrorKind::MockError { reason: e.clone() })),
183 Some(MockResult::Sequence(values)) => {
184 let mut idx = self.sequence_index.write();
185 if *idx < values.len() {
186 let value = values[*idx].clone();
187 *idx += 1;
188 Ok(value)
189 }
190 else {
191 Err(WaeError::new(WaeErrorKind::MockError { reason: "Mock sequence exhausted".to_string() }))
192 }
193 }
194 None => Err(WaeError::new(WaeErrorKind::MockError { reason: "No mock result configured".to_string() })),
195 }
196 }
197
198 pub async fn call_async(&self, args: Vec<String>) -> TestingResult<T> {
200 self.call(args)
201 }
202}
203
204impl<T: Clone + Send + Sync + 'static> Mock for MockFn<T> {
205 fn calls(&self) -> Vec<MockCall> {
206 self.calls.read().clone()
207 }
208
209 fn call_count(&self) -> usize {
210 self.calls.read().len()
211 }
212
213 fn verify(&self) -> TestingResult<()> {
214 let actual_calls = self.call_count();
215
216 if let Some(expected) = self.expectation.expected_calls
217 && actual_calls != expected
218 {
219 return Err(WaeError::new(WaeErrorKind::AssertionFailed {
220 message: format!("Expected {} calls, but got {}", expected, actual_calls),
221 }));
222 }
223
224 Ok(())
225 }
226
227 fn reset(&self) {
228 let mut calls = self.calls.write();
229 calls.clear();
230 let mut idx = self.sequence_index.write();
231 *idx = 0;
232 }
233}
234
235impl<T: Clone + Send + Sync + 'static> AsyncMock for MockFn<T> {
236 async fn verify_async(&self) -> TestingResult<()> {
237 self.verify()
238 }
239}
240
241pub fn verify<M: Mock>(mock: &M) -> TestingResult<()> {
243 mock.verify()
244}
245
246pub async fn verify_async<M: AsyncMock>(mock: &M) -> TestingResult<()> {
248 mock.verify_async().await
249}