google_cloud_grpc/
conn.rs

1use std::sync::atomic::{AtomicI64, Ordering};
2
3use google_cloud_auth::token_source::TokenSource;
4use google_cloud_auth::{create_token_source, Config};
5use http::header::AUTHORIZATION;
6use http::{HeaderValue, Request};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use tonic::body::BoxBody;
11use tonic::transport::{Certificate, Channel as TonicChannel, ClientTlsConfig, Endpoint};
12use tonic::{Code, Status};
13use tower::filter::{AsyncFilter, AsyncFilterLayer, AsyncPredicate};
14use tower::util::Either;
15use tower::{BoxError, ServiceBuilder};
16
17const TLS_CERTS: &[u8] = include_bytes!("roots.pem");
18
19pub type Channel = Either<AsyncFilter<TonicChannel, AsyncAuthInterceptor>, TonicChannel>;
20
21#[derive(Clone)]
22pub struct AsyncAuthInterceptor {
23    token_source: Arc<dyn TokenSource>,
24}
25
26impl AsyncAuthInterceptor {
27    fn new(token_source: Arc<dyn TokenSource>) -> Self {
28        Self { token_source }
29    }
30}
31
32impl AsyncPredicate<Request<BoxBody>> for AsyncAuthInterceptor {
33    type Future = Pin<Box<dyn Future<Output = Result<Self::Request, BoxError>> + Send>>;
34    type Request = Request<BoxBody>;
35
36    fn check(&mut self, request: Request<BoxBody>) -> Self::Future {
37        let ts = self.token_source.clone();
38        Box::pin(async move {
39            let token = ts
40                .token()
41                .await
42                .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {:?}", e)))?;
43            let token_header = HeaderValue::from_str(token.value().as_ref())
44                .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {:?}", e)))?;
45            let (mut parts, body) = request.into_parts();
46            parts.headers.insert(AUTHORIZATION, token_header);
47            Ok(Request::from_parts(parts, body))
48        })
49    }
50}
51
52#[derive(thiserror::Error, Debug)]
53pub enum Error {
54    #[error(transparent)]
55    AuthInitialize(#[from] google_cloud_auth::error::Error),
56
57    #[error(transparent)]
58    TonicTransport(#[from] tonic::transport::Error),
59
60    #[error("invalid spanner host {0}")]
61    InvalidSpannerHOST(String),
62}
63
64pub struct ConnectionManager {
65    index: AtomicI64,
66    conns: Vec<Channel>,
67}
68
69impl ConnectionManager {
70    pub async fn new(
71        pool_size: usize,
72        domain_name: &'static str,
73        audience: &'static str,
74        scopes: Option<&'static [&'static str]>,
75        emulator_host: Option<String>,
76    ) -> Result<Self, Error> {
77        let conns = match emulator_host {
78            None => Self::create_connections(pool_size, domain_name, audience, scopes).await?,
79            Some(host) => Self::create_emulator_connections(&host).await?,
80        };
81        Ok(Self {
82            index: AtomicI64::new(0),
83            conns,
84        })
85    }
86
87    async fn create_connections(
88        pool_size: usize,
89        domain_name: &'static str,
90        audience: &'static str,
91        scopes: Option<&'static [&'static str]>,
92    ) -> Result<Vec<Channel>, Error> {
93        let tls_config = ClientTlsConfig::new()
94            .ca_certificate(Certificate::from_pem(TLS_CERTS))
95            .domain_name(domain_name);
96        let mut conns = Vec::with_capacity(pool_size);
97
98        let ts = create_token_source(Config {
99            audience: Some(audience),
100            scopes,
101        })
102        .await
103        .map(|e| Arc::from(e))?;
104
105        for _i_ in 0..pool_size {
106            let endpoint = TonicChannel::from_static(audience).tls_config(tls_config.clone())?;
107            let con = Self::connect(endpoint).await?;
108            // use GCP token per call
109            let auth_layer = Some(AsyncFilterLayer::new(AsyncAuthInterceptor::new(
110                Arc::clone(&ts),
111            )));
112            let auth_con = ServiceBuilder::new().option_layer(auth_layer).service(con);
113            conns.push(auth_con);
114        }
115        Ok(conns)
116    }
117
118    async fn create_emulator_connections(host: &str) -> Result<Vec<Channel>, Error> {
119        let mut conns = Vec::with_capacity(1);
120        let endpoint = TonicChannel::from_shared(format!("http://{}", host).into_bytes())
121            .map_err(|_| Error::InvalidSpannerHOST(host.to_string()))?;
122        let con = Self::connect(endpoint).await?;
123        conns.push(
124            ServiceBuilder::new()
125                .option_layer::<AsyncFilterLayer<AsyncAuthInterceptor>>(None)
126                .service(con),
127        );
128        Ok(conns)
129    }
130
131    async fn connect(endpoint: Endpoint) -> Result<TonicChannel, tonic::transport::Error> {
132        let channel = endpoint.connect().await?;
133        Ok(channel)
134    }
135
136    pub fn num(&self) -> usize {
137        self.conns.len()
138    }
139
140    pub fn conn(&self) -> Channel {
141        let current = self.index.fetch_add(1, Ordering::SeqCst) as usize;
142        //clone() reuses http/2 connection
143        self.conns[current % self.conns.len()].clone()
144    }
145}