Skip to main content

ranvier_http/
test_harness.rs

1use bytes::Bytes;
2use http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode};
3use hyper::server::conn::http1;
4use hyper_util::rt::TokioIo;
5use hyper_util::service::TowerToHyperService;
6use ranvier_core::transition::ResourceRequirement;
7use serde::Serialize;
8use serde::de::DeserializeOwned;
9use tokio::io::{AsyncReadExt, AsyncWriteExt};
10
11use crate::ingress::{HttpIngress, RawIngressService};
12
13#[derive(Debug, thiserror::Error)]
14pub enum TestHarnessError {
15    #[error("io error: {0}")]
16    Io(#[from] std::io::Error),
17    #[error("hyper error: {0}")]
18    Hyper(#[from] hyper::Error),
19    #[error("task join error: {0}")]
20    Join(#[from] tokio::task::JoinError),
21    #[error("invalid response: {0}")]
22    InvalidResponse(&'static str),
23    #[error("invalid utf-8 in response headers: {0}")]
24    Utf8(#[from] std::str::Utf8Error),
25    #[error("invalid status code text: {0}")]
26    InvalidStatus(#[from] std::num::ParseIntError),
27    #[error("invalid status code value: {0}")]
28    InvalidStatusCode(#[from] http::status::InvalidStatusCode),
29    #[error("invalid header name: {0}")]
30    InvalidHeaderName(#[from] http::header::InvalidHeaderName),
31    #[error("invalid header value: {0}")]
32    InvalidHeaderValue(#[from] http::header::InvalidHeaderValue),
33    #[error("json serialization error: {0}")]
34    Json(#[from] serde_json::Error),
35}
36
37/// In-process HTTP test harness for `HttpIngress`.
38///
39/// Uses an in-memory duplex stream and Hyper HTTP/1.1 server connection,
40/// so no TCP socket/network bind is required.
41#[derive(Clone)]
42pub struct TestApp<R> {
43    service: RawIngressService<R>,
44    host: String,
45}
46
47impl<R> TestApp<R>
48where
49    R: ResourceRequirement + Clone + Send + Sync + 'static,
50{
51    pub fn new(ingress: HttpIngress<R>, resources: R) -> Self {
52        Self {
53            service: ingress.into_raw_service(resources),
54            host: "test.local".to_string(),
55        }
56    }
57
58    pub fn with_host(mut self, host: impl Into<String>) -> Self {
59        self.host = host.into();
60        self
61    }
62
63    pub async fn send(&self, request: TestRequest) -> Result<TestResponse, TestHarnessError> {
64        let mut request_bytes = request.to_http1_bytes(&self.host);
65        let capacity = request_bytes.len().saturating_mul(2).max(16 * 1024);
66        let (mut client_io, server_io) = tokio::io::duplex(capacity);
67
68        let service = self.service.clone();
69        let server_task = tokio::spawn(async move {
70            let hyper_service = TowerToHyperService::new(service);
71            http1::Builder::new()
72                .keep_alive(false)
73                .serve_connection(TokioIo::new(server_io), hyper_service)
74                .await
75        });
76
77        client_io.write_all(&request_bytes).await?;
78
79        let mut raw_response = Vec::new();
80        client_io.read_to_end(&mut raw_response).await?;
81
82        let response = TestResponse::from_http1_bytes(&raw_response)?;
83
84        // Connection close races in in-memory duplex mode can surface as
85        // IncompleteMessage after a valid response is already produced.
86        // Treat that specific case as non-fatal for test harness usage.
87        let server_result = server_task.await?;
88        if let Err(error) = server_result {
89            if !error.is_incomplete_message() {
90                return Err(TestHarnessError::Hyper(error));
91            }
92        }
93
94        // Avoid keeping oversized request buffer alive longer than needed.
95        request_bytes.clear();
96
97        Ok(response)
98    }
99}
100
101#[derive(Clone, Debug)]
102pub struct TestRequest {
103    method: Method,
104    path: String,
105    headers: Vec<(String, String)>,
106    body: Bytes,
107}
108
109impl TestRequest {
110    pub fn new(method: Method, path: impl Into<String>) -> Self {
111        Self {
112            method,
113            path: path.into(),
114            headers: Vec::new(),
115            body: Bytes::new(),
116        }
117    }
118
119    pub fn get(path: impl Into<String>) -> Self {
120        Self::new(Method::GET, path)
121    }
122
123    pub fn post(path: impl Into<String>) -> Self {
124        Self::new(Method::POST, path)
125    }
126
127    pub fn put(path: impl Into<String>) -> Self {
128        Self::new(Method::PUT, path)
129    }
130
131    pub fn delete(path: impl Into<String>) -> Self {
132        Self::new(Method::DELETE, path)
133    }
134
135    pub fn patch(path: impl Into<String>) -> Self {
136        Self::new(Method::PATCH, path)
137    }
138
139    pub fn header(mut self, name: impl Into<String>, value: impl Into<String>) -> Self {
140        self.headers.push((name.into(), value.into()));
141        self
142    }
143
144    pub fn body(mut self, body: impl Into<Bytes>) -> Self {
145        self.body = body.into();
146        self
147    }
148
149    pub fn text(mut self, body: impl Into<String>) -> Self {
150        self.body = Bytes::from(body.into());
151        self
152    }
153
154    pub fn json<T: Serialize>(mut self, payload: &T) -> Result<Self, TestHarnessError> {
155        self.body = Bytes::from(serde_json::to_vec(payload)?);
156        self.headers
157            .push(("content-type".to_string(), "application/json".to_string()));
158        Ok(self)
159    }
160
161    fn to_http1_bytes(&self, host: &str) -> Vec<u8> {
162        let path = if self.path.is_empty() {
163            "/"
164        } else {
165            &self.path
166        };
167
168        let mut has_host = false;
169        let mut has_connection = false;
170        let mut has_content_length = false;
171
172        for (name, _) in &self.headers {
173            let lower = name.to_ascii_lowercase();
174            if lower == "host" {
175                has_host = true;
176            } else if lower == "connection" {
177                has_connection = true;
178            } else if lower == "content-length" {
179                has_content_length = true;
180            }
181        }
182
183        let mut output = format!("{} {} HTTP/1.1\r\n", self.method, path);
184
185        if !has_host {
186            output.push_str(&format!("Host: {host}\r\n"));
187        }
188        if !has_connection {
189            output.push_str("Connection: close\r\n");
190        }
191        if !has_content_length {
192            output.push_str(&format!("Content-Length: {}\r\n", self.body.len()));
193        }
194
195        for (name, value) in &self.headers {
196            output.push_str(name);
197            output.push_str(": ");
198            output.push_str(value);
199            output.push_str("\r\n");
200        }
201
202        output.push_str("\r\n");
203
204        let mut bytes = output.into_bytes();
205        bytes.extend_from_slice(&self.body);
206        bytes
207    }
208}
209
210#[derive(Clone, Debug)]
211pub struct TestResponse {
212    status: StatusCode,
213    headers: HeaderMap,
214    body: Bytes,
215}
216
217impl TestResponse {
218    fn from_http1_bytes(raw: &[u8]) -> Result<Self, TestHarnessError> {
219        let delimiter = b"\r\n\r\n";
220        let header_end = raw
221            .windows(delimiter.len())
222            .position(|window| window == delimiter)
223            .ok_or(TestHarnessError::InvalidResponse(
224                "missing HTTP header delimiter",
225            ))?;
226
227        let header_text = std::str::from_utf8(&raw[..header_end])?;
228        let mut lines = header_text.split("\r\n");
229
230        let status_line = lines
231            .next()
232            .ok_or(TestHarnessError::InvalidResponse("missing status line"))?;
233        let mut status_parts = status_line.split_whitespace();
234        let _http_version = status_parts
235            .next()
236            .ok_or(TestHarnessError::InvalidResponse("missing HTTP version"))?;
237        let status_code = status_parts
238            .next()
239            .ok_or(TestHarnessError::InvalidResponse("missing status code"))?
240            .parse::<u16>()?;
241        let status = StatusCode::from_u16(status_code)?;
242
243        let mut headers = HeaderMap::new();
244        for line in lines {
245            if line.is_empty() {
246                continue;
247            }
248            let (name, value) = line
249                .split_once(':')
250                .ok_or(TestHarnessError::InvalidResponse("malformed header line"))?;
251            let name = HeaderName::from_bytes(name.trim().as_bytes())?;
252            let value = HeaderValue::from_str(value.trim())?;
253            headers.append(name, value);
254        }
255
256        let body = Bytes::copy_from_slice(&raw[(header_end + delimiter.len())..]);
257
258        Ok(Self {
259            status,
260            headers,
261            body,
262        })
263    }
264
265    pub fn status(&self) -> StatusCode {
266        self.status
267    }
268
269    pub fn headers(&self) -> &HeaderMap {
270        &self.headers
271    }
272
273    pub fn header(&self, name: &str) -> Option<&HeaderValue> {
274        self.headers.get(name)
275    }
276
277    pub fn body(&self) -> &[u8] {
278        &self.body
279    }
280
281    pub fn text(&self) -> Result<&str, std::str::Utf8Error> {
282        std::str::from_utf8(&self.body)
283    }
284
285    pub fn json<T: DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
286        serde_json::from_slice(&self.body)
287    }
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use ranvier_core::{Outcome, Transition};
294    use ranvier_runtime::Axon;
295
296    #[derive(Clone)]
297    struct Ping;
298
299    #[async_trait::async_trait]
300    impl Transition<(), &'static str> for Ping {
301        type Error = std::convert::Infallible;
302        type Resources = ();
303
304        async fn run(
305            &self,
306            _state: (),
307            _resources: &Self::Resources,
308            _bus: &mut ranvier_core::Bus,
309        ) -> Outcome<&'static str, Self::Error> {
310            Outcome::next("pong")
311        }
312    }
313
314    #[tokio::test]
315    async fn test_app_executes_route_without_network_socket() {
316        let ingress = crate::Ranvier::http::<()>().get(
317            "/ping",
318            Axon::<(), (), std::convert::Infallible, ()>::new("Ping").then(Ping),
319        );
320        let app = TestApp::new(ingress, ());
321
322        let response = app
323            .send(TestRequest::get("/ping"))
324            .await
325            .expect("test request should succeed");
326
327        assert_eq!(response.status(), StatusCode::OK);
328        assert_eq!(response.text().expect("utf8 body"), "pong");
329    }
330}