rquery_orm/
db.rs

1use anyhow::Result;
2use native_tls::TlsConnector;
3use postgres_native_tls::MakeTlsConnector;
4use std::sync::Arc;
5use tokio::net::TcpStream;
6use tokio::sync::Mutex;
7use tokio_postgres::NoTls;
8use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt};
9
10pub enum DbKind {
11    Mssql,
12    Postgres,
13}
14
15pub enum DatabaseRef {
16    Mssql(Arc<Mutex<tiberius::Client<Compat<TcpStream>>>>),
17    Postgres(Arc<tokio_postgres::Client>),
18}
19
20impl DatabaseRef {
21    pub fn kind(&self) -> DbKind {
22        match self {
23            DatabaseRef::Mssql(_) => DbKind::Mssql,
24            DatabaseRef::Postgres(_) => DbKind::Postgres,
25        }
26    }
27}
28
29pub async fn connect_mssql(
30    host: &str,
31    port: u16,
32    db: &str,
33    user: &str,
34    pass: &str,
35) -> Result<DatabaseRef> {
36    let mut config = tiberius::Config::new();
37    config.host(host);
38    config.port(port);
39    config.database(db);
40    config.authentication(tiberius::AuthMethod::sql_server(user, pass));
41    config.trust_cert();
42
43    let tcp = TcpStream::connect((host, port)).await?;
44    tcp.set_nodelay(true)?;
45    let client = tiberius::Client::connect(config, tcp.compat_write()).await?;
46    Ok(DatabaseRef::Mssql(Arc::new(Mutex::new(client))))
47}
48
49pub async fn connect_postgres(
50    host: &str,
51    port: u16,
52    db: &str,
53    user: &str,
54    pass: &str,
55) -> Result<DatabaseRef> {
56    let base = format!(
57        "host={} port={} dbname={} user={} password={}",
58        host, port, db, user, pass
59    );
60
61    let builder = TlsConnector::builder()
62        .danger_accept_invalid_certs(true)
63        .danger_accept_invalid_hostnames(true)
64        .build()?;
65    let connector = MakeTlsConnector::new(builder);
66    let tls_config = format!("{} sslmode=require", base);
67
68    match tokio_postgres::connect(&tls_config, connector).await {
69        Ok((client, connection)) => {
70            tokio::spawn(async move {
71                if let Err(e) = connection.await {
72                    eprintln!("postgres connection error: {}", e);
73                }
74            });
75            Ok(DatabaseRef::Postgres(Arc::new(client)))
76        }
77        Err(e) if e.to_string().contains("server does not support TLS") => {
78            let plain_config = format!("{} sslmode=disable", base);
79            let (client, connection) = tokio_postgres::connect(&plain_config, NoTls).await?;
80            tokio::spawn(async move {
81                if let Err(e) = connection.await {
82                    eprintln!("postgres connection error: {}", e);
83                }
84            });
85            Ok(DatabaseRef::Postgres(Arc::new(client)))
86        }
87        Err(e) => Err(e.into()),
88    }
89}