1use std::marker::PhantomData;
2
3use futures_core::Stream;
4use serde::de::DeserializeOwned;
5
6use crate::client::{ApiClient, Method};
7use crate::error::{ApiError, DefinedErrorBody};
8use crate::sse::SseEvent;
9
10fn parse_error_response(status: u16, body: String) -> ApiError {
11 if let Ok(parsed) = serde_json::from_str::<DefinedErrorBody>(&body) {
12 if parsed.defined && !parsed.code.is_empty() {
13 return ApiError::Defined {
14 status,
15 code: parsed.code,
16 message: parsed.message,
17 };
18 }
19 }
20 ApiError::Api {
21 status,
22 message: body,
23 }
24}
25
26pub struct ApiRequest<T> {
30 pub method: Method,
31 pub path: String,
32 pub query: Option<String>,
33 pub body: Option<String>,
34 pub _marker: PhantomData<T>,
35}
36
37impl<T> ApiRequest<T> {
38 pub fn new(method: Method, path: String) -> Self {
39 Self {
40 method,
41 path,
42 query: None,
43 body: None,
44 _marker: PhantomData,
45 }
46 }
47
48 pub fn query_raw(mut self, qs: impl Into<String>) -> Self {
49 self.query = Some(qs.into());
50 self
51 }
52
53 pub fn body_json(mut self, body: &impl serde::Serialize) -> Self {
54 self.body = Some(serde_json::to_string(body).expect("request body must be serializable"));
55 self
56 }
57
58 pub fn try_body_json(mut self, body: &impl serde::Serialize) -> Result<Self, ApiError> {
60 self.body = Some(serde_json::to_string(body)?);
61 Ok(self)
62 }
63}
64
65impl<T: DeserializeOwned> ApiRequest<T> {
66 pub async fn fetch(self, client: &(impl ApiClient + ?Sized)) -> Result<T, ApiError> {
68 let resp = client
69 .request(self.method, &self.path, self.query.as_deref(), self.body)
70 .await?;
71
72 let status = resp.status();
73 if !status.is_success() {
74 let body = resp.text().await.unwrap_or_default();
75 return Err(parse_error_response(status.as_u16(), body));
76 }
77
78 let text = resp.text().await?;
79 if text.is_empty() {
80 return serde_json::from_str("null").map_err(ApiError::from);
81 }
82 serde_json::from_str(&text).map_err(ApiError::from)
83 }
84}
85
86impl ApiRequest<()> {
87 pub async fn fetch_empty(self, client: &(impl ApiClient + ?Sized)) -> Result<(), ApiError> {
89 let resp = client
90 .request(self.method, &self.path, self.query.as_deref(), self.body)
91 .await?;
92
93 let status = resp.status();
94 if !status.is_success() {
95 let body = resp.text().await.unwrap_or_default();
96 return Err(parse_error_response(status.as_u16(), body));
97 }
98 Ok(())
99 }
100}
101
102impl<T> ApiRequest<T> {
103 pub async fn fetch_stream(
105 self,
106 client: &(impl ApiClient + ?Sized),
107 ) -> Result<impl Stream<Item = Result<SseEvent, ApiError>>, ApiError> {
108 let stream = client
109 .request_stream(self.method, &self.path, self.query.as_deref())
110 .await?;
111 Ok(stream)
112 }
113}
114
115fn percent_encode(input: &str) -> String {
116 let mut out = String::with_capacity(input.len());
117 for byte in input.bytes() {
118 match byte {
119 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
120 out.push(byte as char);
121 }
122 _ => {
123 out.push('%');
124 out.push_str(&format!("{:02X}", byte));
125 }
126 }
127 }
128 out
129}
130
131pub fn build_query_string(pairs: &[(&str, &dyn ToString)]) -> String {
133 pairs
134 .iter()
135 .map(|(k, v)| format!("{}={}", percent_encode(k), percent_encode(&v.to_string())))
136 .collect::<Vec<_>>()
137 .join("&")
138}
139
140#[cfg(test)]
141#[allow(clippy::manual_async_fn)]
142mod tests {
143 use super::*;
144 use crate::sse::SseStream;
145 use tokio::io::AsyncWriteExt;
146 use tokio::time::{Duration, sleep};
147
148 async fn mock_response(status: u16, body: &str) -> reqwest::Response {
151 let mut server = mockito::Server::new_async().await;
152 let _mock = server
153 .mock("GET", "/mock")
154 .with_status(status as usize)
155 .with_header("content-type", "application/json")
156 .with_body(body)
157 .create_async()
158 .await;
159 reqwest::get(&format!("{}/mock", server.url()))
160 .await
161 .unwrap()
162 }
163
164 fn make_reqwest_error() -> reqwest::Error {
165 reqwest::Client::new()
166 .get("http://localhost:1/x")
167 .header("bad\0header", "v")
168 .build()
169 .unwrap_err()
170 }
171
172 async fn malformed_chunked_response() -> reqwest::Response {
173 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
174 let addr = listener.local_addr().unwrap();
175 let server = tokio::spawn(async move {
176 let (mut socket, _) = listener.accept().await.unwrap();
177 socket
178 .write_all(
179 b"HTTP/1.1 200 OK\r\n\
180Content-Type: application/json\r\n\
181Content-Length: 40\r\n\
182Connection: close\r\n\
183\r\n\
184{}",
185 )
186 .await
187 .unwrap();
188 sleep(Duration::from_millis(250)).await;
189 socket.shutdown().await.unwrap();
190 });
191 let resp = reqwest::get(format!("http://{addr}")).await.unwrap();
192 let _ = server.await;
193 resp
194 }
195
196 struct MockClient {
197 status: u16,
198 body: String,
199 }
200 struct FailingClient;
201 struct MalformedBodyClient;
202
203 impl ApiClient for MockClient {
204 fn request(
205 &self,
206 _: Method,
207 _: &str,
208 _: Option<&str>,
209 _: Option<String>,
210 ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
211 let status = self.status;
212 let body = self.body.clone();
213 async move { Ok(mock_response(status, &body).await) }
214 }
215 fn request_stream(
216 &self,
217 _: Method,
218 _: &str,
219 _: Option<&str>,
220 ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
221 async move {
222 let chunks: Vec<Result<bytes::Bytes, reqwest::Error>> =
223 vec![Ok(bytes::Bytes::from(&b"data: hi\n\n"[..]))];
224 Ok(SseStream::new(Box::pin(futures_util::stream::iter(chunks))))
225 }
226 }
227 }
228
229 impl ApiClient for FailingClient {
230 fn request(
231 &self,
232 _: Method,
233 _: &str,
234 _: Option<&str>,
235 _: Option<String>,
236 ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
237 async { Err(ApiError::Http(make_reqwest_error())) }
238 }
239 fn request_stream(
240 &self,
241 _: Method,
242 _: &str,
243 _: Option<&str>,
244 ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
245 async { Err(ApiError::Http(make_reqwest_error())) }
246 }
247 }
248
249 impl ApiClient for MalformedBodyClient {
250 fn request(
251 &self,
252 _: Method,
253 _: &str,
254 _: Option<&str>,
255 _: Option<String>,
256 ) -> impl std::future::Future<Output = Result<reqwest::Response, ApiError>> + Send {
257 async { Ok(malformed_chunked_response().await) }
258 }
259 fn request_stream(
260 &self,
261 _: Method,
262 _: &str,
263 _: Option<&str>,
264 ) -> impl std::future::Future<Output = Result<SseStream, ApiError>> + Send {
265 async { Err(ApiError::Http(make_reqwest_error())) }
266 }
267 }
268
269 #[test]
272 fn api_request_builder() {
273 let req = ApiRequest::<String>::new(Method::GET, "/test".into());
275 assert_eq!(req.method, Method::GET);
276 assert_eq!(req.path, "/test");
277 assert!(req.query.is_none());
278 assert!(req.body.is_none());
279
280 let body = serde_json::json!({"x": 1});
282 let req = ApiRequest::<String>::new(Method::POST, "/x".into())
283 .query_raw("q=1")
284 .body_json(&body);
285 assert_eq!(req.query.as_deref(), Some("q=1"));
286 assert_eq!(req.body.as_deref(), Some(r#"{"x":1}"#));
287 }
288
289 #[test]
290 fn body_serialization() {
291 let req = ApiRequest::<String>::new(Method::POST, "/t".into())
293 .try_body_json(&serde_json::json!({"x": 1}))
294 .unwrap();
295 assert!(req.body.is_some());
296
297 #[derive(Debug)]
299 struct Bad;
300 impl serde::Serialize for Bad {
301 fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
302 Err(serde::ser::Error::custom("fail"))
303 }
304 }
305 assert!(
306 ApiRequest::<String>::new(Method::POST, "/t".into())
307 .try_body_json(&Bad)
308 .is_err()
309 );
310 }
311
312 #[test]
313 #[should_panic(expected = "request body must be serializable")]
314 fn body_json_panics_on_bad_input() {
315 #[derive(Debug)]
316 struct Bad;
317 impl serde::Serialize for Bad {
318 fn serialize<S: serde::Serializer>(&self, _: S) -> Result<S::Ok, S::Error> {
319 Err(serde::ser::Error::custom("fail"))
320 }
321 }
322 let _ = ApiRequest::<String>::new(Method::POST, "/t".into()).body_json(&Bad);
323 }
324
325 #[test]
326 fn query_string_and_percent_encode() {
327 assert_eq!(build_query_string(&[]), "");
328 assert_eq!(build_query_string(&[("limit", &10)]), "limit=10");
329 assert_eq!(
330 build_query_string(&[("a", &"hello"), ("b", &42)]),
331 "a=hello&b=42"
332 );
333 assert_eq!(
334 build_query_string(&[("q", &"hello world"), ("x", &"a&b=c")]),
335 "q=hello%20world&x=a%26b%3Dc"
336 );
337 assert_eq!(percent_encode("abc-_.~123"), "abc-_.~123");
338 assert_eq!(percent_encode("&="), "%26%3D");
339 }
340
341 #[tokio::test]
344 async fn fetch_success_and_edge_cases() {
345 let client = MockClient {
347 status: 200,
348 body: r#""hello""#.into(),
349 };
350 assert_eq!(
351 ApiRequest::<String>::new(Method::GET, "/t".into())
352 .fetch(&client)
353 .await
354 .unwrap(),
355 "hello"
356 );
357
358 let client = MockClient {
360 status: 200,
361 body: String::new(),
362 };
363 let result: Option<String> = ApiRequest::new(Method::GET, "/t".into())
364 .fetch(&client)
365 .await
366 .unwrap();
367 assert_eq!(result, None);
368
369 let client = MockClient {
371 status: 200,
372 body: "not-json".into(),
373 };
374 assert!(
375 ApiRequest::<i32>::new(Method::GET, "/t".into())
376 .fetch(&client)
377 .await
378 .unwrap_err()
379 .to_string()
380 .starts_with("serialization error:")
381 );
382 }
383
384 #[tokio::test]
385 async fn fetch_error_responses() {
386 let client = MockClient {
388 status: 403,
389 body: "forbidden".into(),
390 };
391 let err = ApiRequest::<String>::new(Method::GET, "/t".into())
392 .fetch(&client)
393 .await
394 .unwrap_err();
395 assert!(matches!(err, ApiError::Api { status: 403, .. }));
396
397 let client = MockClient {
399 status: 404,
400 body: r#"{"defined":true,"code":"TEAM_NOT_FOUND","message":"Team not found"}"#.into(),
401 };
402 let err = ApiRequest::<String>::new(Method::GET, "/t".into())
403 .fetch(&client)
404 .await
405 .unwrap_err();
406 assert!(err.is_code("TEAM_NOT_FOUND"));
407 assert_eq!(err.status(), Some(404));
408
409 let client = MockClient {
411 status: 400,
412 body: r#"{"defined":false,"code":"NOPE","message":"nope"}"#.into(),
413 };
414 let err = ApiRequest::<String>::new(Method::GET, "/t".into())
415 .fetch(&client)
416 .await
417 .unwrap_err();
418 assert!(matches!(err, ApiError::Api { status: 400, .. }));
419 assert_eq!(err.code(), None);
420 }
421
422 #[tokio::test]
423 async fn fetch_empty_success_and_errors() {
424 let client = MockClient {
426 status: 204,
427 body: String::new(),
428 };
429 assert!(
430 ApiRequest::<()>::new(Method::DELETE, "/t".into())
431 .fetch_empty(&client)
432 .await
433 .is_ok()
434 );
435
436 let client = MockClient {
438 status: 500,
439 body: "oops".into(),
440 };
441 assert!(matches!(
442 ApiRequest::<()>::new(Method::DELETE, "/t".into())
443 .fetch_empty(&client)
444 .await
445 .unwrap_err(),
446 ApiError::Api { status: 500, .. }
447 ));
448
449 let client = MockClient {
451 status: 403,
452 body: r#"{"defined":true,"code":"FORBIDDEN","message":"no access"}"#.into(),
453 };
454 let err = ApiRequest::<()>::new(Method::DELETE, "/t".into())
455 .fetch_empty(&client)
456 .await
457 .unwrap_err();
458 assert!(err.is_code("FORBIDDEN"));
459
460 assert!(
462 ApiRequest::<()>::new(Method::DELETE, "/t".into())
463 .fetch_empty(&FailingClient)
464 .await
465 .unwrap_err()
466 .to_string()
467 .starts_with("HTTP error:")
468 );
469 }
470
471 #[tokio::test]
472 async fn fetch_stream_success_and_errors() {
473 use futures_util::StreamExt;
474
475 let client = MockClient {
477 status: 200,
478 body: String::new(),
479 };
480 let mut stream = ApiRequest::<()>::new(Method::GET, "/sse".into())
481 .fetch_stream(&client)
482 .await
483 .unwrap();
484 assert_eq!(stream.next().await.unwrap().unwrap().data, "hi");
485
486 assert!(
488 ApiRequest::<()>::new(Method::GET, "/sse".into())
489 .fetch_stream(&FailingClient)
490 .await
491 .is_err()
492 );
493 }
494
495 #[tokio::test]
496 async fn fetch_propagates_body_read_error() {
497 let err = ApiRequest::<String>::new(Method::GET, "/t".into())
498 .fetch(&MalformedBodyClient)
499 .await
500 .unwrap_err();
501 assert!(err.to_string().starts_with("HTTP error:"));
502 }
503}