bb8_tiberius/
lib.rs

1/// The error container
2#[derive(Debug, thiserror::Error)]
3pub enum Error {
4    #[error(transparent)]
5    Tiberius(#[from] tiberius::error::Error),
6    #[error(transparent)]
7    Io(#[from] std::io::Error),
8}
9
10/// Implemented for `&str` (ADO-style string) and `tiberius::Config`
11pub trait IntoConfig {
12    fn into_config(self) -> tiberius::Result<tiberius::Config>;
13}
14
15impl IntoConfig for &str {
16    fn into_config(self) -> tiberius::Result<tiberius::Config> {
17        tiberius::Config::from_ado_string(self)
18    }
19}
20
21impl IntoConfig for tiberius::Config {
22    fn into_config(self) -> tiberius::Result<tiberius::Config> {
23        Ok(self)
24    }
25}
26
27/// Implements `bb8::ManageConnection`
28#[allow(clippy::type_complexity)]
29pub struct ConnectionManager {
30    config: tiberius::Config,
31    #[cfg(feature = "with-tokio")]
32    modify_tcp_stream:
33        Box<dyn Fn(&tokio::net::TcpStream) -> tokio::io::Result<()> + Send + Sync + 'static>,
34    #[cfg(feature = "with-async-std")]
35    modify_tcp_stream: Box<
36        dyn Fn(&async_std::net::TcpStream) -> async_std::io::Result<()> + Send + Sync + 'static,
37    >,
38    #[cfg(feature = "sql-browser")]
39    use_named_connection: bool,
40}
41
42impl ConnectionManager {
43    /// Create a new `ConnectionManager`
44    pub fn new(config: tiberius::Config) -> Self {
45        Self {
46            config,
47            modify_tcp_stream: Box::new(|tcp_stream| tcp_stream.set_nodelay(true)),
48            #[cfg(feature = "sql-browser")]
49            use_named_connection: false
50        }
51    }
52
53    /// Build a `ConnectionManager` from e.g. an ADO string
54    pub fn build<I: IntoConfig>(config: I) -> Result<Self, Error> {
55        Ok(config.into_config().map(Self::new)?)
56    }
57
58    #[cfg(feature = "sql-browser")]
59    /// Use `tiberius::SqlBrowser::connect_named` to establish the TCP stream
60    pub fn using_named_connection(mut self) -> Self {
61        self.use_named_connection = true;
62        self
63    }
64}
65
66/// Runtime (`tokio` or `async-std` dependent parts)
67#[cfg(feature = "with-tokio")]
68pub mod rt {
69
70    /// The connection type
71    pub type Client = tiberius::Client<tokio_util::compat::Compat<tokio::net::TcpStream>>;
72
73    impl super::ConnectionManager {
74        /// Perform some configuration on the TCP stream when generating connections
75        pub fn with_modify_tcp_stream<F>(mut self, f: F) -> Self
76        where
77            F: Fn(&tokio::net::TcpStream) -> tokio::io::Result<()> + Send + Sync + 'static,
78        {
79            self.modify_tcp_stream = Box::new(f);
80            self
81        }
82
83        #[cfg(feature = "sql-browser")]
84        async fn connect_tcp(&self) -> Result<tokio::net::TcpStream, super::Error> {
85            use tiberius::SqlBrowser;
86
87            if self.use_named_connection {
88                Ok(tokio::net::TcpStream::connect_named(&self.config).await?)
89            } else {
90                Ok(tokio::net::TcpStream::connect(self.config.get_addr()).await?)
91            }
92        }
93
94        #[cfg(not(feature = "sql-browser"))]
95        async fn connect_tcp(&self) -> std::io::Result<tokio::net::TcpStream> {
96            tokio::net::TcpStream::connect(self.config.get_addr()).await
97        }
98
99        pub(crate) async fn connect_inner(&self) -> Result<Client, super::Error> {
100            use tokio::net::TcpStream;
101            use tokio_util::compat::TokioAsyncWriteCompatExt; //Tokio02AsyncWriteCompatExt;
102
103            let tcp = self.connect_tcp().await?;
104
105            (self.modify_tcp_stream)(&tcp)?;
106
107            let client = match Client::connect(self.config.clone(), tcp.compat_write()).await {
108                // Connection successful.
109                Ok(client) => client,
110
111                // The server wants us to redirect to a different address
112                Err(tiberius::error::Error::Routing { host, port }) => {
113                    let mut config = self.config.clone();
114
115                    config.host(&host);
116                    config.port(port);
117
118                    let tcp = TcpStream::connect(config.get_addr()).await?;
119
120                    (self.modify_tcp_stream)(&tcp)?;
121
122                    // we should not have more than one redirect, so we'll short-circuit here.
123                    tiberius::Client::connect(config, tcp.compat_write()).await?
124                }
125
126                // Other error happened
127                Err(e) => Err(e)?,
128            };
129
130            Ok(client)
131        }
132    }
133}
134
135#[cfg(feature = "with-async-std")]
136pub mod rt {
137
138    /// The connection type
139    pub type Client = tiberius::Client<async_std::net::TcpStream>;
140
141    impl super::ConnectionManager {
142        /// Perform some configuration on the TCP stream when generating connections
143        pub fn with_modify_tcp_stream<F>(mut self, f: F) -> Self
144        where
145            F: Fn(&async_std::net::TcpStream) -> async_std::io::Result<()> + Send + Sync + 'static,
146        {
147            self.modify_tcp_stream = Box::new(f);
148            self
149        }
150
151        #[cfg(feature = "sql-browser")]
152        async fn connect_tcp(&self) -> tiberius::Result<async_std::net::TcpStream> {
153            use tiberius::SqlBrowser;
154            async_std::net::TcpStream::connect_named(&self.config).await
155        }
156
157        #[cfg(not(feature = "sql-browser"))]
158        async fn connect_tcp(&self) -> std::io::Result<async_std::net::TcpStream> {
159            async_std::net::TcpStream::connect(self.config.get_addr()).await
160        }
161
162        pub(crate) async fn connect_inner(&self) -> Result<Client, super::Error> {
163            let tcp = self.connect_tcp().await?;
164
165            (self.modify_tcp_stream)(&tcp)?;
166
167            let client = match Client::connect(self.config.clone(), tcp).await {
168                // Connection successful.
169                Ok(client) => client,
170
171                // The server wants us to redirect to a different address
172                Err(tiberius::error::Error::Routing { host, port }) => {
173                    let mut config = self.config.clone();
174
175                    config.host(&host);
176                    config.port(port);
177
178                    let tcp = async_std::net::TcpStream::connect(config.get_addr()).await?;
179
180                    (self.modify_tcp_stream)(&tcp)?;
181
182                    // we should not have more than one redirect, so we'll short-circuit here.
183                    tiberius::Client::connect(config, tcp).await?
184                }
185
186                // Other error happened
187                Err(e) => Err(e)?,
188            };
189
190            Ok(client)
191        }
192    }
193}
194
195impl bb8::ManageConnection for ConnectionManager {
196    type Connection = rt::Client;
197    type Error = Error;
198
199    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
200        self.connect_inner().await
201    }
202
203    async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
204        conn.simple_query("SELECT 1").await?;
205        Ok(())
206    }
207
208    fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
209        false
210    }
211}