1use std::fmt::Display;
2
3use crate::{errors::PgmqError, types::Message};
4use log::LevelFilter;
5use serde::Deserialize;
6use sqlx::error::Error;
7use sqlx::postgres::PgRow;
8use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
9use sqlx::ConnectOptions;
10use sqlx::Row;
11use sqlx::{Pool, Postgres};
12use url::{ParseError, Url};
13
14pub fn conn_options(url: &str) -> Result<PgConnectOptions, ParseError> {
16 let parsed = Url::parse(url)?;
18 let options = PgConnectOptions::new()
19 .host(parsed.host_str().ok_or(ParseError::EmptyHost)?)
20 .port(parsed.port().ok_or(ParseError::InvalidPort)?)
21 .username(parsed.username())
22 .password(parsed.password().ok_or(ParseError::IdnaError)?)
23 .database(parsed.path().trim_start_matches('/'))
24 .log_statements(LevelFilter::Debug);
25 Ok(options)
26}
27
28pub async fn connect(url: &str, max_connections: u32) -> Result<Pool<Postgres>, PgmqError> {
30 let options = conn_options(url)?;
31 let pgp = PgPoolOptions::new()
32 .acquire_timeout(std::time::Duration::from_secs(10))
33 .max_connections(max_connections)
34 .connect_with(options)
35 .await?;
36 Ok(pgp)
37}
38
39pub async fn fetch_one_message<T: for<'de> Deserialize<'de>>(
43 query: &str,
44 connection: &Pool<Postgres>,
45) -> Result<Option<Message<T>>, PgmqError> {
46 let row: Result<PgRow, Error> = sqlx::query(query).fetch_one(connection).await;
48 match row {
49 Ok(row) => {
50 let raw_msg = row.get("message");
52 let parsed_msg = serde_json::from_value::<T>(raw_msg);
53 match parsed_msg {
54 Ok(parsed_msg) => Ok(Some(Message {
55 msg_id: row.get("msg_id"),
56 vt: row.get("vt"),
57 read_ct: row.get("read_ct"),
58 enqueued_at: row.get("enqueued_at"),
59 message: parsed_msg,
60 })),
61 Err(e) => Err(PgmqError::JsonParsingError(e)),
62 }
63 }
64 Err(sqlx::error::Error::RowNotFound) => Ok(None),
65 Err(e) => Err(e)?,
66 }
67}
68
69#[derive(Clone, Copy)]
71pub struct CheckedName<'a>(&'a str);
72
73impl<'a> CheckedName<'a> {
74 pub fn new(input: &'a str) -> Result<Self, PgmqError> {
76 check_input(input)?;
77
78 Ok(Self(input))
79 }
80}
81
82impl AsRef<str> for CheckedName<'_> {
83 fn as_ref(&self) -> &str {
84 self.0
85 }
86}
87
88impl Display for CheckedName<'_> {
89 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90 f.write_str(self.0)
91 }
92}
93
94pub fn check_input(input: &str) -> Result<(), PgmqError> {
96 const NAMEDATALEN: usize = 64;
101
102 const MAX_IDENTIFIER_LEN: usize = NAMEDATALEN - 1;
105 const BIGGEST_CONCAT: &str = "archived_at_idx_";
106
107 const MAX_PGMQ_QUEUE_LEN: usize = MAX_IDENTIFIER_LEN - BIGGEST_CONCAT.len();
110
111 let is_short_enough = input.len() <= MAX_PGMQ_QUEUE_LEN;
112 let has_valid_characters = input
113 .as_bytes()
114 .iter()
115 .all(|&c| c.is_ascii_alphanumeric() || c == b'_');
116 let valid = is_short_enough && has_valid_characters;
117 match valid {
118 true => Ok(()),
119 false => Err(PgmqError::InvalidQueueName {
120 name: input.to_owned(),
121 }),
122 }
123}