Skip to main content

nils_test_support/
http.rs

1use std::collections::HashMap;
2use std::io::{self, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream};
4use std::sync::atomic::{AtomicBool, Ordering};
5use std::sync::{Arc, Mutex};
6use std::thread::{self, JoinHandle};
7use std::time::Duration;
8
9#[derive(Clone, Debug)]
10pub struct RecordedRequest {
11    pub method: String,
12    pub path: String,
13    pub headers: Vec<(String, String)>,
14    pub body: Vec<u8>,
15}
16
17impl RecordedRequest {
18    pub fn body_text(&self) -> String {
19        String::from_utf8_lossy(&self.body).to_string()
20    }
21
22    pub fn header_value(&self, key: &str) -> Option<String> {
23        let key = key.to_ascii_lowercase();
24        self.headers
25            .iter()
26            .find(|(k, _)| k == &key)
27            .map(|(_, v)| v.clone())
28    }
29}
30
31#[derive(Clone, Debug)]
32pub struct HttpResponse {
33    pub status: u16,
34    pub body: String,
35    pub headers: Vec<(String, String)>,
36}
37
38impl HttpResponse {
39    pub fn new(status: u16, body: impl Into<String>) -> Self {
40        Self {
41            status,
42            body: body.into(),
43            headers: Vec::new(),
44        }
45    }
46
47    pub fn with_header(mut self, key: &str, value: &str) -> Self {
48        self.headers.push((key.to_string(), value.to_string()));
49        self
50    }
51}
52
53#[derive(Clone, Debug, Eq, PartialEq, Hash)]
54struct RouteKey {
55    method: String,
56    path: String,
57}
58
59pub struct LoopbackServer {
60    addr: SocketAddr,
61    routes: Arc<Mutex<HashMap<RouteKey, HttpResponse>>>,
62    requests: Arc<Mutex<Vec<RecordedRequest>>>,
63    stop: Arc<AtomicBool>,
64    handle: Option<JoinHandle<()>>,
65}
66
67impl LoopbackServer {
68    pub fn new() -> io::Result<Self> {
69        let listener = TcpListener::bind("127.0.0.1:0")?;
70        listener.set_nonblocking(true)?;
71        let addr = listener.local_addr()?;
72
73        let routes = Arc::new(Mutex::new(HashMap::new()));
74        let requests = Arc::new(Mutex::new(Vec::new()));
75        let stop = Arc::new(AtomicBool::new(false));
76
77        let routes_t = Arc::clone(&routes);
78        let requests_t = Arc::clone(&requests);
79        let stop_t = Arc::clone(&stop);
80
81        let handle = thread::spawn(move || {
82            while !stop_t.load(Ordering::SeqCst) {
83                match listener.accept() {
84                    Ok((mut stream, _)) => {
85                        let _ = handle_connection(&mut stream, &routes_t, &requests_t);
86                    }
87                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
88                        thread::yield_now();
89                    }
90                    Err(_) => break,
91                }
92            }
93        });
94
95        Ok(Self {
96            addr,
97            routes,
98            requests,
99            stop,
100            handle: Some(handle),
101        })
102    }
103
104    pub fn url(&self) -> String {
105        format!("http://{}", self.addr)
106    }
107
108    pub fn add_route(&self, method: &str, path: &str, response: HttpResponse) {
109        let key = RouteKey {
110            method: method.trim().to_ascii_uppercase(),
111            path: path.trim().to_string(),
112        };
113        let mut routes = self.routes.lock().expect("routes lock");
114        routes.insert(key, response);
115    }
116
117    pub fn take_requests(&self) -> Vec<RecordedRequest> {
118        let mut guard = self.requests.lock().expect("requests lock");
119        std::mem::take(&mut *guard)
120    }
121}
122
123impl Drop for LoopbackServer {
124    fn drop(&mut self) {
125        self.stop.store(true, Ordering::SeqCst);
126        if let Some(handle) = self.handle.take() {
127            let _ = handle.join();
128        }
129    }
130}
131
132fn handle_connection(
133    stream: &mut TcpStream,
134    routes: &Arc<Mutex<HashMap<RouteKey, HttpResponse>>>,
135    requests: &Arc<Mutex<Vec<RecordedRequest>>>,
136) -> io::Result<()> {
137    let request = read_request(stream)?;
138    requests
139        .lock()
140        .expect("requests lock")
141        .push(request.clone());
142
143    let route_key = RouteKey {
144        method: request.method.to_ascii_uppercase(),
145        path: request.path.clone(),
146    };
147
148    let response = routes
149        .lock()
150        .expect("routes lock")
151        .get(&route_key)
152        .cloned()
153        .unwrap_or_else(|| HttpResponse::new(404, "not found"));
154
155    write_response(stream, response)?;
156    Ok(())
157}
158
159fn read_request(stream: &mut TcpStream) -> io::Result<RecordedRequest> {
160    stream.set_nonblocking(false)?;
161    stream.set_read_timeout(Some(Duration::from_secs(2)))?;
162    let mut buffer = Vec::new();
163    let mut temp = [0u8; 8192];
164
165    loop {
166        let n = match stream.read(&mut temp) {
167            Ok(n) => n,
168            Err(err) if err.kind() == io::ErrorKind::WouldBlock => 0,
169            Err(err) if err.kind() == io::ErrorKind::TimedOut => 0,
170            Err(err) => return Err(err),
171        };
172        if n == 0 {
173            break;
174        }
175        buffer.extend_from_slice(&temp[..n]);
176        if buffer.windows(4).any(|w| w == b"\r\n\r\n") {
177            break;
178        }
179        if buffer.len() > 1024 * 1024 {
180            break;
181        }
182    }
183
184    let (method, path, headers, rest) = parse_headers_and_rest(&buffer);
185
186    if header_value(&headers, "expect")
187        .is_some_and(|v| v.to_ascii_lowercase().contains("100-continue"))
188    {
189        let _ = stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n");
190        let _ = stream.flush();
191    }
192
193    let body = if let Some(te) =
194        header_value(&headers, "transfer-encoding").map(|v| v.to_ascii_lowercase())
195    {
196        if te.contains("chunked") {
197            read_chunked_body(stream, rest)
198        } else {
199            Vec::new()
200        }
201    } else if let Some(cl) = header_value(&headers, "content-length") {
202        let len = cl.parse::<usize>().unwrap_or(0);
203        read_exact_bytes(stream, rest, len)
204    } else {
205        rest
206    };
207
208    Ok(RecordedRequest {
209        method,
210        path,
211        headers,
212        body,
213    })
214}
215
216fn parse_headers_and_rest(buffer: &[u8]) -> (String, String, Vec<(String, String)>, Vec<u8>) {
217    let mut headers_end = None;
218    for i in 0..buffer.len().saturating_sub(3) {
219        if &buffer[i..i + 4] == b"\r\n\r\n" {
220            headers_end = Some(i);
221            break;
222        }
223    }
224
225    let header_end = headers_end.unwrap_or(buffer.len());
226    let headers_raw = String::from_utf8_lossy(&buffer[..header_end]);
227    let mut lines = headers_raw.split("\r\n");
228    let first = lines.next().unwrap_or("");
229    let mut parts = first.split_whitespace();
230    let method = parts.next().unwrap_or("GET").to_string();
231    let target = parts.next().unwrap_or("/");
232    let path = target.split('?').next().unwrap_or("/").to_string();
233
234    let mut headers = Vec::new();
235    for line in lines {
236        if line.is_empty() {
237            continue;
238        }
239        if let Some((key, value)) = line.split_once(':') {
240            let key = key.trim().to_ascii_lowercase();
241            let value = value.trim().to_string();
242            headers.push((key, value));
243        }
244    }
245
246    let body_start = header_end.saturating_add(4);
247    let rest = if buffer.len() > body_start {
248        buffer[body_start..].to_vec()
249    } else {
250        Vec::new()
251    };
252
253    (method, path, headers, rest)
254}
255
256fn header_value(headers: &[(String, String)], key: &str) -> Option<String> {
257    let key = key.to_ascii_lowercase();
258    headers
259        .iter()
260        .find(|(k, _)| k == &key)
261        .map(|(_, v)| v.clone())
262}
263
264fn read_exact_bytes(stream: &mut TcpStream, mut already: Vec<u8>, want: usize) -> Vec<u8> {
265    while already.len() < want {
266        let mut tmp = vec![0u8; (want - already.len()).min(8192)];
267        let n = stream.read(&mut tmp).unwrap_or_default();
268        if n == 0 {
269            break;
270        }
271        already.extend_from_slice(&tmp[..n]);
272    }
273    already.truncate(want);
274    already
275}
276
277fn take_line(buf: &mut Vec<u8>, stream: &mut TcpStream) -> Option<Vec<u8>> {
278    loop {
279        if let Some(pos) = buf.windows(2).position(|w| w == b"\r\n") {
280            let mut line = buf.drain(..pos).collect::<Vec<u8>>();
281            let _ = buf.drain(..2);
282            if line.ends_with(b"\r") {
283                line.pop();
284            }
285            return Some(line);
286        }
287
288        let mut tmp = [0u8; 4096];
289        let n = stream.read(&mut tmp).ok()?;
290        if n == 0 {
291            return None;
292        }
293        buf.extend_from_slice(&tmp[..n]);
294    }
295}
296
297fn read_chunked_body(stream: &mut TcpStream, mut buf: Vec<u8>) -> Vec<u8> {
298    let mut body = Vec::new();
299    while let Some(line) = take_line(&mut buf, stream) {
300        let line_str = String::from_utf8_lossy(&line);
301        let size_str = line_str.split(';').next().unwrap_or("").trim();
302        let Ok(size) = usize::from_str_radix(size_str, 16) else {
303            break;
304        };
305        if size == 0 {
306            while let Some(l) = take_line(&mut buf, stream) {
307                if l.is_empty() {
308                    break;
309                }
310            }
311            break;
312        }
313
314        while buf.len() < size + 2 {
315            let mut tmp = [0u8; 8192];
316            let n = stream.read(&mut tmp).unwrap_or_default();
317            if n == 0 {
318                break;
319            }
320            buf.extend_from_slice(&tmp[..n]);
321        }
322
323        if buf.len() < size + 2 {
324            break;
325        }
326        body.extend_from_slice(&buf[..size]);
327        buf.drain(..size + 2);
328    }
329    body
330}
331
332pub struct TestServer {
333    addr: SocketAddr,
334    requests: Arc<Mutex<Vec<RecordedRequest>>>,
335    stop: Arc<AtomicBool>,
336    handle: Option<JoinHandle<()>>,
337}
338
339impl TestServer {
340    pub fn new<F>(handler: F) -> io::Result<Self>
341    where
342        F: Fn(&RecordedRequest) -> HttpResponse + Send + Sync + 'static,
343    {
344        let listener = TcpListener::bind("127.0.0.1:0")?;
345        listener.set_nonblocking(true)?;
346        let addr = listener.local_addr()?;
347
348        let requests = Arc::new(Mutex::new(Vec::new()));
349        let stop = Arc::new(AtomicBool::new(false));
350        let handler = Arc::new(handler);
351
352        let requests_t = Arc::clone(&requests);
353        let stop_t = Arc::clone(&stop);
354        let handler_t = Arc::clone(&handler);
355
356        let handle = thread::spawn(move || {
357            while !stop_t.load(Ordering::SeqCst) {
358                match listener.accept() {
359                    Ok((mut stream, _)) => {
360                        if let Ok(request) = read_request(&mut stream) {
361                            let response = handler_t(&request);
362                            requests_t.lock().expect("requests lock").push(request);
363                            let _ = write_response(&mut stream, response);
364                        }
365                    }
366                    Err(err) if err.kind() == io::ErrorKind::WouldBlock => {
367                        thread::yield_now();
368                    }
369                    Err(_) => break,
370                }
371            }
372        });
373
374        Ok(Self {
375            addr,
376            requests,
377            stop,
378            handle: Some(handle),
379        })
380    }
381
382    pub fn url(&self) -> String {
383        format!("http://{}", self.addr)
384    }
385
386    pub fn take_requests(&self) -> Vec<RecordedRequest> {
387        let mut guard = self.requests.lock().expect("requests lock");
388        std::mem::take(&mut *guard)
389    }
390}
391
392impl Drop for TestServer {
393    fn drop(&mut self) {
394        self.stop.store(true, Ordering::SeqCst);
395        if let Some(handle) = self.handle.take() {
396            let _ = handle.join();
397        }
398    }
399}
400
401fn write_response(stream: &mut TcpStream, response: HttpResponse) -> io::Result<()> {
402    let status_text = match response.status {
403        200 => "OK",
404        201 => "Created",
405        202 => "Accepted",
406        204 => "No Content",
407        400 => "Bad Request",
408        401 => "Unauthorized",
409        403 => "Forbidden",
410        404 => "Not Found",
411        500 => "Internal Server Error",
412        _ => "OK",
413    };
414
415    let body = response.body;
416    let mut headers = response.headers;
417    headers.push(("Content-Length".to_string(), body.len().to_string()));
418
419    let mut out = String::new();
420    out.push_str(&format!("HTTP/1.1 {} {}\r\n", response.status, status_text));
421    for (k, v) in headers {
422        out.push_str(&format!("{k}: {v}\r\n"));
423    }
424    out.push_str("\r\n");
425    out.push_str(&body);
426
427    stream.write_all(out.as_bytes())?;
428    stream.flush()?;
429    Ok(())
430}