mobc_tonic/
lib.rs

1use tonic::{Request, Status};
2
3mod config;
4mod error;
5
6#[cfg(test)]
7mod fixtures;
8
9pub type InterceptorFn = fn(Request<()>) -> Result<Request<()>, Status>;
10
11pub use config::{CertConfig, ClientConfig};
12pub use error::MobcTonicError;
13
14/// re-exports Manager and Pool
15pub use mobc::{Error, Manager, Pool};
16
17#[macro_export]
18macro_rules! instantiate_client_pool {
19    ($type:ty) => {
20        #[allow(dead_code)]
21        pub struct ClientManager {
22            pub(crate) config: ClientConfig,
23            pub(crate) interceptor: Option<InterceptorFn>,
24        }
25
26        impl ClientManager {
27            pub fn new(config: ClientConfig) -> Self {
28                Self {
29                    config,
30                    interceptor: None,
31                }
32            }
33
34            pub fn with_interceptor(config: ClientConfig, interceptor: InterceptorFn) -> Self {
35                Self {
36                    config,
37                    interceptor: Some(interceptor),
38                }
39            }
40        }
41
42        pub struct ClientPool {
43            pool: Pool<ClientManager>,
44        }
45
46        impl Clone for ClientPool {
47            fn clone(&self) -> Self {
48                Self {
49                    pool: self.pool.clone(),
50                }
51            }
52        }
53
54        impl ClientPool {
55            pub fn new(config: ClientConfig) -> Self {
56                let size = config.pool_size;
57                let manager = ClientManager::new(config);
58                let pool = Pool::builder().max_open(size as u64).build(manager);
59                Self { pool }
60            }
61
62            pub fn with_interceptor(config: ClientConfig, interceptor: InterceptorFn) -> Self {
63                let size = config.pool_size;
64                let manager = ClientManager::with_interceptor(config, interceptor);
65                let pool = Pool::builder().max_open(size as u64).build(manager);
66                Self { pool }
67            }
68
69            pub async fn get(&self) -> Result<$type, MobcTonicError> {
70                match self.pool.clone().get().await {
71                    Ok(conn) => Ok(conn.into_inner()),
72                    Err(Error::Timeout) => Err(MobcTonicError::Timeout),
73                    Err(Error::BadConn) => Err(MobcTonicError::BadConn),
74                    Err(Error::Inner(e)) => Err(e),
75                }
76            }
77        }
78
79        #[tonic::async_trait]
80        impl Manager for ClientManager {
81            type Connection = $type;
82            type Error = MobcTonicError;
83
84            async fn connect(&self) -> Result<Self::Connection, Self::Error> {
85                let config = &self.config;
86                let cert = Certificate::from_pem(config.ca_cert.clone());
87                let tls = ClientTlsConfig::new()
88                    .domain_name(self.config.domain.clone())
89                    .ca_certificate(cert);
90                let tls = if let Some(client_config) = config.client_cert.clone() {
91                    let identity = Identity::from_pem(client_config.cert, client_config.sk);
92                    tls.identity(identity)
93                } else {
94                    tls
95                };
96
97                let channel = Channel::from_shared(self.config.uri.clone())?
98                    .tls_config(tls)?
99                    .connect()
100                    .await?;
101
102                let client = if let Some(interceptor) = self.interceptor.as_ref() {
103                    Self::Connection::with_interceptor(channel, interceptor.to_owned())
104                } else {
105                    Self::Connection::new(channel)
106                };
107
108                Ok(client)
109            }
110
111            async fn check(&self, conn: Self::Connection) -> Result<Self::Connection, Self::Error> {
112                Ok(conn)
113            }
114        }
115    };
116}
117
118#[cfg(test)]
119mod tests {
120    use anyhow::Result;
121
122    use tonic::transport::{Certificate, Channel, ClientTlsConfig, Identity};
123
124    use fixtures::{
125        greeter_client::GreeterClient, start_server, start_server_verify_client_cert, HelloRequest,
126    };
127    use tonic::Code;
128
129    use super::*;
130
131    instantiate_client_pool!(GreeterClient<Channel>);
132
133    #[tokio::test]
134    async fn connect_pool_should_work() -> Result<()> {
135        let server_cert: CertConfig = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
136        tokio::spawn(async move { start_server("0.0.0.0:4000", server_cert).await });
137        sleep(10).await;
138
139        let client_config: ClientConfig =
140            toml::from_str(include_str!("fixtures/client.toml")).unwrap();
141
142        let pool = ClientPool::new(client_config);
143        let mut client = pool.get().await.unwrap();
144        let reply = client
145            .say_hello(HelloRequest {
146                name: "Tyr".to_owned(),
147            })
148            .await
149            .unwrap()
150            .into_inner();
151
152        assert_eq!(reply.message, "Hello Tyr!");
153        Ok(())
154    }
155
156    #[tokio::test]
157    async fn connect_pool_with_client_cert_should_work() -> Result<()> {
158        let server_cert: CertConfig = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
159        tokio::spawn(
160            async move { start_server_verify_client_cert("0.0.0.0:4001", server_cert).await },
161        );
162        sleep(10).await;
163
164        let client_config: ClientConfig =
165            toml::from_str(include_str!("fixtures/client_with_cert.toml")).unwrap();
166
167        let pool = ClientPool::new(client_config);
168        let mut client = pool.get().await.unwrap();
169        let reply = client
170            .say_hello(HelloRequest {
171                name: "Tyr".to_owned(),
172            })
173            .await
174            .unwrap()
175            .into_inner();
176
177        assert_eq!(reply.message, "Hello Tyr!");
178        Ok(())
179    }
180
181    #[tokio::test]
182    async fn connect_pool_with_client_cert_and_intercepter_should_work() -> Result<()> {
183        let server_cert: CertConfig = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
184        tokio::spawn(
185            async move { start_server_verify_client_cert("0.0.0.0:4003", server_cert).await },
186        );
187        sleep(10).await;
188
189        let mut client_config: ClientConfig =
190            toml::from_str(include_str!("fixtures/client_with_cert.toml")).unwrap();
191
192        client_config.uri = "https://localhost:4003".to_owned();
193
194        let pool = ClientPool::with_interceptor(client_config, intercept);
195        let mut client = pool.get().await.unwrap();
196        let reply = client
197            .say_hello(HelloRequest {
198                name: "Tyr".to_owned(),
199            })
200            .await;
201
202        assert!(reply.is_err());
203        assert_eq!(reply.err().unwrap().code(), Code::FailedPrecondition);
204        Ok(())
205    }
206
207    #[tokio::test]
208    async fn connect_pool_with_invalid_client_cert_should_fail() -> Result<()> {
209        let server_cert: CertConfig = toml::from_str(include_str!("fixtures/server.toml")).unwrap();
210        tokio::spawn(
211            async move { start_server_verify_client_cert("0.0.0.0:4002", server_cert).await },
212        );
213        sleep(10).await;
214
215        let client_config: ClientConfig =
216            toml::from_str(include_str!("fixtures/client_with_invalid_cert.toml")).unwrap();
217
218        let pool = ClientPool::new(client_config);
219        let mut client = pool.get().await.unwrap();
220        let reply = client
221            .say_hello(HelloRequest {
222                name: "Tyr".to_owned(),
223            })
224            .await;
225
226        assert!(reply.is_err());
227        Ok(())
228    }
229
230    async fn sleep(duration: u64) {
231        tokio::time::sleep(tokio::time::Duration::from_millis(duration)).await;
232    }
233
234    fn intercept(_req: Request<()>) -> Result<Request<()>, Status> {
235        Err(Status::failed_precondition("should faile"))
236    }
237}