Skip to main content

coil_data/
sqlx_postgres.rs

1use std::collections::BTreeSet;
2use std::str::FromStr;
3
4use coil_config::DatabaseDriver;
5use sqlx::postgres::{PgArguments, PgConnectOptions, PgPoolOptions};
6use sqlx::{Column, Pool, Postgres, Row};
7
8use crate::{
9    CompiledMigrationBatch, CompiledStatement, CompiledTransaction, DataModelError, DataRuntime,
10    DataValue, quote_identifier,
11};
12
13#[derive(Debug, Clone)]
14pub struct PostgresDataClient {
15    pub runtime: DataRuntime,
16    pub connection_url: String,
17    pub pool: Pool<Postgres>,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq)]
21pub struct StatementExecution {
22    pub rows_affected: u64,
23}
24
25#[derive(Debug, Clone, PartialEq, Eq)]
26pub struct QueryExecution {
27    pub rows_returned: usize,
28    pub projected_columns: Vec<String>,
29}
30
31#[derive(Debug, Clone, PartialEq, Eq)]
32pub struct TransactionExecution {
33    pub statements_executed: usize,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
37pub struct MigrationBatchExecution {
38    pub statements_executed: usize,
39}
40
41impl PostgresDataClient {
42    pub(crate) fn connect_lazy(runtime: &DataRuntime) -> Result<Self, DataModelError> {
43        if runtime.driver != DatabaseDriver::Postgres {
44            return Err(DataModelError::UnsupportedSqlxDriver {
45                driver: runtime.driver,
46            });
47        }
48
49        let connection_url = runtime.resolve_connection_url()?;
50        let options = PgConnectOptions::from_str(&connection_url).map_err(|error| {
51            DataModelError::InvalidConnectionUrl {
52                reason: error.to_string(),
53            }
54        })?;
55        let pool = PgPoolOptions::new()
56            .min_connections(u32::from(runtime.pool.min_connections))
57            .max_connections(u32::from(runtime.pool.max_connections))
58            .acquire_timeout(runtime.pool.statement_timeout)
59            .connect_lazy_with(options);
60
61        Ok(Self {
62            runtime: runtime.clone(),
63            connection_url,
64            pool,
65        })
66    }
67
68    pub async fn ping(&self) -> Result<(), DataModelError> {
69        sqlx::query("SELECT 1")
70            .execute(&self.pool)
71            .await
72            .map_err(|error| DataModelError::Sqlx {
73                reason: error.to_string(),
74            })?;
75        Ok(())
76    }
77
78    pub async fn execute_statement(
79        &self,
80        statement: &CompiledStatement,
81    ) -> Result<StatementExecution, DataModelError> {
82        self.apply_statement_timeout().await?;
83        let result = bind_query(sqlx::query(&statement.sql), &statement.bind_values)?
84            .execute(&self.pool)
85            .await
86            .map_err(|error| DataModelError::Sqlx {
87                reason: error.to_string(),
88            })?;
89
90        Ok(StatementExecution {
91            rows_affected: result.rows_affected(),
92        })
93    }
94
95    pub async fn execute_query(
96        &self,
97        query: &crate::CompiledQuery,
98    ) -> Result<QueryExecution, DataModelError> {
99        self.apply_statement_timeout().await?;
100        let rows = bind_query(sqlx::query(&query.sql), &query.bind_values)?
101            .fetch_all(&self.pool)
102            .await
103            .map_err(|error| DataModelError::Sqlx {
104                reason: error.to_string(),
105            })?;
106
107        let projected_columns = rows
108            .first()
109            .map(|row| {
110                row.columns()
111                    .iter()
112                    .map(|column| column.name().to_string())
113                    .collect()
114            })
115            .unwrap_or_default();
116
117        Ok(QueryExecution {
118            rows_returned: rows.len(),
119            projected_columns,
120        })
121    }
122
123    pub async fn execute_transaction(
124        &self,
125        transaction: &CompiledTransaction,
126    ) -> Result<TransactionExecution, DataModelError> {
127        let mut tx = self
128            .pool
129            .begin()
130            .await
131            .map_err(|error| DataModelError::Sqlx {
132                reason: error.to_string(),
133            })?;
134
135        sqlx::query(&format!(
136            "SET LOCAL statement_timeout = {}",
137            self.runtime.pool.statement_timeout.as_millis()
138        ))
139        .execute(&mut *tx)
140        .await
141        .map_err(|error| DataModelError::Sqlx {
142            reason: error.to_string(),
143        })?;
144
145        for statement in &transaction.statements {
146            bind_query(sqlx::query(&statement.sql), &statement.bind_values)?
147                .execute(&mut *tx)
148                .await
149                .map_err(|error| DataModelError::Sqlx {
150                    reason: error.to_string(),
151                })?;
152        }
153
154        tx.commit().await.map_err(|error| DataModelError::Sqlx {
155            reason: error.to_string(),
156        })?;
157
158        Ok(TransactionExecution {
159            statements_executed: transaction.statements.len(),
160        })
161    }
162
163    pub async fn apply_migrations(
164        &self,
165        batch: &CompiledMigrationBatch,
166    ) -> Result<MigrationBatchExecution, DataModelError> {
167        let mut tx = self
168            .pool
169            .begin()
170            .await
171            .map_err(|error| DataModelError::Sqlx {
172                reason: error.to_string(),
173            })?;
174
175        sqlx::query(&format!(
176            "SET LOCAL statement_timeout = {}",
177            self.runtime.pool.statement_timeout.as_millis()
178        ))
179        .execute(&mut *tx)
180        .await
181        .map_err(|error| DataModelError::Sqlx {
182            reason: error.to_string(),
183        })?;
184
185        for statement in &batch.statements {
186            bind_query(sqlx::query(&statement.sql), &statement.bind_values)?
187                .execute(&mut *tx)
188                .await
189                .map_err(|error| DataModelError::Sqlx {
190                    reason: error.to_string(),
191                })?;
192        }
193
194        tx.commit().await.map_err(|error| DataModelError::Sqlx {
195            reason: error.to_string(),
196        })?;
197
198        Ok(MigrationBatchExecution {
199            statements_executed: batch.statements.len(),
200        })
201    }
202
203    pub async fn applied_migration_keys(
204        &self,
205    ) -> Result<BTreeSet<(String, String)>, DataModelError> {
206        let migrations_table = quote_identifier(&format!(
207            "{}.{}",
208            self.runtime.schema, self.runtime.migrations_table
209        ));
210        sqlx::query(&format!(
211            "CREATE TABLE IF NOT EXISTS {migrations_table} (owner TEXT NOT NULL, migration_id TEXT NOT NULL, description TEXT NOT NULL, applied_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), PRIMARY KEY (owner, migration_id))"
212        ))
213        .execute(&self.pool)
214        .await
215        .map_err(|error| DataModelError::Sqlx {
216            reason: error.to_string(),
217        })?;
218
219        let rows = sqlx::query(&format!(
220            "SELECT owner, migration_id FROM {migrations_table} ORDER BY owner ASC, migration_id ASC"
221        ))
222        .fetch_all(&self.pool)
223        .await
224        .map_err(|error| DataModelError::Sqlx {
225            reason: error.to_string(),
226        })?;
227
228        Ok(rows
229            .into_iter()
230            .map(|row| (row.get("owner"), row.get("migration_id")))
231            .collect())
232    }
233
234    async fn apply_statement_timeout(&self) -> Result<(), DataModelError> {
235        sqlx::query(&format!(
236            "SET statement_timeout = {}",
237            self.runtime.pool.statement_timeout.as_millis()
238        ))
239        .execute(&self.pool)
240        .await
241        .map_err(|error| DataModelError::Sqlx {
242            reason: error.to_string(),
243        })?;
244        Ok(())
245    }
246}
247
248pub(crate) fn bind_query<'q>(
249    mut query: sqlx::query::Query<'q, Postgres, PgArguments>,
250    values: &[DataValue],
251) -> Result<sqlx::query::Query<'q, Postgres, PgArguments>, DataModelError> {
252    for value in values {
253        query = match value {
254            DataValue::String(value) => query.bind(value.clone()),
255            DataValue::Int(value) => query.bind(*value),
256            DataValue::UInt(value) => {
257                let value = i64::try_from(*value)
258                    .map_err(|_| DataModelError::UnsupportedUnsignedBindValue { value: *value })?;
259                query.bind(value)
260            }
261            DataValue::Bool(value) => query.bind(*value),
262        };
263    }
264
265    Ok(query)
266}