Skip to main content

ntex_grpc/server/
service.rs

1use std::{cell::RefCell, rc::Rc};
2
3use ntex_bytes::{Buf, BufMut, ByteString, BytesMut};
4use ntex_h2::{self as h2, StreamRef, frame::Reason, frame::StreamId};
5use ntex_http::{HeaderMap, HeaderValue, StatusCode, header::CONTENT_TYPE};
6use ntex_io::{Filter, Io, IoBoxed};
7use ntex_service::{Service, ServiceCtx, ServiceFactory, cfg::SharedCfg};
8use ntex_util::{HashMap, time::Millis, time::timeout_checked};
9
10use crate::{consts, status::GrpcStatus, utils::Data};
11
12use super::{ServerError, ServerRequest, ServerResponse};
13
14const ERR_DECODE: HeaderValue =
15    HeaderValue::from_static("Cannot decode request message: not enough data provided");
16const ERR_DATA_DECODE: HeaderValue =
17    HeaderValue::from_static("Cannot decode request message: not enough data provided");
18const ERR_DECODE_TIMEOUT: HeaderValue =
19    HeaderValue::from_static("Cannot decode grpc-timeout header");
20const ERR_DEADLINE: HeaderValue = HeaderValue::from_static("Deadline exceeded");
21const HDR_APP_GRPC: HeaderValue = HeaderValue::from_static("application/grpc");
22
23const MILLIS_IN_HOUR: u64 = 60 * 60 * 1000;
24const MILLIS_IN_MINUTE: u64 = 60 * 1000;
25
26/// Grpc server
27pub struct GrpcServer<T> {
28    factory: Rc<T>,
29}
30
31impl<T> GrpcServer<T> {
32    /// Create grpc server
33    pub fn new(factory: T) -> Self {
34        Self {
35            factory: Rc::new(factory),
36        }
37    }
38}
39
40impl<T> GrpcServer<T>
41where
42    T: ServiceFactory<ServerRequest, SharedCfg, Response = ServerResponse, Error = ServerError>,
43    T::Service: Clone,
44{
45    /// Create default server
46    pub fn make_server(&self, cfg: SharedCfg) -> GrpcService<T> {
47        log::trace!("{}: Starting grpc service", cfg.tag());
48
49        GrpcService {
50            cfg,
51            factory: self.factory.clone(),
52        }
53    }
54}
55
56impl<F, T> ServiceFactory<Io<F>, SharedCfg> for GrpcServer<T>
57where
58    F: Filter,
59    T: ServiceFactory<ServerRequest, SharedCfg, Response = ServerResponse, Error = ServerError>
60        + 'static,
61    T::Service: Clone,
62{
63    type Response = ();
64    type Error = T::InitError;
65    type Service = GrpcService<T>;
66    type InitError = ();
67
68    async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
69        Ok(self.make_server(cfg))
70    }
71}
72
73pub struct GrpcService<T> {
74    cfg: SharedCfg,
75    factory: Rc<T>,
76}
77
78impl<T, F> Service<Io<F>> for GrpcService<T>
79where
80    F: Filter,
81    T: ServiceFactory<ServerRequest, SharedCfg, Response = ServerResponse, Error = ServerError>
82        + 'static,
83{
84    type Response = ();
85    type Error = T::InitError;
86
87    async fn call(&self, io: Io<F>, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
88        // init server
89        let service = self.factory.create(self.cfg.clone()).await?;
90
91        let _ = h2::server::handle_one(
92            io.into(),
93            PublishService::new(service, self.cfg.clone()),
94            ControlService,
95        )
96        .await;
97
98        Ok(())
99    }
100}
101
102impl<T> Service<IoBoxed> for GrpcService<T>
103where
104    T: ServiceFactory<ServerRequest, SharedCfg, Response = ServerResponse, Error = ServerError>
105        + 'static,
106{
107    type Response = ();
108    type Error = T::InitError;
109
110    async fn call(&self, io: IoBoxed, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
111        // init server
112        let service = self.factory.create(self.cfg.clone()).await?;
113
114        let _ = h2::server::handle_one(
115            io,
116            PublishService::new(service, self.cfg.clone()),
117            ControlService,
118        )
119        .await;
120
121        Ok(())
122    }
123}
124
125struct ControlService;
126
127impl Service<h2::Control<h2::StreamError>> for ControlService {
128    type Response = h2::ControlAck;
129    type Error = ();
130
131    async fn call(
132        &self,
133        msg: h2::Control<h2::StreamError>,
134        _: ServiceCtx<'_, Self>,
135    ) -> Result<Self::Response, Self::Error> {
136        log::trace!("Control message: {msg:?}");
137        Ok::<_, ()>(msg.ack())
138    }
139}
140
141struct PublishService<S: Service<ServerRequest>> {
142    cfg: SharedCfg,
143    service: S,
144    streams: RefCell<HashMap<StreamId, Inflight>>,
145}
146
147struct Inflight {
148    name: ByteString,
149    service: ByteString,
150    data: Data,
151    headers: HeaderMap,
152}
153
154impl<S> PublishService<S>
155where
156    S: Service<ServerRequest, Response = ServerResponse, Error = ServerError>,
157{
158    fn new(service: S, cfg: SharedCfg) -> Self {
159        Self {
160            cfg,
161            service,
162            streams: RefCell::new(HashMap::default()),
163        }
164    }
165}
166
167impl<S> Service<h2::Message> for PublishService<S>
168where
169    S: Service<ServerRequest, Response = ServerResponse, Error = ServerError> + 'static,
170{
171    type Response = ();
172    type Error = h2::StreamError;
173
174    #[allow(clippy::await_holding_refcell_ref, clippy::too_many_lines)]
175    async fn call(
176        &self,
177        msg: h2::Message,
178        ctx: ServiceCtx<'_, Self>,
179    ) -> Result<Self::Response, Self::Error> {
180        let id = msg.id();
181        let h2::Message { stream, kind } = msg;
182        let mut streams = self.streams.borrow_mut();
183
184        match kind {
185            h2::MessageKind::Headers {
186                headers,
187                pseudo,
188                eof,
189            } => {
190                let mut path = pseudo.path.unwrap().split_off(1);
191                let srvname = if let Some(n) = path.find('/') {
192                    path.split_to(n)
193                } else {
194                    // not found
195                    let _ = stream.send_response(StatusCode::NOT_FOUND, hdrs(), true);
196                    return Ok(());
197                };
198
199                // stream eof, cannot do anything
200                if eof {
201                    if stream.send_response(StatusCode::OK, hdrs(), false).is_ok() {
202                        send_error(&stream, GrpcStatus::InvalidArgument, ERR_DECODE);
203                    }
204                    return Ok(());
205                }
206
207                let mut path = path.split_off(1);
208                let methodname = if let Some(n) = path.find('/') {
209                    path.split_to(n)
210                } else {
211                    path
212                };
213
214                let _ = streams.insert(
215                    stream.id(),
216                    Inflight {
217                        headers,
218                        data: Data::Empty,
219                        name: methodname,
220                        service: srvname,
221                    },
222                );
223            }
224            h2::MessageKind::Data(data, _cap) => {
225                if let Some(inflight) = streams.get_mut(&stream.id()) {
226                    inflight.data.push(data);
227                }
228            }
229            h2::MessageKind::Eof(data) => {
230                if let Some(mut inflight) = streams.remove(&id) {
231                    match data {
232                        h2::StreamEof::Data(chunk) => inflight.data.push(chunk),
233                        h2::StreamEof::Trailers(hdrs) => {
234                            for (name, val) in &hdrs {
235                                inflight.headers.insert(name.clone(), val.clone());
236                            }
237                        }
238                        h2::StreamEof::Error(err) => return Err(err.into_error()),
239                    }
240
241                    let mut data = inflight.data.get();
242                    let _compressed = data.get_u8();
243                    let len = data.get_u32();
244                    if (len as usize) > data.len() {
245                        if stream.send_response(StatusCode::OK, hdrs(), false).is_ok() {
246                            send_error(&stream, GrpcStatus::InvalidArgument, ERR_DATA_DECODE);
247                        }
248                        return Ok(());
249                    }
250                    let data = data
251                        .split_to_checked(len as usize)
252                        .ok_or(h2::StreamError::Reset(Reason::PROTOCOL_ERROR))?;
253
254                    log::debug!(
255                        "{}: Call service {} method {}",
256                        self.cfg.tag(),
257                        inflight.service,
258                        inflight.name
259                    );
260                    let req = ServerRequest {
261                        payload: data,
262                        name: inflight.name,
263                        headers: inflight.headers,
264                    };
265                    if stream.send_response(StatusCode::OK, hdrs(), false).is_err() {
266                        return Ok(());
267                    }
268                    drop(streams);
269
270                    // GRPC Timeout
271                    let to = if let Some(to) = req.headers.get(consts::GRPC_TIMEOUT) {
272                        if let Ok(to) = try_parse_grpc_timeout(to) {
273                            to
274                        } else {
275                            send_error(&stream, GrpcStatus::InvalidArgument, ERR_DECODE_TIMEOUT);
276                            return Ok(());
277                        }
278                    } else {
279                        Millis::ZERO
280                    };
281
282                    match timeout_checked(to, ctx.call(&self.service, req)).await {
283                        Ok(Ok(res)) => {
284                            log::debug!("{}: Response is received {res:?}", self.cfg.tag());
285                            let mut buf = BytesMut::with_capacity(res.payload.len() + 5);
286                            buf.put_u8(0); // compression
287                            buf.put_u32(res.payload.len() as u32); // length
288                            buf.extend_from_slice(&res.payload);
289
290                            let _ = stream.send_payload(buf.freeze(), false).await;
291
292                            let mut trailers = HeaderMap::default();
293                            trailers.insert(consts::GRPC_STATUS, GrpcStatus::Ok.into());
294                            for (name, val) in res.headers {
295                                trailers.append(name, val);
296                            }
297
298                            stream.send_trailers(trailers);
299                        }
300                        Ok(Err(err)) => {
301                            log::debug!(
302                                "{}: Failure during service call: {:?}",
303                                self.cfg.tag(),
304                                err.message
305                            );
306                            let mut trailers = err.headers;
307                            trailers.insert(consts::GRPC_STATUS, err.status.into());
308                            trailers.insert(consts::GRPC_MESSAGE, err.message);
309                            stream.send_trailers(trailers);
310                        }
311                        Err(()) => {
312                            log::debug!(
313                                "{}: Deadline exceeded failure during service call",
314                                self.cfg.tag()
315                            );
316                            send_error(&stream, GrpcStatus::DeadlineExceeded, ERR_DEADLINE);
317                        }
318                    }
319
320                    return Ok(());
321                }
322            }
323            h2::MessageKind::Disconnect(_) => {
324                streams.remove(&id);
325            }
326        }
327        Ok(())
328    }
329}
330
331fn hdrs() -> HeaderMap {
332    let mut hdrs = HeaderMap::default();
333    hdrs.insert(CONTENT_TYPE, HDR_APP_GRPC);
334    hdrs
335}
336
337fn send_error(stream: &StreamRef, st: GrpcStatus, msg: HeaderValue) {
338    let mut trailers = HeaderMap::default();
339    trailers.insert(consts::GRPC_STATUS, st.into());
340    trailers.insert(consts::GRPC_MESSAGE, msg);
341    stream.send_trailers(trailers);
342}
343
344/// Tries to parse the `grpc-timeout` header if it is present.
345///
346/// Follows the [gRPC over HTTP2 spec](https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md).
347fn try_parse_grpc_timeout(val: &HeaderValue) -> Result<Millis, ()> {
348    let (timeout_value, timeout_unit) = val
349        .to_str()
350        .map_err(|_| ())
351        .and_then(|s| if s.is_empty() { Err(()) } else { Ok(s) })?
352        .split_at(val.len() - 1);
353
354    // gRPC spec specifies `TimeoutValue` will be at most 8 digits
355    // Caping this at 8 digits also prevents integer overflow from ever occurring
356    if timeout_value.len() > 8 {
357        return Err(());
358    }
359
360    let timeout_value: u64 = timeout_value.parse().map_err(|_| ())?;
361    let duration = match timeout_unit {
362        // Hours
363        "H" => Millis(u32::try_from(timeout_value * MILLIS_IN_HOUR).unwrap_or(u32::MAX)),
364        // Minutes
365        "M" => Millis(u32::try_from(timeout_value * MILLIS_IN_MINUTE).unwrap_or(u32::MAX)),
366        // Seconds
367        "S" => Millis(u32::try_from(timeout_value * 1000).unwrap_or(u32::MAX)),
368        // Milliseconds
369        "m" => Millis(u32::try_from(timeout_value).unwrap_or(u32::MAX)),
370        // Microseconds
371        "u" => Millis(u32::try_from(timeout_value / 1000).unwrap_or(u32::MAX)),
372        // Nanoseconds
373        "n" => Millis(u32::try_from(timeout_value / 1_000_000).unwrap_or(u32::MAX)),
374        _ => return Err(()),
375    };
376
377    Ok(duration)
378}