endpoint_libs/libs/database/
pooled.rs

1use dashmap::DashMap;
2use deadpool_postgres::Runtime;
3use deadpool_postgres::*;
4use eyre::*;
5use postgres_from_row::FromRow;
6use secrecy::ExposeSecret;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt::Debug;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::result::Result::Ok;
12use std::sync::Arc;
13use std::time::Duration;
14pub use tokio_postgres::types::ToSql;
15use tokio_postgres::Statement;
16pub use tokio_postgres::{NoTls, Row, ToStatement};
17use tracing::*;
18
19use crate::libs::datatable::RDataTable;
20
21use super::DatabaseConfig;
22use super::DatabaseRequest;
23
24#[derive(Clone)]
25pub struct PooledDbClient {
26    pool: Pool,
27    prepared_stmts: Arc<DashMap<String, Statement>>,
28    conn_hash: u64,
29}
30impl PooledDbClient {
31    #[deprecated]
32    pub async fn query<T>(
33        &self,
34        statement: &T,
35        params: &[&(dyn ToSql + Sync)],
36    ) -> Result<Vec<Row>, Error>
37    where
38        T: ?Sized + Sync + Send + ToStatement,
39    {
40        Ok(self
41            .pool
42            .get()
43            .await
44            .context("Failed to connect to database")?
45            .query(statement, params)
46            .await?)
47    }
48
49    pub async fn execute<T: DatabaseRequest + Debug>(
50        &self,
51        req: T,
52    ) -> Result<RDataTable<T::ResponseRow>> {
53        let mut error = None;
54        for _ in 0..2 {
55            let begin = std::time::Instant::now();
56            let client = self
57                .pool
58                .get()
59                .await
60                .context("Failed to connect to database")?;
61            let statement = tokio::time::timeout(
62                Duration::from_secs(20),
63                client.prepare_cached(req.statement()),
64            )
65            .await
66            .context("timeout preparing statement")??;
67            let rows = match tokio::time::timeout(
68                Duration::from_secs(20),
69                client.query(&statement, &req.params()),
70            )
71            .await
72            .context(format!(
73                "timeout executing statement: {}, params: {:?}",
74                req.statement(),
75                req.params()
76            ))? {
77                Ok(rows) => rows,
78                Err(err) => {
79                    let reason = err.to_string();
80                    if reason.contains("cache lookup failed for type")
81                        || reason.contains("cached plan must not change result type")
82                        || reason.contains("prepared statement")
83                    {
84                        warn!("Database has been updated. Cleaning cache and retrying query");
85                        self.prepared_stmts.clear();
86                        error = Some(err);
87                        continue;
88                    }
89                    return Err(err.into());
90                }
91            };
92            let dur = begin.elapsed();
93            debug!(
94                "Database query took {}.{:03} seconds: {:?}",
95                dur.as_secs(),
96                dur.subsec_millis(),
97                req
98            );
99            let mut response = RDataTable::with_capacity(rows.len());
100            for row in rows {
101                response.push(T::ResponseRow::try_from_row(&row)?);
102            }
103            return Ok(response);
104        }
105        Err(error.unwrap().into())
106    }
107    pub fn conn_hash(&self) -> u64 {
108        self.conn_hash
109    }
110}
111
112pub async fn connect_to_database(config: DatabaseConfig) -> Result<PooledDbClient> {
113    let config = Config {
114        user: config.user,
115        password: config.password.map(|s| s.expose_secret().clone()),
116        dbname: config.dbname,
117        options: config.options,
118        application_name: config.application_name,
119        ssl_mode: config.ssl_mode,
120        host: config.host,
121        hosts: config.hosts,
122        port: config.port,
123        ports: config.ports,
124        connect_timeout: config.connect_timeout,
125        keepalives: config.keepalives,
126        keepalives_idle: config.keepalives_idle,
127        target_session_attrs: config.target_session_attrs,
128        channel_binding: config.channel_binding,
129        manager: config.manager.or({
130            Some(ManagerConfig {
131                recycling_method: RecyclingMethod::Fast,
132            })
133        }),
134        pool: config.pool,
135        ..Default::default()
136    };
137    info!(
138        "Connecting to database {}:{} {}",
139        config.host.as_deref().unwrap_or(""),
140        config.port.unwrap_or(0),
141        config.dbname.as_deref().unwrap_or("")
142    );
143    let mut hasher = DefaultHasher::new();
144    config.host.hash(&mut hasher);
145    config.port.hash(&mut hasher);
146    config.dbname.hash(&mut hasher);
147    let conn_hash = hasher.finish();
148
149    let pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
150    Ok(PooledDbClient {
151        pool,
152        prepared_stmts: Arc::new(Default::default()),
153        conn_hash,
154    })
155}