#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Tiberius(#[from] tiberius::error::Error),
#[error(transparent)]
Io(#[from] std::io::Error),
}
pub trait IntoConfig {
fn into_config(self) -> tiberius::Result<tiberius::Config>;
}
impl IntoConfig for &str {
fn into_config(self) -> tiberius::Result<tiberius::Config> {
tiberius::Config::from_ado_string(self)
}
}
impl IntoConfig for tiberius::Config {
fn into_config(self) -> tiberius::Result<tiberius::Config> {
Ok(self)
}
}
#[allow(clippy::type_complexity)]
pub struct ConnectionManager {
config: tiberius::Config,
#[cfg(feature = "with-tokio")]
modify_tcp_stream:
Box<dyn Fn(&tokio::net::TcpStream) -> tokio::io::Result<()> + Send + Sync + 'static>,
#[cfg(feature = "with-async-std")]
modify_tcp_stream: Box<
dyn Fn(&async_std::net::TcpStream) -> async_std::io::Result<()> + Send + Sync + 'static,
>,
#[cfg(feature = "sql-browser")]
use_named_connection: bool,
}
impl ConnectionManager {
pub fn new(config: tiberius::Config) -> Self {
Self {
config,
modify_tcp_stream: Box::new(|tcp_stream| tcp_stream.set_nodelay(true)),
#[cfg(feature = "sql-browser")]
use_named_connection: false
}
}
pub fn build<I: IntoConfig>(config: I) -> Result<Self, Error> {
Ok(config.into_config().map(Self::new)?)
}
#[cfg(feature = "sql-browser")]
pub fn using_named_connection(mut self) -> Self {
self.use_named_connection = true;
self
}
}
#[cfg(feature = "with-tokio")]
pub mod rt {
pub type Client = tiberius::Client<tokio_util::compat::Compat<tokio::net::TcpStream>>;
impl super::ConnectionManager {
pub fn with_modify_tcp_stream<F>(mut self, f: F) -> Self
where
F: Fn(&tokio::net::TcpStream) -> tokio::io::Result<()> + Send + Sync + 'static,
{
self.modify_tcp_stream = Box::new(f);
self
}
#[cfg(feature = "sql-browser")]
async fn connect_tcp(&self) -> Result<tokio::net::TcpStream, super::Error> {
use tiberius::SqlBrowser;
if self.use_named_connection {
Ok(tokio::net::TcpStream::connect_named(&self.config).await?)
} else {
Ok(tokio::net::TcpStream::connect(self.config.get_addr()).await?)
}
}
#[cfg(not(feature = "sql-browser"))]
async fn connect_tcp(&self) -> std::io::Result<tokio::net::TcpStream> {
tokio::net::TcpStream::connect(self.config.get_addr()).await
}
pub(crate) async fn connect_inner(&self) -> Result<Client, super::Error> {
use tokio::net::TcpStream;
use tokio_util::compat::TokioAsyncWriteCompatExt;
let tcp = self.connect_tcp().await?;
(self.modify_tcp_stream)(&tcp)?;
let client = match Client::connect(self.config.clone(), tcp.compat_write()).await {
Ok(client) => client,
Err(tiberius::error::Error::Routing { host, port }) => {
let mut config = self.config.clone();
config.host(&host);
config.port(port);
let tcp = TcpStream::connect(config.get_addr()).await?;
(self.modify_tcp_stream)(&tcp)?;
tiberius::Client::connect(config, tcp.compat_write()).await?
}
Err(e) => Err(e)?,
};
Ok(client)
}
}
}
#[cfg(feature = "with-async-std")]
pub mod rt {
pub type Client = tiberius::Client<async_std::net::TcpStream>;
impl super::ConnectionManager {
pub fn with_modify_tcp_stream<F>(mut self, f: F) -> Self
where
F: Fn(&async_std::net::TcpStream) -> async_std::io::Result<()> + Send + Sync + 'static,
{
self.modify_tcp_stream = Box::new(f);
self
}
#[cfg(feature = "sql-browser")]
async fn connect_tcp(&self) -> tiberius::Result<async_std::net::TcpStream> {
use tiberius::SqlBrowser;
async_std::net::TcpStream::connect_named(&self.config).await
}
#[cfg(not(feature = "sql-browser"))]
async fn connect_tcp(&self) -> std::io::Result<async_std::net::TcpStream> {
async_std::net::TcpStream::connect(self.config.get_addr()).await
}
pub(crate) async fn connect_inner(&self) -> Result<Client, super::Error> {
let tcp = self.connect_tcp().await?;
(self.modify_tcp_stream)(&tcp)?;
let client = match Client::connect(self.config.clone(), tcp).await {
Ok(client) => client,
Err(tiberius::error::Error::Routing { host, port }) => {
let mut config = self.config.clone();
config.host(&host);
config.port(port);
let tcp = async_std::net::TcpStream::connect(config.get_addr()).await?;
(self.modify_tcp_stream)(&tcp)?;
tiberius::Client::connect(config, tcp).await?
}
Err(e) => Err(e)?,
};
Ok(client)
}
}
}
impl bb8::ManageConnection for ConnectionManager {
type Connection = rt::Client;
type Error = Error;
async fn connect(&self) -> Result<Self::Connection, Self::Error> {
self.connect_inner().await
}
async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
conn.simple_query("SELECT 1").await?;
Ok(())
}
fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
false
}
}