conjure_runtime_raw/service/
timeout.rs

1// Copyright 2020 Palantir Technologies, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14use conjure_runtime::{builder, Builder};
15use futures::ready;
16use hyper::rt::{Read, ReadBufCursor, Write};
17use hyper_util::client::legacy::connect::{Connected, Connection};
18use hyper_util::rt::TokioIo;
19use pin_project::pin_project;
20use std::future::Future;
21use std::io;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24use std::time::Duration;
25use tower_layer::Layer;
26use tower_service::Service;
27
28/// A connector layer which wraps a stream in a `TimeoutStream`.
29pub struct TimeoutLayer {
30    read_timeout: Duration,
31    write_timeout: Duration,
32}
33
34impl TimeoutLayer {
35    pub fn new<T>(builder: &Builder<builder::Complete<T>>) -> TimeoutLayer {
36        TimeoutLayer {
37            read_timeout: builder.get_read_timeout(),
38            write_timeout: builder.get_write_timeout(),
39        }
40    }
41}
42
43impl<S> Layer<S> for TimeoutLayer {
44    type Service = TimeoutService<S>;
45
46    fn layer(&self, inner: S) -> Self::Service {
47        TimeoutService {
48            inner,
49            read_timeout: self.read_timeout,
50            write_timeout: self.write_timeout,
51        }
52    }
53}
54
55#[derive(Clone)]
56pub struct TimeoutService<S> {
57    inner: S,
58    read_timeout: Duration,
59    write_timeout: Duration,
60}
61
62impl<S, R> Service<R> for TimeoutService<S>
63where
64    S: Service<R>,
65    S::Response: Read + Write + Unpin,
66{
67    type Response = TimeoutStream<S::Response>;
68    type Error = S::Error;
69    type Future = TimeoutFuture<S::Future>;
70
71    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
72        self.inner.poll_ready(cx)
73    }
74
75    fn call(&mut self, req: R) -> Self::Future {
76        TimeoutFuture {
77            future: self.inner.call(req),
78            read_timeout: self.read_timeout,
79            write_timeout: self.write_timeout,
80        }
81    }
82}
83
84#[pin_project]
85pub struct TimeoutFuture<F> {
86    #[pin]
87    future: F,
88    read_timeout: Duration,
89    write_timeout: Duration,
90}
91
92impl<F, S, E> Future for TimeoutFuture<F>
93where
94    F: Future<Output = Result<S, E>>,
95    S: Read + Write + Unpin,
96{
97    type Output = Result<TimeoutStream<S>, E>;
98
99    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100        let this = self.project();
101
102        let stream = ready!(this.future.poll(cx))?;
103        let mut stream = tokio_io_timeout::TimeoutStream::new(TokioIo::new(stream));
104        stream.set_read_timeout(Some(*this.read_timeout));
105        stream.set_write_timeout(Some(*this.write_timeout));
106
107        Poll::Ready(Ok(TimeoutStream {
108            stream: Box::pin(TokioIo::new(stream)),
109        }))
110    }
111}
112
113#[derive(Debug)]
114pub struct TimeoutStream<S> {
115    stream: Pin<Box<TokioIo<tokio_io_timeout::TimeoutStream<TokioIo<S>>>>>,
116}
117
118impl<S> Read for TimeoutStream<S>
119where
120    S: Read + Write,
121{
122    fn poll_read(
123        mut self: Pin<&mut Self>,
124        cx: &mut Context<'_>,
125        buf: ReadBufCursor<'_>,
126    ) -> Poll<io::Result<()>> {
127        self.stream.as_mut().poll_read(cx, buf)
128    }
129}
130
131impl<S> Write for TimeoutStream<S>
132where
133    S: Read + Write,
134{
135    fn poll_write(
136        mut self: Pin<&mut Self>,
137        cx: &mut Context<'_>,
138        buf: &[u8],
139    ) -> Poll<io::Result<usize>> {
140        self.stream.as_mut().poll_write(cx, buf)
141    }
142
143    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
144        self.stream.as_mut().poll_flush(cx)
145    }
146
147    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148        self.stream.as_mut().poll_shutdown(cx)
149    }
150}
151
152impl<S> Connection for TimeoutStream<S>
153where
154    S: Read + Write + Connection,
155{
156    fn connected(&self) -> Connected {
157        self.stream.inner().get_ref().inner().connected()
158    }
159}