1use std::collections::BTreeSet;
2use std::str::FromStr;
3
4use coil_config::DatabaseDriver;
5use sqlx::postgres::{PgArguments, PgConnectOptions, PgPoolOptions};
6use sqlx::{Column, Pool, Postgres, Row};
7
8use crate::{
9 CompiledMigrationBatch, CompiledStatement, CompiledTransaction, DataModelError, DataRuntime,
10 DataValue, quote_identifier,
11};
12
13#[derive(Debug, Clone)]
14pub struct PostgresDataClient {
15 pub runtime: DataRuntime,
16 pub connection_url: String,
17 pub pool: Pool<Postgres>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct StatementExecution {
22 pub rows_affected: u64,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct QueryExecution {
27 pub rows_returned: usize,
28 pub projected_columns: Vec<String>,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct TransactionExecution {
33 pub statements_executed: usize,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct MigrationBatchExecution {
38 pub statements_executed: usize,
39}
40
41impl PostgresDataClient {
42 pub(crate) fn connect_lazy(runtime: &DataRuntime) -> Result<Self, DataModelError> {
43 if runtime.driver != DatabaseDriver::Postgres {
44 return Err(DataModelError::UnsupportedSqlxDriver {
45 driver: runtime.driver,
46 });
47 }
48
49 let connection_url = runtime.resolve_connection_url()?;
50 let options = PgConnectOptions::from_str(&connection_url).map_err(|error| {
51 DataModelError::InvalidConnectionUrl {
52 reason: error.to_string(),
53 }
54 })?;
55 let pool = PgPoolOptions::new()
56 .min_connections(u32::from(runtime.pool.min_connections))
57 .max_connections(u32::from(runtime.pool.max_connections))
58 .acquire_timeout(runtime.pool.statement_timeout)
59 .connect_lazy_with(options);
60
61 Ok(Self {
62 runtime: runtime.clone(),
63 connection_url,
64 pool,
65 })
66 }
67
68 pub async fn ping(&self) -> Result<(), DataModelError> {
69 sqlx::query("SELECT 1")
70 .execute(&self.pool)
71 .await
72 .map_err(|error| DataModelError::Sqlx {
73 reason: error.to_string(),
74 })?;
75 Ok(())
76 }
77
78 pub async fn execute_statement(
79 &self,
80 statement: &CompiledStatement,
81 ) -> Result<StatementExecution, DataModelError> {
82 self.apply_statement_timeout().await?;
83 let result = bind_query(sqlx::query(&statement.sql), &statement.bind_values)?
84 .execute(&self.pool)
85 .await
86 .map_err(|error| DataModelError::Sqlx {
87 reason: error.to_string(),
88 })?;
89
90 Ok(StatementExecution {
91 rows_affected: result.rows_affected(),
92 })
93 }
94
95 pub async fn execute_query(
96 &self,
97 query: &crate::CompiledQuery,
98 ) -> Result<QueryExecution, DataModelError> {
99 self.apply_statement_timeout().await?;
100 let rows = bind_query(sqlx::query(&query.sql), &query.bind_values)?
101 .fetch_all(&self.pool)
102 .await
103 .map_err(|error| DataModelError::Sqlx {
104 reason: error.to_string(),
105 })?;
106
107 let projected_columns = rows
108 .first()
109 .map(|row| {
110 row.columns()
111 .iter()
112 .map(|column| column.name().to_string())
113 .collect()
114 })
115 .unwrap_or_default();
116
117 Ok(QueryExecution {
118 rows_returned: rows.len(),
119 projected_columns,
120 })
121 }
122
123 pub async fn execute_transaction(
124 &self,
125 transaction: &CompiledTransaction,
126 ) -> Result<TransactionExecution, DataModelError> {
127 let mut tx = self
128 .pool
129 .begin()
130 .await
131 .map_err(|error| DataModelError::Sqlx {
132 reason: error.to_string(),
133 })?;
134
135 sqlx::query(&format!(
136 "SET LOCAL statement_timeout = {}",
137 self.runtime.pool.statement_timeout.as_millis()
138 ))
139 .execute(&mut *tx)
140 .await
141 .map_err(|error| DataModelError::Sqlx {
142 reason: error.to_string(),
143 })?;
144
145 for statement in &transaction.statements {
146 bind_query(sqlx::query(&statement.sql), &statement.bind_values)?
147 .execute(&mut *tx)
148 .await
149 .map_err(|error| DataModelError::Sqlx {
150 reason: error.to_string(),
151 })?;
152 }
153
154 tx.commit().await.map_err(|error| DataModelError::Sqlx {
155 reason: error.to_string(),
156 })?;
157
158 Ok(TransactionExecution {
159 statements_executed: transaction.statements.len(),
160 })
161 }
162
163 pub async fn apply_migrations(
164 &self,
165 batch: &CompiledMigrationBatch,
166 ) -> Result<MigrationBatchExecution, DataModelError> {
167 let mut tx = self
168 .pool
169 .begin()
170 .await
171 .map_err(|error| DataModelError::Sqlx {
172 reason: error.to_string(),
173 })?;
174
175 sqlx::query(&format!(
176 "SET LOCAL statement_timeout = {}",
177 self.runtime.pool.statement_timeout.as_millis()
178 ))
179 .execute(&mut *tx)
180 .await
181 .map_err(|error| DataModelError::Sqlx {
182 reason: error.to_string(),
183 })?;
184
185 for statement in &batch.statements {
186 bind_query(sqlx::query(&statement.sql), &statement.bind_values)?
187 .execute(&mut *tx)
188 .await
189 .map_err(|error| DataModelError::Sqlx {
190 reason: error.to_string(),
191 })?;
192 }
193
194 tx.commit().await.map_err(|error| DataModelError::Sqlx {
195 reason: error.to_string(),
196 })?;
197
198 Ok(MigrationBatchExecution {
199 statements_executed: batch.statements.len(),
200 })
201 }
202
203 pub async fn applied_migration_keys(
204 &self,
205 ) -> Result<BTreeSet<(String, String)>, DataModelError> {
206 let migrations_table = quote_identifier(&format!(
207 "{}.{}",
208 self.runtime.schema, self.runtime.migrations_table
209 ));
210 sqlx::query(&format!(
211 "CREATE TABLE IF NOT EXISTS {migrations_table} (owner TEXT NOT NULL, migration_id TEXT NOT NULL, description TEXT NOT NULL, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), PRIMARY KEY (owner, migration_id))"
212 ))
213 .execute(&self.pool)
214 .await
215 .map_err(|error| DataModelError::Sqlx {
216 reason: error.to_string(),
217 })?;
218
219 let rows = sqlx::query(&format!(
220 "SELECT owner, migration_id FROM {migrations_table} ORDER BY owner ASC, migration_id ASC"
221 ))
222 .fetch_all(&self.pool)
223 .await
224 .map_err(|error| DataModelError::Sqlx {
225 reason: error.to_string(),
226 })?;
227
228 Ok(rows
229 .into_iter()
230 .map(|row| (row.get("owner"), row.get("migration_id")))
231 .collect())
232 }
233
234 async fn apply_statement_timeout(&self) -> Result<(), DataModelError> {
235 sqlx::query(&format!(
236 "SET statement_timeout = {}",
237 self.runtime.pool.statement_timeout.as_millis()
238 ))
239 .execute(&self.pool)
240 .await
241 .map_err(|error| DataModelError::Sqlx {
242 reason: error.to_string(),
243 })?;
244 Ok(())
245 }
246}
247
248pub(crate) fn bind_query<'q>(
249 mut query: sqlx::query::Query<'q, Postgres, PgArguments>,
250 values: &[DataValue],
251) -> Result<sqlx::query::Query<'q, Postgres, PgArguments>, DataModelError> {
252 for value in values {
253 query = match value {
254 DataValue::String(value) => query.bind(value.clone()),
255 DataValue::Int(value) => query.bind(*value),
256 DataValue::UInt(value) => {
257 let value = i64::try_from(*value)
258 .map_err(|_| DataModelError::UnsupportedUnsignedBindValue { value: *value })?;
259 query.bind(value)
260 }
261 DataValue::Bool(value) => query.bind(*value),
262 };
263 }
264
265 Ok(query)
266}