conjure_runtime_raw/service/
timeout.rs1use conjure_runtime::{builder, Builder};
15use futures::ready;
16use hyper::rt::{Read, ReadBufCursor, Write};
17use hyper_util::client::legacy::connect::{Connected, Connection};
18use hyper_util::rt::TokioIo;
19use pin_project::pin_project;
20use std::future::Future;
21use std::io;
22use std::pin::Pin;
23use std::task::{Context, Poll};
24use std::time::Duration;
25use tower_layer::Layer;
26use tower_service::Service;
27
28pub struct TimeoutLayer {
30 read_timeout: Duration,
31 write_timeout: Duration,
32}
33
34impl TimeoutLayer {
35 pub fn new<T>(builder: &Builder<builder::Complete<T>>) -> TimeoutLayer {
36 TimeoutLayer {
37 read_timeout: builder.get_read_timeout(),
38 write_timeout: builder.get_write_timeout(),
39 }
40 }
41}
42
43impl<S> Layer<S> for TimeoutLayer {
44 type Service = TimeoutService<S>;
45
46 fn layer(&self, inner: S) -> Self::Service {
47 TimeoutService {
48 inner,
49 read_timeout: self.read_timeout,
50 write_timeout: self.write_timeout,
51 }
52 }
53}
54
55#[derive(Clone)]
56pub struct TimeoutService<S> {
57 inner: S,
58 read_timeout: Duration,
59 write_timeout: Duration,
60}
61
62impl<S, R> Service<R> for TimeoutService<S>
63where
64 S: Service<R>,
65 S::Response: Read + Write + Unpin,
66{
67 type Response = TimeoutStream<S::Response>;
68 type Error = S::Error;
69 type Future = TimeoutFuture<S::Future>;
70
71 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
72 self.inner.poll_ready(cx)
73 }
74
75 fn call(&mut self, req: R) -> Self::Future {
76 TimeoutFuture {
77 future: self.inner.call(req),
78 read_timeout: self.read_timeout,
79 write_timeout: self.write_timeout,
80 }
81 }
82}
83
84#[pin_project]
85pub struct TimeoutFuture<F> {
86 #[pin]
87 future: F,
88 read_timeout: Duration,
89 write_timeout: Duration,
90}
91
92impl<F, S, E> Future for TimeoutFuture<F>
93where
94 F: Future<Output = Result<S, E>>,
95 S: Read + Write + Unpin,
96{
97 type Output = Result<TimeoutStream<S>, E>;
98
99 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100 let this = self.project();
101
102 let stream = ready!(this.future.poll(cx))?;
103 let mut stream = tokio_io_timeout::TimeoutStream::new(TokioIo::new(stream));
104 stream.set_read_timeout(Some(*this.read_timeout));
105 stream.set_write_timeout(Some(*this.write_timeout));
106
107 Poll::Ready(Ok(TimeoutStream {
108 stream: Box::pin(TokioIo::new(stream)),
109 }))
110 }
111}
112
113#[derive(Debug)]
114pub struct TimeoutStream<S> {
115 stream: Pin<Box<TokioIo<tokio_io_timeout::TimeoutStream<TokioIo<S>>>>>,
116}
117
118impl<S> Read for TimeoutStream<S>
119where
120 S: Read + Write,
121{
122 fn poll_read(
123 mut self: Pin<&mut Self>,
124 cx: &mut Context<'_>,
125 buf: ReadBufCursor<'_>,
126 ) -> Poll<io::Result<()>> {
127 self.stream.as_mut().poll_read(cx, buf)
128 }
129}
130
131impl<S> Write for TimeoutStream<S>
132where
133 S: Read + Write,
134{
135 fn poll_write(
136 mut self: Pin<&mut Self>,
137 cx: &mut Context<'_>,
138 buf: &[u8],
139 ) -> Poll<io::Result<usize>> {
140 self.stream.as_mut().poll_write(cx, buf)
141 }
142
143 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
144 self.stream.as_mut().poll_flush(cx)
145 }
146
147 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
148 self.stream.as_mut().poll_shutdown(cx)
149 }
150}
151
152impl<S> Connection for TimeoutStream<S>
153where
154 S: Read + Write + Connection,
155{
156 fn connected(&self) -> Connected {
157 self.stream.inner().get_ref().inner().connected()
158 }
159}