1use std::{cell::RefCell, rc::Rc};
2
3use ntex_bytes::{Buf, BufMut, ByteString, BytesMut};
4use ntex_h2::{self as h2, StreamRef, frame::Reason, frame::StreamId};
5use ntex_http::{HeaderMap, HeaderValue, StatusCode, header::CONTENT_TYPE};
6use ntex_io::{Filter, Io, IoBoxed};
7use ntex_service::{Service, ServiceCtx, ServiceFactory, cfg::SharedCfg};
8use ntex_util::{HashMap, time::Millis, time::timeout_checked};
9
10use crate::{consts, status::GrpcStatus, utils::Data};
11
12use super::{ServerError, ServerRequest, ServerResponse};
13
14const ERR_DECODE: HeaderValue =
15 HeaderValue::from_static("Cannot decode request message: not enough data provided");
16const ERR_DATA_DECODE: HeaderValue =
17 HeaderValue::from_static("Cannot decode request message: not enough data provided");
18const ERR_DECODE_TIMEOUT: HeaderValue =
19 HeaderValue::from_static("Cannot decode grpc-timeout header");
20const ERR_DEADLINE: HeaderValue = HeaderValue::from_static("Deadline exceeded");
21const HDR_APP_GRPC: HeaderValue = HeaderValue::from_static("application/grpc");
22
23const MILLIS_IN_HOUR: u64 = 60 * 60 * 1000;
24const MILLIS_IN_MINUTE: u64 = 60 * 1000;
25
26pub struct GrpcServer<T> {
28 factory: Rc<T>,
29}
30
31impl<T> GrpcServer<T> {
32 pub fn new(factory: T) -> Self {
34 Self {
35 factory: Rc::new(factory),
36 }
37 }
38}
39
40impl<T> GrpcServer<T>
41where
42 T: ServiceFactory<ServerRequest, SharedCfg, Response = ServerResponse, Error = ServerError>,
43 T::Service: Clone,
44{
45 pub fn make_server(&self, cfg: SharedCfg) -> GrpcService<T> {
47 log::trace!("{}: Starting grpc service", cfg.tag());
48
49 GrpcService {
50 cfg,
51 factory: self.factory.clone(),
52 }
53 }
54}
55
56impl<F, T> ServiceFactory<Io<F>, SharedCfg> for GrpcServer<T>
57where
58 F: Filter,
59 T: ServiceFactory<ServerRequest, SharedCfg, Response = ServerResponse, Error = ServerError>
60 + 'static,
61 T::Service: Clone,
62{
63 type Response = ();
64 type Error = T::InitError;
65 type Service = GrpcService<T>;
66 type InitError = ();
67
68 async fn create(&self, cfg: SharedCfg) -> Result<Self::Service, Self::InitError> {
69 Ok(self.make_server(cfg))
70 }
71}
72
73pub struct GrpcService<T> {
74 cfg: SharedCfg,
75 factory: Rc<T>,
76}
77
78impl<T, F> Service<Io<F>> for GrpcService<T>
79where
80 F: Filter,
81 T: ServiceFactory<ServerRequest, SharedCfg, Response = ServerResponse, Error = ServerError>
82 + 'static,
83{
84 type Response = ();
85 type Error = T::InitError;
86
87 async fn call(&self, io: Io<F>, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
88 let service = self.factory.create(self.cfg.clone()).await?;
90
91 let _ = h2::server::handle_one(
92 io.into(),
93 PublishService::new(service, self.cfg.clone()),
94 ControlService,
95 )
96 .await;
97
98 Ok(())
99 }
100}
101
102impl<T> Service<IoBoxed> for GrpcService<T>
103where
104 T: ServiceFactory<ServerRequest, SharedCfg, Response = ServerResponse, Error = ServerError>
105 + 'static,
106{
107 type Response = ();
108 type Error = T::InitError;
109
110 async fn call(&self, io: IoBoxed, _: ServiceCtx<'_, Self>) -> Result<(), Self::Error> {
111 let service = self.factory.create(self.cfg.clone()).await?;
113
114 let _ = h2::server::handle_one(
115 io,
116 PublishService::new(service, self.cfg.clone()),
117 ControlService,
118 )
119 .await;
120
121 Ok(())
122 }
123}
124
125struct ControlService;
126
127impl Service<h2::Control<h2::StreamError>> for ControlService {
128 type Response = h2::ControlAck;
129 type Error = ();
130
131 async fn call(
132 &self,
133 msg: h2::Control<h2::StreamError>,
134 _: ServiceCtx<'_, Self>,
135 ) -> Result<Self::Response, Self::Error> {
136 log::trace!("Control message: {msg:?}");
137 Ok::<_, ()>(msg.ack())
138 }
139}
140
141struct PublishService<S: Service<ServerRequest>> {
142 cfg: SharedCfg,
143 service: S,
144 streams: RefCell<HashMap<StreamId, Inflight>>,
145}
146
147struct Inflight {
148 name: ByteString,
149 service: ByteString,
150 data: Data,
151 headers: HeaderMap,
152}
153
154impl<S> PublishService<S>
155where
156 S: Service<ServerRequest, Response = ServerResponse, Error = ServerError>,
157{
158 fn new(service: S, cfg: SharedCfg) -> Self {
159 Self {
160 cfg,
161 service,
162 streams: RefCell::new(HashMap::default()),
163 }
164 }
165}
166
167impl<S> Service<h2::Message> for PublishService<S>
168where
169 S: Service<ServerRequest, Response = ServerResponse, Error = ServerError> + 'static,
170{
171 type Response = ();
172 type Error = h2::StreamError;
173
174 #[allow(clippy::await_holding_refcell_ref, clippy::too_many_lines)]
175 async fn call(
176 &self,
177 msg: h2::Message,
178 ctx: ServiceCtx<'_, Self>,
179 ) -> Result<Self::Response, Self::Error> {
180 let id = msg.id();
181 let h2::Message { stream, kind } = msg;
182 let mut streams = self.streams.borrow_mut();
183
184 match kind {
185 h2::MessageKind::Headers {
186 headers,
187 pseudo,
188 eof,
189 } => {
190 let mut path = pseudo.path.unwrap().split_off(1);
191 let srvname = if let Some(n) = path.find('/') {
192 path.split_to(n)
193 } else {
194 let _ = stream.send_response(StatusCode::NOT_FOUND, hdrs(), true);
196 return Ok(());
197 };
198
199 if eof {
201 if stream.send_response(StatusCode::OK, hdrs(), false).is_ok() {
202 send_error(&stream, GrpcStatus::InvalidArgument, ERR_DECODE);
203 }
204 return Ok(());
205 }
206
207 let mut path = path.split_off(1);
208 let methodname = if let Some(n) = path.find('/') {
209 path.split_to(n)
210 } else {
211 path
212 };
213
214 let _ = streams.insert(
215 stream.id(),
216 Inflight {
217 headers,
218 data: Data::Empty,
219 name: methodname,
220 service: srvname,
221 },
222 );
223 }
224 h2::MessageKind::Data(data, _cap) => {
225 if let Some(inflight) = streams.get_mut(&stream.id()) {
226 inflight.data.push(data);
227 }
228 }
229 h2::MessageKind::Eof(data) => {
230 if let Some(mut inflight) = streams.remove(&id) {
231 match data {
232 h2::StreamEof::Data(chunk) => inflight.data.push(chunk),
233 h2::StreamEof::Trailers(hdrs) => {
234 for (name, val) in &hdrs {
235 inflight.headers.insert(name.clone(), val.clone());
236 }
237 }
238 h2::StreamEof::Error(err) => return Err(err.into_error()),
239 }
240
241 let mut data = inflight.data.get();
242 let _compressed = data.get_u8();
243 let len = data.get_u32();
244 if (len as usize) > data.len() {
245 if stream.send_response(StatusCode::OK, hdrs(), false).is_ok() {
246 send_error(&stream, GrpcStatus::InvalidArgument, ERR_DATA_DECODE);
247 }
248 return Ok(());
249 }
250 let data = data
251 .split_to_checked(len as usize)
252 .ok_or(h2::StreamError::Reset(Reason::PROTOCOL_ERROR))?;
253
254 log::debug!(
255 "{}: Call service {} method {}",
256 self.cfg.tag(),
257 inflight.service,
258 inflight.name
259 );
260 let req = ServerRequest {
261 payload: data,
262 name: inflight.name,
263 headers: inflight.headers,
264 };
265 if stream.send_response(StatusCode::OK, hdrs(), false).is_err() {
266 return Ok(());
267 }
268 drop(streams);
269
270 let to = if let Some(to) = req.headers.get(consts::GRPC_TIMEOUT) {
272 if let Ok(to) = try_parse_grpc_timeout(to) {
273 to
274 } else {
275 send_error(&stream, GrpcStatus::InvalidArgument, ERR_DECODE_TIMEOUT);
276 return Ok(());
277 }
278 } else {
279 Millis::ZERO
280 };
281
282 match timeout_checked(to, ctx.call(&self.service, req)).await {
283 Ok(Ok(res)) => {
284 log::debug!("{}: Response is received {res:?}", self.cfg.tag());
285 let mut buf = BytesMut::with_capacity(res.payload.len() + 5);
286 buf.put_u8(0); buf.put_u32(res.payload.len() as u32); buf.extend_from_slice(&res.payload);
289
290 let _ = stream.send_payload(buf.freeze(), false).await;
291
292 let mut trailers = HeaderMap::default();
293 trailers.insert(consts::GRPC_STATUS, GrpcStatus::Ok.into());
294 for (name, val) in res.headers {
295 trailers.append(name, val);
296 }
297
298 stream.send_trailers(trailers);
299 }
300 Ok(Err(err)) => {
301 log::debug!(
302 "{}: Failure during service call: {:?}",
303 self.cfg.tag(),
304 err.message
305 );
306 let mut trailers = err.headers;
307 trailers.insert(consts::GRPC_STATUS, err.status.into());
308 trailers.insert(consts::GRPC_MESSAGE, err.message);
309 stream.send_trailers(trailers);
310 }
311 Err(()) => {
312 log::debug!(
313 "{}: Deadline exceeded failure during service call",
314 self.cfg.tag()
315 );
316 send_error(&stream, GrpcStatus::DeadlineExceeded, ERR_DEADLINE);
317 }
318 }
319
320 return Ok(());
321 }
322 }
323 h2::MessageKind::Disconnect(_) => {
324 streams.remove(&id);
325 }
326 }
327 Ok(())
328 }
329}
330
331fn hdrs() -> HeaderMap {
332 let mut hdrs = HeaderMap::default();
333 hdrs.insert(CONTENT_TYPE, HDR_APP_GRPC);
334 hdrs
335}
336
337fn send_error(stream: &StreamRef, st: GrpcStatus, msg: HeaderValue) {
338 let mut trailers = HeaderMap::default();
339 trailers.insert(consts::GRPC_STATUS, st.into());
340 trailers.insert(consts::GRPC_MESSAGE, msg);
341 stream.send_trailers(trailers);
342}
343
344fn try_parse_grpc_timeout(val: &HeaderValue) -> Result<Millis, ()> {
348 let (timeout_value, timeout_unit) = val
349 .to_str()
350 .map_err(|_| ())
351 .and_then(|s| if s.is_empty() { Err(()) } else { Ok(s) })?
352 .split_at(val.len() - 1);
353
354 if timeout_value.len() > 8 {
357 return Err(());
358 }
359
360 let timeout_value: u64 = timeout_value.parse().map_err(|_| ())?;
361 let duration = match timeout_unit {
362 "H" => Millis(u32::try_from(timeout_value * MILLIS_IN_HOUR).unwrap_or(u32::MAX)),
364 "M" => Millis(u32::try_from(timeout_value * MILLIS_IN_MINUTE).unwrap_or(u32::MAX)),
366 "S" => Millis(u32::try_from(timeout_value * 1000).unwrap_or(u32::MAX)),
368 "m" => Millis(u32::try_from(timeout_value).unwrap_or(u32::MAX)),
370 "u" => Millis(u32::try_from(timeout_value / 1000).unwrap_or(u32::MAX)),
372 "n" => Millis(u32::try_from(timeout_value / 1_000_000).unwrap_or(u32::MAX)),
374 _ => return Err(()),
375 };
376
377 Ok(duration)
378}