1use crate::conn::{
2 self, build_request, get_response_string, stream_json_response, stream_response, Compat,
3 Headers, Payload, Transport,
4};
5use futures_util::{
6 io::{AsyncRead, AsyncWrite},
7 stream::Stream,
8 TryFutureExt, TryStreamExt,
9};
10use hyper::{body::Bytes, header, Body, Method, Request, Response, StatusCode};
11use log::trace;
12use serde::de::DeserializeOwned;
13use std::future::Future;
14use std::pin::Pin;
15
16#[derive(Debug, Clone)]
17pub struct RequestClient<E> {
18 transport: Transport,
19 validate_fn: Box<ValidateResponseFn<E>>,
20 _error_type: std::marker::PhantomData<E>,
21}
22
23pub type ValidateResponseFn<E> =
24 fn(Response<Body>) -> Pin<Box<dyn Future<Output = Result<Response<Body>, E>> + Send + Sync>>;
25
26impl<E: From<conn::Error> + From<serde_json::Error>> RequestClient<E> {
27 pub fn new(transport: Transport, validate_fn: Box<ValidateResponseFn<E>>) -> Self {
30 Self {
31 transport,
32 validate_fn,
33 _error_type: std::marker::PhantomData,
34 }
35 }
36
37 fn make_request<B>(
38 &self,
39 method: http::Method,
40 endpoint: &str,
41 body: Payload<B>,
42 headers: Option<Headers>,
43 ) -> conn::Result<Request<Body>>
44 where
45 B: Into<Body>,
46 {
47 let uri = self.transport.make_uri(endpoint)?;
48 build_request(method, uri, body, headers)
49 }
50
51 async fn send_request(&self, request: Request<Body>) -> Result<Response<Body>, E> {
52 let response = self.transport.request(request).await.map_err(E::from)?;
53 (self.validate_fn)(response).await
54 }
55
56 pub async fn get(&self, endpoint: impl AsRef<str>) -> Result<Response<Body>, E> {
62 let req = self.make_request(
63 Method::GET,
64 endpoint.as_ref(),
65 Payload::empty(),
66 Headers::none(),
67 );
68 self.send_request(req?).await
69 }
70
71 pub async fn get_string(&self, endpoint: impl AsRef<str>) -> Result<String, E> {
73 let response = self.get(endpoint).await?;
74 get_response_string(response).await.map_err(E::from)
75 }
76
77 pub async fn get_json<T: DeserializeOwned>(&self, endpoint: impl AsRef<str>) -> Result<T, E> {
79 let raw_string = self.get_string(endpoint).await?;
80 trace!("{raw_string}");
81 serde_json::from_str::<T>(&raw_string).map_err(E::from)
82 }
83
84 async fn get_stream_impl(
85 &self,
86 endpoint: impl AsRef<str>,
87 ) -> Result<impl Stream<Item = Result<Bytes, E>> + '_, E> {
88 let response = self.get(endpoint).await?;
89 Ok(stream_response(response).map_err(E::from))
90 }
91
92 pub fn get_stream<'client>(
94 &'client self,
95 endpoint: impl AsRef<str> + 'client,
96 ) -> impl Stream<Item = Result<Bytes, E>> + 'client {
97 self.get_stream_impl(endpoint).try_flatten_stream()
98 }
99
100 pub fn get_json_stream<'client, T>(
102 &'client self,
103 endpoint: impl AsRef<str> + 'client,
104 ) -> impl Stream<Item = Result<T, E>> + 'client
105 where
106 T: DeserializeOwned,
107 {
108 self.get_stream(endpoint)
109 .and_then(|chunk| async move {
110 let stream = futures_util::stream::iter(
111 serde_json::Deserializer::from_slice(&chunk)
112 .into_iter()
113 .collect::<Vec<_>>(),
114 )
115 .map_err(E::from);
116
117 Ok(stream)
118 })
119 .try_flatten()
120 }
121
122 pub async fn post<B>(
128 &self,
129 endpoint: impl AsRef<str>,
130 body: Payload<B>,
131 headers: Option<Headers>,
132 ) -> Result<Response<Body>, E>
133 where
134 B: Into<Body>,
135 {
136 let req = self.make_request(Method::POST, endpoint.as_ref(), body, headers);
137 self.send_request(req?).await
138 }
139
140 pub async fn post_string<B>(
142 &self,
143 endpoint: impl AsRef<str>,
144 body: Payload<B>,
145 headers: Option<Headers>,
146 ) -> Result<String, E>
147 where
148 B: Into<Body>,
149 {
150 let response = self.post(endpoint, body, headers).await?;
151 get_response_string(response).await.map_err(E::from)
152 }
153
154 pub async fn post_json<B, T>(
157 &self,
158 endpoint: impl AsRef<str>,
159 body: Payload<B>,
160 headers: Option<Headers>,
161 ) -> Result<T, E>
162 where
163 T: DeserializeOwned,
164 B: Into<Body>,
165 {
166 let raw_string = self.post_string(endpoint, body, headers).await?;
167 trace!("{raw_string}");
168 serde_json::from_str::<T>(&raw_string).map_err(E::from)
169 }
170
171 async fn post_stream_impl<B>(
172 &self,
173 endpoint: impl AsRef<str>,
174 body: Payload<B>,
175 headers: Option<Headers>,
176 ) -> Result<impl Stream<Item = Result<Bytes, E>> + '_, E>
177 where
178 B: Into<Body>,
179 {
180 let response = self.post(endpoint, body, headers).await?;
181 Ok(stream_response(response).map_err(E::from))
182 }
183
184 pub fn post_stream<'client, B>(
190 &'client self,
191 endpoint: impl AsRef<str> + 'client,
192 body: Payload<B>,
193 headers: Option<Headers>,
194 ) -> impl Stream<Item = Result<Bytes, E>> + 'client
195 where
196 B: Into<Body> + 'client,
197 {
198 self.post_stream_impl(endpoint, body, headers)
199 .try_flatten_stream()
200 }
201
202 async fn post_json_stream_impl<B>(
203 &self,
204 endpoint: impl AsRef<str>,
205 body: Payload<B>,
206 headers: Option<Headers>,
207 ) -> Result<impl Stream<Item = Result<Bytes, E>> + '_, E>
208 where
209 B: Into<Body>,
210 {
211 let response = self.post(endpoint, body, headers).await?;
212 Ok(stream_json_response(response).map_err(E::from))
213 }
214
215 fn post_json_stream<'client, B>(
217 &'client self,
218 endpoint: impl AsRef<str> + 'client,
219 body: Payload<B>,
220 headers: Option<Headers>,
221 ) -> impl Stream<Item = Result<Bytes, E>> + 'client
222 where
223 B: Into<Body> + 'client,
224 {
225 self.post_json_stream_impl(endpoint, body, headers)
226 .try_flatten_stream()
227 }
228
229 pub fn post_into_stream<'client, B, T>(
232 &'client self,
233 endpoint: impl AsRef<str> + 'client,
234 body: Payload<B>,
235 headers: Option<Headers>,
236 ) -> impl Stream<Item = Result<T, E>> + 'client
237 where
238 B: Into<Body> + 'client,
239 T: DeserializeOwned,
240 {
241 self.post_json_stream(endpoint, body, headers)
242 .and_then(|chunk| async move {
243 trace!("got chunk {:?}", chunk);
244 let stream = futures_util::stream::iter(
245 serde_json::Deserializer::from_slice(&chunk)
246 .into_iter()
247 .collect::<Vec<_>>(),
248 )
249 .map_err(E::from);
250
251 Ok(stream)
252 })
253 .try_flatten()
254 }
255
256 pub async fn post_upgrade_stream<B>(
257 self,
258 endpoint: impl AsRef<str>,
259 body: Payload<B>,
260 ) -> Result<impl AsyncRead + AsyncWrite, E>
261 where
262 B: Into<Body>,
263 {
264 self.stream_upgrade(Method::POST, endpoint, body)
265 .await
266 .map_err(E::from)
267 }
268
269 pub async fn put<B>(
275 &self,
276 endpoint: impl AsRef<str>,
277 body: Payload<B>,
278 ) -> Result<Response<Body>, E>
279 where
280 B: Into<Body>,
281 {
282 let req = self.make_request(Method::PUT, endpoint.as_ref(), body, Headers::none());
283 self.send_request(req?).await
284 }
285
286 pub async fn put_string<B>(
288 &self,
289 endpoint: impl AsRef<str>,
290 body: Payload<B>,
291 ) -> Result<String, E>
292 where
293 B: Into<Body>,
294 {
295 let response = self.put(endpoint, body).await?;
296 get_response_string(response).await.map_err(E::from)
297 }
298
299 pub async fn delete(&self, endpoint: impl AsRef<str>) -> Result<Response<Body>, E> {
305 let req = self.make_request(
306 Method::DELETE,
307 endpoint.as_ref(),
308 Payload::empty(),
309 Headers::none(),
310 );
311 self.send_request(req?).await
312 }
313
314 pub async fn delete_string(&self, endpoint: impl AsRef<str>) -> Result<String, E> {
316 let response = self.delete(endpoint).await?;
317 get_response_string(response).await.map_err(E::from)
318 }
319
320 pub async fn delete_json<T: DeserializeOwned>(
323 &self,
324 endpoint: impl AsRef<str>,
325 ) -> Result<T, E> {
326 let raw_string = self.delete_string(endpoint).await?;
327 trace!("{raw_string}");
328 serde_json::from_str::<T>(&raw_string).map_err(E::from)
329 }
330
331 pub async fn head(&self, endpoint: impl AsRef<str>) -> Result<Response<Body>, E> {
337 let req = self.make_request(
338 Method::HEAD,
339 endpoint.as_ref(),
340 Payload::empty(),
341 Headers::none(),
342 );
343 self.send_request(req?).await
344 }
345
346 async fn stream_upgrade<B>(
351 &self,
352 method: Method,
353 endpoint: impl AsRef<str>,
354 body: Payload<B>,
355 ) -> Result<impl AsyncRead + AsyncWrite, E>
356 where
357 B: Into<Body>,
358 {
359 self.stream_upgrade_tokio(method, endpoint.as_ref(), body)
360 .await
361 .map(Compat::new)
362 }
363
364 async fn stream_upgrade_tokio<B>(
367 &self,
368 method: Method,
369 endpoint: &str,
370 body: Payload<B>,
371 ) -> Result<hyper::upgrade::Upgraded, E>
372 where
373 B: Into<Body>,
374 {
375 let mut headers = Headers::default();
376 headers.add(header::CONNECTION.as_str(), "Upgrade");
377 headers.add(header::UPGRADE.as_str(), "tcp");
378
379 let req = self.make_request(method, endpoint, body, Some(headers));
380
381 let response = self.send_request(req?).await?;
382 match response.status() {
383 StatusCode::SWITCHING_PROTOCOLS => Ok(hyper::upgrade::on(response)
384 .await
385 .map_err(conn::Error::from)?),
386 _ => Err(E::from(conn::Error::ConnectionNotUpgraded)),
387 }
388 }
389}