gcloud_gax/
conn.rs

1use std::fmt::Debug;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use http::header::AUTHORIZATION;
9use http::{HeaderValue, Request};
10use tonic::body::Body;
11use tonic::transport::{Channel as TonicChannel, ClientTlsConfig, Endpoint};
12use tonic::{Code, Status};
13use tower::filter::{AsyncFilter, AsyncFilterLayer, AsyncPredicate};
14use tower::util::Either;
15use tower::{BoxError, ServiceBuilder};
16
17use token_source::{TokenSource, TokenSourceProvider};
18
19pub type Channel = Either<AsyncFilter<TonicChannel, AsyncAuthInterceptor>, TonicChannel>;
20
21#[derive(Clone, Debug)]
22pub struct AsyncAuthInterceptor {
23    token_source: Arc<dyn TokenSource>,
24}
25
26impl AsyncAuthInterceptor {
27    fn new(token_source: Arc<dyn TokenSource>) -> Self {
28        Self { token_source }
29    }
30}
31
32impl AsyncPredicate<Request<Body>> for AsyncAuthInterceptor {
33    type Future = Pin<Box<dyn Future<Output = Result<Self::Request, BoxError>> + Send>>;
34    type Request = Request<Body>;
35
36    fn check(&mut self, request: Request<Body>) -> Self::Future {
37        let ts = self.token_source.clone();
38        Box::pin(async move {
39            let token = ts
40                .token()
41                .await
42                .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {e:?}")))?;
43            let token_header = HeaderValue::from_str(token.as_str())
44                .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {e:?}")))?;
45            let (mut parts, body) = request.into_parts();
46            parts.headers.insert(AUTHORIZATION, token_header);
47            Ok(Request::from_parts(parts, body))
48        })
49    }
50}
51
52#[derive(thiserror::Error, Debug)]
53pub enum Error {
54    #[error(transparent)]
55    Auth(#[from] Box<dyn std::error::Error + Send + Sync>),
56
57    #[error("tonic error : {0}")]
58    TonicTransport(#[from] tonic::transport::Error),
59
60    #[error("invalid emulator host: {0}")]
61    InvalidEmulatorHOST(String),
62}
63
64#[derive(Debug)]
65pub enum Environment {
66    Emulator(String),
67    GoogleCloud(Box<dyn TokenSourceProvider>),
68}
69
70#[derive(Debug)]
71struct AtomicRing<T>
72where
73    T: Clone + Debug,
74{
75    index: AtomicUsize,
76    values: Vec<T>,
77}
78
79impl<T> AtomicRing<T>
80where
81    T: Clone + Debug,
82{
83    fn next(&self) -> T {
84        let current = self.index.fetch_add(1, Ordering::SeqCst);
85        //clone() reuses http/2 connection
86        self.values[current % self.values.len()].clone()
87    }
88}
89
90#[derive(Debug, Clone, Default)]
91pub struct ConnectionOptions {
92    pub timeout: Option<Duration>,
93    pub connect_timeout: Option<Duration>,
94}
95
96impl ConnectionOptions {
97    fn apply(&self, mut endpoint: Endpoint) -> Endpoint {
98        endpoint = match self.timeout {
99            Some(t) => endpoint.timeout(t),
100            None => endpoint,
101        };
102        endpoint = match self.connect_timeout {
103            Some(t) => endpoint.connect_timeout(t),
104            None => endpoint,
105        };
106        endpoint
107    }
108}
109
110#[derive(Debug)]
111pub struct ConnectionManager {
112    inner: AtomicRing<Channel>,
113}
114
115impl<'a> ConnectionManager {
116    pub async fn new(
117        pool_size: usize,
118        domain_name: impl Into<String>,
119        audience: &'static str,
120        environment: &Environment,
121        conn_options: &'a ConnectionOptions,
122    ) -> Result<Self, Error> {
123        let conns = match environment {
124            Environment::GoogleCloud(ts_provider) => {
125                Self::create_connections(pool_size, domain_name, audience, ts_provider.as_ref(), conn_options).await?
126            }
127            Environment::Emulator(host) => Self::create_emulator_connections(host, conn_options).await?,
128        };
129        Ok(Self {
130            inner: AtomicRing {
131                index: AtomicUsize::new(0),
132                values: conns,
133            },
134        })
135    }
136
137    async fn create_connections(
138        pool_size: usize,
139        domain_name: impl Into<String>,
140        audience: &'static str,
141        ts_provider: &dyn TokenSourceProvider,
142        conn_options: &'a ConnectionOptions,
143    ) -> Result<Vec<Channel>, Error> {
144        let tls_config = ClientTlsConfig::new().with_webpki_roots().domain_name(domain_name);
145        let mut conns = Vec::with_capacity(pool_size);
146
147        let ts = ts_provider.token_source();
148
149        for _i_ in 0..pool_size {
150            let endpoint = TonicChannel::from_static(audience).tls_config(tls_config.clone())?;
151            let endpoint = conn_options.apply(endpoint);
152
153            let con = Self::connect(endpoint).await?;
154            // use GCP token per call
155            let auth_layer = Some(AsyncFilterLayer::new(AsyncAuthInterceptor::new(Arc::clone(&ts))));
156            let auth_con = ServiceBuilder::new().option_layer(auth_layer).service(con);
157            conns.push(auth_con);
158        }
159        Ok(conns)
160    }
161
162    async fn create_emulator_connections(
163        host: &str,
164        conn_options: &'a ConnectionOptions,
165    ) -> Result<Vec<Channel>, Error> {
166        let mut conns = Vec::with_capacity(1);
167        let endpoint = TonicChannel::from_shared(format!("http://{host}").into_bytes())
168            .map_err(|_| Error::InvalidEmulatorHOST(host.to_string()))?;
169        let endpoint = conn_options.apply(endpoint);
170
171        let con = Self::connect(endpoint).await?;
172        conns.push(
173            ServiceBuilder::new()
174                .option_layer::<AsyncFilterLayer<AsyncAuthInterceptor>>(None)
175                .service(con),
176        );
177        Ok(conns)
178    }
179
180    async fn connect(endpoint: Endpoint) -> Result<TonicChannel, tonic::transport::Error> {
181        let channel = endpoint.connect().await?;
182        Ok(channel)
183    }
184
185    pub fn num(&self) -> usize {
186        self.inner.values.len()
187    }
188
189    pub fn conn(&self) -> Channel {
190        self.inner.next()
191    }
192}
193
194#[cfg(test)]
195mod test {
196    use std::collections::HashSet;
197    use std::sync::atomic::{AtomicUsize, Ordering};
198
199    use crate::conn::AtomicRing;
200
201    #[test]
202    fn test_atomic_ring() {
203        let cm = AtomicRing::<&str> {
204            index: AtomicUsize::new(usize::MAX - 1),
205            values: vec!["a", "b", "c", "d"],
206        };
207        let mut values = HashSet::new();
208        assert_eq!(usize::MAX - 1, cm.index.load(Ordering::SeqCst));
209        assert!(values.insert(cm.next()));
210        assert_eq!(usize::MAX, cm.index.load(Ordering::SeqCst));
211        assert!(values.insert(cm.next()));
212        assert_eq!(0, cm.index.load(Ordering::SeqCst));
213        assert!(values.insert(cm.next()));
214        assert_eq!(1, cm.index.load(Ordering::SeqCst));
215        assert!(values.insert(cm.next()));
216        assert_eq!(2, cm.index.load(Ordering::SeqCst));
217        assert!(!values.insert(cm.next()));
218        assert_eq!(3, cm.index.load(Ordering::SeqCst));
219    }
220}