Skip to main content

ntex_grpc/client/
request.rs

1use std::task::{Context, Poll};
2use std::{cell::Cell, convert::TryFrom, fmt, future::Future, mem, ops, pin::Pin, rc::Rc, time};
3
4use ntex_http::{HeaderMap, HeaderName, HeaderValue, error::Error as HttpError};
5use ntex_util::future::BoxFuture;
6
7use crate::{client::Transport, consts, service::MethodDef};
8
9pub struct RequestContext(Rc<RequestContextInner>);
10
11bitflags::bitflags! {
12    #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
13    struct Flags: u8 {
14        const DISCONNECT_ON_DROP = 0b0000_0001;
15    }
16}
17
18struct RequestContextInner {
19    err: Option<HttpError>,
20    headers: Vec<(HeaderName, HeaderValue)>,
21    timeout: Cell<Option<time::Duration>>,
22    flags: Cell<Flags>,
23}
24
25impl RequestContext {
26    /// Create new `RequestContext` instance
27    fn new() -> Self {
28        Self(Rc::new(RequestContextInner {
29            err: None,
30            headers: Vec::new(),
31            timeout: Cell::new(None),
32            flags: Cell::new(Flags::empty()),
33        }))
34    }
35
36    /// Get request timeout
37    pub fn get_timeout(&self) -> Option<time::Duration> {
38        self.0.timeout.get()
39    }
40
41    /// Set the max duration the request is allowed to take.
42    ///
43    /// The duration will be formatted according to [the spec] and use the most precise
44    /// possible.
45    ///
46    /// [the spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
47    pub fn timeout<U>(&mut self, timeout: U) -> &mut Self
48    where
49        time::Duration: From<U>,
50    {
51        let to = timeout.into();
52        self.0.timeout.set(Some(to));
53        self.header(consts::GRPC_TIMEOUT, duration_to_grpc_timeout(to));
54        self
55    }
56
57    /// Disconnect connection on request drop
58    pub fn disconnect_on_drop(&mut self) -> &mut Self {
59        let mut flags = self.0.flags.get();
60        flags.insert(Flags::DISCONNECT_ON_DROP);
61        self.0.flags.set(flags);
62        self
63    }
64
65    /// Append a header to existing headers.
66    pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
67    where
68        HeaderName: TryFrom<K>,
69        HeaderValue: TryFrom<V>,
70        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
71        <HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
72    {
73        if let Some(ctx) = ctx(self) {
74            match HeaderName::try_from(key) {
75                Ok(key) => match HeaderValue::try_from(value) {
76                    Ok(value) => ctx.headers.push((key, value)),
77                    Err(e) => ctx.err = Some(log_error(e)),
78                },
79                Err(e) => ctx.err = Some(log_error(e)),
80            }
81        }
82        self
83    }
84
85    pub(crate) fn headers(&self) -> &[(HeaderName, HeaderValue)] {
86        &self.0.headers
87    }
88
89    pub(crate) fn get_disconnect_on_drop(&self) -> bool {
90        self.0.flags.get().contains(Flags::DISCONNECT_ON_DROP)
91    }
92}
93
94impl Clone for RequestContext {
95    fn clone(&self) -> Self {
96        Self(self.0.clone())
97    }
98}
99
100fn log_error<T: Into<HttpError>>(err: T) -> HttpError {
101    let e = err.into();
102    log::error!("Error in Grpc Request {e}");
103    e
104}
105
106fn ctx(slf: &mut RequestContext) -> Option<&mut RequestContextInner> {
107    if slf.0.err.is_some() {
108        return None;
109    }
110
111    if Rc::get_mut(&mut slf.0).is_some() {
112        Rc::get_mut(&mut slf.0)
113    } else {
114        slf.0 = Rc::new(RequestContextInner {
115            err: None,
116            headers: slf.0.headers.clone(),
117            timeout: slf.0.timeout.clone(),
118            flags: slf.0.flags.clone(),
119        });
120        Some(Rc::get_mut(&mut slf.0).unwrap())
121    }
122}
123
124pin_project_lite::pin_project! {
125    pub struct Request<'a, T, M>
126    where T: Transport<M>,
127          T: 'a,
128          M: MethodDef
129    {
130        transport: &'a T,
131        #[pin]
132        state: State<'a, T, M>,
133    }
134}
135
136enum State<'a, T, M>
137where
138    T: Transport<M> + 'a,
139    M: MethodDef,
140{
141    Call {
142        fut: BoxFuture<'a, Result<Response<M>, T::Error>>,
143    },
144    Request {
145        input: &'a M::Input,
146        ctx: Option<RequestContext>,
147    },
148    None,
149}
150
151impl<'a, T, M> Request<'a, T, M>
152where
153    T: Transport<M>,
154    M: MethodDef,
155{
156    pub fn new(transport: &'a T, input: &'a M::Input) -> Self {
157        Self {
158            transport,
159            state: State::Request {
160                input,
161                ctx: Some(RequestContext::new()),
162            },
163        }
164    }
165
166    /// Append a header to existing headers.
167    ///
168    /// ```rust
169    /// use ntex::http::{header, Request, Response};
170    ///
171    /// fn index(req: Request) -> Response {
172    ///     Response::Ok()
173    ///         .header("X-TEST", "value")
174    ///         .header(header::CONTENT_TYPE, "application/json")
175    ///         .finish()
176    /// }
177    /// ```
178    pub fn header<K, V>(&mut self, key: K, value: V) -> &mut Self
179    where
180        HeaderName: TryFrom<K>,
181        HeaderValue: TryFrom<V>,
182        <HeaderName as TryFrom<K>>::Error: Into<HttpError>,
183        <HeaderValue as TryFrom<V>>::Error: Into<HttpError>,
184    {
185        if let Some(ctx) = parts(&mut self.state) {
186            ctx.header(key, value);
187        }
188        self
189    }
190
191    /// Set the max duration the request is allowed to take.
192    ///
193    /// The duration will be formatted according to [the spec] and use the most precise
194    /// possible.
195    ///
196    /// [the spec]: https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md
197    pub fn timeout<U>(&mut self, timeout: U) -> &mut Self
198    where
199        time::Duration: From<U>,
200    {
201        if let Some(ctx) = parts(&mut self.state) {
202            let to = timeout.into();
203            ctx.0.timeout.set(Some(to));
204            ctx.header(consts::GRPC_TIMEOUT, duration_to_grpc_timeout(to));
205        }
206        self
207    }
208}
209
210fn duration_to_grpc_timeout(duration: time::Duration) -> String {
211    fn try_format<T: Into<u128>>(
212        duration: time::Duration,
213        unit: char,
214        convert: impl FnOnce(time::Duration) -> T,
215    ) -> Option<String> {
216        // The gRPC spec specifies that the timeout most be at most 8 digits. So this is the largest a
217        // value can be before we need to use a bigger unit.
218        let max_size: u128 = 99_999_999; // exactly 8 digits
219
220        let value = convert(duration).into();
221        if value > max_size {
222            None
223        } else {
224            Some(format!("{value}{unit}"))
225        }
226    }
227
228    // pick the most precise unit that is less than or equal to 8 digits as per the gRPC spec
229    try_format(duration, 'n', |d| d.as_nanos())
230        .or_else(|| try_format(duration, 'u', |d| d.as_micros()))
231        .or_else(|| try_format(duration, 'm', |d| d.as_millis()))
232        .or_else(|| try_format(duration, 'S', |d| d.as_secs()))
233        .or_else(|| try_format(duration, 'M', |d| d.as_secs() / 60))
234        .or_else(|| {
235            try_format(duration, 'H', |d| {
236                let minutes = d.as_secs() / 60;
237                minutes / 60
238            })
239        })
240        // duration has to be more than 11_415 years for this to happen
241        .expect("duration is unrealistically large")
242}
243
244#[inline]
245fn parts<'a, 'b, T: Transport<M> + 'a, M: MethodDef>(
246    parts: &'b mut State<'a, T, M>,
247) -> Option<&'b mut RequestContext> {
248    if let State::Request { ctx, .. } = parts {
249        ctx.as_mut()
250    } else {
251        None
252    }
253}
254
255impl<'a, T, M: 'a> Future for Request<'a, T, M>
256where
257    T: Transport<M>,
258    M: MethodDef,
259{
260    type Output = Result<Response<M>, T::Error>;
261
262    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
263        loop {
264            if let State::Call { ref mut fut } = self.state {
265                return Pin::new(fut).poll(cx);
266            }
267
268            if let State::Request { input, ref mut ctx } =
269                mem::replace(&mut self.state, State::None)
270            {
271                self.state = State::Call {
272                    fut: Box::pin(self.transport.request(input, ctx.take().unwrap())),
273                };
274            }
275        }
276    }
277}
278
279pub struct Response<T: MethodDef> {
280    pub output: T::Output,
281    pub headers: HeaderMap,
282    pub trailers: HeaderMap,
283    pub req_size: usize,
284    pub res_size: usize,
285}
286
287impl<T: MethodDef> Response<T> {
288    #[inline]
289    pub fn headers(&self) -> &HeaderMap {
290        &self.headers
291    }
292
293    #[inline]
294    pub fn trailers(&self) -> &HeaderMap {
295        &self.trailers
296    }
297
298    #[inline]
299    pub fn into_inner(self) -> T::Output {
300        self.output
301    }
302
303    #[inline]
304    pub fn into_parts(self) -> (T::Output, HeaderMap, HeaderMap) {
305        (self.output, self.headers, self.trailers)
306    }
307}
308
309impl<T: MethodDef> ops::Deref for Response<T> {
310    type Target = T::Output;
311
312    fn deref(&self) -> &Self::Target {
313        &self.output
314    }
315}
316
317impl<T: MethodDef> ops::DerefMut for Response<T> {
318    fn deref_mut(&mut self) -> &mut Self::Target {
319        &mut self.output
320    }
321}
322
323impl<T: MethodDef> fmt::Debug for Response<T>
324where
325    T::Output: fmt::Debug,
326{
327    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
328        f.debug_struct(format!("ResponseFor<{}>", T::NAME).as_str())
329            .field("output", &self.output)
330            .field("headers", &self.headers)
331            .field("translers", &self.headers)
332            .finish()
333    }
334}
335
336#[cfg(test)]
337mod tests {
338    use super::*;
339
340    #[test]
341    fn duration_to_grpc_timeout_less_than_second() {
342        let timeout = time::Duration::from_millis(500);
343        let value = duration_to_grpc_timeout(timeout);
344        assert_eq!(value, format!("{}u", timeout.as_micros()));
345
346        let timeout = time::Duration::from_secs(30);
347        let value = duration_to_grpc_timeout(timeout);
348        assert_eq!(value, format!("{}u", timeout.as_micros()));
349
350        let one_hour = time::Duration::from_secs(60 * 60);
351        let value = duration_to_grpc_timeout(one_hour);
352        assert_eq!(value, format!("{}m", one_hour.as_millis()));
353    }
354}