1use anyhow::Result;
2use native_tls::TlsConnector;
3use postgres_native_tls::MakeTlsConnector;
4use std::sync::Arc;
5use tokio::net::TcpStream;
6use tokio::sync::Mutex;
7use tokio_postgres::NoTls;
8use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
9
10pub enum DbKind {
11 Mssql,
12 Postgres,
13}
14
15pub enum DatabaseRef {
16 Mssql(Arc<Mutex<tiberius::Client<Compat<TcpStream>>>>),
17 Postgres(Arc<tokio_postgres::Client>),
18}
19
20impl DatabaseRef {
21 pub fn kind(&self) -> DbKind {
22 match self {
23 DatabaseRef::Mssql(_) => DbKind::Mssql,
24 DatabaseRef::Postgres(_) => DbKind::Postgres,
25 }
26 }
27}
28
29pub async fn connect_mssql(
30 host: &str,
31 port: u16,
32 db: &str,
33 user: &str,
34 pass: &str,
35) -> Result<DatabaseRef> {
36 let mut config = tiberius::Config::new();
37 config.host(host);
38 config.port(port);
39 config.database(db);
40 config.authentication(tiberius::AuthMethod::sql_server(user, pass));
41 config.trust_cert();
42
43 let tcp = TcpStream::connect((host, port)).await?;
44 tcp.set_nodelay(true)?;
45 let client = tiberius::Client::connect(config, tcp.compat_write()).await?;
46 Ok(DatabaseRef::Mssql(Arc::new(Mutex::new(client))))
47}
48
49pub async fn connect_postgres(
50 host: &str,
51 port: u16,
52 db: &str,
53 user: &str,
54 pass: &str,
55) -> Result<DatabaseRef> {
56 let base = format!(
57 "host={} port={} dbname={} user={} password={}",
58 host, port, db, user, pass
59 );
60
61 let builder = TlsConnector::builder()
62 .danger_accept_invalid_certs(true)
63 .danger_accept_invalid_hostnames(true)
64 .build()?;
65 let connector = MakeTlsConnector::new(builder);
66 let tls_config = format!("{} sslmode=require", base);
67
68 match tokio_postgres::connect(&tls_config, connector).await {
69 Ok((client, connection)) => {
70 tokio::spawn(async move {
71 if let Err(e) = connection.await {
72 eprintln!("postgres connection error: {}", e);
73 }
74 });
75 Ok(DatabaseRef::Postgres(Arc::new(client)))
76 }
77 Err(e) if e.to_string().contains("server does not support TLS") => {
78 let plain_config = format!("{} sslmode=disable", base);
79 let (client, connection) = tokio_postgres::connect(&plain_config, NoTls).await?;
80 tokio::spawn(async move {
81 if let Err(e) = connection.await {
82 eprintln!("postgres connection error: {}", e);
83 }
84 });
85 Ok(DatabaseRef::Postgres(Arc::new(client)))
86 }
87 Err(e) => Err(e.into()),
88 }
89}