gcloud_gax/
conn.rs

1use std::fmt::Debug;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use http::header::AUTHORIZATION;
9use http::{HeaderValue, Request};
10use tonic::body::Body;
11use tonic::transport::{Channel as TonicChannel, ClientTlsConfig, Endpoint};
12use tonic::{Code, Status};
13use tower::filter::{AsyncFilter, AsyncPredicate};
14use tower::{BoxError, ServiceBuilder};
15
16use token_source::{TokenSource, TokenSourceProvider};
17
18pub type Channel = AsyncFilter<TonicChannel, AsyncAuthInterceptor>;
19
20#[derive(Clone, Debug)]
21pub struct AsyncAuthInterceptor {
22    token_source: Option<Arc<dyn TokenSource>>,
23}
24
25impl AsyncAuthInterceptor {
26    fn new(token_source: Arc<dyn TokenSource>) -> Self {
27        Self {
28            token_source: Some(token_source),
29        }
30    }
31    fn empty() -> Self {
32        Self { token_source: None }
33    }
34}
35
36impl AsyncPredicate<Request<Body>> for AsyncAuthInterceptor {
37    type Future = Pin<Box<dyn Future<Output = Result<Self::Request, BoxError>> + Send>>;
38    type Request = Request<Body>;
39
40    fn check(&mut self, request: Request<Body>) -> Self::Future {
41        let ts = match &self.token_source {
42            Some(ts) => ts.clone(),
43            None => return Box::pin(async move { Ok(request) }),
44        };
45        Box::pin(async move {
46            let token = ts
47                .token()
48                .await
49                .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {e:?}")))?;
50            let token_header = HeaderValue::from_str(token.as_str())
51                .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {e:?}")))?;
52            let (mut parts, body) = request.into_parts();
53            parts.headers.insert(AUTHORIZATION, token_header);
54            Ok(Request::from_parts(parts, body))
55        })
56    }
57}
58
59#[derive(thiserror::Error, Debug)]
60pub enum Error {
61    #[error(transparent)]
62    Auth(#[from] Box<dyn std::error::Error + Send + Sync>),
63
64    #[error("tonic error : {0}")]
65    TonicTransport(#[from] tonic::transport::Error),
66
67    #[error("invalid emulator host: {0}")]
68    InvalidEmulatorHOST(String),
69}
70
71#[derive(Debug)]
72pub enum Environment {
73    Emulator(String),
74    GoogleCloud(Box<dyn TokenSourceProvider>),
75}
76
77#[derive(Debug)]
78struct AtomicRing<T>
79where
80    T: Clone + Debug,
81{
82    index: AtomicUsize,
83    values: Vec<T>,
84}
85
86impl<T> AtomicRing<T>
87where
88    T: Clone + Debug,
89{
90    fn next(&self) -> T {
91        let current = self.index.fetch_add(1, Ordering::SeqCst);
92        //clone() reuses http/2 connection
93        self.values[current % self.values.len()].clone()
94    }
95}
96
97#[derive(Debug, Clone, Default)]
98pub struct ConnectionOptions {
99    pub timeout: Option<Duration>,
100    pub connect_timeout: Option<Duration>,
101}
102
103impl ConnectionOptions {
104    fn apply(&self, mut endpoint: Endpoint) -> Endpoint {
105        endpoint = match self.timeout {
106            Some(t) => endpoint.timeout(t),
107            None => endpoint,
108        };
109        endpoint = match self.connect_timeout {
110            Some(t) => endpoint.connect_timeout(t),
111            None => endpoint,
112        };
113        endpoint
114    }
115}
116
117#[derive(Debug)]
118pub struct ConnectionManager {
119    inner: AtomicRing<Channel>,
120}
121
122impl<'a> ConnectionManager {
123    pub async fn new(
124        pool_size: usize,
125        domain_name: impl Into<String>,
126        audience: &str,
127        environment: &Environment,
128        conn_options: &'a ConnectionOptions,
129    ) -> Result<Self, Error> {
130        let conns = match environment {
131            Environment::GoogleCloud(ts_provider) => {
132                Self::create_connections(pool_size, domain_name, audience, ts_provider.as_ref(), conn_options).await?
133            }
134            Environment::Emulator(host) => Self::create_emulator_connections(host, conn_options).await?,
135        };
136        Ok(Self {
137            inner: AtomicRing {
138                index: AtomicUsize::new(0),
139                values: conns,
140            },
141        })
142    }
143
144    async fn create_connections(
145        pool_size: usize,
146        domain_name: impl Into<String>,
147        audience: &str,
148        ts_provider: &dyn TokenSourceProvider,
149        conn_options: &'a ConnectionOptions,
150    ) -> Result<Vec<Channel>, Error> {
151        let pool_size = Self::get_pool_size(pool_size);
152
153        let tls_config = ClientTlsConfig::new().with_webpki_roots().domain_name(domain_name);
154        let mut conns = Vec::with_capacity(pool_size);
155
156        let ts = ts_provider.token_source();
157
158        for _i_ in 0..pool_size {
159            let endpoint = TonicChannel::from_shared(audience.to_string().into_bytes())
160                .map_err(|e| Error::InvalidEmulatorHOST(e.to_string()))?
161                .tls_config(tls_config.clone())?;
162            let endpoint = conn_options.apply(endpoint);
163
164            let con = Self::connect(endpoint).await?;
165            // use GCP token per call
166            let auth_filter = AsyncAuthInterceptor::new(Arc::clone(&ts));
167            let auth_con = ServiceBuilder::new().filter_async(auth_filter).service(con);
168            conns.push(auth_con);
169        }
170        Ok(conns)
171    }
172
173    async fn create_emulator_connections(
174        host: &str,
175        conn_options: &'a ConnectionOptions,
176    ) -> Result<Vec<Channel>, Error> {
177        let mut conns = Vec::with_capacity(1);
178        let endpoint = TonicChannel::from_shared(format!("http://{host}").into_bytes())
179            .map_err(|_| Error::InvalidEmulatorHOST(host.to_string()))?;
180        let endpoint = conn_options.apply(endpoint);
181
182        let con = Self::connect(endpoint).await?;
183        let auth_filter = AsyncAuthInterceptor::empty();
184        let auth_con = ServiceBuilder::new().filter_async(auth_filter).service(con);
185        conns.push(auth_con);
186        Ok(conns)
187    }
188
189    async fn connect(endpoint: Endpoint) -> Result<TonicChannel, tonic::transport::Error> {
190        let channel = endpoint.connect().await?;
191        Ok(channel)
192    }
193
194    fn get_pool_size(pool_size: usize) -> usize {
195        pool_size.max(1)
196    }
197
198    pub fn num(&self) -> usize {
199        self.inner.values.len()
200    }
201
202    pub fn conn(&self) -> Channel {
203        self.inner.next()
204    }
205}
206
207#[cfg(test)]
208mod test {
209    use std::collections::HashSet;
210    use std::sync::atomic::{AtomicUsize, Ordering};
211
212    use crate::conn::{AtomicRing, ConnectionManager};
213
214    #[test]
215    fn test_atomic_ring() {
216        let cm = AtomicRing::<&str> {
217            index: AtomicUsize::new(usize::MAX - 1),
218            values: vec!["a", "b", "c", "d"],
219        };
220        let mut values = HashSet::new();
221        assert_eq!(usize::MAX - 1, cm.index.load(Ordering::SeqCst));
222        assert!(values.insert(cm.next()));
223        assert_eq!(usize::MAX, cm.index.load(Ordering::SeqCst));
224        assert!(values.insert(cm.next()));
225        assert_eq!(0, cm.index.load(Ordering::SeqCst));
226        assert!(values.insert(cm.next()));
227        assert_eq!(1, cm.index.load(Ordering::SeqCst));
228        assert!(values.insert(cm.next()));
229        assert_eq!(2, cm.index.load(Ordering::SeqCst));
230        assert!(!values.insert(cm.next()));
231        assert_eq!(3, cm.index.load(Ordering::SeqCst));
232    }
233
234    #[test]
235    fn test_get_pool_size() {
236        assert_eq!(1, ConnectionManager::get_pool_size(0));
237        assert_eq!(1, ConnectionManager::get_pool_size(1));
238        assert_eq!(2, ConnectionManager::get_pool_size(2));
239    }
240}