use std::future::Future;
use std::time::Duration;
use bb8::{Pool, PooledConnection};
use bb8_tiberius::ConnectionManager;
use faucet_core::FaucetError;
use tiberius::{AuthMethod, Config, EncryptionLevel};
use crate::config::{MssqlConnectionConfig, MssqlTls, MssqlTlsMode, parse_connection_url};
pub type MssqlPool = Pool<ConnectionManager>;
pub type MssqlPooledConnection<'a> = PooledConnection<'a, ConnectionManager>;
pub fn build_config(cfg: &MssqlConnectionConfig) -> Result<Config, FaucetError> {
cfg.validate()?;
let mut config = if let Some(url) = &cfg.connection_url {
let parts = parse_connection_url(url)?;
let mut c = Config::new();
c.host(parts.host);
c.port(parts.port);
if let Some(db) = parts.database {
c.database(db);
}
c.authentication(AuthMethod::sql_server(parts.username, parts.password));
c
} else {
let s = cfg
.connection_string
.as_ref()
.expect("validate() guarantees one of url/string is set");
Config::from_ado_string(s)
.map_err(|e| FaucetError::Config(format!("invalid MSSQL connection_string: {e}")))?
};
apply_tls(&mut config, &cfg.tls)?;
config.application_name("faucet");
Ok(config)
}
fn apply_tls(config: &mut Config, tls: &MssqlTls) -> Result<(), FaucetError> {
match tls.mode {
MssqlTlsMode::Prefer => config.encryption(EncryptionLevel::On),
MssqlTlsMode::Require => config.encryption(EncryptionLevel::Required),
MssqlTlsMode::Disable => config.encryption(EncryptionLevel::NotSupported),
MssqlTlsMode::TrustServerCertificate => {
config.encryption(EncryptionLevel::On);
config.trust_cert();
}
}
if tls.mode != MssqlTlsMode::Disable
&& let Some(path) = &tls.ca_cert_path
{
let p = path.to_str().ok_or_else(|| {
FaucetError::Config("MSSQL tls.ca_cert_path is not valid UTF-8".into())
})?;
config.trust_cert_ca(p);
}
Ok(())
}
pub async fn build_pool(
cfg: &MssqlConnectionConfig,
max_connections: u32,
) -> Result<MssqlPool, FaucetError> {
let config = build_config(cfg)?;
let manager = ConnectionManager::new(config);
let pool = Pool::builder()
.max_size(max_connections.max(1))
.build(manager)
.await
.map_err(|e| FaucetError::Config(format!("MSSQL pool build failed: {e}")))?;
pool.get()
.await
.map_err(|e| FaucetError::Config(format!("MSSQL connection failed: {e}")))?;
Ok(pool)
}
pub async fn with_statement_timeout<F, T>(
timeout: Duration,
fut: F,
make_timeout_err: impl FnOnce() -> FaucetError,
) -> Result<T, FaucetError>
where
F: Future<Output = Result<T, FaucetError>>,
{
match tokio::time::timeout(timeout, fut).await {
Ok(inner) => inner,
Err(_) => Err(make_timeout_err()),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn timeout_passes_through_completed_future() {
let out: Result<i32, FaucetError> =
with_statement_timeout(Duration::from_secs(5), async { Ok(42) }, || {
FaucetError::Source("unused".into())
})
.await;
assert_eq!(out.unwrap(), 42);
}
#[tokio::test]
async fn timeout_fires_on_slow_future() {
let out: Result<i32, FaucetError> = with_statement_timeout(
Duration::from_millis(10),
async {
tokio::time::sleep(Duration::from_secs(30)).await;
Ok(1)
},
|| FaucetError::Sink("query timed out".into()),
)
.await;
let err = out.unwrap_err();
assert!(matches!(err, FaucetError::Sink(_)));
assert!(err.to_string().contains("timed out"));
}
#[test]
fn build_config_rejects_invalid_source_combo() {
let both = MssqlConnectionConfig {
connection_url: Some("mssql://sa:pw@h/db".into()),
connection_string: Some("Server=h".into()),
..Default::default()
};
assert!(build_config(&both).is_err());
}
#[test]
fn build_config_from_url_succeeds() {
let cfg = MssqlConnectionConfig {
connection_url: Some("mssql://sa:pw@localhost:1433/sales".into()),
..Default::default()
};
assert!(build_config(&cfg).is_ok());
}
}