endpoint_libs/libs/database/
pooled.rs

1use dashmap::DashMap;
2use deadpool_postgres::Runtime;
3use deadpool_postgres::*;
4use eyre::*;
5use postgres_from_row::FromRow;
6use secrecy::ExposeSecret;
7use std::collections::hash_map::DefaultHasher;
8use std::fmt::Debug;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::result::Result::Ok;
12use std::sync::Arc;
13use std::time::Duration;
14pub use tokio_postgres::types::ToSql;
15use tokio_postgres::Statement;
16pub use tokio_postgres::{NoTls, Row, ToStatement};
17use tracing::*;
18
19use crate::libs::datatable::RDataTable;
20
21use super::DatabaseConfig;
22use super::DatabaseRequest;
23
24#[derive(Clone)]
25pub struct PooledDbClient {
26    pool: Pool,
27    prepared_stmts: Arc<DashMap<String, Statement>>,
28    conn_hash: u64,
29}
30impl PooledDbClient {
31    #[deprecated]
32    pub async fn query<T>(
33        &self,
34        statement: &T,
35        params: &[&(dyn ToSql + Sync)],
36    ) -> Result<Vec<Row>, Error>
37    where
38        T: ?Sized + Sync + Send + ToStatement,
39    {
40        Ok(self
41            .pool
42            .get()
43            .await
44            .context("Failed to connect to database")?
45            .query(statement, params)
46            .await?)
47    }
48
49    pub async fn execute<T: DatabaseRequest + Debug>(
50        &self,
51        req: T,
52    ) -> Result<RDataTable<T::ResponseRow>> {
53        let mut error = None;
54        for _ in 0..2 {
55            let begin = std::time::Instant::now();
56            let client = self
57                .pool
58                .get()
59                .await
60                .context("Failed to connect to database")?;
61            let statement =
62                tokio::time::timeout(Duration::from_secs(20), client.prepare_cached(req.statement()))
63                    .await
64                    .context("timeout preparing statement")??;
65            let rows = match tokio::time::timeout(
66                Duration::from_secs(20),
67                client.query(&statement, &req.params()),
68            )
69            .await
70            .context(format!("timeout executing statement: {}, params: {:?}", req.statement(), req.params()))?
71            {
72                Ok(rows) => rows,
73                Err(err) => {
74                    let reason = err.to_string();
75                    if reason.contains("cache lookup failed for type")
76                        || reason.contains("cached plan must not change result type")
77                        || reason.contains("prepared statement")
78                    {
79                        warn!("Database has been updated. Cleaning cache and retrying query");
80                        self.prepared_stmts.clear();
81                        error = Some(err);
82                        continue;
83                    }
84                    return Err(err.into());
85                }
86            };
87            let dur = begin.elapsed();
88            debug!(
89                "Database query took {}.{:03} seconds: {:?}",
90                dur.as_secs(),
91                dur.subsec_millis(),
92                req
93            );
94            let mut response = RDataTable::with_capacity(rows.len());
95            for row in rows {
96                response.push(T::ResponseRow::try_from_row(&row)?);
97            }
98            return Ok(response);
99        }
100        Err(error.unwrap().into())
101    }
102    pub fn conn_hash(&self) -> u64 {
103        self.conn_hash
104    }
105}
106
107pub async fn connect_to_database(config: DatabaseConfig) -> Result<PooledDbClient> {
108    let config = Config {
109        user: config.user,
110        password: config.password.map(|s| s.expose_secret().clone()),
111        dbname: config.dbname,
112        options: config.options,
113        application_name: config.application_name,
114        ssl_mode: config.ssl_mode,
115        host: config.host,
116        hosts: config.hosts,
117        port: config.port,
118        ports: config.ports,
119        connect_timeout: config.connect_timeout,
120        keepalives: config.keepalives,
121        keepalives_idle: config.keepalives_idle,
122        target_session_attrs: config.target_session_attrs,
123        channel_binding: config.channel_binding,
124        manager: config.manager.or_else(|| {
125            Some(ManagerConfig {
126                recycling_method: RecyclingMethod::Fast,
127            })
128        }),
129        pool: config.pool,
130        ..Default::default()
131    };
132    info!(
133        "Connecting to database {}:{} {}",
134        config.host.as_deref().unwrap_or(""),
135        config.port.unwrap_or(0),
136        config.dbname.as_deref().unwrap_or("")
137    );
138    let mut hasher = DefaultHasher::new();
139    config.host.hash(&mut hasher);
140    config.port.hash(&mut hasher);
141    config.dbname.hash(&mut hasher);
142    let conn_hash = hasher.finish();
143
144    let pool = config.create_pool(Some(Runtime::Tokio1), NoTls)?;
145    Ok(PooledDbClient {
146        pool,
147        prepared_stmts: Arc::new(Default::default()),
148        conn_hash,
149    })
150}