google_cloud_grpc/
conn.rs1use std::sync::atomic::{AtomicI64, Ordering};
2
3use google_cloud_auth::token_source::TokenSource;
4use google_cloud_auth::{create_token_source, Config};
5use http::header::AUTHORIZATION;
6use http::{HeaderValue, Request};
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use tonic::body::BoxBody;
11use tonic::transport::{Certificate, Channel as TonicChannel, ClientTlsConfig, Endpoint};
12use tonic::{Code, Status};
13use tower::filter::{AsyncFilter, AsyncFilterLayer, AsyncPredicate};
14use tower::util::Either;
15use tower::{BoxError, ServiceBuilder};
16
17const TLS_CERTS: &[u8] = include_bytes!("roots.pem");
18
19pub type Channel = Either<AsyncFilter<TonicChannel, AsyncAuthInterceptor>, TonicChannel>;
20
21#[derive(Clone)]
22pub struct AsyncAuthInterceptor {
23 token_source: Arc<dyn TokenSource>,
24}
25
26impl AsyncAuthInterceptor {
27 fn new(token_source: Arc<dyn TokenSource>) -> Self {
28 Self { token_source }
29 }
30}
31
32impl AsyncPredicate<Request<BoxBody>> for AsyncAuthInterceptor {
33 type Future = Pin<Box<dyn Future<Output = Result<Self::Request, BoxError>> + Send>>;
34 type Request = Request<BoxBody>;
35
36 fn check(&mut self, request: Request<BoxBody>) -> Self::Future {
37 let ts = self.token_source.clone();
38 Box::pin(async move {
39 let token = ts
40 .token()
41 .await
42 .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {:?}", e)))?;
43 let token_header = HeaderValue::from_str(token.value().as_ref())
44 .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {:?}", e)))?;
45 let (mut parts, body) = request.into_parts();
46 parts.headers.insert(AUTHORIZATION, token_header);
47 Ok(Request::from_parts(parts, body))
48 })
49 }
50}
51
52#[derive(thiserror::Error, Debug)]
53pub enum Error {
54 #[error(transparent)]
55 AuthInitialize(#[from] google_cloud_auth::error::Error),
56
57 #[error(transparent)]
58 TonicTransport(#[from] tonic::transport::Error),
59
60 #[error("invalid spanner host {0}")]
61 InvalidSpannerHOST(String),
62}
63
64pub struct ConnectionManager {
65 index: AtomicI64,
66 conns: Vec<Channel>,
67}
68
69impl ConnectionManager {
70 pub async fn new(
71 pool_size: usize,
72 domain_name: &'static str,
73 audience: &'static str,
74 scopes: Option<&'static [&'static str]>,
75 emulator_host: Option<String>,
76 ) -> Result<Self, Error> {
77 let conns = match emulator_host {
78 None => Self::create_connections(pool_size, domain_name, audience, scopes).await?,
79 Some(host) => Self::create_emulator_connections(&host).await?,
80 };
81 Ok(Self {
82 index: AtomicI64::new(0),
83 conns,
84 })
85 }
86
87 async fn create_connections(
88 pool_size: usize,
89 domain_name: &'static str,
90 audience: &'static str,
91 scopes: Option<&'static [&'static str]>,
92 ) -> Result<Vec<Channel>, Error> {
93 let tls_config = ClientTlsConfig::new()
94 .ca_certificate(Certificate::from_pem(TLS_CERTS))
95 .domain_name(domain_name);
96 let mut conns = Vec::with_capacity(pool_size);
97
98 let ts = create_token_source(Config {
99 audience: Some(audience),
100 scopes,
101 })
102 .await
103 .map(|e| Arc::from(e))?;
104
105 for _i_ in 0..pool_size {
106 let endpoint = TonicChannel::from_static(audience).tls_config(tls_config.clone())?;
107 let con = Self::connect(endpoint).await?;
108 let auth_layer = Some(AsyncFilterLayer::new(AsyncAuthInterceptor::new(
110 Arc::clone(&ts),
111 )));
112 let auth_con = ServiceBuilder::new().option_layer(auth_layer).service(con);
113 conns.push(auth_con);
114 }
115 Ok(conns)
116 }
117
118 async fn create_emulator_connections(host: &str) -> Result<Vec<Channel>, Error> {
119 let mut conns = Vec::with_capacity(1);
120 let endpoint = TonicChannel::from_shared(format!("http://{}", host).into_bytes())
121 .map_err(|_| Error::InvalidSpannerHOST(host.to_string()))?;
122 let con = Self::connect(endpoint).await?;
123 conns.push(
124 ServiceBuilder::new()
125 .option_layer::<AsyncFilterLayer<AsyncAuthInterceptor>>(None)
126 .service(con),
127 );
128 Ok(conns)
129 }
130
131 async fn connect(endpoint: Endpoint) -> Result<TonicChannel, tonic::transport::Error> {
132 let channel = endpoint.connect().await?;
133 Ok(channel)
134 }
135
136 pub fn num(&self) -> usize {
137 self.conns.len()
138 }
139
140 pub fn conn(&self) -> Channel {
141 let current = self.index.fetch_add(1, Ordering::SeqCst) as usize;
142 self.conns[current % self.conns.len()].clone()
144 }
145}