1use std::time::Duration;
4
5use async_trait::async_trait;
6use sqlx::Row;
7use sqlx::postgres::PgPoolOptions;
8
9use crate::introspector::SqlxIntrospector;
10use dbrest_core::backend::{DatabaseBackend, DbVersion, StatementResult};
11use dbrest_core::error::Error;
12use dbrest_core::query::sql_builder::{SqlBuilder, SqlParam};
13use dbrest_core::schema_cache::db::DbIntrospector;
14
15pub struct PgBackend {
17 pool: sqlx::PgPool,
18}
19
20impl PgBackend {
21 pub fn pool(&self) -> &sqlx::PgPool {
24 &self.pool
25 }
26
27 pub fn from_pool(pool: sqlx::PgPool) -> Self {
29 Self { pool }
30 }
31}
32
33fn bind_params<'q>(
38 mut q: sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments>,
39 params: &'q [SqlParam],
40) -> sqlx::query::Query<'q, sqlx::Postgres, sqlx::postgres::PgArguments> {
41 for p in params {
42 match p {
43 SqlParam::Text(t) => q = q.bind(t.as_str()),
44 SqlParam::Json(j) => q = q.bind(j.to_vec()),
45 SqlParam::Binary(b) => q = q.bind(b.to_vec()),
46 SqlParam::Null => q = q.bind(Option::<String>::None),
47 }
48 }
49 q
50}
51
52pub fn map_sqlx_error(e: sqlx::Error) -> Error {
55 let (code, message, detail, hint) = match &e {
56 sqlx::Error::Database(db_err) => {
57 let code = db_err.code().map(|c| c.to_string());
58 let message = db_err.message().to_string();
59 let detail = db_err.constraint().map(|c| c.to_string());
60
61 let hint = if let Some(pg_err) =
62 db_err.try_downcast_ref::<sqlx::postgres::PgDatabaseError>()
63 {
64 pg_err.hint().map(|s| s.to_string())
65 } else {
66 None
67 };
68
69 (code, message, detail, hint)
70 }
71 _ => {
72 return Error::Database {
73 code: None,
74 message: e.to_string(),
75 detail: None,
76 hint: None,
77 };
78 }
79 };
80
81 if code.is_some() || !message.is_empty() {
82 match code.as_deref() {
83 Some("23505") => return Error::UniqueViolation(message),
85 Some("23503") => return Error::ForeignKeyViolation(message),
86 Some("23514") => return Error::CheckViolation(message),
87 Some("23502") => return Error::NotNullViolation(message),
88 Some("23P01") => return Error::ExclusionViolation(message),
89
90 Some("42501") => {
92 let role =
93 extract_role_from_message(&message).unwrap_or_else(|| "unknown".to_string());
94 return Error::PermissionDenied { role };
95 }
96
97 Some("42883") => {
99 if message.contains("operator") {
100 return Error::Database {
101 code: Some("42883".to_string()),
102 message: message.clone(),
103 detail: Some(
104 "Operator error: The requested operator is not available for the given data types."
105 .to_string(),
106 ),
107 hint: Some(
108 "Check that the filter operator and column types are compatible."
109 .to_string(),
110 ),
111 };
112 }
113 let func_name =
114 extract_name_from_message(&message, "function").unwrap_or_else(|| {
115 tracing::debug!(
116 "Could not extract function name from PostgreSQL error: {}",
117 message
118 );
119 "unknown".to_string()
120 });
121 return Error::FunctionNotFound { name: func_name };
122 }
123 Some("42P01") => {
124 let table_name = extract_name_from_message(&message, "relation")
125 .unwrap_or_else(|| "unknown".to_string());
126 return Error::TableNotFound {
127 name: table_name,
128 suggestion: None,
129 };
130 }
131 Some("42703") => {
132 if let Some(col_start) = message.find("column ")
133 && let Some(after_col) = message.get(col_start + 7..)
134 {
135 let col_end = after_col.find(" does").unwrap_or(after_col.len());
136 let col_ref = &after_col[..col_end];
137 let col_ref = col_ref.trim();
138
139 let (table_name, col_name) = if let Some(dot_pos) = col_ref.find('.') {
140 let table = col_ref[..dot_pos].trim_matches('"').to_string();
141 let col = col_ref[dot_pos + 1..].trim_matches('"').to_string();
142 (table, col)
143 } else {
144 let col = col_ref.trim_matches('"').to_string();
145 ("unknown".to_string(), col)
146 };
147 return Error::ColumnNotFound {
148 table: table_name,
149 column: col_name,
150 };
151 }
152 return Error::InvalidQueryParam {
153 param: "column".to_string(),
154 message,
155 };
156 }
157
158 Some("P0001") => {
160 return Error::RaisedException {
161 message,
162 status: None,
163 };
164 }
165
166 Some(code) if code.starts_with("PT") => {
168 if let Some(status_str) = code.strip_prefix("PT")
169 && let Ok(status) = status_str.parse::<u16>()
170 {
171 return Error::DbrstRaise { message, status };
172 }
173 }
174
175 _ => {}
176 }
177
178 return Error::Database {
179 code,
180 message,
181 detail,
182 hint,
183 };
184 }
185
186 Error::Database {
187 code: None,
188 message: e.to_string(),
189 detail: None,
190 hint: None,
191 }
192}
193
194fn extract_role_from_message(msg: &str) -> Option<String> {
195 if let Some(start) = msg.find("role ") {
196 let rest = &msg[start + 5..];
197 if let Some(end) = rest.find([' ', '\n', '\r']) {
198 return Some(rest[..end].to_string());
199 }
200 return Some(rest.to_string());
201 }
202 None
203}
204
205fn extract_name_from_message(msg: &str, keyword: &str) -> Option<String> {
206 if let Some(start) = msg.find(keyword) {
207 let rest = &msg[start + keyword.len()..];
208 let rest = rest.trim_start();
209 if let Some(end) = rest.find([' ', ',', '(', '\n', '\r']) {
210 let name = rest[..end].trim_matches('"').to_string();
211 if !name.is_empty() {
212 return Some(name);
213 }
214 }
215 let name = rest
216 .split_whitespace()
217 .next()?
218 .trim_matches('"')
219 .to_string();
220 if !name.is_empty() {
221 return Some(name);
222 }
223 }
224 None
225}
226
227fn parse_statement_row(row: &sqlx::postgres::PgRow) -> StatementResult {
232 let total: Option<i64> = row
233 .try_get::<String, _>("total_result_set")
234 .ok()
235 .and_then(|s| s.parse::<i64>().ok());
236
237 let page_total: i64 = row.try_get("page_total").unwrap_or(0);
238
239 let body_str: String = row.try_get("body").unwrap_or_else(|_| "[]".to_string());
240
241 let response_headers: Option<serde_json::Value> = row
242 .try_get::<Option<String>, _>("response_headers")
243 .ok()
244 .flatten()
245 .and_then(|s| {
246 if s.is_empty() {
247 None
248 } else {
249 serde_json::from_str(&s).ok()
250 }
251 });
252
253 let response_status: Option<i32> = row
254 .try_get::<Option<String>, _>("response_status")
255 .ok()
256 .flatten()
257 .and_then(|s| {
258 if s.is_empty() {
259 None
260 } else {
261 s.parse::<i32>().ok()
262 }
263 });
264
265 StatementResult {
266 total,
267 page_total,
268 body: body_str,
269 response_headers,
270 response_status,
271 }
272}
273
274#[async_trait]
279impl DatabaseBackend for PgBackend {
280 async fn connect(
281 uri: &str,
282 pool_size: u32,
283 acquire_timeout_secs: u64,
284 max_lifetime_secs: u64,
285 idle_timeout_secs: u64,
286 ) -> Result<Self, Error> {
287 let pool = PgPoolOptions::new()
288 .max_connections(pool_size)
289 .acquire_timeout(Duration::from_secs(acquire_timeout_secs))
290 .max_lifetime(Duration::from_secs(max_lifetime_secs))
291 .idle_timeout(Duration::from_secs(idle_timeout_secs))
292 .connect(uri)
293 .await
294 .map_err(|e| Error::DbConnection(e.to_string()))?;
295
296 Ok(Self { pool })
297 }
298
299 async fn version(&self) -> Result<DbVersion, Error> {
300 let row: (String,) = sqlx::query_as("SHOW server_version")
301 .fetch_one(&self.pool)
302 .await
303 .map_err(|e| Error::DbConnection(format!("Failed to query PG version: {}", e)))?;
304
305 let version_str = &row.0;
306 let parts: Vec<&str> = version_str.split('.').collect();
307 Ok(DbVersion {
308 major: parts.first().and_then(|s| s.parse().ok()).unwrap_or(0),
309 minor: parts.get(1).and_then(|s| s.parse().ok()).unwrap_or(0),
310 patch: parts
311 .get(2)
312 .and_then(|s| s.split_whitespace().next().and_then(|v| v.parse().ok()))
313 .unwrap_or(0),
314 engine: "PostgreSQL".to_string(),
315 })
316 }
317
318 fn min_version(&self) -> (u32, u32) {
319 (12, 0)
320 }
321
322 async fn exec_raw(&self, sql: &str, params: &[SqlParam]) -> Result<(), Error> {
323 let q = sqlx::query(sql);
324 let q = bind_params(q, params);
325 q.execute(&self.pool).await.map_err(map_sqlx_error)?;
326 Ok(())
327 }
328
329 async fn exec_statement(
330 &self,
331 sql: &str,
332 params: &[SqlParam],
333 ) -> Result<StatementResult, Error> {
334 let q = sqlx::query(sql);
335 let q = bind_params(q, params);
336 let rows = q.fetch_all(&self.pool).await.map_err(map_sqlx_error)?;
337
338 if rows.is_empty() {
339 return Ok(StatementResult::empty());
340 }
341
342 Ok(parse_statement_row(&rows[0]))
343 }
344
345 async fn exec_in_transaction(
346 &self,
347 tx_vars: Option<&SqlBuilder>,
348 pre_req: Option<&SqlBuilder>,
349 _mutation: Option<&SqlBuilder>,
350 main: Option<&SqlBuilder>,
351 ) -> Result<StatementResult, Error> {
352 let mut tx = self.pool.begin().await.map_err(|e| Error::Database {
353 code: None,
354 message: e.to_string(),
355 detail: None,
356 hint: None,
357 })?;
358
359 if let Some(tv) = tx_vars {
361 let q = sqlx::query(tv.sql());
362 let q = bind_params(q, tv.params());
363 q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
364 }
365
366 if let Some(pr) = pre_req {
368 let q = sqlx::query(pr.sql());
369 let q = bind_params(q, pr.params());
370 q.execute(&mut *tx).await.map_err(map_sqlx_error)?;
371 }
372
373 let result = if let Some(main_q) = main {
375 let q = sqlx::query(main_q.sql());
376 let q = bind_params(q, main_q.params());
377 let rows = q.fetch_all(&mut *tx).await.map_err(map_sqlx_error)?;
378
379 if rows.is_empty() {
380 StatementResult::empty()
381 } else {
382 parse_statement_row(&rows[0])
383 }
384 } else {
385 StatementResult::empty()
386 };
387
388 tx.commit().await.map_err(|e| Error::Database {
389 code: None,
390 message: e.to_string(),
391 detail: None,
392 hint: None,
393 })?;
394
395 Ok(result)
396 }
397
398 fn introspector(&self) -> Box<dyn DbIntrospector + '_> {
399 Box::new(SqlxIntrospector::new(&self.pool))
400 }
401
402 async fn start_listener(
403 &self,
404 channel: &str,
405 cancel: tokio::sync::watch::Receiver<bool>,
406 on_event: std::sync::Arc<dyn Fn(String) + Send + Sync>,
407 ) -> Result<(), Error> {
408 let mut listener = sqlx::postgres::PgListener::connect_with(&self.pool)
409 .await
410 .map_err(|e| Error::Database {
411 code: None,
412 message: e.to_string(),
413 detail: None,
414 hint: None,
415 })?;
416
417 listener
418 .listen(channel)
419 .await
420 .map_err(|e| Error::Database {
421 code: None,
422 message: e.to_string(),
423 detail: None,
424 hint: None,
425 })?;
426
427 tracing::info!(channel = channel, "Subscribed to NOTIFY channel");
428
429 loop {
432 if *cancel.borrow() {
433 return Ok(());
434 }
435
436 let notification = tokio::time::timeout(Duration::from_secs(30), listener.recv()).await;
437
438 let maybe_payload: Option<Result<String, sqlx::Error>> = match notification {
442 Ok(Ok(msg)) => Some(Ok(msg.payload().to_string())),
443 Ok(Err(e)) => Some(Err(e)),
444 Err(_) => None,
445 };
446
447 match maybe_payload {
448 Some(Ok(payload)) => {
449 tracing::info!(payload = %payload, "Received NOTIFY");
450 on_event(payload);
451 }
452 Some(Err(e)) => {
453 return Err(Error::Database {
454 code: None,
455 message: e.to_string(),
456 detail: None,
457 hint: None,
458 });
459 }
460 None => continue,
461 }
462 }
463 }
464
465 fn map_error(&self, err: Box<dyn std::error::Error + Send + Sync>) -> Error {
466 if let Ok(sqlx_err) = err.downcast::<sqlx::Error>() {
467 map_sqlx_error(*sqlx_err)
468 } else {
469 Error::Internal("Unknown database error".to_string())
470 }
471 }
472}