1use std::fmt::Debug;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use http::header::AUTHORIZATION;
9use http::{HeaderValue, Request};
10use tonic::body::Body;
11use tonic::transport::{Channel as TonicChannel, ClientTlsConfig, Endpoint};
12use tonic::{Code, Status};
13use tower::filter::{AsyncFilter, AsyncFilterLayer, AsyncPredicate};
14use tower::util::Either;
15use tower::{BoxError, ServiceBuilder};
16
17use token_source::{TokenSource, TokenSourceProvider};
18
19pub type Channel = Either<AsyncFilter<TonicChannel, AsyncAuthInterceptor>, TonicChannel>;
20
21#[derive(Clone, Debug)]
22pub struct AsyncAuthInterceptor {
23 token_source: Arc<dyn TokenSource>,
24}
25
26impl AsyncAuthInterceptor {
27 fn new(token_source: Arc<dyn TokenSource>) -> Self {
28 Self { token_source }
29 }
30}
31
32impl AsyncPredicate<Request<Body>> for AsyncAuthInterceptor {
33 type Future = Pin<Box<dyn Future<Output = Result<Self::Request, BoxError>> + Send>>;
34 type Request = Request<Body>;
35
36 fn check(&mut self, request: Request<Body>) -> Self::Future {
37 let ts = self.token_source.clone();
38 Box::pin(async move {
39 let token = ts
40 .token()
41 .await
42 .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {e:?}")))?;
43 let token_header = HeaderValue::from_str(token.as_str())
44 .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {e:?}")))?;
45 let (mut parts, body) = request.into_parts();
46 parts.headers.insert(AUTHORIZATION, token_header);
47 Ok(Request::from_parts(parts, body))
48 })
49 }
50}
51
52#[derive(thiserror::Error, Debug)]
53pub enum Error {
54 #[error(transparent)]
55 Auth(#[from] Box<dyn std::error::Error + Send + Sync>),
56
57 #[error("tonic error : {0}")]
58 TonicTransport(#[from] tonic::transport::Error),
59
60 #[error("invalid emulator host: {0}")]
61 InvalidEmulatorHOST(String),
62}
63
64#[derive(Debug)]
65pub enum Environment {
66 Emulator(String),
67 GoogleCloud(Box<dyn TokenSourceProvider>),
68}
69
70#[derive(Debug)]
71struct AtomicRing<T>
72where
73 T: Clone + Debug,
74{
75 index: AtomicUsize,
76 values: Vec<T>,
77}
78
79impl<T> AtomicRing<T>
80where
81 T: Clone + Debug,
82{
83 fn next(&self) -> T {
84 let current = self.index.fetch_add(1, Ordering::SeqCst);
85 self.values[current % self.values.len()].clone()
87 }
88}
89
90#[derive(Debug, Clone, Default)]
91pub struct ConnectionOptions {
92 pub timeout: Option<Duration>,
93 pub connect_timeout: Option<Duration>,
94}
95
96impl ConnectionOptions {
97 fn apply(&self, mut endpoint: Endpoint) -> Endpoint {
98 endpoint = match self.timeout {
99 Some(t) => endpoint.timeout(t),
100 None => endpoint,
101 };
102 endpoint = match self.connect_timeout {
103 Some(t) => endpoint.connect_timeout(t),
104 None => endpoint,
105 };
106 endpoint
107 }
108}
109
110#[derive(Debug)]
111pub struct ConnectionManager {
112 inner: AtomicRing<Channel>,
113}
114
115impl<'a> ConnectionManager {
116 pub async fn new(
117 pool_size: usize,
118 domain_name: impl Into<String>,
119 audience: &'static str,
120 environment: &Environment,
121 conn_options: &'a ConnectionOptions,
122 ) -> Result<Self, Error> {
123 let conns = match environment {
124 Environment::GoogleCloud(ts_provider) => {
125 Self::create_connections(pool_size, domain_name, audience, ts_provider.as_ref(), conn_options).await?
126 }
127 Environment::Emulator(host) => Self::create_emulator_connections(host, conn_options).await?,
128 };
129 Ok(Self {
130 inner: AtomicRing {
131 index: AtomicUsize::new(0),
132 values: conns,
133 },
134 })
135 }
136
137 async fn create_connections(
138 pool_size: usize,
139 domain_name: impl Into<String>,
140 audience: &'static str,
141 ts_provider: &dyn TokenSourceProvider,
142 conn_options: &'a ConnectionOptions,
143 ) -> Result<Vec<Channel>, Error> {
144 let tls_config = ClientTlsConfig::new().with_webpki_roots().domain_name(domain_name);
145 let mut conns = Vec::with_capacity(pool_size);
146
147 let ts = ts_provider.token_source();
148
149 for _i_ in 0..pool_size {
150 let endpoint = TonicChannel::from_static(audience).tls_config(tls_config.clone())?;
151 let endpoint = conn_options.apply(endpoint);
152
153 let con = Self::connect(endpoint).await?;
154 let auth_layer = Some(AsyncFilterLayer::new(AsyncAuthInterceptor::new(Arc::clone(&ts))));
156 let auth_con = ServiceBuilder::new().option_layer(auth_layer).service(con);
157 conns.push(auth_con);
158 }
159 Ok(conns)
160 }
161
162 async fn create_emulator_connections(
163 host: &str,
164 conn_options: &'a ConnectionOptions,
165 ) -> Result<Vec<Channel>, Error> {
166 let mut conns = Vec::with_capacity(1);
167 let endpoint = TonicChannel::from_shared(format!("http://{host}").into_bytes())
168 .map_err(|_| Error::InvalidEmulatorHOST(host.to_string()))?;
169 let endpoint = conn_options.apply(endpoint);
170
171 let con = Self::connect(endpoint).await?;
172 conns.push(
173 ServiceBuilder::new()
174 .option_layer::<AsyncFilterLayer<AsyncAuthInterceptor>>(None)
175 .service(con),
176 );
177 Ok(conns)
178 }
179
180 async fn connect(endpoint: Endpoint) -> Result<TonicChannel, tonic::transport::Error> {
181 let channel = endpoint.connect().await?;
182 Ok(channel)
183 }
184
185 pub fn num(&self) -> usize {
186 self.inner.values.len()
187 }
188
189 pub fn conn(&self) -> Channel {
190 self.inner.next()
191 }
192}
193
194#[cfg(test)]
195mod test {
196 use std::collections::HashSet;
197 use std::sync::atomic::{AtomicUsize, Ordering};
198
199 use crate::conn::AtomicRing;
200
201 #[test]
202 fn test_atomic_ring() {
203 let cm = AtomicRing::<&str> {
204 index: AtomicUsize::new(usize::MAX - 1),
205 values: vec!["a", "b", "c", "d"],
206 };
207 let mut values = HashSet::new();
208 assert_eq!(usize::MAX - 1, cm.index.load(Ordering::SeqCst));
209 assert!(values.insert(cm.next()));
210 assert_eq!(usize::MAX, cm.index.load(Ordering::SeqCst));
211 assert!(values.insert(cm.next()));
212 assert_eq!(0, cm.index.load(Ordering::SeqCst));
213 assert!(values.insert(cm.next()));
214 assert_eq!(1, cm.index.load(Ordering::SeqCst));
215 assert!(values.insert(cm.next()));
216 assert_eq!(2, cm.index.load(Ordering::SeqCst));
217 assert!(!values.insert(cm.next()));
218 assert_eq!(3, cm.index.load(Ordering::SeqCst));
219 }
220}