db_testkit/backends/
postgres.rs

1use std::str::FromStr;
2
3use async_trait::async_trait;
4use tokio_postgres::{Client, Config, NoTls};
5use url::Url;
6
7use crate::{
8    backend::{Connection, DatabaseBackend, DatabasePool},
9    error::{PoolError, Result},
10    pool::PoolConfig,
11    template::DatabaseName,
12};
13
14pub struct PostgresConnection {
15    pub(crate) client: Client,
16}
17
18#[async_trait]
19impl Connection for PostgresConnection {
20    async fn is_valid(&self) -> bool {
21        self.client.simple_query("SELECT 1").await.is_ok()
22    }
23
24    async fn reset(&mut self) -> Result<()> {
25        self.client
26            .simple_query("DISCARD ALL")
27            .await
28            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
29        Ok(())
30    }
31
32    async fn execute(&mut self, sql: &str) -> Result<()> {
33        self.client
34            .batch_execute(sql)
35            .await
36            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
37        Ok(())
38    }
39}
40
41#[derive(Debug, Clone)]
42pub struct PostgresBackend {
43    config: Config,
44    #[allow(unused)]
45    url: Url,
46}
47
48impl PostgresBackend {
49    pub async fn new(connection_string: &str) -> Result<Self> {
50        let config = Config::from_str(connection_string)
51            .map_err(|e| PoolError::ConfigError(format!("Invalid connection string: {}", e)))?;
52        let url = Url::parse(connection_string)
53            .map_err(|e| PoolError::ConfigError(format!("Invalid connection string: {}", e)))?;
54
55        // Create a connection to postgres database
56        let mut postgres_url = url.clone();
57        postgres_url.set_path("/postgres");
58
59        let postgres_config = Config::from_str(postgres_url.as_str())
60            .map_err(|e| PoolError::ConfigError(format!("Invalid connection string: {}", e)))?;
61
62        // Try to connect and create the database
63        if let Ok((client, connection)) = postgres_config.connect(NoTls).await {
64            tokio::spawn(async move {
65                if let Err(e) = connection.await {
66                    tracing::error!("Connection error: {}", e);
67                }
68            });
69
70            let db_name = url.path().trim_start_matches('/');
71            let _ = client
72                .execute(&format!(r#"CREATE DATABASE "{}""#, db_name), &[])
73                .await;
74        }
75
76        Ok(Self { config, url })
77    }
78
79    #[allow(dead_code)]
80    fn get_database_url(&self, name: &DatabaseName) -> String {
81        let mut url = self.url.clone();
82        url.set_path(name.as_str());
83        url.to_string()
84    }
85}
86
87#[async_trait]
88impl DatabaseBackend for PostgresBackend {
89    type Connection = PostgresConnection;
90    type Pool = PostgresPool;
91
92    async fn create_database(&self, name: &DatabaseName) -> Result<()> {
93        let (client, connection) = self
94            .config
95            .connect(NoTls)
96            .await
97            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
98
99        tokio::spawn(async move {
100            if let Err(e) = connection.await {
101                tracing::error!("Connection error: {}", e);
102            }
103        });
104
105        client
106            .execute(&format!(r#"CREATE DATABASE "{}""#, name), &[])
107            .await
108            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
109
110        Ok(())
111    }
112
113    async fn drop_database(&self, name: &DatabaseName) -> Result<()> {
114        // First terminate all connections
115        self.terminate_connections(name).await?;
116
117        let (client, connection) = self
118            .config
119            .connect(NoTls)
120            .await
121            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
122
123        tokio::spawn(async move {
124            if let Err(e) = connection.await {
125                tracing::error!("Connection error: {}", e);
126            }
127        });
128
129        client
130            .execute(&format!(r#"DROP DATABASE IF EXISTS "{}""#, name), &[])
131            .await
132            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
133
134        Ok(())
135    }
136
137    async fn create_pool(&self, name: &DatabaseName, config: &PoolConfig) -> Result<Self::Pool> {
138        let mut pool_config = self.config.clone();
139        pool_config.dbname(name.as_str());
140        Ok(PostgresPool::new(pool_config, config.max_size))
141    }
142
143    async fn terminate_connections(&self, name: &DatabaseName) -> Result<()> {
144        let (client, connection) = self
145            .config
146            .connect(NoTls)
147            .await
148            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
149
150        tokio::spawn(async move {
151            if let Err(e) = connection.await {
152                tracing::error!("Connection error: {}", e);
153            }
154        });
155
156        client
157            .execute(
158                &format!(
159                    r#"
160                    SELECT pg_terminate_backend(pid)
161                    FROM pg_stat_activity
162                    WHERE datname = '{}'
163                    AND pid <> pg_backend_pid()
164                    "#,
165                    name
166                ),
167                &[],
168            )
169            .await
170            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
171
172        Ok(())
173    }
174
175    async fn create_database_from_template(
176        &self,
177        name: &DatabaseName,
178        template: &DatabaseName,
179    ) -> Result<()> {
180        let (client, connection) = self
181            .config
182            .connect(NoTls)
183            .await
184            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
185
186        tokio::spawn(async move {
187            if let Err(e) = connection.await {
188                tracing::error!("Connection error: {}", e);
189            }
190        });
191
192        client
193            .execute(
194                &format!(r#"CREATE DATABASE "{}" TEMPLATE "{}""#, name, template),
195                &[],
196            )
197            .await
198            .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
199
200        Ok(())
201    }
202}
203
204#[derive(Debug, Clone)]
205pub struct PostgresPool {
206    config: Config,
207    #[allow(unused)]
208    max_size: usize,
209}
210
211impl PostgresPool {
212    pub fn new(config: Config, max_size: usize) -> Self {
213        Self { config, max_size }
214    }
215}
216
217#[async_trait]
218impl DatabasePool for PostgresPool {
219    type Connection = PostgresConnection;
220
221    async fn acquire(&self) -> Result<Self::Connection> {
222        let (client, connection) = self
223            .config
224            .connect(NoTls)
225            .await
226            .map_err(|e| PoolError::ConnectionAcquisitionFailed(e.to_string()))?;
227
228        tokio::spawn(async move {
229            if let Err(e) = connection.await {
230                tracing::error!("Connection error: {}", e);
231            }
232        });
233
234        Ok(PostgresConnection { client })
235    }
236
237    async fn release(&self, _conn: Self::Connection) -> Result<()> {
238        // Connection is automatically closed when dropped
239        Ok(())
240    }
241}