1use tonic::{Request, Status};
2
3mod config;
4mod error;
5
6#[cfg(test)]
7mod fixtures;
8
9pub type InterceptorFn = fn(Request<()>) -> Result<Request<()>, Status>;
10
11pub use config::{CertConfig, ClientConfig};
12pub use error::MobcTonicError;
13
14pub use mobc::{Error, Manager, Pool};
16
17#[macro_export]
18macro_rules! instantiate_client_pool {
19 ($type:ty) => {
20 #[allow(dead_code)]
21 pub struct ClientManager {
22 pub(crate) config: ClientConfig,
23 pub(crate) interceptor: Option<InterceptorFn>,
24 }
25
26 impl ClientManager {
27 pub fn new(config: ClientConfig) -> Self {
28 Self {
29 config,
30 interceptor: None,
31 }
32 }
33
34 pub fn with_interceptor(config: ClientConfig, interceptor: InterceptorFn) -> Self {
35 Self {
36 config,
37 interceptor: Some(interceptor),
38 }
39 }
40 }
41
42 pub struct ClientPool {
43 pool: Pool<ClientManager>,
44 }
45
46 impl Clone for ClientPool {
47 fn clone(&self) -> Self {
48 Self {
49 pool: self.pool.clone(),
50 }
51 }
52 }
53
54 impl ClientPool {
55 pub fn new(config: ClientConfig) -> Self {
56 let size = config.pool_size;
57 let manager = ClientManager::new(config);
58 let pool = Pool::builder().max_open(size as u64).build(manager);
59 Self { pool }
60 }
61
62 pub fn with_interceptor(config: ClientConfig, interceptor: InterceptorFn) -> Self {
63 let size = config.pool_size;
64 let manager = ClientManager::with_interceptor(config, interceptor);
65 let pool = Pool::builder().max_open(size as u64).build(manager);
66 Self { pool }
67 }
68
69 pub async fn get(&self) -> Result<$type, MobcTonicError> {
70 match self.pool.clone().get().await {
71 Ok(conn) => Ok(conn.into_inner()),
72 Err(Error::Timeout) => Err(MobcTonicError::Timeout),
73 Err(Error::BadConn) => Err(MobcTonicError::BadConn),
74 Err(Error::Inner(e)) => Err(e),
75 }
76 }
77 }
78
79 #[tonic::async_trait]
80 impl Manager for ClientManager {
81 type Connection = $type;
82 type Error = MobcTonicError;
83
84 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
85 let config = &self.config;
86 let cert = Certificate::from_pem(config.ca_cert.clone());
87 let tls = ClientTlsConfig::new()
88 .domain_name(self.config.domain.clone())
89 .ca_certificate(cert);
90 let tls = if let Some(client_config) = config.client_cert.clone() {
91 let identity = Identity::from_pem(client_config.cert, client_config.sk);
92 tls.identity(identity)
93 } else {
94 tls
95 };
96
97 let channel = Channel::from_shared(self.config.uri.clone())?
98 .tls_config(tls)?
99 .connect()
100 .await?;
101
102 let client = if let Some(interceptor) = self.interceptor.as_ref() {
103 Self::Connection::with_interceptor(channel, interceptor.to_owned())
104 } else {
105 Self::Connection::new(channel)
106 };
107
108 Ok(client)
109 }
110
111 async fn check(&self, conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
112 Ok(conn)
113 }
114 }
115 };
116}
117
118#[cfg(test)]
119mod tests {
120 use anyhow::Result;
121
122 use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
123
124 use fixtures::{
125 greeter_client::GreeterClient, start_server, start_server_verify_client_cert, HelloRequest,
126 };
127 use tonic::Code;
128
129 use super::*;
130
131 instantiate_client_pool!(GreeterClient<Channel>);
132
133 #[tokio::test]
134 async fn connect_pool_should_work() -> Result<()> {
135 let server_cert: CertConfig = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
136 tokio::spawn(async move { start_server("0.0.0.0:4000", server_cert).await });
137 sleep(10).await;
138
139 let client_config: ClientConfig =
140 toml::from_str(include_str!("fixtures/client.toml")).unwrap();
141
142 let pool = ClientPool::new(client_config);
143 let mut client = pool.get().await.unwrap();
144 let reply = client
145 .say_hello(HelloRequest {
146 name: "Tyr".to_owned(),
147 })
148 .await
149 .unwrap()
150 .into_inner();
151
152 assert_eq!(reply.message, "Hello Tyr!");
153 Ok(())
154 }
155
156 #[tokio::test]
157 async fn connect_pool_with_client_cert_should_work() -> Result<()> {
158 let server_cert: CertConfig = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
159 tokio::spawn(
160 async move { start_server_verify_client_cert("0.0.0.0:4001", server_cert).await },
161 );
162 sleep(10).await;
163
164 let client_config: ClientConfig =
165 toml::from_str(include_str!("fixtures/client_with_cert.toml")).unwrap();
166
167 let pool = ClientPool::new(client_config);
168 let mut client = pool.get().await.unwrap();
169 let reply = client
170 .say_hello(HelloRequest {
171 name: "Tyr".to_owned(),
172 })
173 .await
174 .unwrap()
175 .into_inner();
176
177 assert_eq!(reply.message, "Hello Tyr!");
178 Ok(())
179 }
180
181 #[tokio::test]
182 async fn connect_pool_with_client_cert_and_intercepter_should_work() -> Result<()> {
183 let server_cert: CertConfig = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
184 tokio::spawn(
185 async move { start_server_verify_client_cert("0.0.0.0:4003", server_cert).await },
186 );
187 sleep(10).await;
188
189 let mut client_config: ClientConfig =
190 toml::from_str(include_str!("fixtures/client_with_cert.toml")).unwrap();
191
192 client_config.uri = "https://localhost:4003".to_owned();
193
194 let pool = ClientPool::with_interceptor(client_config, intercept);
195 let mut client = pool.get().await.unwrap();
196 let reply = client
197 .say_hello(HelloRequest {
198 name: "Tyr".to_owned(),
199 })
200 .await;
201
202 assert!(reply.is_err());
203 assert_eq!(reply.err().unwrap().code(), Code::FailedPrecondition);
204 Ok(())
205 }
206
207 #[tokio::test]
208 async fn connect_pool_with_invalid_client_cert_should_fail() -> Result<()> {
209 let server_cert: CertConfig = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
210 tokio::spawn(
211 async move { start_server_verify_client_cert("0.0.0.0:4002", server_cert).await },
212 );
213 sleep(10).await;
214
215 let client_config: ClientConfig =
216 toml::from_str(include_str!("fixtures/client_with_invalid_cert.toml")).unwrap();
217
218 let pool = ClientPool::new(client_config);
219 let mut client = pool.get().await.unwrap();
220 let reply = client
221 .say_hello(HelloRequest {
222 name: "Tyr".to_owned(),
223 })
224 .await;
225
226 assert!(reply.is_err());
227 Ok(())
228 }
229
230 async fn sleep(duration: u64) {
231 tokio::time::sleep(tokio::time::Duration::from_millis(duration)).await;
232 }
233
234 fn intercept(_req: Request<()>) -> Result<Request<()>, Status> {
235 Err(Status::failed_precondition("should faile"))
236 }
237}