deadpool_tiberius_rustls/
lib.rs

1//! This crate chains config from [`tiberius`] and [`deadpool`] to make it easier for creating tiberius connection pool.
2//! # Example
3//! ```no_run
4//! let pool = deadpool_tiberius::Manager::new()
5//!     .host("host")
6//!     .port(1433)
7//!     .basic_authentication("username", "password")
8//!     .database("database1")
9//!     .max_size(20)
10//!     .wait_timeout(1.52)
11//!     .pre_recycle_sync(|_client, _metrics| {
12//!         // do sth with connection object and pool metrics.
13//!         Ok(())
14//!     })
15//!     .create_pool()?;
16//! ```
17//!
18//! [`Manager::from_ado_string`] and [`Manager::from_jdbc_string`] also served as another entry for constructing Manager.
19//! ```no_run
20//! const CONN_STR: &str = "Driver={SQL Server};Integrated Security=True;\
21//!                         Server=DESKTOP-TTTTTTT;Database=master;\
22//!                         Trusted_Connection=yes;encrypt=DANGER_PLAINTEXT;";
23//! let pool = deadpool_tiberius::Manager::from_ado_string(CONN_STR)?
24//!                 .max_size(20)
25//!                 .wait_timeout(1.52)
26//!                 .create_pool()?;
27//! ```
28//! For all configurable pls visit [`Manager`].
29#![warn(missing_docs)]
30#![cfg_attr(docsrs, feature(doc_cfg))]
31use std::mem::replace;
32use std::time::Duration;
33
34pub use deadpool;
35use deadpool::{
36    async_trait, managed,
37    managed::{Hook, HookFuture, HookResult, Metrics, PoolConfig, RecycleError, RecycleResult},
38    Runtime,
39};
40pub use tiberius_rustls;
41use tiberius_rustls::error::Error;
42use tiberius_rustls::{AuthMethod, EncryptionLevel};
43use tokio_util::compat::TokioAsyncWriteCompatExt;
44
45pub use crate::error::SqlServerError;
46pub use crate::error::SqlServerResult;
47
48mod error;
49
50/// Type aliasing for tiberius client with [`tokio`] as runtime.
51pub type Client = tiberius_rustls::Client<tokio_util::compat::Compat<tokio::net::TcpStream>>;
52/// Type aliasing for Pool.
53pub type Pool = managed::Pool<Manager>;
54
55/// Connection pool Manager served as Builder. Call [`create_pool`] after filling out your configs.
56///
57/// [`create_pool`]: struct.Manager.html#method.create_pool
58pub struct Manager {
59    config: tiberius_rustls::Config,
60    pool_config: PoolConfig,
61    runtime: Option<Runtime>,
62    hooks: Hooks,
63    modify_tcp_stream:
64        Box<dyn Fn(&tokio::net::TcpStream) -> tokio::io::Result<()> + Send + Sync + 'static>,
65    #[cfg(feature = "sql-browser")]
66    enable_sql_browser: bool,
67}
68
69#[async_trait]
70impl managed::Manager for Manager {
71    type Type = Client;
72    type Error = tiberius_rustls::error::Error;
73
74    #[cfg(feature = "sql-browser")]
75    async fn create(&self) -> Result<Client, Self::Error> {
76        use tiberius_rustls::SqlBrowser;
77        let tcp = if !self.enable_sql_browser {
78            tokio::net::TcpStream::connect(self.config.get_addr()).await?
79        } else {
80            tokio::net::TcpStream::connect_named(&self.config).await?
81        };
82        (self.modify_tcp_stream)(&tcp)?;
83        let client = Client::connect(self.config.clone(), tcp.compat_write()).await;
84        match client {
85            Ok(client) => Ok(client),
86            Err(Error::Routing { host, port }) => {
87                let mut config = self.config.clone();
88                config.host(host);
89                config.port(port);
90
91                let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
92                tcp.set_nodelay(true)?;
93
94                Client::connect(config, tcp.compat_write()).await
95            },
96            // Propagate errors
97            Err(err) => Err(err)?,
98        }
99    }
100
101    #[cfg(not(feature = "sql-browser"))]
102    async fn create(&self) -> Result<Client, Self::Error> {
103        let tcp = tokio::net::TcpStream::connect(self.config.get_addr()).await?;
104        (self.modify_tcp_stream)(&tcp)?;
105        let client = Client::connect(self.config.clone(), tcp.compat_write()).await;
106
107        match client {
108            Ok(client) => Ok(client),
109            Err(Error::Routing { host, port }) => {
110                let mut config = self.config.clone();
111                config.host(host);
112                config.port(port);
113
114                let tcp = tokio::net::TcpStream::connect(config.get_addr()).await?;
115                tcp.set_nodelay(true)?;
116
117                Client::connect(config, tcp.compat_write()).await
118            },
119            // Propagate errors
120            Err(err) => Err(err)?,
121        }
122    }
123
124    async fn recycle(
125        &self,
126        obj: &mut Self::Type,
127        _metrics: &Metrics,
128    ) -> RecycleResult<Self::Error> {
129        match obj.simple_query("").await {
130            Ok(_) => Ok(()),
131            Err(e) => Err(RecycleError::Message(e.to_string())),
132        }
133    }
134}
135
136impl Manager {
137    /// Create new ConnectionPool Manager
138    pub fn new() -> Self {
139        Self::new_with_tiberius_config(tiberius_rustls::Config::new())
140    }
141
142    /// Create a new ConnectionPool Manager and fills connection configs from ado string.
143    /// For more details about ADO_String pleas refer to [`tiberius::Config::from_ado_string`] and [`Connection Strings in ADO.NET`].
144    ///
145    /// [`Connection Strings in ADO.NET`]: https://docs.microsoft.com/en-us/dotnet/framework/data/adonet/connection-strings
146    pub fn from_ado_string(conn_str: &str) -> SqlServerResult<Self> {
147        Ok(Self::new_with_tiberius_config(
148            tiberius_rustls::Config::from_ado_string(conn_str)?,
149        ))
150    }
151
152    /// Create new ConnectionPool Manager and fills connection config from jdbc string.
153    /// For more details about jdbc_string pls refer to [`Building JDBC connection URL`].
154    ///
155    /// [`Building JDBC connection URL`]: https://docs.microsoft.com/en-us/sql/connect/jdbc/building-the-connection-url?view=sql-server-ver15
156    pub fn from_jdbc_string(conn_str: &str) -> SqlServerResult<Self> {
157        Ok(Self::new_with_tiberius_config(
158            tiberius_rustls::Config::from_jdbc_string(conn_str)?,
159        ))
160    }
161
162    fn new_with_tiberius_config(config: tiberius_rustls::Config) -> Self {
163        Self {
164            config,
165            pool_config: Default::default(),
166            runtime: None,
167            hooks: Default::default(),
168            modify_tcp_stream: Box::new(|tcp_stream| tcp_stream.set_nodelay(true)),
169            #[cfg(feature = "sql-browser")]
170            enable_sql_browser: false,
171        }
172    }
173
174    /// Consume self, builds a pool.
175    pub fn create_pool(mut self) -> Result<Pool, error::SqlServerError> {
176        let config = self.pool_config;
177        let runtime = self.runtime;
178        let hooks = replace(&mut self.hooks, Hooks::default());
179        let mut pool = Pool::builder(self).config(config);
180        if let Some(v) = runtime {
181            pool = pool.runtime(v);
182        }
183
184        for hook in hooks.post_create {
185            pool = pool.post_create(hook);
186        }
187        for hook in hooks.pre_recycle {
188            pool = pool.pre_recycle(hook);
189        }
190        for hook in hooks.post_recycle {
191            pool = pool.post_recycle(hook);
192        }
193
194        Ok(pool.build()?)
195    }
196
197    /// Whether connected via sql-browser feature, default to `false`.
198    #[cfg(feature = "sql-browser")]
199    #[cfg_attr(docsrs, doc(cfg(feature = "sql-browser")))]
200    pub fn enable_sql_browser(mut self) -> Self {
201        self.enable_sql_browser = true;
202        self
203    }
204
205    /// Server host, defaults to `localhost`.
206    pub fn host(mut self, host: impl ToString) -> Self {
207        self.config.host(host);
208        self
209    }
210
211    /// Server port, defaults to 1433.
212    pub fn port(mut self, port: u16) -> Self {
213        self.config.port(port);
214        self
215    }
216
217    /// Database, defaults to `master`.
218    pub fn database(mut self, database: impl ToString) -> Self {
219        self.config.database(database);
220        self
221    }
222
223    /// Simplified authentication for those using `username` and `password` as login method.
224    pub fn basic_authentication(
225        mut self,
226        username: impl ToString,
227        password: impl ToString,
228    ) -> Self {
229        self.config
230            .authentication(AuthMethod::sql_server(username, password));
231        self
232    }
233
234    /// Set [`tiberius::AuthMethod`] as authentication method.
235    pub fn authentication(mut self, authentication: AuthMethod) -> Self {
236        self.config.authentication(authentication);
237        self
238    }
239
240    /// See [`tiberius::Config::trust_cert`]
241    pub fn trust_cert(mut self) -> Self {
242        self.config.trust_cert();
243        self
244    }
245
246    /// Set [`tiberius::EncryptionLevel`] as enctryption method.
247    pub fn encryption(mut self, encryption: EncryptionLevel) -> Self {
248        self.config.encryption(encryption);
249        self
250    }
251
252    /// See [`tiberius::Config::trust_cert_ca`]
253    pub fn trust_cert_ca(mut self, path: impl ToString) -> Self {
254        self.config.trust_cert_ca(path);
255        self
256    }
257
258    /// Instance name defined in `Sql Browser`, defaults to None.
259    pub fn instance_name(mut self, name: impl ToString) -> Self {
260        self.config.instance_name(name);
261        self
262    }
263
264    /// See [`tiberius::Config::application_name`]
265    pub fn application_name(mut self, name: impl ToString) -> Self {
266        self.config.application_name(name);
267        self
268    }
269
270    /// Set pool size, defaults to 10.
271    pub fn max_size(mut self, value: usize) -> Self {
272        self.pool_config.max_size = value;
273        self
274    }
275
276    /// Set timeout for when waiting for a connection object to become available.
277    pub fn wait_timeout(mut self, value: Duration) -> Self {
278        self.pool_config.timeouts.wait = Some(value);
279        self.set_runtime(Runtime::Tokio1);
280        self
281    }
282
283    /// Set timeout for when creating a new connection object.
284    pub fn create_timeout(mut self, value: Duration) -> Self {
285        self.pool_config.timeouts.create = Some(value);
286        self.set_runtime(Runtime::Tokio1);
287        self
288    }
289
290    /// Set timeout for when recycling a connection object.
291    pub fn recycle_timeout(mut self, value: Duration) -> Self {
292        self.pool_config.timeouts.recycle = Some(value);
293        self.set_runtime(Runtime::Tokio1);
294        self
295    }
296
297    /// Attach a `sync fn` as hook to connection pool.
298    /// The hook will be called each time before a connection [`deadpool::managed::Object`] is recycled.
299    pub fn pre_recycle_sync<T>(mut self, hook: T) -> Self
300    where
301        T: Fn(&mut Client, &Metrics) -> HookResult<Error> + Sync + Send + 'static,
302    {
303        self.hooks.pre_recycle.push(Hook::sync_fn(hook));
304        self
305    }
306
307    /// Attach an `async fn` as hook to connection pool.
308    /// The hook will be called each time before a connection [`deadpool::managed::Object`] is recycled.
309    pub fn pre_recycle_async<T>(mut self, hook: T) -> Self
310    where
311        T: for<'a> Fn(&'a mut Client, &'a Metrics) -> HookFuture<'a, Error> + Sync + Send + 'static,
312    {
313        self.hooks.pre_recycle.push(Hook::async_fn(hook));
314        self
315    }
316
317    /// Attach a `sync fn` as hook to connection pool.
318    /// The hook will be called each time af after a connection [`deadpool::managed::Object`] is recycled.
319    pub fn post_recycle_sync<T>(mut self, hook: T) -> Self
320    where
321        T: Fn(&mut Client, &Metrics) -> HookResult<Error> + Sync + Send + 'static,
322    {
323        self.hooks.post_recycle.push(Hook::sync_fn(hook));
324        self
325    }
326
327    /// Attach an `async fn` as hook to connection pool.
328    /// The hook will be called each time after a connection [`deadpool::managed::Object`] is recycled.
329    pub fn post_recycle_async<T>(mut self, hook: T) -> Self
330    where
331        T: for<'a> Fn(&'a mut Client, &'a Metrics) -> HookFuture<'a, Error> + Sync + Send + 'static,
332    {
333        self.hooks.post_recycle.push(Hook::async_fn(hook));
334        self
335    }
336
337    /// Attach a `sync fn` as hook to connection pool.
338    /// The hook will be called each time after a connection [`deadpool::managed::Object`] is created.
339    pub fn post_create_sync<T>(mut self, hook: T) -> Self
340    where
341        T: Fn(&mut Client, &Metrics) -> HookResult<Error> + Sync + Send + 'static,
342    {
343        self.hooks.post_create.push(Hook::sync_fn(hook));
344        self
345    }
346
347    /// Attach an `async fn` as hook to connection pool.
348    /// The hook will be called each time after a connection [`deadpool::managed::Object`] is created.
349    pub fn post_create_async<T>(mut self, hook: T) -> Self
350    where
351        T: for<'a> Fn(&'a mut Client, &'a Metrics) -> HookFuture<'a, Error> + Sync + Send + 'static,
352    {
353        self.hooks.post_create.push(Hook::async_fn(hook));
354        self
355    }
356
357    fn set_runtime(&mut self, value: Runtime) {
358        self.runtime = Some(value);
359    }
360}
361
362struct Hooks {
363    pre_recycle: Vec<Hook<Manager>>,
364    post_recycle: Vec<Hook<Manager>>,
365    post_create: Vec<Hook<Manager>>,
366}
367
368impl Default for Hooks {
369    fn default() -> Self {
370        Hooks {
371            pre_recycle: Vec::<Hook<Manager>>::new(),
372            post_recycle: Vec::<Hook<Manager>>::new(),
373            post_create: Vec::<Hook<Manager>>::new(),
374        }
375    }
376}