1use std::task::{Context, Poll};
2use std::{cell::Cell, convert::TryFrom, fmt, future::Future, mem, ops, pin::Pin, rc::Rc, time};
3
4use ntex_http::{HeaderMap, HeaderName, HeaderValue, error::Error as HttpError};
5use ntex_util::future::BoxFuture;
6
7use crate::{client::Transport, consts, service::MethodDef};
8
9pub struct RequestContext(Rc<RequestContextInner>);
10
11bitflags::bitflags! {
12 #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
13 struct Flags: u8 {
14 const DISCONNECT_ON_DROP = 0b0000_0001;
15 }
16}
17
18struct RequestContextInner {
19 err: Option<HttpError>,
20 headers: Vec<(HeaderName, HeaderValue)>,
21 timeout: Cell<Option<time::Duration>>,
22 flags: Cell<Flags>,
23}
24
25impl RequestContext {
26 fn new() -> Self {
28 Self(Rc::new(RequestContextInner {
29 err: None,
30 headers: Vec::new(),
31 timeout: Cell::new(None),
32 flags: Cell::new(Flags::empty()),
33 }))
34 }
35
36 pub fn get_timeout(&self) -> Option<time::Duration> {
38 self.0.timeout.get()
39 }
40
41 pub fn timeout<U>(&mut self, timeout: U) -> &mut Self
48 where
49 time::Duration: From<U>,
50 {
51 let to = timeout.into();
52 self.0.timeout.set(Some(to));
53 self.header(consts::GRPC_TIMEOUT, duration_to_grpc_timeout(to));
54 self
55 }
56
57 pub fn disconnect_on_drop(&mut self) -> &mut Self {
59 let mut flags = self.0.flags.get();
60 flags.insert(Flags::DISCONNECT_ON_DROP);
61 self.0.flags.set(flags);
62 self
63 }
64
65 pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
67 where
68 HeaderName: TryFrom<K>,
69 HeaderValue: TryFrom<V>,
70 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
71 <HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
72 {
73 if let Some(ctx) = ctx(self) {
74 match HeaderName::try_from(key) {
75 Ok(key) => match HeaderValue::try_from(value) {
76 Ok(value) => ctx.headers.push((key, value)),
77 Err(e) => ctx.err = Some(log_error(e)),
78 },
79 Err(e) => ctx.err = Some(log_error(e)),
80 }
81 }
82 self
83 }
84
85 pub(crate) fn headers(&self) -> &[(HeaderName, HeaderValue)] {
86 &self.0.headers
87 }
88
89 pub(crate) fn get_disconnect_on_drop(&self) -> bool {
90 self.0.flags.get().contains(Flags::DISCONNECT_ON_DROP)
91 }
92}
93
94impl Clone for RequestContext {
95 fn clone(&self) -> Self {
96 Self(self.0.clone())
97 }
98}
99
100fn log_error<T: Into<HttpError>>(err: T) -> HttpError {
101 let e = err.into();
102 log::error!("Error in Grpc Request {e}");
103 e
104}
105
106fn ctx(slf: &mut RequestContext) -> Option<&mut RequestContextInner> {
107 if slf.0.err.is_some() {
108 return None;
109 }
110
111 if Rc::get_mut(&mut slf.0).is_some() {
112 Rc::get_mut(&mut slf.0)
113 } else {
114 slf.0 = Rc::new(RequestContextInner {
115 err: None,
116 headers: slf.0.headers.clone(),
117 timeout: slf.0.timeout.clone(),
118 flags: slf.0.flags.clone(),
119 });
120 Some(Rc::get_mut(&mut slf.0).unwrap())
121 }
122}
123
124pin_project_lite::pin_project! {
125 pub struct Request<'a, T, M>
126 where T: Transport<M>,
127 T: 'a,
128 M: MethodDef
129 {
130 transport: &'a T,
131 #[pin]
132 state: State<'a, T, M>,
133 }
134}
135
136enum State<'a, T, M>
137where
138 T: Transport<M> + 'a,
139 M: MethodDef,
140{
141 Call {
142 fut: BoxFuture<'a, Result<Response<M>, T::Error>>,
143 },
144 Request {
145 input: &'a M::Input,
146 ctx: Option<RequestContext>,
147 },
148 None,
149}
150
151impl<'a, T, M> Request<'a, T, M>
152where
153 T: Transport<M>,
154 M: MethodDef,
155{
156 pub fn new(transport: &'a T, input: &'a M::Input) -> Self {
157 Self {
158 transport,
159 state: State::Request {
160 input,
161 ctx: Some(RequestContext::new()),
162 },
163 }
164 }
165
166 pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
179 where
180 HeaderName: TryFrom<K>,
181 HeaderValue: TryFrom<V>,
182 <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
183 <HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
184 {
185 if let Some(ctx) = parts(&mut self.state) {
186 ctx.header(key, value);
187 }
188 self
189 }
190
191 pub fn timeout<U>(&mut self, timeout: U) -> &mut Self
198 where
199 time::Duration: From<U>,
200 {
201 if let Some(ctx) = parts(&mut self.state) {
202 let to = timeout.into();
203 ctx.0.timeout.set(Some(to));
204 ctx.header(consts::GRPC_TIMEOUT, duration_to_grpc_timeout(to));
205 }
206 self
207 }
208}
209
210fn duration_to_grpc_timeout(duration: time::Duration) -> String {
211 fn try_format<T: Into<u128>>(
212 duration: time::Duration,
213 unit: char,
214 convert: impl FnOnce(time::Duration) -> T,
215 ) -> Option<String> {
216 let max_size: u128 = 99_999_999; let value = convert(duration).into();
221 if value > max_size {
222 None
223 } else {
224 Some(format!("{value}{unit}"))
225 }
226 }
227
228 try_format(duration, 'n', |d| d.as_nanos())
230 .or_else(|| try_format(duration, 'u', |d| d.as_micros()))
231 .or_else(|| try_format(duration, 'm', |d| d.as_millis()))
232 .or_else(|| try_format(duration, 'S', |d| d.as_secs()))
233 .or_else(|| try_format(duration, 'M', |d| d.as_secs() / 60))
234 .or_else(|| {
235 try_format(duration, 'H', |d| {
236 let minutes = d.as_secs() / 60;
237 minutes / 60
238 })
239 })
240 .expect("duration is unrealistically large")
242}
243
244#[inline]
245fn parts<'a, 'b, T: Transport<M> + 'a, M: MethodDef>(
246 parts: &'b mut State<'a, T, M>,
247) -> Option<&'b mut RequestContext> {
248 if let State::Request { ctx, .. } = parts {
249 ctx.as_mut()
250 } else {
251 None
252 }
253}
254
255impl<'a, T, M: 'a> Future for Request<'a, T, M>
256where
257 T: Transport<M>,
258 M: MethodDef,
259{
260 type Output = Result<Response<M>, T::Error>;
261
262 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
263 loop {
264 if let State::Call { ref mut fut } = self.state {
265 return Pin::new(fut).poll(cx);
266 }
267
268 if let State::Request { input, ref mut ctx } =
269 mem::replace(&mut self.state, State::None)
270 {
271 self.state = State::Call {
272 fut: Box::pin(self.transport.request(input, ctx.take().unwrap())),
273 };
274 }
275 }
276 }
277}
278
279pub struct Response<T: MethodDef> {
280 pub output: T::Output,
281 pub headers: HeaderMap,
282 pub trailers: HeaderMap,
283 pub req_size: usize,
284 pub res_size: usize,
285}
286
287impl<T: MethodDef> Response<T> {
288 #[inline]
289 pub fn headers(&self) -> &HeaderMap {
290 &self.headers
291 }
292
293 #[inline]
294 pub fn trailers(&self) -> &HeaderMap {
295 &self.trailers
296 }
297
298 #[inline]
299 pub fn into_inner(self) -> T::Output {
300 self.output
301 }
302
303 #[inline]
304 pub fn into_parts(self) -> (T::Output, HeaderMap, HeaderMap) {
305 (self.output, self.headers, self.trailers)
306 }
307}
308
309impl<T: MethodDef> ops::Deref for Response<T> {
310 type Target = T::Output;
311
312 fn deref(&self) -> &Self::Target {
313 &self.output
314 }
315}
316
317impl<T: MethodDef> ops::DerefMut for Response<T> {
318 fn deref_mut(&mut self) -> &mut Self::Target {
319 &mut self.output
320 }
321}
322
323impl<T: MethodDef> fmt::Debug for Response<T>
324where
325 T::Output: fmt::Debug,
326{
327 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328 f.debug_struct(format!("ResponseFor<{}>", T::NAME).as_str())
329 .field("output", &self.output)
330 .field("headers", &self.headers)
331 .field("translers", &self.headers)
332 .finish()
333 }
334}
335
336#[cfg(test)]
337mod tests {
338 use super::*;
339
340 #[test]
341 fn duration_to_grpc_timeout_less_than_second() {
342 let timeout = time::Duration::from_millis(500);
343 let value = duration_to_grpc_timeout(timeout);
344 assert_eq!(value, format!("{}u", timeout.as_micros()));
345
346 let timeout = time::Duration::from_secs(30);
347 let value = duration_to_grpc_timeout(timeout);
348 assert_eq!(value, format!("{}u", timeout.as_micros()));
349
350 let one_hour = time::Duration::from_secs(60 * 60);
351 let value = duration_to_grpc_timeout(one_hour);
352 assert_eq!(value, format!("{}m", one_hour.as_millis()));
353 }
354}