deadpool_tiberius_rustls/
lib.rs1#![warn(missing_docs)]
30#![cfg_attr(docsrs, feature(doc_cfg))]
31use std::mem::replace;
32use std::time::Duration;
33
34pub use deadpool;
35use deadpool::{
36 async_trait, managed,
37 managed::{Hook, HookFuture, HookResult, Metrics, PoolConfig, RecycleError, RecycleResult},
38 Runtime,
39};
40pub use tiberius_rustls;
41use tiberius_rustls::error::Error;
42use tiberius_rustls::{AuthMethod, EncryptionLevel};
43use tokio_util::compat::TokioAsyncWriteCompatExt;
44
45pub use crate::error::SqlServerError;
46pub use crate::error::SqlServerResult;
47
48mod error;
49
50pub type Client = tiberius_rustls::Client<tokio_util::compat::Compat<tokio::net::TcpStream>>;
52pub type Pool = managed::Pool<Manager>;
54
55pub struct Manager {
59 config: tiberius_rustls::Config,
60 pool_config: PoolConfig,
61 runtime: Option<Runtime>,
62 hooks: Hooks,
63 modify_tcp_stream:
64 Box<dyn Fn(&tokio::net::TcpStream) -> tokio::io::Result<()> + Send + Sync + 'static>,
65 #[cfg(feature = "sql-browser")]
66 enable_sql_browser: bool,
67}
68
69#[async_trait]
70impl managed::Manager for Manager {
71 type Type = Client;
72 type Error = tiberius_rustls::error::Error;
73
74 #[cfg(feature = "sql-browser")]
75 async fn create(&self) -> Result<Client, Self::Error> {
76 use tiberius_rustls::SqlBrowser;
77 let tcp = if !self.enable_sql_browser {
78 tokio::net::TcpStream::connect(self.config.get_addr()).await?
79 } else {
80 tokio::net::TcpStream::connect_named(&self.config).await?
81 };
82 (self.modify_tcp_stream)(&tcp)?;
83 let client = Client::connect(self.config.clone(), tcp.compat_write()).await;
84 match client {
85 Ok(client) => Ok(client),
86 Err(Error::Routing { host, port }) => {
87 let mut config = self.config.clone();
88 config.host(host);
89 config.port(port);
90
91 let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
92 tcp.set_nodelay(true)?;
93
94 Client::connect(config, tcp.compat_write()).await
95 },
96 Err(err) => Err(err)?,
98 }
99 }
100
101 #[cfg(not(feature = "sql-browser"))]
102 async fn create(&self) -> Result<Client, Self::Error> {
103 let tcp = tokio::net::TcpStream::connect(self.config.get_addr()).await?;
104 (self.modify_tcp_stream)(&tcp)?;
105 let client = Client::connect(self.config.clone(), tcp.compat_write()).await;
106
107 match client {
108 Ok(client) => Ok(client),
109 Err(Error::Routing { host, port }) => {
110 let mut config = self.config.clone();
111 config.host(host);
112 config.port(port);
113
114 let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
115 tcp.set_nodelay(true)?;
116
117 Client::connect(config, tcp.compat_write()).await
118 },
119 Err(err) => Err(err)?,
121 }
122 }
123
124 async fn recycle(
125 &self,
126 obj: &mut Self::Type,
127 _metrics: &Metrics,
128 ) -> RecycleResult<Self::Error> {
129 match obj.simple_query("").await {
130 Ok(_) => Ok(()),
131 Err(e) => Err(RecycleError::Message(e.to_string())),
132 }
133 }
134}
135
136impl Manager {
137 pub fn new() -> Self {
139 Self::new_with_tiberius_config(tiberius_rustls::Config::new())
140 }
141
142 pub fn from_ado_string(conn_str: &str) -> SqlServerResult<Self> {
147 Ok(Self::new_with_tiberius_config(
148 tiberius_rustls::Config::from_ado_string(conn_str)?,
149 ))
150 }
151
152 pub fn from_jdbc_string(conn_str: &str) -> SqlServerResult<Self> {
157 Ok(Self::new_with_tiberius_config(
158 tiberius_rustls::Config::from_jdbc_string(conn_str)?,
159 ))
160 }
161
162 fn new_with_tiberius_config(config: tiberius_rustls::Config) -> Self {
163 Self {
164 config,
165 pool_config: Default::default(),
166 runtime: None,
167 hooks: Default::default(),
168 modify_tcp_stream: Box::new(|tcp_stream| tcp_stream.set_nodelay(true)),
169 #[cfg(feature = "sql-browser")]
170 enable_sql_browser: false,
171 }
172 }
173
174 pub fn create_pool(mut self) -> Result<Pool, error::SqlServerError> {
176 let config = self.pool_config;
177 let runtime = self.runtime;
178 let hooks = replace(&mut self.hooks, Hooks::default());
179 let mut pool = Pool::builder(self).config(config);
180 if let Some(v) = runtime {
181 pool = pool.runtime(v);
182 }
183
184 for hook in hooks.post_create {
185 pool = pool.post_create(hook);
186 }
187 for hook in hooks.pre_recycle {
188 pool = pool.pre_recycle(hook);
189 }
190 for hook in hooks.post_recycle {
191 pool = pool.post_recycle(hook);
192 }
193
194 Ok(pool.build()?)
195 }
196
197 #[cfg(feature = "sql-browser")]
199 #[cfg_attr(docsrs, doc(cfg(feature = "sql-browser")))]
200 pub fn enable_sql_browser(mut self) -> Self {
201 self.enable_sql_browser = true;
202 self
203 }
204
205 pub fn host(mut self, host: impl ToString) -> Self {
207 self.config.host(host);
208 self
209 }
210
211 pub fn port(mut self, port: u16) -> Self {
213 self.config.port(port);
214 self
215 }
216
217 pub fn database(mut self, database: impl ToString) -> Self {
219 self.config.database(database);
220 self
221 }
222
223 pub fn basic_authentication(
225 mut self,
226 username: impl ToString,
227 password: impl ToString,
228 ) -> Self {
229 self.config
230 .authentication(AuthMethod::sql_server(username, password));
231 self
232 }
233
234 pub fn authentication(mut self, authentication: AuthMethod) -> Self {
236 self.config.authentication(authentication);
237 self
238 }
239
240 pub fn trust_cert(mut self) -> Self {
242 self.config.trust_cert();
243 self
244 }
245
246 pub fn encryption(mut self, encryption: EncryptionLevel) -> Self {
248 self.config.encryption(encryption);
249 self
250 }
251
252 pub fn trust_cert_ca(mut self, path: impl ToString) -> Self {
254 self.config.trust_cert_ca(path);
255 self
256 }
257
258 pub fn instance_name(mut self, name: impl ToString) -> Self {
260 self.config.instance_name(name);
261 self
262 }
263
264 pub fn application_name(mut self, name: impl ToString) -> Self {
266 self.config.application_name(name);
267 self
268 }
269
270 pub fn max_size(mut self, value: usize) -> Self {
272 self.pool_config.max_size = value;
273 self
274 }
275
276 pub fn wait_timeout(mut self, value: Duration) -> Self {
278 self.pool_config.timeouts.wait = Some(value);
279 self.set_runtime(Runtime::Tokio1);
280 self
281 }
282
283 pub fn create_timeout(mut self, value: Duration) -> Self {
285 self.pool_config.timeouts.create = Some(value);
286 self.set_runtime(Runtime::Tokio1);
287 self
288 }
289
290 pub fn recycle_timeout(mut self, value: Duration) -> Self {
292 self.pool_config.timeouts.recycle = Some(value);
293 self.set_runtime(Runtime::Tokio1);
294 self
295 }
296
297 pub fn pre_recycle_sync<T>(mut self, hook: T) -> Self
300 where
301 T: Fn(&mut Client, &Metrics) -> HookResult<Error> + Sync + Send + 'static,
302 {
303 self.hooks.pre_recycle.push(Hook::sync_fn(hook));
304 self
305 }
306
307 pub fn pre_recycle_async<T>(mut self, hook: T) -> Self
310 where
311 T: for<'a> Fn(&'a mut Client, &'a Metrics) -> HookFuture<'a, Error> + Sync + Send + 'static,
312 {
313 self.hooks.pre_recycle.push(Hook::async_fn(hook));
314 self
315 }
316
317 pub fn post_recycle_sync<T>(mut self, hook: T) -> Self
320 where
321 T: Fn(&mut Client, &Metrics) -> HookResult<Error> + Sync + Send + 'static,
322 {
323 self.hooks.post_recycle.push(Hook::sync_fn(hook));
324 self
325 }
326
327 pub fn post_recycle_async<T>(mut self, hook: T) -> Self
330 where
331 T: for<'a> Fn(&'a mut Client, &'a Metrics) -> HookFuture<'a, Error> + Sync + Send + 'static,
332 {
333 self.hooks.post_recycle.push(Hook::async_fn(hook));
334 self
335 }
336
337 pub fn post_create_sync<T>(mut self, hook: T) -> Self
340 where
341 T: Fn(&mut Client, &Metrics) -> HookResult<Error> + Sync + Send + 'static,
342 {
343 self.hooks.post_create.push(Hook::sync_fn(hook));
344 self
345 }
346
347 pub fn post_create_async<T>(mut self, hook: T) -> Self
350 where
351 T: for<'a> Fn(&'a mut Client, &'a Metrics) -> HookFuture<'a, Error> + Sync + Send + 'static,
352 {
353 self.hooks.post_create.push(Hook::async_fn(hook));
354 self
355 }
356
357 fn set_runtime(&mut self, value: Runtime) {
358 self.runtime = Some(value);
359 }
360}
361
362struct Hooks {
363 pre_recycle: Vec<Hook<Manager>>,
364 post_recycle: Vec<Hook<Manager>>,
365 post_create: Vec<Hook<Manager>>,
366}
367
368impl Default for Hooks {
369 fn default() -> Self {
370 Hooks {
371 pre_recycle: Vec::<Hook<Manager>>::new(),
372 post_recycle: Vec::<Hook<Manager>>::new(),
373 post_create: Vec::<Hook<Manager>>::new(),
374 }
375 }
376}