pgmq_core/
util.rs

1use std::fmt::Display;
2
3use crate::{errors::PgmqError, types::Message};
4use log::LevelFilter;
5use serde::Deserialize;
6use sqlx::error::Error;
7use sqlx::postgres::PgRow;
8use sqlx::postgres::{PgConnectOptions, PgPoolOptions};
9use sqlx::ConnectOptions;
10use sqlx::Row;
11use sqlx::{Pool, Postgres};
12use url::{ParseError, Url};
13
14// Configure connection options
15pub fn conn_options(url: &str) -> Result<PgConnectOptions, ParseError> {
16    // Parse url
17    let parsed = Url::parse(url)?;
18    let options = PgConnectOptions::new()
19        .host(parsed.host_str().ok_or(ParseError::EmptyHost)?)
20        .port(parsed.port().ok_or(ParseError::InvalidPort)?)
21        .username(parsed.username())
22        .password(parsed.password().ok_or(ParseError::IdnaError)?)
23        .database(parsed.path().trim_start_matches('/'))
24        .log_statements(LevelFilter::Debug);
25    Ok(options)
26}
27
28/// Connect to the database
29pub async fn connect(url: &str, max_connections: u32) -> Result<Pool<Postgres>, PgmqError> {
30    let options = conn_options(url)?;
31    let pgp = PgPoolOptions::new()
32        .acquire_timeout(std::time::Duration::from_secs(10))
33        .max_connections(max_connections)
34        .connect_with(options)
35        .await?;
36    Ok(pgp)
37}
38
39// Executes a query and returns a single row
40// If the query returns no rows, None is returned
41// This function is intended for internal use.
42pub async fn fetch_one_message<T: for<'de> Deserialize<'de>>(
43    query: &str,
44    connection: &Pool<Postgres>,
45) -> Result<Option<Message<T>>, PgmqError> {
46    // explore: .fetch_optional()
47    let row: Result<PgRow, Error> = sqlx::query(query).fetch_one(connection).await;
48    match row {
49        Ok(row) => {
50            // happy path - successfully read a message
51            let raw_msg = row.get("message");
52            let parsed_msg = serde_json::from_value::<T>(raw_msg);
53            match parsed_msg {
54                Ok(parsed_msg) => Ok(Some(Message {
55                    msg_id: row.get("msg_id"),
56                    vt: row.get("vt"),
57                    read_ct: row.get("read_ct"),
58                    enqueued_at: row.get("enqueued_at"),
59                    message: parsed_msg,
60                })),
61                Err(e) => Err(PgmqError::JsonParsingError(e)),
62            }
63        }
64        Err(sqlx::error::Error::RowNotFound) => Ok(None),
65        Err(e) => Err(e)?,
66    }
67}
68
69/// A string that is known to be formed of only ASCII alphanumeric or an underscore;
70#[derive(Clone, Copy)]
71pub struct CheckedName<'a>(&'a str);
72
73impl<'a> CheckedName<'a> {
74    /// Accepts `input` as a CheckedName if it is a valid queue identifier
75    pub fn new(input: &'a str) -> Result<Self, PgmqError> {
76        check_input(input)?;
77
78        Ok(Self(input))
79    }
80}
81
82impl AsRef<str> for CheckedName<'_> {
83    fn as_ref(&self) -> &str {
84        self.0
85    }
86}
87
88impl Display for CheckedName<'_> {
89    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
90        f.write_str(self.0)
91    }
92}
93
94/// panics if input is invalid. otherwise does nothing.
95pub fn check_input(input: &str) -> Result<(), PgmqError> {
96    // Docs:
97    // https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS
98
99    // Default value of `NAMEDATALEN`, set in `src/include/pg_config_manual.h`
100    const NAMEDATALEN: usize = 64;
101
102    // The maximum length of an identifier.
103    // Longer names can be used in commands, but they'll be truncated
104    const MAX_IDENTIFIER_LEN: usize = NAMEDATALEN - 1;
105    const BIGGEST_CONCAT: &str = "archived_at_idx_";
106
107    // The max length of the name of a PGMQ queue, considering that the biggest
108    // postgres identifier created by PGMQ is the index on archived_at
109    const MAX_PGMQ_QUEUE_LEN: usize = MAX_IDENTIFIER_LEN - BIGGEST_CONCAT.len();
110
111    let is_short_enough = input.len() <= MAX_PGMQ_QUEUE_LEN;
112    let has_valid_characters = input
113        .as_bytes()
114        .iter()
115        .all(|&c| c.is_ascii_alphanumeric() || c == b'_');
116    let valid = is_short_enough && has_valid_characters;
117    match valid {
118        true => Ok(()),
119        false => Err(PgmqError::InvalidQueueName {
120            name: input.to_owned(),
121        }),
122    }
123}