1use std::fmt::Debug;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::sync::Arc;
6use std::time::Duration;
7
8use http::header::AUTHORIZATION;
9use http::{HeaderValue, Request};
10use tonic::body::Body;
11use tonic::transport::{Channel as TonicChannel, ClientTlsConfig, Endpoint};
12use tonic::{Code, Status};
13use tower::filter::{AsyncFilter, AsyncPredicate};
14use tower::{BoxError, ServiceBuilder};
15
16use token_source::{TokenSource, TokenSourceProvider};
17
18pub type Channel = AsyncFilter<TonicChannel, AsyncAuthInterceptor>;
19
20#[derive(Clone, Debug)]
21pub struct AsyncAuthInterceptor {
22 token_source: Option<Arc<dyn TokenSource>>,
23}
24
25impl AsyncAuthInterceptor {
26 fn new(token_source: Arc<dyn TokenSource>) -> Self {
27 Self {
28 token_source: Some(token_source),
29 }
30 }
31 fn empty() -> Self {
32 Self { token_source: None }
33 }
34}
35
36impl AsyncPredicate<Request<Body>> for AsyncAuthInterceptor {
37 type Future = Pin<Box<dyn Future<Output = Result<Self::Request, BoxError>> + Send>>;
38 type Request = Request<Body>;
39
40 fn check(&mut self, request: Request<Body>) -> Self::Future {
41 let ts = match &self.token_source {
42 Some(ts) => ts.clone(),
43 None => return Box::pin(async move { Ok(request) }),
44 };
45 Box::pin(async move {
46 let token = ts
47 .token()
48 .await
49 .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {e:?}")))?;
50 let token_header = HeaderValue::from_str(token.as_str())
51 .map_err(|e| Status::new(Code::Unauthenticated, format!("token error: {e:?}")))?;
52 let (mut parts, body) = request.into_parts();
53 parts.headers.insert(AUTHORIZATION, token_header);
54 Ok(Request::from_parts(parts, body))
55 })
56 }
57}
58
59#[derive(thiserror::Error, Debug)]
60pub enum Error {
61 #[error(transparent)]
62 Auth(#[from] Box<dyn std::error::Error + Send + Sync>),
63
64 #[error("tonic error : {0}")]
65 TonicTransport(#[from] tonic::transport::Error),
66
67 #[error("invalid emulator host: {0}")]
68 InvalidEmulatorHOST(String),
69}
70
71#[derive(Debug)]
72pub enum Environment {
73 Emulator(String),
74 GoogleCloud(Box<dyn TokenSourceProvider>),
75}
76
77#[derive(Debug)]
78struct AtomicRing<T>
79where
80 T: Clone + Debug,
81{
82 index: AtomicUsize,
83 values: Vec<T>,
84}
85
86impl<T> AtomicRing<T>
87where
88 T: Clone + Debug,
89{
90 fn next(&self) -> T {
91 let current = self.index.fetch_add(1, Ordering::SeqCst);
92 self.values[current % self.values.len()].clone()
94 }
95}
96
97#[derive(Debug, Clone, Default)]
98pub struct ConnectionOptions {
99 pub timeout: Option<Duration>,
100 pub connect_timeout: Option<Duration>,
101}
102
103impl ConnectionOptions {
104 fn apply(&self, mut endpoint: Endpoint) -> Endpoint {
105 endpoint = match self.timeout {
106 Some(t) => endpoint.timeout(t),
107 None => endpoint,
108 };
109 endpoint = match self.connect_timeout {
110 Some(t) => endpoint.connect_timeout(t),
111 None => endpoint,
112 };
113 endpoint
114 }
115}
116
117#[derive(Debug)]
118pub struct ConnectionManager {
119 inner: AtomicRing<Channel>,
120}
121
122impl<'a> ConnectionManager {
123 pub async fn new(
124 pool_size: usize,
125 domain_name: impl Into<String>,
126 audience: &str,
127 environment: &Environment,
128 conn_options: &'a ConnectionOptions,
129 ) -> Result<Self, Error> {
130 let conns = match environment {
131 Environment::GoogleCloud(ts_provider) => {
132 Self::create_connections(pool_size, domain_name, audience, ts_provider.as_ref(), conn_options).await?
133 }
134 Environment::Emulator(host) => Self::create_emulator_connections(host, conn_options).await?,
135 };
136 Ok(Self {
137 inner: AtomicRing {
138 index: AtomicUsize::new(0),
139 values: conns,
140 },
141 })
142 }
143
144 async fn create_connections(
145 pool_size: usize,
146 domain_name: impl Into<String>,
147 audience: &str,
148 ts_provider: &dyn TokenSourceProvider,
149 conn_options: &'a ConnectionOptions,
150 ) -> Result<Vec<Channel>, Error> {
151 let pool_size = Self::get_pool_size(pool_size);
152
153 let tls_config = ClientTlsConfig::new().with_webpki_roots().domain_name(domain_name);
154 let mut conns = Vec::with_capacity(pool_size);
155
156 let ts = ts_provider.token_source();
157
158 for _i_ in 0..pool_size {
159 let endpoint = TonicChannel::from_shared(audience.to_string().into_bytes())
160 .map_err(|e| Error::InvalidEmulatorHOST(e.to_string()))?
161 .tls_config(tls_config.clone())?;
162 let endpoint = conn_options.apply(endpoint);
163
164 let con = Self::connect(endpoint).await?;
165 let auth_filter = AsyncAuthInterceptor::new(Arc::clone(&ts));
167 let auth_con = ServiceBuilder::new().filter_async(auth_filter).service(con);
168 conns.push(auth_con);
169 }
170 Ok(conns)
171 }
172
173 async fn create_emulator_connections(
174 host: &str,
175 conn_options: &'a ConnectionOptions,
176 ) -> Result<Vec<Channel>, Error> {
177 let mut conns = Vec::with_capacity(1);
178 let endpoint = TonicChannel::from_shared(format!("http://{host}").into_bytes())
179 .map_err(|_| Error::InvalidEmulatorHOST(host.to_string()))?;
180 let endpoint = conn_options.apply(endpoint);
181
182 let con = Self::connect(endpoint).await?;
183 let auth_filter = AsyncAuthInterceptor::empty();
184 let auth_con = ServiceBuilder::new().filter_async(auth_filter).service(con);
185 conns.push(auth_con);
186 Ok(conns)
187 }
188
189 async fn connect(endpoint: Endpoint) -> Result<TonicChannel, tonic::transport::Error> {
190 let channel = endpoint.connect().await?;
191 Ok(channel)
192 }
193
194 fn get_pool_size(pool_size: usize) -> usize {
195 pool_size.max(1)
196 }
197
198 pub fn num(&self) -> usize {
199 self.inner.values.len()
200 }
201
202 pub fn conn(&self) -> Channel {
203 self.inner.next()
204 }
205}
206
207#[cfg(test)]
208mod test {
209 use std::collections::HashSet;
210 use std::sync::atomic::{AtomicUsize, Ordering};
211
212 use crate::conn::{AtomicRing, ConnectionManager};
213
214 #[test]
215 fn test_atomic_ring() {
216 let cm = AtomicRing::<&str> {
217 index: AtomicUsize::new(usize::MAX - 1),
218 values: vec!["a", "b", "c", "d"],
219 };
220 let mut values = HashSet::new();
221 assert_eq!(usize::MAX - 1, cm.index.load(Ordering::SeqCst));
222 assert!(values.insert(cm.next()));
223 assert_eq!(usize::MAX, cm.index.load(Ordering::SeqCst));
224 assert!(values.insert(cm.next()));
225 assert_eq!(0, cm.index.load(Ordering::SeqCst));
226 assert!(values.insert(cm.next()));
227 assert_eq!(1, cm.index.load(Ordering::SeqCst));
228 assert!(values.insert(cm.next()));
229 assert_eq!(2, cm.index.load(Ordering::SeqCst));
230 assert!(!values.insert(cm.next()));
231 assert_eq!(3, cm.index.load(Ordering::SeqCst));
232 }
233
234 #[test]
235 fn test_get_pool_size() {
236 assert_eq!(1, ConnectionManager::get_pool_size(0));
237 assert_eq!(1, ConnectionManager::get_pool_size(1));
238 assert_eq!(2, ConnectionManager::get_pool_size(2));
239 }
240}