Skip to main content

ntex_grpc/client/
transport.rs

1use std::{convert::TryFrom, str::FromStr};
2
3use ntex_bytes::{Buf, BufMut, BytesMut};
4use ntex_error::Error;
5use ntex_h2::{self as h2};
6use ntex_http::{HeaderMap, Method, header};
7
8use super::{Client, ClientError, Transport, request::RequestContext, request::Response};
9use crate::{DecodeError, GrpcStatus, Message, consts, service::MethodDef, utils::Data};
10
11impl<T: MethodDef> Transport<T> for Client {
12    type Error = Error<ClientError>;
13
14    #[inline]
15    async fn request(
16        &self,
17        val: &T::Input,
18        ctx: RequestContext,
19    ) -> Result<Response<T>, Self::Error> {
20        Transport::request(&self.0, val, ctx).await
21    }
22}
23
24impl<T: MethodDef> Transport<T> for h2::client::Client {
25    type Error = Error<ClientError>;
26
27    #[inline]
28    async fn request(
29        &self,
30        val: &T::Input,
31        ctx: RequestContext,
32    ) -> Result<Response<T>, Self::Error> {
33        Transport::request(
34            &self.client().await.map_err(|e| e.map(ClientError::from))?,
35            val,
36            ctx,
37        )
38        .await
39    }
40}
41
42impl<T: MethodDef> Transport<T> for h2::client::SimpleClient {
43    type Error = Error<ClientError>;
44
45    #[allow(clippy::too_many_lines)]
46    async fn request(
47        &self,
48        val: &T::Input,
49        ctx: RequestContext,
50    ) -> Result<Response<T>, Self::Error> {
51        let len = val.encoded_len();
52        let mut buf = BytesMut::with_capacity(len + 5);
53        buf.put_u8(0); // compression
54        buf.put_u32(len as u32); // length
55        val.write(&mut buf);
56        let req_size = buf.len();
57
58        let mut hdrs = HeaderMap::new();
59        hdrs.append(header::CONTENT_TYPE, consts::HDRV_CT_GRPC);
60        hdrs.append(header::USER_AGENT, consts::HDRV_USER_AGENT);
61        hdrs.insert(header::TE, consts::HDRV_TRAILERS);
62        hdrs.insert(consts::GRPC_ENCODING, consts::IDENTITY);
63        hdrs.insert(consts::GRPC_ACCEPT_ENCODING, consts::IDENTITY);
64        for (key, val) in ctx.headers() {
65            hdrs.insert(key.clone(), val.clone());
66        }
67
68        // send request
69        let (snd_stream, rcv_stream) = self
70            .send(Method::POST, T::PATH, hdrs, false)
71            .await
72            .map_err(|e| e.map(ClientError::from))?;
73        if ctx.get_disconnect_on_drop() {
74            snd_stream.disconnect_on_drop();
75        }
76        snd_stream
77            .send_payload(buf.freeze(), true)
78            .await
79            .map_err(|e| e.map(ClientError::from))?;
80
81        // read response
82        let mut status = None;
83        let mut hdrs = HeaderMap::default();
84        let mut trailers = HeaderMap::default();
85        let mut payload = Data::Empty;
86
87        async {
88            loop {
89                let Some(msg) = rcv_stream.recv().await else {
90                    return Err(Error::from(ClientError::UnexpectedEof(status, hdrs)));
91                };
92
93                match msg.kind {
94                    h2::MessageKind::Headers {
95                        headers,
96                        pseudo,
97                        eof,
98                    } => {
99                        if eof {
100                            // check grpc status
101                            match check_grpc_status(&headers) {
102                                Some(Ok(GrpcStatus::DeadlineExceeded)) => {
103                                    return Err(Error::from(ClientError::DeadlineExceeded(hdrs)));
104                                }
105                                Some(Ok(status)) => {
106                                    if status != GrpcStatus::Ok {
107                                        return Err(Error::from(ClientError::GrpcStatus(
108                                            status, headers,
109                                        )));
110                                    }
111                                }
112                                Some(Err(())) => {
113                                    return Err(Error::from(ClientError::Decode(
114                                        DecodeError::new("Cannot parse grpc status"),
115                                    )));
116                                }
117                                None => {}
118                            }
119
120                            return Err(Error::from(ClientError::UnexpectedEof(
121                                pseudo.status,
122                                headers,
123                            )));
124                        }
125                        hdrs = headers;
126                        status = pseudo.status;
127                        continue;
128                    }
129                    h2::MessageKind::Data(data, _cap) => {
130                        payload.push(data);
131                        continue;
132                    }
133                    h2::MessageKind::Eof(data) => {
134                        match data {
135                            h2::StreamEof::Data(data) => {
136                                payload.push(data);
137                            }
138                            h2::StreamEof::Trailers(hdrs) => {
139                                // check grpc status
140                                match check_grpc_status(&hdrs) {
141                                    Some(Ok(GrpcStatus::Ok)) | None => Ok(()),
142                                    Some(Ok(GrpcStatus::DeadlineExceeded)) => {
143                                        return Err(Error::from(ClientError::DeadlineExceeded(
144                                            hdrs,
145                                        )));
146                                    }
147                                    Some(Ok(st)) => {
148                                        return Err(Error::from(ClientError::GrpcStatus(
149                                            st, hdrs,
150                                        )));
151                                    }
152                                    Some(Err(())) => Err(Error::from(ClientError::Decode(
153                                        DecodeError::new("Cannot parse grpc status"),
154                                    ))),
155                                }?;
156                                trailers = hdrs;
157                            }
158                            h2::StreamEof::Error(err) => {
159                                return Err(err.map(ClientError::Stream));
160                            }
161                        }
162                    }
163                    h2::MessageKind::Disconnect(err) => {
164                        return Err(err.map(ClientError::Operation));
165                    }
166                }
167
168                let mut data = payload.get();
169                match status {
170                    Some(st) => {
171                        if !st.is_success() {
172                            return Err(Error::from(ClientError::Response(Some(st), hdrs, data)));
173                        }
174                    }
175                    None => return Err(Error::from(ClientError::Response(None, hdrs, data))),
176                }
177                let _compressed = data.get_u8();
178                let len = data.get_u32();
179                let Some(mut block) = data.split_to_checked(len as usize) else {
180                    return Err(Error::from(ClientError::UnexpectedEof(None, hdrs)));
181                };
182
183                return match <T::Output as Message>::read(&mut block) {
184                    Ok(output) => Ok(Response {
185                        output,
186                        trailers,
187                        req_size,
188                        headers: hdrs,
189                        res_size: data.len(),
190                    }),
191                    Err(e) => Err(Error::from(ClientError::Decode(e))),
192                };
193            }
194        }
195        .await
196        .map_err(|e| e.set_service(self.service()))
197    }
198}
199
200fn check_grpc_status(hdrs: &HeaderMap) -> Option<Result<GrpcStatus, ()>> {
201    // check grpc status
202    if let Some(val) = hdrs.get(consts::GRPC_STATUS) {
203        if let Ok(status) = val
204            .to_str()
205            .map_err(|_| ())
206            .and_then(|v| u8::from_str(v).map_err(|_| ()))
207            .and_then(GrpcStatus::try_from)
208        {
209            Some(Ok(status))
210        } else {
211            Some(Err(()))
212        }
213    } else {
214        None
215    }
216}