Skip to main content

openapi_contract/
request.rs

1use std::marker::PhantomData;
2
3use futures_core::Stream;
4use serde::de::DeserializeOwned;
5
6use crate::client::{ApiClient, Method};
7use crate::error::{ApiError, DefinedErrorBody};
8use crate::sse::SseEvent;
9
10fn parse_error_response(status: u16, body: String) -> ApiError {
11    if let Ok(parsed) = serde_json::from_str::<DefinedErrorBody>(&body) {
12        if parsed.defined && !parsed.code.is_empty() {
13            return ApiError::Defined {
14                status,
15                code: parsed.code,
16                message: parsed.message,
17            };
18        }
19    }
20    ApiError::Api {
21        status,
22        message: body,
23    }
24}
25
26/// A type-safe API request built by the `api!` macro.
27///
28/// `T` is the expected response type, inferred from the OpenAPI spec at compile time.
29pub struct ApiRequest<T> {
30    pub method: Method,
31    pub path: String,
32    pub query: Option<String>,
33    pub body: Option<String>,
34    pub _marker: PhantomData<T>,
35}
36
37impl<T> ApiRequest<T> {
38    pub fn new(method: Method, path: String) -> Self {
39        Self {
40            method,
41            path,
42            query: None,
43            body: None,
44            _marker: PhantomData,
45        }
46    }
47
48    pub fn query_raw(mut self, qs: impl Into<String>) -> Self {
49        self.query = Some(qs.into());
50        self
51    }
52
53    pub fn body_json(mut self, body: &impl serde::Serialize) -> Self {
54        self.body = Some(serde_json::to_string(body).expect("request body must be serializable"));
55        self
56    }
57
58    /// Set a pre-serialized JSON body. Returns an error if serialization fails.
59    pub fn try_body_json(mut self, body: &impl serde::Serialize) -> Result<Self, ApiError> {
60        self.body = Some(serde_json::to_string(body)?);
61        Ok(self)
62    }
63}
64
65impl<T: DeserializeOwned> ApiRequest<T> {
66    /// Execute the request and deserialize the response.
67    pub async fn fetch(self, client: &(impl ApiClient + ?Sized)) -> Result<T, ApiError> {
68        let resp = client
69            .request(self.method, &self.path, self.query.as_deref(), self.body)
70            .await?;
71
72        let status = resp.status();
73        if !status.is_success() {
74            let body = resp.text().await.unwrap_or_default();
75            return Err(parse_error_response(status.as_u16(), body));
76        }
77
78        let text = resp.text().await?;
79        if text.is_empty() {
80            return serde_json::from_str("null").map_err(ApiError::from);
81        }
82        serde_json::from_str(&text).map_err(ApiError::from)
83    }
84}
85
86impl ApiRequest<()> {
87    /// Execute a request that returns no body (e.g. DELETE, PUT with 204).
88    pub async fn fetch_empty(self, client: &(impl ApiClient + ?Sized)) -> Result<(), ApiError> {
89        let resp = client
90            .request(self.method, &self.path, self.query.as_deref(), self.body)
91            .await?;
92
93        let status = resp.status();
94        if !status.is_success() {
95            let body = resp.text().await.unwrap_or_default();
96            return Err(parse_error_response(status.as_u16(), body));
97        }
98        Ok(())
99    }
100}
101
102impl<T> ApiRequest<T> {
103    /// Execute the request and return an SSE event stream.
104    pub async fn fetch_stream(
105        self,
106        client: &(impl ApiClient + ?Sized),
107    ) -> Result<impl Stream<Item = Result<SseEvent, ApiError>>, ApiError> {
108        let stream = client
109            .request_stream(self.method, &self.path, self.query.as_deref())
110            .await?;
111        Ok(stream)
112    }
113}
114
115fn percent_encode(input: &str) -> String {
116    let mut out = String::with_capacity(input.len());
117    for byte in input.bytes() {
118        match byte {
119            b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
120                out.push(byte as char);
121            }
122            _ => {
123                out.push('%');
124                out.push_str(&format!("{:02X}", byte));
125            }
126        }
127    }
128    out
129}
130
131/// Build a URL-encoded query string from key-value pairs.
132pub fn build_query_string(pairs: &[(&str, &dyn ToString)]) -> String {
133    pairs
134        .iter()
135        .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(&v.to_string())))
136        .collect::<Vec<_>>()
137        .join("&")
138}
139
140#[cfg(test)]
141#[allow(clippy::manual_async_fn)]
142mod tests {
143    use super::*;
144    use crate::sse::SseStream;
145    use tokio::io::AsyncWriteExt;
146    use tokio::time::{sleep, Duration};
147
148    // ── Helpers ─────────────────────────────────────────────────
149
150    async fn mock_response(status: u16, body: &str) -> reqwest::Response {
151        let mut server = mockito::Server::new_async().await;
152        let _mock = server
153            .mock("GET", "/mock")
154            .with_status(status as usize)
155            .with_header("content-type", "application/json")
156            .with_body(body)
157            .create_async()
158            .await;
159        reqwest::get(&format!("{}/mock", server.url()))
160            .await
161            .unwrap()
162    }
163
164    fn make_reqwest_error() -> reqwest::Error {
165        reqwest::Client::new()
166            .get("http://localhost:1/x")
167            .header("bad\0header", "v")
168            .build()
169            .unwrap_err()
170    }
171
172    async fn malformed_chunked_response() -> reqwest::Response {
173        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
174        let addr = listener.local_addr().unwrap();
175        let server = tokio::spawn(async move {
176            let (mut socket, _) = listener.accept().await.unwrap();
177            socket
178                .write_all(
179                    b"HTTP/1.1 200 OK\r\n\
180Content-Type: application/json\r\n\
181Content-Length: 40\r\n\
182Connection: close\r\n\
183\r\n\
184{}",
185                )
186                .await
187                .unwrap();
188            sleep(Duration::from_millis(250)).await;
189            socket.shutdown().await.unwrap();
190        });
191        let resp = reqwest::get(format!("http://{addr}")).await.unwrap();
192        let _ = server.await;
193        resp
194    }
195
196    struct MockClient {
197        status: u16,
198        body: String,
199    }
200    struct FailingClient;
201    struct MalformedBodyClient;
202
203    impl ApiClient for MockClient {
204        fn request(
205            &self,
206            _: Method,
207            _: &str,
208            _: Option<&str>,
209            _: Option<String>,
210        ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
211            let status = self.status;
212            let body = self.body.clone();
213            async move { Ok(mock_response(status, &body).await) }
214        }
215        fn request_stream(
216            &self,
217            _: Method,
218            _: &str,
219            _: Option<&str>,
220        ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
221            async move {
222                let chunks: Vec<Result<bytes::Bytes, reqwest::Error>> =
223                    vec![Ok(bytes::Bytes::from(&b"data: hi\n\n"[..]))];
224                Ok(SseStream::new(Box::pin(futures_util::stream::iter(chunks))))
225            }
226        }
227    }
228
229    impl ApiClient for FailingClient {
230        fn request(
231            &self,
232            _: Method,
233            _: &str,
234            _: Option<&str>,
235            _: Option<String>,
236        ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
237            async { Err(ApiError::Http(make_reqwest_error())) }
238        }
239        fn request_stream(
240            &self,
241            _: Method,
242            _: &str,
243            _: Option<&str>,
244        ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
245            async { Err(ApiError::Http(make_reqwest_error())) }
246        }
247    }
248
249    impl ApiClient for MalformedBodyClient {
250        fn request(
251            &self,
252            _: Method,
253            _: &str,
254            _: Option<&str>,
255            _: Option<String>,
256        ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
257            async { Ok(malformed_chunked_response().await) }
258        }
259        fn request_stream(
260            &self,
261            _: Method,
262            _: &str,
263            _: Option<&str>,
264        ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
265            async { Err(ApiError::Http(make_reqwest_error())) }
266        }
267    }
268
269    // ── Builder tests ───────────────────────────────────────────
270
271    #[test]
272    fn api_request_builder() {
273        // new
274        let req = ApiRequest::<String>::new(Method::GET, "/test".into());
275        assert_eq!(req.method, Method::GET);
276        assert_eq!(req.path, "/test");
277        assert!(req.query.is_none());
278        assert!(req.body.is_none());
279
280        // chaining query + body
281        let body = serde_json::json!({"x": 1});
282        let req = ApiRequest::<String>::new(Method::POST, "/x".into())
283            .query_raw("q=1")
284            .body_json(&body);
285        assert_eq!(req.query.as_deref(), Some("q=1"));
286        assert_eq!(req.body.as_deref(), Some(r#"{"x":1}"#));
287    }
288
289    #[test]
290    fn body_serialization() {
291        // try_body_json success
292        let req = ApiRequest::<String>::new(Method::POST, "/t".into())
293            .try_body_json(&serde_json::json!({"x": 1}))
294            .unwrap();
295        assert!(req.body.is_some());
296
297        // try_body_json failure
298        #[derive(Debug)]
299        struct Bad;
300        impl serde::Serialize for Bad {
301            fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
302                Err(serde::ser::Error::custom("fail"))
303            }
304        }
305        assert!(ApiRequest::<String>::new(Method::POST, "/t".into())
306            .try_body_json(&Bad)
307            .is_err());
308    }
309
310    #[test]
311    #[should_panic(expected = "request body must be serializable")]
312    fn body_json_panics_on_bad_input() {
313        #[derive(Debug)]
314        struct Bad;
315        impl serde::Serialize for Bad {
316            fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
317                Err(serde::ser::Error::custom("fail"))
318            }
319        }
320        let _ = ApiRequest::<String>::new(Method::POST, "/t".into()).body_json(&Bad);
321    }
322
323    #[test]
324    fn query_string_and_percent_encode() {
325        assert_eq!(build_query_string(&[]), "");
326        assert_eq!(build_query_string(&[("limit", &10)]), "limit=10");
327        assert_eq!(
328            build_query_string(&[("a", &"hello"), ("b", &42)]),
329            "a=hello&b=42"
330        );
331        assert_eq!(
332            build_query_string(&[("q", &"hello world"), ("x", &"a&b=c")]),
333            "q=hello%20world&x=a%26b%3Dc"
334        );
335        assert_eq!(percent_encode("abc-_.~123"), "abc-_.~123");
336        assert_eq!(percent_encode("&="), "%26%3D");
337    }
338
339    // ── fetch tests ─────────────────────────────────────────────
340
341    #[tokio::test]
342    async fn fetch_success_and_edge_cases() {
343        // Normal success
344        let client = MockClient {
345            status: 200,
346            body: r#""hello""#.into(),
347        };
348        assert_eq!(
349            ApiRequest::<String>::new(Method::GET, "/t".into())
350                .fetch(&client)
351                .await
352                .unwrap(),
353            "hello"
354        );
355
356        // Empty body → null
357        let client = MockClient {
358            status: 200,
359            body: String::new(),
360        };
361        let result: Option<String> = ApiRequest::new(Method::GET, "/t".into())
362            .fetch(&client)
363            .await
364            .unwrap();
365        assert_eq!(result, None);
366
367        // Deserialization error
368        let client = MockClient {
369            status: 200,
370            body: "not-json".into(),
371        };
372        assert!(ApiRequest::<i32>::new(Method::GET, "/t".into())
373            .fetch(&client)
374            .await
375            .unwrap_err()
376            .to_string()
377            .starts_with("serialization error:"));
378    }
379
380    #[tokio::test]
381    async fn fetch_error_responses() {
382        // Plain API error
383        let client = MockClient {
384            status: 403,
385            body: "forbidden".into(),
386        };
387        let err = ApiRequest::<String>::new(Method::GET, "/t".into())
388            .fetch(&client)
389            .await
390            .unwrap_err();
391        assert!(matches!(err, ApiError::Api { status: 403, .. }));
392
393        // Defined error
394        let client = MockClient {
395            status: 404,
396            body: r#"{"defined":true,"code":"TEAM_NOT_FOUND","message":"Team not found"}"#.into(),
397        };
398        let err = ApiRequest::<String>::new(Method::GET, "/t".into())
399            .fetch(&client)
400            .await
401            .unwrap_err();
402        assert!(err.is_code("TEAM_NOT_FOUND"));
403        assert_eq!(err.status(), Some(404));
404
405        // Non-defined JSON fallback
406        let client = MockClient {
407            status: 400,
408            body: r#"{"defined":false,"code":"NOPE","message":"nope"}"#.into(),
409        };
410        let err = ApiRequest::<String>::new(Method::GET, "/t".into())
411            .fetch(&client)
412            .await
413            .unwrap_err();
414        assert!(matches!(err, ApiError::Api { status: 400, .. }));
415        assert_eq!(err.code(), None);
416    }
417
418    #[tokio::test]
419    async fn fetch_empty_success_and_errors() {
420        // Success
421        let client = MockClient {
422            status: 204,
423            body: String::new(),
424        };
425        assert!(ApiRequest::<()>::new(Method::DELETE, "/t".into())
426            .fetch_empty(&client)
427            .await
428            .is_ok());
429
430        // API error
431        let client = MockClient {
432            status: 500,
433            body: "oops".into(),
434        };
435        assert!(matches!(
436            ApiRequest::<()>::new(Method::DELETE, "/t".into())
437                .fetch_empty(&client)
438                .await
439                .unwrap_err(),
440            ApiError::Api { status: 500, .. }
441        ));
442
443        // Defined error
444        let client = MockClient {
445            status: 403,
446            body: r#"{"defined":true,"code":"FORBIDDEN","message":"no access"}"#.into(),
447        };
448        let err = ApiRequest::<()>::new(Method::DELETE, "/t".into())
449            .fetch_empty(&client)
450            .await
451            .unwrap_err();
452        assert!(err.is_code("FORBIDDEN"));
453
454        // Request error propagation
455        assert!(ApiRequest::<()>::new(Method::DELETE, "/t".into())
456            .fetch_empty(&FailingClient)
457            .await
458            .unwrap_err()
459            .to_string()
460            .starts_with("HTTP error:"));
461    }
462
463    #[tokio::test]
464    async fn fetch_stream_success_and_errors() {
465        use futures_util::StreamExt;
466
467        // Success
468        let client = MockClient {
469            status: 200,
470            body: String::new(),
471        };
472        let mut stream = ApiRequest::<()>::new(Method::GET, "/sse".into())
473            .fetch_stream(&client)
474            .await
475            .unwrap();
476        assert_eq!(stream.next().await.unwrap().unwrap().data, "hi");
477
478        // Request error propagation
479        assert!(ApiRequest::<()>::new(Method::GET, "/sse".into())
480            .fetch_stream(&FailingClient)
481            .await
482            .is_err());
483    }
484
485    #[tokio::test]
486    async fn fetch_propagates_body_read_error() {
487        let err = ApiRequest::<String>::new(Method::GET, "/t".into())
488            .fetch(&MalformedBodyClient)
489            .await
490            .unwrap_err();
491        assert!(err.to_string().starts_with("HTTP error:"));
492    }
493}