1use std::{io::Error as IoError, sync::Arc};
2
3use bytes::Bytes;
4use futures_util::TryStreamExt;
5use http_body_util::BodyExt;
6use hyper_util::{client::legacy::Client, rt::TokioExecutor};
7use poem::{
8 Endpoint, EndpointExt, IntoEndpoint, Middleware, Request as HttpRequest,
9 Response as HttpResponse,
10 endpoint::{DynEndpoint, ToDynEndpoint},
11 http::{
12 Extensions, HeaderValue, Method, StatusCode, Uri, Version,
13 header::{self, InvalidHeaderValue},
14 uri::InvalidUri,
15 },
16};
17use rustls::ClientConfig as TlsClientConfig;
18
19use crate::{
20 Code, CompressionEncoding, Metadata, Request, Response, Status, Streaming,
21 codec::Codec,
22 compression::get_incoming_encodings,
23 connector::HttpsConnector,
24 encoding::{create_decode_response_body, create_encode_request_body},
25};
26
27pub(crate) type BoxBody = http_body_util::combinators::BoxBody<Bytes, IoError>;
28
29pub struct ClientConfig {
31 uris: Vec<Uri>,
32 origin: Option<Uri>,
33 user_agent: Option<HeaderValue>,
34 tls_config: Option<TlsClientConfig>,
35 max_header_list_size: u32,
36}
37
38impl ClientConfig {
39 pub fn builder() -> ClientConfigBuilder {
41 ClientConfigBuilder {
42 config: Ok(ClientConfig {
43 uris: vec![],
44 origin: None,
45 user_agent: None,
46 tls_config: None,
47 max_header_list_size: 16384,
48 }),
49 }
50 }
51}
52
53#[derive(Debug, thiserror::Error)]
54pub enum ClientBuilderError {
55 #[error("invalid uri: {0}")]
57 InvalidUri(InvalidUri),
58
59 #[error("invalid origin: {0}")]
61 InvalidOrigin(InvalidUri),
62
63 #[error("invalid user-agent: {0}")]
65 InvalidUserAgent(InvalidHeaderValue),
66}
67
68pub struct ClientConfigBuilder {
70 config: Result<ClientConfig, ClientBuilderError>,
71}
72
73impl ClientConfigBuilder {
74 pub fn uri(mut self, uri: impl TryInto<Uri, Error = InvalidUri>) -> Self {
87 self.config = self.config.and_then(|mut config| {
88 config
89 .uris
90 .push(uri.try_into().map_err(ClientBuilderError::InvalidUri)?);
91 Ok(config)
92 });
93 self
94 }
95
96 pub fn uris<I, T>(self, uris: I) -> Self
111 where
112 I: IntoIterator<Item = T>,
113 T: TryInto<Uri, Error = InvalidUri>,
114 {
115 uris.into_iter().fold(self, |acc, uri| acc.uri(uri))
116 }
117
118 pub fn origin(mut self, origin: impl TryInto<Uri, Error = InvalidUri>) -> Self {
120 self.config = self.config.and_then(|mut config| {
121 config.origin = Some(
122 origin
123 .try_into()
124 .map_err(ClientBuilderError::InvalidOrigin)?,
125 );
126 Ok(config)
127 });
128 self
129 }
130
131 pub fn user_agent(
133 mut self,
134 user_agent: impl TryInto<HeaderValue, Error = InvalidHeaderValue>,
135 ) -> Self {
136 self.config = self.config.and_then(|mut config| {
137 config.user_agent = Some(
138 user_agent
139 .try_into()
140 .map_err(ClientBuilderError::InvalidUserAgent)?,
141 );
142 Ok(config)
143 });
144 self
145 }
146
147 pub fn tls_config(mut self, tls_config: TlsClientConfig) -> Self {
149 if let Ok(config) = &mut self.config {
150 config.tls_config = Some(tls_config);
151 }
152 self
153 }
154
155 pub fn http2_max_header_list_size(mut self, max: u32) -> Self {
159 if let Ok(config) = &mut self.config {
160 config.max_header_list_size = max;
161 }
162 self
163 }
164
165 pub fn build(self) -> Result<ClientConfig, ClientBuilderError> {
167 self.config
168 }
169}
170
171#[doc(hidden)]
172#[derive(Clone)]
173pub struct GrpcClient {
174 ep: Arc<dyn DynEndpoint<Output = HttpResponse> + 'static>,
175 send_compressed: Option<CompressionEncoding>,
176 accept_compressed: Arc<[CompressionEncoding]>,
177}
178
179impl GrpcClient {
180 #[inline]
181 pub fn new(config: ClientConfig) -> Self {
182 Self {
183 ep: create_client_endpoint(config),
184 send_compressed: None,
185 accept_compressed: Arc::new([]),
186 }
187 }
188
189 pub fn from_endpoint<T>(ep: T) -> Self
190 where
191 T: IntoEndpoint,
192 T::Endpoint: 'static,
193 <T::Endpoint as Endpoint>::Output: 'static,
194 {
195 Self {
196 ep: Arc::new(ToDynEndpoint(ep.map_to_response())),
197 send_compressed: None,
198 accept_compressed: Arc::new([]),
199 }
200 }
201
202 pub fn set_send_compressed(&mut self, encoding: CompressionEncoding) {
203 self.send_compressed = Some(encoding);
204 }
205
206 pub fn set_accept_compressed(&mut self, encodings: impl Into<Arc<[CompressionEncoding]>>) {
207 self.accept_compressed = encodings.into();
208 }
209
210 pub fn with<M>(mut self, middleware: M) -> Self
211 where
212 M: Middleware<Arc<dyn DynEndpoint<Output = HttpResponse> + 'static>>,
213 M::Output: 'static,
214 {
215 self.ep = Arc::new(ToDynEndpoint(
216 middleware.transform(self.ep).map_to_response(),
217 ));
218 self
219 }
220
221 pub async fn unary<T: Codec>(
222 &self,
223 path: &str,
224 mut codec: T,
225 request: Request<T::Encode>,
226 ) -> Result<Response<T::Decode>, Status> {
227 let Request {
228 metadata,
229 message,
230 extensions,
231 } = request;
232 let mut http_request =
233 create_http_request::<T>(path, metadata, extensions, self.send_compressed);
234 http_request.set_body(create_encode_request_body(
235 codec.encoder(),
236 Streaming::new(futures_util::stream::once(async move { Ok(message) })),
237 self.send_compressed,
238 ));
239
240 let mut resp = self
241 .ep
242 .call(http_request)
243 .await
244 .map_err(|err| Status::new(Code::Internal).with_message(err))?;
245
246 if resp.status() != StatusCode::OK {
247 return Err(Status::new(Code::Internal).with_message(format!(
248 "invalid http status code: {}",
249 resp.status().as_u16()
250 )));
251 }
252
253 let body = resp.take_body();
254 let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?;
255 let mut stream =
256 create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?;
257
258 let message = stream
259 .try_next()
260 .await?
261 .ok_or_else(|| Status::new(Code::Internal).with_message("missing response message"))?;
262 Ok(Response {
263 metadata: Metadata {
264 headers: std::mem::take(resp.headers_mut()),
265 },
266 message,
267 })
268 }
269
270 pub async fn client_streaming<T: Codec>(
271 &self,
272 path: &str,
273 mut codec: T,
274 request: Request<Streaming<T::Encode>>,
275 ) -> Result<Response<T::Decode>, Status> {
276 let Request {
277 metadata,
278 message,
279 extensions,
280 } = request;
281 let mut http_request =
282 create_http_request::<T>(path, metadata, extensions, self.send_compressed);
283 http_request.set_body(create_encode_request_body(
284 codec.encoder(),
285 message,
286 self.send_compressed,
287 ));
288
289 let mut resp = self
290 .ep
291 .call(http_request)
292 .await
293 .map_err(|err| Status::new(Code::Internal).with_message(err))?;
294
295 if resp.status() != StatusCode::OK {
296 return Err(Status::new(Code::Internal).with_message(format!(
297 "invalid http status code: {}",
298 resp.status().as_u16()
299 )));
300 }
301
302 let body = resp.take_body();
303 let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?;
304 let mut stream =
305 create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?;
306
307 let message = stream
308 .try_next()
309 .await?
310 .ok_or_else(|| Status::new(Code::Internal).with_message("missing response message"))?;
311 Ok(Response {
312 metadata: Metadata {
313 headers: std::mem::take(resp.headers_mut()),
314 },
315 message,
316 })
317 }
318
319 pub async fn server_streaming<T: Codec>(
320 &self,
321 path: &str,
322 mut codec: T,
323 request: Request<T::Encode>,
324 ) -> Result<Response<Streaming<T::Decode>>, Status> {
325 let Request {
326 metadata,
327 message,
328 extensions,
329 } = request;
330 let mut http_request =
331 create_http_request::<T>(path, metadata, extensions, self.send_compressed);
332 http_request.set_body(create_encode_request_body(
333 codec.encoder(),
334 Streaming::new(futures_util::stream::once(async move { Ok(message) })),
335 self.send_compressed,
336 ));
337
338 let mut resp = self
339 .ep
340 .call(http_request)
341 .await
342 .map_err(|err| Status::new(Code::Internal).with_message(err))?;
343
344 if resp.status() != StatusCode::OK {
345 return Err(Status::new(Code::Internal).with_message(format!(
346 "invalid http status code: {}",
347 resp.status().as_u16()
348 )));
349 }
350
351 let body = resp.take_body();
352 let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?;
353 let stream =
354 create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?;
355
356 Ok(Response {
357 metadata: Metadata {
358 headers: std::mem::take(resp.headers_mut()),
359 },
360 message: stream,
361 })
362 }
363
364 pub async fn bidirectional_streaming<T: Codec>(
365 &self,
366 path: &str,
367 mut codec: T,
368 request: Request<Streaming<T::Encode>>,
369 ) -> Result<Response<Streaming<T::Decode>>, Status> {
370 let Request {
371 metadata,
372 message,
373 extensions,
374 } = request;
375 let mut http_request =
376 create_http_request::<T>(path, metadata, extensions, self.send_compressed);
377 http_request.set_body(create_encode_request_body(
378 codec.encoder(),
379 message,
380 self.send_compressed,
381 ));
382
383 let mut resp = self
384 .ep
385 .call(http_request)
386 .await
387 .map_err(|err| Status::new(Code::Internal).with_message(err))?;
388
389 if resp.status() != StatusCode::OK {
390 return Err(Status::new(Code::Internal).with_message(format!(
391 "invalid http status code: {}",
392 resp.status().as_u16()
393 )));
394 }
395
396 let body = resp.take_body();
397 let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?;
398 let stream =
399 create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?;
400
401 Ok(Response {
402 metadata: Metadata {
403 headers: std::mem::take(resp.headers_mut()),
404 },
405 message: stream,
406 })
407 }
408}
409
410fn create_http_request<T: Codec>(
411 path: &str,
412 metadata: Metadata,
413 extensions: Extensions,
414 send_compressed: Option<CompressionEncoding>,
415) -> HttpRequest {
416 let mut http_request = HttpRequest::builder()
417 .uri_str(path)
418 .method(Method::POST)
419 .version(Version::HTTP_2)
420 .finish();
421 *http_request.headers_mut() = metadata.headers;
422 *http_request.extensions_mut() = extensions;
423 http_request
424 .headers_mut()
425 .insert("content-type", T::CONTENT_TYPES[0].parse().unwrap());
426 http_request
427 .headers_mut()
428 .insert(header::TE, "trailers".parse().unwrap());
429 if let Some(send_compressed) = send_compressed {
430 http_request.headers_mut().insert(
431 "grpc-encoding",
432 HeaderValue::from_str(send_compressed.as_str()).expect("BUG: invalid encoding"),
433 );
434 }
435 http_request
436}
437
438#[inline]
439fn to_boxed_error(
440 err: impl std::error::Error + Send + Sync + 'static,
441) -> Box<dyn std::error::Error + Send + Sync> {
442 Box::new(err)
443}
444
445fn make_uri(base_uri: &Uri, path: &Uri) -> Uri {
446 let path = path.path_and_query().unwrap().path();
447 let mut parts = base_uri.clone().into_parts();
448 match parts.path_and_query {
449 Some(path_and_query) => {
450 let mut new_path = format!("{}{}", path_and_query.path().trim_end_matches('/'), path);
451 if let Some(query) = path_and_query.query() {
452 new_path.push('?');
453 new_path.push_str(query);
454 }
455 parts.path_and_query = Some(new_path.parse().unwrap());
456 }
457 None => {
458 parts.path_and_query = Some(path.parse().unwrap());
459 }
460 }
461 Uri::from_parts(parts).unwrap()
462}
463
464fn create_client_endpoint(
465 config: ClientConfig,
466) -> Arc<dyn DynEndpoint<Output = HttpResponse> + 'static> {
467 let mut config = config;
468 let cli = Client::builder(TokioExecutor::new())
469 .http2_only(true)
470 .http2_max_header_list_size(config.max_header_list_size)
471 .build(HttpsConnector::new(config.tls_config.take()));
472
473 let config = Arc::new(config);
474
475 Arc::new(ToDynEndpoint(poem::endpoint::make(move |request| {
476 let config = config.clone();
477 let cli = cli.clone();
478 async move {
479 let mut request: hyper::Request<BoxBody> = request.into();
480
481 if config.uris.is_empty() {
482 return Err(poem::Error::from_string(
483 "uris is empty",
484 StatusCode::INTERNAL_SERVER_ERROR,
485 ));
486 }
487
488 let base_uri = if config.uris.len() == 1 {
489 &config.uris[0]
490 } else {
491 &config.uris[fastrand::usize(0..config.uris.len())]
492 };
493 *request.uri_mut() = make_uri(base_uri, request.uri());
494
495 if let Some(origin) = &config.origin {
496 if let Ok(value) = HeaderValue::from_maybe_shared(origin.to_string()) {
497 request.headers_mut().insert(header::ORIGIN, value);
498 }
499 }
500
501 if let Some(user_agent) = &config.user_agent {
502 request
503 .headers_mut()
504 .insert(header::ORIGIN, user_agent.clone());
505 }
506
507 let resp = cli.request(request).await.map_err(to_boxed_error)?;
508 let (parts, body) = resp.into_parts();
509
510 Ok::<_, poem::Error>(HttpResponse::from(hyper::Response::from_parts(
511 parts,
512 body.map_err(IoError::other),
513 )))
514 }
515 })))
516}