1use std::collections::HashMap;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::{Arc, Mutex};
6use std::thread::{self, JoinHandle};
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct RecordedRequest {
11 pub method: String,
12 pub path: String,
13 pub headers: Vec<(String, String)>,
14 pub body: Vec<u8>,
15}
16
17impl RecordedRequest {
18 pub fn body_text(&self) -> String {
19 String::from_utf8_lossy(&self.body).to_string()
20 }
21
22 pub fn header_value(&self, key: &str) -> Option<String> {
23 let key = key.to_ascii_lowercase();
24 self.headers
25 .iter()
26 .find(|(k, _)| k == &key)
27 .map(|(_, v)| v.clone())
28 }
29}
30
31#[derive(Clone, Debug)]
32pub struct HttpResponse {
33 pub status: u16,
34 pub body: String,
35 pub headers: Vec<(String, String)>,
36}
37
38impl HttpResponse {
39 pub fn new(status: u16, body: impl Into<String>) -> Self {
40 Self {
41 status,
42 body: body.into(),
43 headers: Vec::new(),
44 }
45 }
46
47 pub fn with_header(mut self, key: &str, value: &str) -> Self {
48 self.headers.push((key.to_string(), value.to_string()));
49 self
50 }
51}
52
53#[derive(Clone, Debug, Eq, PartialEq, Hash)]
54struct RouteKey {
55 method: String,
56 path: String,
57}
58
59pub struct LoopbackServer {
60 addr: SocketAddr,
61 routes: Arc<Mutex<HashMap<RouteKey, HttpResponse>>>,
62 requests: Arc<Mutex<Vec<RecordedRequest>>>,
63 stop: Arc<AtomicBool>,
64 handle: Option<JoinHandle<()>>,
65}
66
67impl LoopbackServer {
68 pub fn new() -> io::Result<Self> {
69 let listener = TcpListener::bind("127.0.0.1:0")?;
70 listener.set_nonblocking(true)?;
71 let addr = listener.local_addr()?;
72
73 let routes = Arc::new(Mutex::new(HashMap::new()));
74 let requests = Arc::new(Mutex::new(Vec::new()));
75 let stop = Arc::new(AtomicBool::new(false));
76
77 let routes_t = Arc::clone(&routes);
78 let requests_t = Arc::clone(&requests);
79 let stop_t = Arc::clone(&stop);
80
81 let handle = thread::spawn(move || {
82 while !stop_t.load(Ordering::SeqCst) {
83 match listener.accept() {
84 Ok((mut stream, _)) => {
85 let _ = handle_connection(&mut stream, &routes_t, &requests_t);
86 }
87 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
88 thread::yield_now();
89 }
90 Err(_) => break,
91 }
92 }
93 });
94
95 Ok(Self {
96 addr,
97 routes,
98 requests,
99 stop,
100 handle: Some(handle),
101 })
102 }
103
104 pub fn url(&self) -> String {
105 format!("http://{}", self.addr)
106 }
107
108 pub fn add_route(&self, method: &str, path: &str, response: HttpResponse) {
109 let key = RouteKey {
110 method: method.trim().to_ascii_uppercase(),
111 path: path.trim().to_string(),
112 };
113 let mut routes = self.routes.lock().expect("routes lock");
114 routes.insert(key, response);
115 }
116
117 pub fn take_requests(&self) -> Vec<RecordedRequest> {
118 let mut guard = self.requests.lock().expect("requests lock");
119 std::mem::take(&mut *guard)
120 }
121}
122
123impl Drop for LoopbackServer {
124 fn drop(&mut self) {
125 self.stop.store(true, Ordering::SeqCst);
126 if let Some(handle) = self.handle.take() {
127 let _ = handle.join();
128 }
129 }
130}
131
132fn handle_connection(
133 stream: &mut TcpStream,
134 routes: &Arc<Mutex<HashMap<RouteKey, HttpResponse>>>,
135 requests: &Arc<Mutex<Vec<RecordedRequest>>>,
136) -> io::Result<()> {
137 let request = read_request(stream)?;
138 requests
139 .lock()
140 .expect("requests lock")
141 .push(request.clone());
142
143 let route_key = RouteKey {
144 method: request.method.to_ascii_uppercase(),
145 path: request.path.clone(),
146 };
147
148 let response = routes
149 .lock()
150 .expect("routes lock")
151 .get(&route_key)
152 .cloned()
153 .unwrap_or_else(|| HttpResponse::new(404, "not found"));
154
155 write_response(stream, response)?;
156 Ok(())
157}
158
159fn read_request(stream: &mut TcpStream) -> io::Result<RecordedRequest> {
160 stream.set_nonblocking(false)?;
161 stream.set_read_timeout(Some(Duration::from_secs(2)))?;
162 let mut buffer = Vec::new();
163 let mut temp = [0u8; 8192];
164
165 loop {
166 let n = match stream.read(&mut temp) {
167 Ok(n) => n,
168 Err(err) if err.kind() == io::ErrorKind::WouldBlock => 0,
169 Err(err) if err.kind() == io::ErrorKind::TimedOut => 0,
170 Err(err) => return Err(err),
171 };
172 if n == 0 {
173 break;
174 }
175 buffer.extend_from_slice(&temp[..n]);
176 if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
177 break;
178 }
179 if buffer.len() > 1024 * 1024 {
180 break;
181 }
182 }
183
184 let (method, path, headers, rest) = parse_headers_and_rest(&buffer);
185
186 if header_value(&headers, "expect")
187 .is_some_and(|v| v.to_ascii_lowercase().contains("100-continue"))
188 {
189 let _ = stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n");
190 let _ = stream.flush();
191 }
192
193 let body = if let Some(te) =
194 header_value(&headers, "transfer-encoding").map(|v| v.to_ascii_lowercase())
195 {
196 if te.contains("chunked") {
197 read_chunked_body(stream, rest)
198 } else {
199 Vec::new()
200 }
201 } else if let Some(cl) = header_value(&headers, "content-length") {
202 let len = cl.parse::<usize>().unwrap_or(0);
203 read_exact_bytes(stream, rest, len)
204 } else {
205 rest
206 };
207
208 Ok(RecordedRequest {
209 method,
210 path,
211 headers,
212 body,
213 })
214}
215
216fn parse_headers_and_rest(buffer: &[u8]) -> (String, String, Vec<(String, String)>, Vec<u8>) {
217 let mut headers_end = None;
218 for i in 0..buffer.len().saturating_sub(3) {
219 if &buffer[i..i + 4] == b"\r\n\r\n" {
220 headers_end = Some(i);
221 break;
222 }
223 }
224
225 let header_end = headers_end.unwrap_or(buffer.len());
226 let headers_raw = String::from_utf8_lossy(&buffer[..header_end]);
227 let mut lines = headers_raw.split("\r\n");
228 let first = lines.next().unwrap_or("");
229 let mut parts = first.split_whitespace();
230 let method = parts.next().unwrap_or("GET").to_string();
231 let target = parts.next().unwrap_or("/");
232 let path = target.split('?').next().unwrap_or("/").to_string();
233
234 let mut headers = Vec::new();
235 for line in lines {
236 if line.is_empty() {
237 continue;
238 }
239 if let Some((key, value)) = line.split_once(':') {
240 let key = key.trim().to_ascii_lowercase();
241 let value = value.trim().to_string();
242 headers.push((key, value));
243 }
244 }
245
246 let body_start = header_end.saturating_add(4);
247 let rest = if buffer.len() > body_start {
248 buffer[body_start..].to_vec()
249 } else {
250 Vec::new()
251 };
252
253 (method, path, headers, rest)
254}
255
256fn header_value(headers: &[(String, String)], key: &str) -> Option<String> {
257 let key = key.to_ascii_lowercase();
258 headers
259 .iter()
260 .find(|(k, _)| k == &key)
261 .map(|(_, v)| v.clone())
262}
263
264fn read_exact_bytes(stream: &mut TcpStream, mut already: Vec<u8>, want: usize) -> Vec<u8> {
265 while already.len() < want {
266 let mut tmp = vec![0u8; (want - already.len()).min(8192)];
267 let n = stream.read(&mut tmp).unwrap_or_default();
268 if n == 0 {
269 break;
270 }
271 already.extend_from_slice(&tmp[..n]);
272 }
273 already.truncate(want);
274 already
275}
276
277fn take_line(buf: &mut Vec<u8>, stream: &mut TcpStream) -> Option<Vec<u8>> {
278 loop {
279 if let Some(pos) = buf.windows(2).position(|w| w == b"\r\n") {
280 let mut line = buf.drain(..pos).collect::<Vec<u8>>();
281 let _ = buf.drain(..2);
282 if line.ends_with(b"\r") {
283 line.pop();
284 }
285 return Some(line);
286 }
287
288 let mut tmp = [0u8; 4096];
289 let n = stream.read(&mut tmp).ok()?;
290 if n == 0 {
291 return None;
292 }
293 buf.extend_from_slice(&tmp[..n]);
294 }
295}
296
297fn read_chunked_body(stream: &mut TcpStream, mut buf: Vec<u8>) -> Vec<u8> {
298 let mut body = Vec::new();
299 while let Some(line) = take_line(&mut buf, stream) {
300 let line_str = String::from_utf8_lossy(&line);
301 let size_str = line_str.split(';').next().unwrap_or("").trim();
302 let Ok(size) = usize::from_str_radix(size_str, 16) else {
303 break;
304 };
305 if size == 0 {
306 while let Some(l) = take_line(&mut buf, stream) {
307 if l.is_empty() {
308 break;
309 }
310 }
311 break;
312 }
313
314 while buf.len() < size + 2 {
315 let mut tmp = [0u8; 8192];
316 let n = stream.read(&mut tmp).unwrap_or_default();
317 if n == 0 {
318 break;
319 }
320 buf.extend_from_slice(&tmp[..n]);
321 }
322
323 if buf.len() < size + 2 {
324 break;
325 }
326 body.extend_from_slice(&buf[..size]);
327 buf.drain(..size + 2);
328 }
329 body
330}
331
332pub struct TestServer {
333 addr: SocketAddr,
334 requests: Arc<Mutex<Vec<RecordedRequest>>>,
335 stop: Arc<AtomicBool>,
336 handle: Option<JoinHandle<()>>,
337}
338
339impl TestServer {
340 pub fn new<F>(handler: F) -> io::Result<Self>
341 where
342 F: Fn(&RecordedRequest) -> HttpResponse + Send + Sync + 'static,
343 {
344 let listener = TcpListener::bind("127.0.0.1:0")?;
345 listener.set_nonblocking(true)?;
346 let addr = listener.local_addr()?;
347
348 let requests = Arc::new(Mutex::new(Vec::new()));
349 let stop = Arc::new(AtomicBool::new(false));
350 let handler = Arc::new(handler);
351
352 let requests_t = Arc::clone(&requests);
353 let stop_t = Arc::clone(&stop);
354 let handler_t = Arc::clone(&handler);
355
356 let handle = thread::spawn(move || {
357 while !stop_t.load(Ordering::SeqCst) {
358 match listener.accept() {
359 Ok((mut stream, _)) => {
360 if let Ok(request) = read_request(&mut stream) {
361 let response = handler_t(&request);
362 requests_t.lock().expect("requests lock").push(request);
363 let _ = write_response(&mut stream, response);
364 }
365 }
366 Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
367 thread::yield_now();
368 }
369 Err(_) => break,
370 }
371 }
372 });
373
374 Ok(Self {
375 addr,
376 requests,
377 stop,
378 handle: Some(handle),
379 })
380 }
381
382 pub fn url(&self) -> String {
383 format!("http://{}", self.addr)
384 }
385
386 pub fn take_requests(&self) -> Vec<RecordedRequest> {
387 let mut guard = self.requests.lock().expect("requests lock");
388 std::mem::take(&mut *guard)
389 }
390}
391
392impl Drop for TestServer {
393 fn drop(&mut self) {
394 self.stop.store(true, Ordering::SeqCst);
395 if let Some(handle) = self.handle.take() {
396 let _ = handle.join();
397 }
398 }
399}
400
401fn write_response(stream: &mut TcpStream, response: HttpResponse) -> io::Result<()> {
402 let status_text = match response.status {
403 200 => "OK",
404 201 => "Created",
405 202 => "Accepted",
406 204 => "No Content",
407 400 => "Bad Request",
408 401 => "Unauthorized",
409 403 => "Forbidden",
410 404 => "Not Found",
411 500 => "Internal Server Error",
412 _ => "OK",
413 };
414
415 let body = response.body;
416 let mut headers = response.headers;
417 headers.push(("Content-Length".to_string(), body.len().to_string()));
418
419 let mut out = String::new();
420 out.push_str(&format!("HTTP/1.1 {} {}\r\n", response.status, status_text));
421 for (k, v) in headers {
422 out.push_str(&format!("{k}: {v}\r\n"));
423 }
424 out.push_str("\r\n");
425 out.push_str(&body);
426
427 stream.write_all(out.as_bytes())?;
428 stream.flush()?;
429 Ok(())
430}