db_testkit/backends/
postgres.rs1use std::str::FromStr;
2
3use async_trait::async_trait;
4use tokio_postgres::{Client, Config, NoTls};
5use url::Url;
6
7use crate::{
8 backend::{Connection, DatabaseBackend, DatabasePool},
9 error::{PoolError, Result},
10 pool::PoolConfig,
11 template::DatabaseName,
12};
13
14pub struct PostgresConnection {
15 pub(crate) client: Client,
16}
17
18#[async_trait]
19impl Connection for PostgresConnection {
20 async fn is_valid(&self) -> bool {
21 self.client.simple_query("SELECT 1").await.is_ok()
22 }
23
24 async fn reset(&mut self) -> Result<()> {
25 self.client
26 .simple_query("DISCARD ALL")
27 .await
28 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
29 Ok(())
30 }
31
32 async fn execute(&mut self, sql: &str) -> Result<()> {
33 self.client
34 .batch_execute(sql)
35 .await
36 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
37 Ok(())
38 }
39}
40
41#[derive(Debug, Clone)]
42pub struct PostgresBackend {
43 config: Config,
44 #[allow(unused)]
45 url: Url,
46}
47
48impl PostgresBackend {
49 pub async fn new(connection_string: &str) -> Result<Self> {
50 let config = Config::from_str(connection_string)
51 .map_err(|e| PoolError::ConfigError(format!("Invalid connection string: {}", e)))?;
52 let url = Url::parse(connection_string)
53 .map_err(|e| PoolError::ConfigError(format!("Invalid connection string: {}", e)))?;
54
55 let mut postgres_url = url.clone();
57 postgres_url.set_path("/postgres");
58
59 let postgres_config = Config::from_str(postgres_url.as_str())
60 .map_err(|e| PoolError::ConfigError(format!("Invalid connection string: {}", e)))?;
61
62 if let Ok((client, connection)) = postgres_config.connect(NoTls).await {
64 tokio::spawn(async move {
65 if let Err(e) = connection.await {
66 tracing::error!("Connection error: {}", e);
67 }
68 });
69
70 let db_name = url.path().trim_start_matches('/');
71 let _ = client
72 .execute(&format!(r#"CREATE DATABASE "{}""#, db_name), &[])
73 .await;
74 }
75
76 Ok(Self { config, url })
77 }
78
79 #[allow(dead_code)]
80 fn get_database_url(&self, name: &DatabaseName) -> String {
81 let mut url = self.url.clone();
82 url.set_path(name.as_str());
83 url.to_string()
84 }
85}
86
87#[async_trait]
88impl DatabaseBackend for PostgresBackend {
89 type Connection = PostgresConnection;
90 type Pool = PostgresPool;
91
92 async fn create_database(&self, name: &DatabaseName) -> Result<()> {
93 let (client, connection) = self
94 .config
95 .connect(NoTls)
96 .await
97 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
98
99 tokio::spawn(async move {
100 if let Err(e) = connection.await {
101 tracing::error!("Connection error: {}", e);
102 }
103 });
104
105 client
106 .execute(&format!(r#"CREATE DATABASE "{}""#, name), &[])
107 .await
108 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
109
110 Ok(())
111 }
112
113 async fn drop_database(&self, name: &DatabaseName) -> Result<()> {
114 self.terminate_connections(name).await?;
116
117 let (client, connection) = self
118 .config
119 .connect(NoTls)
120 .await
121 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
122
123 tokio::spawn(async move {
124 if let Err(e) = connection.await {
125 tracing::error!("Connection error: {}", e);
126 }
127 });
128
129 client
130 .execute(&format!(r#"DROP DATABASE IF EXISTS "{}""#, name), &[])
131 .await
132 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
133
134 Ok(())
135 }
136
137 async fn create_pool(&self, name: &DatabaseName, config: &PoolConfig) -> Result<Self::Pool> {
138 let mut pool_config = self.config.clone();
139 pool_config.dbname(name.as_str());
140 Ok(PostgresPool::new(pool_config, config.max_size))
141 }
142
143 async fn terminate_connections(&self, name: &DatabaseName) -> Result<()> {
144 let (client, connection) = self
145 .config
146 .connect(NoTls)
147 .await
148 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
149
150 tokio::spawn(async move {
151 if let Err(e) = connection.await {
152 tracing::error!("Connection error: {}", e);
153 }
154 });
155
156 client
157 .execute(
158 &format!(
159 r#"
160 SELECT pg_terminate_backend(pid)
161 FROM pg_stat_activity
162 WHERE datname = '{}'
163 AND pid <> pg_backend_pid()
164 "#,
165 name
166 ),
167 &[],
168 )
169 .await
170 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
171
172 Ok(())
173 }
174
175 async fn create_database_from_template(
176 &self,
177 name: &DatabaseName,
178 template: &DatabaseName,
179 ) -> Result<()> {
180 let (client, connection) = self
181 .config
182 .connect(NoTls)
183 .await
184 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
185
186 tokio::spawn(async move {
187 if let Err(e) = connection.await {
188 tracing::error!("Connection error: {}", e);
189 }
190 });
191
192 client
193 .execute(
194 &format!(r#"CREATE DATABASE "{}" TEMPLATE "{}""#, name, template),
195 &[],
196 )
197 .await
198 .map_err(|e| PoolError::DatabaseError(e.to_string()))?;
199
200 Ok(())
201 }
202}
203
204#[derive(Debug, Clone)]
205pub struct PostgresPool {
206 config: Config,
207 #[allow(unused)]
208 max_size: usize,
209}
210
211impl PostgresPool {
212 pub fn new(config: Config, max_size: usize) -> Self {
213 Self { config, max_size }
214 }
215}
216
217#[async_trait]
218impl DatabasePool for PostgresPool {
219 type Connection = PostgresConnection;
220
221 async fn acquire(&self) -> Result<Self::Connection> {
222 let (client, connection) = self
223 .config
224 .connect(NoTls)
225 .await
226 .map_err(|e| PoolError::ConnectionAcquisitionFailed(e.to_string()))?;
227
228 tokio::spawn(async move {
229 if let Err(e) = connection.await {
230 tracing::error!("Connection error: {}", e);
231 }
232 });
233
234 Ok(PostgresConnection { client })
235 }
236
237 async fn release(&self, _conn: Self::Connection) -> Result<()> {
238 Ok(())
240 }
241}