1#[derive(Debug, thiserror::Error)]
3pub enum Error {
4 #[error(transparent)]
5 Tiberius(#[from] tiberius::error::Error),
6 #[error(transparent)]
7 Io(#[from] std::io::Error),
8}
9
10pub trait IntoConfig {
12 fn into_config(self) -> tiberius::Result<tiberius::Config>;
13}
14
15impl IntoConfig for &str {
16 fn into_config(self) -> tiberius::Result<tiberius::Config> {
17 tiberius::Config::from_ado_string(self)
18 }
19}
20
21impl IntoConfig for tiberius::Config {
22 fn into_config(self) -> tiberius::Result<tiberius::Config> {
23 Ok(self)
24 }
25}
26
27#[allow(clippy::type_complexity)]
29pub struct ConnectionManager {
30 config: tiberius::Config,
31 #[cfg(feature = "with-tokio")]
32 modify_tcp_stream:
33 Box<dyn Fn(&tokio::net::TcpStream) -> tokio::io::Result<()> + Send + Sync + 'static>,
34 #[cfg(feature = "with-async-std")]
35 modify_tcp_stream: Box<
36 dyn Fn(&async_std::net::TcpStream) -> async_std::io::Result<()> + Send + Sync + 'static,
37 >,
38 #[cfg(feature = "sql-browser")]
39 use_named_connection: bool,
40}
41
42impl ConnectionManager {
43 pub fn new(config: tiberius::Config) -> Self {
45 Self {
46 config,
47 modify_tcp_stream: Box::new(|tcp_stream| tcp_stream.set_nodelay(true)),
48 #[cfg(feature = "sql-browser")]
49 use_named_connection: false
50 }
51 }
52
53 pub fn build<I: IntoConfig>(config: I) -> Result<Self, Error> {
55 Ok(config.into_config().map(Self::new)?)
56 }
57
58 #[cfg(feature = "sql-browser")]
59 pub fn using_named_connection(mut self) -> Self {
61 self.use_named_connection = true;
62 self
63 }
64}
65
66#[cfg(feature = "with-tokio")]
68pub mod rt {
69
70 pub type Client = tiberius::Client<tokio_util::compat::Compat<tokio::net::TcpStream>>;
72
73 impl super::ConnectionManager {
74 pub fn with_modify_tcp_stream<F>(mut self, f: F) -> Self
76 where
77 F: Fn(&tokio::net::TcpStream) -> tokio::io::Result<()> + Send + Sync + 'static,
78 {
79 self.modify_tcp_stream = Box::new(f);
80 self
81 }
82
83 #[cfg(feature = "sql-browser")]
84 async fn connect_tcp(&self) -> Result<tokio::net::TcpStream, super::Error> {
85 use tiberius::SqlBrowser;
86
87 if self.use_named_connection {
88 Ok(tokio::net::TcpStream::connect_named(&self.config).await?)
89 } else {
90 Ok(tokio::net::TcpStream::connect(self.config.get_addr()).await?)
91 }
92 }
93
94 #[cfg(not(feature = "sql-browser"))]
95 async fn connect_tcp(&self) -> std::io::Result<tokio::net::TcpStream> {
96 tokio::net::TcpStream::connect(self.config.get_addr()).await
97 }
98
99 pub(crate) async fn connect_inner(&self) -> Result<Client, super::Error> {
100 use tokio::net::TcpStream;
101 use tokio_util::compat::TokioAsyncWriteCompatExt; let tcp = self.connect_tcp().await?;
104
105 (self.modify_tcp_stream)(&tcp)?;
106
107 let client = match Client::connect(self.config.clone(), tcp.compat_write()).await {
108 Ok(client) => client,
110
111 Err(tiberius::error::Error::Routing { host, port }) => {
113 let mut config = self.config.clone();
114
115 config.host(&host);
116 config.port(port);
117
118 let tcp = TcpStream::connect(config.get_addr()).await?;
119
120 (self.modify_tcp_stream)(&tcp)?;
121
122 tiberius::Client::connect(config, tcp.compat_write()).await?
124 }
125
126 Err(e) => Err(e)?,
128 };
129
130 Ok(client)
131 }
132 }
133}
134
135#[cfg(feature = "with-async-std")]
136pub mod rt {
137
138 pub type Client = tiberius::Client<async_std::net::TcpStream>;
140
141 impl super::ConnectionManager {
142 pub fn with_modify_tcp_stream<F>(mut self, f: F) -> Self
144 where
145 F: Fn(&async_std::net::TcpStream) -> async_std::io::Result<()> + Send + Sync + 'static,
146 {
147 self.modify_tcp_stream = Box::new(f);
148 self
149 }
150
151 #[cfg(feature = "sql-browser")]
152 async fn connect_tcp(&self) -> tiberius::Result<async_std::net::TcpStream> {
153 use tiberius::SqlBrowser;
154 async_std::net::TcpStream::connect_named(&self.config).await
155 }
156
157 #[cfg(not(feature = "sql-browser"))]
158 async fn connect_tcp(&self) -> std::io::Result<async_std::net::TcpStream> {
159 async_std::net::TcpStream::connect(self.config.get_addr()).await
160 }
161
162 pub(crate) async fn connect_inner(&self) -> Result<Client, super::Error> {
163 let tcp = self.connect_tcp().await?;
164
165 (self.modify_tcp_stream)(&tcp)?;
166
167 let client = match Client::connect(self.config.clone(), tcp).await {
168 Ok(client) => client,
170
171 Err(tiberius::error::Error::Routing { host, port }) => {
173 let mut config = self.config.clone();
174
175 config.host(&host);
176 config.port(port);
177
178 let tcp = async_std::net::TcpStream::connect(config.get_addr()).await?;
179
180 (self.modify_tcp_stream)(&tcp)?;
181
182 tiberius::Client::connect(config, tcp).await?
184 }
185
186 Err(e) => Err(e)?,
188 };
189
190 Ok(client)
191 }
192 }
193}
194
195impl bb8::ManageConnection for ConnectionManager {
196 type Connection = rt::Client;
197 type Error = Error;
198
199 async fn connect(&self) -> Result<Self::Connection, Self::Error> {
200 self.connect_inner().await
201 }
202
203 async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
204 conn.simple_query("SELECT 1").await?;
205 Ok(())
206 }
207
208 fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
209 false
210 }
211}