rustapi_testing/
server.rs

1use super::expectation::{Expectation, MockResponse, Times};
2use super::matcher::RequestMatcher;
3use bytes::Bytes;
4use http_body_util::{BodyExt, Full};
5use hyper::service::service_fn;
6use hyper::{Request, Response, StatusCode};
7use hyper_util::rt::TokioIo;
8use std::net::SocketAddr;
9use std::sync::{Arc, Mutex};
10use tokio::net::TcpListener;
11use tokio::sync::oneshot;
12
13type GenericError = Box<dyn std::error::Error + Send + Sync>;
14type Result<T> = std::result::Result<T, GenericError>;
15
16/// A mock HTTP server
17pub struct MockServer {
18    addr: SocketAddr,
19    state: Arc<Mutex<ServerState>>,
20    shutdown_tx: Option<oneshot::Sender<()>>,
21}
22
23struct ServerState {
24    expectations: Vec<Expectation>,
25    unmatched_requests: Vec<RecordedRequest>,
26}
27
28#[derive(Debug, Clone)]
29pub struct RecordedRequest {
30    pub method: http::Method,
31    pub path: String,
32    pub headers: http::HeaderMap,
33    pub body: Bytes,
34}
35
36impl MockServer {
37    /// Start a new mock server on a random port
38    pub async fn start() -> Self {
39        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
40        let addr = listener.local_addr().unwrap();
41
42        let state = Arc::new(Mutex::new(ServerState {
43            expectations: Vec::new(),
44            unmatched_requests: Vec::new(),
45        }));
46
47        let state_clone = state.clone();
48        let (shutdown_tx, shutdown_rx) = oneshot::channel();
49
50        tokio::spawn(async move {
51            let mut stop_future = shutdown_rx;
52
53            loop {
54                tokio::select! {
55                    res = listener.accept() => {
56                        match res {
57                            Ok((stream, _)) => {
58                                let io = TokioIo::new(stream);
59                                let state = state_clone.clone();
60
61                                tokio::spawn(async move {
62                                    if let Err(err) = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
63                                        .serve_connection(io, service_fn(move |req| handle_request(req, state.clone())))
64                                        .await
65                                    {
66                                        eprintln!("Error serving connection: {:?}", err);
67                                    }
68                                });
69                            }
70                            Err(e) => eprintln!("Accept error: {}", e),
71                        }
72                    }
73                    _ = &mut stop_future => {
74                        break;
75                    }
76                }
77            }
78        });
79
80        Self {
81            addr,
82            state,
83            shutdown_tx: Some(shutdown_tx),
84        }
85    }
86
87    /// Get the base URL of the server
88    pub fn kind_url(&self) -> String {
89        format!("http://{}", self.addr)
90    }
91
92    /// Alias for kind_url but more standard name
93    pub fn base_url(&self) -> String {
94        self.kind_url()
95    }
96
97    /// Get requests that didn't match any expectation
98    pub fn unmatched_requests(&self) -> Vec<RecordedRequest> {
99        let state = self.state.lock().unwrap();
100        state.unmatched_requests.clone()
101    }
102
103    /// Add an expectation
104    pub fn expect(&self, matcher: RequestMatcher) -> ExpectationBuilder {
105        ExpectationBuilder {
106            server: self.state.clone(),
107            expectation: Some(Expectation::new(matcher)),
108        }
109    }
110
111    /// Verify that all expectations were met
112    pub fn verify(&self) {
113        let state = self.state.lock().unwrap();
114        for exp in &state.expectations {
115            match exp.times {
116                Times::Once => assert_eq!(
117                    exp.call_count, 1,
118                    "Expectation {:?} expected 1 call, got {}",
119                    exp.matcher, exp.call_count
120                ),
121                Times::Exactly(n) => assert_eq!(
122                    exp.call_count, n,
123                    "Expectation {:?} expected {} calls, got {}",
124                    exp.matcher, n, exp.call_count
125                ),
126                Times::AtLeast(n) => assert!(
127                    exp.call_count >= n,
128                    "Expectation {:?} expected at least {} calls, got {}",
129                    exp.matcher,
130                    n,
131                    exp.call_count
132                ),
133                Times::AtMost(n) => assert!(
134                    exp.call_count <= n,
135                    "Expectation {:?} expected at most {} calls, got {}",
136                    exp.matcher,
137                    n,
138                    exp.call_count
139                ),
140                Times::Any => {}
141            }
142        }
143    }
144}
145
146impl Drop for MockServer {
147    fn drop(&mut self) {
148        if let Some(tx) = self.shutdown_tx.take() {
149            let _ = tx.send(());
150        }
151    }
152}
153
154pub struct ExpectationBuilder {
155    server: Arc<Mutex<ServerState>>,
156    expectation: Option<Expectation>,
157}
158
159impl ExpectationBuilder {
160    pub fn respond_with(mut self, response: MockResponse) -> Self {
161        if let Some(exp) = self.expectation.as_mut() {
162            exp.response = response;
163        }
164        self
165    }
166
167    pub fn times(mut self, n: usize) -> Self {
168        if let Some(exp) = self.expectation.as_mut() {
169            exp.times = Times::Exactly(n);
170        }
171        self
172    }
173
174    pub fn once(mut self) -> Self {
175        if let Some(exp) = self.expectation.as_mut() {
176            exp.times = Times::Once;
177        }
178        self
179    }
180
181    pub fn at_least_once(mut self) -> Self {
182        if let Some(exp) = self.expectation.as_mut() {
183            exp.times = Times::AtLeast(1);
184        }
185        self
186    }
187
188    pub fn never(mut self) -> Self {
189        if let Some(exp) = self.expectation.as_mut() {
190            exp.times = Times::Exactly(0);
191        }
192        self
193    }
194}
195
196impl Drop for ExpectationBuilder {
197    fn drop(&mut self) {
198        if let Some(exp) = self.expectation.take() {
199            let mut state = self.server.lock().unwrap();
200            state.expectations.push(exp);
201        }
202    }
203}
204
205async fn handle_request(
206    req: Request<hyper::body::Incoming>,
207    state: Arc<Mutex<ServerState>>,
208) -> Result<Response<Full<Bytes>>> {
209    // Read the full body
210    let (parts, body) = req.into_parts();
211    let body_bytes = body.collect().await?.to_bytes();
212
213    let mut state_guard = state.lock().unwrap();
214
215    // Find matching expectation
216    // We iterate in reverse to prioritize later expectations (override)
217    let matching_idx = state_guard
218        .expectations
219        .iter()
220        .enumerate()
221        .rev()
222        .find(|(_, exp)| {
223            exp.matcher
224                .matches(&parts.method, parts.uri.path(), &parts.headers, &body_bytes)
225        })
226        .map(|(i, _)| i);
227
228    if let Some(idx) = matching_idx {
229        let exp = &mut state_guard.expectations[idx];
230        exp.call_count += 1;
231
232        let resp_def = &exp.response;
233        let mut response = Response::builder().status(resp_def.status);
234
235        for (k, v) in &resp_def.headers {
236            response = response.header(k, v);
237        }
238
239        Ok(response.body(Full::new(resp_def.body.clone()))?)
240    } else {
241        // Record unmatched
242        state_guard.unmatched_requests.push(RecordedRequest {
243            method: parts.method,
244            path: parts.uri.path().to_string(),
245            headers: parts.headers,
246            body: body_bytes,
247        });
248
249        Ok(Response::builder()
250            .status(StatusCode::NOT_FOUND)
251            .body(Full::new(Bytes::from("No expectation matched")))?)
252    }
253}