poem_grpc/
client.rs

1use std::{io::Error as IoError, sync::Arc};
2
3use bytes::Bytes;
4use futures_util::TryStreamExt;
5use http_body_util::BodyExt;
6use hyper_util::{client::legacy::Client, rt::TokioExecutor};
7use poem::{
8    Endpoint, EndpointExt, IntoEndpoint, Middleware, Request as HttpRequest,
9    Response as HttpResponse,
10    endpoint::{DynEndpoint, ToDynEndpoint},
11    http::{
12        Extensions, HeaderValue, Method, StatusCode, Uri, Version,
13        header::{self, InvalidHeaderValue},
14        uri::InvalidUri,
15    },
16};
17use rustls::ClientConfig as TlsClientConfig;
18
19use crate::{
20    Code, CompressionEncoding, Metadata, Request, Response, Status, Streaming,
21    codec::Codec,
22    compression::get_incoming_encodings,
23    connector::HttpsConnector,
24    encoding::{create_decode_response_body, create_encode_request_body},
25};
26
27pub(crate) type BoxBody = http_body_util::combinators::BoxBody<Bytes, IoError>;
28
29/// A configuration for GRPC client
30pub struct ClientConfig {
31    uris: Vec<Uri>,
32    origin: Option<Uri>,
33    user_agent: Option<HeaderValue>,
34    tls_config: Option<TlsClientConfig>,
35    max_header_list_size: u32,
36}
37
38impl ClientConfig {
39    /// Create a `ClientConfig` builder
40    pub fn builder() -> ClientConfigBuilder {
41        ClientConfigBuilder {
42            config: Ok(ClientConfig {
43                uris: vec![],
44                origin: None,
45                user_agent: None,
46                tls_config: None,
47                max_header_list_size: 16384,
48            }),
49        }
50    }
51}
52
53#[derive(Debug, thiserror::Error)]
54pub enum ClientBuilderError {
55    /// Invalid uri
56    #[error("invalid uri: {0}")]
57    InvalidUri(InvalidUri),
58
59    /// Invalid origin
60    #[error("invalid origin: {0}")]
61    InvalidOrigin(InvalidUri),
62
63    /// Invalid user-agent
64    #[error("invalid user-agent: {0}")]
65    InvalidUserAgent(InvalidHeaderValue),
66}
67
68/// A `ClientConfig` builder
69pub struct ClientConfigBuilder {
70    config: Result<ClientConfig, ClientBuilderError>,
71}
72
73impl ClientConfigBuilder {
74    /// Add a uri as GRPC endpoint
75    ///
76    /// # Examples
77    ///
78    /// ```rust
79    /// # use poem_grpc::ClientConfig;
80    /// let cfg = ClientConfig::builder()
81    ///     .uri("http://server1:3000")
82    ///     .uri("http://server2:3000")
83    ///     .uri("http://server3:3000")
84    ///     .build();
85    /// ```
86    pub fn uri(mut self, uri: impl TryInto<Uri, Error = InvalidUri>) -> Self {
87        self.config = self.config.and_then(|mut config| {
88            config
89                .uris
90                .push(uri.try_into().map_err(ClientBuilderError::InvalidUri)?);
91            Ok(config)
92        });
93        self
94    }
95
96    /// Add some uris
97    ///
98    /// # Examples
99    ///
100    /// ```rust
101    /// # use poem_grpc::ClientConfig;
102    /// let cfg = ClientConfig::builder()
103    ///     .uris([
104    ///         "http://server1:3000",
105    ///         "http://server2:3000",
106    ///         "http://server3:3000",
107    ///     ])
108    ///     .build();
109    /// ```
110    pub fn uris<I, T>(self, uris: I) -> Self
111    where
112        I: IntoIterator<Item = T>,
113        T: TryInto<Uri, Error = InvalidUri>,
114    {
115        uris.into_iter().fold(self, |acc, uri| acc.uri(uri))
116    }
117
118    /// Set `Origin` header for each requests.
119    pub fn origin(mut self, origin: impl TryInto<Uri, Error = InvalidUri>) -> Self {
120        self.config = self.config.and_then(|mut config| {
121            config.origin = Some(
122                origin
123                    .try_into()
124                    .map_err(ClientBuilderError::InvalidOrigin)?,
125            );
126            Ok(config)
127        });
128        self
129    }
130
131    /// Set `User-Agent` header for each requests.
132    pub fn user_agent(
133        mut self,
134        user_agent: impl TryInto<HeaderValue, Error = InvalidHeaderValue>,
135    ) -> Self {
136        self.config = self.config.and_then(|mut config| {
137            config.user_agent = Some(
138                user_agent
139                    .try_into()
140                    .map_err(ClientBuilderError::InvalidUserAgent)?,
141            );
142            Ok(config)
143        });
144        self
145    }
146
147    /// Set `TlsConfig` for `HTTPS` uri
148    pub fn tls_config(mut self, tls_config: TlsClientConfig) -> Self {
149        if let Ok(config) = &mut self.config {
150            config.tls_config = Some(tls_config);
151        }
152        self
153    }
154
155    /// Sets the max size of received header frames.
156    ///
157    /// Default is `16384` bytes.
158    pub fn http2_max_header_list_size(mut self, max: u32) -> Self {
159        if let Ok(config) = &mut self.config {
160            config.max_header_list_size = max;
161        }
162        self
163    }
164
165    /// Consumes this builder and returns the `ClientConfig`
166    pub fn build(self) -> Result<ClientConfig, ClientBuilderError> {
167        self.config
168    }
169}
170
171#[doc(hidden)]
172#[derive(Clone)]
173pub struct GrpcClient {
174    ep: Arc<dyn DynEndpoint<Output = HttpResponse> + 'static>,
175    send_compressed: Option<CompressionEncoding>,
176    accept_compressed: Arc<[CompressionEncoding]>,
177}
178
179impl GrpcClient {
180    #[inline]
181    pub fn new(config: ClientConfig) -> Self {
182        Self {
183            ep: create_client_endpoint(config),
184            send_compressed: None,
185            accept_compressed: Arc::new([]),
186        }
187    }
188
189    pub fn from_endpoint<T>(ep: T) -> Self
190    where
191        T: IntoEndpoint,
192        T::Endpoint: 'static,
193        <T::Endpoint as Endpoint>::Output: 'static,
194    {
195        Self {
196            ep: Arc::new(ToDynEndpoint(ep.map_to_response())),
197            send_compressed: None,
198            accept_compressed: Arc::new([]),
199        }
200    }
201
202    pub fn set_send_compressed(&mut self, encoding: CompressionEncoding) {
203        self.send_compressed = Some(encoding);
204    }
205
206    pub fn set_accept_compressed(&mut self, encodings: impl Into<Arc<[CompressionEncoding]>>) {
207        self.accept_compressed = encodings.into();
208    }
209
210    pub fn with<M>(mut self, middleware: M) -> Self
211    where
212        M: Middleware<Arc<dyn DynEndpoint<Output = HttpResponse> + 'static>>,
213        M::Output: 'static,
214    {
215        self.ep = Arc::new(ToDynEndpoint(
216            middleware.transform(self.ep).map_to_response(),
217        ));
218        self
219    }
220
221    pub async fn unary<T: Codec>(
222        &self,
223        path: &str,
224        mut codec: T,
225        request: Request<T::Encode>,
226    ) -> Result<Response<T::Decode>, Status> {
227        let Request {
228            metadata,
229            message,
230            extensions,
231        } = request;
232        let mut http_request =
233            create_http_request::<T>(path, metadata, extensions, self.send_compressed);
234        http_request.set_body(create_encode_request_body(
235            codec.encoder(),
236            Streaming::new(futures_util::stream::once(async move { Ok(message) })),
237            self.send_compressed,
238        ));
239
240        let mut resp = self
241            .ep
242            .call(http_request)
243            .await
244            .map_err(|err| Status::new(Code::Internal).with_message(err))?;
245
246        if resp.status() != StatusCode::OK {
247            return Err(Status::new(Code::Internal).with_message(format!(
248                "invalid http status code: {}",
249                resp.status().as_u16()
250            )));
251        }
252
253        let body = resp.take_body();
254        let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?;
255        let mut stream =
256            create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?;
257
258        let message = stream
259            .try_next()
260            .await?
261            .ok_or_else(|| Status::new(Code::Internal).with_message("missing response message"))?;
262        Ok(Response {
263            metadata: Metadata {
264                headers: std::mem::take(resp.headers_mut()),
265            },
266            message,
267        })
268    }
269
270    pub async fn client_streaming<T: Codec>(
271        &self,
272        path: &str,
273        mut codec: T,
274        request: Request<Streaming<T::Encode>>,
275    ) -> Result<Response<T::Decode>, Status> {
276        let Request {
277            metadata,
278            message,
279            extensions,
280        } = request;
281        let mut http_request =
282            create_http_request::<T>(path, metadata, extensions, self.send_compressed);
283        http_request.set_body(create_encode_request_body(
284            codec.encoder(),
285            message,
286            self.send_compressed,
287        ));
288
289        let mut resp = self
290            .ep
291            .call(http_request)
292            .await
293            .map_err(|err| Status::new(Code::Internal).with_message(err))?;
294
295        if resp.status() != StatusCode::OK {
296            return Err(Status::new(Code::Internal).with_message(format!(
297                "invalid http status code: {}",
298                resp.status().as_u16()
299            )));
300        }
301
302        let body = resp.take_body();
303        let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?;
304        let mut stream =
305            create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?;
306
307        let message = stream
308            .try_next()
309            .await?
310            .ok_or_else(|| Status::new(Code::Internal).with_message("missing response message"))?;
311        Ok(Response {
312            metadata: Metadata {
313                headers: std::mem::take(resp.headers_mut()),
314            },
315            message,
316        })
317    }
318
319    pub async fn server_streaming<T: Codec>(
320        &self,
321        path: &str,
322        mut codec: T,
323        request: Request<T::Encode>,
324    ) -> Result<Response<Streaming<T::Decode>>, Status> {
325        let Request {
326            metadata,
327            message,
328            extensions,
329        } = request;
330        let mut http_request =
331            create_http_request::<T>(path, metadata, extensions, self.send_compressed);
332        http_request.set_body(create_encode_request_body(
333            codec.encoder(),
334            Streaming::new(futures_util::stream::once(async move { Ok(message) })),
335            self.send_compressed,
336        ));
337
338        let mut resp = self
339            .ep
340            .call(http_request)
341            .await
342            .map_err(|err| Status::new(Code::Internal).with_message(err))?;
343
344        if resp.status() != StatusCode::OK {
345            return Err(Status::new(Code::Internal).with_message(format!(
346                "invalid http status code: {}",
347                resp.status().as_u16()
348            )));
349        }
350
351        let body = resp.take_body();
352        let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?;
353        let stream =
354            create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?;
355
356        Ok(Response {
357            metadata: Metadata {
358                headers: std::mem::take(resp.headers_mut()),
359            },
360            message: stream,
361        })
362    }
363
364    pub async fn bidirectional_streaming<T: Codec>(
365        &self,
366        path: &str,
367        mut codec: T,
368        request: Request<Streaming<T::Encode>>,
369    ) -> Result<Response<Streaming<T::Decode>>, Status> {
370        let Request {
371            metadata,
372            message,
373            extensions,
374        } = request;
375        let mut http_request =
376            create_http_request::<T>(path, metadata, extensions, self.send_compressed);
377        http_request.set_body(create_encode_request_body(
378            codec.encoder(),
379            message,
380            self.send_compressed,
381        ));
382
383        let mut resp = self
384            .ep
385            .call(http_request)
386            .await
387            .map_err(|err| Status::new(Code::Internal).with_message(err))?;
388
389        if resp.status() != StatusCode::OK {
390            return Err(Status::new(Code::Internal).with_message(format!(
391                "invalid http status code: {}",
392                resp.status().as_u16()
393            )));
394        }
395
396        let body = resp.take_body();
397        let incoming_encoding = get_incoming_encodings(resp.headers(), &self.accept_compressed)?;
398        let stream =
399            create_decode_response_body(codec.decoder(), resp.headers(), body, incoming_encoding)?;
400
401        Ok(Response {
402            metadata: Metadata {
403                headers: std::mem::take(resp.headers_mut()),
404            },
405            message: stream,
406        })
407    }
408}
409
410fn create_http_request<T: Codec>(
411    path: &str,
412    metadata: Metadata,
413    extensions: Extensions,
414    send_compressed: Option<CompressionEncoding>,
415) -> HttpRequest {
416    let mut http_request = HttpRequest::builder()
417        .uri_str(path)
418        .method(Method::POST)
419        .version(Version::HTTP_2)
420        .finish();
421    *http_request.headers_mut() = metadata.headers;
422    *http_request.extensions_mut() = extensions;
423    http_request
424        .headers_mut()
425        .insert("content-type", T::CONTENT_TYPES[0].parse().unwrap());
426    http_request
427        .headers_mut()
428        .insert(header::TE, "trailers".parse().unwrap());
429    if let Some(send_compressed) = send_compressed {
430        http_request.headers_mut().insert(
431            "grpc-encoding",
432            HeaderValue::from_str(send_compressed.as_str()).expect("BUG: invalid encoding"),
433        );
434    }
435    http_request
436}
437
438#[inline]
439fn to_boxed_error(
440    err: impl std::error::Error + Send + Sync + 'static,
441) -> Box<dyn std::error::Error + Send + Sync> {
442    Box::new(err)
443}
444
445fn make_uri(base_uri: &Uri, path: &Uri) -> Uri {
446    let path = path.path_and_query().unwrap().path();
447    let mut parts = base_uri.clone().into_parts();
448    match parts.path_and_query {
449        Some(path_and_query) => {
450            let mut new_path = format!("{}{}", path_and_query.path().trim_end_matches('/'), path);
451            if let Some(query) = path_and_query.query() {
452                new_path.push('?');
453                new_path.push_str(query);
454            }
455            parts.path_and_query = Some(new_path.parse().unwrap());
456        }
457        None => {
458            parts.path_and_query = Some(path.parse().unwrap());
459        }
460    }
461    Uri::from_parts(parts).unwrap()
462}
463
464fn create_client_endpoint(
465    config: ClientConfig,
466) -> Arc<dyn DynEndpoint<Output = HttpResponse> + 'static> {
467    let mut config = config;
468    let cli = Client::builder(TokioExecutor::new())
469        .http2_only(true)
470        .http2_max_header_list_size(config.max_header_list_size)
471        .build(HttpsConnector::new(config.tls_config.take()));
472
473    let config = Arc::new(config);
474
475    Arc::new(ToDynEndpoint(poem::endpoint::make(move |request| {
476        let config = config.clone();
477        let cli = cli.clone();
478        async move {
479            let mut request: hyper::Request<BoxBody> = request.into();
480
481            if config.uris.is_empty() {
482                return Err(poem::Error::from_string(
483                    "uris is empty",
484                    StatusCode::INTERNAL_SERVER_ERROR,
485                ));
486            }
487
488            let base_uri = if config.uris.len() == 1 {
489                &config.uris[0]
490            } else {
491                &config.uris[fastrand::usize(0..config.uris.len())]
492            };
493            *request.uri_mut() = make_uri(base_uri, request.uri());
494
495            if let Some(origin) = &config.origin {
496                if let Ok(value) = HeaderValue::from_maybe_shared(origin.to_string()) {
497                    request.headers_mut().insert(header::ORIGIN, value);
498                }
499            }
500
501            if let Some(user_agent) = &config.user_agent {
502                request
503                    .headers_mut()
504                    .insert(header::ORIGIN, user_agent.clone());
505            }
506
507            let resp = cli.request(request).await.map_err(to_boxed_error)?;
508            let (parts, body) = resp.into_parts();
509
510            Ok::<_, poem::Error>(HttpResponse::from(hyper::Response::from_parts(
511                parts,
512                body.map_err(IoError::other),
513            )))
514        }
515    })))
516}