Skip to main content

openapi_contract/
request.rs

1use std::marker::PhantomData;
2
3use futures_core::Stream;
4use serde::de::DeserializeOwned;
5
6use crate::client::{ApiClient, Method};
7use crate::error::{ApiError, DefinedErrorBody};
8use crate::sse::SseEvent;
9
10fn parse_error_response(status: u16, body: String) -> ApiError {
11    if let Ok(parsed) = serde_json::from_str::<DefinedErrorBody>(&body) {
12        if parsed.defined && !parsed.code.is_empty() {
13            return ApiError::Defined {
14                status,
15                code: parsed.code,
16                message: parsed.message,
17            };
18        }
19    }
20    ApiError::Api {
21        status,
22        message: body,
23    }
24}
25
26/// A type-safe API request built by the `api!` macro.
27///
28/// `T` is the expected response type, inferred from the OpenAPI spec at compile time.
29pub struct ApiRequest<T> {
30    pub method: Method,
31    pub path: String,
32    pub query: Option<String>,
33    pub body: Option<String>,
34    pub _marker: PhantomData<T>,
35}
36
37impl<T> ApiRequest<T> {
38    pub fn new(method: Method, path: String) -> Self {
39        Self {
40            method,
41            path,
42            query: None,
43            body: None,
44            _marker: PhantomData,
45        }
46    }
47
48    pub fn query_raw(mut self, qs: impl Into<String>) -> Self {
49        self.query = Some(qs.into());
50        self
51    }
52
53    pub fn body_json(mut self, body: &impl serde::Serialize) -> Self {
54        self.body = Some(serde_json::to_string(body).expect("request body must be serializable"));
55        self
56    }
57
58    /// Set a pre-serialized JSON body. Returns an error if serialization fails.
59    pub fn try_body_json(mut self, body: &impl serde::Serialize) -> Result<Self, ApiError> {
60        self.body = Some(serde_json::to_string(body)?);
61        Ok(self)
62    }
63}
64
65impl<T: DeserializeOwned> ApiRequest<T> {
66    /// Execute the request and deserialize the response.
67    pub async fn fetch(self, client: &(impl ApiClient + ?Sized)) -> Result<T, ApiError> {
68        let resp = client
69            .request(self.method, &self.path, self.query.as_deref(), self.body)
70            .await?;
71
72        let status = resp.status();
73        if !status.is_success() {
74            let body = resp.text().await.unwrap_or_default();
75            return Err(parse_error_response(status.as_u16(), body));
76        }
77
78        let text = resp.text().await?;
79        if text.is_empty() {
80            return serde_json::from_str("null").map_err(ApiError::from);
81        }
82        serde_json::from_str(&text).map_err(ApiError::from)
83    }
84}
85
86impl ApiRequest<()> {
87    /// Execute a request that returns no body (e.g. DELETE, PUT with 204).
88    pub async fn fetch_empty(self, client: &(impl ApiClient + ?Sized)) -> Result<(), ApiError> {
89        let resp = client
90            .request(self.method, &self.path, self.query.as_deref(), self.body)
91            .await?;
92
93        let status = resp.status();
94        if !status.is_success() {
95            let body = resp.text().await.unwrap_or_default();
96            return Err(parse_error_response(status.as_u16(), body));
97        }
98        Ok(())
99    }
100}
101
102impl<T> ApiRequest<T> {
103    /// Execute the request and return an SSE event stream.
104    pub async fn fetch_stream(
105        self,
106        client: &(impl ApiClient + ?Sized),
107    ) -> Result<impl Stream<Item = Result<SseEvent, ApiError>>, ApiError> {
108        let stream = client
109            .request_stream(self.method, &self.path, self.query.as_deref())
110            .await?;
111        Ok(stream)
112    }
113}
114
115fn percent_encode(input: &str) -> String {
116    let mut out = String::with_capacity(input.len());
117    for byte in input.bytes() {
118        match byte {
119            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
120                out.push(byte as char);
121            }
122            _ => {
123                out.push('%');
124                out.push_str(&format!("{:02X}", byte));
125            }
126        }
127    }
128    out
129}
130
131/// Build a URL-encoded query string from key-value pairs.
132pub fn build_query_string(pairs: &[(&str, &dyn ToString)]) -> String {
133    pairs
134        .iter()
135        .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(&v.to_string())))
136        .collect::<Vec<_>>()
137        .join("&")
138}
139
140#[cfg(test)]
141#[allow(clippy::manual_async_fn)]
142mod tests {
143    use super::*;
144    use crate::sse::SseStream;
145    use tokio::io::AsyncWriteExt;
146    use tokio::time::{Duration, sleep};
147
148    // ── Helpers ─────────────────────────────────────────────────
149
150    async fn mock_response(status: u16, body: &str) -> reqwest::Response {
151        let mut server = mockito::Server::new_async().await;
152        let _mock = server
153            .mock("GET", "/mock")
154            .with_status(status as usize)
155            .with_header("content-type", "application/json")
156            .with_body(body)
157            .create_async()
158            .await;
159        reqwest::get(&format!("{}/mock", server.url()))
160            .await
161            .unwrap()
162    }
163
164    fn make_reqwest_error() -> reqwest::Error {
165        reqwest::Client::new()
166            .get("http://localhost:1/x")
167            .header("bad\0header", "v")
168            .build()
169            .unwrap_err()
170    }
171
172    async fn malformed_chunked_response() -> reqwest::Response {
173        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
174        let addr = listener.local_addr().unwrap();
175        let server = tokio::spawn(async move {
176            let (mut socket, _) = listener.accept().await.unwrap();
177            socket
178                .write_all(
179                    b"HTTP/1.1 200 OK\r\n\
180Content-Type: application/json\r\n\
181Content-Length: 40\r\n\
182Connection: close\r\n\
183\r\n\
184{}",
185                )
186                .await
187                .unwrap();
188            sleep(Duration::from_millis(250)).await;
189            socket.shutdown().await.unwrap();
190        });
191        let resp = reqwest::get(format!("http://{addr}")).await.unwrap();
192        let _ = server.await;
193        resp
194    }
195
196    struct MockClient {
197        status: u16,
198        body: String,
199    }
200    struct FailingClient;
201    struct MalformedBodyClient;
202
203    impl ApiClient for MockClient {
204        fn request(
205            &self,
206            _: Method,
207            _: &str,
208            _: Option<&str>,
209            _: Option<String>,
210        ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
211            let status = self.status;
212            let body = self.body.clone();
213            async move { Ok(mock_response(status, &body).await) }
214        }
215        fn request_stream(
216            &self,
217            _: Method,
218            _: &str,
219            _: Option<&str>,
220        ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
221            async move {
222                let chunks: Vec<Result<bytes::Bytes, reqwest::Error>> =
223                    vec![Ok(bytes::Bytes::from(&b"data: hi\n\n"[..]))];
224                Ok(SseStream::new(Box::pin(futures_util::stream::iter(chunks))))
225            }
226        }
227    }
228
229    impl ApiClient for FailingClient {
230        fn request(
231            &self,
232            _: Method,
233            _: &str,
234            _: Option<&str>,
235            _: Option<String>,
236        ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
237            async { Err(ApiError::Http(make_reqwest_error())) }
238        }
239        fn request_stream(
240            &self,
241            _: Method,
242            _: &str,
243            _: Option<&str>,
244        ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
245            async { Err(ApiError::Http(make_reqwest_error())) }
246        }
247    }
248
249    impl ApiClient for MalformedBodyClient {
250        fn request(
251            &self,
252            _: Method,
253            _: &str,
254            _: Option<&str>,
255            _: Option<String>,
256        ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
257            async { Ok(malformed_chunked_response().await) }
258        }
259        fn request_stream(
260            &self,
261            _: Method,
262            _: &str,
263            _: Option<&str>,
264        ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
265            async { Err(ApiError::Http(make_reqwest_error())) }
266        }
267    }
268
269    // ── Builder tests ───────────────────────────────────────────
270
271    #[test]
272    fn api_request_builder() {
273        // new
274        let req = ApiRequest::<String>::new(Method::GET, "/test".into());
275        assert_eq!(req.method, Method::GET);
276        assert_eq!(req.path, "/test");
277        assert!(req.query.is_none());
278        assert!(req.body.is_none());
279
280        // chaining query + body
281        let body = serde_json::json!({"x": 1});
282        let req = ApiRequest::<String>::new(Method::POST, "/x".into())
283            .query_raw("q=1")
284            .body_json(&body);
285        assert_eq!(req.query.as_deref(), Some("q=1"));
286        assert_eq!(req.body.as_deref(), Some(r#"{"x":1}"#));
287    }
288
289    #[test]
290    fn body_serialization() {
291        // try_body_json success
292        let req = ApiRequest::<String>::new(Method::POST, "/t".into())
293            .try_body_json(&serde_json::json!({"x": 1}))
294            .unwrap();
295        assert!(req.body.is_some());
296
297        // try_body_json failure
298        #[derive(Debug)]
299        struct Bad;
300        impl serde::Serialize for Bad {
301            fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
302                Err(serde::ser::Error::custom("fail"))
303            }
304        }
305        assert!(
306            ApiRequest::<String>::new(Method::POST, "/t".into())
307                .try_body_json(&Bad)
308                .is_err()
309        );
310    }
311
312    #[test]
313    #[should_panic(expected = "request body must be serializable")]
314    fn body_json_panics_on_bad_input() {
315        #[derive(Debug)]
316        struct Bad;
317        impl serde::Serialize for Bad {
318            fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
319                Err(serde::ser::Error::custom("fail"))
320            }
321        }
322        let _ = ApiRequest::<String>::new(Method::POST, "/t".into()).body_json(&Bad);
323    }
324
325    #[test]
326    fn query_string_and_percent_encode() {
327        assert_eq!(build_query_string(&[]), "");
328        assert_eq!(build_query_string(&[("limit", &10)]), "limit=10");
329        assert_eq!(
330            build_query_string(&[("a", &"hello"), ("b", &42)]),
331            "a=hello&b=42"
332        );
333        assert_eq!(
334            build_query_string(&[("q", &"hello world"), ("x", &"a&b=c")]),
335            "q=hello%20world&x=a%26b%3Dc"
336        );
337        assert_eq!(percent_encode("abc-_.~123"), "abc-_.~123");
338        assert_eq!(percent_encode("&="), "%26%3D");
339    }
340
341    // ── fetch tests ─────────────────────────────────────────────
342
343    #[tokio::test]
344    async fn fetch_success_and_edge_cases() {
345        // Normal success
346        let client = MockClient {
347            status: 200,
348            body: r#""hello""#.into(),
349        };
350        assert_eq!(
351            ApiRequest::<String>::new(Method::GET, "/t".into())
352                .fetch(&client)
353                .await
354                .unwrap(),
355            "hello"
356        );
357
358        // Empty body → null
359        let client = MockClient {
360            status: 200,
361            body: String::new(),
362        };
363        let result: Option<String> = ApiRequest::new(Method::GET, "/t".into())
364            .fetch(&client)
365            .await
366            .unwrap();
367        assert_eq!(result, None);
368
369        // Deserialization error
370        let client = MockClient {
371            status: 200,
372            body: "not-json".into(),
373        };
374        assert!(
375            ApiRequest::<i32>::new(Method::GET, "/t".into())
376                .fetch(&client)
377                .await
378                .unwrap_err()
379                .to_string()
380                .starts_with("serialization error:")
381        );
382    }
383
384    #[tokio::test]
385    async fn fetch_error_responses() {
386        // Plain API error
387        let client = MockClient {
388            status: 403,
389            body: "forbidden".into(),
390        };
391        let err = ApiRequest::<String>::new(Method::GET, "/t".into())
392            .fetch(&client)
393            .await
394            .unwrap_err();
395        assert!(matches!(err, ApiError::Api { status: 403, .. }));
396
397        // Defined error
398        let client = MockClient {
399            status: 404,
400            body: r#"{"defined":true,"code":"TEAM_NOT_FOUND","message":"Team not found"}"#.into(),
401        };
402        let err = ApiRequest::<String>::new(Method::GET, "/t".into())
403            .fetch(&client)
404            .await
405            .unwrap_err();
406        assert!(err.is_code("TEAM_NOT_FOUND"));
407        assert_eq!(err.status(), Some(404));
408
409        // Non-defined JSON fallback
410        let client = MockClient {
411            status: 400,
412            body: r#"{"defined":false,"code":"NOPE","message":"nope"}"#.into(),
413        };
414        let err = ApiRequest::<String>::new(Method::GET, "/t".into())
415            .fetch(&client)
416            .await
417            .unwrap_err();
418        assert!(matches!(err, ApiError::Api { status: 400, .. }));
419        assert_eq!(err.code(), None);
420    }
421
422    #[tokio::test]
423    async fn fetch_empty_success_and_errors() {
424        // Success
425        let client = MockClient {
426            status: 204,
427            body: String::new(),
428        };
429        assert!(
430            ApiRequest::<()>::new(Method::DELETE, "/t".into())
431                .fetch_empty(&client)
432                .await
433                .is_ok()
434        );
435
436        // API error
437        let client = MockClient {
438            status: 500,
439            body: "oops".into(),
440        };
441        assert!(matches!(
442            ApiRequest::<()>::new(Method::DELETE, "/t".into())
443                .fetch_empty(&client)
444                .await
445                .unwrap_err(),
446            ApiError::Api { status: 500, .. }
447        ));
448
449        // Defined error
450        let client = MockClient {
451            status: 403,
452            body: r#"{"defined":true,"code":"FORBIDDEN","message":"no access"}"#.into(),
453        };
454        let err = ApiRequest::<()>::new(Method::DELETE, "/t".into())
455            .fetch_empty(&client)
456            .await
457            .unwrap_err();
458        assert!(err.is_code("FORBIDDEN"));
459
460        // Request error propagation
461        assert!(
462            ApiRequest::<()>::new(Method::DELETE, "/t".into())
463                .fetch_empty(&FailingClient)
464                .await
465                .unwrap_err()
466                .to_string()
467                .starts_with("HTTP error:")
468        );
469    }
470
471    #[tokio::test]
472    async fn fetch_stream_success_and_errors() {
473        use futures_util::StreamExt;
474
475        // Success
476        let client = MockClient {
477            status: 200,
478            body: String::new(),
479        };
480        let mut stream = ApiRequest::<()>::new(Method::GET, "/sse".into())
481            .fetch_stream(&client)
482            .await
483            .unwrap();
484        assert_eq!(stream.next().await.unwrap().unwrap().data, "hi");
485
486        // Request error propagation
487        assert!(
488            ApiRequest::<()>::new(Method::GET, "/sse".into())
489                .fetch_stream(&FailingClient)
490                .await
491                .is_err()
492        );
493    }
494
495    #[tokio::test]
496    async fn fetch_propagates_body_read_error() {
497        let err = ApiRequest::<String>::new(Method::GET, "/t".into())
498            .fetch(&MalformedBodyClient)
499            .await
500            .unwrap_err();
501        assert!(err.to_string().starts_with("HTTP error:"));
502    }
503}