etcd_client/
channel.rs

1use std::{future::Future, pin::Pin, task::ready};
2
3use http::Uri;
4use tokio::sync::mpsc::Sender;
5use tonic::transport::{channel::Change, Endpoint};
6use tower::{util::BoxCloneSyncService, Service};
7
8/// A type alias to make the below types easier to represent.
9pub type EndpointUpdater = Sender<Change<Uri, Endpoint>>;
10
11/// Creates a balanced channel.
12pub trait BalancedChannelBuilder {
13    type Error;
14
15    /// Makes a new balanced channel, given the provided options.
16    fn balanced_channel(
17        self,
18        buffer_size: usize,
19    ) -> Result<(Channel, EndpointUpdater), Self::Error>;
20}
21
22/// Create a simple Tonic channel.
23pub struct Tonic;
24
25impl BalancedChannelBuilder for Tonic {
26    type Error = tonic::transport::Error;
27
28    #[inline]
29    fn balanced_channel(
30        self,
31        buffer_size: usize,
32    ) -> Result<(Channel, EndpointUpdater), Self::Error> {
33        let (chan, tx) = tonic::transport::Channel::balance_channel(buffer_size);
34        Ok((Channel::Tonic(chan), tx))
35    }
36}
37
38/// Create an Openssl-backed channel.
39#[cfg(feature = "tls-openssl")]
40pub struct Openssl {
41    pub(crate) conn: crate::openssl_tls::OpenSslConnector,
42}
43
44#[cfg(feature = "tls-openssl")]
45impl BalancedChannelBuilder for Openssl {
46    type Error = crate::error::Error;
47
48    #[inline]
49    fn balanced_channel(self, _: usize) -> Result<(Channel, EndpointUpdater), Self::Error> {
50        let (chan, tx) = crate::openssl_tls::balanced_channel(self.conn)?;
51        Ok((Channel::Openssl(chan), tx))
52    }
53}
54
55type TonicRequest = http::Request<tonic::body::Body>;
56type TonicResponse = http::Response<tonic::body::Body>;
57pub type CustomChannel = BoxCloneSyncService<TonicRequest, TonicResponse, tower::BoxError>;
58
59/// Represents a channel that can be created by a BalancedChannelBuilder
60/// or may be initialized externally and passed into the client.
61#[derive(Clone)]
62pub enum Channel {
63    /// A standard tonic channel.
64    Tonic(tonic::transport::Channel),
65
66    /// An OpenSSL channel.
67    #[cfg(feature = "tls-openssl")]
68    Openssl(crate::openssl_tls::OpenSslChannel),
69
70    /// A custom Service impl, inside a Box.
71    Custom(CustomChannel),
72}
73
74impl std::fmt::Debug for Channel {
75    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76        f.debug_struct("Channel").finish_non_exhaustive()
77    }
78}
79
80pub enum ChannelFuture {
81    Tonic(<tonic::transport::Channel as Service<TonicRequest>>::Future),
82    #[cfg(feature = "tls-openssl")]
83    Openssl(<crate::openssl_tls::OpenSslChannel as Service<TonicRequest>>::Future),
84    Custom(<CustomChannel as Service<TonicRequest>>::Future),
85}
86
87impl std::future::Future for ChannelFuture {
88    type Output = Result<TonicResponse, tower::BoxError>;
89
90    #[inline]
91    fn poll(
92        self: std::pin::Pin<&mut Self>,
93        cx: &mut std::task::Context<'_>,
94    ) -> std::task::Poll<Self::Output> {
95        // Safety: trivial projection
96        unsafe {
97            let this = self.get_unchecked_mut();
98            match this {
99                ChannelFuture::Tonic(fut) => {
100                    let fut = Pin::new_unchecked(fut);
101                    let result = ready!(Future::poll(fut, cx));
102                    result.map_err(|e| Box::new(e) as tower::BoxError).into()
103                }
104                #[cfg(feature = "tls-openssl")]
105                ChannelFuture::Openssl(fut) => {
106                    let fut = Pin::new_unchecked(fut);
107                    Future::poll(fut, cx)
108                }
109                ChannelFuture::Custom(fut) => {
110                    let fut = Pin::new_unchecked(fut);
111                    Future::poll(fut, cx)
112                }
113            }
114        }
115    }
116}
117
118impl ChannelFuture {
119    #[inline]
120    fn from_tonic(value: <tonic::transport::Channel as Service<TonicRequest>>::Future) -> Self {
121        Self::Tonic(value)
122    }
123
124    #[cfg(feature = "tls-openssl")]
125    #[inline]
126    fn from_openssl(
127        value: <crate::openssl_tls::OpenSslChannel as Service<TonicRequest>>::Future,
128    ) -> Self {
129        Self::Openssl(value)
130    }
131
132    #[inline]
133    fn from_custom(value: <CustomChannel as Service<TonicRequest>>::Future) -> Self {
134        Self::Custom(value)
135    }
136}
137
138impl Service<TonicRequest> for Channel {
139    type Response = TonicResponse;
140    type Error = tower::BoxError;
141    type Future = ChannelFuture;
142
143    #[inline]
144    fn poll_ready(
145        &mut self,
146        cx: &mut std::task::Context<'_>,
147    ) -> std::task::Poll<Result<(), Self::Error>> {
148        match self {
149            Channel::Tonic(channel) => {
150                let result = ready!(channel.poll_ready(cx));
151                result.map_err(|e| Box::new(e) as tower::BoxError).into()
152            }
153            #[cfg(feature = "tls-openssl")]
154            Channel::Openssl(openssl) => openssl.poll_ready(cx),
155            Channel::Custom(custom) => custom.poll_ready(cx),
156        }
157    }
158
159    #[inline]
160    fn call(&mut self, req: TonicRequest) -> Self::Future {
161        match self {
162            Channel::Tonic(channel) => ChannelFuture::from_tonic(channel.call(req)),
163            #[cfg(feature = "tls-openssl")]
164            Channel::Openssl(openssl) => ChannelFuture::from_openssl(openssl.call(req)),
165            Channel::Custom(custom) => ChannelFuture::from_custom(custom.call(req)),
166        }
167    }
168}