etcd_client/
channel.rs

1use std::{future::Future, pin::Pin, task::ready};
2
3use http::Uri;
4use tokio::sync::mpsc::Sender;
5use tonic::transport::{channel::Change, Endpoint};
6use tower::{util::BoxCloneSyncService, Service};
7
8/// A type alias to make the below types easier to represent.
9pub type EndpointUpdater = Sender<Change<Uri, Endpoint>>;
10
11/// Creates a balanced channel.
12pub trait BalancedChannelBuilder {
13    type Error;
14
15    /// Makes a new balanced channel, given the provided options.
16    fn balanced_channel(
17        self,
18        buffer_size: usize,
19    ) -> Result<(Channel, EndpointUpdater), Self::Error>;
20}
21
22/// Create a simple Tonic channel.
23#[allow(dead_code)]
24pub struct Tonic;
25
26impl BalancedChannelBuilder for Tonic {
27    type Error = tonic::transport::Error;
28
29    #[inline]
30    fn balanced_channel(
31        self,
32        buffer_size: usize,
33    ) -> Result<(Channel, EndpointUpdater), Self::Error> {
34        let (chan, tx) = tonic::transport::Channel::balance_channel(buffer_size);
35        Ok((Channel::Tonic(chan), tx))
36    }
37}
38
39/// Create an Openssl-backed channel.
40#[cfg(feature = "tls-openssl")]
41pub struct Openssl {
42    pub(crate) conn: crate::openssl_tls::OpenSslConnector,
43}
44
45#[cfg(feature = "tls-openssl")]
46impl BalancedChannelBuilder for Openssl {
47    type Error = crate::error::Error;
48
49    #[inline]
50    fn balanced_channel(self, _: usize) -> Result<(Channel, EndpointUpdater), Self::Error> {
51        let (chan, tx) = crate::openssl_tls::balanced_channel(self.conn)?;
52        Ok((Channel::Openssl(chan), tx))
53    }
54}
55
56type TonicRequest = http::Request<tonic::body::Body>;
57type TonicResponse = http::Response<tonic::body::Body>;
58pub type CustomChannel = BoxCloneSyncService<TonicRequest, TonicResponse, tower::BoxError>;
59
60/// Represents a channel that can be created by a BalancedChannelBuilder
61/// or may be initialized externally and passed into the client.
62#[derive(Clone)]
63pub enum Channel {
64    /// A standard tonic channel.
65    Tonic(tonic::transport::Channel),
66
67    /// An OpenSSL channel.
68    #[cfg(feature = "tls-openssl")]
69    Openssl(crate::openssl_tls::OpenSslChannel),
70
71    /// A custom Service impl, inside a Box.
72    Custom(CustomChannel),
73}
74
75impl std::fmt::Debug for Channel {
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("Channel").finish_non_exhaustive()
78    }
79}
80
81pub enum ChannelFuture {
82    Tonic(<tonic::transport::Channel as Service<TonicRequest>>::Future),
83    #[cfg(feature = "tls-openssl")]
84    Openssl(<crate::openssl_tls::OpenSslChannel as Service<TonicRequest>>::Future),
85    Custom(<CustomChannel as Service<TonicRequest>>::Future),
86}
87
88impl std::future::Future for ChannelFuture {
89    type Output = Result<TonicResponse, tower::BoxError>;
90
91    #[inline]
92    fn poll(
93        self: std::pin::Pin<&mut Self>,
94        cx: &mut std::task::Context<'_>,
95    ) -> std::task::Poll<Self::Output> {
96        // Safety: trivial projection
97        unsafe {
98            let this = self.get_unchecked_mut();
99            match this {
100                ChannelFuture::Tonic(fut) => {
101                    let fut = Pin::new_unchecked(fut);
102                    let result = ready!(Future::poll(fut, cx));
103                    result.map_err(|e| Box::new(e) as tower::BoxError).into()
104                }
105                #[cfg(feature = "tls-openssl")]
106                ChannelFuture::Openssl(fut) => {
107                    let fut = Pin::new_unchecked(fut);
108                    Future::poll(fut, cx)
109                }
110                ChannelFuture::Custom(fut) => {
111                    let fut = Pin::new_unchecked(fut);
112                    Future::poll(fut, cx)
113                }
114            }
115        }
116    }
117}
118
119impl ChannelFuture {
120    #[inline]
121    fn from_tonic(value: <tonic::transport::Channel as Service<TonicRequest>>::Future) -> Self {
122        Self::Tonic(value)
123    }
124
125    #[cfg(feature = "tls-openssl")]
126    #[inline]
127    fn from_openssl(
128        value: <crate::openssl_tls::OpenSslChannel as Service<TonicRequest>>::Future,
129    ) -> Self {
130        Self::Openssl(value)
131    }
132
133    #[inline]
134    fn from_custom(value: <CustomChannel as Service<TonicRequest>>::Future) -> Self {
135        Self::Custom(value)
136    }
137}
138
139impl Service<TonicRequest> for Channel {
140    type Response = TonicResponse;
141    type Error = tower::BoxError;
142    type Future = ChannelFuture;
143
144    #[inline]
145    fn poll_ready(
146        &mut self,
147        cx: &mut std::task::Context<'_>,
148    ) -> std::task::Poll<Result<(), Self::Error>> {
149        match self {
150            Channel::Tonic(channel) => {
151                let result = ready!(channel.poll_ready(cx));
152                result.map_err(|e| Box::new(e) as tower::BoxError).into()
153            }
154            #[cfg(feature = "tls-openssl")]
155            Channel::Openssl(openssl) => openssl.poll_ready(cx),
156            Channel::Custom(custom) => custom.poll_ready(cx),
157        }
158    }
159
160    #[inline]
161    fn call(&mut self, req: TonicRequest) -> Self::Future {
162        match self {
163            Channel::Tonic(channel) => ChannelFuture::from_tonic(channel.call(req)),
164            #[cfg(feature = "tls-openssl")]
165            Channel::Openssl(openssl) => ChannelFuture::from_openssl(openssl.call(req)),
166            Channel::Custom(custom) => ChannelFuture::from_custom(custom.call(req)),
167        }
168    }
169}